From a8933ee340f1a336fd7b868f70b683c1077364fd Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Wed, 18 Feb 2026 04:52:53 -0800 Subject: [PATCH] [overlap] Overlap simulation on 1d, 2d variants of llama3, DSv3 for 64, 256 gpus stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/289, branch: IvanKobzarev/stack/12 ghstack-source-id: d69455ced7cf31a61d259a0d350f2dbd425b2e24 Pull-Request: https://github.com/meta-pytorch/autoparallel/pull/318 --- autoparallel/graph_passes/debug_helpers.py | 98 +- .../tools/overlap_simulator/colls16_8.table | 8 + .../tools/overlap_simulator/colls32_8.table | 8 + .../tools/overlap_simulator/colls64_1.table | 6 + .../tools/overlap_simulator/colls8_8.table | 7 + .../overlap_simulator/repro_dsv3_bw_128.py | 17752 ++++++++++++++++ .../overlap_simulator/repro_dsv3_bw_64.py | 11332 ++++++++++ .../overlap_simulator/repro_dsv3_fw_128.py | 10290 +++++++++ .../overlap_simulator/repro_dsv3_fw_64.py | 8752 ++++++++ .../repro_llama3_8b_bw_256_1d.py | 8954 ++++++++ .../repro_llama3_8b_bw_256_2d.py | 11446 ++++++++++ .../repro_llama3_8b_bw_64_1d.py | 8953 ++++++++ .../repro_llama3_8b_bw_64_2d.py | 5783 +++++ .../repro_llama3_8b_fw_256_1d.py | 4153 ++++ .../repro_llama3_8b_fw_256_2d.py | 5658 +++++ .../repro_llama3_8b_fw_64_1d.py | 4153 ++++ .../repro_llama3_8b_fw_64_2d.py | 5657 +++++ autoparallel/tools/overlap_simulator/run.py | 849 + pyproject.toml | 10 + 19 files changed, 103844 insertions(+), 25 deletions(-) create mode 100644 autoparallel/tools/overlap_simulator/colls16_8.table create mode 100644 autoparallel/tools/overlap_simulator/colls32_8.table create mode 100644 autoparallel/tools/overlap_simulator/colls64_1.table create mode 100644 autoparallel/tools/overlap_simulator/colls8_8.table create mode 100644 autoparallel/tools/overlap_simulator/repro_dsv3_bw_128.py create mode 100644 autoparallel/tools/overlap_simulator/repro_dsv3_bw_64.py create mode 100644 autoparallel/tools/overlap_simulator/repro_dsv3_fw_128.py create mode 100644 autoparallel/tools/overlap_simulator/repro_dsv3_fw_64.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_1d.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_2d.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_1d.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_2d.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_1d.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_2d.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_1d.py create mode 100644 autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_2d.py create mode 100644 autoparallel/tools/overlap_simulator/run.py diff --git a/autoparallel/graph_passes/debug_helpers.py b/autoparallel/graph_passes/debug_helpers.py index 577882b5..c361fa89 100644 --- a/autoparallel/graph_passes/debug_helpers.py +++ b/autoparallel/graph_passes/debug_helpers.py @@ -217,10 +217,39 @@ def _get_tid(node): if _is_communication_node(node): if node.target == torch.ops._c10d_functional.wait_tensor.default: return 0 - return node.args[-1] + return f"group-{node.args[-1]}" return 0 +def _get_category(node): + """Get trace category for node.""" + if _is_communication_node(node): + return "nccl" + if node.op == "call_function": + target_str = str(node.target) + if any(x in target_str.lower() for x in ["mm", "bmm", "matmul", "addmm"]): + return "gemm" + if any(x in target_str.lower() for x in ["flash", "attention", "sdpa"]): + return "attention" + return "kernel" + + +def _get_collective_type(node): + """Get the type of collective operation.""" + if not _is_communication_node(node): + return None + target_str = str(node.target) + if "all_gather" in target_str: + return "AllGather" + elif "reduce_scatter" in target_str: + return "ReduceScatter" + elif "all_reduce" in target_str: + return "AllReduce" + elif "wait_tensor" in target_str: + return "Wait" + return "Unknown" + + def get_repr(arg, mode="full"): def get_dtype_repr(dtype): return dtype_abbrs[dtype] @@ -259,51 +288,70 @@ def get_dtype_repr(dtype): def create_execution_trace( gm: torch.fx.GraphModule, runtime_estimator: Callable[[torch.fx.Node], float], - file_path: str = "fake_trace.json", -): + name: str = "fake_trace", + file_path: str | None = None, +) -> dict[str, Any]: """ Create a perfetto trace from a GraphModule representing its execution trace. This is useful for inspecting communication-computation overlapping for different reordering strategies. """ - trace: dict[str, Any] = {} + launch_overhead = 1 # 1us + ms_to_us = 1000 + trace_events = [] - curr_time = {0: 0} - global_time: dict[torch.fx.Node, int] = {} + curr_time: dict[int | str, float] = {0: 0} + global_time: dict[torch.fx.Node, float] = {} + for node_idx, node in enumerate(gm.graph.nodes): - dur = int(runtime_estimator(node)) + dur = runtime_estimator(node) * ms_to_us tid = _get_tid(node) if tid not in curr_time: curr_time[tid] = curr_time[0] - event = {"ph": "X", "cat": "kernel", "name": str(node), "pid": 0, "tid": tid} + + cat = _get_category(node) + coll_type = _get_collective_type(node) + node_name = f"nccl:{coll_type}:{node.name}" if coll_type else str(node) + + event = {"ph": "X", "cat": cat, "name": node_name, "pid": 0, "tid": tid} + if _is_communication_node(node): if tid == 0 and is_wait_tensor(node) and node.args[0].op != "placeholder": - # if it's wait tensor, let's sync with compute stream - comm_end_time = global_time.pop(node.args[0]) + # if it's wait tensor, walk up chained waits to find the collective + # Now this may happen for all_to_all in dsv3 (wait(wait(all_to_all))) + comm_node = node.args[0] + while is_wait_tensor(comm_node): + comm_node = comm_node.args[0] + comm_end_time = global_time[comm_node] curr_time[tid] = max(curr_time[tid], comm_end_time) else: curr_time[tid] = max(curr_time[0], curr_time[tid]) event["ts"] = curr_time[tid] event["dur"] = dur - launch_overhead = 1 # 1us curr_time[tid] += dur + launch_overhead if tid != 0: curr_time[0] += launch_overhead # keep track of when a given collective will finish global_time[node] = curr_time[tid] - args: dict[str, Any] = {} - args["order"] = node_idx - - args["output"] = get_repr(node, mode="content_only") - node_args = [] - for arg in node.args: - node_args.append(get_repr(arg)) - args["inputs"] = node_args - event["args"] = args - trace_events.append(event) - trace["traceEvents"] = trace_events - trace["traceName"] = "fake_trace.json" - with open(file_path, "w") as fp: - json.dump(trace, fp) + event["args"] = { + "order": node_idx, + "output": get_repr(node, mode="content_only"), + "inputs": [get_repr(arg) for arg in node.args], + } + + if dur > 0.0: + trace_events.append(event) + + trace = { + "traceEvents": trace_events, + "traceName": f"{name}_trace.json", + "displayTimeUnit": "us", + } + + if file_path is not None: + with open(file_path, "w") as fp: + json.dump(trace, fp, indent=2) + + return trace diff --git a/autoparallel/tools/overlap_simulator/colls16_8.table b/autoparallel/tools/overlap_simulator/colls16_8.table new file mode 100644 index 00000000..248f6393 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/colls16_8.table @@ -0,0 +1,8 @@ + Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms) +------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- ------------- + 1 8 all_gather_into_tensor 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 10.4714 20.9105 41.7888 + 1 8 reduce_scatter_tensor 0.0173 0.0238 0.0368 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 + 1 8 all_reduce 0.028 0.041 0.0628 0.0849 0.1292 0.2179 0.3822 0.7084 1.3609 2.6658 5.2756 10.4952 + 0 16 all_gather_into_tensor 0.4977 0.8538 1.5291 2.8398 5.4613 10.7042 21.1899 42.1614 84.1045 167.9904 335.763 671.3073 + 0 16 reduce_scatter_tensor 0.1282 0.1638 0.2247 0.3136 0.4858 0.8166 1.4697 2.7525 5.2865 10.3546 20.4909 40.7633 + 0 16 all_gather_into_tensor_out 0.4977 0.8538 1.5291 2.8398 5.4613 10.7042 21.1899 42.1614 84.1045 167.9904 335.763 671.3073 diff --git a/autoparallel/tools/overlap_simulator/colls32_8.table b/autoparallel/tools/overlap_simulator/colls32_8.table new file mode 100644 index 00000000..50426ef2 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/colls32_8.table @@ -0,0 +1,8 @@ + Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms) +------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- ------------- + 1 8 all_gather_into_tensor 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 10.4714 20.9105 41.7888 + 1 8 reduce_scatter_tensor 0.0173 0.0238 0.0368 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 + 1 8 all_reduce 0.028 0.041 0.0628 0.0849 0.1292 0.2179 0.3822 0.7084 1.3609 2.6658 5.2756 10.4952 + 0 32 all_gather_into_tensor 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835 87.1247 173.807 347.171 693.901 1387.36 + 0 32 reduce_scatter_tensor 0.2114 0.2612 0.3608 0.4615 0.6455 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835 + 0 32 all_gather_into_tensor_out 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835 87.1247 173.807 347.171 693.901 1387.36 diff --git a/autoparallel/tools/overlap_simulator/colls64_1.table b/autoparallel/tools/overlap_simulator/colls64_1.table new file mode 100644 index 00000000..403d3160 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/colls64_1.table @@ -0,0 +1,6 @@ + Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms) +------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- ------------- + 0 64 all_reduce 0.20 0.35 0.60 1.10 2.10 4.10 8.10 16.10 32.10 64.10 128.10 256.10 + 0 64 all_gather_into_tensor 0.25 0.45 0.80 1.50 2.90 5.70 11.30 22.50 44.90 89.70 179.30 358.50 + 0 64 reduce_scatter_tensor 0.25 0.45 0.80 1.50 2.90 5.70 11.30 22.50 44.90 89.70 179.30 358.50 + 0 64 all_gather_into_tensor_out 0.25 0.45 0.80 1.50 2.90 5.70 11.30 22.50 44.90 89.70 179.30 358.50 diff --git a/autoparallel/tools/overlap_simulator/colls8_8.table b/autoparallel/tools/overlap_simulator/colls8_8.table new file mode 100644 index 00000000..9d75dac7 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/colls8_8.table @@ -0,0 +1,7 @@ + Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms) +------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- ------------- + 1 8 all_reduce 0.028 0.041 0.0628 0.0849 0.1292 0.2179 0.3822 0.7084 1.3609 2.6658 5.2756 10.4952 + 1 8 all_gather_into_tensor 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 10.4714 20.9105 41.7888 + 0 8 reduce_scatter_tensor 0.0866 0.1151 0.1566 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532 + 0 8 all_gather_into_tensor_out 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532 78.4001 156.694 313.281 + 0 8 all_gather_into_tensor 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532 78.4001 156.694 313.281 diff --git a/autoparallel/tools/overlap_simulator/repro_dsv3_bw_128.py b/autoparallel/tools/overlap_simulator/repro_dsv3_bw_128.py new file mode 100644 index 00000000..8490ab68 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_dsv3_bw_128.py @@ -0,0 +1,17752 @@ +# fmt: off +# flake8: noqa +# isort: skip_file + +import os +os.environ['PYTORCH_KERNEL_CACHE_PATH'] = '/mnt/mffuse/.cache/torch/kernels' +os.environ['TORCH_DISABLE_ADDR2LINE'] = '1' +os.environ['TORCH_TRACE'] = '/mnt/mffuse/outputs/sfsdp-dsv3-16b--tp1-bs2-inductor-128-ivankobzarev-hkpmbjc3/torch_trace/' +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' +os.environ['TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE'] = '[${role_name}${rank}|${local_rank}]:' +os.environ['TORCHELASTIC_MAX_RESTARTS'] = '0' +os.environ['TORCHX_INTERNAL_SESSION_ID'] = '03a200cc-023c-47d4-8372-8d223aedc5c2' +os.environ['TORCHX_RUN_PYTHONPATH'] = '' +os.environ['TORCHELASTIC_ERROR_FILE'] = '/tmp/torchelastic_i226b1gg/sfsdp-dsv3-16b--tp1-bs2-inductor-128-ivankobzarev-hkpmbjc3_ukylfuu9/attempt_0/0/error.json' +os.environ['TORCH_ADDR2LINE_BINARY'] = '/packages/folly.symbolizer/folly-addr2line' +os.environ['TORCHX_JOB_ID'] = 'mast_conda://torchx/sfsdp-dsv3-16b--tp1-bs2-inductor-128-ivankobzarev-hkpmbjc3' +os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '3' +os.environ['TORCHELASTIC_SIGNALS_TO_HANDLE'] = 'SIGTERM,SIGINT,SIGHUP,SIGQUIT' +os.environ['TORCHELASTIC_RUN_ID'] = 'sfsdp-dsv3-16b--tp1-bs2-inductor-128-ivankobzarev-hkpmbjc3' +os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1' +os.environ['TORCHELASTIC_RESTART_COUNT'] = '0' +os.environ['TORCHELASTIC_USE_AGENT_STORE'] = 'False' +os.environ['PYTORCH_DDP_USE_SIDE_STREAM'] = '0' +os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/tmp/torchinductor_root' +os.environ['TORCH_FR_BUFFER_SIZE'] = '20000' +os.environ['TORCH_NCCL_DUMP_ON_TIMEOUT'] = '1' +os.environ['TORCH_FR_DUMP_TEMP_FILE'] = '/mnt/mffuse_nccl_trace/nccl_trace/sfsdp-dsv3-16b--tp1-bs2-inductor-128-ivankobzarev-hkpmbjc3/v_0/attempt_0/nccl_trace_rank_' +os.environ['TRITON_CACHE_DIR'] = '/tmp/torchinductor_root/triton/0' + +import torch +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +import torch._inductor.inductor_prims +import torch.distributed as dist +from torch.testing._internal.distributed.fake_pg import FakeStore +import triton +import triton.language as tl + +import torch._dynamo.config +import torch._inductor.config +import torch._functorch.config +import torch.fx.experimental._config +torch._dynamo.config.capture_scalar_outputs = True +torch._inductor.config.allow_buffer_reuse = False +torch._inductor.config.reorder_for_compute_comm_overlap = False +torch._inductor.config.reorder_for_peak_memory = False +torch._inductor.config.max_autotune = False +torch._inductor.config.coordinate_descent_tuning = False +torch._inductor.config.deterministic = False +torch._inductor.config.aten_distributed_optimizations.collective_bucketing = True +torch._inductor.config.aten_distributed_optimizations.insert_overlap_deps = True +torch._inductor.config.wrap_inductor_compiled_regions = False +torch._inductor.config.triton.cudagraphs = False +torch._inductor.config.triton.store_cubin = False +torch._inductor.config.test_configs.runtime_triton_dtype_assert = False +torch._functorch.config.functionalize_rng_ops = False +torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = True +torch._functorch.config.unlift_effect_tokens = True +torch._functorch.config.selective_decompose = False + + + +isolate_fails_code_str = None + + + + + +if "__compile_source__" in globals(): + import inspect as __after_aot_inspect + import linecache as __after_aot_linecache + __after_aot_filename = __after_aot_inspect.currentframe().f_code.co_filename + __after_aot_linecache.cache[__after_aot_filename] = ( + len(__compile_source__), + None, + __compile_source__.splitlines(True), + __after_aot_filename, + ) +# torch version: 2.11.0a0+git5ac4d4b +# torch cuda version: 12.4 +# torch git version: 5ac4d4bf3f85e15fdd6676f46b090568ea91e47e + + +# CUDA Info: +# nvcc not found +# GPU Hardware Info: +# NVIDIA H100 80GB HBM3 : 8 + +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.reset_table() + +@triton.jit +def _fill_indices_kernel_0( + tokens_per_expert_group_ptr, + start_index_values_ptr, + write_offsets_ptr, + output_ptr, + experts_per_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # Number of threads per block +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # map programs (blocks) to the experts and loop (grid stride) if needed + for expert_id in range(pid, experts_per_rank, num_programs): + # read this experts write offset + write_offset = tl.load(write_offsets_ptr + expert_id) + + for r in range(num_ranks): + # index into tokens_per_expert_group array + i = r * experts_per_rank + expert_id + + # load start index and number of tokens for this expert-rank pair + start_index = tl.load(start_index_values_ptr + i) + length = tl.load(tokens_per_expert_group_ptr + i) + + # each thread in block processes tokens in parallel + offsets = tl.arange(0, BLOCK_SIZE) + + # tokens are processed in chunks of BLOCK_SIZE + for chunk_start in range(0, length, BLOCK_SIZE): + chunk_offsets = chunk_start + offsets + + # mask valid indices + mask = chunk_offsets < length + + values = start_index + chunk_offsets + + # destination + dest_indices = write_offset + chunk_offsets + + # store + tl.store(output_ptr + dest_indices, values, mask=mask) + + # update write offset for next rank + write_offset += length + +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(_fill_indices_kernel_0) +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.constant_args={0: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 1: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 2: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 3: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 4: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 5: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 6: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 7: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 8: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 9: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 10: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 11: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 12: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 13: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 14: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 15: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 16: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 17: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 18: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 19: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 20: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 21: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 22: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 23: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 24: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 25: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}} + +from torch.nn import * +# Stub for submodules referenced in backward graph +class GraphModule(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, *args, **kwargs): + pass +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fw_graph0 = GraphModule() + self.joint_graph0 = GraphModule() + self.mask_graph0 = GraphModule() + self.fw_graph1 = GraphModule() + self.joint_graph1 = GraphModule() + self.mask_graph1 = GraphModule() + self.fw_graph2 = GraphModule() + self.joint_graph2 = GraphModule() + self.mask_graph2 = GraphModule() + self.fw_graph3 = GraphModule() + self.joint_graph3 = GraphModule() + self.mask_graph3 = GraphModule() + self.fw_graph4 = GraphModule() + self.joint_graph4 = GraphModule() + self.mask_graph4 = GraphModule() + self.fw_graph5 = GraphModule() + self.joint_graph5 = GraphModule() + self.mask_graph5 = GraphModule() + self.fw_graph6 = GraphModule() + self.joint_graph6 = GraphModule() + self.mask_graph6 = GraphModule() + self.fw_graph7 = GraphModule() + self.joint_graph7 = GraphModule() + self.mask_graph7 = GraphModule() + self.fw_graph8 = GraphModule() + self.joint_graph8 = GraphModule() + self.mask_graph8 = GraphModule() + self.fw_graph9 = GraphModule() + self.joint_graph9 = GraphModule() + self.mask_graph9 = GraphModule() + self.fw_graph10 = GraphModule() + self.joint_graph10 = GraphModule() + self.mask_graph10 = GraphModule() + self.fw_graph11 = GraphModule() + self.joint_graph11 = GraphModule() + self.mask_graph11 = GraphModule() + self.fw_graph12 = GraphModule() + self.joint_graph12 = GraphModule() + self.mask_graph12 = GraphModule() + self.fw_graph13 = GraphModule() + self.joint_graph13 = GraphModule() + self.mask_graph13 = GraphModule() + self.fw_graph14 = GraphModule() + self.joint_graph14 = GraphModule() + self.mask_graph14 = GraphModule() + self.fw_graph15 = GraphModule() + self.joint_graph15 = GraphModule() + self.mask_graph15 = GraphModule() + self.fw_graph16 = GraphModule() + self.joint_graph16 = GraphModule() + self.mask_graph16 = GraphModule() + self.fw_graph17 = GraphModule() + self.joint_graph17 = GraphModule() + self.mask_graph17 = GraphModule() + self.fw_graph18 = GraphModule() + self.joint_graph18 = GraphModule() + self.mask_graph18 = GraphModule() + self.fw_graph19 = GraphModule() + self.joint_graph19 = GraphModule() + self.mask_graph19 = GraphModule() + self.fw_graph20 = GraphModule() + self.joint_graph20 = GraphModule() + self.mask_graph20 = GraphModule() + self.fw_graph21 = GraphModule() + self.joint_graph21 = GraphModule() + self.mask_graph21 = GraphModule() + self.fw_graph22 = GraphModule() + self.joint_graph22 = GraphModule() + self.mask_graph22 = GraphModule() + self.fw_graph23 = GraphModule() + self.joint_graph23 = GraphModule() + self.mask_graph23 = GraphModule() + self.fw_graph24 = GraphModule() + self.joint_graph24 = GraphModule() + self.mask_graph24 = GraphModule() + self.fw_graph25 = GraphModule() + self.joint_graph25 = GraphModule() + self.mask_graph25 = GraphModule() + self.fw_graph26 = GraphModule() + self.joint_graph26 = GraphModule() + self.mask_graph26 = GraphModule() + + + + def forward(self, _local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7, _local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15, _local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23, _local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31, _local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39, _local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47, _local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55, _local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63, _local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71, _local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79, _local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87, _local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95, _local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103, _local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111, _local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119, _local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127, _local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135, _local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143, _local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151, _local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159, _local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167, _local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175, _local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183, _local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191, _local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199, _local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207, _local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215, _local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223, _local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231, _local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239, _local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247, _local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255, _local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263, _local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271, _local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279, _local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287, _local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295, _local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303, _local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311, _local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319, _local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327, _local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335, _local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343, _local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351, _local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359, _local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367, _local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375, _local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383, _local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391, _local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399, _local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407, _local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415, sym_size_int_1, sym_size_int_5, sym_size_int_9, sym_size_int_13, sym_size_int_17, sym_size_int_21, sym_size_int_25, sym_size_int_29, sym_size_int_33, sym_size_int_37, sym_size_int_41, sym_size_int_45, sym_size_int_49, sym_size_int_53, sym_size_int_57, sym_size_int_61, sym_size_int_65, sym_size_int_69, sym_size_int_73, sym_size_int_77, sym_size_int_81, sym_size_int_85, sym_size_int_89, sym_size_int_93, sym_size_int_97, sym_size_int_101, add_1781, add_1796, add_1811, add_1826, add_1841, add_1856, add_1871, add_1886, add_1901, add_1916, add_1931, add_1946, add_1961, add_1976, add_1991, add_2006, add_2021, add_2036, add_2051, add_2066, add_2081, add_2096, add_2111, add_2126, add_2141, add_2156, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_31, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_47, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_63, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_79, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_95, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_111, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_127, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_143, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_159, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_175, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_191, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_207, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_223, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_239, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_255, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_271, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_287, primals_289, primals_290, primals_291, primals_292, primals_293, primals_294, primals_295, primals_296, primals_297, primals_298, primals_299, primals_300, primals_301, primals_303, primals_305, primals_306, primals_307, primals_308, primals_309, primals_310, primals_311, primals_312, primals_313, primals_314, primals_315, primals_316, primals_317, primals_319, primals_321, primals_322, primals_323, primals_324, primals_325, primals_326, primals_327, primals_328, primals_329, primals_330, primals_331, primals_332, primals_333, primals_335, primals_337, primals_338, primals_339, primals_340, primals_341, primals_342, primals_343, primals_344, primals_345, primals_346, primals_347, primals_348, primals_349, primals_351, primals_353, primals_354, primals_355, primals_356, primals_357, primals_358, primals_359, primals_360, primals_361, primals_362, primals_363, primals_364, primals_365, primals_367, primals_369, primals_370, primals_371, primals_372, primals_373, primals_374, primals_375, primals_376, primals_377, primals_378, primals_379, primals_380, primals_381, primals_383, primals_385, primals_386, primals_387, primals_388, primals_389, primals_390, primals_391, primals_392, primals_393, primals_394, primals_395, primals_396, primals_397, primals_399, primals_401, primals_402, primals_403, primals_404, primals_405, primals_406, primals_407, primals_408, primals_409, primals_410, primals_411, primals_412, primals_413, primals_415, primals_417, primals_418, primals_419, primals_420, primals_421, primals_422, primals_423, primals_424, primals_425, primals_426, primals_427, primals_428, primals_429, primals_431, primals_433, primals_434, primals_435, primals_436, primals_437, primals_438, primals_439, primals_440, embedding, rsqrt, view_3, getitem_2, rsqrt_1, view_17, permute_3, permute_4, permute_5, getitem_6, getitem_7, mm_3, rsqrt_2, view_26, mm_4, mm_5, view_32, add_5, rsqrt_3, view_36, getitem_11, rsqrt_4, view_50, permute_14, permute_15, permute_16, getitem_15, getitem_16, add_8, rsqrt_5, view_58, mm_11, amax, sum_1, getitem_19, getitem_21, div_2, getitem_22, index_1, cumsum_2, _grouped_mm, _grouped_mm_1, mul_35, mm_12, mm_13, mul_55, add_73, rsqrt_6, view_103, getitem_121, rsqrt_7, view_117, permute_29, permute_30, permute_31, getitem_125, getitem_126, add_76, rsqrt_8, view_125, mm_19, amax_1, sum_5, getitem_129, getitem_131, div_7, getitem_132, index_3, cumsum_5, _grouped_mm_3, _grouped_mm_4, mul_84, mm_20, mm_21, mul_104, add_141, rsqrt_9, view_170, getitem_231, rsqrt_10, view_184, permute_44, permute_45, permute_46, getitem_235, getitem_236, add_144, rsqrt_11, view_192, mm_27, amax_2, sum_9, getitem_239, getitem_241, div_12, getitem_242, index_5, cumsum_8, _grouped_mm_6, _grouped_mm_7, mul_133, mm_28, mm_29, mul_153, add_209, rsqrt_12, view_237, getitem_341, rsqrt_13, view_251, permute_59, permute_60, permute_61, getitem_345, getitem_346, add_212, rsqrt_14, view_259, mm_35, amax_3, sum_13, getitem_349, getitem_351, div_17, getitem_352, index_7, cumsum_11, _grouped_mm_9, _grouped_mm_10, mul_182, mm_36, mm_37, mul_202, add_277, rsqrt_15, view_304, getitem_451, rsqrt_16, view_318, permute_74, permute_75, permute_76, getitem_455, getitem_456, add_280, rsqrt_17, view_326, mm_43, amax_4, sum_17, getitem_459, getitem_461, div_22, getitem_462, index_9, cumsum_14, _grouped_mm_12, _grouped_mm_13, mul_231, mm_44, mm_45, mul_251, add_345, rsqrt_18, view_371, getitem_561, rsqrt_19, view_385, permute_89, permute_90, permute_91, getitem_565, getitem_566, add_348, rsqrt_20, view_393, mm_51, amax_5, sum_21, getitem_569, getitem_571, div_27, getitem_572, index_11, cumsum_17, _grouped_mm_15, _grouped_mm_16, mul_280, mm_52, mm_53, mul_300, add_413, rsqrt_21, view_438, getitem_671, rsqrt_22, view_452, permute_104, permute_105, permute_106, getitem_675, getitem_676, add_416, rsqrt_23, view_460, mm_59, amax_6, sum_25, getitem_679, getitem_681, div_32, getitem_682, index_13, cumsum_20, _grouped_mm_18, _grouped_mm_19, mul_329, mm_60, mm_61, mul_349, add_481, rsqrt_24, view_505, getitem_781, rsqrt_25, view_519, permute_119, permute_120, permute_121, getitem_785, getitem_786, add_484, rsqrt_26, view_527, mm_67, amax_7, sum_29, getitem_789, getitem_791, div_37, getitem_792, index_15, cumsum_23, _grouped_mm_21, _grouped_mm_22, mul_378, mm_68, mm_69, mul_398, add_549, rsqrt_27, view_572, getitem_891, rsqrt_28, view_586, permute_134, permute_135, permute_136, getitem_895, getitem_896, add_552, rsqrt_29, view_594, mm_75, amax_8, sum_33, getitem_899, getitem_901, div_42, getitem_902, index_17, cumsum_26, _grouped_mm_24, _grouped_mm_25, mul_427, mm_76, mm_77, mul_447, add_617, rsqrt_30, view_639, getitem_1001, rsqrt_31, view_653, permute_149, permute_150, permute_151, getitem_1005, getitem_1006, add_620, rsqrt_32, view_661, mm_83, amax_9, sum_37, getitem_1009, getitem_1011, div_47, getitem_1012, index_19, cumsum_29, _grouped_mm_27, _grouped_mm_28, mul_476, mm_84, mm_85, mul_496, add_685, rsqrt_33, view_706, getitem_1111, rsqrt_34, view_720, permute_164, permute_165, permute_166, getitem_1115, getitem_1116, add_688, rsqrt_35, view_728, mm_91, amax_10, sum_41, getitem_1119, getitem_1121, div_52, getitem_1122, index_21, cumsum_32, _grouped_mm_30, _grouped_mm_31, mul_525, mm_92, mm_93, mul_545, add_753, rsqrt_36, view_773, getitem_1221, rsqrt_37, view_787, permute_179, permute_180, permute_181, getitem_1225, getitem_1226, add_756, rsqrt_38, view_795, mm_99, amax_11, sum_45, getitem_1229, getitem_1231, div_57, getitem_1232, index_23, cumsum_35, _grouped_mm_33, _grouped_mm_34, mul_574, mm_100, mm_101, mul_594, add_821, rsqrt_39, view_840, getitem_1331, rsqrt_40, view_854, permute_194, permute_195, permute_196, getitem_1335, getitem_1336, add_824, rsqrt_41, view_862, mm_107, amax_12, sum_49, getitem_1339, getitem_1341, div_62, getitem_1342, index_25, cumsum_38, _grouped_mm_36, _grouped_mm_37, mul_623, mm_108, mm_109, mul_643, add_889, rsqrt_42, view_907, getitem_1441, rsqrt_43, view_921, permute_209, permute_210, permute_211, getitem_1445, getitem_1446, add_892, rsqrt_44, view_929, mm_115, amax_13, sum_53, getitem_1449, getitem_1451, div_67, getitem_1452, index_27, cumsum_41, _grouped_mm_39, _grouped_mm_40, mul_672, mm_116, mm_117, mul_692, add_957, rsqrt_45, view_974, getitem_1551, rsqrt_46, view_988, permute_224, permute_225, permute_226, getitem_1555, getitem_1556, add_960, rsqrt_47, view_996, mm_123, amax_14, sum_57, getitem_1559, getitem_1561, div_72, getitem_1562, index_29, cumsum_44, _grouped_mm_42, _grouped_mm_43, mul_721, mm_124, mm_125, mul_741, add_1025, rsqrt_48, view_1041, getitem_1661, rsqrt_49, view_1055, permute_239, permute_240, permute_241, getitem_1665, getitem_1666, add_1028, rsqrt_50, view_1063, mm_131, amax_15, sum_61, getitem_1669, getitem_1671, div_77, getitem_1672, index_31, cumsum_47, _grouped_mm_45, _grouped_mm_46, mul_770, mm_132, mm_133, mul_790, add_1093, rsqrt_51, view_1108, getitem_1771, rsqrt_52, view_1122, permute_254, permute_255, permute_256, getitem_1775, getitem_1776, add_1096, rsqrt_53, view_1130, mm_139, amax_16, sum_65, getitem_1779, getitem_1781, div_82, getitem_1782, index_33, cumsum_50, _grouped_mm_48, _grouped_mm_49, mul_819, mm_140, mm_141, mul_839, add_1161, rsqrt_54, view_1175, getitem_1881, rsqrt_55, view_1189, permute_269, permute_270, permute_271, getitem_1885, getitem_1886, add_1164, rsqrt_56, view_1197, mm_147, amax_17, sum_69, getitem_1889, getitem_1891, div_87, getitem_1892, index_35, cumsum_53, _grouped_mm_51, _grouped_mm_52, mul_868, mm_148, mm_149, mul_888, add_1229, rsqrt_57, view_1242, getitem_1991, rsqrt_58, view_1256, permute_284, permute_285, permute_286, getitem_1995, getitem_1996, add_1232, rsqrt_59, view_1264, mm_155, amax_18, sum_73, getitem_1999, getitem_2001, div_92, getitem_2002, index_37, cumsum_56, _grouped_mm_54, _grouped_mm_55, mul_917, mm_156, mm_157, mul_937, add_1297, rsqrt_60, view_1309, getitem_2101, rsqrt_61, view_1323, permute_299, permute_300, permute_301, getitem_2105, getitem_2106, add_1300, rsqrt_62, view_1331, mm_163, amax_19, sum_77, getitem_2109, getitem_2111, div_97, getitem_2112, index_39, cumsum_59, _grouped_mm_57, _grouped_mm_58, mul_966, mm_164, mm_165, mul_986, add_1365, rsqrt_63, view_1376, getitem_2211, rsqrt_64, view_1390, permute_314, permute_315, permute_316, getitem_2215, getitem_2216, add_1368, rsqrt_65, view_1398, mm_171, amax_20, sum_81, getitem_2219, getitem_2221, div_102, getitem_2222, index_41, cumsum_62, _grouped_mm_60, _grouped_mm_61, mul_1015, mm_172, mm_173, mul_1035, add_1433, rsqrt_66, view_1443, getitem_2321, rsqrt_67, view_1457, permute_329, permute_330, permute_331, getitem_2325, getitem_2326, add_1436, rsqrt_68, view_1465, mm_179, amax_21, sum_85, getitem_2329, getitem_2331, div_107, getitem_2332, index_43, cumsum_65, _grouped_mm_63, _grouped_mm_64, mul_1064, mm_180, mm_181, mul_1084, add_1501, rsqrt_69, view_1510, getitem_2431, rsqrt_70, view_1524, permute_344, permute_345, permute_346, getitem_2435, getitem_2436, add_1504, rsqrt_71, view_1532, mm_187, amax_22, sum_89, getitem_2439, getitem_2441, div_112, getitem_2442, index_45, cumsum_68, _grouped_mm_66, _grouped_mm_67, mul_1113, mm_188, mm_189, mul_1133, add_1569, rsqrt_72, view_1577, getitem_2541, rsqrt_73, view_1591, permute_359, permute_360, permute_361, getitem_2545, getitem_2546, add_1572, rsqrt_74, view_1599, mm_195, amax_23, sum_93, getitem_2549, getitem_2551, div_117, getitem_2552, index_47, cumsum_71, _grouped_mm_69, _grouped_mm_70, mul_1162, mm_196, mm_197, mul_1182, add_1637, rsqrt_75, view_1644, getitem_2651, rsqrt_76, view_1658, permute_374, permute_375, permute_376, getitem_2655, getitem_2656, add_1640, rsqrt_77, view_1666, mm_203, amax_24, sum_97, getitem_2659, getitem_2661, div_122, getitem_2662, index_49, cumsum_74, _grouped_mm_72, _grouped_mm_73, mul_1211, mm_204, mm_205, mul_1231, add_1705, rsqrt_78, view_1711, getitem_2761, rsqrt_79, view_1725, permute_389, permute_390, permute_391, getitem_2765, getitem_2766, add_1708, rsqrt_80, view_1733, mm_211, amax_25, sum_101, getitem_2769, getitem_2771, div_127, getitem_2772, index_51, cumsum_77, _grouped_mm_75, _grouped_mm_76, mul_1260, mm_212, mm_213, mul_1280, add_1773, rsqrt_81, view_1778, permute_406, permute_407, permute_422, permute_426, permute_430, full_default_54, permute_456, permute_457, permute_472, permute_476, permute_480, permute_506, permute_507, permute_522, permute_526, permute_530, permute_556, permute_557, permute_572, permute_576, permute_580, permute_606, permute_607, permute_622, permute_626, permute_630, permute_656, permute_657, permute_672, permute_676, permute_680, permute_706, permute_707, permute_722, permute_726, permute_730, permute_756, permute_757, permute_772, permute_776, permute_780, permute_806, permute_807, permute_822, permute_826, permute_830, permute_856, permute_857, permute_872, permute_876, permute_880, permute_906, permute_907, permute_922, permute_926, permute_930, permute_956, permute_957, permute_972, permute_976, permute_980, permute_1006, permute_1007, permute_1022, permute_1026, permute_1030, permute_1056, permute_1057, permute_1072, permute_1076, permute_1080, permute_1106, permute_1107, permute_1122, permute_1126, permute_1130, permute_1156, permute_1157, permute_1172, permute_1176, permute_1180, permute_1206, permute_1207, permute_1222, permute_1226, permute_1230, permute_1256, permute_1257, permute_1272, permute_1276, permute_1280, permute_1306, permute_1307, permute_1322, permute_1326, permute_1330, permute_1356, permute_1357, permute_1372, permute_1376, permute_1380, permute_1406, permute_1407, permute_1422, permute_1426, permute_1430, permute_1456, permute_1457, permute_1472, permute_1476, permute_1480, permute_1506, permute_1507, permute_1522, permute_1526, permute_1530, permute_1556, permute_1557, permute_1572, permute_1576, permute_1580, permute_1606, permute_1607, permute_1622, permute_1626, permute_1630, permute_1656, permute_1657, permute_1672, permute_1676, permute_1680, tangents_1): + view_1780 = torch.ops.aten.view.default(tangents_1, [8192, 102400]); tangents_1 = None + permute_402 = torch.ops.aten.permute.default(view_1780, [1, 0]) + mm_216 = torch.ops.aten.mm.default(permute_402, view_1778); permute_402 = view_1778 = None + convert_element_type_1444 = torch.ops.prims.convert_element_type.default(primals_440, torch.bfloat16); primals_440 = None + all_gather_into_tensor_454 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1444, 128, '0'); convert_element_type_1444 = None + wait_tensor_558 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_454); all_gather_into_tensor_454 = None + permute_401 = torch.ops.aten.permute.default(wait_tensor_558, [1, 0]); wait_tensor_558 = None + permute_404 = torch.ops.aten.permute.default(permute_401, [1, 0]); permute_401 = None + mm_217 = torch.ops.aten.mm.default(view_1780, permute_404); view_1780 = permute_404 = None + view_1781 = torch.ops.aten.view.default(mm_217, [2, 4096, 2048]); mm_217 = None + convert_element_type_1451 = torch.ops.prims.convert_element_type.default(mm_216, torch.float32); mm_216 = None + reduce_scatter_tensor = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1451, 'avg', 128, '0'); convert_element_type_1451 = None + wait_tensor_559 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None + convert_element_type_1452 = torch.ops.prims.convert_element_type.default(view_1781, torch.float32); view_1781 = None + convert_element_type_1441 = torch.ops.prims.convert_element_type.default(primals_439, torch.bfloat16); primals_439 = None + all_gather_into_tensor_453 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1441, 128, '0'); convert_element_type_1441 = None + wait_tensor_557 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_453); all_gather_into_tensor_453 = None + convert_element_type_1454 = torch.ops.prims.convert_element_type.default(wait_tensor_557, torch.float32); wait_tensor_557 = None + mul_1285 = torch.ops.aten.mul.Tensor(convert_element_type_1452, convert_element_type_1454); convert_element_type_1454 = None + convert_element_type_1442 = torch.ops.prims.convert_element_type.default(add_1773, torch.float32); add_1773 = None + mul_1283 = torch.ops.aten.mul.Tensor(convert_element_type_1442, rsqrt_81); convert_element_type_1442 = None + mul_1287 = torch.ops.aten.mul.Tensor(mul_1283, mul_1285) + sum_105 = torch.ops.aten.sum.dim_IntList(mul_1287, [2], True); mul_1287 = None + div_131 = torch.ops.aten.div.Tensor(mul_1283, 2048) + mul_1288 = torch.ops.aten.mul.Tensor(div_131, sum_105); div_131 = sum_105 = None + sub_624 = torch.ops.aten.sub.Tensor(mul_1285, mul_1288); mul_1285 = mul_1288 = None + mul_1289 = torch.ops.aten.mul.Tensor(sub_624, rsqrt_81); sub_624 = rsqrt_81 = None + mul_1290 = torch.ops.aten.mul.Tensor(convert_element_type_1452, mul_1283); convert_element_type_1452 = mul_1283 = None + sum_106 = torch.ops.aten.sum.dim_IntList(mul_1290, [0, 1]); mul_1290 = None + convert_element_type_1455 = torch.ops.prims.convert_element_type.default(mul_1289, torch.bfloat16); mul_1289 = None + convert_element_type_default_82 = torch.ops.prims.convert_element_type.default(sum_106, torch.float32); sum_106 = None + reduce_scatter_tensor_1 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_82, 'avg', 128, '0'); convert_element_type_default_82 = None + wait_tensor_560 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None + view_1782 = torch.ops.aten.view.default(convert_element_type_1455, [8192, 2048]) + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_1782, 1) + convert_element_type_1458 = torch.ops.prims.convert_element_type.default(unsqueeze_53, torch.float32); unsqueeze_53 = None + bmm_26 = torch.ops.aten.bmm.default(permute_406, convert_element_type_1458); permute_406 = None + bmm_27 = torch.ops.aten.bmm.default(convert_element_type_1458, permute_407); convert_element_type_1458 = permute_407 = None + convert_element_type_1459 = torch.ops.prims.convert_element_type.default(bmm_26, torch.bfloat16); bmm_26 = None + view_1783 = torch.ops.aten.view.default(bmm_27, [8192, 6]); bmm_27 = None + view_1784 = torch.ops.aten.view.default(convert_element_type_1459, [49152, 2048]); convert_element_type_1459 = None + index_52 = torch.ops.aten.index.Tensor(view_1784, [getitem_2771]); view_1784 = getitem_2771 = None + permute_408 = torch.ops.aten.permute.default(view_1782, [1, 0]) + mm_218 = torch.ops.aten.mm.default(permute_408, mul_1280); permute_408 = mul_1280 = None + convert_element_type_1436 = torch.ops.prims.convert_element_type.default(primals_438, torch.bfloat16); primals_438 = None + all_gather_into_tensor_452 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1436, 128, '0'); convert_element_type_1436 = None + wait_tensor_556 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_452); all_gather_into_tensor_452 = None + permute_400 = torch.ops.aten.permute.default(wait_tensor_556, [1, 0]); wait_tensor_556 = None + permute_410 = torch.ops.aten.permute.default(permute_400, [1, 0]); permute_400 = None + mm_219 = torch.ops.aten.mm.default(view_1782, permute_410); view_1782 = permute_410 = None + convert_element_type_1464 = torch.ops.prims.convert_element_type.default(mm_218, torch.float32); mm_218 = None + reduce_scatter_tensor_2 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1464, 'avg', 128, '0'); convert_element_type_1464 = None + wait_tensor_561 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None + convert_element_type_1431 = torch.ops.prims.convert_element_type.default(mm_212, torch.float32); mm_212 = None + neg_52 = torch.ops.aten.neg.default(convert_element_type_1431) + exp_78 = torch.ops.aten.exp.default(neg_52); neg_52 = None + add_1768 = torch.ops.aten.add.Tensor(exp_78, 1); exp_78 = None + div_130 = torch.ops.aten.div.Tensor(convert_element_type_1431, add_1768) + convert_element_type_1432 = torch.ops.prims.convert_element_type.default(div_130, torch.bfloat16); div_130 = None + mul_1291 = torch.ops.aten.mul.Tensor(mm_219, convert_element_type_1432); convert_element_type_1432 = None + mul_1292 = torch.ops.aten.mul.Tensor(mm_219, mm_213); mm_219 = mm_213 = None + permute_412 = torch.ops.aten.permute.default(mul_1291, [1, 0]) + mm_220 = torch.ops.aten.mm.default(permute_412, view_1733); permute_412 = None + convert_element_type_1433 = torch.ops.prims.convert_element_type.default(primals_437, torch.bfloat16); primals_437 = None + all_gather_into_tensor_451 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1433, 128, '0'); convert_element_type_1433 = None + wait_tensor_555 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_451); all_gather_into_tensor_451 = None + permute_399 = torch.ops.aten.permute.default(wait_tensor_555, [1, 0]); wait_tensor_555 = None + permute_414 = torch.ops.aten.permute.default(permute_399, [1, 0]); permute_399 = None + mm_221 = torch.ops.aten.mm.default(mul_1291, permute_414); mul_1291 = permute_414 = None + convert_element_type_1469 = torch.ops.prims.convert_element_type.default(mm_220, torch.float32); mm_220 = None + reduce_scatter_tensor_3 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1469, 'avg', 128, '0'); convert_element_type_1469 = None + wait_tensor_562 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3); reduce_scatter_tensor_3 = None + convert_element_type_1470 = torch.ops.prims.convert_element_type.default(mul_1292, torch.float32); mul_1292 = None + reciprocal = torch.ops.aten.reciprocal.default(add_1768); add_1768 = None + mul_1293 = torch.ops.aten.mul.Tensor(reciprocal, 1); reciprocal = None + mul_1294 = torch.ops.aten.mul.Tensor(convert_element_type_1470, mul_1293); convert_element_type_1470 = None + sub_625 = torch.ops.aten.sub.Tensor(1, mul_1293); mul_1293 = None + mul_1295 = torch.ops.aten.mul.Tensor(convert_element_type_1431, sub_625); convert_element_type_1431 = sub_625 = None + add_1776 = torch.ops.aten.add.Tensor(mul_1295, 1); mul_1295 = None + mul_1296 = torch.ops.aten.mul.Tensor(mul_1294, add_1776); mul_1294 = add_1776 = None + convert_element_type_1472 = torch.ops.prims.convert_element_type.default(mul_1296, torch.bfloat16); mul_1296 = None + permute_416 = torch.ops.aten.permute.default(convert_element_type_1472, [1, 0]) + mm_222 = torch.ops.aten.mm.default(permute_416, view_1733); permute_416 = None + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(primals_436, torch.bfloat16); primals_436 = None + all_gather_into_tensor_450 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1428, 128, '0'); convert_element_type_1428 = None + wait_tensor_554 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_450); all_gather_into_tensor_450 = None + permute_398 = torch.ops.aten.permute.default(wait_tensor_554, [1, 0]); wait_tensor_554 = None + permute_418 = torch.ops.aten.permute.default(permute_398, [1, 0]); permute_398 = None + mm_223 = torch.ops.aten.mm.default(convert_element_type_1472, permute_418); convert_element_type_1472 = permute_418 = None + add_1777 = torch.ops.aten.add.Tensor(mm_221, mm_223); mm_221 = mm_223 = None + convert_element_type_1477 = torch.ops.prims.convert_element_type.default(mm_222, torch.float32); mm_222 = None + reduce_scatter_tensor_4 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1477, 'avg', 128, '0'); convert_element_type_1477 = None + wait_tensor_563 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_4); reduce_scatter_tensor_4 = None + all_to_all_single_78 = torch.ops._c10d_functional.all_to_all_single.default(index_52, [_local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415], [_local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407], '1033'); index_52 = None + wait_tensor_564 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_78); all_to_all_single_78 = None + full_348 = torch.ops.aten.full.default([sym_size_int_101, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_101 = None + slice_scatter = torch.ops.aten.slice_scatter.default(full_348, wait_tensor_564, 0, 0, -1); wait_tensor_564 = None + index_53 = torch.ops.aten.index.Tensor(slice_scatter, [getitem_2772]); slice_scatter = None + permute_420 = torch.ops.aten.permute.default(index_53, [1, 0]) + _grouped_mm_78 = torch.ops.aten._grouped_mm.default(permute_420, mul_1260, cumsum_77); permute_420 = mul_1260 = None + _grouped_mm_79 = torch.ops.aten._grouped_mm.default(index_53, permute_422, cumsum_77); index_53 = permute_422 = None + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(_grouped_mm_75, torch.float32); _grouped_mm_75 = None + neg_51 = torch.ops.aten.neg.default(convert_element_type_1426) + exp_77 = torch.ops.aten.exp.default(neg_51); neg_51 = None + add_1732 = torch.ops.aten.add.Tensor(exp_77, 1); exp_77 = None + div_129 = torch.ops.aten.div.Tensor(convert_element_type_1426, add_1732) + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(div_129, torch.bfloat16); div_129 = None + mul_1297 = torch.ops.aten.mul.Tensor(_grouped_mm_79, convert_element_type_1427); convert_element_type_1427 = None + mul_1298 = torch.ops.aten.mul.Tensor(_grouped_mm_79, _grouped_mm_76); _grouped_mm_79 = _grouped_mm_76 = None + permute_424 = torch.ops.aten.permute.default(mul_1297, [1, 0]) + _grouped_mm_80 = torch.ops.aten._grouped_mm.default(permute_424, index_51, cumsum_77); permute_424 = None + _grouped_mm_81 = torch.ops.aten._grouped_mm.default(mul_1297, permute_426, cumsum_77); mul_1297 = permute_426 = None + convert_element_type_1478 = torch.ops.prims.convert_element_type.default(mul_1298, torch.float32); mul_1298 = None + reciprocal_1 = torch.ops.aten.reciprocal.default(add_1732); add_1732 = None + mul_1299 = torch.ops.aten.mul.Tensor(reciprocal_1, 1); reciprocal_1 = None + mul_1300 = torch.ops.aten.mul.Tensor(convert_element_type_1478, mul_1299); convert_element_type_1478 = None + sub_626 = torch.ops.aten.sub.Tensor(1, mul_1299); mul_1299 = None + mul_1301 = torch.ops.aten.mul.Tensor(convert_element_type_1426, sub_626); convert_element_type_1426 = sub_626 = None + add_1779 = torch.ops.aten.add.Tensor(mul_1301, 1); mul_1301 = None + mul_1302 = torch.ops.aten.mul.Tensor(mul_1300, add_1779); mul_1300 = add_1779 = None + convert_element_type_1480 = torch.ops.prims.convert_element_type.default(mul_1302, torch.bfloat16); mul_1302 = None + permute_428 = torch.ops.aten.permute.default(convert_element_type_1480, [1, 0]) + _grouped_mm_82 = torch.ops.aten._grouped_mm.default(permute_428, index_51, cumsum_77); permute_428 = index_51 = None + _grouped_mm_83 = torch.ops.aten._grouped_mm.default(convert_element_type_1480, permute_430, cumsum_77); convert_element_type_1480 = permute_430 = cumsum_77 = None + add_1780 = torch.ops.aten.add.Tensor(_grouped_mm_81, _grouped_mm_83); _grouped_mm_81 = _grouped_mm_83 = None + convert_element_type_1481 = torch.ops.prims.convert_element_type.default(_grouped_mm_80, torch.float32); _grouped_mm_80 = None + div_132 = torch.ops.aten.div.Tensor(convert_element_type_1481, 128); convert_element_type_1481 = None + split_157 = torch.ops.aten.split.Tensor(div_132, 88, 1); div_132 = None + getitem_2885 = split_157[0] + getitem_2902 = split_157[1] + getitem_2919 = split_157[2] + getitem_2936 = split_157[3] + getitem_2953 = split_157[4] + getitem_2970 = split_157[5] + getitem_2987 = split_157[6] + getitem_3004 = split_157[7] + getitem_3021 = split_157[8] + getitem_3038 = split_157[9] + getitem_3055 = split_157[10] + getitem_3072 = split_157[11] + getitem_3089 = split_157[12] + getitem_3106 = split_157[13] + getitem_3123 = split_157[14] + getitem_3140 = split_157[15]; split_157 = None + cat_236 = torch.ops.aten.cat.default([getitem_2885, getitem_2902, getitem_2919, getitem_2936, getitem_2953, getitem_2970, getitem_2987, getitem_3004, getitem_3021, getitem_3038, getitem_3055, getitem_3072, getitem_3089, getitem_3106, getitem_3123, getitem_3140]); getitem_2885 = getitem_2902 = getitem_2919 = getitem_2936 = getitem_2953 = getitem_2970 = getitem_2987 = getitem_3004 = getitem_3021 = getitem_3038 = getitem_3055 = getitem_3072 = getitem_3089 = getitem_3106 = getitem_3123 = getitem_3140 = None + reduce_scatter_tensor_5 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_236, 'sum', 16, '1025'); cat_236 = None + wait_tensor_565 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5); reduce_scatter_tensor_5 = None + convert_element_type_1482 = torch.ops.prims.convert_element_type.default(_grouped_mm_78, torch.float32); _grouped_mm_78 = None + div_133 = torch.ops.aten.div.Tensor(convert_element_type_1482, 128); convert_element_type_1482 = None + split_174 = torch.ops.aten.split.Tensor(div_133, 128, 1); div_133 = None + getitem_3157 = split_174[0] + getitem_3174 = split_174[1] + getitem_3191 = split_174[2] + getitem_3208 = split_174[3] + getitem_3225 = split_174[4] + getitem_3242 = split_174[5] + getitem_3259 = split_174[6] + getitem_3276 = split_174[7] + getitem_3293 = split_174[8] + getitem_3310 = split_174[9] + getitem_3327 = split_174[10] + getitem_3344 = split_174[11] + getitem_3361 = split_174[12] + getitem_3378 = split_174[13] + getitem_3395 = split_174[14] + getitem_3412 = split_174[15]; split_174 = None + cat_237 = torch.ops.aten.cat.default([getitem_3157, getitem_3174, getitem_3191, getitem_3208, getitem_3225, getitem_3242, getitem_3259, getitem_3276, getitem_3293, getitem_3310, getitem_3327, getitem_3344, getitem_3361, getitem_3378, getitem_3395, getitem_3412]); getitem_3157 = getitem_3174 = getitem_3191 = getitem_3208 = getitem_3225 = getitem_3242 = getitem_3259 = getitem_3276 = getitem_3293 = getitem_3310 = getitem_3327 = getitem_3344 = getitem_3361 = getitem_3378 = getitem_3395 = getitem_3412 = None + reduce_scatter_tensor_6 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_237, 'sum', 16, '1025'); cat_237 = None + wait_tensor_566 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_6); reduce_scatter_tensor_6 = None + convert_element_type_1483 = torch.ops.prims.convert_element_type.default(_grouped_mm_82, torch.float32); _grouped_mm_82 = None + div_134 = torch.ops.aten.div.Tensor(convert_element_type_1483, 128); convert_element_type_1483 = None + split_191 = torch.ops.aten.split.Tensor(div_134, 88, 1); div_134 = None + getitem_3429 = split_191[0] + getitem_3446 = split_191[1] + getitem_3463 = split_191[2] + getitem_3480 = split_191[3] + getitem_3497 = split_191[4] + getitem_3514 = split_191[5] + getitem_3531 = split_191[6] + getitem_3548 = split_191[7] + getitem_3565 = split_191[8] + getitem_3582 = split_191[9] + getitem_3599 = split_191[10] + getitem_3616 = split_191[11] + getitem_3633 = split_191[12] + getitem_3650 = split_191[13] + getitem_3667 = split_191[14] + getitem_3684 = split_191[15]; split_191 = None + cat_238 = torch.ops.aten.cat.default([getitem_3429, getitem_3446, getitem_3463, getitem_3480, getitem_3497, getitem_3514, getitem_3531, getitem_3548, getitem_3565, getitem_3582, getitem_3599, getitem_3616, getitem_3633, getitem_3650, getitem_3667, getitem_3684]); getitem_3429 = getitem_3446 = getitem_3463 = getitem_3480 = getitem_3497 = getitem_3514 = getitem_3531 = getitem_3548 = getitem_3565 = getitem_3582 = getitem_3599 = getitem_3616 = getitem_3633 = getitem_3650 = getitem_3667 = getitem_3684 = None + reduce_scatter_tensor_7 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_238, 'sum', 16, '1025'); cat_238 = None + wait_tensor_567 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7); reduce_scatter_tensor_7 = None + index_put_52 = torch.ops.aten.index_put.default(full_348, [getitem_2772], add_1780, True); full_348 = getitem_2772 = add_1780 = None + slice_162 = torch.ops.aten.slice.Tensor(index_put_52, 0, 0, add_1781); index_put_52 = add_1781 = None + all_to_all_single_79 = torch.ops._c10d_functional.all_to_all_single.default(slice_162, [_local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407], [_local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415], '1033'); slice_162 = _local_scalar_dense_400 = _local_scalar_dense_401 = _local_scalar_dense_402 = _local_scalar_dense_403 = _local_scalar_dense_404 = _local_scalar_dense_405 = _local_scalar_dense_406 = _local_scalar_dense_407 = _local_scalar_dense_408 = _local_scalar_dense_409 = _local_scalar_dense_410 = _local_scalar_dense_411 = _local_scalar_dense_412 = _local_scalar_dense_413 = _local_scalar_dense_414 = _local_scalar_dense_415 = None + wait_tensor_568 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_79); all_to_all_single_79 = None + full_default_52 = torch.ops.aten.full.default([8192, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_53 = torch.ops.aten.index_put.default(full_default_52, [div_127], wait_tensor_568, True); div_127 = wait_tensor_568 = None + add_1785 = torch.ops.aten.add.Tensor(add_1777, index_put_53); add_1777 = index_put_53 = None + mul_1303 = torch.ops.aten.mul.Tensor(view_1783, 1.0); view_1783 = None + full_default_53 = torch.ops.aten.full.default([8192, 64], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + scatter_add = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_2769, mul_1303); getitem_2769 = mul_1303 = None + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(mm_211, torch.float32); mm_211 = None + sub_600 = torch.ops.aten.sub.Tensor(convert_element_type_1415, amax_25); convert_element_type_1415 = amax_25 = None + exp_76 = torch.ops.aten.exp.default(sub_600); sub_600 = None + div_126 = torch.ops.aten.div.Tensor(exp_76, sum_101); exp_76 = sum_101 = None + mul_1304 = torch.ops.aten.mul.Tensor(scatter_add, div_126); scatter_add = None + sum_107 = torch.ops.aten.sum.dim_IntList(mul_1304, [1], True) + neg_55 = torch.ops.aten.neg.default(div_126); div_126 = None + fma = torch.ops.prims.fma.default(neg_55, sum_107, mul_1304); neg_55 = sum_107 = mul_1304 = None + convert_element_type_1484 = torch.ops.prims.convert_element_type.default(fma, torch.bfloat16); fma = None + permute_432 = torch.ops.aten.permute.default(convert_element_type_1484, [1, 0]) + mm_224 = torch.ops.aten.mm.default(permute_432, view_1733); permute_432 = view_1733 = None + convert_element_type_1412 = torch.ops.prims.convert_element_type.default(primals_431, torch.bfloat16); primals_431 = None + all_gather_into_tensor_443 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1412, 128, '0'); convert_element_type_1412 = None + wait_tensor_543 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_443); all_gather_into_tensor_443 = None + slice_159 = torch.ops.aten.slice.Tensor(wait_tensor_543, 0, 0, 64); wait_tensor_543 = None + permute_394 = torch.ops.aten.permute.default(slice_159, [1, 0]); slice_159 = None + permute_434 = torch.ops.aten.permute.default(permute_394, [1, 0]); permute_394 = None + mm_225 = torch.ops.aten.mm.default(convert_element_type_1484, permute_434); convert_element_type_1484 = permute_434 = None + add_1786 = torch.ops.aten.add.Tensor(add_1785, mm_225); add_1785 = mm_225 = None + convert_element_type_1489 = torch.ops.prims.convert_element_type.default(mm_224, torch.float32); mm_224 = None + split_207 = torch.ops.aten.split.Tensor(convert_element_type_1489, 1); convert_element_type_1489 = None + getitem_3685 = split_207[0] + getitem_3686 = split_207[1] + getitem_3687 = split_207[2] + getitem_3688 = split_207[3] + getitem_3689 = split_207[4] + getitem_3690 = split_207[5] + getitem_3691 = split_207[6] + getitem_3692 = split_207[7] + getitem_3693 = split_207[8] + getitem_3694 = split_207[9] + getitem_3695 = split_207[10] + getitem_3696 = split_207[11] + getitem_3697 = split_207[12] + getitem_3698 = split_207[13] + getitem_3699 = split_207[14] + getitem_3700 = split_207[15] + getitem_3701 = split_207[16] + getitem_3702 = split_207[17] + getitem_3703 = split_207[18] + getitem_3704 = split_207[19] + getitem_3705 = split_207[20] + getitem_3706 = split_207[21] + getitem_3707 = split_207[22] + getitem_3708 = split_207[23] + getitem_3709 = split_207[24] + getitem_3710 = split_207[25] + getitem_3711 = split_207[26] + getitem_3712 = split_207[27] + getitem_3713 = split_207[28] + getitem_3714 = split_207[29] + getitem_3715 = split_207[30] + getitem_3716 = split_207[31] + getitem_3717 = split_207[32] + getitem_3718 = split_207[33] + getitem_3719 = split_207[34] + getitem_3720 = split_207[35] + getitem_3721 = split_207[36] + getitem_3722 = split_207[37] + getitem_3723 = split_207[38] + getitem_3724 = split_207[39] + getitem_3725 = split_207[40] + getitem_3726 = split_207[41] + getitem_3727 = split_207[42] + getitem_3728 = split_207[43] + getitem_3729 = split_207[44] + getitem_3730 = split_207[45] + getitem_3731 = split_207[46] + getitem_3732 = split_207[47] + getitem_3733 = split_207[48] + getitem_3734 = split_207[49] + getitem_3735 = split_207[50] + getitem_3736 = split_207[51] + getitem_3737 = split_207[52] + getitem_3738 = split_207[53] + getitem_3739 = split_207[54] + getitem_3740 = split_207[55] + getitem_3741 = split_207[56] + getitem_3742 = split_207[57] + getitem_3743 = split_207[58] + getitem_3744 = split_207[59] + getitem_3745 = split_207[60] + getitem_3746 = split_207[61] + getitem_3747 = split_207[62] + getitem_3748 = split_207[63]; split_207 = None + constant_pad_nd = torch.ops.aten.constant_pad_nd.default(full_default_54, [0, 0, 0, 1], 0.0) + cat_239 = torch.ops.aten.cat.default([getitem_3685, getitem_3686, getitem_3687, getitem_3688, getitem_3689, getitem_3690, getitem_3691, getitem_3692, getitem_3693, getitem_3694, getitem_3695, getitem_3696, getitem_3697, getitem_3698, getitem_3699, getitem_3700, getitem_3701, getitem_3702, getitem_3703, getitem_3704, getitem_3705, getitem_3706, getitem_3707, getitem_3708, getitem_3709, getitem_3710, getitem_3711, getitem_3712, getitem_3713, getitem_3714, getitem_3715, getitem_3716, getitem_3717, getitem_3718, getitem_3719, getitem_3720, getitem_3721, getitem_3722, getitem_3723, getitem_3724, getitem_3725, getitem_3726, getitem_3727, getitem_3728, getitem_3729, getitem_3730, getitem_3731, getitem_3732, getitem_3733, getitem_3734, getitem_3735, getitem_3736, getitem_3737, getitem_3738, getitem_3739, getitem_3740, getitem_3741, getitem_3742, getitem_3743, getitem_3744, getitem_3745, getitem_3746, getitem_3747, getitem_3748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_3685 = getitem_3686 = getitem_3687 = getitem_3688 = getitem_3689 = getitem_3690 = getitem_3691 = getitem_3692 = getitem_3693 = getitem_3694 = getitem_3695 = getitem_3696 = getitem_3697 = getitem_3698 = getitem_3699 = getitem_3700 = getitem_3701 = getitem_3702 = getitem_3703 = getitem_3704 = getitem_3705 = getitem_3706 = getitem_3707 = getitem_3708 = getitem_3709 = getitem_3710 = getitem_3711 = getitem_3712 = getitem_3713 = getitem_3714 = getitem_3715 = getitem_3716 = getitem_3717 = getitem_3718 = getitem_3719 = getitem_3720 = getitem_3721 = getitem_3722 = getitem_3723 = getitem_3724 = getitem_3725 = getitem_3726 = getitem_3727 = getitem_3728 = getitem_3729 = getitem_3730 = getitem_3731 = getitem_3732 = getitem_3733 = getitem_3734 = getitem_3735 = getitem_3736 = getitem_3737 = getitem_3738 = getitem_3739 = getitem_3740 = getitem_3741 = getitem_3742 = getitem_3743 = getitem_3744 = getitem_3745 = getitem_3746 = getitem_3747 = getitem_3748 = None + reduce_scatter_tensor_8 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_239, 'avg', 128, '0'); cat_239 = None + wait_tensor_569 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_8); reduce_scatter_tensor_8 = None + view_1785 = torch.ops.aten.view.default(add_1786, [2, 4096, 2048]); add_1786 = None + convert_element_type_1490 = torch.ops.prims.convert_element_type.default(view_1785, torch.float32); view_1785 = None + convert_element_type_1409 = torch.ops.prims.convert_element_type.default(primals_429, torch.bfloat16); primals_429 = None + all_gather_into_tensor_442 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1409, 128, '0'); convert_element_type_1409 = None + wait_tensor_542 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_442); all_gather_into_tensor_442 = None + convert_element_type_1492 = torch.ops.prims.convert_element_type.default(wait_tensor_542, torch.float32); wait_tensor_542 = None + mul_1305 = torch.ops.aten.mul.Tensor(convert_element_type_1490, convert_element_type_1492); convert_element_type_1492 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(add_1708, torch.float32); add_1708 = None + mul_1240 = torch.ops.aten.mul.Tensor(convert_element_type_1410, rsqrt_80); convert_element_type_1410 = None + mul_1307 = torch.ops.aten.mul.Tensor(mul_1240, mul_1305) + sum_108 = torch.ops.aten.sum.dim_IntList(mul_1307, [2], True); mul_1307 = None + div_135 = torch.ops.aten.div.Tensor(mul_1240, 2048) + mul_1308 = torch.ops.aten.mul.Tensor(div_135, sum_108); div_135 = sum_108 = None + sub_628 = torch.ops.aten.sub.Tensor(mul_1305, mul_1308); mul_1305 = mul_1308 = None + mul_1309 = torch.ops.aten.mul.Tensor(sub_628, rsqrt_80); sub_628 = rsqrt_80 = None + mul_1310 = torch.ops.aten.mul.Tensor(convert_element_type_1490, mul_1240); convert_element_type_1490 = mul_1240 = None + sum_109 = torch.ops.aten.sum.dim_IntList(mul_1310, [0, 1]); mul_1310 = None + convert_element_type_1493 = torch.ops.prims.convert_element_type.default(mul_1309, torch.bfloat16); mul_1309 = None + add_1787 = torch.ops.aten.add.Tensor(convert_element_type_1455, convert_element_type_1493); convert_element_type_1455 = convert_element_type_1493 = None + convert_element_type_default_81 = torch.ops.prims.convert_element_type.default(sum_109, torch.float32); sum_109 = None + reduce_scatter_tensor_9 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_81, 'avg', 128, '0'); convert_element_type_default_81 = None + wait_tensor_570 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9); reduce_scatter_tensor_9 = None + view_1786 = torch.ops.aten.view.default(add_1787, [8192, 2048]) + permute_436 = torch.ops.aten.permute.default(view_1786, [1, 0]) + permute_392 = torch.ops.aten.permute.default(getitem_2765, [0, 2, 1, 3]) + view_1728 = torch.ops.aten.view.default(permute_392, [2, 4096, -1]); permute_392 = None + view_1730 = torch.ops.aten.view.default(view_1728, [8192, 2048]); view_1728 = None + mm_226 = torch.ops.aten.mm.default(permute_436, view_1730); permute_436 = view_1730 = None + convert_element_type_1406 = torch.ops.prims.convert_element_type.default(primals_428, torch.bfloat16); primals_428 = None + all_gather_into_tensor_441 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1406, 128, '0'); convert_element_type_1406 = None + wait_tensor_541 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_441); all_gather_into_tensor_441 = None + permute_393 = torch.ops.aten.permute.default(wait_tensor_541, [1, 0]); wait_tensor_541 = None + permute_438 = torch.ops.aten.permute.default(permute_393, [1, 0]); permute_393 = None + mm_227 = torch.ops.aten.mm.default(view_1786, permute_438); view_1786 = permute_438 = None + view_1787 = torch.ops.aten.view.default(mm_227, [2, 4096, 2048]); mm_227 = None + convert_element_type_1500 = torch.ops.prims.convert_element_type.default(mm_226, torch.float32); mm_226 = None + reduce_scatter_tensor_10 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1500, 'avg', 128, '0'); convert_element_type_1500 = None + wait_tensor_571 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_10); reduce_scatter_tensor_10 = None + view_1788 = torch.ops.aten.view.default(view_1787, [2, 4096, 16, 128]); view_1787 = None + permute_440 = torch.ops.aten.permute.default(view_1788, [0, 2, 1, 3]); view_1788 = None + fw_graph0 = self.fw_graph0 + joint_graph0 = self.joint_graph0 + mask_graph0 = self.mask_graph0 + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(permute_389, permute_390, permute_391, getitem_2765, getitem_2766, permute_440, None, fw_graph0, joint_graph0, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph0), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_389 = permute_390 = permute_391 = getitem_2765 = getitem_2766 = permute_440 = fw_graph0 = joint_graph0 = mask_graph0 = None + getitem_3749 = flex_attention_backward[0] + getitem_3750 = flex_attention_backward[1] + getitem_3751 = flex_attention_backward[2]; flex_attention_backward = None + permute_441 = torch.ops.aten.permute.default(getitem_3751, [0, 2, 1, 3]); getitem_3751 = None + permute_442 = torch.ops.aten.permute.default(getitem_3750, [0, 2, 1, 3]); getitem_3750 = None + permute_443 = torch.ops.aten.permute.default(getitem_3749, [0, 2, 1, 3]); getitem_3749 = None + slice_164 = torch.ops.aten.slice.Tensor(permute_442, 3, 0, 128) + slice_165 = torch.ops.aten.slice.Tensor(permute_442, 3, 128, 192); permute_442 = None + sum_110 = torch.ops.aten.sum.dim_IntList(slice_165, [2], True); slice_165 = None + cat_240 = torch.ops.aten.cat.default([slice_164, permute_441], 3); slice_164 = permute_441 = None + view_1789 = torch.ops.aten.view.default(cat_240, [2, 4096, 4096]); cat_240 = None + view_1790 = torch.ops.aten.view.default(view_1789, [8192, 4096]); view_1789 = None + permute_444 = torch.ops.aten.permute.default(view_1790, [1, 0]) + mm_228 = torch.ops.aten.mm.default(permute_444, view_1725); permute_444 = view_1725 = None + convert_element_type_1403 = torch.ops.prims.convert_element_type.default(primals_427, torch.bfloat16); primals_427 = None + all_gather_into_tensor_440 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1403, 128, '0'); convert_element_type_1403 = None + wait_tensor_540 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_440); all_gather_into_tensor_440 = None + permute_388 = torch.ops.aten.permute.default(wait_tensor_540, [1, 0]); wait_tensor_540 = None + permute_446 = torch.ops.aten.permute.default(permute_388, [1, 0]); permute_388 = None + mm_229 = torch.ops.aten.mm.default(view_1790, permute_446); view_1790 = permute_446 = None + view_1791 = torch.ops.aten.view.default(mm_229, [2, 4096, 512]); mm_229 = None + convert_element_type_1505 = torch.ops.prims.convert_element_type.default(mm_228, torch.float32); mm_228 = None + reduce_scatter_tensor_11 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1505, 'avg', 128, '0'); convert_element_type_1505 = None + wait_tensor_572 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11); reduce_scatter_tensor_11 = None + convert_element_type_1506 = torch.ops.prims.convert_element_type.default(view_1791, torch.float32); view_1791 = None + convert_element_type_1400 = torch.ops.prims.convert_element_type.default(primals_426, torch.bfloat16); primals_426 = None + all_gather_into_tensor_439 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1400, 128, '0'); convert_element_type_1400 = None + wait_tensor_539 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_439); all_gather_into_tensor_439 = None + convert_element_type_1508 = torch.ops.prims.convert_element_type.default(wait_tensor_539, torch.float32); wait_tensor_539 = None + mul_1311 = torch.ops.aten.mul.Tensor(convert_element_type_1506, convert_element_type_1508); convert_element_type_1508 = None + convert_element_type_1401 = torch.ops.prims.convert_element_type.default(getitem_2761, torch.float32); getitem_2761 = None + mul_1238 = torch.ops.aten.mul.Tensor(convert_element_type_1401, rsqrt_79); convert_element_type_1401 = None + mul_1313 = torch.ops.aten.mul.Tensor(mul_1238, mul_1311) + sum_111 = torch.ops.aten.sum.dim_IntList(mul_1313, [2], True); mul_1313 = None + div_136 = torch.ops.aten.div.Tensor(mul_1238, 512) + mul_1314 = torch.ops.aten.mul.Tensor(div_136, sum_111); div_136 = sum_111 = None + sub_629 = torch.ops.aten.sub.Tensor(mul_1311, mul_1314); mul_1311 = mul_1314 = None + mul_1315 = torch.ops.aten.mul.Tensor(sub_629, rsqrt_79); sub_629 = rsqrt_79 = None + mul_1316 = torch.ops.aten.mul.Tensor(convert_element_type_1506, mul_1238); convert_element_type_1506 = mul_1238 = None + sum_112 = torch.ops.aten.sum.dim_IntList(mul_1316, [0, 1]); mul_1316 = None + convert_element_type_1509 = torch.ops.prims.convert_element_type.default(mul_1315, torch.bfloat16); mul_1315 = None + convert_element_type_default_80 = torch.ops.prims.convert_element_type.default(sum_112, torch.float32); sum_112 = None + reduce_scatter_tensor_12 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_80, 'avg', 128, '0'); convert_element_type_default_80 = None + wait_tensor_573 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_12); reduce_scatter_tensor_12 = None + convert_element_type_1512 = torch.ops.prims.convert_element_type.default(sum_110, torch.float32); sum_110 = None + view_1792 = torch.ops.aten.view.default(convert_element_type_1512, [2, 4096, 1, 32, 2]); convert_element_type_1512 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_1792); view_1792 = None + view_7 = torch.ops.aten.view.default(primals_3, [1, 4096, 1, 32]); primals_3 = None + _conj = torch.ops.aten._conj.default(view_7); view_7 = None + clone_9 = torch.ops.aten.clone.default(_conj); _conj = None + mul_1317 = torch.ops.aten.mul.Tensor(view_as_complex_54, clone_9); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_1317); mul_1317 = None + view_1793 = torch.ops.aten.view.default(view_as_real_54, [2, 4096, 1, 64]); view_as_real_54 = None + convert_element_type_1513 = torch.ops.prims.convert_element_type.default(view_1793, torch.bfloat16); view_1793 = None + squeeze_26 = torch.ops.aten.squeeze.dim(convert_element_type_1513, 2); convert_element_type_1513 = None + cat_241 = torch.ops.aten.cat.default([convert_element_type_1509, squeeze_26], 2); convert_element_type_1509 = squeeze_26 = None + view_1794 = torch.ops.aten.view.default(cat_241, [8192, 576]); cat_241 = None + permute_448 = torch.ops.aten.permute.default(view_1794, [1, 0]) + mm_230 = torch.ops.aten.mm.default(permute_448, view_1711); permute_448 = None + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(primals_425, torch.bfloat16); primals_425 = None + all_gather_into_tensor_438 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1395, 128, '0'); convert_element_type_1395 = None + wait_tensor_538 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_438); all_gather_into_tensor_438 = None + slice_157 = torch.ops.aten.slice.Tensor(wait_tensor_538, 0, 0, 576); wait_tensor_538 = None + permute_387 = torch.ops.aten.permute.default(slice_157, [1, 0]); slice_157 = None + permute_450 = torch.ops.aten.permute.default(permute_387, [1, 0]); permute_387 = None + mm_231 = torch.ops.aten.mm.default(view_1794, permute_450); view_1794 = permute_450 = None + view_1795 = torch.ops.aten.view.default(mm_231, [2, 4096, 2048]); mm_231 = None + convert_element_type_1518 = torch.ops.prims.convert_element_type.default(mm_230, torch.float32); mm_230 = None + split_208 = torch.ops.aten.split.Tensor(convert_element_type_1518, 5); convert_element_type_1518 = None + getitem_3753 = split_208[0] + getitem_3754 = split_208[1] + getitem_3755 = split_208[2] + getitem_3756 = split_208[3] + getitem_3757 = split_208[4] + getitem_3758 = split_208[5] + getitem_3759 = split_208[6] + getitem_3760 = split_208[7] + getitem_3761 = split_208[8] + getitem_3762 = split_208[9] + getitem_3763 = split_208[10] + getitem_3764 = split_208[11] + getitem_3765 = split_208[12] + getitem_3766 = split_208[13] + getitem_3767 = split_208[14] + getitem_3768 = split_208[15] + getitem_3769 = split_208[16] + getitem_3770 = split_208[17] + getitem_3771 = split_208[18] + getitem_3772 = split_208[19] + getitem_3773 = split_208[20] + getitem_3774 = split_208[21] + getitem_3775 = split_208[22] + getitem_3776 = split_208[23] + getitem_3777 = split_208[24] + getitem_3778 = split_208[25] + getitem_3779 = split_208[26] + getitem_3780 = split_208[27] + getitem_3781 = split_208[28] + getitem_3782 = split_208[29] + getitem_3783 = split_208[30] + getitem_3784 = split_208[31] + getitem_3785 = split_208[32] + getitem_3786 = split_208[33] + getitem_3787 = split_208[34] + getitem_3788 = split_208[35] + getitem_3789 = split_208[36] + getitem_3790 = split_208[37] + getitem_3791 = split_208[38] + getitem_3792 = split_208[39] + getitem_3793 = split_208[40] + getitem_3794 = split_208[41] + getitem_3795 = split_208[42] + getitem_3796 = split_208[43] + getitem_3797 = split_208[44] + getitem_3798 = split_208[45] + getitem_3799 = split_208[46] + getitem_3800 = split_208[47] + getitem_3801 = split_208[48] + getitem_3802 = split_208[49] + getitem_3803 = split_208[50] + getitem_3804 = split_208[51] + getitem_3805 = split_208[52] + getitem_3806 = split_208[53] + getitem_3807 = split_208[54] + getitem_3808 = split_208[55] + getitem_3809 = split_208[56] + getitem_3810 = split_208[57] + getitem_3811 = split_208[58] + getitem_3812 = split_208[59] + getitem_3813 = split_208[60] + getitem_3814 = split_208[61] + getitem_3815 = split_208[62] + getitem_3816 = split_208[63] + getitem_3817 = split_208[64] + getitem_3818 = split_208[65] + getitem_3819 = split_208[66] + getitem_3820 = split_208[67] + getitem_3821 = split_208[68] + getitem_3822 = split_208[69] + getitem_3823 = split_208[70] + getitem_3824 = split_208[71] + getitem_3825 = split_208[72] + getitem_3826 = split_208[73] + getitem_3827 = split_208[74] + getitem_3828 = split_208[75] + getitem_3829 = split_208[76] + getitem_3830 = split_208[77] + getitem_3831 = split_208[78] + getitem_3832 = split_208[79] + getitem_3833 = split_208[80] + getitem_3834 = split_208[81] + getitem_3835 = split_208[82] + getitem_3836 = split_208[83] + getitem_3837 = split_208[84] + getitem_3838 = split_208[85] + getitem_3839 = split_208[86] + getitem_3840 = split_208[87] + getitem_3841 = split_208[88] + getitem_3842 = split_208[89] + getitem_3843 = split_208[90] + getitem_3844 = split_208[91] + getitem_3845 = split_208[92] + getitem_3846 = split_208[93] + getitem_3847 = split_208[94] + getitem_3848 = split_208[95] + getitem_3849 = split_208[96] + getitem_3850 = split_208[97] + getitem_3851 = split_208[98] + getitem_3852 = split_208[99] + getitem_3853 = split_208[100] + getitem_3854 = split_208[101] + getitem_3855 = split_208[102] + getitem_3856 = split_208[103] + getitem_3857 = split_208[104] + getitem_3858 = split_208[105] + getitem_3859 = split_208[106] + getitem_3860 = split_208[107] + getitem_3861 = split_208[108] + getitem_3862 = split_208[109] + getitem_3863 = split_208[110] + getitem_3864 = split_208[111] + getitem_3865 = split_208[112] + getitem_3866 = split_208[113] + getitem_3867 = split_208[114] + getitem_3868 = split_208[115]; split_208 = None + constant_pad_nd_64 = torch.ops.aten.constant_pad_nd.default(getitem_3868, [0, 0, 0, 4], 0.0); getitem_3868 = None + constant_pad_nd_65 = torch.ops.aten.constant_pad_nd.default(full_default_54, [0, 0, 0, 5], 0.0); full_default_54 = None + cat_242 = torch.ops.aten.cat.default([getitem_3753, getitem_3754, getitem_3755, getitem_3756, getitem_3757, getitem_3758, getitem_3759, getitem_3760, getitem_3761, getitem_3762, getitem_3763, getitem_3764, getitem_3765, getitem_3766, getitem_3767, getitem_3768, getitem_3769, getitem_3770, getitem_3771, getitem_3772, getitem_3773, getitem_3774, getitem_3775, getitem_3776, getitem_3777, getitem_3778, getitem_3779, getitem_3780, getitem_3781, getitem_3782, getitem_3783, getitem_3784, getitem_3785, getitem_3786, getitem_3787, getitem_3788, getitem_3789, getitem_3790, getitem_3791, getitem_3792, getitem_3793, getitem_3794, getitem_3795, getitem_3796, getitem_3797, getitem_3798, getitem_3799, getitem_3800, getitem_3801, getitem_3802, getitem_3803, getitem_3804, getitem_3805, getitem_3806, getitem_3807, getitem_3808, getitem_3809, getitem_3810, getitem_3811, getitem_3812, getitem_3813, getitem_3814, getitem_3815, getitem_3816, getitem_3817, getitem_3818, getitem_3819, getitem_3820, getitem_3821, getitem_3822, getitem_3823, getitem_3824, getitem_3825, getitem_3826, getitem_3827, getitem_3828, getitem_3829, getitem_3830, getitem_3831, getitem_3832, getitem_3833, getitem_3834, getitem_3835, getitem_3836, getitem_3837, getitem_3838, getitem_3839, getitem_3840, getitem_3841, getitem_3842, getitem_3843, getitem_3844, getitem_3845, getitem_3846, getitem_3847, getitem_3848, getitem_3849, getitem_3850, getitem_3851, getitem_3852, getitem_3853, getitem_3854, getitem_3855, getitem_3856, getitem_3857, getitem_3858, getitem_3859, getitem_3860, getitem_3861, getitem_3862, getitem_3863, getitem_3864, getitem_3865, getitem_3866, getitem_3867, constant_pad_nd_64, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_3753 = getitem_3754 = getitem_3755 = getitem_3756 = getitem_3757 = getitem_3758 = getitem_3759 = getitem_3760 = getitem_3761 = getitem_3762 = getitem_3763 = getitem_3764 = getitem_3765 = getitem_3766 = getitem_3767 = getitem_3768 = getitem_3769 = getitem_3770 = getitem_3771 = getitem_3772 = getitem_3773 = getitem_3774 = getitem_3775 = getitem_3776 = getitem_3777 = getitem_3778 = getitem_3779 = getitem_3780 = getitem_3781 = getitem_3782 = getitem_3783 = getitem_3784 = getitem_3785 = getitem_3786 = getitem_3787 = getitem_3788 = getitem_3789 = getitem_3790 = getitem_3791 = getitem_3792 = getitem_3793 = getitem_3794 = getitem_3795 = getitem_3796 = getitem_3797 = getitem_3798 = getitem_3799 = getitem_3800 = getitem_3801 = getitem_3802 = getitem_3803 = getitem_3804 = getitem_3805 = getitem_3806 = getitem_3807 = getitem_3808 = getitem_3809 = getitem_3810 = getitem_3811 = getitem_3812 = getitem_3813 = getitem_3814 = getitem_3815 = getitem_3816 = getitem_3817 = getitem_3818 = getitem_3819 = getitem_3820 = getitem_3821 = getitem_3822 = getitem_3823 = getitem_3824 = getitem_3825 = getitem_3826 = getitem_3827 = getitem_3828 = getitem_3829 = getitem_3830 = getitem_3831 = getitem_3832 = getitem_3833 = getitem_3834 = getitem_3835 = getitem_3836 = getitem_3837 = getitem_3838 = getitem_3839 = getitem_3840 = getitem_3841 = getitem_3842 = getitem_3843 = getitem_3844 = getitem_3845 = getitem_3846 = getitem_3847 = getitem_3848 = getitem_3849 = getitem_3850 = getitem_3851 = getitem_3852 = getitem_3853 = getitem_3854 = getitem_3855 = getitem_3856 = getitem_3857 = getitem_3858 = getitem_3859 = getitem_3860 = getitem_3861 = getitem_3862 = getitem_3863 = getitem_3864 = getitem_3865 = getitem_3866 = getitem_3867 = constant_pad_nd_64 = None + reduce_scatter_tensor_13 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_242, 'avg', 128, '0'); cat_242 = None + wait_tensor_574 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13); reduce_scatter_tensor_13 = None + slice_166 = torch.ops.aten.slice.Tensor(permute_443, 3, 0, 128) + slice_167 = torch.ops.aten.slice.Tensor(permute_443, 3, 128, 192); permute_443 = None + convert_element_type_1519 = torch.ops.prims.convert_element_type.default(slice_167, torch.float32); slice_167 = None + view_1796 = torch.ops.aten.view.default(convert_element_type_1519, [2, 4096, 16, 32, 2]); convert_element_type_1519 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_1796); view_1796 = None + mul_1318 = torch.ops.aten.mul.Tensor(view_as_complex_55, clone_9); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_1318); mul_1318 = None + view_1797 = torch.ops.aten.view.default(view_as_real_55, [2, 4096, 16, 64]); view_as_real_55 = None + convert_element_type_1520 = torch.ops.prims.convert_element_type.default(view_1797, torch.bfloat16); view_1797 = None + cat_243 = torch.ops.aten.cat.default([slice_166, convert_element_type_1520], 3); slice_166 = convert_element_type_1520 = None + view_1798 = torch.ops.aten.view.default(cat_243, [2, 4096, 3072]); cat_243 = None + view_1799 = torch.ops.aten.view.default(view_1798, [8192, 3072]); view_1798 = None + permute_452 = torch.ops.aten.permute.default(view_1799, [1, 0]) + mm_232 = torch.ops.aten.mm.default(permute_452, view_1711); permute_452 = view_1711 = None + convert_element_type_1390 = torch.ops.prims.convert_element_type.default(primals_424, torch.bfloat16); primals_424 = None + all_gather_into_tensor_437 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1390, 128, '0'); convert_element_type_1390 = None + wait_tensor_537 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_437); all_gather_into_tensor_437 = None + permute_386 = torch.ops.aten.permute.default(wait_tensor_537, [1, 0]); wait_tensor_537 = None + permute_454 = torch.ops.aten.permute.default(permute_386, [1, 0]); permute_386 = None + mm_233 = torch.ops.aten.mm.default(view_1799, permute_454); view_1799 = permute_454 = None + view_1800 = torch.ops.aten.view.default(mm_233, [2, 4096, 2048]); mm_233 = None + add_1788 = torch.ops.aten.add.Tensor(view_1795, view_1800); view_1795 = view_1800 = None + convert_element_type_1525 = torch.ops.prims.convert_element_type.default(mm_232, torch.float32); mm_232 = None + reduce_scatter_tensor_14 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1525, 'avg', 128, '0'); convert_element_type_1525 = None + wait_tensor_575 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_14); reduce_scatter_tensor_14 = None + convert_element_type_1526 = torch.ops.prims.convert_element_type.default(add_1788, torch.float32); add_1788 = None + convert_element_type_1387 = torch.ops.prims.convert_element_type.default(primals_423, torch.bfloat16); primals_423 = None + all_gather_into_tensor_436 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1387, 128, '0'); convert_element_type_1387 = None + wait_tensor_536 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_436); all_gather_into_tensor_436 = None + convert_element_type_1528 = torch.ops.prims.convert_element_type.default(wait_tensor_536, torch.float32); wait_tensor_536 = None + mul_1319 = torch.ops.aten.mul.Tensor(convert_element_type_1526, convert_element_type_1528); convert_element_type_1528 = None + convert_element_type_1388 = torch.ops.prims.convert_element_type.default(add_1705, torch.float32); add_1705 = None + mul_1234 = torch.ops.aten.mul.Tensor(convert_element_type_1388, rsqrt_78); convert_element_type_1388 = None + mul_1321 = torch.ops.aten.mul.Tensor(mul_1234, mul_1319) + sum_113 = torch.ops.aten.sum.dim_IntList(mul_1321, [2], True); mul_1321 = None + div_137 = torch.ops.aten.div.Tensor(mul_1234, 2048) + mul_1322 = torch.ops.aten.mul.Tensor(div_137, sum_113); div_137 = sum_113 = None + sub_630 = torch.ops.aten.sub.Tensor(mul_1319, mul_1322); mul_1319 = mul_1322 = None + mul_1323 = torch.ops.aten.mul.Tensor(sub_630, rsqrt_78); sub_630 = rsqrt_78 = None + mul_1324 = torch.ops.aten.mul.Tensor(convert_element_type_1526, mul_1234); convert_element_type_1526 = mul_1234 = None + sum_114 = torch.ops.aten.sum.dim_IntList(mul_1324, [0, 1]); mul_1324 = None + convert_element_type_1529 = torch.ops.prims.convert_element_type.default(mul_1323, torch.bfloat16); mul_1323 = None + add_1789 = torch.ops.aten.add.Tensor(add_1787, convert_element_type_1529); add_1787 = convert_element_type_1529 = None + convert_element_type_default_79 = torch.ops.prims.convert_element_type.default(sum_114, torch.float32); sum_114 = None + reduce_scatter_tensor_15 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_79, 'avg', 128, '0'); convert_element_type_default_79 = None + wait_tensor_576 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15); reduce_scatter_tensor_15 = None + view_1801 = torch.ops.aten.view.default(add_1789, [8192, 2048]) + unsqueeze_54 = torch.ops.aten.unsqueeze.default(view_1801, 1) + convert_element_type_1532 = torch.ops.prims.convert_element_type.default(unsqueeze_54, torch.float32); unsqueeze_54 = None + bmm_28 = torch.ops.aten.bmm.default(permute_456, convert_element_type_1532); permute_456 = None + bmm_29 = torch.ops.aten.bmm.default(convert_element_type_1532, permute_457); convert_element_type_1532 = permute_457 = None + convert_element_type_1533 = torch.ops.prims.convert_element_type.default(bmm_28, torch.bfloat16); bmm_28 = None + view_1802 = torch.ops.aten.view.default(bmm_29, [8192, 6]); bmm_29 = None + view_1803 = torch.ops.aten.view.default(convert_element_type_1533, [49152, 2048]); convert_element_type_1533 = None + index_54 = torch.ops.aten.index.Tensor(view_1803, [getitem_2661]); view_1803 = getitem_2661 = None + permute_458 = torch.ops.aten.permute.default(view_1801, [1, 0]) + mm_234 = torch.ops.aten.mm.default(permute_458, mul_1231); permute_458 = mul_1231 = None + convert_element_type_1382 = torch.ops.prims.convert_element_type.default(primals_422, torch.bfloat16); primals_422 = None + all_gather_into_tensor_435 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1382, 128, '0'); convert_element_type_1382 = None + wait_tensor_535 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_435); all_gather_into_tensor_435 = None + permute_385 = torch.ops.aten.permute.default(wait_tensor_535, [1, 0]); wait_tensor_535 = None + permute_460 = torch.ops.aten.permute.default(permute_385, [1, 0]); permute_385 = None + mm_235 = torch.ops.aten.mm.default(view_1801, permute_460); view_1801 = permute_460 = None + convert_element_type_1538 = torch.ops.prims.convert_element_type.default(mm_234, torch.float32); mm_234 = None + reduce_scatter_tensor_16 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1538, 'avg', 128, '0'); convert_element_type_1538 = None + wait_tensor_577 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_16); reduce_scatter_tensor_16 = None + convert_element_type_1377 = torch.ops.prims.convert_element_type.default(mm_204, torch.float32); mm_204 = None + neg_50 = torch.ops.aten.neg.default(convert_element_type_1377) + exp_75 = torch.ops.aten.exp.default(neg_50); neg_50 = None + add_1700 = torch.ops.aten.add.Tensor(exp_75, 1); exp_75 = None + div_125 = torch.ops.aten.div.Tensor(convert_element_type_1377, add_1700) + convert_element_type_1378 = torch.ops.prims.convert_element_type.default(div_125, torch.bfloat16); div_125 = None + mul_1325 = torch.ops.aten.mul.Tensor(mm_235, convert_element_type_1378); convert_element_type_1378 = None + mul_1326 = torch.ops.aten.mul.Tensor(mm_235, mm_205); mm_235 = mm_205 = None + permute_462 = torch.ops.aten.permute.default(mul_1325, [1, 0]) + mm_236 = torch.ops.aten.mm.default(permute_462, view_1666); permute_462 = None + convert_element_type_1379 = torch.ops.prims.convert_element_type.default(primals_421, torch.bfloat16); primals_421 = None + all_gather_into_tensor_434 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1379, 128, '0'); convert_element_type_1379 = None + wait_tensor_534 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_434); all_gather_into_tensor_434 = None + permute_384 = torch.ops.aten.permute.default(wait_tensor_534, [1, 0]); wait_tensor_534 = None + permute_464 = torch.ops.aten.permute.default(permute_384, [1, 0]); permute_384 = None + mm_237 = torch.ops.aten.mm.default(mul_1325, permute_464); mul_1325 = permute_464 = None + convert_element_type_1543 = torch.ops.prims.convert_element_type.default(mm_236, torch.float32); mm_236 = None + reduce_scatter_tensor_17 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1543, 'avg', 128, '0'); convert_element_type_1543 = None + wait_tensor_578 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17); reduce_scatter_tensor_17 = None + convert_element_type_1544 = torch.ops.prims.convert_element_type.default(mul_1326, torch.float32); mul_1326 = None + reciprocal_2 = torch.ops.aten.reciprocal.default(add_1700); add_1700 = None + mul_1327 = torch.ops.aten.mul.Tensor(reciprocal_2, 1); reciprocal_2 = None + mul_1328 = torch.ops.aten.mul.Tensor(convert_element_type_1544, mul_1327); convert_element_type_1544 = None + sub_631 = torch.ops.aten.sub.Tensor(1, mul_1327); mul_1327 = None + mul_1329 = torch.ops.aten.mul.Tensor(convert_element_type_1377, sub_631); convert_element_type_1377 = sub_631 = None + add_1791 = torch.ops.aten.add.Tensor(mul_1329, 1); mul_1329 = None + mul_1330 = torch.ops.aten.mul.Tensor(mul_1328, add_1791); mul_1328 = add_1791 = None + convert_element_type_1546 = torch.ops.prims.convert_element_type.default(mul_1330, torch.bfloat16); mul_1330 = None + permute_466 = torch.ops.aten.permute.default(convert_element_type_1546, [1, 0]) + mm_238 = torch.ops.aten.mm.default(permute_466, view_1666); permute_466 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(primals_420, torch.bfloat16); primals_420 = None + all_gather_into_tensor_433 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1374, 128, '0'); convert_element_type_1374 = None + wait_tensor_533 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_433); all_gather_into_tensor_433 = None + permute_383 = torch.ops.aten.permute.default(wait_tensor_533, [1, 0]); wait_tensor_533 = None + permute_468 = torch.ops.aten.permute.default(permute_383, [1, 0]); permute_383 = None + mm_239 = torch.ops.aten.mm.default(convert_element_type_1546, permute_468); convert_element_type_1546 = permute_468 = None + add_1792 = torch.ops.aten.add.Tensor(mm_237, mm_239); mm_237 = mm_239 = None + convert_element_type_1551 = torch.ops.prims.convert_element_type.default(mm_238, torch.float32); mm_238 = None + reduce_scatter_tensor_18 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1551, 'avg', 128, '0'); convert_element_type_1551 = None + wait_tensor_579 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_18); reduce_scatter_tensor_18 = None + all_to_all_single_80 = torch.ops._c10d_functional.all_to_all_single.default(index_54, [_local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399], [_local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391], '1033'); index_54 = None + wait_tensor_580 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_80); all_to_all_single_80 = None + full_354 = torch.ops.aten.full.default([sym_size_int_97, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_97 = None + slice_scatter_1 = torch.ops.aten.slice_scatter.default(full_354, wait_tensor_580, 0, 0, -1); wait_tensor_580 = None + index_55 = torch.ops.aten.index.Tensor(slice_scatter_1, [getitem_2662]); slice_scatter_1 = None + permute_470 = torch.ops.aten.permute.default(index_55, [1, 0]) + _grouped_mm_84 = torch.ops.aten._grouped_mm.default(permute_470, mul_1211, cumsum_74); permute_470 = mul_1211 = None + _grouped_mm_85 = torch.ops.aten._grouped_mm.default(index_55, permute_472, cumsum_74); index_55 = permute_472 = None + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(_grouped_mm_72, torch.float32); _grouped_mm_72 = None + neg_49 = torch.ops.aten.neg.default(convert_element_type_1372) + exp_74 = torch.ops.aten.exp.default(neg_49); neg_49 = None + add_1664 = torch.ops.aten.add.Tensor(exp_74, 1); exp_74 = None + div_124 = torch.ops.aten.div.Tensor(convert_element_type_1372, add_1664) + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(div_124, torch.bfloat16); div_124 = None + mul_1331 = torch.ops.aten.mul.Tensor(_grouped_mm_85, convert_element_type_1373); convert_element_type_1373 = None + mul_1332 = torch.ops.aten.mul.Tensor(_grouped_mm_85, _grouped_mm_73); _grouped_mm_85 = _grouped_mm_73 = None + permute_474 = torch.ops.aten.permute.default(mul_1331, [1, 0]) + _grouped_mm_86 = torch.ops.aten._grouped_mm.default(permute_474, index_49, cumsum_74); permute_474 = None + _grouped_mm_87 = torch.ops.aten._grouped_mm.default(mul_1331, permute_476, cumsum_74); mul_1331 = permute_476 = None + convert_element_type_1552 = torch.ops.prims.convert_element_type.default(mul_1332, torch.float32); mul_1332 = None + reciprocal_3 = torch.ops.aten.reciprocal.default(add_1664); add_1664 = None + mul_1333 = torch.ops.aten.mul.Tensor(reciprocal_3, 1); reciprocal_3 = None + mul_1334 = torch.ops.aten.mul.Tensor(convert_element_type_1552, mul_1333); convert_element_type_1552 = None + sub_632 = torch.ops.aten.sub.Tensor(1, mul_1333); mul_1333 = None + mul_1335 = torch.ops.aten.mul.Tensor(convert_element_type_1372, sub_632); convert_element_type_1372 = sub_632 = None + add_1794 = torch.ops.aten.add.Tensor(mul_1335, 1); mul_1335 = None + mul_1336 = torch.ops.aten.mul.Tensor(mul_1334, add_1794); mul_1334 = add_1794 = None + convert_element_type_1554 = torch.ops.prims.convert_element_type.default(mul_1336, torch.bfloat16); mul_1336 = None + permute_478 = torch.ops.aten.permute.default(convert_element_type_1554, [1, 0]) + _grouped_mm_88 = torch.ops.aten._grouped_mm.default(permute_478, index_49, cumsum_74); permute_478 = index_49 = None + _grouped_mm_89 = torch.ops.aten._grouped_mm.default(convert_element_type_1554, permute_480, cumsum_74); convert_element_type_1554 = permute_480 = cumsum_74 = None + add_1795 = torch.ops.aten.add.Tensor(_grouped_mm_87, _grouped_mm_89); _grouped_mm_87 = _grouped_mm_89 = None + convert_element_type_1555 = torch.ops.prims.convert_element_type.default(_grouped_mm_86, torch.float32); _grouped_mm_86 = None + div_138 = torch.ops.aten.div.Tensor(convert_element_type_1555, 128); convert_element_type_1555 = None + split_210 = torch.ops.aten.split.Tensor(div_138, 88, 1); div_138 = None + getitem_3885 = split_210[0] + getitem_3902 = split_210[1] + getitem_3919 = split_210[2] + getitem_3936 = split_210[3] + getitem_3953 = split_210[4] + getitem_3970 = split_210[5] + getitem_3987 = split_210[6] + getitem_4004 = split_210[7] + getitem_4021 = split_210[8] + getitem_4038 = split_210[9] + getitem_4055 = split_210[10] + getitem_4072 = split_210[11] + getitem_4089 = split_210[12] + getitem_4106 = split_210[13] + getitem_4123 = split_210[14] + getitem_4140 = split_210[15]; split_210 = None + cat_244 = torch.ops.aten.cat.default([getitem_3885, getitem_3902, getitem_3919, getitem_3936, getitem_3953, getitem_3970, getitem_3987, getitem_4004, getitem_4021, getitem_4038, getitem_4055, getitem_4072, getitem_4089, getitem_4106, getitem_4123, getitem_4140]); getitem_3885 = getitem_3902 = getitem_3919 = getitem_3936 = getitem_3953 = getitem_3970 = getitem_3987 = getitem_4004 = getitem_4021 = getitem_4038 = getitem_4055 = getitem_4072 = getitem_4089 = getitem_4106 = getitem_4123 = getitem_4140 = None + reduce_scatter_tensor_19 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_244, 'sum', 16, '1025'); cat_244 = None + wait_tensor_581 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19); reduce_scatter_tensor_19 = None + convert_element_type_1556 = torch.ops.prims.convert_element_type.default(_grouped_mm_84, torch.float32); _grouped_mm_84 = None + div_139 = torch.ops.aten.div.Tensor(convert_element_type_1556, 128); convert_element_type_1556 = None + split_227 = torch.ops.aten.split.Tensor(div_139, 128, 1); div_139 = None + getitem_4157 = split_227[0] + getitem_4174 = split_227[1] + getitem_4191 = split_227[2] + getitem_4208 = split_227[3] + getitem_4225 = split_227[4] + getitem_4242 = split_227[5] + getitem_4259 = split_227[6] + getitem_4276 = split_227[7] + getitem_4293 = split_227[8] + getitem_4310 = split_227[9] + getitem_4327 = split_227[10] + getitem_4344 = split_227[11] + getitem_4361 = split_227[12] + getitem_4378 = split_227[13] + getitem_4395 = split_227[14] + getitem_4412 = split_227[15]; split_227 = None + cat_245 = torch.ops.aten.cat.default([getitem_4157, getitem_4174, getitem_4191, getitem_4208, getitem_4225, getitem_4242, getitem_4259, getitem_4276, getitem_4293, getitem_4310, getitem_4327, getitem_4344, getitem_4361, getitem_4378, getitem_4395, getitem_4412]); getitem_4157 = getitem_4174 = getitem_4191 = getitem_4208 = getitem_4225 = getitem_4242 = getitem_4259 = getitem_4276 = getitem_4293 = getitem_4310 = getitem_4327 = getitem_4344 = getitem_4361 = getitem_4378 = getitem_4395 = getitem_4412 = None + reduce_scatter_tensor_20 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_245, 'sum', 16, '1025'); cat_245 = None + wait_tensor_582 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_20); reduce_scatter_tensor_20 = None + convert_element_type_1557 = torch.ops.prims.convert_element_type.default(_grouped_mm_88, torch.float32); _grouped_mm_88 = None + div_140 = torch.ops.aten.div.Tensor(convert_element_type_1557, 128); convert_element_type_1557 = None + split_244 = torch.ops.aten.split.Tensor(div_140, 88, 1); div_140 = None + getitem_4429 = split_244[0] + getitem_4446 = split_244[1] + getitem_4463 = split_244[2] + getitem_4480 = split_244[3] + getitem_4497 = split_244[4] + getitem_4514 = split_244[5] + getitem_4531 = split_244[6] + getitem_4548 = split_244[7] + getitem_4565 = split_244[8] + getitem_4582 = split_244[9] + getitem_4599 = split_244[10] + getitem_4616 = split_244[11] + getitem_4633 = split_244[12] + getitem_4650 = split_244[13] + getitem_4667 = split_244[14] + getitem_4684 = split_244[15]; split_244 = None + cat_246 = torch.ops.aten.cat.default([getitem_4429, getitem_4446, getitem_4463, getitem_4480, getitem_4497, getitem_4514, getitem_4531, getitem_4548, getitem_4565, getitem_4582, getitem_4599, getitem_4616, getitem_4633, getitem_4650, getitem_4667, getitem_4684]); getitem_4429 = getitem_4446 = getitem_4463 = getitem_4480 = getitem_4497 = getitem_4514 = getitem_4531 = getitem_4548 = getitem_4565 = getitem_4582 = getitem_4599 = getitem_4616 = getitem_4633 = getitem_4650 = getitem_4667 = getitem_4684 = None + reduce_scatter_tensor_21 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_246, 'sum', 16, '1025'); cat_246 = None + wait_tensor_583 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21); reduce_scatter_tensor_21 = None + index_put_54 = torch.ops.aten.index_put.default(full_354, [getitem_2662], add_1795, True); full_354 = getitem_2662 = add_1795 = None + slice_168 = torch.ops.aten.slice.Tensor(index_put_54, 0, 0, add_1796); index_put_54 = add_1796 = None + all_to_all_single_81 = torch.ops._c10d_functional.all_to_all_single.default(slice_168, [_local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391], [_local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399], '1033'); slice_168 = _local_scalar_dense_384 = _local_scalar_dense_385 = _local_scalar_dense_386 = _local_scalar_dense_387 = _local_scalar_dense_388 = _local_scalar_dense_389 = _local_scalar_dense_390 = _local_scalar_dense_391 = _local_scalar_dense_392 = _local_scalar_dense_393 = _local_scalar_dense_394 = _local_scalar_dense_395 = _local_scalar_dense_396 = _local_scalar_dense_397 = _local_scalar_dense_398 = _local_scalar_dense_399 = None + wait_tensor_584 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_81); all_to_all_single_81 = None + index_put_55 = torch.ops.aten.index_put.default(full_default_52, [div_122], wait_tensor_584, True); div_122 = wait_tensor_584 = None + add_1800 = torch.ops.aten.add.Tensor(add_1792, index_put_55); add_1792 = index_put_55 = None + mul_1337 = torch.ops.aten.mul.Tensor(view_1802, 1.0); view_1802 = None + scatter_add_1 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_2659, mul_1337); getitem_2659 = mul_1337 = None + convert_element_type_1361 = torch.ops.prims.convert_element_type.default(mm_203, torch.float32); mm_203 = None + sub_576 = torch.ops.aten.sub.Tensor(convert_element_type_1361, amax_24); convert_element_type_1361 = amax_24 = None + exp_73 = torch.ops.aten.exp.default(sub_576); sub_576 = None + div_121 = torch.ops.aten.div.Tensor(exp_73, sum_97); exp_73 = sum_97 = None + mul_1338 = torch.ops.aten.mul.Tensor(scatter_add_1, div_121); scatter_add_1 = None + sum_115 = torch.ops.aten.sum.dim_IntList(mul_1338, [1], True) + neg_58 = torch.ops.aten.neg.default(div_121); div_121 = None + fma_1 = torch.ops.prims.fma.default(neg_58, sum_115, mul_1338); neg_58 = sum_115 = mul_1338 = None + convert_element_type_1558 = torch.ops.prims.convert_element_type.default(fma_1, torch.bfloat16); fma_1 = None + permute_482 = torch.ops.aten.permute.default(convert_element_type_1558, [1, 0]) + mm_240 = torch.ops.aten.mm.default(permute_482, view_1666); permute_482 = view_1666 = None + convert_element_type_1358 = torch.ops.prims.convert_element_type.default(primals_415, torch.bfloat16); primals_415 = None + all_gather_into_tensor_426 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1358, 128, '0'); convert_element_type_1358 = None + wait_tensor_522 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_426); all_gather_into_tensor_426 = None + slice_153 = torch.ops.aten.slice.Tensor(wait_tensor_522, 0, 0, 64); wait_tensor_522 = None + permute_379 = torch.ops.aten.permute.default(slice_153, [1, 0]); slice_153 = None + permute_484 = torch.ops.aten.permute.default(permute_379, [1, 0]); permute_379 = None + mm_241 = torch.ops.aten.mm.default(convert_element_type_1558, permute_484); convert_element_type_1558 = permute_484 = None + add_1801 = torch.ops.aten.add.Tensor(add_1800, mm_241); add_1800 = mm_241 = None + convert_element_type_1563 = torch.ops.prims.convert_element_type.default(mm_240, torch.float32); mm_240 = None + split_260 = torch.ops.aten.split.Tensor(convert_element_type_1563, 1); convert_element_type_1563 = None + getitem_4685 = split_260[0] + getitem_4686 = split_260[1] + getitem_4687 = split_260[2] + getitem_4688 = split_260[3] + getitem_4689 = split_260[4] + getitem_4690 = split_260[5] + getitem_4691 = split_260[6] + getitem_4692 = split_260[7] + getitem_4693 = split_260[8] + getitem_4694 = split_260[9] + getitem_4695 = split_260[10] + getitem_4696 = split_260[11] + getitem_4697 = split_260[12] + getitem_4698 = split_260[13] + getitem_4699 = split_260[14] + getitem_4700 = split_260[15] + getitem_4701 = split_260[16] + getitem_4702 = split_260[17] + getitem_4703 = split_260[18] + getitem_4704 = split_260[19] + getitem_4705 = split_260[20] + getitem_4706 = split_260[21] + getitem_4707 = split_260[22] + getitem_4708 = split_260[23] + getitem_4709 = split_260[24] + getitem_4710 = split_260[25] + getitem_4711 = split_260[26] + getitem_4712 = split_260[27] + getitem_4713 = split_260[28] + getitem_4714 = split_260[29] + getitem_4715 = split_260[30] + getitem_4716 = split_260[31] + getitem_4717 = split_260[32] + getitem_4718 = split_260[33] + getitem_4719 = split_260[34] + getitem_4720 = split_260[35] + getitem_4721 = split_260[36] + getitem_4722 = split_260[37] + getitem_4723 = split_260[38] + getitem_4724 = split_260[39] + getitem_4725 = split_260[40] + getitem_4726 = split_260[41] + getitem_4727 = split_260[42] + getitem_4728 = split_260[43] + getitem_4729 = split_260[44] + getitem_4730 = split_260[45] + getitem_4731 = split_260[46] + getitem_4732 = split_260[47] + getitem_4733 = split_260[48] + getitem_4734 = split_260[49] + getitem_4735 = split_260[50] + getitem_4736 = split_260[51] + getitem_4737 = split_260[52] + getitem_4738 = split_260[53] + getitem_4739 = split_260[54] + getitem_4740 = split_260[55] + getitem_4741 = split_260[56] + getitem_4742 = split_260[57] + getitem_4743 = split_260[58] + getitem_4744 = split_260[59] + getitem_4745 = split_260[60] + getitem_4746 = split_260[61] + getitem_4747 = split_260[62] + getitem_4748 = split_260[63]; split_260 = None + cat_247 = torch.ops.aten.cat.default([getitem_4685, getitem_4686, getitem_4687, getitem_4688, getitem_4689, getitem_4690, getitem_4691, getitem_4692, getitem_4693, getitem_4694, getitem_4695, getitem_4696, getitem_4697, getitem_4698, getitem_4699, getitem_4700, getitem_4701, getitem_4702, getitem_4703, getitem_4704, getitem_4705, getitem_4706, getitem_4707, getitem_4708, getitem_4709, getitem_4710, getitem_4711, getitem_4712, getitem_4713, getitem_4714, getitem_4715, getitem_4716, getitem_4717, getitem_4718, getitem_4719, getitem_4720, getitem_4721, getitem_4722, getitem_4723, getitem_4724, getitem_4725, getitem_4726, getitem_4727, getitem_4728, getitem_4729, getitem_4730, getitem_4731, getitem_4732, getitem_4733, getitem_4734, getitem_4735, getitem_4736, getitem_4737, getitem_4738, getitem_4739, getitem_4740, getitem_4741, getitem_4742, getitem_4743, getitem_4744, getitem_4745, getitem_4746, getitem_4747, getitem_4748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_4685 = getitem_4686 = getitem_4687 = getitem_4688 = getitem_4689 = getitem_4690 = getitem_4691 = getitem_4692 = getitem_4693 = getitem_4694 = getitem_4695 = getitem_4696 = getitem_4697 = getitem_4698 = getitem_4699 = getitem_4700 = getitem_4701 = getitem_4702 = getitem_4703 = getitem_4704 = getitem_4705 = getitem_4706 = getitem_4707 = getitem_4708 = getitem_4709 = getitem_4710 = getitem_4711 = getitem_4712 = getitem_4713 = getitem_4714 = getitem_4715 = getitem_4716 = getitem_4717 = getitem_4718 = getitem_4719 = getitem_4720 = getitem_4721 = getitem_4722 = getitem_4723 = getitem_4724 = getitem_4725 = getitem_4726 = getitem_4727 = getitem_4728 = getitem_4729 = getitem_4730 = getitem_4731 = getitem_4732 = getitem_4733 = getitem_4734 = getitem_4735 = getitem_4736 = getitem_4737 = getitem_4738 = getitem_4739 = getitem_4740 = getitem_4741 = getitem_4742 = getitem_4743 = getitem_4744 = getitem_4745 = getitem_4746 = getitem_4747 = getitem_4748 = None + reduce_scatter_tensor_22 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_247, 'avg', 128, '0'); cat_247 = None + wait_tensor_585 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_22); reduce_scatter_tensor_22 = None + view_1804 = torch.ops.aten.view.default(add_1801, [2, 4096, 2048]); add_1801 = None + convert_element_type_1564 = torch.ops.prims.convert_element_type.default(view_1804, torch.float32); view_1804 = None + convert_element_type_1355 = torch.ops.prims.convert_element_type.default(primals_413, torch.bfloat16); primals_413 = None + all_gather_into_tensor_425 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1355, 128, '0'); convert_element_type_1355 = None + wait_tensor_521 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_425); all_gather_into_tensor_425 = None + convert_element_type_1566 = torch.ops.prims.convert_element_type.default(wait_tensor_521, torch.float32); wait_tensor_521 = None + mul_1339 = torch.ops.aten.mul.Tensor(convert_element_type_1564, convert_element_type_1566); convert_element_type_1566 = None + convert_element_type_1356 = torch.ops.prims.convert_element_type.default(add_1640, torch.float32); add_1640 = None + mul_1191 = torch.ops.aten.mul.Tensor(convert_element_type_1356, rsqrt_77); convert_element_type_1356 = None + mul_1341 = torch.ops.aten.mul.Tensor(mul_1191, mul_1339) + sum_116 = torch.ops.aten.sum.dim_IntList(mul_1341, [2], True); mul_1341 = None + div_141 = torch.ops.aten.div.Tensor(mul_1191, 2048) + mul_1342 = torch.ops.aten.mul.Tensor(div_141, sum_116); div_141 = sum_116 = None + sub_634 = torch.ops.aten.sub.Tensor(mul_1339, mul_1342); mul_1339 = mul_1342 = None + mul_1343 = torch.ops.aten.mul.Tensor(sub_634, rsqrt_77); sub_634 = rsqrt_77 = None + mul_1344 = torch.ops.aten.mul.Tensor(convert_element_type_1564, mul_1191); convert_element_type_1564 = mul_1191 = None + sum_117 = torch.ops.aten.sum.dim_IntList(mul_1344, [0, 1]); mul_1344 = None + convert_element_type_1567 = torch.ops.prims.convert_element_type.default(mul_1343, torch.bfloat16); mul_1343 = None + add_1802 = torch.ops.aten.add.Tensor(add_1789, convert_element_type_1567); add_1789 = convert_element_type_1567 = None + convert_element_type_default_78 = torch.ops.prims.convert_element_type.default(sum_117, torch.float32); sum_117 = None + reduce_scatter_tensor_23 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_78, 'avg', 128, '0'); convert_element_type_default_78 = None + wait_tensor_586 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23); reduce_scatter_tensor_23 = None + view_1805 = torch.ops.aten.view.default(add_1802, [8192, 2048]) + permute_486 = torch.ops.aten.permute.default(view_1805, [1, 0]) + permute_377 = torch.ops.aten.permute.default(getitem_2655, [0, 2, 1, 3]) + view_1661 = torch.ops.aten.view.default(permute_377, [2, 4096, -1]); permute_377 = None + view_1663 = torch.ops.aten.view.default(view_1661, [8192, 2048]); view_1661 = None + mm_242 = torch.ops.aten.mm.default(permute_486, view_1663); permute_486 = view_1663 = None + convert_element_type_1352 = torch.ops.prims.convert_element_type.default(primals_412, torch.bfloat16); primals_412 = None + all_gather_into_tensor_424 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1352, 128, '0'); convert_element_type_1352 = None + wait_tensor_520 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_424); all_gather_into_tensor_424 = None + permute_378 = torch.ops.aten.permute.default(wait_tensor_520, [1, 0]); wait_tensor_520 = None + permute_488 = torch.ops.aten.permute.default(permute_378, [1, 0]); permute_378 = None + mm_243 = torch.ops.aten.mm.default(view_1805, permute_488); view_1805 = permute_488 = None + view_1806 = torch.ops.aten.view.default(mm_243, [2, 4096, 2048]); mm_243 = None + convert_element_type_1574 = torch.ops.prims.convert_element_type.default(mm_242, torch.float32); mm_242 = None + reduce_scatter_tensor_24 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1574, 'avg', 128, '0'); convert_element_type_1574 = None + wait_tensor_587 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_24); reduce_scatter_tensor_24 = None + view_1807 = torch.ops.aten.view.default(view_1806, [2, 4096, 16, 128]); view_1806 = None + permute_490 = torch.ops.aten.permute.default(view_1807, [0, 2, 1, 3]); view_1807 = None + fw_graph1 = self.fw_graph1 + joint_graph1 = self.joint_graph1 + mask_graph1 = self.mask_graph1 + flex_attention_backward_1 = torch.ops.higher_order.flex_attention_backward(permute_374, permute_375, permute_376, getitem_2655, getitem_2656, permute_490, None, fw_graph1, joint_graph1, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph1), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_374 = permute_375 = permute_376 = getitem_2655 = getitem_2656 = permute_490 = fw_graph1 = joint_graph1 = mask_graph1 = None + getitem_4749 = flex_attention_backward_1[0] + getitem_4750 = flex_attention_backward_1[1] + getitem_4751 = flex_attention_backward_1[2]; flex_attention_backward_1 = None + permute_491 = torch.ops.aten.permute.default(getitem_4751, [0, 2, 1, 3]); getitem_4751 = None + permute_492 = torch.ops.aten.permute.default(getitem_4750, [0, 2, 1, 3]); getitem_4750 = None + permute_493 = torch.ops.aten.permute.default(getitem_4749, [0, 2, 1, 3]); getitem_4749 = None + slice_170 = torch.ops.aten.slice.Tensor(permute_492, 3, 0, 128) + slice_171 = torch.ops.aten.slice.Tensor(permute_492, 3, 128, 192); permute_492 = None + sum_118 = torch.ops.aten.sum.dim_IntList(slice_171, [2], True); slice_171 = None + cat_248 = torch.ops.aten.cat.default([slice_170, permute_491], 3); slice_170 = permute_491 = None + view_1808 = torch.ops.aten.view.default(cat_248, [2, 4096, 4096]); cat_248 = None + view_1809 = torch.ops.aten.view.default(view_1808, [8192, 4096]); view_1808 = None + permute_494 = torch.ops.aten.permute.default(view_1809, [1, 0]) + mm_244 = torch.ops.aten.mm.default(permute_494, view_1658); permute_494 = view_1658 = None + convert_element_type_1349 = torch.ops.prims.convert_element_type.default(primals_411, torch.bfloat16); primals_411 = None + all_gather_into_tensor_423 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1349, 128, '0'); convert_element_type_1349 = None + wait_tensor_519 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_423); all_gather_into_tensor_423 = None + permute_373 = torch.ops.aten.permute.default(wait_tensor_519, [1, 0]); wait_tensor_519 = None + permute_496 = torch.ops.aten.permute.default(permute_373, [1, 0]); permute_373 = None + mm_245 = torch.ops.aten.mm.default(view_1809, permute_496); view_1809 = permute_496 = None + view_1810 = torch.ops.aten.view.default(mm_245, [2, 4096, 512]); mm_245 = None + convert_element_type_1579 = torch.ops.prims.convert_element_type.default(mm_244, torch.float32); mm_244 = None + reduce_scatter_tensor_25 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1579, 'avg', 128, '0'); convert_element_type_1579 = None + wait_tensor_588 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25); reduce_scatter_tensor_25 = None + convert_element_type_1580 = torch.ops.prims.convert_element_type.default(view_1810, torch.float32); view_1810 = None + convert_element_type_1346 = torch.ops.prims.convert_element_type.default(primals_410, torch.bfloat16); primals_410 = None + all_gather_into_tensor_422 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1346, 128, '0'); convert_element_type_1346 = None + wait_tensor_518 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_422); all_gather_into_tensor_422 = None + convert_element_type_1582 = torch.ops.prims.convert_element_type.default(wait_tensor_518, torch.float32); wait_tensor_518 = None + mul_1345 = torch.ops.aten.mul.Tensor(convert_element_type_1580, convert_element_type_1582); convert_element_type_1582 = None + convert_element_type_1347 = torch.ops.prims.convert_element_type.default(getitem_2651, torch.float32); getitem_2651 = None + mul_1189 = torch.ops.aten.mul.Tensor(convert_element_type_1347, rsqrt_76); convert_element_type_1347 = None + mul_1347 = torch.ops.aten.mul.Tensor(mul_1189, mul_1345) + sum_119 = torch.ops.aten.sum.dim_IntList(mul_1347, [2], True); mul_1347 = None + div_142 = torch.ops.aten.div.Tensor(mul_1189, 512) + mul_1348 = torch.ops.aten.mul.Tensor(div_142, sum_119); div_142 = sum_119 = None + sub_635 = torch.ops.aten.sub.Tensor(mul_1345, mul_1348); mul_1345 = mul_1348 = None + mul_1349 = torch.ops.aten.mul.Tensor(sub_635, rsqrt_76); sub_635 = rsqrt_76 = None + mul_1350 = torch.ops.aten.mul.Tensor(convert_element_type_1580, mul_1189); convert_element_type_1580 = mul_1189 = None + sum_120 = torch.ops.aten.sum.dim_IntList(mul_1350, [0, 1]); mul_1350 = None + convert_element_type_1583 = torch.ops.prims.convert_element_type.default(mul_1349, torch.bfloat16); mul_1349 = None + convert_element_type_default_77 = torch.ops.prims.convert_element_type.default(sum_120, torch.float32); sum_120 = None + reduce_scatter_tensor_26 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_77, 'avg', 128, '0'); convert_element_type_default_77 = None + wait_tensor_589 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_26); reduce_scatter_tensor_26 = None + convert_element_type_1586 = torch.ops.prims.convert_element_type.default(sum_118, torch.float32); sum_118 = None + view_1811 = torch.ops.aten.view.default(convert_element_type_1586, [2, 4096, 1, 32, 2]); convert_element_type_1586 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_1811); view_1811 = None + mul_1351 = torch.ops.aten.mul.Tensor(view_as_complex_56, clone_9); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_1351); mul_1351 = None + view_1812 = torch.ops.aten.view.default(view_as_real_56, [2, 4096, 1, 64]); view_as_real_56 = None + convert_element_type_1587 = torch.ops.prims.convert_element_type.default(view_1812, torch.bfloat16); view_1812 = None + squeeze_27 = torch.ops.aten.squeeze.dim(convert_element_type_1587, 2); convert_element_type_1587 = None + cat_249 = torch.ops.aten.cat.default([convert_element_type_1583, squeeze_27], 2); convert_element_type_1583 = squeeze_27 = None + view_1813 = torch.ops.aten.view.default(cat_249, [8192, 576]); cat_249 = None + permute_498 = torch.ops.aten.permute.default(view_1813, [1, 0]) + mm_246 = torch.ops.aten.mm.default(permute_498, view_1644); permute_498 = None + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(primals_409, torch.bfloat16); primals_409 = None + all_gather_into_tensor_421 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1341, 128, '0'); convert_element_type_1341 = None + wait_tensor_517 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_421); all_gather_into_tensor_421 = None + slice_151 = torch.ops.aten.slice.Tensor(wait_tensor_517, 0, 0, 576); wait_tensor_517 = None + permute_372 = torch.ops.aten.permute.default(slice_151, [1, 0]); slice_151 = None + permute_500 = torch.ops.aten.permute.default(permute_372, [1, 0]); permute_372 = None + mm_247 = torch.ops.aten.mm.default(view_1813, permute_500); view_1813 = permute_500 = None + view_1814 = torch.ops.aten.view.default(mm_247, [2, 4096, 2048]); mm_247 = None + convert_element_type_1592 = torch.ops.prims.convert_element_type.default(mm_246, torch.float32); mm_246 = None + split_261 = torch.ops.aten.split.Tensor(convert_element_type_1592, 5); convert_element_type_1592 = None + getitem_4753 = split_261[0] + getitem_4754 = split_261[1] + getitem_4755 = split_261[2] + getitem_4756 = split_261[3] + getitem_4757 = split_261[4] + getitem_4758 = split_261[5] + getitem_4759 = split_261[6] + getitem_4760 = split_261[7] + getitem_4761 = split_261[8] + getitem_4762 = split_261[9] + getitem_4763 = split_261[10] + getitem_4764 = split_261[11] + getitem_4765 = split_261[12] + getitem_4766 = split_261[13] + getitem_4767 = split_261[14] + getitem_4768 = split_261[15] + getitem_4769 = split_261[16] + getitem_4770 = split_261[17] + getitem_4771 = split_261[18] + getitem_4772 = split_261[19] + getitem_4773 = split_261[20] + getitem_4774 = split_261[21] + getitem_4775 = split_261[22] + getitem_4776 = split_261[23] + getitem_4777 = split_261[24] + getitem_4778 = split_261[25] + getitem_4779 = split_261[26] + getitem_4780 = split_261[27] + getitem_4781 = split_261[28] + getitem_4782 = split_261[29] + getitem_4783 = split_261[30] + getitem_4784 = split_261[31] + getitem_4785 = split_261[32] + getitem_4786 = split_261[33] + getitem_4787 = split_261[34] + getitem_4788 = split_261[35] + getitem_4789 = split_261[36] + getitem_4790 = split_261[37] + getitem_4791 = split_261[38] + getitem_4792 = split_261[39] + getitem_4793 = split_261[40] + getitem_4794 = split_261[41] + getitem_4795 = split_261[42] + getitem_4796 = split_261[43] + getitem_4797 = split_261[44] + getitem_4798 = split_261[45] + getitem_4799 = split_261[46] + getitem_4800 = split_261[47] + getitem_4801 = split_261[48] + getitem_4802 = split_261[49] + getitem_4803 = split_261[50] + getitem_4804 = split_261[51] + getitem_4805 = split_261[52] + getitem_4806 = split_261[53] + getitem_4807 = split_261[54] + getitem_4808 = split_261[55] + getitem_4809 = split_261[56] + getitem_4810 = split_261[57] + getitem_4811 = split_261[58] + getitem_4812 = split_261[59] + getitem_4813 = split_261[60] + getitem_4814 = split_261[61] + getitem_4815 = split_261[62] + getitem_4816 = split_261[63] + getitem_4817 = split_261[64] + getitem_4818 = split_261[65] + getitem_4819 = split_261[66] + getitem_4820 = split_261[67] + getitem_4821 = split_261[68] + getitem_4822 = split_261[69] + getitem_4823 = split_261[70] + getitem_4824 = split_261[71] + getitem_4825 = split_261[72] + getitem_4826 = split_261[73] + getitem_4827 = split_261[74] + getitem_4828 = split_261[75] + getitem_4829 = split_261[76] + getitem_4830 = split_261[77] + getitem_4831 = split_261[78] + getitem_4832 = split_261[79] + getitem_4833 = split_261[80] + getitem_4834 = split_261[81] + getitem_4835 = split_261[82] + getitem_4836 = split_261[83] + getitem_4837 = split_261[84] + getitem_4838 = split_261[85] + getitem_4839 = split_261[86] + getitem_4840 = split_261[87] + getitem_4841 = split_261[88] + getitem_4842 = split_261[89] + getitem_4843 = split_261[90] + getitem_4844 = split_261[91] + getitem_4845 = split_261[92] + getitem_4846 = split_261[93] + getitem_4847 = split_261[94] + getitem_4848 = split_261[95] + getitem_4849 = split_261[96] + getitem_4850 = split_261[97] + getitem_4851 = split_261[98] + getitem_4852 = split_261[99] + getitem_4853 = split_261[100] + getitem_4854 = split_261[101] + getitem_4855 = split_261[102] + getitem_4856 = split_261[103] + getitem_4857 = split_261[104] + getitem_4858 = split_261[105] + getitem_4859 = split_261[106] + getitem_4860 = split_261[107] + getitem_4861 = split_261[108] + getitem_4862 = split_261[109] + getitem_4863 = split_261[110] + getitem_4864 = split_261[111] + getitem_4865 = split_261[112] + getitem_4866 = split_261[113] + getitem_4867 = split_261[114] + getitem_4868 = split_261[115]; split_261 = None + constant_pad_nd_141 = torch.ops.aten.constant_pad_nd.default(getitem_4868, [0, 0, 0, 4], 0.0); getitem_4868 = None + cat_250 = torch.ops.aten.cat.default([getitem_4753, getitem_4754, getitem_4755, getitem_4756, getitem_4757, getitem_4758, getitem_4759, getitem_4760, getitem_4761, getitem_4762, getitem_4763, getitem_4764, getitem_4765, getitem_4766, getitem_4767, getitem_4768, getitem_4769, getitem_4770, getitem_4771, getitem_4772, getitem_4773, getitem_4774, getitem_4775, getitem_4776, getitem_4777, getitem_4778, getitem_4779, getitem_4780, getitem_4781, getitem_4782, getitem_4783, getitem_4784, getitem_4785, getitem_4786, getitem_4787, getitem_4788, getitem_4789, getitem_4790, getitem_4791, getitem_4792, getitem_4793, getitem_4794, getitem_4795, getitem_4796, getitem_4797, getitem_4798, getitem_4799, getitem_4800, getitem_4801, getitem_4802, getitem_4803, getitem_4804, getitem_4805, getitem_4806, getitem_4807, getitem_4808, getitem_4809, getitem_4810, getitem_4811, getitem_4812, getitem_4813, getitem_4814, getitem_4815, getitem_4816, getitem_4817, getitem_4818, getitem_4819, getitem_4820, getitem_4821, getitem_4822, getitem_4823, getitem_4824, getitem_4825, getitem_4826, getitem_4827, getitem_4828, getitem_4829, getitem_4830, getitem_4831, getitem_4832, getitem_4833, getitem_4834, getitem_4835, getitem_4836, getitem_4837, getitem_4838, getitem_4839, getitem_4840, getitem_4841, getitem_4842, getitem_4843, getitem_4844, getitem_4845, getitem_4846, getitem_4847, getitem_4848, getitem_4849, getitem_4850, getitem_4851, getitem_4852, getitem_4853, getitem_4854, getitem_4855, getitem_4856, getitem_4857, getitem_4858, getitem_4859, getitem_4860, getitem_4861, getitem_4862, getitem_4863, getitem_4864, getitem_4865, getitem_4866, getitem_4867, constant_pad_nd_141, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_4753 = getitem_4754 = getitem_4755 = getitem_4756 = getitem_4757 = getitem_4758 = getitem_4759 = getitem_4760 = getitem_4761 = getitem_4762 = getitem_4763 = getitem_4764 = getitem_4765 = getitem_4766 = getitem_4767 = getitem_4768 = getitem_4769 = getitem_4770 = getitem_4771 = getitem_4772 = getitem_4773 = getitem_4774 = getitem_4775 = getitem_4776 = getitem_4777 = getitem_4778 = getitem_4779 = getitem_4780 = getitem_4781 = getitem_4782 = getitem_4783 = getitem_4784 = getitem_4785 = getitem_4786 = getitem_4787 = getitem_4788 = getitem_4789 = getitem_4790 = getitem_4791 = getitem_4792 = getitem_4793 = getitem_4794 = getitem_4795 = getitem_4796 = getitem_4797 = getitem_4798 = getitem_4799 = getitem_4800 = getitem_4801 = getitem_4802 = getitem_4803 = getitem_4804 = getitem_4805 = getitem_4806 = getitem_4807 = getitem_4808 = getitem_4809 = getitem_4810 = getitem_4811 = getitem_4812 = getitem_4813 = getitem_4814 = getitem_4815 = getitem_4816 = getitem_4817 = getitem_4818 = getitem_4819 = getitem_4820 = getitem_4821 = getitem_4822 = getitem_4823 = getitem_4824 = getitem_4825 = getitem_4826 = getitem_4827 = getitem_4828 = getitem_4829 = getitem_4830 = getitem_4831 = getitem_4832 = getitem_4833 = getitem_4834 = getitem_4835 = getitem_4836 = getitem_4837 = getitem_4838 = getitem_4839 = getitem_4840 = getitem_4841 = getitem_4842 = getitem_4843 = getitem_4844 = getitem_4845 = getitem_4846 = getitem_4847 = getitem_4848 = getitem_4849 = getitem_4850 = getitem_4851 = getitem_4852 = getitem_4853 = getitem_4854 = getitem_4855 = getitem_4856 = getitem_4857 = getitem_4858 = getitem_4859 = getitem_4860 = getitem_4861 = getitem_4862 = getitem_4863 = getitem_4864 = getitem_4865 = getitem_4866 = getitem_4867 = constant_pad_nd_141 = None + reduce_scatter_tensor_27 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_250, 'avg', 128, '0'); cat_250 = None + wait_tensor_590 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27); reduce_scatter_tensor_27 = None + slice_172 = torch.ops.aten.slice.Tensor(permute_493, 3, 0, 128) + slice_173 = torch.ops.aten.slice.Tensor(permute_493, 3, 128, 192); permute_493 = None + convert_element_type_1593 = torch.ops.prims.convert_element_type.default(slice_173, torch.float32); slice_173 = None + view_1815 = torch.ops.aten.view.default(convert_element_type_1593, [2, 4096, 16, 32, 2]); convert_element_type_1593 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_1815); view_1815 = None + mul_1352 = torch.ops.aten.mul.Tensor(view_as_complex_57, clone_9); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_1352); mul_1352 = None + view_1816 = torch.ops.aten.view.default(view_as_real_57, [2, 4096, 16, 64]); view_as_real_57 = None + convert_element_type_1594 = torch.ops.prims.convert_element_type.default(view_1816, torch.bfloat16); view_1816 = None + cat_251 = torch.ops.aten.cat.default([slice_172, convert_element_type_1594], 3); slice_172 = convert_element_type_1594 = None + view_1817 = torch.ops.aten.view.default(cat_251, [2, 4096, 3072]); cat_251 = None + view_1818 = torch.ops.aten.view.default(view_1817, [8192, 3072]); view_1817 = None + permute_502 = torch.ops.aten.permute.default(view_1818, [1, 0]) + mm_248 = torch.ops.aten.mm.default(permute_502, view_1644); permute_502 = view_1644 = None + convert_element_type_1336 = torch.ops.prims.convert_element_type.default(primals_408, torch.bfloat16); primals_408 = None + all_gather_into_tensor_420 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1336, 128, '0'); convert_element_type_1336 = None + wait_tensor_516 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_420); all_gather_into_tensor_420 = None + permute_371 = torch.ops.aten.permute.default(wait_tensor_516, [1, 0]); wait_tensor_516 = None + permute_504 = torch.ops.aten.permute.default(permute_371, [1, 0]); permute_371 = None + mm_249 = torch.ops.aten.mm.default(view_1818, permute_504); view_1818 = permute_504 = None + view_1819 = torch.ops.aten.view.default(mm_249, [2, 4096, 2048]); mm_249 = None + add_1803 = torch.ops.aten.add.Tensor(view_1814, view_1819); view_1814 = view_1819 = None + convert_element_type_1599 = torch.ops.prims.convert_element_type.default(mm_248, torch.float32); mm_248 = None + reduce_scatter_tensor_28 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1599, 'avg', 128, '0'); convert_element_type_1599 = None + wait_tensor_591 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_28); reduce_scatter_tensor_28 = None + convert_element_type_1600 = torch.ops.prims.convert_element_type.default(add_1803, torch.float32); add_1803 = None + convert_element_type_1333 = torch.ops.prims.convert_element_type.default(primals_407, torch.bfloat16); primals_407 = None + all_gather_into_tensor_419 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1333, 128, '0'); convert_element_type_1333 = None + wait_tensor_515 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_419); all_gather_into_tensor_419 = None + convert_element_type_1602 = torch.ops.prims.convert_element_type.default(wait_tensor_515, torch.float32); wait_tensor_515 = None + mul_1353 = torch.ops.aten.mul.Tensor(convert_element_type_1600, convert_element_type_1602); convert_element_type_1602 = None + convert_element_type_1334 = torch.ops.prims.convert_element_type.default(add_1637, torch.float32); add_1637 = None + mul_1185 = torch.ops.aten.mul.Tensor(convert_element_type_1334, rsqrt_75); convert_element_type_1334 = None + mul_1355 = torch.ops.aten.mul.Tensor(mul_1185, mul_1353) + sum_121 = torch.ops.aten.sum.dim_IntList(mul_1355, [2], True); mul_1355 = None + div_143 = torch.ops.aten.div.Tensor(mul_1185, 2048) + mul_1356 = torch.ops.aten.mul.Tensor(div_143, sum_121); div_143 = sum_121 = None + sub_636 = torch.ops.aten.sub.Tensor(mul_1353, mul_1356); mul_1353 = mul_1356 = None + mul_1357 = torch.ops.aten.mul.Tensor(sub_636, rsqrt_75); sub_636 = rsqrt_75 = None + mul_1358 = torch.ops.aten.mul.Tensor(convert_element_type_1600, mul_1185); convert_element_type_1600 = mul_1185 = None + sum_122 = torch.ops.aten.sum.dim_IntList(mul_1358, [0, 1]); mul_1358 = None + convert_element_type_1603 = torch.ops.prims.convert_element_type.default(mul_1357, torch.bfloat16); mul_1357 = None + add_1804 = torch.ops.aten.add.Tensor(add_1802, convert_element_type_1603); add_1802 = convert_element_type_1603 = None + convert_element_type_default_76 = torch.ops.prims.convert_element_type.default(sum_122, torch.float32); sum_122 = None + reduce_scatter_tensor_29 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_76, 'avg', 128, '0'); convert_element_type_default_76 = None + wait_tensor_592 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29); reduce_scatter_tensor_29 = None + view_1820 = torch.ops.aten.view.default(add_1804, [8192, 2048]) + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_1820, 1) + convert_element_type_1606 = torch.ops.prims.convert_element_type.default(unsqueeze_55, torch.float32); unsqueeze_55 = None + bmm_30 = torch.ops.aten.bmm.default(permute_506, convert_element_type_1606); permute_506 = None + bmm_31 = torch.ops.aten.bmm.default(convert_element_type_1606, permute_507); convert_element_type_1606 = permute_507 = None + convert_element_type_1607 = torch.ops.prims.convert_element_type.default(bmm_30, torch.bfloat16); bmm_30 = None + view_1821 = torch.ops.aten.view.default(bmm_31, [8192, 6]); bmm_31 = None + view_1822 = torch.ops.aten.view.default(convert_element_type_1607, [49152, 2048]); convert_element_type_1607 = None + index_56 = torch.ops.aten.index.Tensor(view_1822, [getitem_2551]); view_1822 = getitem_2551 = None + permute_508 = torch.ops.aten.permute.default(view_1820, [1, 0]) + mm_250 = torch.ops.aten.mm.default(permute_508, mul_1182); permute_508 = mul_1182 = None + convert_element_type_1328 = torch.ops.prims.convert_element_type.default(primals_406, torch.bfloat16); primals_406 = None + all_gather_into_tensor_418 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1328, 128, '0'); convert_element_type_1328 = None + wait_tensor_514 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_418); all_gather_into_tensor_418 = None + permute_370 = torch.ops.aten.permute.default(wait_tensor_514, [1, 0]); wait_tensor_514 = None + permute_510 = torch.ops.aten.permute.default(permute_370, [1, 0]); permute_370 = None + mm_251 = torch.ops.aten.mm.default(view_1820, permute_510); view_1820 = permute_510 = None + convert_element_type_1612 = torch.ops.prims.convert_element_type.default(mm_250, torch.float32); mm_250 = None + reduce_scatter_tensor_30 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1612, 'avg', 128, '0'); convert_element_type_1612 = None + wait_tensor_593 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_30); reduce_scatter_tensor_30 = None + convert_element_type_1323 = torch.ops.prims.convert_element_type.default(mm_196, torch.float32); mm_196 = None + neg_48 = torch.ops.aten.neg.default(convert_element_type_1323) + exp_72 = torch.ops.aten.exp.default(neg_48); neg_48 = None + add_1632 = torch.ops.aten.add.Tensor(exp_72, 1); exp_72 = None + div_120 = torch.ops.aten.div.Tensor(convert_element_type_1323, add_1632) + convert_element_type_1324 = torch.ops.prims.convert_element_type.default(div_120, torch.bfloat16); div_120 = None + mul_1359 = torch.ops.aten.mul.Tensor(mm_251, convert_element_type_1324); convert_element_type_1324 = None + mul_1360 = torch.ops.aten.mul.Tensor(mm_251, mm_197); mm_251 = mm_197 = None + permute_512 = torch.ops.aten.permute.default(mul_1359, [1, 0]) + mm_252 = torch.ops.aten.mm.default(permute_512, view_1599); permute_512 = None + convert_element_type_1325 = torch.ops.prims.convert_element_type.default(primals_405, torch.bfloat16); primals_405 = None + all_gather_into_tensor_417 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1325, 128, '0'); convert_element_type_1325 = None + wait_tensor_513 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_417); all_gather_into_tensor_417 = None + permute_369 = torch.ops.aten.permute.default(wait_tensor_513, [1, 0]); wait_tensor_513 = None + permute_514 = torch.ops.aten.permute.default(permute_369, [1, 0]); permute_369 = None + mm_253 = torch.ops.aten.mm.default(mul_1359, permute_514); mul_1359 = permute_514 = None + convert_element_type_1617 = torch.ops.prims.convert_element_type.default(mm_252, torch.float32); mm_252 = None + reduce_scatter_tensor_31 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1617, 'avg', 128, '0'); convert_element_type_1617 = None + wait_tensor_594 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31); reduce_scatter_tensor_31 = None + convert_element_type_1618 = torch.ops.prims.convert_element_type.default(mul_1360, torch.float32); mul_1360 = None + reciprocal_4 = torch.ops.aten.reciprocal.default(add_1632); add_1632 = None + mul_1361 = torch.ops.aten.mul.Tensor(reciprocal_4, 1); reciprocal_4 = None + mul_1362 = torch.ops.aten.mul.Tensor(convert_element_type_1618, mul_1361); convert_element_type_1618 = None + sub_637 = torch.ops.aten.sub.Tensor(1, mul_1361); mul_1361 = None + mul_1363 = torch.ops.aten.mul.Tensor(convert_element_type_1323, sub_637); convert_element_type_1323 = sub_637 = None + add_1806 = torch.ops.aten.add.Tensor(mul_1363, 1); mul_1363 = None + mul_1364 = torch.ops.aten.mul.Tensor(mul_1362, add_1806); mul_1362 = add_1806 = None + convert_element_type_1620 = torch.ops.prims.convert_element_type.default(mul_1364, torch.bfloat16); mul_1364 = None + permute_516 = torch.ops.aten.permute.default(convert_element_type_1620, [1, 0]) + mm_254 = torch.ops.aten.mm.default(permute_516, view_1599); permute_516 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(primals_404, torch.bfloat16); primals_404 = None + all_gather_into_tensor_416 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1320, 128, '0'); convert_element_type_1320 = None + wait_tensor_512 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_416); all_gather_into_tensor_416 = None + permute_368 = torch.ops.aten.permute.default(wait_tensor_512, [1, 0]); wait_tensor_512 = None + permute_518 = torch.ops.aten.permute.default(permute_368, [1, 0]); permute_368 = None + mm_255 = torch.ops.aten.mm.default(convert_element_type_1620, permute_518); convert_element_type_1620 = permute_518 = None + add_1807 = torch.ops.aten.add.Tensor(mm_253, mm_255); mm_253 = mm_255 = None + convert_element_type_1625 = torch.ops.prims.convert_element_type.default(mm_254, torch.float32); mm_254 = None + reduce_scatter_tensor_32 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1625, 'avg', 128, '0'); convert_element_type_1625 = None + wait_tensor_595 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + all_to_all_single_82 = torch.ops._c10d_functional.all_to_all_single.default(index_56, [_local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383], [_local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375], '1033'); index_56 = None + wait_tensor_596 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_82); all_to_all_single_82 = None + full_360 = torch.ops.aten.full.default([sym_size_int_93, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_93 = None + slice_scatter_2 = torch.ops.aten.slice_scatter.default(full_360, wait_tensor_596, 0, 0, -1); wait_tensor_596 = None + index_57 = torch.ops.aten.index.Tensor(slice_scatter_2, [getitem_2552]); slice_scatter_2 = None + permute_520 = torch.ops.aten.permute.default(index_57, [1, 0]) + _grouped_mm_90 = torch.ops.aten._grouped_mm.default(permute_520, mul_1162, cumsum_71); permute_520 = mul_1162 = None + _grouped_mm_91 = torch.ops.aten._grouped_mm.default(index_57, permute_522, cumsum_71); index_57 = permute_522 = None + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(_grouped_mm_69, torch.float32); _grouped_mm_69 = None + neg_47 = torch.ops.aten.neg.default(convert_element_type_1318) + exp_71 = torch.ops.aten.exp.default(neg_47); neg_47 = None + add_1596 = torch.ops.aten.add.Tensor(exp_71, 1); exp_71 = None + div_119 = torch.ops.aten.div.Tensor(convert_element_type_1318, add_1596) + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(div_119, torch.bfloat16); div_119 = None + mul_1365 = torch.ops.aten.mul.Tensor(_grouped_mm_91, convert_element_type_1319); convert_element_type_1319 = None + mul_1366 = torch.ops.aten.mul.Tensor(_grouped_mm_91, _grouped_mm_70); _grouped_mm_91 = _grouped_mm_70 = None + permute_524 = torch.ops.aten.permute.default(mul_1365, [1, 0]) + _grouped_mm_92 = torch.ops.aten._grouped_mm.default(permute_524, index_47, cumsum_71); permute_524 = None + _grouped_mm_93 = torch.ops.aten._grouped_mm.default(mul_1365, permute_526, cumsum_71); mul_1365 = permute_526 = None + convert_element_type_1626 = torch.ops.prims.convert_element_type.default(mul_1366, torch.float32); mul_1366 = None + reciprocal_5 = torch.ops.aten.reciprocal.default(add_1596); add_1596 = None + mul_1367 = torch.ops.aten.mul.Tensor(reciprocal_5, 1); reciprocal_5 = None + mul_1368 = torch.ops.aten.mul.Tensor(convert_element_type_1626, mul_1367); convert_element_type_1626 = None + sub_638 = torch.ops.aten.sub.Tensor(1, mul_1367); mul_1367 = None + mul_1369 = torch.ops.aten.mul.Tensor(convert_element_type_1318, sub_638); convert_element_type_1318 = sub_638 = None + add_1809 = torch.ops.aten.add.Tensor(mul_1369, 1); mul_1369 = None + mul_1370 = torch.ops.aten.mul.Tensor(mul_1368, add_1809); mul_1368 = add_1809 = None + convert_element_type_1628 = torch.ops.prims.convert_element_type.default(mul_1370, torch.bfloat16); mul_1370 = None + permute_528 = torch.ops.aten.permute.default(convert_element_type_1628, [1, 0]) + _grouped_mm_94 = torch.ops.aten._grouped_mm.default(permute_528, index_47, cumsum_71); permute_528 = index_47 = None + _grouped_mm_95 = torch.ops.aten._grouped_mm.default(convert_element_type_1628, permute_530, cumsum_71); convert_element_type_1628 = permute_530 = cumsum_71 = None + add_1810 = torch.ops.aten.add.Tensor(_grouped_mm_93, _grouped_mm_95); _grouped_mm_93 = _grouped_mm_95 = None + convert_element_type_1629 = torch.ops.prims.convert_element_type.default(_grouped_mm_92, torch.float32); _grouped_mm_92 = None + div_144 = torch.ops.aten.div.Tensor(convert_element_type_1629, 128); convert_element_type_1629 = None + split_263 = torch.ops.aten.split.Tensor(div_144, 88, 1); div_144 = None + getitem_4885 = split_263[0] + getitem_4902 = split_263[1] + getitem_4919 = split_263[2] + getitem_4936 = split_263[3] + getitem_4953 = split_263[4] + getitem_4970 = split_263[5] + getitem_4987 = split_263[6] + getitem_5004 = split_263[7] + getitem_5021 = split_263[8] + getitem_5038 = split_263[9] + getitem_5055 = split_263[10] + getitem_5072 = split_263[11] + getitem_5089 = split_263[12] + getitem_5106 = split_263[13] + getitem_5123 = split_263[14] + getitem_5140 = split_263[15]; split_263 = None + cat_252 = torch.ops.aten.cat.default([getitem_4885, getitem_4902, getitem_4919, getitem_4936, getitem_4953, getitem_4970, getitem_4987, getitem_5004, getitem_5021, getitem_5038, getitem_5055, getitem_5072, getitem_5089, getitem_5106, getitem_5123, getitem_5140]); getitem_4885 = getitem_4902 = getitem_4919 = getitem_4936 = getitem_4953 = getitem_4970 = getitem_4987 = getitem_5004 = getitem_5021 = getitem_5038 = getitem_5055 = getitem_5072 = getitem_5089 = getitem_5106 = getitem_5123 = getitem_5140 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_252, 'sum', 16, '1025'); cat_252 = None + wait_tensor_597 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33); reduce_scatter_tensor_33 = None + convert_element_type_1630 = torch.ops.prims.convert_element_type.default(_grouped_mm_90, torch.float32); _grouped_mm_90 = None + div_145 = torch.ops.aten.div.Tensor(convert_element_type_1630, 128); convert_element_type_1630 = None + split_280 = torch.ops.aten.split.Tensor(div_145, 128, 1); div_145 = None + getitem_5157 = split_280[0] + getitem_5174 = split_280[1] + getitem_5191 = split_280[2] + getitem_5208 = split_280[3] + getitem_5225 = split_280[4] + getitem_5242 = split_280[5] + getitem_5259 = split_280[6] + getitem_5276 = split_280[7] + getitem_5293 = split_280[8] + getitem_5310 = split_280[9] + getitem_5327 = split_280[10] + getitem_5344 = split_280[11] + getitem_5361 = split_280[12] + getitem_5378 = split_280[13] + getitem_5395 = split_280[14] + getitem_5412 = split_280[15]; split_280 = None + cat_253 = torch.ops.aten.cat.default([getitem_5157, getitem_5174, getitem_5191, getitem_5208, getitem_5225, getitem_5242, getitem_5259, getitem_5276, getitem_5293, getitem_5310, getitem_5327, getitem_5344, getitem_5361, getitem_5378, getitem_5395, getitem_5412]); getitem_5157 = getitem_5174 = getitem_5191 = getitem_5208 = getitem_5225 = getitem_5242 = getitem_5259 = getitem_5276 = getitem_5293 = getitem_5310 = getitem_5327 = getitem_5344 = getitem_5361 = getitem_5378 = getitem_5395 = getitem_5412 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_253, 'sum', 16, '1025'); cat_253 = None + wait_tensor_598 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + convert_element_type_1631 = torch.ops.prims.convert_element_type.default(_grouped_mm_94, torch.float32); _grouped_mm_94 = None + div_146 = torch.ops.aten.div.Tensor(convert_element_type_1631, 128); convert_element_type_1631 = None + split_297 = torch.ops.aten.split.Tensor(div_146, 88, 1); div_146 = None + getitem_5429 = split_297[0] + getitem_5446 = split_297[1] + getitem_5463 = split_297[2] + getitem_5480 = split_297[3] + getitem_5497 = split_297[4] + getitem_5514 = split_297[5] + getitem_5531 = split_297[6] + getitem_5548 = split_297[7] + getitem_5565 = split_297[8] + getitem_5582 = split_297[9] + getitem_5599 = split_297[10] + getitem_5616 = split_297[11] + getitem_5633 = split_297[12] + getitem_5650 = split_297[13] + getitem_5667 = split_297[14] + getitem_5684 = split_297[15]; split_297 = None + cat_254 = torch.ops.aten.cat.default([getitem_5429, getitem_5446, getitem_5463, getitem_5480, getitem_5497, getitem_5514, getitem_5531, getitem_5548, getitem_5565, getitem_5582, getitem_5599, getitem_5616, getitem_5633, getitem_5650, getitem_5667, getitem_5684]); getitem_5429 = getitem_5446 = getitem_5463 = getitem_5480 = getitem_5497 = getitem_5514 = getitem_5531 = getitem_5548 = getitem_5565 = getitem_5582 = getitem_5599 = getitem_5616 = getitem_5633 = getitem_5650 = getitem_5667 = getitem_5684 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_254, 'sum', 16, '1025'); cat_254 = None + wait_tensor_599 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35); reduce_scatter_tensor_35 = None + index_put_56 = torch.ops.aten.index_put.default(full_360, [getitem_2552], add_1810, True); full_360 = getitem_2552 = add_1810 = None + slice_174 = torch.ops.aten.slice.Tensor(index_put_56, 0, 0, add_1811); index_put_56 = add_1811 = None + all_to_all_single_83 = torch.ops._c10d_functional.all_to_all_single.default(slice_174, [_local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375], [_local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383], '1033'); slice_174 = _local_scalar_dense_368 = _local_scalar_dense_369 = _local_scalar_dense_370 = _local_scalar_dense_371 = _local_scalar_dense_372 = _local_scalar_dense_373 = _local_scalar_dense_374 = _local_scalar_dense_375 = _local_scalar_dense_376 = _local_scalar_dense_377 = _local_scalar_dense_378 = _local_scalar_dense_379 = _local_scalar_dense_380 = _local_scalar_dense_381 = _local_scalar_dense_382 = _local_scalar_dense_383 = None + wait_tensor_600 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_83); all_to_all_single_83 = None + index_put_57 = torch.ops.aten.index_put.default(full_default_52, [div_117], wait_tensor_600, True); div_117 = wait_tensor_600 = None + add_1815 = torch.ops.aten.add.Tensor(add_1807, index_put_57); add_1807 = index_put_57 = None + mul_1371 = torch.ops.aten.mul.Tensor(view_1821, 1.0); view_1821 = None + scatter_add_2 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_2549, mul_1371); getitem_2549 = mul_1371 = None + convert_element_type_1307 = torch.ops.prims.convert_element_type.default(mm_195, torch.float32); mm_195 = None + sub_552 = torch.ops.aten.sub.Tensor(convert_element_type_1307, amax_23); convert_element_type_1307 = amax_23 = None + exp_70 = torch.ops.aten.exp.default(sub_552); sub_552 = None + div_116 = torch.ops.aten.div.Tensor(exp_70, sum_93); exp_70 = sum_93 = None + mul_1372 = torch.ops.aten.mul.Tensor(scatter_add_2, div_116); scatter_add_2 = None + sum_123 = torch.ops.aten.sum.dim_IntList(mul_1372, [1], True) + neg_61 = torch.ops.aten.neg.default(div_116); div_116 = None + fma_2 = torch.ops.prims.fma.default(neg_61, sum_123, mul_1372); neg_61 = sum_123 = mul_1372 = None + convert_element_type_1632 = torch.ops.prims.convert_element_type.default(fma_2, torch.bfloat16); fma_2 = None + permute_532 = torch.ops.aten.permute.default(convert_element_type_1632, [1, 0]) + mm_256 = torch.ops.aten.mm.default(permute_532, view_1599); permute_532 = view_1599 = None + convert_element_type_1304 = torch.ops.prims.convert_element_type.default(primals_399, torch.bfloat16); primals_399 = None + all_gather_into_tensor_409 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1304, 128, '0'); convert_element_type_1304 = None + wait_tensor_501 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_409); all_gather_into_tensor_409 = None + slice_147 = torch.ops.aten.slice.Tensor(wait_tensor_501, 0, 0, 64); wait_tensor_501 = None + permute_364 = torch.ops.aten.permute.default(slice_147, [1, 0]); slice_147 = None + permute_534 = torch.ops.aten.permute.default(permute_364, [1, 0]); permute_364 = None + mm_257 = torch.ops.aten.mm.default(convert_element_type_1632, permute_534); convert_element_type_1632 = permute_534 = None + add_1816 = torch.ops.aten.add.Tensor(add_1815, mm_257); add_1815 = mm_257 = None + convert_element_type_1637 = torch.ops.prims.convert_element_type.default(mm_256, torch.float32); mm_256 = None + split_313 = torch.ops.aten.split.Tensor(convert_element_type_1637, 1); convert_element_type_1637 = None + getitem_5685 = split_313[0] + getitem_5686 = split_313[1] + getitem_5687 = split_313[2] + getitem_5688 = split_313[3] + getitem_5689 = split_313[4] + getitem_5690 = split_313[5] + getitem_5691 = split_313[6] + getitem_5692 = split_313[7] + getitem_5693 = split_313[8] + getitem_5694 = split_313[9] + getitem_5695 = split_313[10] + getitem_5696 = split_313[11] + getitem_5697 = split_313[12] + getitem_5698 = split_313[13] + getitem_5699 = split_313[14] + getitem_5700 = split_313[15] + getitem_5701 = split_313[16] + getitem_5702 = split_313[17] + getitem_5703 = split_313[18] + getitem_5704 = split_313[19] + getitem_5705 = split_313[20] + getitem_5706 = split_313[21] + getitem_5707 = split_313[22] + getitem_5708 = split_313[23] + getitem_5709 = split_313[24] + getitem_5710 = split_313[25] + getitem_5711 = split_313[26] + getitem_5712 = split_313[27] + getitem_5713 = split_313[28] + getitem_5714 = split_313[29] + getitem_5715 = split_313[30] + getitem_5716 = split_313[31] + getitem_5717 = split_313[32] + getitem_5718 = split_313[33] + getitem_5719 = split_313[34] + getitem_5720 = split_313[35] + getitem_5721 = split_313[36] + getitem_5722 = split_313[37] + getitem_5723 = split_313[38] + getitem_5724 = split_313[39] + getitem_5725 = split_313[40] + getitem_5726 = split_313[41] + getitem_5727 = split_313[42] + getitem_5728 = split_313[43] + getitem_5729 = split_313[44] + getitem_5730 = split_313[45] + getitem_5731 = split_313[46] + getitem_5732 = split_313[47] + getitem_5733 = split_313[48] + getitem_5734 = split_313[49] + getitem_5735 = split_313[50] + getitem_5736 = split_313[51] + getitem_5737 = split_313[52] + getitem_5738 = split_313[53] + getitem_5739 = split_313[54] + getitem_5740 = split_313[55] + getitem_5741 = split_313[56] + getitem_5742 = split_313[57] + getitem_5743 = split_313[58] + getitem_5744 = split_313[59] + getitem_5745 = split_313[60] + getitem_5746 = split_313[61] + getitem_5747 = split_313[62] + getitem_5748 = split_313[63]; split_313 = None + cat_255 = torch.ops.aten.cat.default([getitem_5685, getitem_5686, getitem_5687, getitem_5688, getitem_5689, getitem_5690, getitem_5691, getitem_5692, getitem_5693, getitem_5694, getitem_5695, getitem_5696, getitem_5697, getitem_5698, getitem_5699, getitem_5700, getitem_5701, getitem_5702, getitem_5703, getitem_5704, getitem_5705, getitem_5706, getitem_5707, getitem_5708, getitem_5709, getitem_5710, getitem_5711, getitem_5712, getitem_5713, getitem_5714, getitem_5715, getitem_5716, getitem_5717, getitem_5718, getitem_5719, getitem_5720, getitem_5721, getitem_5722, getitem_5723, getitem_5724, getitem_5725, getitem_5726, getitem_5727, getitem_5728, getitem_5729, getitem_5730, getitem_5731, getitem_5732, getitem_5733, getitem_5734, getitem_5735, getitem_5736, getitem_5737, getitem_5738, getitem_5739, getitem_5740, getitem_5741, getitem_5742, getitem_5743, getitem_5744, getitem_5745, getitem_5746, getitem_5747, getitem_5748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_5685 = getitem_5686 = getitem_5687 = getitem_5688 = getitem_5689 = getitem_5690 = getitem_5691 = getitem_5692 = getitem_5693 = getitem_5694 = getitem_5695 = getitem_5696 = getitem_5697 = getitem_5698 = getitem_5699 = getitem_5700 = getitem_5701 = getitem_5702 = getitem_5703 = getitem_5704 = getitem_5705 = getitem_5706 = getitem_5707 = getitem_5708 = getitem_5709 = getitem_5710 = getitem_5711 = getitem_5712 = getitem_5713 = getitem_5714 = getitem_5715 = getitem_5716 = getitem_5717 = getitem_5718 = getitem_5719 = getitem_5720 = getitem_5721 = getitem_5722 = getitem_5723 = getitem_5724 = getitem_5725 = getitem_5726 = getitem_5727 = getitem_5728 = getitem_5729 = getitem_5730 = getitem_5731 = getitem_5732 = getitem_5733 = getitem_5734 = getitem_5735 = getitem_5736 = getitem_5737 = getitem_5738 = getitem_5739 = getitem_5740 = getitem_5741 = getitem_5742 = getitem_5743 = getitem_5744 = getitem_5745 = getitem_5746 = getitem_5747 = getitem_5748 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_255, 'avg', 128, '0'); cat_255 = None + wait_tensor_601 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + view_1823 = torch.ops.aten.view.default(add_1816, [2, 4096, 2048]); add_1816 = None + convert_element_type_1638 = torch.ops.prims.convert_element_type.default(view_1823, torch.float32); view_1823 = None + convert_element_type_1301 = torch.ops.prims.convert_element_type.default(primals_397, torch.bfloat16); primals_397 = None + all_gather_into_tensor_408 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1301, 128, '0'); convert_element_type_1301 = None + wait_tensor_500 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_408); all_gather_into_tensor_408 = None + convert_element_type_1640 = torch.ops.prims.convert_element_type.default(wait_tensor_500, torch.float32); wait_tensor_500 = None + mul_1373 = torch.ops.aten.mul.Tensor(convert_element_type_1638, convert_element_type_1640); convert_element_type_1640 = None + convert_element_type_1302 = torch.ops.prims.convert_element_type.default(add_1572, torch.float32); add_1572 = None + mul_1142 = torch.ops.aten.mul.Tensor(convert_element_type_1302, rsqrt_74); convert_element_type_1302 = None + mul_1375 = torch.ops.aten.mul.Tensor(mul_1142, mul_1373) + sum_124 = torch.ops.aten.sum.dim_IntList(mul_1375, [2], True); mul_1375 = None + div_147 = torch.ops.aten.div.Tensor(mul_1142, 2048) + mul_1376 = torch.ops.aten.mul.Tensor(div_147, sum_124); div_147 = sum_124 = None + sub_640 = torch.ops.aten.sub.Tensor(mul_1373, mul_1376); mul_1373 = mul_1376 = None + mul_1377 = torch.ops.aten.mul.Tensor(sub_640, rsqrt_74); sub_640 = rsqrt_74 = None + mul_1378 = torch.ops.aten.mul.Tensor(convert_element_type_1638, mul_1142); convert_element_type_1638 = mul_1142 = None + sum_125 = torch.ops.aten.sum.dim_IntList(mul_1378, [0, 1]); mul_1378 = None + convert_element_type_1641 = torch.ops.prims.convert_element_type.default(mul_1377, torch.bfloat16); mul_1377 = None + add_1817 = torch.ops.aten.add.Tensor(add_1804, convert_element_type_1641); add_1804 = convert_element_type_1641 = None + convert_element_type_default_75 = torch.ops.prims.convert_element_type.default(sum_125, torch.float32); sum_125 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_75, 'avg', 128, '0'); convert_element_type_default_75 = None + wait_tensor_602 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37); reduce_scatter_tensor_37 = None + view_1824 = torch.ops.aten.view.default(add_1817, [8192, 2048]) + permute_536 = torch.ops.aten.permute.default(view_1824, [1, 0]) + permute_362 = torch.ops.aten.permute.default(getitem_2545, [0, 2, 1, 3]) + view_1594 = torch.ops.aten.view.default(permute_362, [2, 4096, -1]); permute_362 = None + view_1596 = torch.ops.aten.view.default(view_1594, [8192, 2048]); view_1594 = None + mm_258 = torch.ops.aten.mm.default(permute_536, view_1596); permute_536 = view_1596 = None + convert_element_type_1298 = torch.ops.prims.convert_element_type.default(primals_396, torch.bfloat16); primals_396 = None + all_gather_into_tensor_407 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1298, 128, '0'); convert_element_type_1298 = None + wait_tensor_499 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_407); all_gather_into_tensor_407 = None + permute_363 = torch.ops.aten.permute.default(wait_tensor_499, [1, 0]); wait_tensor_499 = None + permute_538 = torch.ops.aten.permute.default(permute_363, [1, 0]); permute_363 = None + mm_259 = torch.ops.aten.mm.default(view_1824, permute_538); view_1824 = permute_538 = None + view_1825 = torch.ops.aten.view.default(mm_259, [2, 4096, 2048]); mm_259 = None + convert_element_type_1648 = torch.ops.prims.convert_element_type.default(mm_258, torch.float32); mm_258 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1648, 'avg', 128, '0'); convert_element_type_1648 = None + wait_tensor_603 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + view_1826 = torch.ops.aten.view.default(view_1825, [2, 4096, 16, 128]); view_1825 = None + permute_540 = torch.ops.aten.permute.default(view_1826, [0, 2, 1, 3]); view_1826 = None + fw_graph2 = self.fw_graph2 + joint_graph2 = self.joint_graph2 + mask_graph2 = self.mask_graph2 + flex_attention_backward_2 = torch.ops.higher_order.flex_attention_backward(permute_359, permute_360, permute_361, getitem_2545, getitem_2546, permute_540, None, fw_graph2, joint_graph2, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph2), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_359 = permute_360 = permute_361 = getitem_2545 = getitem_2546 = permute_540 = fw_graph2 = joint_graph2 = mask_graph2 = None + getitem_5749 = flex_attention_backward_2[0] + getitem_5750 = flex_attention_backward_2[1] + getitem_5751 = flex_attention_backward_2[2]; flex_attention_backward_2 = None + permute_541 = torch.ops.aten.permute.default(getitem_5751, [0, 2, 1, 3]); getitem_5751 = None + permute_542 = torch.ops.aten.permute.default(getitem_5750, [0, 2, 1, 3]); getitem_5750 = None + permute_543 = torch.ops.aten.permute.default(getitem_5749, [0, 2, 1, 3]); getitem_5749 = None + slice_176 = torch.ops.aten.slice.Tensor(permute_542, 3, 0, 128) + slice_177 = torch.ops.aten.slice.Tensor(permute_542, 3, 128, 192); permute_542 = None + sum_126 = torch.ops.aten.sum.dim_IntList(slice_177, [2], True); slice_177 = None + cat_256 = torch.ops.aten.cat.default([slice_176, permute_541], 3); slice_176 = permute_541 = None + view_1827 = torch.ops.aten.view.default(cat_256, [2, 4096, 4096]); cat_256 = None + view_1828 = torch.ops.aten.view.default(view_1827, [8192, 4096]); view_1827 = None + permute_544 = torch.ops.aten.permute.default(view_1828, [1, 0]) + mm_260 = torch.ops.aten.mm.default(permute_544, view_1591); permute_544 = view_1591 = None + convert_element_type_1295 = torch.ops.prims.convert_element_type.default(primals_395, torch.bfloat16); primals_395 = None + all_gather_into_tensor_406 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1295, 128, '0'); convert_element_type_1295 = None + wait_tensor_498 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_406); all_gather_into_tensor_406 = None + permute_358 = torch.ops.aten.permute.default(wait_tensor_498, [1, 0]); wait_tensor_498 = None + permute_546 = torch.ops.aten.permute.default(permute_358, [1, 0]); permute_358 = None + mm_261 = torch.ops.aten.mm.default(view_1828, permute_546); view_1828 = permute_546 = None + view_1829 = torch.ops.aten.view.default(mm_261, [2, 4096, 512]); mm_261 = None + convert_element_type_1653 = torch.ops.prims.convert_element_type.default(mm_260, torch.float32); mm_260 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1653, 'avg', 128, '0'); convert_element_type_1653 = None + wait_tensor_604 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39); reduce_scatter_tensor_39 = None + convert_element_type_1654 = torch.ops.prims.convert_element_type.default(view_1829, torch.float32); view_1829 = None + convert_element_type_1292 = torch.ops.prims.convert_element_type.default(primals_394, torch.bfloat16); primals_394 = None + all_gather_into_tensor_405 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1292, 128, '0'); convert_element_type_1292 = None + wait_tensor_497 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_405); all_gather_into_tensor_405 = None + convert_element_type_1656 = torch.ops.prims.convert_element_type.default(wait_tensor_497, torch.float32); wait_tensor_497 = None + mul_1379 = torch.ops.aten.mul.Tensor(convert_element_type_1654, convert_element_type_1656); convert_element_type_1656 = None + convert_element_type_1293 = torch.ops.prims.convert_element_type.default(getitem_2541, torch.float32); getitem_2541 = None + mul_1140 = torch.ops.aten.mul.Tensor(convert_element_type_1293, rsqrt_73); convert_element_type_1293 = None + mul_1381 = torch.ops.aten.mul.Tensor(mul_1140, mul_1379) + sum_127 = torch.ops.aten.sum.dim_IntList(mul_1381, [2], True); mul_1381 = None + div_148 = torch.ops.aten.div.Tensor(mul_1140, 512) + mul_1382 = torch.ops.aten.mul.Tensor(div_148, sum_127); div_148 = sum_127 = None + sub_641 = torch.ops.aten.sub.Tensor(mul_1379, mul_1382); mul_1379 = mul_1382 = None + mul_1383 = torch.ops.aten.mul.Tensor(sub_641, rsqrt_73); sub_641 = rsqrt_73 = None + mul_1384 = torch.ops.aten.mul.Tensor(convert_element_type_1654, mul_1140); convert_element_type_1654 = mul_1140 = None + sum_128 = torch.ops.aten.sum.dim_IntList(mul_1384, [0, 1]); mul_1384 = None + convert_element_type_1657 = torch.ops.prims.convert_element_type.default(mul_1383, torch.bfloat16); mul_1383 = None + convert_element_type_default_74 = torch.ops.prims.convert_element_type.default(sum_128, torch.float32); sum_128 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_74, 'avg', 128, '0'); convert_element_type_default_74 = None + wait_tensor_605 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + convert_element_type_1660 = torch.ops.prims.convert_element_type.default(sum_126, torch.float32); sum_126 = None + view_1830 = torch.ops.aten.view.default(convert_element_type_1660, [2, 4096, 1, 32, 2]); convert_element_type_1660 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1830); view_1830 = None + mul_1385 = torch.ops.aten.mul.Tensor(view_as_complex_58, clone_9); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_1385); mul_1385 = None + view_1831 = torch.ops.aten.view.default(view_as_real_58, [2, 4096, 1, 64]); view_as_real_58 = None + convert_element_type_1661 = torch.ops.prims.convert_element_type.default(view_1831, torch.bfloat16); view_1831 = None + squeeze_28 = torch.ops.aten.squeeze.dim(convert_element_type_1661, 2); convert_element_type_1661 = None + cat_257 = torch.ops.aten.cat.default([convert_element_type_1657, squeeze_28], 2); convert_element_type_1657 = squeeze_28 = None + view_1832 = torch.ops.aten.view.default(cat_257, [8192, 576]); cat_257 = None + permute_548 = torch.ops.aten.permute.default(view_1832, [1, 0]) + mm_262 = torch.ops.aten.mm.default(permute_548, view_1577); permute_548 = None + convert_element_type_1287 = torch.ops.prims.convert_element_type.default(primals_393, torch.bfloat16); primals_393 = None + all_gather_into_tensor_404 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1287, 128, '0'); convert_element_type_1287 = None + wait_tensor_496 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_404); all_gather_into_tensor_404 = None + slice_145 = torch.ops.aten.slice.Tensor(wait_tensor_496, 0, 0, 576); wait_tensor_496 = None + permute_357 = torch.ops.aten.permute.default(slice_145, [1, 0]); slice_145 = None + permute_550 = torch.ops.aten.permute.default(permute_357, [1, 0]); permute_357 = None + mm_263 = torch.ops.aten.mm.default(view_1832, permute_550); view_1832 = permute_550 = None + view_1833 = torch.ops.aten.view.default(mm_263, [2, 4096, 2048]); mm_263 = None + convert_element_type_1666 = torch.ops.prims.convert_element_type.default(mm_262, torch.float32); mm_262 = None + split_314 = torch.ops.aten.split.Tensor(convert_element_type_1666, 5); convert_element_type_1666 = None + getitem_5753 = split_314[0] + getitem_5754 = split_314[1] + getitem_5755 = split_314[2] + getitem_5756 = split_314[3] + getitem_5757 = split_314[4] + getitem_5758 = split_314[5] + getitem_5759 = split_314[6] + getitem_5760 = split_314[7] + getitem_5761 = split_314[8] + getitem_5762 = split_314[9] + getitem_5763 = split_314[10] + getitem_5764 = split_314[11] + getitem_5765 = split_314[12] + getitem_5766 = split_314[13] + getitem_5767 = split_314[14] + getitem_5768 = split_314[15] + getitem_5769 = split_314[16] + getitem_5770 = split_314[17] + getitem_5771 = split_314[18] + getitem_5772 = split_314[19] + getitem_5773 = split_314[20] + getitem_5774 = split_314[21] + getitem_5775 = split_314[22] + getitem_5776 = split_314[23] + getitem_5777 = split_314[24] + getitem_5778 = split_314[25] + getitem_5779 = split_314[26] + getitem_5780 = split_314[27] + getitem_5781 = split_314[28] + getitem_5782 = split_314[29] + getitem_5783 = split_314[30] + getitem_5784 = split_314[31] + getitem_5785 = split_314[32] + getitem_5786 = split_314[33] + getitem_5787 = split_314[34] + getitem_5788 = split_314[35] + getitem_5789 = split_314[36] + getitem_5790 = split_314[37] + getitem_5791 = split_314[38] + getitem_5792 = split_314[39] + getitem_5793 = split_314[40] + getitem_5794 = split_314[41] + getitem_5795 = split_314[42] + getitem_5796 = split_314[43] + getitem_5797 = split_314[44] + getitem_5798 = split_314[45] + getitem_5799 = split_314[46] + getitem_5800 = split_314[47] + getitem_5801 = split_314[48] + getitem_5802 = split_314[49] + getitem_5803 = split_314[50] + getitem_5804 = split_314[51] + getitem_5805 = split_314[52] + getitem_5806 = split_314[53] + getitem_5807 = split_314[54] + getitem_5808 = split_314[55] + getitem_5809 = split_314[56] + getitem_5810 = split_314[57] + getitem_5811 = split_314[58] + getitem_5812 = split_314[59] + getitem_5813 = split_314[60] + getitem_5814 = split_314[61] + getitem_5815 = split_314[62] + getitem_5816 = split_314[63] + getitem_5817 = split_314[64] + getitem_5818 = split_314[65] + getitem_5819 = split_314[66] + getitem_5820 = split_314[67] + getitem_5821 = split_314[68] + getitem_5822 = split_314[69] + getitem_5823 = split_314[70] + getitem_5824 = split_314[71] + getitem_5825 = split_314[72] + getitem_5826 = split_314[73] + getitem_5827 = split_314[74] + getitem_5828 = split_314[75] + getitem_5829 = split_314[76] + getitem_5830 = split_314[77] + getitem_5831 = split_314[78] + getitem_5832 = split_314[79] + getitem_5833 = split_314[80] + getitem_5834 = split_314[81] + getitem_5835 = split_314[82] + getitem_5836 = split_314[83] + getitem_5837 = split_314[84] + getitem_5838 = split_314[85] + getitem_5839 = split_314[86] + getitem_5840 = split_314[87] + getitem_5841 = split_314[88] + getitem_5842 = split_314[89] + getitem_5843 = split_314[90] + getitem_5844 = split_314[91] + getitem_5845 = split_314[92] + getitem_5846 = split_314[93] + getitem_5847 = split_314[94] + getitem_5848 = split_314[95] + getitem_5849 = split_314[96] + getitem_5850 = split_314[97] + getitem_5851 = split_314[98] + getitem_5852 = split_314[99] + getitem_5853 = split_314[100] + getitem_5854 = split_314[101] + getitem_5855 = split_314[102] + getitem_5856 = split_314[103] + getitem_5857 = split_314[104] + getitem_5858 = split_314[105] + getitem_5859 = split_314[106] + getitem_5860 = split_314[107] + getitem_5861 = split_314[108] + getitem_5862 = split_314[109] + getitem_5863 = split_314[110] + getitem_5864 = split_314[111] + getitem_5865 = split_314[112] + getitem_5866 = split_314[113] + getitem_5867 = split_314[114] + getitem_5868 = split_314[115]; split_314 = None + constant_pad_nd_218 = torch.ops.aten.constant_pad_nd.default(getitem_5868, [0, 0, 0, 4], 0.0); getitem_5868 = None + cat_258 = torch.ops.aten.cat.default([getitem_5753, getitem_5754, getitem_5755, getitem_5756, getitem_5757, getitem_5758, getitem_5759, getitem_5760, getitem_5761, getitem_5762, getitem_5763, getitem_5764, getitem_5765, getitem_5766, getitem_5767, getitem_5768, getitem_5769, getitem_5770, getitem_5771, getitem_5772, getitem_5773, getitem_5774, getitem_5775, getitem_5776, getitem_5777, getitem_5778, getitem_5779, getitem_5780, getitem_5781, getitem_5782, getitem_5783, getitem_5784, getitem_5785, getitem_5786, getitem_5787, getitem_5788, getitem_5789, getitem_5790, getitem_5791, getitem_5792, getitem_5793, getitem_5794, getitem_5795, getitem_5796, getitem_5797, getitem_5798, getitem_5799, getitem_5800, getitem_5801, getitem_5802, getitem_5803, getitem_5804, getitem_5805, getitem_5806, getitem_5807, getitem_5808, getitem_5809, getitem_5810, getitem_5811, getitem_5812, getitem_5813, getitem_5814, getitem_5815, getitem_5816, getitem_5817, getitem_5818, getitem_5819, getitem_5820, getitem_5821, getitem_5822, getitem_5823, getitem_5824, getitem_5825, getitem_5826, getitem_5827, getitem_5828, getitem_5829, getitem_5830, getitem_5831, getitem_5832, getitem_5833, getitem_5834, getitem_5835, getitem_5836, getitem_5837, getitem_5838, getitem_5839, getitem_5840, getitem_5841, getitem_5842, getitem_5843, getitem_5844, getitem_5845, getitem_5846, getitem_5847, getitem_5848, getitem_5849, getitem_5850, getitem_5851, getitem_5852, getitem_5853, getitem_5854, getitem_5855, getitem_5856, getitem_5857, getitem_5858, getitem_5859, getitem_5860, getitem_5861, getitem_5862, getitem_5863, getitem_5864, getitem_5865, getitem_5866, getitem_5867, constant_pad_nd_218, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_5753 = getitem_5754 = getitem_5755 = getitem_5756 = getitem_5757 = getitem_5758 = getitem_5759 = getitem_5760 = getitem_5761 = getitem_5762 = getitem_5763 = getitem_5764 = getitem_5765 = getitem_5766 = getitem_5767 = getitem_5768 = getitem_5769 = getitem_5770 = getitem_5771 = getitem_5772 = getitem_5773 = getitem_5774 = getitem_5775 = getitem_5776 = getitem_5777 = getitem_5778 = getitem_5779 = getitem_5780 = getitem_5781 = getitem_5782 = getitem_5783 = getitem_5784 = getitem_5785 = getitem_5786 = getitem_5787 = getitem_5788 = getitem_5789 = getitem_5790 = getitem_5791 = getitem_5792 = getitem_5793 = getitem_5794 = getitem_5795 = getitem_5796 = getitem_5797 = getitem_5798 = getitem_5799 = getitem_5800 = getitem_5801 = getitem_5802 = getitem_5803 = getitem_5804 = getitem_5805 = getitem_5806 = getitem_5807 = getitem_5808 = getitem_5809 = getitem_5810 = getitem_5811 = getitem_5812 = getitem_5813 = getitem_5814 = getitem_5815 = getitem_5816 = getitem_5817 = getitem_5818 = getitem_5819 = getitem_5820 = getitem_5821 = getitem_5822 = getitem_5823 = getitem_5824 = getitem_5825 = getitem_5826 = getitem_5827 = getitem_5828 = getitem_5829 = getitem_5830 = getitem_5831 = getitem_5832 = getitem_5833 = getitem_5834 = getitem_5835 = getitem_5836 = getitem_5837 = getitem_5838 = getitem_5839 = getitem_5840 = getitem_5841 = getitem_5842 = getitem_5843 = getitem_5844 = getitem_5845 = getitem_5846 = getitem_5847 = getitem_5848 = getitem_5849 = getitem_5850 = getitem_5851 = getitem_5852 = getitem_5853 = getitem_5854 = getitem_5855 = getitem_5856 = getitem_5857 = getitem_5858 = getitem_5859 = getitem_5860 = getitem_5861 = getitem_5862 = getitem_5863 = getitem_5864 = getitem_5865 = getitem_5866 = getitem_5867 = constant_pad_nd_218 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_258, 'avg', 128, '0'); cat_258 = None + wait_tensor_606 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41); reduce_scatter_tensor_41 = None + slice_178 = torch.ops.aten.slice.Tensor(permute_543, 3, 0, 128) + slice_179 = torch.ops.aten.slice.Tensor(permute_543, 3, 128, 192); permute_543 = None + convert_element_type_1667 = torch.ops.prims.convert_element_type.default(slice_179, torch.float32); slice_179 = None + view_1834 = torch.ops.aten.view.default(convert_element_type_1667, [2, 4096, 16, 32, 2]); convert_element_type_1667 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1834); view_1834 = None + mul_1386 = torch.ops.aten.mul.Tensor(view_as_complex_59, clone_9); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_1386); mul_1386 = None + view_1835 = torch.ops.aten.view.default(view_as_real_59, [2, 4096, 16, 64]); view_as_real_59 = None + convert_element_type_1668 = torch.ops.prims.convert_element_type.default(view_1835, torch.bfloat16); view_1835 = None + cat_259 = torch.ops.aten.cat.default([slice_178, convert_element_type_1668], 3); slice_178 = convert_element_type_1668 = None + view_1836 = torch.ops.aten.view.default(cat_259, [2, 4096, 3072]); cat_259 = None + view_1837 = torch.ops.aten.view.default(view_1836, [8192, 3072]); view_1836 = None + permute_552 = torch.ops.aten.permute.default(view_1837, [1, 0]) + mm_264 = torch.ops.aten.mm.default(permute_552, view_1577); permute_552 = view_1577 = None + convert_element_type_1282 = torch.ops.prims.convert_element_type.default(primals_392, torch.bfloat16); primals_392 = None + all_gather_into_tensor_403 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1282, 128, '0'); convert_element_type_1282 = None + wait_tensor_495 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_403); all_gather_into_tensor_403 = None + permute_356 = torch.ops.aten.permute.default(wait_tensor_495, [1, 0]); wait_tensor_495 = None + permute_554 = torch.ops.aten.permute.default(permute_356, [1, 0]); permute_356 = None + mm_265 = torch.ops.aten.mm.default(view_1837, permute_554); view_1837 = permute_554 = None + view_1838 = torch.ops.aten.view.default(mm_265, [2, 4096, 2048]); mm_265 = None + add_1818 = torch.ops.aten.add.Tensor(view_1833, view_1838); view_1833 = view_1838 = None + convert_element_type_1673 = torch.ops.prims.convert_element_type.default(mm_264, torch.float32); mm_264 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1673, 'avg', 128, '0'); convert_element_type_1673 = None + wait_tensor_607 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + convert_element_type_1674 = torch.ops.prims.convert_element_type.default(add_1818, torch.float32); add_1818 = None + convert_element_type_1279 = torch.ops.prims.convert_element_type.default(primals_391, torch.bfloat16); primals_391 = None + all_gather_into_tensor_402 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1279, 128, '0'); convert_element_type_1279 = None + wait_tensor_494 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_402); all_gather_into_tensor_402 = None + convert_element_type_1676 = torch.ops.prims.convert_element_type.default(wait_tensor_494, torch.float32); wait_tensor_494 = None + mul_1387 = torch.ops.aten.mul.Tensor(convert_element_type_1674, convert_element_type_1676); convert_element_type_1676 = None + convert_element_type_1280 = torch.ops.prims.convert_element_type.default(add_1569, torch.float32); add_1569 = None + mul_1136 = torch.ops.aten.mul.Tensor(convert_element_type_1280, rsqrt_72); convert_element_type_1280 = None + mul_1389 = torch.ops.aten.mul.Tensor(mul_1136, mul_1387) + sum_129 = torch.ops.aten.sum.dim_IntList(mul_1389, [2], True); mul_1389 = None + div_149 = torch.ops.aten.div.Tensor(mul_1136, 2048) + mul_1390 = torch.ops.aten.mul.Tensor(div_149, sum_129); div_149 = sum_129 = None + sub_642 = torch.ops.aten.sub.Tensor(mul_1387, mul_1390); mul_1387 = mul_1390 = None + mul_1391 = torch.ops.aten.mul.Tensor(sub_642, rsqrt_72); sub_642 = rsqrt_72 = None + mul_1392 = torch.ops.aten.mul.Tensor(convert_element_type_1674, mul_1136); convert_element_type_1674 = mul_1136 = None + sum_130 = torch.ops.aten.sum.dim_IntList(mul_1392, [0, 1]); mul_1392 = None + convert_element_type_1677 = torch.ops.prims.convert_element_type.default(mul_1391, torch.bfloat16); mul_1391 = None + add_1819 = torch.ops.aten.add.Tensor(add_1817, convert_element_type_1677); add_1817 = convert_element_type_1677 = None + convert_element_type_default_73 = torch.ops.prims.convert_element_type.default(sum_130, torch.float32); sum_130 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_73, 'avg', 128, '0'); convert_element_type_default_73 = None + wait_tensor_608 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43); reduce_scatter_tensor_43 = None + view_1839 = torch.ops.aten.view.default(add_1819, [8192, 2048]) + unsqueeze_56 = torch.ops.aten.unsqueeze.default(view_1839, 1) + convert_element_type_1680 = torch.ops.prims.convert_element_type.default(unsqueeze_56, torch.float32); unsqueeze_56 = None + bmm_32 = torch.ops.aten.bmm.default(permute_556, convert_element_type_1680); permute_556 = None + bmm_33 = torch.ops.aten.bmm.default(convert_element_type_1680, permute_557); convert_element_type_1680 = permute_557 = None + convert_element_type_1681 = torch.ops.prims.convert_element_type.default(bmm_32, torch.bfloat16); bmm_32 = None + view_1840 = torch.ops.aten.view.default(bmm_33, [8192, 6]); bmm_33 = None + view_1841 = torch.ops.aten.view.default(convert_element_type_1681, [49152, 2048]); convert_element_type_1681 = None + index_58 = torch.ops.aten.index.Tensor(view_1841, [getitem_2441]); view_1841 = getitem_2441 = None + permute_558 = torch.ops.aten.permute.default(view_1839, [1, 0]) + mm_266 = torch.ops.aten.mm.default(permute_558, mul_1133); permute_558 = mul_1133 = None + convert_element_type_1274 = torch.ops.prims.convert_element_type.default(primals_390, torch.bfloat16); primals_390 = None + all_gather_into_tensor_401 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1274, 128, '0'); convert_element_type_1274 = None + wait_tensor_493 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_401); all_gather_into_tensor_401 = None + permute_355 = torch.ops.aten.permute.default(wait_tensor_493, [1, 0]); wait_tensor_493 = None + permute_560 = torch.ops.aten.permute.default(permute_355, [1, 0]); permute_355 = None + mm_267 = torch.ops.aten.mm.default(view_1839, permute_560); view_1839 = permute_560 = None + convert_element_type_1686 = torch.ops.prims.convert_element_type.default(mm_266, torch.float32); mm_266 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1686, 'avg', 128, '0'); convert_element_type_1686 = None + wait_tensor_609 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + convert_element_type_1269 = torch.ops.prims.convert_element_type.default(mm_188, torch.float32); mm_188 = None + neg_46 = torch.ops.aten.neg.default(convert_element_type_1269) + exp_69 = torch.ops.aten.exp.default(neg_46); neg_46 = None + add_1564 = torch.ops.aten.add.Tensor(exp_69, 1); exp_69 = None + div_115 = torch.ops.aten.div.Tensor(convert_element_type_1269, add_1564) + convert_element_type_1270 = torch.ops.prims.convert_element_type.default(div_115, torch.bfloat16); div_115 = None + mul_1393 = torch.ops.aten.mul.Tensor(mm_267, convert_element_type_1270); convert_element_type_1270 = None + mul_1394 = torch.ops.aten.mul.Tensor(mm_267, mm_189); mm_267 = mm_189 = None + permute_562 = torch.ops.aten.permute.default(mul_1393, [1, 0]) + mm_268 = torch.ops.aten.mm.default(permute_562, view_1532); permute_562 = None + convert_element_type_1271 = torch.ops.prims.convert_element_type.default(primals_389, torch.bfloat16); primals_389 = None + all_gather_into_tensor_400 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1271, 128, '0'); convert_element_type_1271 = None + wait_tensor_492 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_400); all_gather_into_tensor_400 = None + permute_354 = torch.ops.aten.permute.default(wait_tensor_492, [1, 0]); wait_tensor_492 = None + permute_564 = torch.ops.aten.permute.default(permute_354, [1, 0]); permute_354 = None + mm_269 = torch.ops.aten.mm.default(mul_1393, permute_564); mul_1393 = permute_564 = None + convert_element_type_1691 = torch.ops.prims.convert_element_type.default(mm_268, torch.float32); mm_268 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1691, 'avg', 128, '0'); convert_element_type_1691 = None + wait_tensor_610 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45); reduce_scatter_tensor_45 = None + convert_element_type_1692 = torch.ops.prims.convert_element_type.default(mul_1394, torch.float32); mul_1394 = None + reciprocal_6 = torch.ops.aten.reciprocal.default(add_1564); add_1564 = None + mul_1395 = torch.ops.aten.mul.Tensor(reciprocal_6, 1); reciprocal_6 = None + mul_1396 = torch.ops.aten.mul.Tensor(convert_element_type_1692, mul_1395); convert_element_type_1692 = None + sub_643 = torch.ops.aten.sub.Tensor(1, mul_1395); mul_1395 = None + mul_1397 = torch.ops.aten.mul.Tensor(convert_element_type_1269, sub_643); convert_element_type_1269 = sub_643 = None + add_1821 = torch.ops.aten.add.Tensor(mul_1397, 1); mul_1397 = None + mul_1398 = torch.ops.aten.mul.Tensor(mul_1396, add_1821); mul_1396 = add_1821 = None + convert_element_type_1694 = torch.ops.prims.convert_element_type.default(mul_1398, torch.bfloat16); mul_1398 = None + permute_566 = torch.ops.aten.permute.default(convert_element_type_1694, [1, 0]) + mm_270 = torch.ops.aten.mm.default(permute_566, view_1532); permute_566 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(primals_388, torch.bfloat16); primals_388 = None + all_gather_into_tensor_399 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1266, 128, '0'); convert_element_type_1266 = None + wait_tensor_491 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_399); all_gather_into_tensor_399 = None + permute_353 = torch.ops.aten.permute.default(wait_tensor_491, [1, 0]); wait_tensor_491 = None + permute_568 = torch.ops.aten.permute.default(permute_353, [1, 0]); permute_353 = None + mm_271 = torch.ops.aten.mm.default(convert_element_type_1694, permute_568); convert_element_type_1694 = permute_568 = None + add_1822 = torch.ops.aten.add.Tensor(mm_269, mm_271); mm_269 = mm_271 = None + convert_element_type_1699 = torch.ops.prims.convert_element_type.default(mm_270, torch.float32); mm_270 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1699, 'avg', 128, '0'); convert_element_type_1699 = None + wait_tensor_611 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + all_to_all_single_84 = torch.ops._c10d_functional.all_to_all_single.default(index_58, [_local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367], [_local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359], '1033'); index_58 = None + wait_tensor_612 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_84); all_to_all_single_84 = None + full_366 = torch.ops.aten.full.default([sym_size_int_89, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_89 = None + slice_scatter_3 = torch.ops.aten.slice_scatter.default(full_366, wait_tensor_612, 0, 0, -1); wait_tensor_612 = None + index_59 = torch.ops.aten.index.Tensor(slice_scatter_3, [getitem_2442]); slice_scatter_3 = None + permute_570 = torch.ops.aten.permute.default(index_59, [1, 0]) + _grouped_mm_96 = torch.ops.aten._grouped_mm.default(permute_570, mul_1113, cumsum_68); permute_570 = mul_1113 = None + _grouped_mm_97 = torch.ops.aten._grouped_mm.default(index_59, permute_572, cumsum_68); index_59 = permute_572 = None + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(_grouped_mm_66, torch.float32); _grouped_mm_66 = None + neg_45 = torch.ops.aten.neg.default(convert_element_type_1264) + exp_68 = torch.ops.aten.exp.default(neg_45); neg_45 = None + add_1528 = torch.ops.aten.add.Tensor(exp_68, 1); exp_68 = None + div_114 = torch.ops.aten.div.Tensor(convert_element_type_1264, add_1528) + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(div_114, torch.bfloat16); div_114 = None + mul_1399 = torch.ops.aten.mul.Tensor(_grouped_mm_97, convert_element_type_1265); convert_element_type_1265 = None + mul_1400 = torch.ops.aten.mul.Tensor(_grouped_mm_97, _grouped_mm_67); _grouped_mm_97 = _grouped_mm_67 = None + permute_574 = torch.ops.aten.permute.default(mul_1399, [1, 0]) + _grouped_mm_98 = torch.ops.aten._grouped_mm.default(permute_574, index_45, cumsum_68); permute_574 = None + _grouped_mm_99 = torch.ops.aten._grouped_mm.default(mul_1399, permute_576, cumsum_68); mul_1399 = permute_576 = None + convert_element_type_1700 = torch.ops.prims.convert_element_type.default(mul_1400, torch.float32); mul_1400 = None + reciprocal_7 = torch.ops.aten.reciprocal.default(add_1528); add_1528 = None + mul_1401 = torch.ops.aten.mul.Tensor(reciprocal_7, 1); reciprocal_7 = None + mul_1402 = torch.ops.aten.mul.Tensor(convert_element_type_1700, mul_1401); convert_element_type_1700 = None + sub_644 = torch.ops.aten.sub.Tensor(1, mul_1401); mul_1401 = None + mul_1403 = torch.ops.aten.mul.Tensor(convert_element_type_1264, sub_644); convert_element_type_1264 = sub_644 = None + add_1824 = torch.ops.aten.add.Tensor(mul_1403, 1); mul_1403 = None + mul_1404 = torch.ops.aten.mul.Tensor(mul_1402, add_1824); mul_1402 = add_1824 = None + convert_element_type_1702 = torch.ops.prims.convert_element_type.default(mul_1404, torch.bfloat16); mul_1404 = None + permute_578 = torch.ops.aten.permute.default(convert_element_type_1702, [1, 0]) + _grouped_mm_100 = torch.ops.aten._grouped_mm.default(permute_578, index_45, cumsum_68); permute_578 = index_45 = None + _grouped_mm_101 = torch.ops.aten._grouped_mm.default(convert_element_type_1702, permute_580, cumsum_68); convert_element_type_1702 = permute_580 = cumsum_68 = None + add_1825 = torch.ops.aten.add.Tensor(_grouped_mm_99, _grouped_mm_101); _grouped_mm_99 = _grouped_mm_101 = None + convert_element_type_1703 = torch.ops.prims.convert_element_type.default(_grouped_mm_98, torch.float32); _grouped_mm_98 = None + div_150 = torch.ops.aten.div.Tensor(convert_element_type_1703, 128); convert_element_type_1703 = None + split_316 = torch.ops.aten.split.Tensor(div_150, 88, 1); div_150 = None + getitem_5885 = split_316[0] + getitem_5902 = split_316[1] + getitem_5919 = split_316[2] + getitem_5936 = split_316[3] + getitem_5953 = split_316[4] + getitem_5970 = split_316[5] + getitem_5987 = split_316[6] + getitem_6004 = split_316[7] + getitem_6021 = split_316[8] + getitem_6038 = split_316[9] + getitem_6055 = split_316[10] + getitem_6072 = split_316[11] + getitem_6089 = split_316[12] + getitem_6106 = split_316[13] + getitem_6123 = split_316[14] + getitem_6140 = split_316[15]; split_316 = None + cat_260 = torch.ops.aten.cat.default([getitem_5885, getitem_5902, getitem_5919, getitem_5936, getitem_5953, getitem_5970, getitem_5987, getitem_6004, getitem_6021, getitem_6038, getitem_6055, getitem_6072, getitem_6089, getitem_6106, getitem_6123, getitem_6140]); getitem_5885 = getitem_5902 = getitem_5919 = getitem_5936 = getitem_5953 = getitem_5970 = getitem_5987 = getitem_6004 = getitem_6021 = getitem_6038 = getitem_6055 = getitem_6072 = getitem_6089 = getitem_6106 = getitem_6123 = getitem_6140 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_260, 'sum', 16, '1025'); cat_260 = None + wait_tensor_613 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47); reduce_scatter_tensor_47 = None + convert_element_type_1704 = torch.ops.prims.convert_element_type.default(_grouped_mm_96, torch.float32); _grouped_mm_96 = None + div_151 = torch.ops.aten.div.Tensor(convert_element_type_1704, 128); convert_element_type_1704 = None + split_333 = torch.ops.aten.split.Tensor(div_151, 128, 1); div_151 = None + getitem_6157 = split_333[0] + getitem_6174 = split_333[1] + getitem_6191 = split_333[2] + getitem_6208 = split_333[3] + getitem_6225 = split_333[4] + getitem_6242 = split_333[5] + getitem_6259 = split_333[6] + getitem_6276 = split_333[7] + getitem_6293 = split_333[8] + getitem_6310 = split_333[9] + getitem_6327 = split_333[10] + getitem_6344 = split_333[11] + getitem_6361 = split_333[12] + getitem_6378 = split_333[13] + getitem_6395 = split_333[14] + getitem_6412 = split_333[15]; split_333 = None + cat_261 = torch.ops.aten.cat.default([getitem_6157, getitem_6174, getitem_6191, getitem_6208, getitem_6225, getitem_6242, getitem_6259, getitem_6276, getitem_6293, getitem_6310, getitem_6327, getitem_6344, getitem_6361, getitem_6378, getitem_6395, getitem_6412]); getitem_6157 = getitem_6174 = getitem_6191 = getitem_6208 = getitem_6225 = getitem_6242 = getitem_6259 = getitem_6276 = getitem_6293 = getitem_6310 = getitem_6327 = getitem_6344 = getitem_6361 = getitem_6378 = getitem_6395 = getitem_6412 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_261, 'sum', 16, '1025'); cat_261 = None + wait_tensor_614 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + convert_element_type_1705 = torch.ops.prims.convert_element_type.default(_grouped_mm_100, torch.float32); _grouped_mm_100 = None + div_152 = torch.ops.aten.div.Tensor(convert_element_type_1705, 128); convert_element_type_1705 = None + split_350 = torch.ops.aten.split.Tensor(div_152, 88, 1); div_152 = None + getitem_6429 = split_350[0] + getitem_6446 = split_350[1] + getitem_6463 = split_350[2] + getitem_6480 = split_350[3] + getitem_6497 = split_350[4] + getitem_6514 = split_350[5] + getitem_6531 = split_350[6] + getitem_6548 = split_350[7] + getitem_6565 = split_350[8] + getitem_6582 = split_350[9] + getitem_6599 = split_350[10] + getitem_6616 = split_350[11] + getitem_6633 = split_350[12] + getitem_6650 = split_350[13] + getitem_6667 = split_350[14] + getitem_6684 = split_350[15]; split_350 = None + cat_262 = torch.ops.aten.cat.default([getitem_6429, getitem_6446, getitem_6463, getitem_6480, getitem_6497, getitem_6514, getitem_6531, getitem_6548, getitem_6565, getitem_6582, getitem_6599, getitem_6616, getitem_6633, getitem_6650, getitem_6667, getitem_6684]); getitem_6429 = getitem_6446 = getitem_6463 = getitem_6480 = getitem_6497 = getitem_6514 = getitem_6531 = getitem_6548 = getitem_6565 = getitem_6582 = getitem_6599 = getitem_6616 = getitem_6633 = getitem_6650 = getitem_6667 = getitem_6684 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_262, 'sum', 16, '1025'); cat_262 = None + wait_tensor_615 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49); reduce_scatter_tensor_49 = None + index_put_58 = torch.ops.aten.index_put.default(full_366, [getitem_2442], add_1825, True); full_366 = getitem_2442 = add_1825 = None + slice_180 = torch.ops.aten.slice.Tensor(index_put_58, 0, 0, add_1826); index_put_58 = add_1826 = None + all_to_all_single_85 = torch.ops._c10d_functional.all_to_all_single.default(slice_180, [_local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359], [_local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367], '1033'); slice_180 = _local_scalar_dense_352 = _local_scalar_dense_353 = _local_scalar_dense_354 = _local_scalar_dense_355 = _local_scalar_dense_356 = _local_scalar_dense_357 = _local_scalar_dense_358 = _local_scalar_dense_359 = _local_scalar_dense_360 = _local_scalar_dense_361 = _local_scalar_dense_362 = _local_scalar_dense_363 = _local_scalar_dense_364 = _local_scalar_dense_365 = _local_scalar_dense_366 = _local_scalar_dense_367 = None + wait_tensor_616 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_85); all_to_all_single_85 = None + index_put_59 = torch.ops.aten.index_put.default(full_default_52, [div_112], wait_tensor_616, True); div_112 = wait_tensor_616 = None + add_1830 = torch.ops.aten.add.Tensor(add_1822, index_put_59); add_1822 = index_put_59 = None + mul_1405 = torch.ops.aten.mul.Tensor(view_1840, 1.0); view_1840 = None + scatter_add_3 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_2439, mul_1405); getitem_2439 = mul_1405 = None + convert_element_type_1253 = torch.ops.prims.convert_element_type.default(mm_187, torch.float32); mm_187 = None + sub_528 = torch.ops.aten.sub.Tensor(convert_element_type_1253, amax_22); convert_element_type_1253 = amax_22 = None + exp_67 = torch.ops.aten.exp.default(sub_528); sub_528 = None + div_111 = torch.ops.aten.div.Tensor(exp_67, sum_89); exp_67 = sum_89 = None + mul_1406 = torch.ops.aten.mul.Tensor(scatter_add_3, div_111); scatter_add_3 = None + sum_131 = torch.ops.aten.sum.dim_IntList(mul_1406, [1], True) + neg_64 = torch.ops.aten.neg.default(div_111); div_111 = None + fma_3 = torch.ops.prims.fma.default(neg_64, sum_131, mul_1406); neg_64 = sum_131 = mul_1406 = None + convert_element_type_1706 = torch.ops.prims.convert_element_type.default(fma_3, torch.bfloat16); fma_3 = None + permute_582 = torch.ops.aten.permute.default(convert_element_type_1706, [1, 0]) + mm_272 = torch.ops.aten.mm.default(permute_582, view_1532); permute_582 = view_1532 = None + convert_element_type_1250 = torch.ops.prims.convert_element_type.default(primals_383, torch.bfloat16); primals_383 = None + all_gather_into_tensor_392 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1250, 128, '0'); convert_element_type_1250 = None + wait_tensor_480 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_392); all_gather_into_tensor_392 = None + slice_141 = torch.ops.aten.slice.Tensor(wait_tensor_480, 0, 0, 64); wait_tensor_480 = None + permute_349 = torch.ops.aten.permute.default(slice_141, [1, 0]); slice_141 = None + permute_584 = torch.ops.aten.permute.default(permute_349, [1, 0]); permute_349 = None + mm_273 = torch.ops.aten.mm.default(convert_element_type_1706, permute_584); convert_element_type_1706 = permute_584 = None + add_1831 = torch.ops.aten.add.Tensor(add_1830, mm_273); add_1830 = mm_273 = None + convert_element_type_1711 = torch.ops.prims.convert_element_type.default(mm_272, torch.float32); mm_272 = None + split_366 = torch.ops.aten.split.Tensor(convert_element_type_1711, 1); convert_element_type_1711 = None + getitem_6685 = split_366[0] + getitem_6686 = split_366[1] + getitem_6687 = split_366[2] + getitem_6688 = split_366[3] + getitem_6689 = split_366[4] + getitem_6690 = split_366[5] + getitem_6691 = split_366[6] + getitem_6692 = split_366[7] + getitem_6693 = split_366[8] + getitem_6694 = split_366[9] + getitem_6695 = split_366[10] + getitem_6696 = split_366[11] + getitem_6697 = split_366[12] + getitem_6698 = split_366[13] + getitem_6699 = split_366[14] + getitem_6700 = split_366[15] + getitem_6701 = split_366[16] + getitem_6702 = split_366[17] + getitem_6703 = split_366[18] + getitem_6704 = split_366[19] + getitem_6705 = split_366[20] + getitem_6706 = split_366[21] + getitem_6707 = split_366[22] + getitem_6708 = split_366[23] + getitem_6709 = split_366[24] + getitem_6710 = split_366[25] + getitem_6711 = split_366[26] + getitem_6712 = split_366[27] + getitem_6713 = split_366[28] + getitem_6714 = split_366[29] + getitem_6715 = split_366[30] + getitem_6716 = split_366[31] + getitem_6717 = split_366[32] + getitem_6718 = split_366[33] + getitem_6719 = split_366[34] + getitem_6720 = split_366[35] + getitem_6721 = split_366[36] + getitem_6722 = split_366[37] + getitem_6723 = split_366[38] + getitem_6724 = split_366[39] + getitem_6725 = split_366[40] + getitem_6726 = split_366[41] + getitem_6727 = split_366[42] + getitem_6728 = split_366[43] + getitem_6729 = split_366[44] + getitem_6730 = split_366[45] + getitem_6731 = split_366[46] + getitem_6732 = split_366[47] + getitem_6733 = split_366[48] + getitem_6734 = split_366[49] + getitem_6735 = split_366[50] + getitem_6736 = split_366[51] + getitem_6737 = split_366[52] + getitem_6738 = split_366[53] + getitem_6739 = split_366[54] + getitem_6740 = split_366[55] + getitem_6741 = split_366[56] + getitem_6742 = split_366[57] + getitem_6743 = split_366[58] + getitem_6744 = split_366[59] + getitem_6745 = split_366[60] + getitem_6746 = split_366[61] + getitem_6747 = split_366[62] + getitem_6748 = split_366[63]; split_366 = None + cat_263 = torch.ops.aten.cat.default([getitem_6685, getitem_6686, getitem_6687, getitem_6688, getitem_6689, getitem_6690, getitem_6691, getitem_6692, getitem_6693, getitem_6694, getitem_6695, getitem_6696, getitem_6697, getitem_6698, getitem_6699, getitem_6700, getitem_6701, getitem_6702, getitem_6703, getitem_6704, getitem_6705, getitem_6706, getitem_6707, getitem_6708, getitem_6709, getitem_6710, getitem_6711, getitem_6712, getitem_6713, getitem_6714, getitem_6715, getitem_6716, getitem_6717, getitem_6718, getitem_6719, getitem_6720, getitem_6721, getitem_6722, getitem_6723, getitem_6724, getitem_6725, getitem_6726, getitem_6727, getitem_6728, getitem_6729, getitem_6730, getitem_6731, getitem_6732, getitem_6733, getitem_6734, getitem_6735, getitem_6736, getitem_6737, getitem_6738, getitem_6739, getitem_6740, getitem_6741, getitem_6742, getitem_6743, getitem_6744, getitem_6745, getitem_6746, getitem_6747, getitem_6748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_6685 = getitem_6686 = getitem_6687 = getitem_6688 = getitem_6689 = getitem_6690 = getitem_6691 = getitem_6692 = getitem_6693 = getitem_6694 = getitem_6695 = getitem_6696 = getitem_6697 = getitem_6698 = getitem_6699 = getitem_6700 = getitem_6701 = getitem_6702 = getitem_6703 = getitem_6704 = getitem_6705 = getitem_6706 = getitem_6707 = getitem_6708 = getitem_6709 = getitem_6710 = getitem_6711 = getitem_6712 = getitem_6713 = getitem_6714 = getitem_6715 = getitem_6716 = getitem_6717 = getitem_6718 = getitem_6719 = getitem_6720 = getitem_6721 = getitem_6722 = getitem_6723 = getitem_6724 = getitem_6725 = getitem_6726 = getitem_6727 = getitem_6728 = getitem_6729 = getitem_6730 = getitem_6731 = getitem_6732 = getitem_6733 = getitem_6734 = getitem_6735 = getitem_6736 = getitem_6737 = getitem_6738 = getitem_6739 = getitem_6740 = getitem_6741 = getitem_6742 = getitem_6743 = getitem_6744 = getitem_6745 = getitem_6746 = getitem_6747 = getitem_6748 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_263, 'avg', 128, '0'); cat_263 = None + wait_tensor_617 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + view_1842 = torch.ops.aten.view.default(add_1831, [2, 4096, 2048]); add_1831 = None + convert_element_type_1712 = torch.ops.prims.convert_element_type.default(view_1842, torch.float32); view_1842 = None + convert_element_type_1247 = torch.ops.prims.convert_element_type.default(primals_381, torch.bfloat16); primals_381 = None + all_gather_into_tensor_391 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1247, 128, '0'); convert_element_type_1247 = None + wait_tensor_479 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_391); all_gather_into_tensor_391 = None + convert_element_type_1714 = torch.ops.prims.convert_element_type.default(wait_tensor_479, torch.float32); wait_tensor_479 = None + mul_1407 = torch.ops.aten.mul.Tensor(convert_element_type_1712, convert_element_type_1714); convert_element_type_1714 = None + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(add_1504, torch.float32); add_1504 = None + mul_1093 = torch.ops.aten.mul.Tensor(convert_element_type_1248, rsqrt_71); convert_element_type_1248 = None + mul_1409 = torch.ops.aten.mul.Tensor(mul_1093, mul_1407) + sum_132 = torch.ops.aten.sum.dim_IntList(mul_1409, [2], True); mul_1409 = None + div_153 = torch.ops.aten.div.Tensor(mul_1093, 2048) + mul_1410 = torch.ops.aten.mul.Tensor(div_153, sum_132); div_153 = sum_132 = None + sub_646 = torch.ops.aten.sub.Tensor(mul_1407, mul_1410); mul_1407 = mul_1410 = None + mul_1411 = torch.ops.aten.mul.Tensor(sub_646, rsqrt_71); sub_646 = rsqrt_71 = None + mul_1412 = torch.ops.aten.mul.Tensor(convert_element_type_1712, mul_1093); convert_element_type_1712 = mul_1093 = None + sum_133 = torch.ops.aten.sum.dim_IntList(mul_1412, [0, 1]); mul_1412 = None + convert_element_type_1715 = torch.ops.prims.convert_element_type.default(mul_1411, torch.bfloat16); mul_1411 = None + add_1832 = torch.ops.aten.add.Tensor(add_1819, convert_element_type_1715); add_1819 = convert_element_type_1715 = None + convert_element_type_default_72 = torch.ops.prims.convert_element_type.default(sum_133, torch.float32); sum_133 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_72, 'avg', 128, '0'); convert_element_type_default_72 = None + wait_tensor_618 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51); reduce_scatter_tensor_51 = None + view_1843 = torch.ops.aten.view.default(add_1832, [8192, 2048]) + permute_586 = torch.ops.aten.permute.default(view_1843, [1, 0]) + permute_347 = torch.ops.aten.permute.default(getitem_2435, [0, 2, 1, 3]) + view_1527 = torch.ops.aten.view.default(permute_347, [2, 4096, -1]); permute_347 = None + view_1529 = torch.ops.aten.view.default(view_1527, [8192, 2048]); view_1527 = None + mm_274 = torch.ops.aten.mm.default(permute_586, view_1529); permute_586 = view_1529 = None + convert_element_type_1244 = torch.ops.prims.convert_element_type.default(primals_380, torch.bfloat16); primals_380 = None + all_gather_into_tensor_390 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1244, 128, '0'); convert_element_type_1244 = None + wait_tensor_478 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_390); all_gather_into_tensor_390 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_478, [1, 0]); wait_tensor_478 = None + permute_588 = torch.ops.aten.permute.default(permute_348, [1, 0]); permute_348 = None + mm_275 = torch.ops.aten.mm.default(view_1843, permute_588); view_1843 = permute_588 = None + view_1844 = torch.ops.aten.view.default(mm_275, [2, 4096, 2048]); mm_275 = None + convert_element_type_1722 = torch.ops.prims.convert_element_type.default(mm_274, torch.float32); mm_274 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1722, 'avg', 128, '0'); convert_element_type_1722 = None + wait_tensor_619 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + view_1845 = torch.ops.aten.view.default(view_1844, [2, 4096, 16, 128]); view_1844 = None + permute_590 = torch.ops.aten.permute.default(view_1845, [0, 2, 1, 3]); view_1845 = None + fw_graph3 = self.fw_graph3 + joint_graph3 = self.joint_graph3 + mask_graph3 = self.mask_graph3 + flex_attention_backward_3 = torch.ops.higher_order.flex_attention_backward(permute_344, permute_345, permute_346, getitem_2435, getitem_2436, permute_590, None, fw_graph3, joint_graph3, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph3), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_344 = permute_345 = permute_346 = getitem_2435 = getitem_2436 = permute_590 = fw_graph3 = joint_graph3 = mask_graph3 = None + getitem_6749 = flex_attention_backward_3[0] + getitem_6750 = flex_attention_backward_3[1] + getitem_6751 = flex_attention_backward_3[2]; flex_attention_backward_3 = None + permute_591 = torch.ops.aten.permute.default(getitem_6751, [0, 2, 1, 3]); getitem_6751 = None + permute_592 = torch.ops.aten.permute.default(getitem_6750, [0, 2, 1, 3]); getitem_6750 = None + permute_593 = torch.ops.aten.permute.default(getitem_6749, [0, 2, 1, 3]); getitem_6749 = None + slice_182 = torch.ops.aten.slice.Tensor(permute_592, 3, 0, 128) + slice_183 = torch.ops.aten.slice.Tensor(permute_592, 3, 128, 192); permute_592 = None + sum_134 = torch.ops.aten.sum.dim_IntList(slice_183, [2], True); slice_183 = None + cat_264 = torch.ops.aten.cat.default([slice_182, permute_591], 3); slice_182 = permute_591 = None + view_1846 = torch.ops.aten.view.default(cat_264, [2, 4096, 4096]); cat_264 = None + view_1847 = torch.ops.aten.view.default(view_1846, [8192, 4096]); view_1846 = None + permute_594 = torch.ops.aten.permute.default(view_1847, [1, 0]) + mm_276 = torch.ops.aten.mm.default(permute_594, view_1524); permute_594 = view_1524 = None + convert_element_type_1241 = torch.ops.prims.convert_element_type.default(primals_379, torch.bfloat16); primals_379 = None + all_gather_into_tensor_389 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1241, 128, '0'); convert_element_type_1241 = None + wait_tensor_477 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_389); all_gather_into_tensor_389 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_477, [1, 0]); wait_tensor_477 = None + permute_596 = torch.ops.aten.permute.default(permute_343, [1, 0]); permute_343 = None + mm_277 = torch.ops.aten.mm.default(view_1847, permute_596); view_1847 = permute_596 = None + view_1848 = torch.ops.aten.view.default(mm_277, [2, 4096, 512]); mm_277 = None + convert_element_type_1727 = torch.ops.prims.convert_element_type.default(mm_276, torch.float32); mm_276 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1727, 'avg', 128, '0'); convert_element_type_1727 = None + wait_tensor_620 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53); reduce_scatter_tensor_53 = None + convert_element_type_1728 = torch.ops.prims.convert_element_type.default(view_1848, torch.float32); view_1848 = None + convert_element_type_1238 = torch.ops.prims.convert_element_type.default(primals_378, torch.bfloat16); primals_378 = None + all_gather_into_tensor_388 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1238, 128, '0'); convert_element_type_1238 = None + wait_tensor_476 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_388); all_gather_into_tensor_388 = None + convert_element_type_1730 = torch.ops.prims.convert_element_type.default(wait_tensor_476, torch.float32); wait_tensor_476 = None + mul_1413 = torch.ops.aten.mul.Tensor(convert_element_type_1728, convert_element_type_1730); convert_element_type_1730 = None + convert_element_type_1239 = torch.ops.prims.convert_element_type.default(getitem_2431, torch.float32); getitem_2431 = None + mul_1091 = torch.ops.aten.mul.Tensor(convert_element_type_1239, rsqrt_70); convert_element_type_1239 = None + mul_1415 = torch.ops.aten.mul.Tensor(mul_1091, mul_1413) + sum_135 = torch.ops.aten.sum.dim_IntList(mul_1415, [2], True); mul_1415 = None + div_154 = torch.ops.aten.div.Tensor(mul_1091, 512) + mul_1416 = torch.ops.aten.mul.Tensor(div_154, sum_135); div_154 = sum_135 = None + sub_647 = torch.ops.aten.sub.Tensor(mul_1413, mul_1416); mul_1413 = mul_1416 = None + mul_1417 = torch.ops.aten.mul.Tensor(sub_647, rsqrt_70); sub_647 = rsqrt_70 = None + mul_1418 = torch.ops.aten.mul.Tensor(convert_element_type_1728, mul_1091); convert_element_type_1728 = mul_1091 = None + sum_136 = torch.ops.aten.sum.dim_IntList(mul_1418, [0, 1]); mul_1418 = None + convert_element_type_1731 = torch.ops.prims.convert_element_type.default(mul_1417, torch.bfloat16); mul_1417 = None + convert_element_type_default_71 = torch.ops.prims.convert_element_type.default(sum_136, torch.float32); sum_136 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_71, 'avg', 128, '0'); convert_element_type_default_71 = None + wait_tensor_621 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + convert_element_type_1734 = torch.ops.prims.convert_element_type.default(sum_134, torch.float32); sum_134 = None + view_1849 = torch.ops.aten.view.default(convert_element_type_1734, [2, 4096, 1, 32, 2]); convert_element_type_1734 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1849); view_1849 = None + mul_1419 = torch.ops.aten.mul.Tensor(view_as_complex_60, clone_9); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_1419); mul_1419 = None + view_1850 = torch.ops.aten.view.default(view_as_real_60, [2, 4096, 1, 64]); view_as_real_60 = None + convert_element_type_1735 = torch.ops.prims.convert_element_type.default(view_1850, torch.bfloat16); view_1850 = None + squeeze_29 = torch.ops.aten.squeeze.dim(convert_element_type_1735, 2); convert_element_type_1735 = None + cat_265 = torch.ops.aten.cat.default([convert_element_type_1731, squeeze_29], 2); convert_element_type_1731 = squeeze_29 = None + view_1851 = torch.ops.aten.view.default(cat_265, [8192, 576]); cat_265 = None + permute_598 = torch.ops.aten.permute.default(view_1851, [1, 0]) + mm_278 = torch.ops.aten.mm.default(permute_598, view_1510); permute_598 = None + convert_element_type_1233 = torch.ops.prims.convert_element_type.default(primals_377, torch.bfloat16); primals_377 = None + all_gather_into_tensor_387 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1233, 128, '0'); convert_element_type_1233 = None + wait_tensor_475 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_387); all_gather_into_tensor_387 = None + slice_139 = torch.ops.aten.slice.Tensor(wait_tensor_475, 0, 0, 576); wait_tensor_475 = None + permute_342 = torch.ops.aten.permute.default(slice_139, [1, 0]); slice_139 = None + permute_600 = torch.ops.aten.permute.default(permute_342, [1, 0]); permute_342 = None + mm_279 = torch.ops.aten.mm.default(view_1851, permute_600); view_1851 = permute_600 = None + view_1852 = torch.ops.aten.view.default(mm_279, [2, 4096, 2048]); mm_279 = None + convert_element_type_1740 = torch.ops.prims.convert_element_type.default(mm_278, torch.float32); mm_278 = None + split_367 = torch.ops.aten.split.Tensor(convert_element_type_1740, 5); convert_element_type_1740 = None + getitem_6753 = split_367[0] + getitem_6754 = split_367[1] + getitem_6755 = split_367[2] + getitem_6756 = split_367[3] + getitem_6757 = split_367[4] + getitem_6758 = split_367[5] + getitem_6759 = split_367[6] + getitem_6760 = split_367[7] + getitem_6761 = split_367[8] + getitem_6762 = split_367[9] + getitem_6763 = split_367[10] + getitem_6764 = split_367[11] + getitem_6765 = split_367[12] + getitem_6766 = split_367[13] + getitem_6767 = split_367[14] + getitem_6768 = split_367[15] + getitem_6769 = split_367[16] + getitem_6770 = split_367[17] + getitem_6771 = split_367[18] + getitem_6772 = split_367[19] + getitem_6773 = split_367[20] + getitem_6774 = split_367[21] + getitem_6775 = split_367[22] + getitem_6776 = split_367[23] + getitem_6777 = split_367[24] + getitem_6778 = split_367[25] + getitem_6779 = split_367[26] + getitem_6780 = split_367[27] + getitem_6781 = split_367[28] + getitem_6782 = split_367[29] + getitem_6783 = split_367[30] + getitem_6784 = split_367[31] + getitem_6785 = split_367[32] + getitem_6786 = split_367[33] + getitem_6787 = split_367[34] + getitem_6788 = split_367[35] + getitem_6789 = split_367[36] + getitem_6790 = split_367[37] + getitem_6791 = split_367[38] + getitem_6792 = split_367[39] + getitem_6793 = split_367[40] + getitem_6794 = split_367[41] + getitem_6795 = split_367[42] + getitem_6796 = split_367[43] + getitem_6797 = split_367[44] + getitem_6798 = split_367[45] + getitem_6799 = split_367[46] + getitem_6800 = split_367[47] + getitem_6801 = split_367[48] + getitem_6802 = split_367[49] + getitem_6803 = split_367[50] + getitem_6804 = split_367[51] + getitem_6805 = split_367[52] + getitem_6806 = split_367[53] + getitem_6807 = split_367[54] + getitem_6808 = split_367[55] + getitem_6809 = split_367[56] + getitem_6810 = split_367[57] + getitem_6811 = split_367[58] + getitem_6812 = split_367[59] + getitem_6813 = split_367[60] + getitem_6814 = split_367[61] + getitem_6815 = split_367[62] + getitem_6816 = split_367[63] + getitem_6817 = split_367[64] + getitem_6818 = split_367[65] + getitem_6819 = split_367[66] + getitem_6820 = split_367[67] + getitem_6821 = split_367[68] + getitem_6822 = split_367[69] + getitem_6823 = split_367[70] + getitem_6824 = split_367[71] + getitem_6825 = split_367[72] + getitem_6826 = split_367[73] + getitem_6827 = split_367[74] + getitem_6828 = split_367[75] + getitem_6829 = split_367[76] + getitem_6830 = split_367[77] + getitem_6831 = split_367[78] + getitem_6832 = split_367[79] + getitem_6833 = split_367[80] + getitem_6834 = split_367[81] + getitem_6835 = split_367[82] + getitem_6836 = split_367[83] + getitem_6837 = split_367[84] + getitem_6838 = split_367[85] + getitem_6839 = split_367[86] + getitem_6840 = split_367[87] + getitem_6841 = split_367[88] + getitem_6842 = split_367[89] + getitem_6843 = split_367[90] + getitem_6844 = split_367[91] + getitem_6845 = split_367[92] + getitem_6846 = split_367[93] + getitem_6847 = split_367[94] + getitem_6848 = split_367[95] + getitem_6849 = split_367[96] + getitem_6850 = split_367[97] + getitem_6851 = split_367[98] + getitem_6852 = split_367[99] + getitem_6853 = split_367[100] + getitem_6854 = split_367[101] + getitem_6855 = split_367[102] + getitem_6856 = split_367[103] + getitem_6857 = split_367[104] + getitem_6858 = split_367[105] + getitem_6859 = split_367[106] + getitem_6860 = split_367[107] + getitem_6861 = split_367[108] + getitem_6862 = split_367[109] + getitem_6863 = split_367[110] + getitem_6864 = split_367[111] + getitem_6865 = split_367[112] + getitem_6866 = split_367[113] + getitem_6867 = split_367[114] + getitem_6868 = split_367[115]; split_367 = None + constant_pad_nd_295 = torch.ops.aten.constant_pad_nd.default(getitem_6868, [0, 0, 0, 4], 0.0); getitem_6868 = None + cat_266 = torch.ops.aten.cat.default([getitem_6753, getitem_6754, getitem_6755, getitem_6756, getitem_6757, getitem_6758, getitem_6759, getitem_6760, getitem_6761, getitem_6762, getitem_6763, getitem_6764, getitem_6765, getitem_6766, getitem_6767, getitem_6768, getitem_6769, getitem_6770, getitem_6771, getitem_6772, getitem_6773, getitem_6774, getitem_6775, getitem_6776, getitem_6777, getitem_6778, getitem_6779, getitem_6780, getitem_6781, getitem_6782, getitem_6783, getitem_6784, getitem_6785, getitem_6786, getitem_6787, getitem_6788, getitem_6789, getitem_6790, getitem_6791, getitem_6792, getitem_6793, getitem_6794, getitem_6795, getitem_6796, getitem_6797, getitem_6798, getitem_6799, getitem_6800, getitem_6801, getitem_6802, getitem_6803, getitem_6804, getitem_6805, getitem_6806, getitem_6807, getitem_6808, getitem_6809, getitem_6810, getitem_6811, getitem_6812, getitem_6813, getitem_6814, getitem_6815, getitem_6816, getitem_6817, getitem_6818, getitem_6819, getitem_6820, getitem_6821, getitem_6822, getitem_6823, getitem_6824, getitem_6825, getitem_6826, getitem_6827, getitem_6828, getitem_6829, getitem_6830, getitem_6831, getitem_6832, getitem_6833, getitem_6834, getitem_6835, getitem_6836, getitem_6837, getitem_6838, getitem_6839, getitem_6840, getitem_6841, getitem_6842, getitem_6843, getitem_6844, getitem_6845, getitem_6846, getitem_6847, getitem_6848, getitem_6849, getitem_6850, getitem_6851, getitem_6852, getitem_6853, getitem_6854, getitem_6855, getitem_6856, getitem_6857, getitem_6858, getitem_6859, getitem_6860, getitem_6861, getitem_6862, getitem_6863, getitem_6864, getitem_6865, getitem_6866, getitem_6867, constant_pad_nd_295, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_6753 = getitem_6754 = getitem_6755 = getitem_6756 = getitem_6757 = getitem_6758 = getitem_6759 = getitem_6760 = getitem_6761 = getitem_6762 = getitem_6763 = getitem_6764 = getitem_6765 = getitem_6766 = getitem_6767 = getitem_6768 = getitem_6769 = getitem_6770 = getitem_6771 = getitem_6772 = getitem_6773 = getitem_6774 = getitem_6775 = getitem_6776 = getitem_6777 = getitem_6778 = getitem_6779 = getitem_6780 = getitem_6781 = getitem_6782 = getitem_6783 = getitem_6784 = getitem_6785 = getitem_6786 = getitem_6787 = getitem_6788 = getitem_6789 = getitem_6790 = getitem_6791 = getitem_6792 = getitem_6793 = getitem_6794 = getitem_6795 = getitem_6796 = getitem_6797 = getitem_6798 = getitem_6799 = getitem_6800 = getitem_6801 = getitem_6802 = getitem_6803 = getitem_6804 = getitem_6805 = getitem_6806 = getitem_6807 = getitem_6808 = getitem_6809 = getitem_6810 = getitem_6811 = getitem_6812 = getitem_6813 = getitem_6814 = getitem_6815 = getitem_6816 = getitem_6817 = getitem_6818 = getitem_6819 = getitem_6820 = getitem_6821 = getitem_6822 = getitem_6823 = getitem_6824 = getitem_6825 = getitem_6826 = getitem_6827 = getitem_6828 = getitem_6829 = getitem_6830 = getitem_6831 = getitem_6832 = getitem_6833 = getitem_6834 = getitem_6835 = getitem_6836 = getitem_6837 = getitem_6838 = getitem_6839 = getitem_6840 = getitem_6841 = getitem_6842 = getitem_6843 = getitem_6844 = getitem_6845 = getitem_6846 = getitem_6847 = getitem_6848 = getitem_6849 = getitem_6850 = getitem_6851 = getitem_6852 = getitem_6853 = getitem_6854 = getitem_6855 = getitem_6856 = getitem_6857 = getitem_6858 = getitem_6859 = getitem_6860 = getitem_6861 = getitem_6862 = getitem_6863 = getitem_6864 = getitem_6865 = getitem_6866 = getitem_6867 = constant_pad_nd_295 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_266, 'avg', 128, '0'); cat_266 = None + wait_tensor_622 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55); reduce_scatter_tensor_55 = None + slice_184 = torch.ops.aten.slice.Tensor(permute_593, 3, 0, 128) + slice_185 = torch.ops.aten.slice.Tensor(permute_593, 3, 128, 192); permute_593 = None + convert_element_type_1741 = torch.ops.prims.convert_element_type.default(slice_185, torch.float32); slice_185 = None + view_1853 = torch.ops.aten.view.default(convert_element_type_1741, [2, 4096, 16, 32, 2]); convert_element_type_1741 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1853); view_1853 = None + mul_1420 = torch.ops.aten.mul.Tensor(view_as_complex_61, clone_9); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_1420); mul_1420 = None + view_1854 = torch.ops.aten.view.default(view_as_real_61, [2, 4096, 16, 64]); view_as_real_61 = None + convert_element_type_1742 = torch.ops.prims.convert_element_type.default(view_1854, torch.bfloat16); view_1854 = None + cat_267 = torch.ops.aten.cat.default([slice_184, convert_element_type_1742], 3); slice_184 = convert_element_type_1742 = None + view_1855 = torch.ops.aten.view.default(cat_267, [2, 4096, 3072]); cat_267 = None + view_1856 = torch.ops.aten.view.default(view_1855, [8192, 3072]); view_1855 = None + permute_602 = torch.ops.aten.permute.default(view_1856, [1, 0]) + mm_280 = torch.ops.aten.mm.default(permute_602, view_1510); permute_602 = view_1510 = None + convert_element_type_1228 = torch.ops.prims.convert_element_type.default(primals_376, torch.bfloat16); primals_376 = None + all_gather_into_tensor_386 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1228, 128, '0'); convert_element_type_1228 = None + wait_tensor_474 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_386); all_gather_into_tensor_386 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_474, [1, 0]); wait_tensor_474 = None + permute_604 = torch.ops.aten.permute.default(permute_341, [1, 0]); permute_341 = None + mm_281 = torch.ops.aten.mm.default(view_1856, permute_604); view_1856 = permute_604 = None + view_1857 = torch.ops.aten.view.default(mm_281, [2, 4096, 2048]); mm_281 = None + add_1833 = torch.ops.aten.add.Tensor(view_1852, view_1857); view_1852 = view_1857 = None + convert_element_type_1747 = torch.ops.prims.convert_element_type.default(mm_280, torch.float32); mm_280 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1747, 'avg', 128, '0'); convert_element_type_1747 = None + wait_tensor_623 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + convert_element_type_1748 = torch.ops.prims.convert_element_type.default(add_1833, torch.float32); add_1833 = None + convert_element_type_1225 = torch.ops.prims.convert_element_type.default(primals_375, torch.bfloat16); primals_375 = None + all_gather_into_tensor_385 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1225, 128, '0'); convert_element_type_1225 = None + wait_tensor_473 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_385); all_gather_into_tensor_385 = None + convert_element_type_1750 = torch.ops.prims.convert_element_type.default(wait_tensor_473, torch.float32); wait_tensor_473 = None + mul_1421 = torch.ops.aten.mul.Tensor(convert_element_type_1748, convert_element_type_1750); convert_element_type_1750 = None + convert_element_type_1226 = torch.ops.prims.convert_element_type.default(add_1501, torch.float32); add_1501 = None + mul_1087 = torch.ops.aten.mul.Tensor(convert_element_type_1226, rsqrt_69); convert_element_type_1226 = None + mul_1423 = torch.ops.aten.mul.Tensor(mul_1087, mul_1421) + sum_137 = torch.ops.aten.sum.dim_IntList(mul_1423, [2], True); mul_1423 = None + div_155 = torch.ops.aten.div.Tensor(mul_1087, 2048) + mul_1424 = torch.ops.aten.mul.Tensor(div_155, sum_137); div_155 = sum_137 = None + sub_648 = torch.ops.aten.sub.Tensor(mul_1421, mul_1424); mul_1421 = mul_1424 = None + mul_1425 = torch.ops.aten.mul.Tensor(sub_648, rsqrt_69); sub_648 = rsqrt_69 = None + mul_1426 = torch.ops.aten.mul.Tensor(convert_element_type_1748, mul_1087); convert_element_type_1748 = mul_1087 = None + sum_138 = torch.ops.aten.sum.dim_IntList(mul_1426, [0, 1]); mul_1426 = None + convert_element_type_1751 = torch.ops.prims.convert_element_type.default(mul_1425, torch.bfloat16); mul_1425 = None + add_1834 = torch.ops.aten.add.Tensor(add_1832, convert_element_type_1751); add_1832 = convert_element_type_1751 = None + convert_element_type_default_70 = torch.ops.prims.convert_element_type.default(sum_138, torch.float32); sum_138 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_70, 'avg', 128, '0'); convert_element_type_default_70 = None + wait_tensor_624 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57); reduce_scatter_tensor_57 = None + view_1858 = torch.ops.aten.view.default(add_1834, [8192, 2048]) + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_1858, 1) + convert_element_type_1754 = torch.ops.prims.convert_element_type.default(unsqueeze_57, torch.float32); unsqueeze_57 = None + bmm_34 = torch.ops.aten.bmm.default(permute_606, convert_element_type_1754); permute_606 = None + bmm_35 = torch.ops.aten.bmm.default(convert_element_type_1754, permute_607); convert_element_type_1754 = permute_607 = None + convert_element_type_1755 = torch.ops.prims.convert_element_type.default(bmm_34, torch.bfloat16); bmm_34 = None + view_1859 = torch.ops.aten.view.default(bmm_35, [8192, 6]); bmm_35 = None + view_1860 = torch.ops.aten.view.default(convert_element_type_1755, [49152, 2048]); convert_element_type_1755 = None + index_60 = torch.ops.aten.index.Tensor(view_1860, [getitem_2331]); view_1860 = getitem_2331 = None + permute_608 = torch.ops.aten.permute.default(view_1858, [1, 0]) + mm_282 = torch.ops.aten.mm.default(permute_608, mul_1084); permute_608 = mul_1084 = None + convert_element_type_1220 = torch.ops.prims.convert_element_type.default(primals_374, torch.bfloat16); primals_374 = None + all_gather_into_tensor_384 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1220, 128, '0'); convert_element_type_1220 = None + wait_tensor_472 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_384); all_gather_into_tensor_384 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_472, [1, 0]); wait_tensor_472 = None + permute_610 = torch.ops.aten.permute.default(permute_340, [1, 0]); permute_340 = None + mm_283 = torch.ops.aten.mm.default(view_1858, permute_610); view_1858 = permute_610 = None + convert_element_type_1760 = torch.ops.prims.convert_element_type.default(mm_282, torch.float32); mm_282 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1760, 'avg', 128, '0'); convert_element_type_1760 = None + wait_tensor_625 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + convert_element_type_1215 = torch.ops.prims.convert_element_type.default(mm_180, torch.float32); mm_180 = None + neg_44 = torch.ops.aten.neg.default(convert_element_type_1215) + exp_66 = torch.ops.aten.exp.default(neg_44); neg_44 = None + add_1496 = torch.ops.aten.add.Tensor(exp_66, 1); exp_66 = None + div_110 = torch.ops.aten.div.Tensor(convert_element_type_1215, add_1496) + convert_element_type_1216 = torch.ops.prims.convert_element_type.default(div_110, torch.bfloat16); div_110 = None + mul_1427 = torch.ops.aten.mul.Tensor(mm_283, convert_element_type_1216); convert_element_type_1216 = None + mul_1428 = torch.ops.aten.mul.Tensor(mm_283, mm_181); mm_283 = mm_181 = None + permute_612 = torch.ops.aten.permute.default(mul_1427, [1, 0]) + mm_284 = torch.ops.aten.mm.default(permute_612, view_1465); permute_612 = None + convert_element_type_1217 = torch.ops.prims.convert_element_type.default(primals_373, torch.bfloat16); primals_373 = None + all_gather_into_tensor_383 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1217, 128, '0'); convert_element_type_1217 = None + wait_tensor_471 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_383); all_gather_into_tensor_383 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_471, [1, 0]); wait_tensor_471 = None + permute_614 = torch.ops.aten.permute.default(permute_339, [1, 0]); permute_339 = None + mm_285 = torch.ops.aten.mm.default(mul_1427, permute_614); mul_1427 = permute_614 = None + convert_element_type_1765 = torch.ops.prims.convert_element_type.default(mm_284, torch.float32); mm_284 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1765, 'avg', 128, '0'); convert_element_type_1765 = None + wait_tensor_626 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59); reduce_scatter_tensor_59 = None + convert_element_type_1766 = torch.ops.prims.convert_element_type.default(mul_1428, torch.float32); mul_1428 = None + reciprocal_8 = torch.ops.aten.reciprocal.default(add_1496); add_1496 = None + mul_1429 = torch.ops.aten.mul.Tensor(reciprocal_8, 1); reciprocal_8 = None + mul_1430 = torch.ops.aten.mul.Tensor(convert_element_type_1766, mul_1429); convert_element_type_1766 = None + sub_649 = torch.ops.aten.sub.Tensor(1, mul_1429); mul_1429 = None + mul_1431 = torch.ops.aten.mul.Tensor(convert_element_type_1215, sub_649); convert_element_type_1215 = sub_649 = None + add_1836 = torch.ops.aten.add.Tensor(mul_1431, 1); mul_1431 = None + mul_1432 = torch.ops.aten.mul.Tensor(mul_1430, add_1836); mul_1430 = add_1836 = None + convert_element_type_1768 = torch.ops.prims.convert_element_type.default(mul_1432, torch.bfloat16); mul_1432 = None + permute_616 = torch.ops.aten.permute.default(convert_element_type_1768, [1, 0]) + mm_286 = torch.ops.aten.mm.default(permute_616, view_1465); permute_616 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(primals_372, torch.bfloat16); primals_372 = None + all_gather_into_tensor_382 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1212, 128, '0'); convert_element_type_1212 = None + wait_tensor_470 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_382); all_gather_into_tensor_382 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_470, [1, 0]); wait_tensor_470 = None + permute_618 = torch.ops.aten.permute.default(permute_338, [1, 0]); permute_338 = None + mm_287 = torch.ops.aten.mm.default(convert_element_type_1768, permute_618); convert_element_type_1768 = permute_618 = None + add_1837 = torch.ops.aten.add.Tensor(mm_285, mm_287); mm_285 = mm_287 = None + convert_element_type_1773 = torch.ops.prims.convert_element_type.default(mm_286, torch.float32); mm_286 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1773, 'avg', 128, '0'); convert_element_type_1773 = None + wait_tensor_627 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + all_to_all_single_86 = torch.ops._c10d_functional.all_to_all_single.default(index_60, [_local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351], [_local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343], '1033'); index_60 = None + wait_tensor_628 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_86); all_to_all_single_86 = None + full_372 = torch.ops.aten.full.default([sym_size_int_85, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_85 = None + slice_scatter_4 = torch.ops.aten.slice_scatter.default(full_372, wait_tensor_628, 0, 0, -1); wait_tensor_628 = None + index_61 = torch.ops.aten.index.Tensor(slice_scatter_4, [getitem_2332]); slice_scatter_4 = None + permute_620 = torch.ops.aten.permute.default(index_61, [1, 0]) + _grouped_mm_102 = torch.ops.aten._grouped_mm.default(permute_620, mul_1064, cumsum_65); permute_620 = mul_1064 = None + _grouped_mm_103 = torch.ops.aten._grouped_mm.default(index_61, permute_622, cumsum_65); index_61 = permute_622 = None + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(_grouped_mm_63, torch.float32); _grouped_mm_63 = None + neg_43 = torch.ops.aten.neg.default(convert_element_type_1210) + exp_65 = torch.ops.aten.exp.default(neg_43); neg_43 = None + add_1460 = torch.ops.aten.add.Tensor(exp_65, 1); exp_65 = None + div_109 = torch.ops.aten.div.Tensor(convert_element_type_1210, add_1460) + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(div_109, torch.bfloat16); div_109 = None + mul_1433 = torch.ops.aten.mul.Tensor(_grouped_mm_103, convert_element_type_1211); convert_element_type_1211 = None + mul_1434 = torch.ops.aten.mul.Tensor(_grouped_mm_103, _grouped_mm_64); _grouped_mm_103 = _grouped_mm_64 = None + permute_624 = torch.ops.aten.permute.default(mul_1433, [1, 0]) + _grouped_mm_104 = torch.ops.aten._grouped_mm.default(permute_624, index_43, cumsum_65); permute_624 = None + _grouped_mm_105 = torch.ops.aten._grouped_mm.default(mul_1433, permute_626, cumsum_65); mul_1433 = permute_626 = None + convert_element_type_1774 = torch.ops.prims.convert_element_type.default(mul_1434, torch.float32); mul_1434 = None + reciprocal_9 = torch.ops.aten.reciprocal.default(add_1460); add_1460 = None + mul_1435 = torch.ops.aten.mul.Tensor(reciprocal_9, 1); reciprocal_9 = None + mul_1436 = torch.ops.aten.mul.Tensor(convert_element_type_1774, mul_1435); convert_element_type_1774 = None + sub_650 = torch.ops.aten.sub.Tensor(1, mul_1435); mul_1435 = None + mul_1437 = torch.ops.aten.mul.Tensor(convert_element_type_1210, sub_650); convert_element_type_1210 = sub_650 = None + add_1839 = torch.ops.aten.add.Tensor(mul_1437, 1); mul_1437 = None + mul_1438 = torch.ops.aten.mul.Tensor(mul_1436, add_1839); mul_1436 = add_1839 = None + convert_element_type_1776 = torch.ops.prims.convert_element_type.default(mul_1438, torch.bfloat16); mul_1438 = None + permute_628 = torch.ops.aten.permute.default(convert_element_type_1776, [1, 0]) + _grouped_mm_106 = torch.ops.aten._grouped_mm.default(permute_628, index_43, cumsum_65); permute_628 = index_43 = None + _grouped_mm_107 = torch.ops.aten._grouped_mm.default(convert_element_type_1776, permute_630, cumsum_65); convert_element_type_1776 = permute_630 = cumsum_65 = None + add_1840 = torch.ops.aten.add.Tensor(_grouped_mm_105, _grouped_mm_107); _grouped_mm_105 = _grouped_mm_107 = None + convert_element_type_1777 = torch.ops.prims.convert_element_type.default(_grouped_mm_104, torch.float32); _grouped_mm_104 = None + div_156 = torch.ops.aten.div.Tensor(convert_element_type_1777, 128); convert_element_type_1777 = None + split_369 = torch.ops.aten.split.Tensor(div_156, 88, 1); div_156 = None + getitem_6885 = split_369[0] + getitem_6902 = split_369[1] + getitem_6919 = split_369[2] + getitem_6936 = split_369[3] + getitem_6953 = split_369[4] + getitem_6970 = split_369[5] + getitem_6987 = split_369[6] + getitem_7004 = split_369[7] + getitem_7021 = split_369[8] + getitem_7038 = split_369[9] + getitem_7055 = split_369[10] + getitem_7072 = split_369[11] + getitem_7089 = split_369[12] + getitem_7106 = split_369[13] + getitem_7123 = split_369[14] + getitem_7140 = split_369[15]; split_369 = None + cat_268 = torch.ops.aten.cat.default([getitem_6885, getitem_6902, getitem_6919, getitem_6936, getitem_6953, getitem_6970, getitem_6987, getitem_7004, getitem_7021, getitem_7038, getitem_7055, getitem_7072, getitem_7089, getitem_7106, getitem_7123, getitem_7140]); getitem_6885 = getitem_6902 = getitem_6919 = getitem_6936 = getitem_6953 = getitem_6970 = getitem_6987 = getitem_7004 = getitem_7021 = getitem_7038 = getitem_7055 = getitem_7072 = getitem_7089 = getitem_7106 = getitem_7123 = getitem_7140 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_268, 'sum', 16, '1025'); cat_268 = None + wait_tensor_629 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61); reduce_scatter_tensor_61 = None + convert_element_type_1778 = torch.ops.prims.convert_element_type.default(_grouped_mm_102, torch.float32); _grouped_mm_102 = None + div_157 = torch.ops.aten.div.Tensor(convert_element_type_1778, 128); convert_element_type_1778 = None + split_386 = torch.ops.aten.split.Tensor(div_157, 128, 1); div_157 = None + getitem_7157 = split_386[0] + getitem_7174 = split_386[1] + getitem_7191 = split_386[2] + getitem_7208 = split_386[3] + getitem_7225 = split_386[4] + getitem_7242 = split_386[5] + getitem_7259 = split_386[6] + getitem_7276 = split_386[7] + getitem_7293 = split_386[8] + getitem_7310 = split_386[9] + getitem_7327 = split_386[10] + getitem_7344 = split_386[11] + getitem_7361 = split_386[12] + getitem_7378 = split_386[13] + getitem_7395 = split_386[14] + getitem_7412 = split_386[15]; split_386 = None + cat_269 = torch.ops.aten.cat.default([getitem_7157, getitem_7174, getitem_7191, getitem_7208, getitem_7225, getitem_7242, getitem_7259, getitem_7276, getitem_7293, getitem_7310, getitem_7327, getitem_7344, getitem_7361, getitem_7378, getitem_7395, getitem_7412]); getitem_7157 = getitem_7174 = getitem_7191 = getitem_7208 = getitem_7225 = getitem_7242 = getitem_7259 = getitem_7276 = getitem_7293 = getitem_7310 = getitem_7327 = getitem_7344 = getitem_7361 = getitem_7378 = getitem_7395 = getitem_7412 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_269, 'sum', 16, '1025'); cat_269 = None + wait_tensor_630 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + convert_element_type_1779 = torch.ops.prims.convert_element_type.default(_grouped_mm_106, torch.float32); _grouped_mm_106 = None + div_158 = torch.ops.aten.div.Tensor(convert_element_type_1779, 128); convert_element_type_1779 = None + split_403 = torch.ops.aten.split.Tensor(div_158, 88, 1); div_158 = None + getitem_7429 = split_403[0] + getitem_7446 = split_403[1] + getitem_7463 = split_403[2] + getitem_7480 = split_403[3] + getitem_7497 = split_403[4] + getitem_7514 = split_403[5] + getitem_7531 = split_403[6] + getitem_7548 = split_403[7] + getitem_7565 = split_403[8] + getitem_7582 = split_403[9] + getitem_7599 = split_403[10] + getitem_7616 = split_403[11] + getitem_7633 = split_403[12] + getitem_7650 = split_403[13] + getitem_7667 = split_403[14] + getitem_7684 = split_403[15]; split_403 = None + cat_270 = torch.ops.aten.cat.default([getitem_7429, getitem_7446, getitem_7463, getitem_7480, getitem_7497, getitem_7514, getitem_7531, getitem_7548, getitem_7565, getitem_7582, getitem_7599, getitem_7616, getitem_7633, getitem_7650, getitem_7667, getitem_7684]); getitem_7429 = getitem_7446 = getitem_7463 = getitem_7480 = getitem_7497 = getitem_7514 = getitem_7531 = getitem_7548 = getitem_7565 = getitem_7582 = getitem_7599 = getitem_7616 = getitem_7633 = getitem_7650 = getitem_7667 = getitem_7684 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_270, 'sum', 16, '1025'); cat_270 = None + wait_tensor_631 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63); reduce_scatter_tensor_63 = None + index_put_60 = torch.ops.aten.index_put.default(full_372, [getitem_2332], add_1840, True); full_372 = getitem_2332 = add_1840 = None + slice_186 = torch.ops.aten.slice.Tensor(index_put_60, 0, 0, add_1841); index_put_60 = add_1841 = None + all_to_all_single_87 = torch.ops._c10d_functional.all_to_all_single.default(slice_186, [_local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343], [_local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351], '1033'); slice_186 = _local_scalar_dense_336 = _local_scalar_dense_337 = _local_scalar_dense_338 = _local_scalar_dense_339 = _local_scalar_dense_340 = _local_scalar_dense_341 = _local_scalar_dense_342 = _local_scalar_dense_343 = _local_scalar_dense_344 = _local_scalar_dense_345 = _local_scalar_dense_346 = _local_scalar_dense_347 = _local_scalar_dense_348 = _local_scalar_dense_349 = _local_scalar_dense_350 = _local_scalar_dense_351 = None + wait_tensor_632 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_87); all_to_all_single_87 = None + index_put_61 = torch.ops.aten.index_put.default(full_default_52, [div_107], wait_tensor_632, True); div_107 = wait_tensor_632 = None + add_1845 = torch.ops.aten.add.Tensor(add_1837, index_put_61); add_1837 = index_put_61 = None + mul_1439 = torch.ops.aten.mul.Tensor(view_1859, 1.0); view_1859 = None + scatter_add_4 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_2329, mul_1439); getitem_2329 = mul_1439 = None + convert_element_type_1199 = torch.ops.prims.convert_element_type.default(mm_179, torch.float32); mm_179 = None + sub_504 = torch.ops.aten.sub.Tensor(convert_element_type_1199, amax_21); convert_element_type_1199 = amax_21 = None + exp_64 = torch.ops.aten.exp.default(sub_504); sub_504 = None + div_106 = torch.ops.aten.div.Tensor(exp_64, sum_85); exp_64 = sum_85 = None + mul_1440 = torch.ops.aten.mul.Tensor(scatter_add_4, div_106); scatter_add_4 = None + sum_139 = torch.ops.aten.sum.dim_IntList(mul_1440, [1], True) + neg_67 = torch.ops.aten.neg.default(div_106); div_106 = None + fma_4 = torch.ops.prims.fma.default(neg_67, sum_139, mul_1440); neg_67 = sum_139 = mul_1440 = None + convert_element_type_1780 = torch.ops.prims.convert_element_type.default(fma_4, torch.bfloat16); fma_4 = None + permute_632 = torch.ops.aten.permute.default(convert_element_type_1780, [1, 0]) + mm_288 = torch.ops.aten.mm.default(permute_632, view_1465); permute_632 = view_1465 = None + convert_element_type_1196 = torch.ops.prims.convert_element_type.default(primals_367, torch.bfloat16); primals_367 = None + all_gather_into_tensor_375 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1196, 128, '0'); convert_element_type_1196 = None + wait_tensor_459 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_375); all_gather_into_tensor_375 = None + slice_135 = torch.ops.aten.slice.Tensor(wait_tensor_459, 0, 0, 64); wait_tensor_459 = None + permute_334 = torch.ops.aten.permute.default(slice_135, [1, 0]); slice_135 = None + permute_634 = torch.ops.aten.permute.default(permute_334, [1, 0]); permute_334 = None + mm_289 = torch.ops.aten.mm.default(convert_element_type_1780, permute_634); convert_element_type_1780 = permute_634 = None + add_1846 = torch.ops.aten.add.Tensor(add_1845, mm_289); add_1845 = mm_289 = None + convert_element_type_1785 = torch.ops.prims.convert_element_type.default(mm_288, torch.float32); mm_288 = None + split_419 = torch.ops.aten.split.Tensor(convert_element_type_1785, 1); convert_element_type_1785 = None + getitem_7685 = split_419[0] + getitem_7686 = split_419[1] + getitem_7687 = split_419[2] + getitem_7688 = split_419[3] + getitem_7689 = split_419[4] + getitem_7690 = split_419[5] + getitem_7691 = split_419[6] + getitem_7692 = split_419[7] + getitem_7693 = split_419[8] + getitem_7694 = split_419[9] + getitem_7695 = split_419[10] + getitem_7696 = split_419[11] + getitem_7697 = split_419[12] + getitem_7698 = split_419[13] + getitem_7699 = split_419[14] + getitem_7700 = split_419[15] + getitem_7701 = split_419[16] + getitem_7702 = split_419[17] + getitem_7703 = split_419[18] + getitem_7704 = split_419[19] + getitem_7705 = split_419[20] + getitem_7706 = split_419[21] + getitem_7707 = split_419[22] + getitem_7708 = split_419[23] + getitem_7709 = split_419[24] + getitem_7710 = split_419[25] + getitem_7711 = split_419[26] + getitem_7712 = split_419[27] + getitem_7713 = split_419[28] + getitem_7714 = split_419[29] + getitem_7715 = split_419[30] + getitem_7716 = split_419[31] + getitem_7717 = split_419[32] + getitem_7718 = split_419[33] + getitem_7719 = split_419[34] + getitem_7720 = split_419[35] + getitem_7721 = split_419[36] + getitem_7722 = split_419[37] + getitem_7723 = split_419[38] + getitem_7724 = split_419[39] + getitem_7725 = split_419[40] + getitem_7726 = split_419[41] + getitem_7727 = split_419[42] + getitem_7728 = split_419[43] + getitem_7729 = split_419[44] + getitem_7730 = split_419[45] + getitem_7731 = split_419[46] + getitem_7732 = split_419[47] + getitem_7733 = split_419[48] + getitem_7734 = split_419[49] + getitem_7735 = split_419[50] + getitem_7736 = split_419[51] + getitem_7737 = split_419[52] + getitem_7738 = split_419[53] + getitem_7739 = split_419[54] + getitem_7740 = split_419[55] + getitem_7741 = split_419[56] + getitem_7742 = split_419[57] + getitem_7743 = split_419[58] + getitem_7744 = split_419[59] + getitem_7745 = split_419[60] + getitem_7746 = split_419[61] + getitem_7747 = split_419[62] + getitem_7748 = split_419[63]; split_419 = None + cat_271 = torch.ops.aten.cat.default([getitem_7685, getitem_7686, getitem_7687, getitem_7688, getitem_7689, getitem_7690, getitem_7691, getitem_7692, getitem_7693, getitem_7694, getitem_7695, getitem_7696, getitem_7697, getitem_7698, getitem_7699, getitem_7700, getitem_7701, getitem_7702, getitem_7703, getitem_7704, getitem_7705, getitem_7706, getitem_7707, getitem_7708, getitem_7709, getitem_7710, getitem_7711, getitem_7712, getitem_7713, getitem_7714, getitem_7715, getitem_7716, getitem_7717, getitem_7718, getitem_7719, getitem_7720, getitem_7721, getitem_7722, getitem_7723, getitem_7724, getitem_7725, getitem_7726, getitem_7727, getitem_7728, getitem_7729, getitem_7730, getitem_7731, getitem_7732, getitem_7733, getitem_7734, getitem_7735, getitem_7736, getitem_7737, getitem_7738, getitem_7739, getitem_7740, getitem_7741, getitem_7742, getitem_7743, getitem_7744, getitem_7745, getitem_7746, getitem_7747, getitem_7748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_7685 = getitem_7686 = getitem_7687 = getitem_7688 = getitem_7689 = getitem_7690 = getitem_7691 = getitem_7692 = getitem_7693 = getitem_7694 = getitem_7695 = getitem_7696 = getitem_7697 = getitem_7698 = getitem_7699 = getitem_7700 = getitem_7701 = getitem_7702 = getitem_7703 = getitem_7704 = getitem_7705 = getitem_7706 = getitem_7707 = getitem_7708 = getitem_7709 = getitem_7710 = getitem_7711 = getitem_7712 = getitem_7713 = getitem_7714 = getitem_7715 = getitem_7716 = getitem_7717 = getitem_7718 = getitem_7719 = getitem_7720 = getitem_7721 = getitem_7722 = getitem_7723 = getitem_7724 = getitem_7725 = getitem_7726 = getitem_7727 = getitem_7728 = getitem_7729 = getitem_7730 = getitem_7731 = getitem_7732 = getitem_7733 = getitem_7734 = getitem_7735 = getitem_7736 = getitem_7737 = getitem_7738 = getitem_7739 = getitem_7740 = getitem_7741 = getitem_7742 = getitem_7743 = getitem_7744 = getitem_7745 = getitem_7746 = getitem_7747 = getitem_7748 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_271, 'avg', 128, '0'); cat_271 = None + wait_tensor_633 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64); reduce_scatter_tensor_64 = None + view_1861 = torch.ops.aten.view.default(add_1846, [2, 4096, 2048]); add_1846 = None + convert_element_type_1786 = torch.ops.prims.convert_element_type.default(view_1861, torch.float32); view_1861 = None + convert_element_type_1193 = torch.ops.prims.convert_element_type.default(primals_365, torch.bfloat16); primals_365 = None + all_gather_into_tensor_374 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1193, 128, '0'); convert_element_type_1193 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_374); all_gather_into_tensor_374 = None + convert_element_type_1788 = torch.ops.prims.convert_element_type.default(wait_tensor_458, torch.float32); wait_tensor_458 = None + mul_1441 = torch.ops.aten.mul.Tensor(convert_element_type_1786, convert_element_type_1788); convert_element_type_1788 = None + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(add_1436, torch.float32); add_1436 = None + mul_1044 = torch.ops.aten.mul.Tensor(convert_element_type_1194, rsqrt_68); convert_element_type_1194 = None + mul_1443 = torch.ops.aten.mul.Tensor(mul_1044, mul_1441) + sum_140 = torch.ops.aten.sum.dim_IntList(mul_1443, [2], True); mul_1443 = None + div_159 = torch.ops.aten.div.Tensor(mul_1044, 2048) + mul_1444 = torch.ops.aten.mul.Tensor(div_159, sum_140); div_159 = sum_140 = None + sub_652 = torch.ops.aten.sub.Tensor(mul_1441, mul_1444); mul_1441 = mul_1444 = None + mul_1445 = torch.ops.aten.mul.Tensor(sub_652, rsqrt_68); sub_652 = rsqrt_68 = None + mul_1446 = torch.ops.aten.mul.Tensor(convert_element_type_1786, mul_1044); convert_element_type_1786 = mul_1044 = None + sum_141 = torch.ops.aten.sum.dim_IntList(mul_1446, [0, 1]); mul_1446 = None + convert_element_type_1789 = torch.ops.prims.convert_element_type.default(mul_1445, torch.bfloat16); mul_1445 = None + add_1847 = torch.ops.aten.add.Tensor(add_1834, convert_element_type_1789); add_1834 = convert_element_type_1789 = None + convert_element_type_default_69 = torch.ops.prims.convert_element_type.default(sum_141, torch.float32); sum_141 = None + reduce_scatter_tensor_65 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_69, 'avg', 128, '0'); convert_element_type_default_69 = None + wait_tensor_634 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_65); reduce_scatter_tensor_65 = None + view_1862 = torch.ops.aten.view.default(add_1847, [8192, 2048]) + permute_636 = torch.ops.aten.permute.default(view_1862, [1, 0]) + permute_332 = torch.ops.aten.permute.default(getitem_2325, [0, 2, 1, 3]) + view_1460 = torch.ops.aten.view.default(permute_332, [2, 4096, -1]); permute_332 = None + view_1462 = torch.ops.aten.view.default(view_1460, [8192, 2048]); view_1460 = None + mm_290 = torch.ops.aten.mm.default(permute_636, view_1462); permute_636 = view_1462 = None + convert_element_type_1190 = torch.ops.prims.convert_element_type.default(primals_364, torch.bfloat16); primals_364 = None + all_gather_into_tensor_373 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1190, 128, '0'); convert_element_type_1190 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_373); all_gather_into_tensor_373 = None + permute_333 = torch.ops.aten.permute.default(wait_tensor_457, [1, 0]); wait_tensor_457 = None + permute_638 = torch.ops.aten.permute.default(permute_333, [1, 0]); permute_333 = None + mm_291 = torch.ops.aten.mm.default(view_1862, permute_638); view_1862 = permute_638 = None + view_1863 = torch.ops.aten.view.default(mm_291, [2, 4096, 2048]); mm_291 = None + convert_element_type_1796 = torch.ops.prims.convert_element_type.default(mm_290, torch.float32); mm_290 = None + reduce_scatter_tensor_66 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1796, 'avg', 128, '0'); convert_element_type_1796 = None + wait_tensor_635 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_66); reduce_scatter_tensor_66 = None + view_1864 = torch.ops.aten.view.default(view_1863, [2, 4096, 16, 128]); view_1863 = None + permute_640 = torch.ops.aten.permute.default(view_1864, [0, 2, 1, 3]); view_1864 = None + fw_graph4 = self.fw_graph4 + joint_graph4 = self.joint_graph4 + mask_graph4 = self.mask_graph4 + flex_attention_backward_4 = torch.ops.higher_order.flex_attention_backward(permute_329, permute_330, permute_331, getitem_2325, getitem_2326, permute_640, None, fw_graph4, joint_graph4, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph4), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_329 = permute_330 = permute_331 = getitem_2325 = getitem_2326 = permute_640 = fw_graph4 = joint_graph4 = mask_graph4 = None + getitem_7749 = flex_attention_backward_4[0] + getitem_7750 = flex_attention_backward_4[1] + getitem_7751 = flex_attention_backward_4[2]; flex_attention_backward_4 = None + permute_641 = torch.ops.aten.permute.default(getitem_7751, [0, 2, 1, 3]); getitem_7751 = None + permute_642 = torch.ops.aten.permute.default(getitem_7750, [0, 2, 1, 3]); getitem_7750 = None + permute_643 = torch.ops.aten.permute.default(getitem_7749, [0, 2, 1, 3]); getitem_7749 = None + slice_188 = torch.ops.aten.slice.Tensor(permute_642, 3, 0, 128) + slice_189 = torch.ops.aten.slice.Tensor(permute_642, 3, 128, 192); permute_642 = None + sum_142 = torch.ops.aten.sum.dim_IntList(slice_189, [2], True); slice_189 = None + cat_272 = torch.ops.aten.cat.default([slice_188, permute_641], 3); slice_188 = permute_641 = None + view_1865 = torch.ops.aten.view.default(cat_272, [2, 4096, 4096]); cat_272 = None + view_1866 = torch.ops.aten.view.default(view_1865, [8192, 4096]); view_1865 = None + permute_644 = torch.ops.aten.permute.default(view_1866, [1, 0]) + mm_292 = torch.ops.aten.mm.default(permute_644, view_1457); permute_644 = view_1457 = None + convert_element_type_1187 = torch.ops.prims.convert_element_type.default(primals_363, torch.bfloat16); primals_363 = None + all_gather_into_tensor_372 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1187, 128, '0'); convert_element_type_1187 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_372); all_gather_into_tensor_372 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_456, [1, 0]); wait_tensor_456 = None + permute_646 = torch.ops.aten.permute.default(permute_328, [1, 0]); permute_328 = None + mm_293 = torch.ops.aten.mm.default(view_1866, permute_646); view_1866 = permute_646 = None + view_1867 = torch.ops.aten.view.default(mm_293, [2, 4096, 512]); mm_293 = None + convert_element_type_1801 = torch.ops.prims.convert_element_type.default(mm_292, torch.float32); mm_292 = None + reduce_scatter_tensor_67 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1801, 'avg', 128, '0'); convert_element_type_1801 = None + wait_tensor_636 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_67); reduce_scatter_tensor_67 = None + convert_element_type_1802 = torch.ops.prims.convert_element_type.default(view_1867, torch.float32); view_1867 = None + convert_element_type_1184 = torch.ops.prims.convert_element_type.default(primals_362, torch.bfloat16); primals_362 = None + all_gather_into_tensor_371 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1184, 128, '0'); convert_element_type_1184 = None + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_371); all_gather_into_tensor_371 = None + convert_element_type_1804 = torch.ops.prims.convert_element_type.default(wait_tensor_455, torch.float32); wait_tensor_455 = None + mul_1447 = torch.ops.aten.mul.Tensor(convert_element_type_1802, convert_element_type_1804); convert_element_type_1804 = None + convert_element_type_1185 = torch.ops.prims.convert_element_type.default(getitem_2321, torch.float32); getitem_2321 = None + mul_1042 = torch.ops.aten.mul.Tensor(convert_element_type_1185, rsqrt_67); convert_element_type_1185 = None + mul_1449 = torch.ops.aten.mul.Tensor(mul_1042, mul_1447) + sum_143 = torch.ops.aten.sum.dim_IntList(mul_1449, [2], True); mul_1449 = None + div_160 = torch.ops.aten.div.Tensor(mul_1042, 512) + mul_1450 = torch.ops.aten.mul.Tensor(div_160, sum_143); div_160 = sum_143 = None + sub_653 = torch.ops.aten.sub.Tensor(mul_1447, mul_1450); mul_1447 = mul_1450 = None + mul_1451 = torch.ops.aten.mul.Tensor(sub_653, rsqrt_67); sub_653 = rsqrt_67 = None + mul_1452 = torch.ops.aten.mul.Tensor(convert_element_type_1802, mul_1042); convert_element_type_1802 = mul_1042 = None + sum_144 = torch.ops.aten.sum.dim_IntList(mul_1452, [0, 1]); mul_1452 = None + convert_element_type_1805 = torch.ops.prims.convert_element_type.default(mul_1451, torch.bfloat16); mul_1451 = None + convert_element_type_default_68 = torch.ops.prims.convert_element_type.default(sum_144, torch.float32); sum_144 = None + reduce_scatter_tensor_68 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_68, 'avg', 128, '0'); convert_element_type_default_68 = None + wait_tensor_637 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_68); reduce_scatter_tensor_68 = None + convert_element_type_1808 = torch.ops.prims.convert_element_type.default(sum_142, torch.float32); sum_142 = None + view_1868 = torch.ops.aten.view.default(convert_element_type_1808, [2, 4096, 1, 32, 2]); convert_element_type_1808 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1868); view_1868 = None + mul_1453 = torch.ops.aten.mul.Tensor(view_as_complex_62, clone_9); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_1453); mul_1453 = None + view_1869 = torch.ops.aten.view.default(view_as_real_62, [2, 4096, 1, 64]); view_as_real_62 = None + convert_element_type_1809 = torch.ops.prims.convert_element_type.default(view_1869, torch.bfloat16); view_1869 = None + squeeze_30 = torch.ops.aten.squeeze.dim(convert_element_type_1809, 2); convert_element_type_1809 = None + cat_273 = torch.ops.aten.cat.default([convert_element_type_1805, squeeze_30], 2); convert_element_type_1805 = squeeze_30 = None + view_1870 = torch.ops.aten.view.default(cat_273, [8192, 576]); cat_273 = None + permute_648 = torch.ops.aten.permute.default(view_1870, [1, 0]) + mm_294 = torch.ops.aten.mm.default(permute_648, view_1443); permute_648 = None + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(primals_361, torch.bfloat16); primals_361 = None + all_gather_into_tensor_370 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1179, 128, '0'); convert_element_type_1179 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_370); all_gather_into_tensor_370 = None + slice_133 = torch.ops.aten.slice.Tensor(wait_tensor_454, 0, 0, 576); wait_tensor_454 = None + permute_327 = torch.ops.aten.permute.default(slice_133, [1, 0]); slice_133 = None + permute_650 = torch.ops.aten.permute.default(permute_327, [1, 0]); permute_327 = None + mm_295 = torch.ops.aten.mm.default(view_1870, permute_650); view_1870 = permute_650 = None + view_1871 = torch.ops.aten.view.default(mm_295, [2, 4096, 2048]); mm_295 = None + convert_element_type_1814 = torch.ops.prims.convert_element_type.default(mm_294, torch.float32); mm_294 = None + split_420 = torch.ops.aten.split.Tensor(convert_element_type_1814, 5); convert_element_type_1814 = None + getitem_7753 = split_420[0] + getitem_7754 = split_420[1] + getitem_7755 = split_420[2] + getitem_7756 = split_420[3] + getitem_7757 = split_420[4] + getitem_7758 = split_420[5] + getitem_7759 = split_420[6] + getitem_7760 = split_420[7] + getitem_7761 = split_420[8] + getitem_7762 = split_420[9] + getitem_7763 = split_420[10] + getitem_7764 = split_420[11] + getitem_7765 = split_420[12] + getitem_7766 = split_420[13] + getitem_7767 = split_420[14] + getitem_7768 = split_420[15] + getitem_7769 = split_420[16] + getitem_7770 = split_420[17] + getitem_7771 = split_420[18] + getitem_7772 = split_420[19] + getitem_7773 = split_420[20] + getitem_7774 = split_420[21] + getitem_7775 = split_420[22] + getitem_7776 = split_420[23] + getitem_7777 = split_420[24] + getitem_7778 = split_420[25] + getitem_7779 = split_420[26] + getitem_7780 = split_420[27] + getitem_7781 = split_420[28] + getitem_7782 = split_420[29] + getitem_7783 = split_420[30] + getitem_7784 = split_420[31] + getitem_7785 = split_420[32] + getitem_7786 = split_420[33] + getitem_7787 = split_420[34] + getitem_7788 = split_420[35] + getitem_7789 = split_420[36] + getitem_7790 = split_420[37] + getitem_7791 = split_420[38] + getitem_7792 = split_420[39] + getitem_7793 = split_420[40] + getitem_7794 = split_420[41] + getitem_7795 = split_420[42] + getitem_7796 = split_420[43] + getitem_7797 = split_420[44] + getitem_7798 = split_420[45] + getitem_7799 = split_420[46] + getitem_7800 = split_420[47] + getitem_7801 = split_420[48] + getitem_7802 = split_420[49] + getitem_7803 = split_420[50] + getitem_7804 = split_420[51] + getitem_7805 = split_420[52] + getitem_7806 = split_420[53] + getitem_7807 = split_420[54] + getitem_7808 = split_420[55] + getitem_7809 = split_420[56] + getitem_7810 = split_420[57] + getitem_7811 = split_420[58] + getitem_7812 = split_420[59] + getitem_7813 = split_420[60] + getitem_7814 = split_420[61] + getitem_7815 = split_420[62] + getitem_7816 = split_420[63] + getitem_7817 = split_420[64] + getitem_7818 = split_420[65] + getitem_7819 = split_420[66] + getitem_7820 = split_420[67] + getitem_7821 = split_420[68] + getitem_7822 = split_420[69] + getitem_7823 = split_420[70] + getitem_7824 = split_420[71] + getitem_7825 = split_420[72] + getitem_7826 = split_420[73] + getitem_7827 = split_420[74] + getitem_7828 = split_420[75] + getitem_7829 = split_420[76] + getitem_7830 = split_420[77] + getitem_7831 = split_420[78] + getitem_7832 = split_420[79] + getitem_7833 = split_420[80] + getitem_7834 = split_420[81] + getitem_7835 = split_420[82] + getitem_7836 = split_420[83] + getitem_7837 = split_420[84] + getitem_7838 = split_420[85] + getitem_7839 = split_420[86] + getitem_7840 = split_420[87] + getitem_7841 = split_420[88] + getitem_7842 = split_420[89] + getitem_7843 = split_420[90] + getitem_7844 = split_420[91] + getitem_7845 = split_420[92] + getitem_7846 = split_420[93] + getitem_7847 = split_420[94] + getitem_7848 = split_420[95] + getitem_7849 = split_420[96] + getitem_7850 = split_420[97] + getitem_7851 = split_420[98] + getitem_7852 = split_420[99] + getitem_7853 = split_420[100] + getitem_7854 = split_420[101] + getitem_7855 = split_420[102] + getitem_7856 = split_420[103] + getitem_7857 = split_420[104] + getitem_7858 = split_420[105] + getitem_7859 = split_420[106] + getitem_7860 = split_420[107] + getitem_7861 = split_420[108] + getitem_7862 = split_420[109] + getitem_7863 = split_420[110] + getitem_7864 = split_420[111] + getitem_7865 = split_420[112] + getitem_7866 = split_420[113] + getitem_7867 = split_420[114] + getitem_7868 = split_420[115]; split_420 = None + constant_pad_nd_372 = torch.ops.aten.constant_pad_nd.default(getitem_7868, [0, 0, 0, 4], 0.0); getitem_7868 = None + cat_274 = torch.ops.aten.cat.default([getitem_7753, getitem_7754, getitem_7755, getitem_7756, getitem_7757, getitem_7758, getitem_7759, getitem_7760, getitem_7761, getitem_7762, getitem_7763, getitem_7764, getitem_7765, getitem_7766, getitem_7767, getitem_7768, getitem_7769, getitem_7770, getitem_7771, getitem_7772, getitem_7773, getitem_7774, getitem_7775, getitem_7776, getitem_7777, getitem_7778, getitem_7779, getitem_7780, getitem_7781, getitem_7782, getitem_7783, getitem_7784, getitem_7785, getitem_7786, getitem_7787, getitem_7788, getitem_7789, getitem_7790, getitem_7791, getitem_7792, getitem_7793, getitem_7794, getitem_7795, getitem_7796, getitem_7797, getitem_7798, getitem_7799, getitem_7800, getitem_7801, getitem_7802, getitem_7803, getitem_7804, getitem_7805, getitem_7806, getitem_7807, getitem_7808, getitem_7809, getitem_7810, getitem_7811, getitem_7812, getitem_7813, getitem_7814, getitem_7815, getitem_7816, getitem_7817, getitem_7818, getitem_7819, getitem_7820, getitem_7821, getitem_7822, getitem_7823, getitem_7824, getitem_7825, getitem_7826, getitem_7827, getitem_7828, getitem_7829, getitem_7830, getitem_7831, getitem_7832, getitem_7833, getitem_7834, getitem_7835, getitem_7836, getitem_7837, getitem_7838, getitem_7839, getitem_7840, getitem_7841, getitem_7842, getitem_7843, getitem_7844, getitem_7845, getitem_7846, getitem_7847, getitem_7848, getitem_7849, getitem_7850, getitem_7851, getitem_7852, getitem_7853, getitem_7854, getitem_7855, getitem_7856, getitem_7857, getitem_7858, getitem_7859, getitem_7860, getitem_7861, getitem_7862, getitem_7863, getitem_7864, getitem_7865, getitem_7866, getitem_7867, constant_pad_nd_372, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_7753 = getitem_7754 = getitem_7755 = getitem_7756 = getitem_7757 = getitem_7758 = getitem_7759 = getitem_7760 = getitem_7761 = getitem_7762 = getitem_7763 = getitem_7764 = getitem_7765 = getitem_7766 = getitem_7767 = getitem_7768 = getitem_7769 = getitem_7770 = getitem_7771 = getitem_7772 = getitem_7773 = getitem_7774 = getitem_7775 = getitem_7776 = getitem_7777 = getitem_7778 = getitem_7779 = getitem_7780 = getitem_7781 = getitem_7782 = getitem_7783 = getitem_7784 = getitem_7785 = getitem_7786 = getitem_7787 = getitem_7788 = getitem_7789 = getitem_7790 = getitem_7791 = getitem_7792 = getitem_7793 = getitem_7794 = getitem_7795 = getitem_7796 = getitem_7797 = getitem_7798 = getitem_7799 = getitem_7800 = getitem_7801 = getitem_7802 = getitem_7803 = getitem_7804 = getitem_7805 = getitem_7806 = getitem_7807 = getitem_7808 = getitem_7809 = getitem_7810 = getitem_7811 = getitem_7812 = getitem_7813 = getitem_7814 = getitem_7815 = getitem_7816 = getitem_7817 = getitem_7818 = getitem_7819 = getitem_7820 = getitem_7821 = getitem_7822 = getitem_7823 = getitem_7824 = getitem_7825 = getitem_7826 = getitem_7827 = getitem_7828 = getitem_7829 = getitem_7830 = getitem_7831 = getitem_7832 = getitem_7833 = getitem_7834 = getitem_7835 = getitem_7836 = getitem_7837 = getitem_7838 = getitem_7839 = getitem_7840 = getitem_7841 = getitem_7842 = getitem_7843 = getitem_7844 = getitem_7845 = getitem_7846 = getitem_7847 = getitem_7848 = getitem_7849 = getitem_7850 = getitem_7851 = getitem_7852 = getitem_7853 = getitem_7854 = getitem_7855 = getitem_7856 = getitem_7857 = getitem_7858 = getitem_7859 = getitem_7860 = getitem_7861 = getitem_7862 = getitem_7863 = getitem_7864 = getitem_7865 = getitem_7866 = getitem_7867 = constant_pad_nd_372 = None + reduce_scatter_tensor_69 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_274, 'avg', 128, '0'); cat_274 = None + wait_tensor_638 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_69); reduce_scatter_tensor_69 = None + slice_190 = torch.ops.aten.slice.Tensor(permute_643, 3, 0, 128) + slice_191 = torch.ops.aten.slice.Tensor(permute_643, 3, 128, 192); permute_643 = None + convert_element_type_1815 = torch.ops.prims.convert_element_type.default(slice_191, torch.float32); slice_191 = None + view_1872 = torch.ops.aten.view.default(convert_element_type_1815, [2, 4096, 16, 32, 2]); convert_element_type_1815 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1872); view_1872 = None + mul_1454 = torch.ops.aten.mul.Tensor(view_as_complex_63, clone_9); view_as_complex_63 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_1454); mul_1454 = None + view_1873 = torch.ops.aten.view.default(view_as_real_63, [2, 4096, 16, 64]); view_as_real_63 = None + convert_element_type_1816 = torch.ops.prims.convert_element_type.default(view_1873, torch.bfloat16); view_1873 = None + cat_275 = torch.ops.aten.cat.default([slice_190, convert_element_type_1816], 3); slice_190 = convert_element_type_1816 = None + view_1874 = torch.ops.aten.view.default(cat_275, [2, 4096, 3072]); cat_275 = None + view_1875 = torch.ops.aten.view.default(view_1874, [8192, 3072]); view_1874 = None + permute_652 = torch.ops.aten.permute.default(view_1875, [1, 0]) + mm_296 = torch.ops.aten.mm.default(permute_652, view_1443); permute_652 = view_1443 = None + convert_element_type_1174 = torch.ops.prims.convert_element_type.default(primals_360, torch.bfloat16); primals_360 = None + all_gather_into_tensor_369 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1174, 128, '0'); convert_element_type_1174 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_369); all_gather_into_tensor_369 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_453, [1, 0]); wait_tensor_453 = None + permute_654 = torch.ops.aten.permute.default(permute_326, [1, 0]); permute_326 = None + mm_297 = torch.ops.aten.mm.default(view_1875, permute_654); view_1875 = permute_654 = None + view_1876 = torch.ops.aten.view.default(mm_297, [2, 4096, 2048]); mm_297 = None + add_1848 = torch.ops.aten.add.Tensor(view_1871, view_1876); view_1871 = view_1876 = None + convert_element_type_1821 = torch.ops.prims.convert_element_type.default(mm_296, torch.float32); mm_296 = None + reduce_scatter_tensor_70 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1821, 'avg', 128, '0'); convert_element_type_1821 = None + wait_tensor_639 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_70); reduce_scatter_tensor_70 = None + convert_element_type_1822 = torch.ops.prims.convert_element_type.default(add_1848, torch.float32); add_1848 = None + convert_element_type_1171 = torch.ops.prims.convert_element_type.default(primals_359, torch.bfloat16); primals_359 = None + all_gather_into_tensor_368 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1171, 128, '0'); convert_element_type_1171 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_368); all_gather_into_tensor_368 = None + convert_element_type_1824 = torch.ops.prims.convert_element_type.default(wait_tensor_452, torch.float32); wait_tensor_452 = None + mul_1455 = torch.ops.aten.mul.Tensor(convert_element_type_1822, convert_element_type_1824); convert_element_type_1824 = None + convert_element_type_1172 = torch.ops.prims.convert_element_type.default(add_1433, torch.float32); add_1433 = None + mul_1038 = torch.ops.aten.mul.Tensor(convert_element_type_1172, rsqrt_66); convert_element_type_1172 = None + mul_1457 = torch.ops.aten.mul.Tensor(mul_1038, mul_1455) + sum_145 = torch.ops.aten.sum.dim_IntList(mul_1457, [2], True); mul_1457 = None + div_161 = torch.ops.aten.div.Tensor(mul_1038, 2048) + mul_1458 = torch.ops.aten.mul.Tensor(div_161, sum_145); div_161 = sum_145 = None + sub_654 = torch.ops.aten.sub.Tensor(mul_1455, mul_1458); mul_1455 = mul_1458 = None + mul_1459 = torch.ops.aten.mul.Tensor(sub_654, rsqrt_66); sub_654 = rsqrt_66 = None + mul_1460 = torch.ops.aten.mul.Tensor(convert_element_type_1822, mul_1038); convert_element_type_1822 = mul_1038 = None + sum_146 = torch.ops.aten.sum.dim_IntList(mul_1460, [0, 1]); mul_1460 = None + convert_element_type_1825 = torch.ops.prims.convert_element_type.default(mul_1459, torch.bfloat16); mul_1459 = None + add_1849 = torch.ops.aten.add.Tensor(add_1847, convert_element_type_1825); add_1847 = convert_element_type_1825 = None + convert_element_type_default_67 = torch.ops.prims.convert_element_type.default(sum_146, torch.float32); sum_146 = None + reduce_scatter_tensor_71 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_67, 'avg', 128, '0'); convert_element_type_default_67 = None + wait_tensor_640 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_71); reduce_scatter_tensor_71 = None + view_1877 = torch.ops.aten.view.default(add_1849, [8192, 2048]) + unsqueeze_58 = torch.ops.aten.unsqueeze.default(view_1877, 1) + convert_element_type_1828 = torch.ops.prims.convert_element_type.default(unsqueeze_58, torch.float32); unsqueeze_58 = None + bmm_36 = torch.ops.aten.bmm.default(permute_656, convert_element_type_1828); permute_656 = None + bmm_37 = torch.ops.aten.bmm.default(convert_element_type_1828, permute_657); convert_element_type_1828 = permute_657 = None + convert_element_type_1829 = torch.ops.prims.convert_element_type.default(bmm_36, torch.bfloat16); bmm_36 = None + view_1878 = torch.ops.aten.view.default(bmm_37, [8192, 6]); bmm_37 = None + view_1879 = torch.ops.aten.view.default(convert_element_type_1829, [49152, 2048]); convert_element_type_1829 = None + index_62 = torch.ops.aten.index.Tensor(view_1879, [getitem_2221]); view_1879 = getitem_2221 = None + permute_658 = torch.ops.aten.permute.default(view_1877, [1, 0]) + mm_298 = torch.ops.aten.mm.default(permute_658, mul_1035); permute_658 = mul_1035 = None + convert_element_type_1166 = torch.ops.prims.convert_element_type.default(primals_358, torch.bfloat16); primals_358 = None + all_gather_into_tensor_367 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1166, 128, '0'); convert_element_type_1166 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_367); all_gather_into_tensor_367 = None + permute_325 = torch.ops.aten.permute.default(wait_tensor_451, [1, 0]); wait_tensor_451 = None + permute_660 = torch.ops.aten.permute.default(permute_325, [1, 0]); permute_325 = None + mm_299 = torch.ops.aten.mm.default(view_1877, permute_660); view_1877 = permute_660 = None + convert_element_type_1834 = torch.ops.prims.convert_element_type.default(mm_298, torch.float32); mm_298 = None + reduce_scatter_tensor_72 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1834, 'avg', 128, '0'); convert_element_type_1834 = None + wait_tensor_641 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_72); reduce_scatter_tensor_72 = None + convert_element_type_1161 = torch.ops.prims.convert_element_type.default(mm_172, torch.float32); mm_172 = None + neg_42 = torch.ops.aten.neg.default(convert_element_type_1161) + exp_63 = torch.ops.aten.exp.default(neg_42); neg_42 = None + add_1428 = torch.ops.aten.add.Tensor(exp_63, 1); exp_63 = None + div_105 = torch.ops.aten.div.Tensor(convert_element_type_1161, add_1428) + convert_element_type_1162 = torch.ops.prims.convert_element_type.default(div_105, torch.bfloat16); div_105 = None + mul_1461 = torch.ops.aten.mul.Tensor(mm_299, convert_element_type_1162); convert_element_type_1162 = None + mul_1462 = torch.ops.aten.mul.Tensor(mm_299, mm_173); mm_299 = mm_173 = None + permute_662 = torch.ops.aten.permute.default(mul_1461, [1, 0]) + mm_300 = torch.ops.aten.mm.default(permute_662, view_1398); permute_662 = None + convert_element_type_1163 = torch.ops.prims.convert_element_type.default(primals_357, torch.bfloat16); primals_357 = None + all_gather_into_tensor_366 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1163, 128, '0'); convert_element_type_1163 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_366); all_gather_into_tensor_366 = None + permute_324 = torch.ops.aten.permute.default(wait_tensor_450, [1, 0]); wait_tensor_450 = None + permute_664 = torch.ops.aten.permute.default(permute_324, [1, 0]); permute_324 = None + mm_301 = torch.ops.aten.mm.default(mul_1461, permute_664); mul_1461 = permute_664 = None + convert_element_type_1839 = torch.ops.prims.convert_element_type.default(mm_300, torch.float32); mm_300 = None + reduce_scatter_tensor_73 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1839, 'avg', 128, '0'); convert_element_type_1839 = None + wait_tensor_642 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_73); reduce_scatter_tensor_73 = None + convert_element_type_1840 = torch.ops.prims.convert_element_type.default(mul_1462, torch.float32); mul_1462 = None + reciprocal_10 = torch.ops.aten.reciprocal.default(add_1428); add_1428 = None + mul_1463 = torch.ops.aten.mul.Tensor(reciprocal_10, 1); reciprocal_10 = None + mul_1464 = torch.ops.aten.mul.Tensor(convert_element_type_1840, mul_1463); convert_element_type_1840 = None + sub_655 = torch.ops.aten.sub.Tensor(1, mul_1463); mul_1463 = None + mul_1465 = torch.ops.aten.mul.Tensor(convert_element_type_1161, sub_655); convert_element_type_1161 = sub_655 = None + add_1851 = torch.ops.aten.add.Tensor(mul_1465, 1); mul_1465 = None + mul_1466 = torch.ops.aten.mul.Tensor(mul_1464, add_1851); mul_1464 = add_1851 = None + convert_element_type_1842 = torch.ops.prims.convert_element_type.default(mul_1466, torch.bfloat16); mul_1466 = None + permute_666 = torch.ops.aten.permute.default(convert_element_type_1842, [1, 0]) + mm_302 = torch.ops.aten.mm.default(permute_666, view_1398); permute_666 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(primals_356, torch.bfloat16); primals_356 = None + all_gather_into_tensor_365 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1158, 128, '0'); convert_element_type_1158 = None + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_365); all_gather_into_tensor_365 = None + permute_323 = torch.ops.aten.permute.default(wait_tensor_449, [1, 0]); wait_tensor_449 = None + permute_668 = torch.ops.aten.permute.default(permute_323, [1, 0]); permute_323 = None + mm_303 = torch.ops.aten.mm.default(convert_element_type_1842, permute_668); convert_element_type_1842 = permute_668 = None + add_1852 = torch.ops.aten.add.Tensor(mm_301, mm_303); mm_301 = mm_303 = None + convert_element_type_1847 = torch.ops.prims.convert_element_type.default(mm_302, torch.float32); mm_302 = None + reduce_scatter_tensor_74 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1847, 'avg', 128, '0'); convert_element_type_1847 = None + wait_tensor_643 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_74); reduce_scatter_tensor_74 = None + all_to_all_single_88 = torch.ops._c10d_functional.all_to_all_single.default(index_62, [_local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335], [_local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327], '1033'); index_62 = None + wait_tensor_644 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_88); all_to_all_single_88 = None + full_378 = torch.ops.aten.full.default([sym_size_int_81, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_81 = None + slice_scatter_5 = torch.ops.aten.slice_scatter.default(full_378, wait_tensor_644, 0, 0, -1); wait_tensor_644 = None + index_63 = torch.ops.aten.index.Tensor(slice_scatter_5, [getitem_2222]); slice_scatter_5 = None + permute_670 = torch.ops.aten.permute.default(index_63, [1, 0]) + _grouped_mm_108 = torch.ops.aten._grouped_mm.default(permute_670, mul_1015, cumsum_62); permute_670 = mul_1015 = None + _grouped_mm_109 = torch.ops.aten._grouped_mm.default(index_63, permute_672, cumsum_62); index_63 = permute_672 = None + convert_element_type_1156 = torch.ops.prims.convert_element_type.default(_grouped_mm_60, torch.float32); _grouped_mm_60 = None + neg_41 = torch.ops.aten.neg.default(convert_element_type_1156) + exp_62 = torch.ops.aten.exp.default(neg_41); neg_41 = None + add_1392 = torch.ops.aten.add.Tensor(exp_62, 1); exp_62 = None + div_104 = torch.ops.aten.div.Tensor(convert_element_type_1156, add_1392) + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(div_104, torch.bfloat16); div_104 = None + mul_1467 = torch.ops.aten.mul.Tensor(_grouped_mm_109, convert_element_type_1157); convert_element_type_1157 = None + mul_1468 = torch.ops.aten.mul.Tensor(_grouped_mm_109, _grouped_mm_61); _grouped_mm_109 = _grouped_mm_61 = None + permute_674 = torch.ops.aten.permute.default(mul_1467, [1, 0]) + _grouped_mm_110 = torch.ops.aten._grouped_mm.default(permute_674, index_41, cumsum_62); permute_674 = None + _grouped_mm_111 = torch.ops.aten._grouped_mm.default(mul_1467, permute_676, cumsum_62); mul_1467 = permute_676 = None + convert_element_type_1848 = torch.ops.prims.convert_element_type.default(mul_1468, torch.float32); mul_1468 = None + reciprocal_11 = torch.ops.aten.reciprocal.default(add_1392); add_1392 = None + mul_1469 = torch.ops.aten.mul.Tensor(reciprocal_11, 1); reciprocal_11 = None + mul_1470 = torch.ops.aten.mul.Tensor(convert_element_type_1848, mul_1469); convert_element_type_1848 = None + sub_656 = torch.ops.aten.sub.Tensor(1, mul_1469); mul_1469 = None + mul_1471 = torch.ops.aten.mul.Tensor(convert_element_type_1156, sub_656); convert_element_type_1156 = sub_656 = None + add_1854 = torch.ops.aten.add.Tensor(mul_1471, 1); mul_1471 = None + mul_1472 = torch.ops.aten.mul.Tensor(mul_1470, add_1854); mul_1470 = add_1854 = None + convert_element_type_1850 = torch.ops.prims.convert_element_type.default(mul_1472, torch.bfloat16); mul_1472 = None + permute_678 = torch.ops.aten.permute.default(convert_element_type_1850, [1, 0]) + _grouped_mm_112 = torch.ops.aten._grouped_mm.default(permute_678, index_41, cumsum_62); permute_678 = index_41 = None + _grouped_mm_113 = torch.ops.aten._grouped_mm.default(convert_element_type_1850, permute_680, cumsum_62); convert_element_type_1850 = permute_680 = cumsum_62 = None + add_1855 = torch.ops.aten.add.Tensor(_grouped_mm_111, _grouped_mm_113); _grouped_mm_111 = _grouped_mm_113 = None + convert_element_type_1851 = torch.ops.prims.convert_element_type.default(_grouped_mm_110, torch.float32); _grouped_mm_110 = None + div_162 = torch.ops.aten.div.Tensor(convert_element_type_1851, 128); convert_element_type_1851 = None + split_422 = torch.ops.aten.split.Tensor(div_162, 88, 1); div_162 = None + getitem_7885 = split_422[0] + getitem_7902 = split_422[1] + getitem_7919 = split_422[2] + getitem_7936 = split_422[3] + getitem_7953 = split_422[4] + getitem_7970 = split_422[5] + getitem_7987 = split_422[6] + getitem_8004 = split_422[7] + getitem_8021 = split_422[8] + getitem_8038 = split_422[9] + getitem_8055 = split_422[10] + getitem_8072 = split_422[11] + getitem_8089 = split_422[12] + getitem_8106 = split_422[13] + getitem_8123 = split_422[14] + getitem_8140 = split_422[15]; split_422 = None + cat_276 = torch.ops.aten.cat.default([getitem_7885, getitem_7902, getitem_7919, getitem_7936, getitem_7953, getitem_7970, getitem_7987, getitem_8004, getitem_8021, getitem_8038, getitem_8055, getitem_8072, getitem_8089, getitem_8106, getitem_8123, getitem_8140]); getitem_7885 = getitem_7902 = getitem_7919 = getitem_7936 = getitem_7953 = getitem_7970 = getitem_7987 = getitem_8004 = getitem_8021 = getitem_8038 = getitem_8055 = getitem_8072 = getitem_8089 = getitem_8106 = getitem_8123 = getitem_8140 = None + reduce_scatter_tensor_75 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_276, 'sum', 16, '1025'); cat_276 = None + wait_tensor_645 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_75); reduce_scatter_tensor_75 = None + convert_element_type_1852 = torch.ops.prims.convert_element_type.default(_grouped_mm_108, torch.float32); _grouped_mm_108 = None + div_163 = torch.ops.aten.div.Tensor(convert_element_type_1852, 128); convert_element_type_1852 = None + split_439 = torch.ops.aten.split.Tensor(div_163, 128, 1); div_163 = None + getitem_8157 = split_439[0] + getitem_8174 = split_439[1] + getitem_8191 = split_439[2] + getitem_8208 = split_439[3] + getitem_8225 = split_439[4] + getitem_8242 = split_439[5] + getitem_8259 = split_439[6] + getitem_8276 = split_439[7] + getitem_8293 = split_439[8] + getitem_8310 = split_439[9] + getitem_8327 = split_439[10] + getitem_8344 = split_439[11] + getitem_8361 = split_439[12] + getitem_8378 = split_439[13] + getitem_8395 = split_439[14] + getitem_8412 = split_439[15]; split_439 = None + cat_277 = torch.ops.aten.cat.default([getitem_8157, getitem_8174, getitem_8191, getitem_8208, getitem_8225, getitem_8242, getitem_8259, getitem_8276, getitem_8293, getitem_8310, getitem_8327, getitem_8344, getitem_8361, getitem_8378, getitem_8395, getitem_8412]); getitem_8157 = getitem_8174 = getitem_8191 = getitem_8208 = getitem_8225 = getitem_8242 = getitem_8259 = getitem_8276 = getitem_8293 = getitem_8310 = getitem_8327 = getitem_8344 = getitem_8361 = getitem_8378 = getitem_8395 = getitem_8412 = None + reduce_scatter_tensor_76 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_277, 'sum', 16, '1025'); cat_277 = None + wait_tensor_646 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_76); reduce_scatter_tensor_76 = None + convert_element_type_1853 = torch.ops.prims.convert_element_type.default(_grouped_mm_112, torch.float32); _grouped_mm_112 = None + div_164 = torch.ops.aten.div.Tensor(convert_element_type_1853, 128); convert_element_type_1853 = None + split_456 = torch.ops.aten.split.Tensor(div_164, 88, 1); div_164 = None + getitem_8429 = split_456[0] + getitem_8446 = split_456[1] + getitem_8463 = split_456[2] + getitem_8480 = split_456[3] + getitem_8497 = split_456[4] + getitem_8514 = split_456[5] + getitem_8531 = split_456[6] + getitem_8548 = split_456[7] + getitem_8565 = split_456[8] + getitem_8582 = split_456[9] + getitem_8599 = split_456[10] + getitem_8616 = split_456[11] + getitem_8633 = split_456[12] + getitem_8650 = split_456[13] + getitem_8667 = split_456[14] + getitem_8684 = split_456[15]; split_456 = None + cat_278 = torch.ops.aten.cat.default([getitem_8429, getitem_8446, getitem_8463, getitem_8480, getitem_8497, getitem_8514, getitem_8531, getitem_8548, getitem_8565, getitem_8582, getitem_8599, getitem_8616, getitem_8633, getitem_8650, getitem_8667, getitem_8684]); getitem_8429 = getitem_8446 = getitem_8463 = getitem_8480 = getitem_8497 = getitem_8514 = getitem_8531 = getitem_8548 = getitem_8565 = getitem_8582 = getitem_8599 = getitem_8616 = getitem_8633 = getitem_8650 = getitem_8667 = getitem_8684 = None + reduce_scatter_tensor_77 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_278, 'sum', 16, '1025'); cat_278 = None + wait_tensor_647 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_77); reduce_scatter_tensor_77 = None + index_put_62 = torch.ops.aten.index_put.default(full_378, [getitem_2222], add_1855, True); full_378 = getitem_2222 = add_1855 = None + slice_192 = torch.ops.aten.slice.Tensor(index_put_62, 0, 0, add_1856); index_put_62 = add_1856 = None + all_to_all_single_89 = torch.ops._c10d_functional.all_to_all_single.default(slice_192, [_local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327], [_local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335], '1033'); slice_192 = _local_scalar_dense_320 = _local_scalar_dense_321 = _local_scalar_dense_322 = _local_scalar_dense_323 = _local_scalar_dense_324 = _local_scalar_dense_325 = _local_scalar_dense_326 = _local_scalar_dense_327 = _local_scalar_dense_328 = _local_scalar_dense_329 = _local_scalar_dense_330 = _local_scalar_dense_331 = _local_scalar_dense_332 = _local_scalar_dense_333 = _local_scalar_dense_334 = _local_scalar_dense_335 = None + wait_tensor_648 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_89); all_to_all_single_89 = None + index_put_63 = torch.ops.aten.index_put.default(full_default_52, [div_102], wait_tensor_648, True); div_102 = wait_tensor_648 = None + add_1860 = torch.ops.aten.add.Tensor(add_1852, index_put_63); add_1852 = index_put_63 = None + mul_1473 = torch.ops.aten.mul.Tensor(view_1878, 1.0); view_1878 = None + scatter_add_5 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_2219, mul_1473); getitem_2219 = mul_1473 = None + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(mm_171, torch.float32); mm_171 = None + sub_480 = torch.ops.aten.sub.Tensor(convert_element_type_1145, amax_20); convert_element_type_1145 = amax_20 = None + exp_61 = torch.ops.aten.exp.default(sub_480); sub_480 = None + div_101 = torch.ops.aten.div.Tensor(exp_61, sum_81); exp_61 = sum_81 = None + mul_1474 = torch.ops.aten.mul.Tensor(scatter_add_5, div_101); scatter_add_5 = None + sum_147 = torch.ops.aten.sum.dim_IntList(mul_1474, [1], True) + neg_70 = torch.ops.aten.neg.default(div_101); div_101 = None + fma_5 = torch.ops.prims.fma.default(neg_70, sum_147, mul_1474); neg_70 = sum_147 = mul_1474 = None + convert_element_type_1854 = torch.ops.prims.convert_element_type.default(fma_5, torch.bfloat16); fma_5 = None + permute_682 = torch.ops.aten.permute.default(convert_element_type_1854, [1, 0]) + mm_304 = torch.ops.aten.mm.default(permute_682, view_1398); permute_682 = view_1398 = None + convert_element_type_1142 = torch.ops.prims.convert_element_type.default(primals_351, torch.bfloat16); primals_351 = None + all_gather_into_tensor_358 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1142, 128, '0'); convert_element_type_1142 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_358); all_gather_into_tensor_358 = None + slice_129 = torch.ops.aten.slice.Tensor(wait_tensor_438, 0, 0, 64); wait_tensor_438 = None + permute_319 = torch.ops.aten.permute.default(slice_129, [1, 0]); slice_129 = None + permute_684 = torch.ops.aten.permute.default(permute_319, [1, 0]); permute_319 = None + mm_305 = torch.ops.aten.mm.default(convert_element_type_1854, permute_684); convert_element_type_1854 = permute_684 = None + add_1861 = torch.ops.aten.add.Tensor(add_1860, mm_305); add_1860 = mm_305 = None + convert_element_type_1859 = torch.ops.prims.convert_element_type.default(mm_304, torch.float32); mm_304 = None + split_472 = torch.ops.aten.split.Tensor(convert_element_type_1859, 1); convert_element_type_1859 = None + getitem_8685 = split_472[0] + getitem_8686 = split_472[1] + getitem_8687 = split_472[2] + getitem_8688 = split_472[3] + getitem_8689 = split_472[4] + getitem_8690 = split_472[5] + getitem_8691 = split_472[6] + getitem_8692 = split_472[7] + getitem_8693 = split_472[8] + getitem_8694 = split_472[9] + getitem_8695 = split_472[10] + getitem_8696 = split_472[11] + getitem_8697 = split_472[12] + getitem_8698 = split_472[13] + getitem_8699 = split_472[14] + getitem_8700 = split_472[15] + getitem_8701 = split_472[16] + getitem_8702 = split_472[17] + getitem_8703 = split_472[18] + getitem_8704 = split_472[19] + getitem_8705 = split_472[20] + getitem_8706 = split_472[21] + getitem_8707 = split_472[22] + getitem_8708 = split_472[23] + getitem_8709 = split_472[24] + getitem_8710 = split_472[25] + getitem_8711 = split_472[26] + getitem_8712 = split_472[27] + getitem_8713 = split_472[28] + getitem_8714 = split_472[29] + getitem_8715 = split_472[30] + getitem_8716 = split_472[31] + getitem_8717 = split_472[32] + getitem_8718 = split_472[33] + getitem_8719 = split_472[34] + getitem_8720 = split_472[35] + getitem_8721 = split_472[36] + getitem_8722 = split_472[37] + getitem_8723 = split_472[38] + getitem_8724 = split_472[39] + getitem_8725 = split_472[40] + getitem_8726 = split_472[41] + getitem_8727 = split_472[42] + getitem_8728 = split_472[43] + getitem_8729 = split_472[44] + getitem_8730 = split_472[45] + getitem_8731 = split_472[46] + getitem_8732 = split_472[47] + getitem_8733 = split_472[48] + getitem_8734 = split_472[49] + getitem_8735 = split_472[50] + getitem_8736 = split_472[51] + getitem_8737 = split_472[52] + getitem_8738 = split_472[53] + getitem_8739 = split_472[54] + getitem_8740 = split_472[55] + getitem_8741 = split_472[56] + getitem_8742 = split_472[57] + getitem_8743 = split_472[58] + getitem_8744 = split_472[59] + getitem_8745 = split_472[60] + getitem_8746 = split_472[61] + getitem_8747 = split_472[62] + getitem_8748 = split_472[63]; split_472 = None + cat_279 = torch.ops.aten.cat.default([getitem_8685, getitem_8686, getitem_8687, getitem_8688, getitem_8689, getitem_8690, getitem_8691, getitem_8692, getitem_8693, getitem_8694, getitem_8695, getitem_8696, getitem_8697, getitem_8698, getitem_8699, getitem_8700, getitem_8701, getitem_8702, getitem_8703, getitem_8704, getitem_8705, getitem_8706, getitem_8707, getitem_8708, getitem_8709, getitem_8710, getitem_8711, getitem_8712, getitem_8713, getitem_8714, getitem_8715, getitem_8716, getitem_8717, getitem_8718, getitem_8719, getitem_8720, getitem_8721, getitem_8722, getitem_8723, getitem_8724, getitem_8725, getitem_8726, getitem_8727, getitem_8728, getitem_8729, getitem_8730, getitem_8731, getitem_8732, getitem_8733, getitem_8734, getitem_8735, getitem_8736, getitem_8737, getitem_8738, getitem_8739, getitem_8740, getitem_8741, getitem_8742, getitem_8743, getitem_8744, getitem_8745, getitem_8746, getitem_8747, getitem_8748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_8685 = getitem_8686 = getitem_8687 = getitem_8688 = getitem_8689 = getitem_8690 = getitem_8691 = getitem_8692 = getitem_8693 = getitem_8694 = getitem_8695 = getitem_8696 = getitem_8697 = getitem_8698 = getitem_8699 = getitem_8700 = getitem_8701 = getitem_8702 = getitem_8703 = getitem_8704 = getitem_8705 = getitem_8706 = getitem_8707 = getitem_8708 = getitem_8709 = getitem_8710 = getitem_8711 = getitem_8712 = getitem_8713 = getitem_8714 = getitem_8715 = getitem_8716 = getitem_8717 = getitem_8718 = getitem_8719 = getitem_8720 = getitem_8721 = getitem_8722 = getitem_8723 = getitem_8724 = getitem_8725 = getitem_8726 = getitem_8727 = getitem_8728 = getitem_8729 = getitem_8730 = getitem_8731 = getitem_8732 = getitem_8733 = getitem_8734 = getitem_8735 = getitem_8736 = getitem_8737 = getitem_8738 = getitem_8739 = getitem_8740 = getitem_8741 = getitem_8742 = getitem_8743 = getitem_8744 = getitem_8745 = getitem_8746 = getitem_8747 = getitem_8748 = None + reduce_scatter_tensor_78 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_279, 'avg', 128, '0'); cat_279 = None + wait_tensor_649 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_78); reduce_scatter_tensor_78 = None + view_1880 = torch.ops.aten.view.default(add_1861, [2, 4096, 2048]); add_1861 = None + convert_element_type_1860 = torch.ops.prims.convert_element_type.default(view_1880, torch.float32); view_1880 = None + convert_element_type_1139 = torch.ops.prims.convert_element_type.default(primals_349, torch.bfloat16); primals_349 = None + all_gather_into_tensor_357 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1139, 128, '0'); convert_element_type_1139 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_357); all_gather_into_tensor_357 = None + convert_element_type_1862 = torch.ops.prims.convert_element_type.default(wait_tensor_437, torch.float32); wait_tensor_437 = None + mul_1475 = torch.ops.aten.mul.Tensor(convert_element_type_1860, convert_element_type_1862); convert_element_type_1862 = None + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(add_1368, torch.float32); add_1368 = None + mul_995 = torch.ops.aten.mul.Tensor(convert_element_type_1140, rsqrt_65); convert_element_type_1140 = None + mul_1477 = torch.ops.aten.mul.Tensor(mul_995, mul_1475) + sum_148 = torch.ops.aten.sum.dim_IntList(mul_1477, [2], True); mul_1477 = None + div_165 = torch.ops.aten.div.Tensor(mul_995, 2048) + mul_1478 = torch.ops.aten.mul.Tensor(div_165, sum_148); div_165 = sum_148 = None + sub_658 = torch.ops.aten.sub.Tensor(mul_1475, mul_1478); mul_1475 = mul_1478 = None + mul_1479 = torch.ops.aten.mul.Tensor(sub_658, rsqrt_65); sub_658 = rsqrt_65 = None + mul_1480 = torch.ops.aten.mul.Tensor(convert_element_type_1860, mul_995); convert_element_type_1860 = mul_995 = None + sum_149 = torch.ops.aten.sum.dim_IntList(mul_1480, [0, 1]); mul_1480 = None + convert_element_type_1863 = torch.ops.prims.convert_element_type.default(mul_1479, torch.bfloat16); mul_1479 = None + add_1862 = torch.ops.aten.add.Tensor(add_1849, convert_element_type_1863); add_1849 = convert_element_type_1863 = None + convert_element_type_default_66 = torch.ops.prims.convert_element_type.default(sum_149, torch.float32); sum_149 = None + reduce_scatter_tensor_79 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_66, 'avg', 128, '0'); convert_element_type_default_66 = None + wait_tensor_650 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_79); reduce_scatter_tensor_79 = None + view_1881 = torch.ops.aten.view.default(add_1862, [8192, 2048]) + permute_686 = torch.ops.aten.permute.default(view_1881, [1, 0]) + permute_317 = torch.ops.aten.permute.default(getitem_2215, [0, 2, 1, 3]) + view_1393 = torch.ops.aten.view.default(permute_317, [2, 4096, -1]); permute_317 = None + view_1395 = torch.ops.aten.view.default(view_1393, [8192, 2048]); view_1393 = None + mm_306 = torch.ops.aten.mm.default(permute_686, view_1395); permute_686 = view_1395 = None + convert_element_type_1136 = torch.ops.prims.convert_element_type.default(primals_348, torch.bfloat16); primals_348 = None + all_gather_into_tensor_356 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1136, 128, '0'); convert_element_type_1136 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_356); all_gather_into_tensor_356 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_436, [1, 0]); wait_tensor_436 = None + permute_688 = torch.ops.aten.permute.default(permute_318, [1, 0]); permute_318 = None + mm_307 = torch.ops.aten.mm.default(view_1881, permute_688); view_1881 = permute_688 = None + view_1882 = torch.ops.aten.view.default(mm_307, [2, 4096, 2048]); mm_307 = None + convert_element_type_1870 = torch.ops.prims.convert_element_type.default(mm_306, torch.float32); mm_306 = None + reduce_scatter_tensor_80 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1870, 'avg', 128, '0'); convert_element_type_1870 = None + wait_tensor_651 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_80); reduce_scatter_tensor_80 = None + view_1883 = torch.ops.aten.view.default(view_1882, [2, 4096, 16, 128]); view_1882 = None + permute_690 = torch.ops.aten.permute.default(view_1883, [0, 2, 1, 3]); view_1883 = None + fw_graph5 = self.fw_graph5 + joint_graph5 = self.joint_graph5 + mask_graph5 = self.mask_graph5 + flex_attention_backward_5 = torch.ops.higher_order.flex_attention_backward(permute_314, permute_315, permute_316, getitem_2215, getitem_2216, permute_690, None, fw_graph5, joint_graph5, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph5), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_314 = permute_315 = permute_316 = getitem_2215 = getitem_2216 = permute_690 = fw_graph5 = joint_graph5 = mask_graph5 = None + getitem_8749 = flex_attention_backward_5[0] + getitem_8750 = flex_attention_backward_5[1] + getitem_8751 = flex_attention_backward_5[2]; flex_attention_backward_5 = None + permute_691 = torch.ops.aten.permute.default(getitem_8751, [0, 2, 1, 3]); getitem_8751 = None + permute_692 = torch.ops.aten.permute.default(getitem_8750, [0, 2, 1, 3]); getitem_8750 = None + permute_693 = torch.ops.aten.permute.default(getitem_8749, [0, 2, 1, 3]); getitem_8749 = None + slice_194 = torch.ops.aten.slice.Tensor(permute_692, 3, 0, 128) + slice_195 = torch.ops.aten.slice.Tensor(permute_692, 3, 128, 192); permute_692 = None + sum_150 = torch.ops.aten.sum.dim_IntList(slice_195, [2], True); slice_195 = None + cat_280 = torch.ops.aten.cat.default([slice_194, permute_691], 3); slice_194 = permute_691 = None + view_1884 = torch.ops.aten.view.default(cat_280, [2, 4096, 4096]); cat_280 = None + view_1885 = torch.ops.aten.view.default(view_1884, [8192, 4096]); view_1884 = None + permute_694 = torch.ops.aten.permute.default(view_1885, [1, 0]) + mm_308 = torch.ops.aten.mm.default(permute_694, view_1390); permute_694 = view_1390 = None + convert_element_type_1133 = torch.ops.prims.convert_element_type.default(primals_347, torch.bfloat16); primals_347 = None + all_gather_into_tensor_355 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1133, 128, '0'); convert_element_type_1133 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_355); all_gather_into_tensor_355 = None + permute_313 = torch.ops.aten.permute.default(wait_tensor_435, [1, 0]); wait_tensor_435 = None + permute_696 = torch.ops.aten.permute.default(permute_313, [1, 0]); permute_313 = None + mm_309 = torch.ops.aten.mm.default(view_1885, permute_696); view_1885 = permute_696 = None + view_1886 = torch.ops.aten.view.default(mm_309, [2, 4096, 512]); mm_309 = None + convert_element_type_1875 = torch.ops.prims.convert_element_type.default(mm_308, torch.float32); mm_308 = None + reduce_scatter_tensor_81 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1875, 'avg', 128, '0'); convert_element_type_1875 = None + wait_tensor_652 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_81); reduce_scatter_tensor_81 = None + convert_element_type_1876 = torch.ops.prims.convert_element_type.default(view_1886, torch.float32); view_1886 = None + convert_element_type_1130 = torch.ops.prims.convert_element_type.default(primals_346, torch.bfloat16); primals_346 = None + all_gather_into_tensor_354 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1130, 128, '0'); convert_element_type_1130 = None + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_354); all_gather_into_tensor_354 = None + convert_element_type_1878 = torch.ops.prims.convert_element_type.default(wait_tensor_434, torch.float32); wait_tensor_434 = None + mul_1481 = torch.ops.aten.mul.Tensor(convert_element_type_1876, convert_element_type_1878); convert_element_type_1878 = None + convert_element_type_1131 = torch.ops.prims.convert_element_type.default(getitem_2211, torch.float32); getitem_2211 = None + mul_993 = torch.ops.aten.mul.Tensor(convert_element_type_1131, rsqrt_64); convert_element_type_1131 = None + mul_1483 = torch.ops.aten.mul.Tensor(mul_993, mul_1481) + sum_151 = torch.ops.aten.sum.dim_IntList(mul_1483, [2], True); mul_1483 = None + div_166 = torch.ops.aten.div.Tensor(mul_993, 512) + mul_1484 = torch.ops.aten.mul.Tensor(div_166, sum_151); div_166 = sum_151 = None + sub_659 = torch.ops.aten.sub.Tensor(mul_1481, mul_1484); mul_1481 = mul_1484 = None + mul_1485 = torch.ops.aten.mul.Tensor(sub_659, rsqrt_64); sub_659 = rsqrt_64 = None + mul_1486 = torch.ops.aten.mul.Tensor(convert_element_type_1876, mul_993); convert_element_type_1876 = mul_993 = None + sum_152 = torch.ops.aten.sum.dim_IntList(mul_1486, [0, 1]); mul_1486 = None + convert_element_type_1879 = torch.ops.prims.convert_element_type.default(mul_1485, torch.bfloat16); mul_1485 = None + convert_element_type_default_65 = torch.ops.prims.convert_element_type.default(sum_152, torch.float32); sum_152 = None + reduce_scatter_tensor_82 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_65, 'avg', 128, '0'); convert_element_type_default_65 = None + wait_tensor_653 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_82); reduce_scatter_tensor_82 = None + convert_element_type_1882 = torch.ops.prims.convert_element_type.default(sum_150, torch.float32); sum_150 = None + view_1887 = torch.ops.aten.view.default(convert_element_type_1882, [2, 4096, 1, 32, 2]); convert_element_type_1882 = None + view_as_complex_64 = torch.ops.aten.view_as_complex.default(view_1887); view_1887 = None + mul_1487 = torch.ops.aten.mul.Tensor(view_as_complex_64, clone_9); view_as_complex_64 = None + view_as_real_64 = torch.ops.aten.view_as_real.default(mul_1487); mul_1487 = None + view_1888 = torch.ops.aten.view.default(view_as_real_64, [2, 4096, 1, 64]); view_as_real_64 = None + convert_element_type_1883 = torch.ops.prims.convert_element_type.default(view_1888, torch.bfloat16); view_1888 = None + squeeze_31 = torch.ops.aten.squeeze.dim(convert_element_type_1883, 2); convert_element_type_1883 = None + cat_281 = torch.ops.aten.cat.default([convert_element_type_1879, squeeze_31], 2); convert_element_type_1879 = squeeze_31 = None + view_1889 = torch.ops.aten.view.default(cat_281, [8192, 576]); cat_281 = None + permute_698 = torch.ops.aten.permute.default(view_1889, [1, 0]) + mm_310 = torch.ops.aten.mm.default(permute_698, view_1376); permute_698 = None + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(primals_345, torch.bfloat16); primals_345 = None + all_gather_into_tensor_353 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1125, 128, '0'); convert_element_type_1125 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_353); all_gather_into_tensor_353 = None + slice_127 = torch.ops.aten.slice.Tensor(wait_tensor_433, 0, 0, 576); wait_tensor_433 = None + permute_312 = torch.ops.aten.permute.default(slice_127, [1, 0]); slice_127 = None + permute_700 = torch.ops.aten.permute.default(permute_312, [1, 0]); permute_312 = None + mm_311 = torch.ops.aten.mm.default(view_1889, permute_700); view_1889 = permute_700 = None + view_1890 = torch.ops.aten.view.default(mm_311, [2, 4096, 2048]); mm_311 = None + convert_element_type_1888 = torch.ops.prims.convert_element_type.default(mm_310, torch.float32); mm_310 = None + split_473 = torch.ops.aten.split.Tensor(convert_element_type_1888, 5); convert_element_type_1888 = None + getitem_8753 = split_473[0] + getitem_8754 = split_473[1] + getitem_8755 = split_473[2] + getitem_8756 = split_473[3] + getitem_8757 = split_473[4] + getitem_8758 = split_473[5] + getitem_8759 = split_473[6] + getitem_8760 = split_473[7] + getitem_8761 = split_473[8] + getitem_8762 = split_473[9] + getitem_8763 = split_473[10] + getitem_8764 = split_473[11] + getitem_8765 = split_473[12] + getitem_8766 = split_473[13] + getitem_8767 = split_473[14] + getitem_8768 = split_473[15] + getitem_8769 = split_473[16] + getitem_8770 = split_473[17] + getitem_8771 = split_473[18] + getitem_8772 = split_473[19] + getitem_8773 = split_473[20] + getitem_8774 = split_473[21] + getitem_8775 = split_473[22] + getitem_8776 = split_473[23] + getitem_8777 = split_473[24] + getitem_8778 = split_473[25] + getitem_8779 = split_473[26] + getitem_8780 = split_473[27] + getitem_8781 = split_473[28] + getitem_8782 = split_473[29] + getitem_8783 = split_473[30] + getitem_8784 = split_473[31] + getitem_8785 = split_473[32] + getitem_8786 = split_473[33] + getitem_8787 = split_473[34] + getitem_8788 = split_473[35] + getitem_8789 = split_473[36] + getitem_8790 = split_473[37] + getitem_8791 = split_473[38] + getitem_8792 = split_473[39] + getitem_8793 = split_473[40] + getitem_8794 = split_473[41] + getitem_8795 = split_473[42] + getitem_8796 = split_473[43] + getitem_8797 = split_473[44] + getitem_8798 = split_473[45] + getitem_8799 = split_473[46] + getitem_8800 = split_473[47] + getitem_8801 = split_473[48] + getitem_8802 = split_473[49] + getitem_8803 = split_473[50] + getitem_8804 = split_473[51] + getitem_8805 = split_473[52] + getitem_8806 = split_473[53] + getitem_8807 = split_473[54] + getitem_8808 = split_473[55] + getitem_8809 = split_473[56] + getitem_8810 = split_473[57] + getitem_8811 = split_473[58] + getitem_8812 = split_473[59] + getitem_8813 = split_473[60] + getitem_8814 = split_473[61] + getitem_8815 = split_473[62] + getitem_8816 = split_473[63] + getitem_8817 = split_473[64] + getitem_8818 = split_473[65] + getitem_8819 = split_473[66] + getitem_8820 = split_473[67] + getitem_8821 = split_473[68] + getitem_8822 = split_473[69] + getitem_8823 = split_473[70] + getitem_8824 = split_473[71] + getitem_8825 = split_473[72] + getitem_8826 = split_473[73] + getitem_8827 = split_473[74] + getitem_8828 = split_473[75] + getitem_8829 = split_473[76] + getitem_8830 = split_473[77] + getitem_8831 = split_473[78] + getitem_8832 = split_473[79] + getitem_8833 = split_473[80] + getitem_8834 = split_473[81] + getitem_8835 = split_473[82] + getitem_8836 = split_473[83] + getitem_8837 = split_473[84] + getitem_8838 = split_473[85] + getitem_8839 = split_473[86] + getitem_8840 = split_473[87] + getitem_8841 = split_473[88] + getitem_8842 = split_473[89] + getitem_8843 = split_473[90] + getitem_8844 = split_473[91] + getitem_8845 = split_473[92] + getitem_8846 = split_473[93] + getitem_8847 = split_473[94] + getitem_8848 = split_473[95] + getitem_8849 = split_473[96] + getitem_8850 = split_473[97] + getitem_8851 = split_473[98] + getitem_8852 = split_473[99] + getitem_8853 = split_473[100] + getitem_8854 = split_473[101] + getitem_8855 = split_473[102] + getitem_8856 = split_473[103] + getitem_8857 = split_473[104] + getitem_8858 = split_473[105] + getitem_8859 = split_473[106] + getitem_8860 = split_473[107] + getitem_8861 = split_473[108] + getitem_8862 = split_473[109] + getitem_8863 = split_473[110] + getitem_8864 = split_473[111] + getitem_8865 = split_473[112] + getitem_8866 = split_473[113] + getitem_8867 = split_473[114] + getitem_8868 = split_473[115]; split_473 = None + constant_pad_nd_449 = torch.ops.aten.constant_pad_nd.default(getitem_8868, [0, 0, 0, 4], 0.0); getitem_8868 = None + cat_282 = torch.ops.aten.cat.default([getitem_8753, getitem_8754, getitem_8755, getitem_8756, getitem_8757, getitem_8758, getitem_8759, getitem_8760, getitem_8761, getitem_8762, getitem_8763, getitem_8764, getitem_8765, getitem_8766, getitem_8767, getitem_8768, getitem_8769, getitem_8770, getitem_8771, getitem_8772, getitem_8773, getitem_8774, getitem_8775, getitem_8776, getitem_8777, getitem_8778, getitem_8779, getitem_8780, getitem_8781, getitem_8782, getitem_8783, getitem_8784, getitem_8785, getitem_8786, getitem_8787, getitem_8788, getitem_8789, getitem_8790, getitem_8791, getitem_8792, getitem_8793, getitem_8794, getitem_8795, getitem_8796, getitem_8797, getitem_8798, getitem_8799, getitem_8800, getitem_8801, getitem_8802, getitem_8803, getitem_8804, getitem_8805, getitem_8806, getitem_8807, getitem_8808, getitem_8809, getitem_8810, getitem_8811, getitem_8812, getitem_8813, getitem_8814, getitem_8815, getitem_8816, getitem_8817, getitem_8818, getitem_8819, getitem_8820, getitem_8821, getitem_8822, getitem_8823, getitem_8824, getitem_8825, getitem_8826, getitem_8827, getitem_8828, getitem_8829, getitem_8830, getitem_8831, getitem_8832, getitem_8833, getitem_8834, getitem_8835, getitem_8836, getitem_8837, getitem_8838, getitem_8839, getitem_8840, getitem_8841, getitem_8842, getitem_8843, getitem_8844, getitem_8845, getitem_8846, getitem_8847, getitem_8848, getitem_8849, getitem_8850, getitem_8851, getitem_8852, getitem_8853, getitem_8854, getitem_8855, getitem_8856, getitem_8857, getitem_8858, getitem_8859, getitem_8860, getitem_8861, getitem_8862, getitem_8863, getitem_8864, getitem_8865, getitem_8866, getitem_8867, constant_pad_nd_449, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_8753 = getitem_8754 = getitem_8755 = getitem_8756 = getitem_8757 = getitem_8758 = getitem_8759 = getitem_8760 = getitem_8761 = getitem_8762 = getitem_8763 = getitem_8764 = getitem_8765 = getitem_8766 = getitem_8767 = getitem_8768 = getitem_8769 = getitem_8770 = getitem_8771 = getitem_8772 = getitem_8773 = getitem_8774 = getitem_8775 = getitem_8776 = getitem_8777 = getitem_8778 = getitem_8779 = getitem_8780 = getitem_8781 = getitem_8782 = getitem_8783 = getitem_8784 = getitem_8785 = getitem_8786 = getitem_8787 = getitem_8788 = getitem_8789 = getitem_8790 = getitem_8791 = getitem_8792 = getitem_8793 = getitem_8794 = getitem_8795 = getitem_8796 = getitem_8797 = getitem_8798 = getitem_8799 = getitem_8800 = getitem_8801 = getitem_8802 = getitem_8803 = getitem_8804 = getitem_8805 = getitem_8806 = getitem_8807 = getitem_8808 = getitem_8809 = getitem_8810 = getitem_8811 = getitem_8812 = getitem_8813 = getitem_8814 = getitem_8815 = getitem_8816 = getitem_8817 = getitem_8818 = getitem_8819 = getitem_8820 = getitem_8821 = getitem_8822 = getitem_8823 = getitem_8824 = getitem_8825 = getitem_8826 = getitem_8827 = getitem_8828 = getitem_8829 = getitem_8830 = getitem_8831 = getitem_8832 = getitem_8833 = getitem_8834 = getitem_8835 = getitem_8836 = getitem_8837 = getitem_8838 = getitem_8839 = getitem_8840 = getitem_8841 = getitem_8842 = getitem_8843 = getitem_8844 = getitem_8845 = getitem_8846 = getitem_8847 = getitem_8848 = getitem_8849 = getitem_8850 = getitem_8851 = getitem_8852 = getitem_8853 = getitem_8854 = getitem_8855 = getitem_8856 = getitem_8857 = getitem_8858 = getitem_8859 = getitem_8860 = getitem_8861 = getitem_8862 = getitem_8863 = getitem_8864 = getitem_8865 = getitem_8866 = getitem_8867 = constant_pad_nd_449 = None + reduce_scatter_tensor_83 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_282, 'avg', 128, '0'); cat_282 = None + wait_tensor_654 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_83); reduce_scatter_tensor_83 = None + slice_196 = torch.ops.aten.slice.Tensor(permute_693, 3, 0, 128) + slice_197 = torch.ops.aten.slice.Tensor(permute_693, 3, 128, 192); permute_693 = None + convert_element_type_1889 = torch.ops.prims.convert_element_type.default(slice_197, torch.float32); slice_197 = None + view_1891 = torch.ops.aten.view.default(convert_element_type_1889, [2, 4096, 16, 32, 2]); convert_element_type_1889 = None + view_as_complex_65 = torch.ops.aten.view_as_complex.default(view_1891); view_1891 = None + mul_1488 = torch.ops.aten.mul.Tensor(view_as_complex_65, clone_9); view_as_complex_65 = None + view_as_real_65 = torch.ops.aten.view_as_real.default(mul_1488); mul_1488 = None + view_1892 = torch.ops.aten.view.default(view_as_real_65, [2, 4096, 16, 64]); view_as_real_65 = None + convert_element_type_1890 = torch.ops.prims.convert_element_type.default(view_1892, torch.bfloat16); view_1892 = None + cat_283 = torch.ops.aten.cat.default([slice_196, convert_element_type_1890], 3); slice_196 = convert_element_type_1890 = None + view_1893 = torch.ops.aten.view.default(cat_283, [2, 4096, 3072]); cat_283 = None + view_1894 = torch.ops.aten.view.default(view_1893, [8192, 3072]); view_1893 = None + permute_702 = torch.ops.aten.permute.default(view_1894, [1, 0]) + mm_312 = torch.ops.aten.mm.default(permute_702, view_1376); permute_702 = view_1376 = None + convert_element_type_1120 = torch.ops.prims.convert_element_type.default(primals_344, torch.bfloat16); primals_344 = None + all_gather_into_tensor_352 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1120, 128, '0'); convert_element_type_1120 = None + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_352); all_gather_into_tensor_352 = None + permute_311 = torch.ops.aten.permute.default(wait_tensor_432, [1, 0]); wait_tensor_432 = None + permute_704 = torch.ops.aten.permute.default(permute_311, [1, 0]); permute_311 = None + mm_313 = torch.ops.aten.mm.default(view_1894, permute_704); view_1894 = permute_704 = None + view_1895 = torch.ops.aten.view.default(mm_313, [2, 4096, 2048]); mm_313 = None + add_1863 = torch.ops.aten.add.Tensor(view_1890, view_1895); view_1890 = view_1895 = None + convert_element_type_1895 = torch.ops.prims.convert_element_type.default(mm_312, torch.float32); mm_312 = None + reduce_scatter_tensor_84 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1895, 'avg', 128, '0'); convert_element_type_1895 = None + wait_tensor_655 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_84); reduce_scatter_tensor_84 = None + convert_element_type_1896 = torch.ops.prims.convert_element_type.default(add_1863, torch.float32); add_1863 = None + convert_element_type_1117 = torch.ops.prims.convert_element_type.default(primals_343, torch.bfloat16); primals_343 = None + all_gather_into_tensor_351 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1117, 128, '0'); convert_element_type_1117 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_351); all_gather_into_tensor_351 = None + convert_element_type_1898 = torch.ops.prims.convert_element_type.default(wait_tensor_431, torch.float32); wait_tensor_431 = None + mul_1489 = torch.ops.aten.mul.Tensor(convert_element_type_1896, convert_element_type_1898); convert_element_type_1898 = None + convert_element_type_1118 = torch.ops.prims.convert_element_type.default(add_1365, torch.float32); add_1365 = None + mul_989 = torch.ops.aten.mul.Tensor(convert_element_type_1118, rsqrt_63); convert_element_type_1118 = None + mul_1491 = torch.ops.aten.mul.Tensor(mul_989, mul_1489) + sum_153 = torch.ops.aten.sum.dim_IntList(mul_1491, [2], True); mul_1491 = None + div_167 = torch.ops.aten.div.Tensor(mul_989, 2048) + mul_1492 = torch.ops.aten.mul.Tensor(div_167, sum_153); div_167 = sum_153 = None + sub_660 = torch.ops.aten.sub.Tensor(mul_1489, mul_1492); mul_1489 = mul_1492 = None + mul_1493 = torch.ops.aten.mul.Tensor(sub_660, rsqrt_63); sub_660 = rsqrt_63 = None + mul_1494 = torch.ops.aten.mul.Tensor(convert_element_type_1896, mul_989); convert_element_type_1896 = mul_989 = None + sum_154 = torch.ops.aten.sum.dim_IntList(mul_1494, [0, 1]); mul_1494 = None + convert_element_type_1899 = torch.ops.prims.convert_element_type.default(mul_1493, torch.bfloat16); mul_1493 = None + add_1864 = torch.ops.aten.add.Tensor(add_1862, convert_element_type_1899); add_1862 = convert_element_type_1899 = None + convert_element_type_default_64 = torch.ops.prims.convert_element_type.default(sum_154, torch.float32); sum_154 = None + reduce_scatter_tensor_85 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_64, 'avg', 128, '0'); convert_element_type_default_64 = None + wait_tensor_656 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_85); reduce_scatter_tensor_85 = None + view_1896 = torch.ops.aten.view.default(add_1864, [8192, 2048]) + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_1896, 1) + convert_element_type_1902 = torch.ops.prims.convert_element_type.default(unsqueeze_59, torch.float32); unsqueeze_59 = None + bmm_38 = torch.ops.aten.bmm.default(permute_706, convert_element_type_1902); permute_706 = None + bmm_39 = torch.ops.aten.bmm.default(convert_element_type_1902, permute_707); convert_element_type_1902 = permute_707 = None + convert_element_type_1903 = torch.ops.prims.convert_element_type.default(bmm_38, torch.bfloat16); bmm_38 = None + view_1897 = torch.ops.aten.view.default(bmm_39, [8192, 6]); bmm_39 = None + view_1898 = torch.ops.aten.view.default(convert_element_type_1903, [49152, 2048]); convert_element_type_1903 = None + index_64 = torch.ops.aten.index.Tensor(view_1898, [getitem_2111]); view_1898 = getitem_2111 = None + permute_708 = torch.ops.aten.permute.default(view_1896, [1, 0]) + mm_314 = torch.ops.aten.mm.default(permute_708, mul_986); permute_708 = mul_986 = None + convert_element_type_1112 = torch.ops.prims.convert_element_type.default(primals_342, torch.bfloat16); primals_342 = None + all_gather_into_tensor_350 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1112, 128, '0'); convert_element_type_1112 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_350); all_gather_into_tensor_350 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_430, [1, 0]); wait_tensor_430 = None + permute_710 = torch.ops.aten.permute.default(permute_310, [1, 0]); permute_310 = None + mm_315 = torch.ops.aten.mm.default(view_1896, permute_710); view_1896 = permute_710 = None + convert_element_type_1908 = torch.ops.prims.convert_element_type.default(mm_314, torch.float32); mm_314 = None + reduce_scatter_tensor_86 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1908, 'avg', 128, '0'); convert_element_type_1908 = None + wait_tensor_657 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_86); reduce_scatter_tensor_86 = None + convert_element_type_1107 = torch.ops.prims.convert_element_type.default(mm_164, torch.float32); mm_164 = None + neg_40 = torch.ops.aten.neg.default(convert_element_type_1107) + exp_60 = torch.ops.aten.exp.default(neg_40); neg_40 = None + add_1360 = torch.ops.aten.add.Tensor(exp_60, 1); exp_60 = None + div_100 = torch.ops.aten.div.Tensor(convert_element_type_1107, add_1360) + convert_element_type_1108 = torch.ops.prims.convert_element_type.default(div_100, torch.bfloat16); div_100 = None + mul_1495 = torch.ops.aten.mul.Tensor(mm_315, convert_element_type_1108); convert_element_type_1108 = None + mul_1496 = torch.ops.aten.mul.Tensor(mm_315, mm_165); mm_315 = mm_165 = None + permute_712 = torch.ops.aten.permute.default(mul_1495, [1, 0]) + mm_316 = torch.ops.aten.mm.default(permute_712, view_1331); permute_712 = None + convert_element_type_1109 = torch.ops.prims.convert_element_type.default(primals_341, torch.bfloat16); primals_341 = None + all_gather_into_tensor_349 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1109, 128, '0'); convert_element_type_1109 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_349); all_gather_into_tensor_349 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_429, [1, 0]); wait_tensor_429 = None + permute_714 = torch.ops.aten.permute.default(permute_309, [1, 0]); permute_309 = None + mm_317 = torch.ops.aten.mm.default(mul_1495, permute_714); mul_1495 = permute_714 = None + convert_element_type_1913 = torch.ops.prims.convert_element_type.default(mm_316, torch.float32); mm_316 = None + reduce_scatter_tensor_87 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1913, 'avg', 128, '0'); convert_element_type_1913 = None + wait_tensor_658 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_87); reduce_scatter_tensor_87 = None + convert_element_type_1914 = torch.ops.prims.convert_element_type.default(mul_1496, torch.float32); mul_1496 = None + reciprocal_12 = torch.ops.aten.reciprocal.default(add_1360); add_1360 = None + mul_1497 = torch.ops.aten.mul.Tensor(reciprocal_12, 1); reciprocal_12 = None + mul_1498 = torch.ops.aten.mul.Tensor(convert_element_type_1914, mul_1497); convert_element_type_1914 = None + sub_661 = torch.ops.aten.sub.Tensor(1, mul_1497); mul_1497 = None + mul_1499 = torch.ops.aten.mul.Tensor(convert_element_type_1107, sub_661); convert_element_type_1107 = sub_661 = None + add_1866 = torch.ops.aten.add.Tensor(mul_1499, 1); mul_1499 = None + mul_1500 = torch.ops.aten.mul.Tensor(mul_1498, add_1866); mul_1498 = add_1866 = None + convert_element_type_1916 = torch.ops.prims.convert_element_type.default(mul_1500, torch.bfloat16); mul_1500 = None + permute_716 = torch.ops.aten.permute.default(convert_element_type_1916, [1, 0]) + mm_318 = torch.ops.aten.mm.default(permute_716, view_1331); permute_716 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(primals_340, torch.bfloat16); primals_340 = None + all_gather_into_tensor_348 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1104, 128, '0'); convert_element_type_1104 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_348); all_gather_into_tensor_348 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_428, [1, 0]); wait_tensor_428 = None + permute_718 = torch.ops.aten.permute.default(permute_308, [1, 0]); permute_308 = None + mm_319 = torch.ops.aten.mm.default(convert_element_type_1916, permute_718); convert_element_type_1916 = permute_718 = None + add_1867 = torch.ops.aten.add.Tensor(mm_317, mm_319); mm_317 = mm_319 = None + convert_element_type_1921 = torch.ops.prims.convert_element_type.default(mm_318, torch.float32); mm_318 = None + reduce_scatter_tensor_88 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1921, 'avg', 128, '0'); convert_element_type_1921 = None + wait_tensor_659 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_88); reduce_scatter_tensor_88 = None + all_to_all_single_90 = torch.ops._c10d_functional.all_to_all_single.default(index_64, [_local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319], [_local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311], '1033'); index_64 = None + wait_tensor_660 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_90); all_to_all_single_90 = None + full_384 = torch.ops.aten.full.default([sym_size_int_77, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_77 = None + slice_scatter_6 = torch.ops.aten.slice_scatter.default(full_384, wait_tensor_660, 0, 0, -1); wait_tensor_660 = None + index_65 = torch.ops.aten.index.Tensor(slice_scatter_6, [getitem_2112]); slice_scatter_6 = None + permute_720 = torch.ops.aten.permute.default(index_65, [1, 0]) + _grouped_mm_114 = torch.ops.aten._grouped_mm.default(permute_720, mul_966, cumsum_59); permute_720 = mul_966 = None + _grouped_mm_115 = torch.ops.aten._grouped_mm.default(index_65, permute_722, cumsum_59); index_65 = permute_722 = None + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(_grouped_mm_57, torch.float32); _grouped_mm_57 = None + neg_39 = torch.ops.aten.neg.default(convert_element_type_1102) + exp_59 = torch.ops.aten.exp.default(neg_39); neg_39 = None + add_1324 = torch.ops.aten.add.Tensor(exp_59, 1); exp_59 = None + div_99 = torch.ops.aten.div.Tensor(convert_element_type_1102, add_1324) + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(div_99, torch.bfloat16); div_99 = None + mul_1501 = torch.ops.aten.mul.Tensor(_grouped_mm_115, convert_element_type_1103); convert_element_type_1103 = None + mul_1502 = torch.ops.aten.mul.Tensor(_grouped_mm_115, _grouped_mm_58); _grouped_mm_115 = _grouped_mm_58 = None + permute_724 = torch.ops.aten.permute.default(mul_1501, [1, 0]) + _grouped_mm_116 = torch.ops.aten._grouped_mm.default(permute_724, index_39, cumsum_59); permute_724 = None + _grouped_mm_117 = torch.ops.aten._grouped_mm.default(mul_1501, permute_726, cumsum_59); mul_1501 = permute_726 = None + convert_element_type_1922 = torch.ops.prims.convert_element_type.default(mul_1502, torch.float32); mul_1502 = None + reciprocal_13 = torch.ops.aten.reciprocal.default(add_1324); add_1324 = None + mul_1503 = torch.ops.aten.mul.Tensor(reciprocal_13, 1); reciprocal_13 = None + mul_1504 = torch.ops.aten.mul.Tensor(convert_element_type_1922, mul_1503); convert_element_type_1922 = None + sub_662 = torch.ops.aten.sub.Tensor(1, mul_1503); mul_1503 = None + mul_1505 = torch.ops.aten.mul.Tensor(convert_element_type_1102, sub_662); convert_element_type_1102 = sub_662 = None + add_1869 = torch.ops.aten.add.Tensor(mul_1505, 1); mul_1505 = None + mul_1506 = torch.ops.aten.mul.Tensor(mul_1504, add_1869); mul_1504 = add_1869 = None + convert_element_type_1924 = torch.ops.prims.convert_element_type.default(mul_1506, torch.bfloat16); mul_1506 = None + permute_728 = torch.ops.aten.permute.default(convert_element_type_1924, [1, 0]) + _grouped_mm_118 = torch.ops.aten._grouped_mm.default(permute_728, index_39, cumsum_59); permute_728 = index_39 = None + _grouped_mm_119 = torch.ops.aten._grouped_mm.default(convert_element_type_1924, permute_730, cumsum_59); convert_element_type_1924 = permute_730 = cumsum_59 = None + add_1870 = torch.ops.aten.add.Tensor(_grouped_mm_117, _grouped_mm_119); _grouped_mm_117 = _grouped_mm_119 = None + convert_element_type_1925 = torch.ops.prims.convert_element_type.default(_grouped_mm_116, torch.float32); _grouped_mm_116 = None + div_168 = torch.ops.aten.div.Tensor(convert_element_type_1925, 128); convert_element_type_1925 = None + split_475 = torch.ops.aten.split.Tensor(div_168, 88, 1); div_168 = None + getitem_8885 = split_475[0] + getitem_8902 = split_475[1] + getitem_8919 = split_475[2] + getitem_8936 = split_475[3] + getitem_8953 = split_475[4] + getitem_8970 = split_475[5] + getitem_8987 = split_475[6] + getitem_9004 = split_475[7] + getitem_9021 = split_475[8] + getitem_9038 = split_475[9] + getitem_9055 = split_475[10] + getitem_9072 = split_475[11] + getitem_9089 = split_475[12] + getitem_9106 = split_475[13] + getitem_9123 = split_475[14] + getitem_9140 = split_475[15]; split_475 = None + cat_284 = torch.ops.aten.cat.default([getitem_8885, getitem_8902, getitem_8919, getitem_8936, getitem_8953, getitem_8970, getitem_8987, getitem_9004, getitem_9021, getitem_9038, getitem_9055, getitem_9072, getitem_9089, getitem_9106, getitem_9123, getitem_9140]); getitem_8885 = getitem_8902 = getitem_8919 = getitem_8936 = getitem_8953 = getitem_8970 = getitem_8987 = getitem_9004 = getitem_9021 = getitem_9038 = getitem_9055 = getitem_9072 = getitem_9089 = getitem_9106 = getitem_9123 = getitem_9140 = None + reduce_scatter_tensor_89 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_284, 'sum', 16, '1025'); cat_284 = None + wait_tensor_661 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_89); reduce_scatter_tensor_89 = None + convert_element_type_1926 = torch.ops.prims.convert_element_type.default(_grouped_mm_114, torch.float32); _grouped_mm_114 = None + div_169 = torch.ops.aten.div.Tensor(convert_element_type_1926, 128); convert_element_type_1926 = None + split_492 = torch.ops.aten.split.Tensor(div_169, 128, 1); div_169 = None + getitem_9157 = split_492[0] + getitem_9174 = split_492[1] + getitem_9191 = split_492[2] + getitem_9208 = split_492[3] + getitem_9225 = split_492[4] + getitem_9242 = split_492[5] + getitem_9259 = split_492[6] + getitem_9276 = split_492[7] + getitem_9293 = split_492[8] + getitem_9310 = split_492[9] + getitem_9327 = split_492[10] + getitem_9344 = split_492[11] + getitem_9361 = split_492[12] + getitem_9378 = split_492[13] + getitem_9395 = split_492[14] + getitem_9412 = split_492[15]; split_492 = None + cat_285 = torch.ops.aten.cat.default([getitem_9157, getitem_9174, getitem_9191, getitem_9208, getitem_9225, getitem_9242, getitem_9259, getitem_9276, getitem_9293, getitem_9310, getitem_9327, getitem_9344, getitem_9361, getitem_9378, getitem_9395, getitem_9412]); getitem_9157 = getitem_9174 = getitem_9191 = getitem_9208 = getitem_9225 = getitem_9242 = getitem_9259 = getitem_9276 = getitem_9293 = getitem_9310 = getitem_9327 = getitem_9344 = getitem_9361 = getitem_9378 = getitem_9395 = getitem_9412 = None + reduce_scatter_tensor_90 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_285, 'sum', 16, '1025'); cat_285 = None + wait_tensor_662 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_90); reduce_scatter_tensor_90 = None + convert_element_type_1927 = torch.ops.prims.convert_element_type.default(_grouped_mm_118, torch.float32); _grouped_mm_118 = None + div_170 = torch.ops.aten.div.Tensor(convert_element_type_1927, 128); convert_element_type_1927 = None + split_509 = torch.ops.aten.split.Tensor(div_170, 88, 1); div_170 = None + getitem_9429 = split_509[0] + getitem_9446 = split_509[1] + getitem_9463 = split_509[2] + getitem_9480 = split_509[3] + getitem_9497 = split_509[4] + getitem_9514 = split_509[5] + getitem_9531 = split_509[6] + getitem_9548 = split_509[7] + getitem_9565 = split_509[8] + getitem_9582 = split_509[9] + getitem_9599 = split_509[10] + getitem_9616 = split_509[11] + getitem_9633 = split_509[12] + getitem_9650 = split_509[13] + getitem_9667 = split_509[14] + getitem_9684 = split_509[15]; split_509 = None + cat_286 = torch.ops.aten.cat.default([getitem_9429, getitem_9446, getitem_9463, getitem_9480, getitem_9497, getitem_9514, getitem_9531, getitem_9548, getitem_9565, getitem_9582, getitem_9599, getitem_9616, getitem_9633, getitem_9650, getitem_9667, getitem_9684]); getitem_9429 = getitem_9446 = getitem_9463 = getitem_9480 = getitem_9497 = getitem_9514 = getitem_9531 = getitem_9548 = getitem_9565 = getitem_9582 = getitem_9599 = getitem_9616 = getitem_9633 = getitem_9650 = getitem_9667 = getitem_9684 = None + reduce_scatter_tensor_91 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_286, 'sum', 16, '1025'); cat_286 = None + wait_tensor_663 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_91); reduce_scatter_tensor_91 = None + index_put_64 = torch.ops.aten.index_put.default(full_384, [getitem_2112], add_1870, True); full_384 = getitem_2112 = add_1870 = None + slice_198 = torch.ops.aten.slice.Tensor(index_put_64, 0, 0, add_1871); index_put_64 = add_1871 = None + all_to_all_single_91 = torch.ops._c10d_functional.all_to_all_single.default(slice_198, [_local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311], [_local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319], '1033'); slice_198 = _local_scalar_dense_304 = _local_scalar_dense_305 = _local_scalar_dense_306 = _local_scalar_dense_307 = _local_scalar_dense_308 = _local_scalar_dense_309 = _local_scalar_dense_310 = _local_scalar_dense_311 = _local_scalar_dense_312 = _local_scalar_dense_313 = _local_scalar_dense_314 = _local_scalar_dense_315 = _local_scalar_dense_316 = _local_scalar_dense_317 = _local_scalar_dense_318 = _local_scalar_dense_319 = None + wait_tensor_664 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_91); all_to_all_single_91 = None + index_put_65 = torch.ops.aten.index_put.default(full_default_52, [div_97], wait_tensor_664, True); div_97 = wait_tensor_664 = None + add_1875 = torch.ops.aten.add.Tensor(add_1867, index_put_65); add_1867 = index_put_65 = None + mul_1507 = torch.ops.aten.mul.Tensor(view_1897, 1.0); view_1897 = None + scatter_add_6 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_2109, mul_1507); getitem_2109 = mul_1507 = None + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(mm_163, torch.float32); mm_163 = None + sub_456 = torch.ops.aten.sub.Tensor(convert_element_type_1091, amax_19); convert_element_type_1091 = amax_19 = None + exp_58 = torch.ops.aten.exp.default(sub_456); sub_456 = None + div_96 = torch.ops.aten.div.Tensor(exp_58, sum_77); exp_58 = sum_77 = None + mul_1508 = torch.ops.aten.mul.Tensor(scatter_add_6, div_96); scatter_add_6 = None + sum_155 = torch.ops.aten.sum.dim_IntList(mul_1508, [1], True) + neg_73 = torch.ops.aten.neg.default(div_96); div_96 = None + fma_6 = torch.ops.prims.fma.default(neg_73, sum_155, mul_1508); neg_73 = sum_155 = mul_1508 = None + convert_element_type_1928 = torch.ops.prims.convert_element_type.default(fma_6, torch.bfloat16); fma_6 = None + permute_732 = torch.ops.aten.permute.default(convert_element_type_1928, [1, 0]) + mm_320 = torch.ops.aten.mm.default(permute_732, view_1331); permute_732 = view_1331 = None + convert_element_type_1088 = torch.ops.prims.convert_element_type.default(primals_335, torch.bfloat16); primals_335 = None + all_gather_into_tensor_341 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1088, 128, '0'); convert_element_type_1088 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_341); all_gather_into_tensor_341 = None + slice_123 = torch.ops.aten.slice.Tensor(wait_tensor_417, 0, 0, 64); wait_tensor_417 = None + permute_304 = torch.ops.aten.permute.default(slice_123, [1, 0]); slice_123 = None + permute_734 = torch.ops.aten.permute.default(permute_304, [1, 0]); permute_304 = None + mm_321 = torch.ops.aten.mm.default(convert_element_type_1928, permute_734); convert_element_type_1928 = permute_734 = None + add_1876 = torch.ops.aten.add.Tensor(add_1875, mm_321); add_1875 = mm_321 = None + convert_element_type_1933 = torch.ops.prims.convert_element_type.default(mm_320, torch.float32); mm_320 = None + split_525 = torch.ops.aten.split.Tensor(convert_element_type_1933, 1); convert_element_type_1933 = None + getitem_9685 = split_525[0] + getitem_9686 = split_525[1] + getitem_9687 = split_525[2] + getitem_9688 = split_525[3] + getitem_9689 = split_525[4] + getitem_9690 = split_525[5] + getitem_9691 = split_525[6] + getitem_9692 = split_525[7] + getitem_9693 = split_525[8] + getitem_9694 = split_525[9] + getitem_9695 = split_525[10] + getitem_9696 = split_525[11] + getitem_9697 = split_525[12] + getitem_9698 = split_525[13] + getitem_9699 = split_525[14] + getitem_9700 = split_525[15] + getitem_9701 = split_525[16] + getitem_9702 = split_525[17] + getitem_9703 = split_525[18] + getitem_9704 = split_525[19] + getitem_9705 = split_525[20] + getitem_9706 = split_525[21] + getitem_9707 = split_525[22] + getitem_9708 = split_525[23] + getitem_9709 = split_525[24] + getitem_9710 = split_525[25] + getitem_9711 = split_525[26] + getitem_9712 = split_525[27] + getitem_9713 = split_525[28] + getitem_9714 = split_525[29] + getitem_9715 = split_525[30] + getitem_9716 = split_525[31] + getitem_9717 = split_525[32] + getitem_9718 = split_525[33] + getitem_9719 = split_525[34] + getitem_9720 = split_525[35] + getitem_9721 = split_525[36] + getitem_9722 = split_525[37] + getitem_9723 = split_525[38] + getitem_9724 = split_525[39] + getitem_9725 = split_525[40] + getitem_9726 = split_525[41] + getitem_9727 = split_525[42] + getitem_9728 = split_525[43] + getitem_9729 = split_525[44] + getitem_9730 = split_525[45] + getitem_9731 = split_525[46] + getitem_9732 = split_525[47] + getitem_9733 = split_525[48] + getitem_9734 = split_525[49] + getitem_9735 = split_525[50] + getitem_9736 = split_525[51] + getitem_9737 = split_525[52] + getitem_9738 = split_525[53] + getitem_9739 = split_525[54] + getitem_9740 = split_525[55] + getitem_9741 = split_525[56] + getitem_9742 = split_525[57] + getitem_9743 = split_525[58] + getitem_9744 = split_525[59] + getitem_9745 = split_525[60] + getitem_9746 = split_525[61] + getitem_9747 = split_525[62] + getitem_9748 = split_525[63]; split_525 = None + cat_287 = torch.ops.aten.cat.default([getitem_9685, getitem_9686, getitem_9687, getitem_9688, getitem_9689, getitem_9690, getitem_9691, getitem_9692, getitem_9693, getitem_9694, getitem_9695, getitem_9696, getitem_9697, getitem_9698, getitem_9699, getitem_9700, getitem_9701, getitem_9702, getitem_9703, getitem_9704, getitem_9705, getitem_9706, getitem_9707, getitem_9708, getitem_9709, getitem_9710, getitem_9711, getitem_9712, getitem_9713, getitem_9714, getitem_9715, getitem_9716, getitem_9717, getitem_9718, getitem_9719, getitem_9720, getitem_9721, getitem_9722, getitem_9723, getitem_9724, getitem_9725, getitem_9726, getitem_9727, getitem_9728, getitem_9729, getitem_9730, getitem_9731, getitem_9732, getitem_9733, getitem_9734, getitem_9735, getitem_9736, getitem_9737, getitem_9738, getitem_9739, getitem_9740, getitem_9741, getitem_9742, getitem_9743, getitem_9744, getitem_9745, getitem_9746, getitem_9747, getitem_9748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_9685 = getitem_9686 = getitem_9687 = getitem_9688 = getitem_9689 = getitem_9690 = getitem_9691 = getitem_9692 = getitem_9693 = getitem_9694 = getitem_9695 = getitem_9696 = getitem_9697 = getitem_9698 = getitem_9699 = getitem_9700 = getitem_9701 = getitem_9702 = getitem_9703 = getitem_9704 = getitem_9705 = getitem_9706 = getitem_9707 = getitem_9708 = getitem_9709 = getitem_9710 = getitem_9711 = getitem_9712 = getitem_9713 = getitem_9714 = getitem_9715 = getitem_9716 = getitem_9717 = getitem_9718 = getitem_9719 = getitem_9720 = getitem_9721 = getitem_9722 = getitem_9723 = getitem_9724 = getitem_9725 = getitem_9726 = getitem_9727 = getitem_9728 = getitem_9729 = getitem_9730 = getitem_9731 = getitem_9732 = getitem_9733 = getitem_9734 = getitem_9735 = getitem_9736 = getitem_9737 = getitem_9738 = getitem_9739 = getitem_9740 = getitem_9741 = getitem_9742 = getitem_9743 = getitem_9744 = getitem_9745 = getitem_9746 = getitem_9747 = getitem_9748 = None + reduce_scatter_tensor_92 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_287, 'avg', 128, '0'); cat_287 = None + wait_tensor_665 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_92); reduce_scatter_tensor_92 = None + view_1899 = torch.ops.aten.view.default(add_1876, [2, 4096, 2048]); add_1876 = None + convert_element_type_1934 = torch.ops.prims.convert_element_type.default(view_1899, torch.float32); view_1899 = None + convert_element_type_1085 = torch.ops.prims.convert_element_type.default(primals_333, torch.bfloat16); primals_333 = None + all_gather_into_tensor_340 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1085, 128, '0'); convert_element_type_1085 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_340); all_gather_into_tensor_340 = None + convert_element_type_1936 = torch.ops.prims.convert_element_type.default(wait_tensor_416, torch.float32); wait_tensor_416 = None + mul_1509 = torch.ops.aten.mul.Tensor(convert_element_type_1934, convert_element_type_1936); convert_element_type_1936 = None + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(add_1300, torch.float32); add_1300 = None + mul_946 = torch.ops.aten.mul.Tensor(convert_element_type_1086, rsqrt_62); convert_element_type_1086 = None + mul_1511 = torch.ops.aten.mul.Tensor(mul_946, mul_1509) + sum_156 = torch.ops.aten.sum.dim_IntList(mul_1511, [2], True); mul_1511 = None + div_171 = torch.ops.aten.div.Tensor(mul_946, 2048) + mul_1512 = torch.ops.aten.mul.Tensor(div_171, sum_156); div_171 = sum_156 = None + sub_664 = torch.ops.aten.sub.Tensor(mul_1509, mul_1512); mul_1509 = mul_1512 = None + mul_1513 = torch.ops.aten.mul.Tensor(sub_664, rsqrt_62); sub_664 = rsqrt_62 = None + mul_1514 = torch.ops.aten.mul.Tensor(convert_element_type_1934, mul_946); convert_element_type_1934 = mul_946 = None + sum_157 = torch.ops.aten.sum.dim_IntList(mul_1514, [0, 1]); mul_1514 = None + convert_element_type_1937 = torch.ops.prims.convert_element_type.default(mul_1513, torch.bfloat16); mul_1513 = None + add_1877 = torch.ops.aten.add.Tensor(add_1864, convert_element_type_1937); add_1864 = convert_element_type_1937 = None + convert_element_type_default_63 = torch.ops.prims.convert_element_type.default(sum_157, torch.float32); sum_157 = None + reduce_scatter_tensor_93 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_63, 'avg', 128, '0'); convert_element_type_default_63 = None + wait_tensor_666 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_93); reduce_scatter_tensor_93 = None + view_1900 = torch.ops.aten.view.default(add_1877, [8192, 2048]) + permute_736 = torch.ops.aten.permute.default(view_1900, [1, 0]) + permute_302 = torch.ops.aten.permute.default(getitem_2105, [0, 2, 1, 3]) + view_1326 = torch.ops.aten.view.default(permute_302, [2, 4096, -1]); permute_302 = None + view_1328 = torch.ops.aten.view.default(view_1326, [8192, 2048]); view_1326 = None + mm_322 = torch.ops.aten.mm.default(permute_736, view_1328); permute_736 = view_1328 = None + convert_element_type_1082 = torch.ops.prims.convert_element_type.default(primals_332, torch.bfloat16); primals_332 = None + all_gather_into_tensor_339 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1082, 128, '0'); convert_element_type_1082 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_339); all_gather_into_tensor_339 = None + permute_303 = torch.ops.aten.permute.default(wait_tensor_415, [1, 0]); wait_tensor_415 = None + permute_738 = torch.ops.aten.permute.default(permute_303, [1, 0]); permute_303 = None + mm_323 = torch.ops.aten.mm.default(view_1900, permute_738); view_1900 = permute_738 = None + view_1901 = torch.ops.aten.view.default(mm_323, [2, 4096, 2048]); mm_323 = None + convert_element_type_1944 = torch.ops.prims.convert_element_type.default(mm_322, torch.float32); mm_322 = None + reduce_scatter_tensor_94 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1944, 'avg', 128, '0'); convert_element_type_1944 = None + wait_tensor_667 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_94); reduce_scatter_tensor_94 = None + view_1902 = torch.ops.aten.view.default(view_1901, [2, 4096, 16, 128]); view_1901 = None + permute_740 = torch.ops.aten.permute.default(view_1902, [0, 2, 1, 3]); view_1902 = None + fw_graph6 = self.fw_graph6 + joint_graph6 = self.joint_graph6 + mask_graph6 = self.mask_graph6 + flex_attention_backward_6 = torch.ops.higher_order.flex_attention_backward(permute_299, permute_300, permute_301, getitem_2105, getitem_2106, permute_740, None, fw_graph6, joint_graph6, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph6), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_299 = permute_300 = permute_301 = getitem_2105 = getitem_2106 = permute_740 = fw_graph6 = joint_graph6 = mask_graph6 = None + getitem_9749 = flex_attention_backward_6[0] + getitem_9750 = flex_attention_backward_6[1] + getitem_9751 = flex_attention_backward_6[2]; flex_attention_backward_6 = None + permute_741 = torch.ops.aten.permute.default(getitem_9751, [0, 2, 1, 3]); getitem_9751 = None + permute_742 = torch.ops.aten.permute.default(getitem_9750, [0, 2, 1, 3]); getitem_9750 = None + permute_743 = torch.ops.aten.permute.default(getitem_9749, [0, 2, 1, 3]); getitem_9749 = None + slice_200 = torch.ops.aten.slice.Tensor(permute_742, 3, 0, 128) + slice_201 = torch.ops.aten.slice.Tensor(permute_742, 3, 128, 192); permute_742 = None + sum_158 = torch.ops.aten.sum.dim_IntList(slice_201, [2], True); slice_201 = None + cat_288 = torch.ops.aten.cat.default([slice_200, permute_741], 3); slice_200 = permute_741 = None + view_1903 = torch.ops.aten.view.default(cat_288, [2, 4096, 4096]); cat_288 = None + view_1904 = torch.ops.aten.view.default(view_1903, [8192, 4096]); view_1903 = None + permute_744 = torch.ops.aten.permute.default(view_1904, [1, 0]) + mm_324 = torch.ops.aten.mm.default(permute_744, view_1323); permute_744 = view_1323 = None + convert_element_type_1079 = torch.ops.prims.convert_element_type.default(primals_331, torch.bfloat16); primals_331 = None + all_gather_into_tensor_338 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1079, 128, '0'); convert_element_type_1079 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_338); all_gather_into_tensor_338 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_414, [1, 0]); wait_tensor_414 = None + permute_746 = torch.ops.aten.permute.default(permute_298, [1, 0]); permute_298 = None + mm_325 = torch.ops.aten.mm.default(view_1904, permute_746); view_1904 = permute_746 = None + view_1905 = torch.ops.aten.view.default(mm_325, [2, 4096, 512]); mm_325 = None + convert_element_type_1949 = torch.ops.prims.convert_element_type.default(mm_324, torch.float32); mm_324 = None + reduce_scatter_tensor_95 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1949, 'avg', 128, '0'); convert_element_type_1949 = None + wait_tensor_668 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_95); reduce_scatter_tensor_95 = None + convert_element_type_1950 = torch.ops.prims.convert_element_type.default(view_1905, torch.float32); view_1905 = None + convert_element_type_1076 = torch.ops.prims.convert_element_type.default(primals_330, torch.bfloat16); primals_330 = None + all_gather_into_tensor_337 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1076, 128, '0'); convert_element_type_1076 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_337); all_gather_into_tensor_337 = None + convert_element_type_1952 = torch.ops.prims.convert_element_type.default(wait_tensor_413, torch.float32); wait_tensor_413 = None + mul_1515 = torch.ops.aten.mul.Tensor(convert_element_type_1950, convert_element_type_1952); convert_element_type_1952 = None + convert_element_type_1077 = torch.ops.prims.convert_element_type.default(getitem_2101, torch.float32); getitem_2101 = None + mul_944 = torch.ops.aten.mul.Tensor(convert_element_type_1077, rsqrt_61); convert_element_type_1077 = None + mul_1517 = torch.ops.aten.mul.Tensor(mul_944, mul_1515) + sum_159 = torch.ops.aten.sum.dim_IntList(mul_1517, [2], True); mul_1517 = None + div_172 = torch.ops.aten.div.Tensor(mul_944, 512) + mul_1518 = torch.ops.aten.mul.Tensor(div_172, sum_159); div_172 = sum_159 = None + sub_665 = torch.ops.aten.sub.Tensor(mul_1515, mul_1518); mul_1515 = mul_1518 = None + mul_1519 = torch.ops.aten.mul.Tensor(sub_665, rsqrt_61); sub_665 = rsqrt_61 = None + mul_1520 = torch.ops.aten.mul.Tensor(convert_element_type_1950, mul_944); convert_element_type_1950 = mul_944 = None + sum_160 = torch.ops.aten.sum.dim_IntList(mul_1520, [0, 1]); mul_1520 = None + convert_element_type_1953 = torch.ops.prims.convert_element_type.default(mul_1519, torch.bfloat16); mul_1519 = None + convert_element_type_default_62 = torch.ops.prims.convert_element_type.default(sum_160, torch.float32); sum_160 = None + reduce_scatter_tensor_96 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_62, 'avg', 128, '0'); convert_element_type_default_62 = None + wait_tensor_669 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_96); reduce_scatter_tensor_96 = None + convert_element_type_1956 = torch.ops.prims.convert_element_type.default(sum_158, torch.float32); sum_158 = None + view_1906 = torch.ops.aten.view.default(convert_element_type_1956, [2, 4096, 1, 32, 2]); convert_element_type_1956 = None + view_as_complex_66 = torch.ops.aten.view_as_complex.default(view_1906); view_1906 = None + mul_1521 = torch.ops.aten.mul.Tensor(view_as_complex_66, clone_9); view_as_complex_66 = None + view_as_real_66 = torch.ops.aten.view_as_real.default(mul_1521); mul_1521 = None + view_1907 = torch.ops.aten.view.default(view_as_real_66, [2, 4096, 1, 64]); view_as_real_66 = None + convert_element_type_1957 = torch.ops.prims.convert_element_type.default(view_1907, torch.bfloat16); view_1907 = None + squeeze_32 = torch.ops.aten.squeeze.dim(convert_element_type_1957, 2); convert_element_type_1957 = None + cat_289 = torch.ops.aten.cat.default([convert_element_type_1953, squeeze_32], 2); convert_element_type_1953 = squeeze_32 = None + view_1908 = torch.ops.aten.view.default(cat_289, [8192, 576]); cat_289 = None + permute_748 = torch.ops.aten.permute.default(view_1908, [1, 0]) + mm_326 = torch.ops.aten.mm.default(permute_748, view_1309); permute_748 = None + convert_element_type_1071 = torch.ops.prims.convert_element_type.default(primals_329, torch.bfloat16); primals_329 = None + all_gather_into_tensor_336 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1071, 128, '0'); convert_element_type_1071 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_336); all_gather_into_tensor_336 = None + slice_121 = torch.ops.aten.slice.Tensor(wait_tensor_412, 0, 0, 576); wait_tensor_412 = None + permute_297 = torch.ops.aten.permute.default(slice_121, [1, 0]); slice_121 = None + permute_750 = torch.ops.aten.permute.default(permute_297, [1, 0]); permute_297 = None + mm_327 = torch.ops.aten.mm.default(view_1908, permute_750); view_1908 = permute_750 = None + view_1909 = torch.ops.aten.view.default(mm_327, [2, 4096, 2048]); mm_327 = None + convert_element_type_1962 = torch.ops.prims.convert_element_type.default(mm_326, torch.float32); mm_326 = None + split_526 = torch.ops.aten.split.Tensor(convert_element_type_1962, 5); convert_element_type_1962 = None + getitem_9753 = split_526[0] + getitem_9754 = split_526[1] + getitem_9755 = split_526[2] + getitem_9756 = split_526[3] + getitem_9757 = split_526[4] + getitem_9758 = split_526[5] + getitem_9759 = split_526[6] + getitem_9760 = split_526[7] + getitem_9761 = split_526[8] + getitem_9762 = split_526[9] + getitem_9763 = split_526[10] + getitem_9764 = split_526[11] + getitem_9765 = split_526[12] + getitem_9766 = split_526[13] + getitem_9767 = split_526[14] + getitem_9768 = split_526[15] + getitem_9769 = split_526[16] + getitem_9770 = split_526[17] + getitem_9771 = split_526[18] + getitem_9772 = split_526[19] + getitem_9773 = split_526[20] + getitem_9774 = split_526[21] + getitem_9775 = split_526[22] + getitem_9776 = split_526[23] + getitem_9777 = split_526[24] + getitem_9778 = split_526[25] + getitem_9779 = split_526[26] + getitem_9780 = split_526[27] + getitem_9781 = split_526[28] + getitem_9782 = split_526[29] + getitem_9783 = split_526[30] + getitem_9784 = split_526[31] + getitem_9785 = split_526[32] + getitem_9786 = split_526[33] + getitem_9787 = split_526[34] + getitem_9788 = split_526[35] + getitem_9789 = split_526[36] + getitem_9790 = split_526[37] + getitem_9791 = split_526[38] + getitem_9792 = split_526[39] + getitem_9793 = split_526[40] + getitem_9794 = split_526[41] + getitem_9795 = split_526[42] + getitem_9796 = split_526[43] + getitem_9797 = split_526[44] + getitem_9798 = split_526[45] + getitem_9799 = split_526[46] + getitem_9800 = split_526[47] + getitem_9801 = split_526[48] + getitem_9802 = split_526[49] + getitem_9803 = split_526[50] + getitem_9804 = split_526[51] + getitem_9805 = split_526[52] + getitem_9806 = split_526[53] + getitem_9807 = split_526[54] + getitem_9808 = split_526[55] + getitem_9809 = split_526[56] + getitem_9810 = split_526[57] + getitem_9811 = split_526[58] + getitem_9812 = split_526[59] + getitem_9813 = split_526[60] + getitem_9814 = split_526[61] + getitem_9815 = split_526[62] + getitem_9816 = split_526[63] + getitem_9817 = split_526[64] + getitem_9818 = split_526[65] + getitem_9819 = split_526[66] + getitem_9820 = split_526[67] + getitem_9821 = split_526[68] + getitem_9822 = split_526[69] + getitem_9823 = split_526[70] + getitem_9824 = split_526[71] + getitem_9825 = split_526[72] + getitem_9826 = split_526[73] + getitem_9827 = split_526[74] + getitem_9828 = split_526[75] + getitem_9829 = split_526[76] + getitem_9830 = split_526[77] + getitem_9831 = split_526[78] + getitem_9832 = split_526[79] + getitem_9833 = split_526[80] + getitem_9834 = split_526[81] + getitem_9835 = split_526[82] + getitem_9836 = split_526[83] + getitem_9837 = split_526[84] + getitem_9838 = split_526[85] + getitem_9839 = split_526[86] + getitem_9840 = split_526[87] + getitem_9841 = split_526[88] + getitem_9842 = split_526[89] + getitem_9843 = split_526[90] + getitem_9844 = split_526[91] + getitem_9845 = split_526[92] + getitem_9846 = split_526[93] + getitem_9847 = split_526[94] + getitem_9848 = split_526[95] + getitem_9849 = split_526[96] + getitem_9850 = split_526[97] + getitem_9851 = split_526[98] + getitem_9852 = split_526[99] + getitem_9853 = split_526[100] + getitem_9854 = split_526[101] + getitem_9855 = split_526[102] + getitem_9856 = split_526[103] + getitem_9857 = split_526[104] + getitem_9858 = split_526[105] + getitem_9859 = split_526[106] + getitem_9860 = split_526[107] + getitem_9861 = split_526[108] + getitem_9862 = split_526[109] + getitem_9863 = split_526[110] + getitem_9864 = split_526[111] + getitem_9865 = split_526[112] + getitem_9866 = split_526[113] + getitem_9867 = split_526[114] + getitem_9868 = split_526[115]; split_526 = None + constant_pad_nd_526 = torch.ops.aten.constant_pad_nd.default(getitem_9868, [0, 0, 0, 4], 0.0); getitem_9868 = None + cat_290 = torch.ops.aten.cat.default([getitem_9753, getitem_9754, getitem_9755, getitem_9756, getitem_9757, getitem_9758, getitem_9759, getitem_9760, getitem_9761, getitem_9762, getitem_9763, getitem_9764, getitem_9765, getitem_9766, getitem_9767, getitem_9768, getitem_9769, getitem_9770, getitem_9771, getitem_9772, getitem_9773, getitem_9774, getitem_9775, getitem_9776, getitem_9777, getitem_9778, getitem_9779, getitem_9780, getitem_9781, getitem_9782, getitem_9783, getitem_9784, getitem_9785, getitem_9786, getitem_9787, getitem_9788, getitem_9789, getitem_9790, getitem_9791, getitem_9792, getitem_9793, getitem_9794, getitem_9795, getitem_9796, getitem_9797, getitem_9798, getitem_9799, getitem_9800, getitem_9801, getitem_9802, getitem_9803, getitem_9804, getitem_9805, getitem_9806, getitem_9807, getitem_9808, getitem_9809, getitem_9810, getitem_9811, getitem_9812, getitem_9813, getitem_9814, getitem_9815, getitem_9816, getitem_9817, getitem_9818, getitem_9819, getitem_9820, getitem_9821, getitem_9822, getitem_9823, getitem_9824, getitem_9825, getitem_9826, getitem_9827, getitem_9828, getitem_9829, getitem_9830, getitem_9831, getitem_9832, getitem_9833, getitem_9834, getitem_9835, getitem_9836, getitem_9837, getitem_9838, getitem_9839, getitem_9840, getitem_9841, getitem_9842, getitem_9843, getitem_9844, getitem_9845, getitem_9846, getitem_9847, getitem_9848, getitem_9849, getitem_9850, getitem_9851, getitem_9852, getitem_9853, getitem_9854, getitem_9855, getitem_9856, getitem_9857, getitem_9858, getitem_9859, getitem_9860, getitem_9861, getitem_9862, getitem_9863, getitem_9864, getitem_9865, getitem_9866, getitem_9867, constant_pad_nd_526, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_9753 = getitem_9754 = getitem_9755 = getitem_9756 = getitem_9757 = getitem_9758 = getitem_9759 = getitem_9760 = getitem_9761 = getitem_9762 = getitem_9763 = getitem_9764 = getitem_9765 = getitem_9766 = getitem_9767 = getitem_9768 = getitem_9769 = getitem_9770 = getitem_9771 = getitem_9772 = getitem_9773 = getitem_9774 = getitem_9775 = getitem_9776 = getitem_9777 = getitem_9778 = getitem_9779 = getitem_9780 = getitem_9781 = getitem_9782 = getitem_9783 = getitem_9784 = getitem_9785 = getitem_9786 = getitem_9787 = getitem_9788 = getitem_9789 = getitem_9790 = getitem_9791 = getitem_9792 = getitem_9793 = getitem_9794 = getitem_9795 = getitem_9796 = getitem_9797 = getitem_9798 = getitem_9799 = getitem_9800 = getitem_9801 = getitem_9802 = getitem_9803 = getitem_9804 = getitem_9805 = getitem_9806 = getitem_9807 = getitem_9808 = getitem_9809 = getitem_9810 = getitem_9811 = getitem_9812 = getitem_9813 = getitem_9814 = getitem_9815 = getitem_9816 = getitem_9817 = getitem_9818 = getitem_9819 = getitem_9820 = getitem_9821 = getitem_9822 = getitem_9823 = getitem_9824 = getitem_9825 = getitem_9826 = getitem_9827 = getitem_9828 = getitem_9829 = getitem_9830 = getitem_9831 = getitem_9832 = getitem_9833 = getitem_9834 = getitem_9835 = getitem_9836 = getitem_9837 = getitem_9838 = getitem_9839 = getitem_9840 = getitem_9841 = getitem_9842 = getitem_9843 = getitem_9844 = getitem_9845 = getitem_9846 = getitem_9847 = getitem_9848 = getitem_9849 = getitem_9850 = getitem_9851 = getitem_9852 = getitem_9853 = getitem_9854 = getitem_9855 = getitem_9856 = getitem_9857 = getitem_9858 = getitem_9859 = getitem_9860 = getitem_9861 = getitem_9862 = getitem_9863 = getitem_9864 = getitem_9865 = getitem_9866 = getitem_9867 = constant_pad_nd_526 = None + reduce_scatter_tensor_97 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_290, 'avg', 128, '0'); cat_290 = None + wait_tensor_670 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_97); reduce_scatter_tensor_97 = None + slice_202 = torch.ops.aten.slice.Tensor(permute_743, 3, 0, 128) + slice_203 = torch.ops.aten.slice.Tensor(permute_743, 3, 128, 192); permute_743 = None + convert_element_type_1963 = torch.ops.prims.convert_element_type.default(slice_203, torch.float32); slice_203 = None + view_1910 = torch.ops.aten.view.default(convert_element_type_1963, [2, 4096, 16, 32, 2]); convert_element_type_1963 = None + view_as_complex_67 = torch.ops.aten.view_as_complex.default(view_1910); view_1910 = None + mul_1522 = torch.ops.aten.mul.Tensor(view_as_complex_67, clone_9); view_as_complex_67 = None + view_as_real_67 = torch.ops.aten.view_as_real.default(mul_1522); mul_1522 = None + view_1911 = torch.ops.aten.view.default(view_as_real_67, [2, 4096, 16, 64]); view_as_real_67 = None + convert_element_type_1964 = torch.ops.prims.convert_element_type.default(view_1911, torch.bfloat16); view_1911 = None + cat_291 = torch.ops.aten.cat.default([slice_202, convert_element_type_1964], 3); slice_202 = convert_element_type_1964 = None + view_1912 = torch.ops.aten.view.default(cat_291, [2, 4096, 3072]); cat_291 = None + view_1913 = torch.ops.aten.view.default(view_1912, [8192, 3072]); view_1912 = None + permute_752 = torch.ops.aten.permute.default(view_1913, [1, 0]) + mm_328 = torch.ops.aten.mm.default(permute_752, view_1309); permute_752 = view_1309 = None + convert_element_type_1066 = torch.ops.prims.convert_element_type.default(primals_328, torch.bfloat16); primals_328 = None + all_gather_into_tensor_335 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1066, 128, '0'); convert_element_type_1066 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_335); all_gather_into_tensor_335 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_411, [1, 0]); wait_tensor_411 = None + permute_754 = torch.ops.aten.permute.default(permute_296, [1, 0]); permute_296 = None + mm_329 = torch.ops.aten.mm.default(view_1913, permute_754); view_1913 = permute_754 = None + view_1914 = torch.ops.aten.view.default(mm_329, [2, 4096, 2048]); mm_329 = None + add_1878 = torch.ops.aten.add.Tensor(view_1909, view_1914); view_1909 = view_1914 = None + convert_element_type_1969 = torch.ops.prims.convert_element_type.default(mm_328, torch.float32); mm_328 = None + reduce_scatter_tensor_98 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1969, 'avg', 128, '0'); convert_element_type_1969 = None + wait_tensor_671 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_98); reduce_scatter_tensor_98 = None + convert_element_type_1970 = torch.ops.prims.convert_element_type.default(add_1878, torch.float32); add_1878 = None + convert_element_type_1063 = torch.ops.prims.convert_element_type.default(primals_327, torch.bfloat16); primals_327 = None + all_gather_into_tensor_334 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1063, 128, '0'); convert_element_type_1063 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_334); all_gather_into_tensor_334 = None + convert_element_type_1972 = torch.ops.prims.convert_element_type.default(wait_tensor_410, torch.float32); wait_tensor_410 = None + mul_1523 = torch.ops.aten.mul.Tensor(convert_element_type_1970, convert_element_type_1972); convert_element_type_1972 = None + convert_element_type_1064 = torch.ops.prims.convert_element_type.default(add_1297, torch.float32); add_1297 = None + mul_940 = torch.ops.aten.mul.Tensor(convert_element_type_1064, rsqrt_60); convert_element_type_1064 = None + mul_1525 = torch.ops.aten.mul.Tensor(mul_940, mul_1523) + sum_161 = torch.ops.aten.sum.dim_IntList(mul_1525, [2], True); mul_1525 = None + div_173 = torch.ops.aten.div.Tensor(mul_940, 2048) + mul_1526 = torch.ops.aten.mul.Tensor(div_173, sum_161); div_173 = sum_161 = None + sub_666 = torch.ops.aten.sub.Tensor(mul_1523, mul_1526); mul_1523 = mul_1526 = None + mul_1527 = torch.ops.aten.mul.Tensor(sub_666, rsqrt_60); sub_666 = rsqrt_60 = None + mul_1528 = torch.ops.aten.mul.Tensor(convert_element_type_1970, mul_940); convert_element_type_1970 = mul_940 = None + sum_162 = torch.ops.aten.sum.dim_IntList(mul_1528, [0, 1]); mul_1528 = None + convert_element_type_1973 = torch.ops.prims.convert_element_type.default(mul_1527, torch.bfloat16); mul_1527 = None + add_1879 = torch.ops.aten.add.Tensor(add_1877, convert_element_type_1973); add_1877 = convert_element_type_1973 = None + convert_element_type_default_61 = torch.ops.prims.convert_element_type.default(sum_162, torch.float32); sum_162 = None + reduce_scatter_tensor_99 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_61, 'avg', 128, '0'); convert_element_type_default_61 = None + wait_tensor_672 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_99); reduce_scatter_tensor_99 = None + view_1915 = torch.ops.aten.view.default(add_1879, [8192, 2048]) + unsqueeze_60 = torch.ops.aten.unsqueeze.default(view_1915, 1) + convert_element_type_1976 = torch.ops.prims.convert_element_type.default(unsqueeze_60, torch.float32); unsqueeze_60 = None + bmm_40 = torch.ops.aten.bmm.default(permute_756, convert_element_type_1976); permute_756 = None + bmm_41 = torch.ops.aten.bmm.default(convert_element_type_1976, permute_757); convert_element_type_1976 = permute_757 = None + convert_element_type_1977 = torch.ops.prims.convert_element_type.default(bmm_40, torch.bfloat16); bmm_40 = None + view_1916 = torch.ops.aten.view.default(bmm_41, [8192, 6]); bmm_41 = None + view_1917 = torch.ops.aten.view.default(convert_element_type_1977, [49152, 2048]); convert_element_type_1977 = None + index_66 = torch.ops.aten.index.Tensor(view_1917, [getitem_2001]); view_1917 = getitem_2001 = None + permute_758 = torch.ops.aten.permute.default(view_1915, [1, 0]) + mm_330 = torch.ops.aten.mm.default(permute_758, mul_937); permute_758 = mul_937 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(primals_326, torch.bfloat16); primals_326 = None + all_gather_into_tensor_333 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1058, 128, '0'); convert_element_type_1058 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_333); all_gather_into_tensor_333 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_409, [1, 0]); wait_tensor_409 = None + permute_760 = torch.ops.aten.permute.default(permute_295, [1, 0]); permute_295 = None + mm_331 = torch.ops.aten.mm.default(view_1915, permute_760); view_1915 = permute_760 = None + convert_element_type_1982 = torch.ops.prims.convert_element_type.default(mm_330, torch.float32); mm_330 = None + reduce_scatter_tensor_100 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1982, 'avg', 128, '0'); convert_element_type_1982 = None + wait_tensor_673 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_100); reduce_scatter_tensor_100 = None + convert_element_type_1053 = torch.ops.prims.convert_element_type.default(mm_156, torch.float32); mm_156 = None + neg_38 = torch.ops.aten.neg.default(convert_element_type_1053) + exp_57 = torch.ops.aten.exp.default(neg_38); neg_38 = None + add_1292 = torch.ops.aten.add.Tensor(exp_57, 1); exp_57 = None + div_95 = torch.ops.aten.div.Tensor(convert_element_type_1053, add_1292) + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(div_95, torch.bfloat16); div_95 = None + mul_1529 = torch.ops.aten.mul.Tensor(mm_331, convert_element_type_1054); convert_element_type_1054 = None + mul_1530 = torch.ops.aten.mul.Tensor(mm_331, mm_157); mm_331 = mm_157 = None + permute_762 = torch.ops.aten.permute.default(mul_1529, [1, 0]) + mm_332 = torch.ops.aten.mm.default(permute_762, view_1264); permute_762 = None + convert_element_type_1055 = torch.ops.prims.convert_element_type.default(primals_325, torch.bfloat16); primals_325 = None + all_gather_into_tensor_332 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1055, 128, '0'); convert_element_type_1055 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_332); all_gather_into_tensor_332 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_408, [1, 0]); wait_tensor_408 = None + permute_764 = torch.ops.aten.permute.default(permute_294, [1, 0]); permute_294 = None + mm_333 = torch.ops.aten.mm.default(mul_1529, permute_764); mul_1529 = permute_764 = None + convert_element_type_1987 = torch.ops.prims.convert_element_type.default(mm_332, torch.float32); mm_332 = None + reduce_scatter_tensor_101 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1987, 'avg', 128, '0'); convert_element_type_1987 = None + wait_tensor_674 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_101); reduce_scatter_tensor_101 = None + convert_element_type_1988 = torch.ops.prims.convert_element_type.default(mul_1530, torch.float32); mul_1530 = None + reciprocal_14 = torch.ops.aten.reciprocal.default(add_1292); add_1292 = None + mul_1531 = torch.ops.aten.mul.Tensor(reciprocal_14, 1); reciprocal_14 = None + mul_1532 = torch.ops.aten.mul.Tensor(convert_element_type_1988, mul_1531); convert_element_type_1988 = None + sub_667 = torch.ops.aten.sub.Tensor(1, mul_1531); mul_1531 = None + mul_1533 = torch.ops.aten.mul.Tensor(convert_element_type_1053, sub_667); convert_element_type_1053 = sub_667 = None + add_1881 = torch.ops.aten.add.Tensor(mul_1533, 1); mul_1533 = None + mul_1534 = torch.ops.aten.mul.Tensor(mul_1532, add_1881); mul_1532 = add_1881 = None + convert_element_type_1990 = torch.ops.prims.convert_element_type.default(mul_1534, torch.bfloat16); mul_1534 = None + permute_766 = torch.ops.aten.permute.default(convert_element_type_1990, [1, 0]) + mm_334 = torch.ops.aten.mm.default(permute_766, view_1264); permute_766 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(primals_324, torch.bfloat16); primals_324 = None + all_gather_into_tensor_331 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1050, 128, '0'); convert_element_type_1050 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_331); all_gather_into_tensor_331 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_407, [1, 0]); wait_tensor_407 = None + permute_768 = torch.ops.aten.permute.default(permute_293, [1, 0]); permute_293 = None + mm_335 = torch.ops.aten.mm.default(convert_element_type_1990, permute_768); convert_element_type_1990 = permute_768 = None + add_1882 = torch.ops.aten.add.Tensor(mm_333, mm_335); mm_333 = mm_335 = None + convert_element_type_1995 = torch.ops.prims.convert_element_type.default(mm_334, torch.float32); mm_334 = None + reduce_scatter_tensor_102 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1995, 'avg', 128, '0'); convert_element_type_1995 = None + wait_tensor_675 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_102); reduce_scatter_tensor_102 = None + all_to_all_single_92 = torch.ops._c10d_functional.all_to_all_single.default(index_66, [_local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303], [_local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295], '1033'); index_66 = None + wait_tensor_676 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_92); all_to_all_single_92 = None + full_390 = torch.ops.aten.full.default([sym_size_int_73, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_73 = None + slice_scatter_7 = torch.ops.aten.slice_scatter.default(full_390, wait_tensor_676, 0, 0, -1); wait_tensor_676 = None + index_67 = torch.ops.aten.index.Tensor(slice_scatter_7, [getitem_2002]); slice_scatter_7 = None + permute_770 = torch.ops.aten.permute.default(index_67, [1, 0]) + _grouped_mm_120 = torch.ops.aten._grouped_mm.default(permute_770, mul_917, cumsum_56); permute_770 = mul_917 = None + _grouped_mm_121 = torch.ops.aten._grouped_mm.default(index_67, permute_772, cumsum_56); index_67 = permute_772 = None + convert_element_type_1048 = torch.ops.prims.convert_element_type.default(_grouped_mm_54, torch.float32); _grouped_mm_54 = None + neg_37 = torch.ops.aten.neg.default(convert_element_type_1048) + exp_56 = torch.ops.aten.exp.default(neg_37); neg_37 = None + add_1256 = torch.ops.aten.add.Tensor(exp_56, 1); exp_56 = None + div_94 = torch.ops.aten.div.Tensor(convert_element_type_1048, add_1256) + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(div_94, torch.bfloat16); div_94 = None + mul_1535 = torch.ops.aten.mul.Tensor(_grouped_mm_121, convert_element_type_1049); convert_element_type_1049 = None + mul_1536 = torch.ops.aten.mul.Tensor(_grouped_mm_121, _grouped_mm_55); _grouped_mm_121 = _grouped_mm_55 = None + permute_774 = torch.ops.aten.permute.default(mul_1535, [1, 0]) + _grouped_mm_122 = torch.ops.aten._grouped_mm.default(permute_774, index_37, cumsum_56); permute_774 = None + _grouped_mm_123 = torch.ops.aten._grouped_mm.default(mul_1535, permute_776, cumsum_56); mul_1535 = permute_776 = None + convert_element_type_1996 = torch.ops.prims.convert_element_type.default(mul_1536, torch.float32); mul_1536 = None + reciprocal_15 = torch.ops.aten.reciprocal.default(add_1256); add_1256 = None + mul_1537 = torch.ops.aten.mul.Tensor(reciprocal_15, 1); reciprocal_15 = None + mul_1538 = torch.ops.aten.mul.Tensor(convert_element_type_1996, mul_1537); convert_element_type_1996 = None + sub_668 = torch.ops.aten.sub.Tensor(1, mul_1537); mul_1537 = None + mul_1539 = torch.ops.aten.mul.Tensor(convert_element_type_1048, sub_668); convert_element_type_1048 = sub_668 = None + add_1884 = torch.ops.aten.add.Tensor(mul_1539, 1); mul_1539 = None + mul_1540 = torch.ops.aten.mul.Tensor(mul_1538, add_1884); mul_1538 = add_1884 = None + convert_element_type_1998 = torch.ops.prims.convert_element_type.default(mul_1540, torch.bfloat16); mul_1540 = None + permute_778 = torch.ops.aten.permute.default(convert_element_type_1998, [1, 0]) + _grouped_mm_124 = torch.ops.aten._grouped_mm.default(permute_778, index_37, cumsum_56); permute_778 = index_37 = None + _grouped_mm_125 = torch.ops.aten._grouped_mm.default(convert_element_type_1998, permute_780, cumsum_56); convert_element_type_1998 = permute_780 = cumsum_56 = None + add_1885 = torch.ops.aten.add.Tensor(_grouped_mm_123, _grouped_mm_125); _grouped_mm_123 = _grouped_mm_125 = None + convert_element_type_1999 = torch.ops.prims.convert_element_type.default(_grouped_mm_122, torch.float32); _grouped_mm_122 = None + div_174 = torch.ops.aten.div.Tensor(convert_element_type_1999, 128); convert_element_type_1999 = None + split_528 = torch.ops.aten.split.Tensor(div_174, 88, 1); div_174 = None + getitem_9885 = split_528[0] + getitem_9902 = split_528[1] + getitem_9919 = split_528[2] + getitem_9936 = split_528[3] + getitem_9953 = split_528[4] + getitem_9970 = split_528[5] + getitem_9987 = split_528[6] + getitem_10004 = split_528[7] + getitem_10021 = split_528[8] + getitem_10038 = split_528[9] + getitem_10055 = split_528[10] + getitem_10072 = split_528[11] + getitem_10089 = split_528[12] + getitem_10106 = split_528[13] + getitem_10123 = split_528[14] + getitem_10140 = split_528[15]; split_528 = None + cat_292 = torch.ops.aten.cat.default([getitem_9885, getitem_9902, getitem_9919, getitem_9936, getitem_9953, getitem_9970, getitem_9987, getitem_10004, getitem_10021, getitem_10038, getitem_10055, getitem_10072, getitem_10089, getitem_10106, getitem_10123, getitem_10140]); getitem_9885 = getitem_9902 = getitem_9919 = getitem_9936 = getitem_9953 = getitem_9970 = getitem_9987 = getitem_10004 = getitem_10021 = getitem_10038 = getitem_10055 = getitem_10072 = getitem_10089 = getitem_10106 = getitem_10123 = getitem_10140 = None + reduce_scatter_tensor_103 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_292, 'sum', 16, '1025'); cat_292 = None + wait_tensor_677 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_103); reduce_scatter_tensor_103 = None + convert_element_type_2000 = torch.ops.prims.convert_element_type.default(_grouped_mm_120, torch.float32); _grouped_mm_120 = None + div_175 = torch.ops.aten.div.Tensor(convert_element_type_2000, 128); convert_element_type_2000 = None + split_545 = torch.ops.aten.split.Tensor(div_175, 128, 1); div_175 = None + getitem_10157 = split_545[0] + getitem_10174 = split_545[1] + getitem_10191 = split_545[2] + getitem_10208 = split_545[3] + getitem_10225 = split_545[4] + getitem_10242 = split_545[5] + getitem_10259 = split_545[6] + getitem_10276 = split_545[7] + getitem_10293 = split_545[8] + getitem_10310 = split_545[9] + getitem_10327 = split_545[10] + getitem_10344 = split_545[11] + getitem_10361 = split_545[12] + getitem_10378 = split_545[13] + getitem_10395 = split_545[14] + getitem_10412 = split_545[15]; split_545 = None + cat_293 = torch.ops.aten.cat.default([getitem_10157, getitem_10174, getitem_10191, getitem_10208, getitem_10225, getitem_10242, getitem_10259, getitem_10276, getitem_10293, getitem_10310, getitem_10327, getitem_10344, getitem_10361, getitem_10378, getitem_10395, getitem_10412]); getitem_10157 = getitem_10174 = getitem_10191 = getitem_10208 = getitem_10225 = getitem_10242 = getitem_10259 = getitem_10276 = getitem_10293 = getitem_10310 = getitem_10327 = getitem_10344 = getitem_10361 = getitem_10378 = getitem_10395 = getitem_10412 = None + reduce_scatter_tensor_104 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_293, 'sum', 16, '1025'); cat_293 = None + wait_tensor_678 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_104); reduce_scatter_tensor_104 = None + convert_element_type_2001 = torch.ops.prims.convert_element_type.default(_grouped_mm_124, torch.float32); _grouped_mm_124 = None + div_176 = torch.ops.aten.div.Tensor(convert_element_type_2001, 128); convert_element_type_2001 = None + split_562 = torch.ops.aten.split.Tensor(div_176, 88, 1); div_176 = None + getitem_10429 = split_562[0] + getitem_10446 = split_562[1] + getitem_10463 = split_562[2] + getitem_10480 = split_562[3] + getitem_10497 = split_562[4] + getitem_10514 = split_562[5] + getitem_10531 = split_562[6] + getitem_10548 = split_562[7] + getitem_10565 = split_562[8] + getitem_10582 = split_562[9] + getitem_10599 = split_562[10] + getitem_10616 = split_562[11] + getitem_10633 = split_562[12] + getitem_10650 = split_562[13] + getitem_10667 = split_562[14] + getitem_10684 = split_562[15]; split_562 = None + cat_294 = torch.ops.aten.cat.default([getitem_10429, getitem_10446, getitem_10463, getitem_10480, getitem_10497, getitem_10514, getitem_10531, getitem_10548, getitem_10565, getitem_10582, getitem_10599, getitem_10616, getitem_10633, getitem_10650, getitem_10667, getitem_10684]); getitem_10429 = getitem_10446 = getitem_10463 = getitem_10480 = getitem_10497 = getitem_10514 = getitem_10531 = getitem_10548 = getitem_10565 = getitem_10582 = getitem_10599 = getitem_10616 = getitem_10633 = getitem_10650 = getitem_10667 = getitem_10684 = None + reduce_scatter_tensor_105 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_294, 'sum', 16, '1025'); cat_294 = None + wait_tensor_679 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_105); reduce_scatter_tensor_105 = None + index_put_66 = torch.ops.aten.index_put.default(full_390, [getitem_2002], add_1885, True); full_390 = getitem_2002 = add_1885 = None + slice_204 = torch.ops.aten.slice.Tensor(index_put_66, 0, 0, add_1886); index_put_66 = add_1886 = None + all_to_all_single_93 = torch.ops._c10d_functional.all_to_all_single.default(slice_204, [_local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295], [_local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303], '1033'); slice_204 = _local_scalar_dense_288 = _local_scalar_dense_289 = _local_scalar_dense_290 = _local_scalar_dense_291 = _local_scalar_dense_292 = _local_scalar_dense_293 = _local_scalar_dense_294 = _local_scalar_dense_295 = _local_scalar_dense_296 = _local_scalar_dense_297 = _local_scalar_dense_298 = _local_scalar_dense_299 = _local_scalar_dense_300 = _local_scalar_dense_301 = _local_scalar_dense_302 = _local_scalar_dense_303 = None + wait_tensor_680 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_93); all_to_all_single_93 = None + index_put_67 = torch.ops.aten.index_put.default(full_default_52, [div_92], wait_tensor_680, True); div_92 = wait_tensor_680 = None + add_1890 = torch.ops.aten.add.Tensor(add_1882, index_put_67); add_1882 = index_put_67 = None + mul_1541 = torch.ops.aten.mul.Tensor(view_1916, 1.0); view_1916 = None + scatter_add_7 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_1999, mul_1541); getitem_1999 = mul_1541 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(mm_155, torch.float32); mm_155 = None + sub_432 = torch.ops.aten.sub.Tensor(convert_element_type_1037, amax_18); convert_element_type_1037 = amax_18 = None + exp_55 = torch.ops.aten.exp.default(sub_432); sub_432 = None + div_91 = torch.ops.aten.div.Tensor(exp_55, sum_73); exp_55 = sum_73 = None + mul_1542 = torch.ops.aten.mul.Tensor(scatter_add_7, div_91); scatter_add_7 = None + sum_163 = torch.ops.aten.sum.dim_IntList(mul_1542, [1], True) + neg_76 = torch.ops.aten.neg.default(div_91); div_91 = None + fma_7 = torch.ops.prims.fma.default(neg_76, sum_163, mul_1542); neg_76 = sum_163 = mul_1542 = None + convert_element_type_2002 = torch.ops.prims.convert_element_type.default(fma_7, torch.bfloat16); fma_7 = None + permute_782 = torch.ops.aten.permute.default(convert_element_type_2002, [1, 0]) + mm_336 = torch.ops.aten.mm.default(permute_782, view_1264); permute_782 = view_1264 = None + convert_element_type_1034 = torch.ops.prims.convert_element_type.default(primals_319, torch.bfloat16); primals_319 = None + all_gather_into_tensor_324 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1034, 128, '0'); convert_element_type_1034 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_324); all_gather_into_tensor_324 = None + slice_117 = torch.ops.aten.slice.Tensor(wait_tensor_396, 0, 0, 64); wait_tensor_396 = None + permute_289 = torch.ops.aten.permute.default(slice_117, [1, 0]); slice_117 = None + permute_784 = torch.ops.aten.permute.default(permute_289, [1, 0]); permute_289 = None + mm_337 = torch.ops.aten.mm.default(convert_element_type_2002, permute_784); convert_element_type_2002 = permute_784 = None + add_1891 = torch.ops.aten.add.Tensor(add_1890, mm_337); add_1890 = mm_337 = None + convert_element_type_2007 = torch.ops.prims.convert_element_type.default(mm_336, torch.float32); mm_336 = None + split_578 = torch.ops.aten.split.Tensor(convert_element_type_2007, 1); convert_element_type_2007 = None + getitem_10685 = split_578[0] + getitem_10686 = split_578[1] + getitem_10687 = split_578[2] + getitem_10688 = split_578[3] + getitem_10689 = split_578[4] + getitem_10690 = split_578[5] + getitem_10691 = split_578[6] + getitem_10692 = split_578[7] + getitem_10693 = split_578[8] + getitem_10694 = split_578[9] + getitem_10695 = split_578[10] + getitem_10696 = split_578[11] + getitem_10697 = split_578[12] + getitem_10698 = split_578[13] + getitem_10699 = split_578[14] + getitem_10700 = split_578[15] + getitem_10701 = split_578[16] + getitem_10702 = split_578[17] + getitem_10703 = split_578[18] + getitem_10704 = split_578[19] + getitem_10705 = split_578[20] + getitem_10706 = split_578[21] + getitem_10707 = split_578[22] + getitem_10708 = split_578[23] + getitem_10709 = split_578[24] + getitem_10710 = split_578[25] + getitem_10711 = split_578[26] + getitem_10712 = split_578[27] + getitem_10713 = split_578[28] + getitem_10714 = split_578[29] + getitem_10715 = split_578[30] + getitem_10716 = split_578[31] + getitem_10717 = split_578[32] + getitem_10718 = split_578[33] + getitem_10719 = split_578[34] + getitem_10720 = split_578[35] + getitem_10721 = split_578[36] + getitem_10722 = split_578[37] + getitem_10723 = split_578[38] + getitem_10724 = split_578[39] + getitem_10725 = split_578[40] + getitem_10726 = split_578[41] + getitem_10727 = split_578[42] + getitem_10728 = split_578[43] + getitem_10729 = split_578[44] + getitem_10730 = split_578[45] + getitem_10731 = split_578[46] + getitem_10732 = split_578[47] + getitem_10733 = split_578[48] + getitem_10734 = split_578[49] + getitem_10735 = split_578[50] + getitem_10736 = split_578[51] + getitem_10737 = split_578[52] + getitem_10738 = split_578[53] + getitem_10739 = split_578[54] + getitem_10740 = split_578[55] + getitem_10741 = split_578[56] + getitem_10742 = split_578[57] + getitem_10743 = split_578[58] + getitem_10744 = split_578[59] + getitem_10745 = split_578[60] + getitem_10746 = split_578[61] + getitem_10747 = split_578[62] + getitem_10748 = split_578[63]; split_578 = None + cat_295 = torch.ops.aten.cat.default([getitem_10685, getitem_10686, getitem_10687, getitem_10688, getitem_10689, getitem_10690, getitem_10691, getitem_10692, getitem_10693, getitem_10694, getitem_10695, getitem_10696, getitem_10697, getitem_10698, getitem_10699, getitem_10700, getitem_10701, getitem_10702, getitem_10703, getitem_10704, getitem_10705, getitem_10706, getitem_10707, getitem_10708, getitem_10709, getitem_10710, getitem_10711, getitem_10712, getitem_10713, getitem_10714, getitem_10715, getitem_10716, getitem_10717, getitem_10718, getitem_10719, getitem_10720, getitem_10721, getitem_10722, getitem_10723, getitem_10724, getitem_10725, getitem_10726, getitem_10727, getitem_10728, getitem_10729, getitem_10730, getitem_10731, getitem_10732, getitem_10733, getitem_10734, getitem_10735, getitem_10736, getitem_10737, getitem_10738, getitem_10739, getitem_10740, getitem_10741, getitem_10742, getitem_10743, getitem_10744, getitem_10745, getitem_10746, getitem_10747, getitem_10748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_10685 = getitem_10686 = getitem_10687 = getitem_10688 = getitem_10689 = getitem_10690 = getitem_10691 = getitem_10692 = getitem_10693 = getitem_10694 = getitem_10695 = getitem_10696 = getitem_10697 = getitem_10698 = getitem_10699 = getitem_10700 = getitem_10701 = getitem_10702 = getitem_10703 = getitem_10704 = getitem_10705 = getitem_10706 = getitem_10707 = getitem_10708 = getitem_10709 = getitem_10710 = getitem_10711 = getitem_10712 = getitem_10713 = getitem_10714 = getitem_10715 = getitem_10716 = getitem_10717 = getitem_10718 = getitem_10719 = getitem_10720 = getitem_10721 = getitem_10722 = getitem_10723 = getitem_10724 = getitem_10725 = getitem_10726 = getitem_10727 = getitem_10728 = getitem_10729 = getitem_10730 = getitem_10731 = getitem_10732 = getitem_10733 = getitem_10734 = getitem_10735 = getitem_10736 = getitem_10737 = getitem_10738 = getitem_10739 = getitem_10740 = getitem_10741 = getitem_10742 = getitem_10743 = getitem_10744 = getitem_10745 = getitem_10746 = getitem_10747 = getitem_10748 = None + reduce_scatter_tensor_106 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_295, 'avg', 128, '0'); cat_295 = None + wait_tensor_681 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_106); reduce_scatter_tensor_106 = None + view_1918 = torch.ops.aten.view.default(add_1891, [2, 4096, 2048]); add_1891 = None + convert_element_type_2008 = torch.ops.prims.convert_element_type.default(view_1918, torch.float32); view_1918 = None + convert_element_type_1031 = torch.ops.prims.convert_element_type.default(primals_317, torch.bfloat16); primals_317 = None + all_gather_into_tensor_323 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1031, 128, '0'); convert_element_type_1031 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_323); all_gather_into_tensor_323 = None + convert_element_type_2010 = torch.ops.prims.convert_element_type.default(wait_tensor_395, torch.float32); wait_tensor_395 = None + mul_1543 = torch.ops.aten.mul.Tensor(convert_element_type_2008, convert_element_type_2010); convert_element_type_2010 = None + convert_element_type_1032 = torch.ops.prims.convert_element_type.default(add_1232, torch.float32); add_1232 = None + mul_897 = torch.ops.aten.mul.Tensor(convert_element_type_1032, rsqrt_59); convert_element_type_1032 = None + mul_1545 = torch.ops.aten.mul.Tensor(mul_897, mul_1543) + sum_164 = torch.ops.aten.sum.dim_IntList(mul_1545, [2], True); mul_1545 = None + div_177 = torch.ops.aten.div.Tensor(mul_897, 2048) + mul_1546 = torch.ops.aten.mul.Tensor(div_177, sum_164); div_177 = sum_164 = None + sub_670 = torch.ops.aten.sub.Tensor(mul_1543, mul_1546); mul_1543 = mul_1546 = None + mul_1547 = torch.ops.aten.mul.Tensor(sub_670, rsqrt_59); sub_670 = rsqrt_59 = None + mul_1548 = torch.ops.aten.mul.Tensor(convert_element_type_2008, mul_897); convert_element_type_2008 = mul_897 = None + sum_165 = torch.ops.aten.sum.dim_IntList(mul_1548, [0, 1]); mul_1548 = None + convert_element_type_2011 = torch.ops.prims.convert_element_type.default(mul_1547, torch.bfloat16); mul_1547 = None + add_1892 = torch.ops.aten.add.Tensor(add_1879, convert_element_type_2011); add_1879 = convert_element_type_2011 = None + convert_element_type_default_60 = torch.ops.prims.convert_element_type.default(sum_165, torch.float32); sum_165 = None + reduce_scatter_tensor_107 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_60, 'avg', 128, '0'); convert_element_type_default_60 = None + wait_tensor_682 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_107); reduce_scatter_tensor_107 = None + view_1919 = torch.ops.aten.view.default(add_1892, [8192, 2048]) + permute_786 = torch.ops.aten.permute.default(view_1919, [1, 0]) + permute_287 = torch.ops.aten.permute.default(getitem_1995, [0, 2, 1, 3]) + view_1259 = torch.ops.aten.view.default(permute_287, [2, 4096, -1]); permute_287 = None + view_1261 = torch.ops.aten.view.default(view_1259, [8192, 2048]); view_1259 = None + mm_338 = torch.ops.aten.mm.default(permute_786, view_1261); permute_786 = view_1261 = None + convert_element_type_1028 = torch.ops.prims.convert_element_type.default(primals_316, torch.bfloat16); primals_316 = None + all_gather_into_tensor_322 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1028, 128, '0'); convert_element_type_1028 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_322); all_gather_into_tensor_322 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_394, [1, 0]); wait_tensor_394 = None + permute_788 = torch.ops.aten.permute.default(permute_288, [1, 0]); permute_288 = None + mm_339 = torch.ops.aten.mm.default(view_1919, permute_788); view_1919 = permute_788 = None + view_1920 = torch.ops.aten.view.default(mm_339, [2, 4096, 2048]); mm_339 = None + convert_element_type_2018 = torch.ops.prims.convert_element_type.default(mm_338, torch.float32); mm_338 = None + reduce_scatter_tensor_108 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2018, 'avg', 128, '0'); convert_element_type_2018 = None + wait_tensor_683 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_108); reduce_scatter_tensor_108 = None + view_1921 = torch.ops.aten.view.default(view_1920, [2, 4096, 16, 128]); view_1920 = None + permute_790 = torch.ops.aten.permute.default(view_1921, [0, 2, 1, 3]); view_1921 = None + fw_graph7 = self.fw_graph7 + joint_graph7 = self.joint_graph7 + mask_graph7 = self.mask_graph7 + flex_attention_backward_7 = torch.ops.higher_order.flex_attention_backward(permute_284, permute_285, permute_286, getitem_1995, getitem_1996, permute_790, None, fw_graph7, joint_graph7, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph7), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_284 = permute_285 = permute_286 = getitem_1995 = getitem_1996 = permute_790 = fw_graph7 = joint_graph7 = mask_graph7 = None + getitem_10749 = flex_attention_backward_7[0] + getitem_10750 = flex_attention_backward_7[1] + getitem_10751 = flex_attention_backward_7[2]; flex_attention_backward_7 = None + permute_791 = torch.ops.aten.permute.default(getitem_10751, [0, 2, 1, 3]); getitem_10751 = None + permute_792 = torch.ops.aten.permute.default(getitem_10750, [0, 2, 1, 3]); getitem_10750 = None + permute_793 = torch.ops.aten.permute.default(getitem_10749, [0, 2, 1, 3]); getitem_10749 = None + slice_206 = torch.ops.aten.slice.Tensor(permute_792, 3, 0, 128) + slice_207 = torch.ops.aten.slice.Tensor(permute_792, 3, 128, 192); permute_792 = None + sum_166 = torch.ops.aten.sum.dim_IntList(slice_207, [2], True); slice_207 = None + cat_296 = torch.ops.aten.cat.default([slice_206, permute_791], 3); slice_206 = permute_791 = None + view_1922 = torch.ops.aten.view.default(cat_296, [2, 4096, 4096]); cat_296 = None + view_1923 = torch.ops.aten.view.default(view_1922, [8192, 4096]); view_1922 = None + permute_794 = torch.ops.aten.permute.default(view_1923, [1, 0]) + mm_340 = torch.ops.aten.mm.default(permute_794, view_1256); permute_794 = view_1256 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(primals_315, torch.bfloat16); primals_315 = None + all_gather_into_tensor_321 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1025, 128, '0'); convert_element_type_1025 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_321); all_gather_into_tensor_321 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_393, [1, 0]); wait_tensor_393 = None + permute_796 = torch.ops.aten.permute.default(permute_283, [1, 0]); permute_283 = None + mm_341 = torch.ops.aten.mm.default(view_1923, permute_796); view_1923 = permute_796 = None + view_1924 = torch.ops.aten.view.default(mm_341, [2, 4096, 512]); mm_341 = None + convert_element_type_2023 = torch.ops.prims.convert_element_type.default(mm_340, torch.float32); mm_340 = None + reduce_scatter_tensor_109 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2023, 'avg', 128, '0'); convert_element_type_2023 = None + wait_tensor_684 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_109); reduce_scatter_tensor_109 = None + convert_element_type_2024 = torch.ops.prims.convert_element_type.default(view_1924, torch.float32); view_1924 = None + convert_element_type_1022 = torch.ops.prims.convert_element_type.default(primals_314, torch.bfloat16); primals_314 = None + all_gather_into_tensor_320 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1022, 128, '0'); convert_element_type_1022 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_320); all_gather_into_tensor_320 = None + convert_element_type_2026 = torch.ops.prims.convert_element_type.default(wait_tensor_392, torch.float32); wait_tensor_392 = None + mul_1549 = torch.ops.aten.mul.Tensor(convert_element_type_2024, convert_element_type_2026); convert_element_type_2026 = None + convert_element_type_1023 = torch.ops.prims.convert_element_type.default(getitem_1991, torch.float32); getitem_1991 = None + mul_895 = torch.ops.aten.mul.Tensor(convert_element_type_1023, rsqrt_58); convert_element_type_1023 = None + mul_1551 = torch.ops.aten.mul.Tensor(mul_895, mul_1549) + sum_167 = torch.ops.aten.sum.dim_IntList(mul_1551, [2], True); mul_1551 = None + div_178 = torch.ops.aten.div.Tensor(mul_895, 512) + mul_1552 = torch.ops.aten.mul.Tensor(div_178, sum_167); div_178 = sum_167 = None + sub_671 = torch.ops.aten.sub.Tensor(mul_1549, mul_1552); mul_1549 = mul_1552 = None + mul_1553 = torch.ops.aten.mul.Tensor(sub_671, rsqrt_58); sub_671 = rsqrt_58 = None + mul_1554 = torch.ops.aten.mul.Tensor(convert_element_type_2024, mul_895); convert_element_type_2024 = mul_895 = None + sum_168 = torch.ops.aten.sum.dim_IntList(mul_1554, [0, 1]); mul_1554 = None + convert_element_type_2027 = torch.ops.prims.convert_element_type.default(mul_1553, torch.bfloat16); mul_1553 = None + convert_element_type_default_59 = torch.ops.prims.convert_element_type.default(sum_168, torch.float32); sum_168 = None + reduce_scatter_tensor_110 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_59, 'avg', 128, '0'); convert_element_type_default_59 = None + wait_tensor_685 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_110); reduce_scatter_tensor_110 = None + convert_element_type_2030 = torch.ops.prims.convert_element_type.default(sum_166, torch.float32); sum_166 = None + view_1925 = torch.ops.aten.view.default(convert_element_type_2030, [2, 4096, 1, 32, 2]); convert_element_type_2030 = None + view_as_complex_68 = torch.ops.aten.view_as_complex.default(view_1925); view_1925 = None + mul_1555 = torch.ops.aten.mul.Tensor(view_as_complex_68, clone_9); view_as_complex_68 = None + view_as_real_68 = torch.ops.aten.view_as_real.default(mul_1555); mul_1555 = None + view_1926 = torch.ops.aten.view.default(view_as_real_68, [2, 4096, 1, 64]); view_as_real_68 = None + convert_element_type_2031 = torch.ops.prims.convert_element_type.default(view_1926, torch.bfloat16); view_1926 = None + squeeze_33 = torch.ops.aten.squeeze.dim(convert_element_type_2031, 2); convert_element_type_2031 = None + cat_297 = torch.ops.aten.cat.default([convert_element_type_2027, squeeze_33], 2); convert_element_type_2027 = squeeze_33 = None + view_1927 = torch.ops.aten.view.default(cat_297, [8192, 576]); cat_297 = None + permute_798 = torch.ops.aten.permute.default(view_1927, [1, 0]) + mm_342 = torch.ops.aten.mm.default(permute_798, view_1242); permute_798 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(primals_313, torch.bfloat16); primals_313 = None + all_gather_into_tensor_319 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1017, 128, '0'); convert_element_type_1017 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_319); all_gather_into_tensor_319 = None + slice_115 = torch.ops.aten.slice.Tensor(wait_tensor_391, 0, 0, 576); wait_tensor_391 = None + permute_282 = torch.ops.aten.permute.default(slice_115, [1, 0]); slice_115 = None + permute_800 = torch.ops.aten.permute.default(permute_282, [1, 0]); permute_282 = None + mm_343 = torch.ops.aten.mm.default(view_1927, permute_800); view_1927 = permute_800 = None + view_1928 = torch.ops.aten.view.default(mm_343, [2, 4096, 2048]); mm_343 = None + convert_element_type_2036 = torch.ops.prims.convert_element_type.default(mm_342, torch.float32); mm_342 = None + split_579 = torch.ops.aten.split.Tensor(convert_element_type_2036, 5); convert_element_type_2036 = None + getitem_10753 = split_579[0] + getitem_10754 = split_579[1] + getitem_10755 = split_579[2] + getitem_10756 = split_579[3] + getitem_10757 = split_579[4] + getitem_10758 = split_579[5] + getitem_10759 = split_579[6] + getitem_10760 = split_579[7] + getitem_10761 = split_579[8] + getitem_10762 = split_579[9] + getitem_10763 = split_579[10] + getitem_10764 = split_579[11] + getitem_10765 = split_579[12] + getitem_10766 = split_579[13] + getitem_10767 = split_579[14] + getitem_10768 = split_579[15] + getitem_10769 = split_579[16] + getitem_10770 = split_579[17] + getitem_10771 = split_579[18] + getitem_10772 = split_579[19] + getitem_10773 = split_579[20] + getitem_10774 = split_579[21] + getitem_10775 = split_579[22] + getitem_10776 = split_579[23] + getitem_10777 = split_579[24] + getitem_10778 = split_579[25] + getitem_10779 = split_579[26] + getitem_10780 = split_579[27] + getitem_10781 = split_579[28] + getitem_10782 = split_579[29] + getitem_10783 = split_579[30] + getitem_10784 = split_579[31] + getitem_10785 = split_579[32] + getitem_10786 = split_579[33] + getitem_10787 = split_579[34] + getitem_10788 = split_579[35] + getitem_10789 = split_579[36] + getitem_10790 = split_579[37] + getitem_10791 = split_579[38] + getitem_10792 = split_579[39] + getitem_10793 = split_579[40] + getitem_10794 = split_579[41] + getitem_10795 = split_579[42] + getitem_10796 = split_579[43] + getitem_10797 = split_579[44] + getitem_10798 = split_579[45] + getitem_10799 = split_579[46] + getitem_10800 = split_579[47] + getitem_10801 = split_579[48] + getitem_10802 = split_579[49] + getitem_10803 = split_579[50] + getitem_10804 = split_579[51] + getitem_10805 = split_579[52] + getitem_10806 = split_579[53] + getitem_10807 = split_579[54] + getitem_10808 = split_579[55] + getitem_10809 = split_579[56] + getitem_10810 = split_579[57] + getitem_10811 = split_579[58] + getitem_10812 = split_579[59] + getitem_10813 = split_579[60] + getitem_10814 = split_579[61] + getitem_10815 = split_579[62] + getitem_10816 = split_579[63] + getitem_10817 = split_579[64] + getitem_10818 = split_579[65] + getitem_10819 = split_579[66] + getitem_10820 = split_579[67] + getitem_10821 = split_579[68] + getitem_10822 = split_579[69] + getitem_10823 = split_579[70] + getitem_10824 = split_579[71] + getitem_10825 = split_579[72] + getitem_10826 = split_579[73] + getitem_10827 = split_579[74] + getitem_10828 = split_579[75] + getitem_10829 = split_579[76] + getitem_10830 = split_579[77] + getitem_10831 = split_579[78] + getitem_10832 = split_579[79] + getitem_10833 = split_579[80] + getitem_10834 = split_579[81] + getitem_10835 = split_579[82] + getitem_10836 = split_579[83] + getitem_10837 = split_579[84] + getitem_10838 = split_579[85] + getitem_10839 = split_579[86] + getitem_10840 = split_579[87] + getitem_10841 = split_579[88] + getitem_10842 = split_579[89] + getitem_10843 = split_579[90] + getitem_10844 = split_579[91] + getitem_10845 = split_579[92] + getitem_10846 = split_579[93] + getitem_10847 = split_579[94] + getitem_10848 = split_579[95] + getitem_10849 = split_579[96] + getitem_10850 = split_579[97] + getitem_10851 = split_579[98] + getitem_10852 = split_579[99] + getitem_10853 = split_579[100] + getitem_10854 = split_579[101] + getitem_10855 = split_579[102] + getitem_10856 = split_579[103] + getitem_10857 = split_579[104] + getitem_10858 = split_579[105] + getitem_10859 = split_579[106] + getitem_10860 = split_579[107] + getitem_10861 = split_579[108] + getitem_10862 = split_579[109] + getitem_10863 = split_579[110] + getitem_10864 = split_579[111] + getitem_10865 = split_579[112] + getitem_10866 = split_579[113] + getitem_10867 = split_579[114] + getitem_10868 = split_579[115]; split_579 = None + constant_pad_nd_603 = torch.ops.aten.constant_pad_nd.default(getitem_10868, [0, 0, 0, 4], 0.0); getitem_10868 = None + cat_298 = torch.ops.aten.cat.default([getitem_10753, getitem_10754, getitem_10755, getitem_10756, getitem_10757, getitem_10758, getitem_10759, getitem_10760, getitem_10761, getitem_10762, getitem_10763, getitem_10764, getitem_10765, getitem_10766, getitem_10767, getitem_10768, getitem_10769, getitem_10770, getitem_10771, getitem_10772, getitem_10773, getitem_10774, getitem_10775, getitem_10776, getitem_10777, getitem_10778, getitem_10779, getitem_10780, getitem_10781, getitem_10782, getitem_10783, getitem_10784, getitem_10785, getitem_10786, getitem_10787, getitem_10788, getitem_10789, getitem_10790, getitem_10791, getitem_10792, getitem_10793, getitem_10794, getitem_10795, getitem_10796, getitem_10797, getitem_10798, getitem_10799, getitem_10800, getitem_10801, getitem_10802, getitem_10803, getitem_10804, getitem_10805, getitem_10806, getitem_10807, getitem_10808, getitem_10809, getitem_10810, getitem_10811, getitem_10812, getitem_10813, getitem_10814, getitem_10815, getitem_10816, getitem_10817, getitem_10818, getitem_10819, getitem_10820, getitem_10821, getitem_10822, getitem_10823, getitem_10824, getitem_10825, getitem_10826, getitem_10827, getitem_10828, getitem_10829, getitem_10830, getitem_10831, getitem_10832, getitem_10833, getitem_10834, getitem_10835, getitem_10836, getitem_10837, getitem_10838, getitem_10839, getitem_10840, getitem_10841, getitem_10842, getitem_10843, getitem_10844, getitem_10845, getitem_10846, getitem_10847, getitem_10848, getitem_10849, getitem_10850, getitem_10851, getitem_10852, getitem_10853, getitem_10854, getitem_10855, getitem_10856, getitem_10857, getitem_10858, getitem_10859, getitem_10860, getitem_10861, getitem_10862, getitem_10863, getitem_10864, getitem_10865, getitem_10866, getitem_10867, constant_pad_nd_603, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_10753 = getitem_10754 = getitem_10755 = getitem_10756 = getitem_10757 = getitem_10758 = getitem_10759 = getitem_10760 = getitem_10761 = getitem_10762 = getitem_10763 = getitem_10764 = getitem_10765 = getitem_10766 = getitem_10767 = getitem_10768 = getitem_10769 = getitem_10770 = getitem_10771 = getitem_10772 = getitem_10773 = getitem_10774 = getitem_10775 = getitem_10776 = getitem_10777 = getitem_10778 = getitem_10779 = getitem_10780 = getitem_10781 = getitem_10782 = getitem_10783 = getitem_10784 = getitem_10785 = getitem_10786 = getitem_10787 = getitem_10788 = getitem_10789 = getitem_10790 = getitem_10791 = getitem_10792 = getitem_10793 = getitem_10794 = getitem_10795 = getitem_10796 = getitem_10797 = getitem_10798 = getitem_10799 = getitem_10800 = getitem_10801 = getitem_10802 = getitem_10803 = getitem_10804 = getitem_10805 = getitem_10806 = getitem_10807 = getitem_10808 = getitem_10809 = getitem_10810 = getitem_10811 = getitem_10812 = getitem_10813 = getitem_10814 = getitem_10815 = getitem_10816 = getitem_10817 = getitem_10818 = getitem_10819 = getitem_10820 = getitem_10821 = getitem_10822 = getitem_10823 = getitem_10824 = getitem_10825 = getitem_10826 = getitem_10827 = getitem_10828 = getitem_10829 = getitem_10830 = getitem_10831 = getitem_10832 = getitem_10833 = getitem_10834 = getitem_10835 = getitem_10836 = getitem_10837 = getitem_10838 = getitem_10839 = getitem_10840 = getitem_10841 = getitem_10842 = getitem_10843 = getitem_10844 = getitem_10845 = getitem_10846 = getitem_10847 = getitem_10848 = getitem_10849 = getitem_10850 = getitem_10851 = getitem_10852 = getitem_10853 = getitem_10854 = getitem_10855 = getitem_10856 = getitem_10857 = getitem_10858 = getitem_10859 = getitem_10860 = getitem_10861 = getitem_10862 = getitem_10863 = getitem_10864 = getitem_10865 = getitem_10866 = getitem_10867 = constant_pad_nd_603 = None + reduce_scatter_tensor_111 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_298, 'avg', 128, '0'); cat_298 = None + wait_tensor_686 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_111); reduce_scatter_tensor_111 = None + slice_208 = torch.ops.aten.slice.Tensor(permute_793, 3, 0, 128) + slice_209 = torch.ops.aten.slice.Tensor(permute_793, 3, 128, 192); permute_793 = None + convert_element_type_2037 = torch.ops.prims.convert_element_type.default(slice_209, torch.float32); slice_209 = None + view_1929 = torch.ops.aten.view.default(convert_element_type_2037, [2, 4096, 16, 32, 2]); convert_element_type_2037 = None + view_as_complex_69 = torch.ops.aten.view_as_complex.default(view_1929); view_1929 = None + mul_1556 = torch.ops.aten.mul.Tensor(view_as_complex_69, clone_9); view_as_complex_69 = None + view_as_real_69 = torch.ops.aten.view_as_real.default(mul_1556); mul_1556 = None + view_1930 = torch.ops.aten.view.default(view_as_real_69, [2, 4096, 16, 64]); view_as_real_69 = None + convert_element_type_2038 = torch.ops.prims.convert_element_type.default(view_1930, torch.bfloat16); view_1930 = None + cat_299 = torch.ops.aten.cat.default([slice_208, convert_element_type_2038], 3); slice_208 = convert_element_type_2038 = None + view_1931 = torch.ops.aten.view.default(cat_299, [2, 4096, 3072]); cat_299 = None + view_1932 = torch.ops.aten.view.default(view_1931, [8192, 3072]); view_1931 = None + permute_802 = torch.ops.aten.permute.default(view_1932, [1, 0]) + mm_344 = torch.ops.aten.mm.default(permute_802, view_1242); permute_802 = view_1242 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(primals_312, torch.bfloat16); primals_312 = None + all_gather_into_tensor_318 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1012, 128, '0'); convert_element_type_1012 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_318); all_gather_into_tensor_318 = None + permute_281 = torch.ops.aten.permute.default(wait_tensor_390, [1, 0]); wait_tensor_390 = None + permute_804 = torch.ops.aten.permute.default(permute_281, [1, 0]); permute_281 = None + mm_345 = torch.ops.aten.mm.default(view_1932, permute_804); view_1932 = permute_804 = None + view_1933 = torch.ops.aten.view.default(mm_345, [2, 4096, 2048]); mm_345 = None + add_1893 = torch.ops.aten.add.Tensor(view_1928, view_1933); view_1928 = view_1933 = None + convert_element_type_2043 = torch.ops.prims.convert_element_type.default(mm_344, torch.float32); mm_344 = None + reduce_scatter_tensor_112 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2043, 'avg', 128, '0'); convert_element_type_2043 = None + wait_tensor_687 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_112); reduce_scatter_tensor_112 = None + convert_element_type_2044 = torch.ops.prims.convert_element_type.default(add_1893, torch.float32); add_1893 = None + convert_element_type_1009 = torch.ops.prims.convert_element_type.default(primals_311, torch.bfloat16); primals_311 = None + all_gather_into_tensor_317 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1009, 128, '0'); convert_element_type_1009 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_317); all_gather_into_tensor_317 = None + convert_element_type_2046 = torch.ops.prims.convert_element_type.default(wait_tensor_389, torch.float32); wait_tensor_389 = None + mul_1557 = torch.ops.aten.mul.Tensor(convert_element_type_2044, convert_element_type_2046); convert_element_type_2046 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(add_1229, torch.float32); add_1229 = None + mul_891 = torch.ops.aten.mul.Tensor(convert_element_type_1010, rsqrt_57); convert_element_type_1010 = None + mul_1559 = torch.ops.aten.mul.Tensor(mul_891, mul_1557) + sum_169 = torch.ops.aten.sum.dim_IntList(mul_1559, [2], True); mul_1559 = None + div_179 = torch.ops.aten.div.Tensor(mul_891, 2048) + mul_1560 = torch.ops.aten.mul.Tensor(div_179, sum_169); div_179 = sum_169 = None + sub_672 = torch.ops.aten.sub.Tensor(mul_1557, mul_1560); mul_1557 = mul_1560 = None + mul_1561 = torch.ops.aten.mul.Tensor(sub_672, rsqrt_57); sub_672 = rsqrt_57 = None + mul_1562 = torch.ops.aten.mul.Tensor(convert_element_type_2044, mul_891); convert_element_type_2044 = mul_891 = None + sum_170 = torch.ops.aten.sum.dim_IntList(mul_1562, [0, 1]); mul_1562 = None + convert_element_type_2047 = torch.ops.prims.convert_element_type.default(mul_1561, torch.bfloat16); mul_1561 = None + add_1894 = torch.ops.aten.add.Tensor(add_1892, convert_element_type_2047); add_1892 = convert_element_type_2047 = None + convert_element_type_default_58 = torch.ops.prims.convert_element_type.default(sum_170, torch.float32); sum_170 = None + reduce_scatter_tensor_113 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_58, 'avg', 128, '0'); convert_element_type_default_58 = None + wait_tensor_688 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_113); reduce_scatter_tensor_113 = None + view_1934 = torch.ops.aten.view.default(add_1894, [8192, 2048]) + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_1934, 1) + convert_element_type_2050 = torch.ops.prims.convert_element_type.default(unsqueeze_61, torch.float32); unsqueeze_61 = None + bmm_42 = torch.ops.aten.bmm.default(permute_806, convert_element_type_2050); permute_806 = None + bmm_43 = torch.ops.aten.bmm.default(convert_element_type_2050, permute_807); convert_element_type_2050 = permute_807 = None + convert_element_type_2051 = torch.ops.prims.convert_element_type.default(bmm_42, torch.bfloat16); bmm_42 = None + view_1935 = torch.ops.aten.view.default(bmm_43, [8192, 6]); bmm_43 = None + view_1936 = torch.ops.aten.view.default(convert_element_type_2051, [49152, 2048]); convert_element_type_2051 = None + index_68 = torch.ops.aten.index.Tensor(view_1936, [getitem_1891]); view_1936 = getitem_1891 = None + permute_808 = torch.ops.aten.permute.default(view_1934, [1, 0]) + mm_346 = torch.ops.aten.mm.default(permute_808, mul_888); permute_808 = mul_888 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(primals_310, torch.bfloat16); primals_310 = None + all_gather_into_tensor_316 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1004, 128, '0'); convert_element_type_1004 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_316); all_gather_into_tensor_316 = None + permute_280 = torch.ops.aten.permute.default(wait_tensor_388, [1, 0]); wait_tensor_388 = None + permute_810 = torch.ops.aten.permute.default(permute_280, [1, 0]); permute_280 = None + mm_347 = torch.ops.aten.mm.default(view_1934, permute_810); view_1934 = permute_810 = None + convert_element_type_2056 = torch.ops.prims.convert_element_type.default(mm_346, torch.float32); mm_346 = None + reduce_scatter_tensor_114 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2056, 'avg', 128, '0'); convert_element_type_2056 = None + wait_tensor_689 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_114); reduce_scatter_tensor_114 = None + convert_element_type_999 = torch.ops.prims.convert_element_type.default(mm_148, torch.float32); mm_148 = None + neg_36 = torch.ops.aten.neg.default(convert_element_type_999) + exp_54 = torch.ops.aten.exp.default(neg_36); neg_36 = None + add_1224 = torch.ops.aten.add.Tensor(exp_54, 1); exp_54 = None + div_90 = torch.ops.aten.div.Tensor(convert_element_type_999, add_1224) + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(div_90, torch.bfloat16); div_90 = None + mul_1563 = torch.ops.aten.mul.Tensor(mm_347, convert_element_type_1000); convert_element_type_1000 = None + mul_1564 = torch.ops.aten.mul.Tensor(mm_347, mm_149); mm_347 = mm_149 = None + permute_812 = torch.ops.aten.permute.default(mul_1563, [1, 0]) + mm_348 = torch.ops.aten.mm.default(permute_812, view_1197); permute_812 = None + convert_element_type_1001 = torch.ops.prims.convert_element_type.default(primals_309, torch.bfloat16); primals_309 = None + all_gather_into_tensor_315 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1001, 128, '0'); convert_element_type_1001 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_315); all_gather_into_tensor_315 = None + permute_279 = torch.ops.aten.permute.default(wait_tensor_387, [1, 0]); wait_tensor_387 = None + permute_814 = torch.ops.aten.permute.default(permute_279, [1, 0]); permute_279 = None + mm_349 = torch.ops.aten.mm.default(mul_1563, permute_814); mul_1563 = permute_814 = None + convert_element_type_2061 = torch.ops.prims.convert_element_type.default(mm_348, torch.float32); mm_348 = None + reduce_scatter_tensor_115 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2061, 'avg', 128, '0'); convert_element_type_2061 = None + wait_tensor_690 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_115); reduce_scatter_tensor_115 = None + convert_element_type_2062 = torch.ops.prims.convert_element_type.default(mul_1564, torch.float32); mul_1564 = None + reciprocal_16 = torch.ops.aten.reciprocal.default(add_1224); add_1224 = None + mul_1565 = torch.ops.aten.mul.Tensor(reciprocal_16, 1); reciprocal_16 = None + mul_1566 = torch.ops.aten.mul.Tensor(convert_element_type_2062, mul_1565); convert_element_type_2062 = None + sub_673 = torch.ops.aten.sub.Tensor(1, mul_1565); mul_1565 = None + mul_1567 = torch.ops.aten.mul.Tensor(convert_element_type_999, sub_673); convert_element_type_999 = sub_673 = None + add_1896 = torch.ops.aten.add.Tensor(mul_1567, 1); mul_1567 = None + mul_1568 = torch.ops.aten.mul.Tensor(mul_1566, add_1896); mul_1566 = add_1896 = None + convert_element_type_2064 = torch.ops.prims.convert_element_type.default(mul_1568, torch.bfloat16); mul_1568 = None + permute_816 = torch.ops.aten.permute.default(convert_element_type_2064, [1, 0]) + mm_350 = torch.ops.aten.mm.default(permute_816, view_1197); permute_816 = None + convert_element_type_996 = torch.ops.prims.convert_element_type.default(primals_308, torch.bfloat16); primals_308 = None + all_gather_into_tensor_314 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_996, 128, '0'); convert_element_type_996 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_314); all_gather_into_tensor_314 = None + permute_278 = torch.ops.aten.permute.default(wait_tensor_386, [1, 0]); wait_tensor_386 = None + permute_818 = torch.ops.aten.permute.default(permute_278, [1, 0]); permute_278 = None + mm_351 = torch.ops.aten.mm.default(convert_element_type_2064, permute_818); convert_element_type_2064 = permute_818 = None + add_1897 = torch.ops.aten.add.Tensor(mm_349, mm_351); mm_349 = mm_351 = None + convert_element_type_2069 = torch.ops.prims.convert_element_type.default(mm_350, torch.float32); mm_350 = None + reduce_scatter_tensor_116 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2069, 'avg', 128, '0'); convert_element_type_2069 = None + wait_tensor_691 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_116); reduce_scatter_tensor_116 = None + all_to_all_single_94 = torch.ops._c10d_functional.all_to_all_single.default(index_68, [_local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287], [_local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279], '1033'); index_68 = None + wait_tensor_692 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_94); all_to_all_single_94 = None + full_396 = torch.ops.aten.full.default([sym_size_int_69, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_69 = None + slice_scatter_8 = torch.ops.aten.slice_scatter.default(full_396, wait_tensor_692, 0, 0, -1); wait_tensor_692 = None + index_69 = torch.ops.aten.index.Tensor(slice_scatter_8, [getitem_1892]); slice_scatter_8 = None + permute_820 = torch.ops.aten.permute.default(index_69, [1, 0]) + _grouped_mm_126 = torch.ops.aten._grouped_mm.default(permute_820, mul_868, cumsum_53); permute_820 = mul_868 = None + _grouped_mm_127 = torch.ops.aten._grouped_mm.default(index_69, permute_822, cumsum_53); index_69 = permute_822 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(_grouped_mm_51, torch.float32); _grouped_mm_51 = None + neg_35 = torch.ops.aten.neg.default(convert_element_type_994) + exp_53 = torch.ops.aten.exp.default(neg_35); neg_35 = None + add_1188 = torch.ops.aten.add.Tensor(exp_53, 1); exp_53 = None + div_89 = torch.ops.aten.div.Tensor(convert_element_type_994, add_1188) + convert_element_type_995 = torch.ops.prims.convert_element_type.default(div_89, torch.bfloat16); div_89 = None + mul_1569 = torch.ops.aten.mul.Tensor(_grouped_mm_127, convert_element_type_995); convert_element_type_995 = None + mul_1570 = torch.ops.aten.mul.Tensor(_grouped_mm_127, _grouped_mm_52); _grouped_mm_127 = _grouped_mm_52 = None + permute_824 = torch.ops.aten.permute.default(mul_1569, [1, 0]) + _grouped_mm_128 = torch.ops.aten._grouped_mm.default(permute_824, index_35, cumsum_53); permute_824 = None + _grouped_mm_129 = torch.ops.aten._grouped_mm.default(mul_1569, permute_826, cumsum_53); mul_1569 = permute_826 = None + convert_element_type_2070 = torch.ops.prims.convert_element_type.default(mul_1570, torch.float32); mul_1570 = None + reciprocal_17 = torch.ops.aten.reciprocal.default(add_1188); add_1188 = None + mul_1571 = torch.ops.aten.mul.Tensor(reciprocal_17, 1); reciprocal_17 = None + mul_1572 = torch.ops.aten.mul.Tensor(convert_element_type_2070, mul_1571); convert_element_type_2070 = None + sub_674 = torch.ops.aten.sub.Tensor(1, mul_1571); mul_1571 = None + mul_1573 = torch.ops.aten.mul.Tensor(convert_element_type_994, sub_674); convert_element_type_994 = sub_674 = None + add_1899 = torch.ops.aten.add.Tensor(mul_1573, 1); mul_1573 = None + mul_1574 = torch.ops.aten.mul.Tensor(mul_1572, add_1899); mul_1572 = add_1899 = None + convert_element_type_2072 = torch.ops.prims.convert_element_type.default(mul_1574, torch.bfloat16); mul_1574 = None + permute_828 = torch.ops.aten.permute.default(convert_element_type_2072, [1, 0]) + _grouped_mm_130 = torch.ops.aten._grouped_mm.default(permute_828, index_35, cumsum_53); permute_828 = index_35 = None + _grouped_mm_131 = torch.ops.aten._grouped_mm.default(convert_element_type_2072, permute_830, cumsum_53); convert_element_type_2072 = permute_830 = cumsum_53 = None + add_1900 = torch.ops.aten.add.Tensor(_grouped_mm_129, _grouped_mm_131); _grouped_mm_129 = _grouped_mm_131 = None + convert_element_type_2073 = torch.ops.prims.convert_element_type.default(_grouped_mm_128, torch.float32); _grouped_mm_128 = None + div_180 = torch.ops.aten.div.Tensor(convert_element_type_2073, 128); convert_element_type_2073 = None + split_581 = torch.ops.aten.split.Tensor(div_180, 88, 1); div_180 = None + getitem_10885 = split_581[0] + getitem_10902 = split_581[1] + getitem_10919 = split_581[2] + getitem_10936 = split_581[3] + getitem_10953 = split_581[4] + getitem_10970 = split_581[5] + getitem_10987 = split_581[6] + getitem_11004 = split_581[7] + getitem_11021 = split_581[8] + getitem_11038 = split_581[9] + getitem_11055 = split_581[10] + getitem_11072 = split_581[11] + getitem_11089 = split_581[12] + getitem_11106 = split_581[13] + getitem_11123 = split_581[14] + getitem_11140 = split_581[15]; split_581 = None + cat_300 = torch.ops.aten.cat.default([getitem_10885, getitem_10902, getitem_10919, getitem_10936, getitem_10953, getitem_10970, getitem_10987, getitem_11004, getitem_11021, getitem_11038, getitem_11055, getitem_11072, getitem_11089, getitem_11106, getitem_11123, getitem_11140]); getitem_10885 = getitem_10902 = getitem_10919 = getitem_10936 = getitem_10953 = getitem_10970 = getitem_10987 = getitem_11004 = getitem_11021 = getitem_11038 = getitem_11055 = getitem_11072 = getitem_11089 = getitem_11106 = getitem_11123 = getitem_11140 = None + reduce_scatter_tensor_117 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_300, 'sum', 16, '1025'); cat_300 = None + wait_tensor_693 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_117); reduce_scatter_tensor_117 = None + convert_element_type_2074 = torch.ops.prims.convert_element_type.default(_grouped_mm_126, torch.float32); _grouped_mm_126 = None + div_181 = torch.ops.aten.div.Tensor(convert_element_type_2074, 128); convert_element_type_2074 = None + split_598 = torch.ops.aten.split.Tensor(div_181, 128, 1); div_181 = None + getitem_11157 = split_598[0] + getitem_11174 = split_598[1] + getitem_11191 = split_598[2] + getitem_11208 = split_598[3] + getitem_11225 = split_598[4] + getitem_11242 = split_598[5] + getitem_11259 = split_598[6] + getitem_11276 = split_598[7] + getitem_11293 = split_598[8] + getitem_11310 = split_598[9] + getitem_11327 = split_598[10] + getitem_11344 = split_598[11] + getitem_11361 = split_598[12] + getitem_11378 = split_598[13] + getitem_11395 = split_598[14] + getitem_11412 = split_598[15]; split_598 = None + cat_301 = torch.ops.aten.cat.default([getitem_11157, getitem_11174, getitem_11191, getitem_11208, getitem_11225, getitem_11242, getitem_11259, getitem_11276, getitem_11293, getitem_11310, getitem_11327, getitem_11344, getitem_11361, getitem_11378, getitem_11395, getitem_11412]); getitem_11157 = getitem_11174 = getitem_11191 = getitem_11208 = getitem_11225 = getitem_11242 = getitem_11259 = getitem_11276 = getitem_11293 = getitem_11310 = getitem_11327 = getitem_11344 = getitem_11361 = getitem_11378 = getitem_11395 = getitem_11412 = None + reduce_scatter_tensor_118 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_301, 'sum', 16, '1025'); cat_301 = None + wait_tensor_694 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_118); reduce_scatter_tensor_118 = None + convert_element_type_2075 = torch.ops.prims.convert_element_type.default(_grouped_mm_130, torch.float32); _grouped_mm_130 = None + div_182 = torch.ops.aten.div.Tensor(convert_element_type_2075, 128); convert_element_type_2075 = None + split_615 = torch.ops.aten.split.Tensor(div_182, 88, 1); div_182 = None + getitem_11429 = split_615[0] + getitem_11446 = split_615[1] + getitem_11463 = split_615[2] + getitem_11480 = split_615[3] + getitem_11497 = split_615[4] + getitem_11514 = split_615[5] + getitem_11531 = split_615[6] + getitem_11548 = split_615[7] + getitem_11565 = split_615[8] + getitem_11582 = split_615[9] + getitem_11599 = split_615[10] + getitem_11616 = split_615[11] + getitem_11633 = split_615[12] + getitem_11650 = split_615[13] + getitem_11667 = split_615[14] + getitem_11684 = split_615[15]; split_615 = None + cat_302 = torch.ops.aten.cat.default([getitem_11429, getitem_11446, getitem_11463, getitem_11480, getitem_11497, getitem_11514, getitem_11531, getitem_11548, getitem_11565, getitem_11582, getitem_11599, getitem_11616, getitem_11633, getitem_11650, getitem_11667, getitem_11684]); getitem_11429 = getitem_11446 = getitem_11463 = getitem_11480 = getitem_11497 = getitem_11514 = getitem_11531 = getitem_11548 = getitem_11565 = getitem_11582 = getitem_11599 = getitem_11616 = getitem_11633 = getitem_11650 = getitem_11667 = getitem_11684 = None + reduce_scatter_tensor_119 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_302, 'sum', 16, '1025'); cat_302 = None + wait_tensor_695 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_119); reduce_scatter_tensor_119 = None + index_put_68 = torch.ops.aten.index_put.default(full_396, [getitem_1892], add_1900, True); full_396 = getitem_1892 = add_1900 = None + slice_210 = torch.ops.aten.slice.Tensor(index_put_68, 0, 0, add_1901); index_put_68 = add_1901 = None + all_to_all_single_95 = torch.ops._c10d_functional.all_to_all_single.default(slice_210, [_local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279], [_local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287], '1033'); slice_210 = _local_scalar_dense_272 = _local_scalar_dense_273 = _local_scalar_dense_274 = _local_scalar_dense_275 = _local_scalar_dense_276 = _local_scalar_dense_277 = _local_scalar_dense_278 = _local_scalar_dense_279 = _local_scalar_dense_280 = _local_scalar_dense_281 = _local_scalar_dense_282 = _local_scalar_dense_283 = _local_scalar_dense_284 = _local_scalar_dense_285 = _local_scalar_dense_286 = _local_scalar_dense_287 = None + wait_tensor_696 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_95); all_to_all_single_95 = None + index_put_69 = torch.ops.aten.index_put.default(full_default_52, [div_87], wait_tensor_696, True); div_87 = wait_tensor_696 = None + add_1905 = torch.ops.aten.add.Tensor(add_1897, index_put_69); add_1897 = index_put_69 = None + mul_1575 = torch.ops.aten.mul.Tensor(view_1935, 1.0); view_1935 = None + scatter_add_8 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_1889, mul_1575); getitem_1889 = mul_1575 = None + convert_element_type_983 = torch.ops.prims.convert_element_type.default(mm_147, torch.float32); mm_147 = None + sub_408 = torch.ops.aten.sub.Tensor(convert_element_type_983, amax_17); convert_element_type_983 = amax_17 = None + exp_52 = torch.ops.aten.exp.default(sub_408); sub_408 = None + div_86 = torch.ops.aten.div.Tensor(exp_52, sum_69); exp_52 = sum_69 = None + mul_1576 = torch.ops.aten.mul.Tensor(scatter_add_8, div_86); scatter_add_8 = None + sum_171 = torch.ops.aten.sum.dim_IntList(mul_1576, [1], True) + neg_79 = torch.ops.aten.neg.default(div_86); div_86 = None + fma_8 = torch.ops.prims.fma.default(neg_79, sum_171, mul_1576); neg_79 = sum_171 = mul_1576 = None + convert_element_type_2076 = torch.ops.prims.convert_element_type.default(fma_8, torch.bfloat16); fma_8 = None + permute_832 = torch.ops.aten.permute.default(convert_element_type_2076, [1, 0]) + mm_352 = torch.ops.aten.mm.default(permute_832, view_1197); permute_832 = view_1197 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_303, torch.bfloat16); primals_303 = None + all_gather_into_tensor_307 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 128, '0'); convert_element_type_980 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_307); all_gather_into_tensor_307 = None + slice_111 = torch.ops.aten.slice.Tensor(wait_tensor_375, 0, 0, 64); wait_tensor_375 = None + permute_274 = torch.ops.aten.permute.default(slice_111, [1, 0]); slice_111 = None + permute_834 = torch.ops.aten.permute.default(permute_274, [1, 0]); permute_274 = None + mm_353 = torch.ops.aten.mm.default(convert_element_type_2076, permute_834); convert_element_type_2076 = permute_834 = None + add_1906 = torch.ops.aten.add.Tensor(add_1905, mm_353); add_1905 = mm_353 = None + convert_element_type_2081 = torch.ops.prims.convert_element_type.default(mm_352, torch.float32); mm_352 = None + split_631 = torch.ops.aten.split.Tensor(convert_element_type_2081, 1); convert_element_type_2081 = None + getitem_11685 = split_631[0] + getitem_11686 = split_631[1] + getitem_11687 = split_631[2] + getitem_11688 = split_631[3] + getitem_11689 = split_631[4] + getitem_11690 = split_631[5] + getitem_11691 = split_631[6] + getitem_11692 = split_631[7] + getitem_11693 = split_631[8] + getitem_11694 = split_631[9] + getitem_11695 = split_631[10] + getitem_11696 = split_631[11] + getitem_11697 = split_631[12] + getitem_11698 = split_631[13] + getitem_11699 = split_631[14] + getitem_11700 = split_631[15] + getitem_11701 = split_631[16] + getitem_11702 = split_631[17] + getitem_11703 = split_631[18] + getitem_11704 = split_631[19] + getitem_11705 = split_631[20] + getitem_11706 = split_631[21] + getitem_11707 = split_631[22] + getitem_11708 = split_631[23] + getitem_11709 = split_631[24] + getitem_11710 = split_631[25] + getitem_11711 = split_631[26] + getitem_11712 = split_631[27] + getitem_11713 = split_631[28] + getitem_11714 = split_631[29] + getitem_11715 = split_631[30] + getitem_11716 = split_631[31] + getitem_11717 = split_631[32] + getitem_11718 = split_631[33] + getitem_11719 = split_631[34] + getitem_11720 = split_631[35] + getitem_11721 = split_631[36] + getitem_11722 = split_631[37] + getitem_11723 = split_631[38] + getitem_11724 = split_631[39] + getitem_11725 = split_631[40] + getitem_11726 = split_631[41] + getitem_11727 = split_631[42] + getitem_11728 = split_631[43] + getitem_11729 = split_631[44] + getitem_11730 = split_631[45] + getitem_11731 = split_631[46] + getitem_11732 = split_631[47] + getitem_11733 = split_631[48] + getitem_11734 = split_631[49] + getitem_11735 = split_631[50] + getitem_11736 = split_631[51] + getitem_11737 = split_631[52] + getitem_11738 = split_631[53] + getitem_11739 = split_631[54] + getitem_11740 = split_631[55] + getitem_11741 = split_631[56] + getitem_11742 = split_631[57] + getitem_11743 = split_631[58] + getitem_11744 = split_631[59] + getitem_11745 = split_631[60] + getitem_11746 = split_631[61] + getitem_11747 = split_631[62] + getitem_11748 = split_631[63]; split_631 = None + cat_303 = torch.ops.aten.cat.default([getitem_11685, getitem_11686, getitem_11687, getitem_11688, getitem_11689, getitem_11690, getitem_11691, getitem_11692, getitem_11693, getitem_11694, getitem_11695, getitem_11696, getitem_11697, getitem_11698, getitem_11699, getitem_11700, getitem_11701, getitem_11702, getitem_11703, getitem_11704, getitem_11705, getitem_11706, getitem_11707, getitem_11708, getitem_11709, getitem_11710, getitem_11711, getitem_11712, getitem_11713, getitem_11714, getitem_11715, getitem_11716, getitem_11717, getitem_11718, getitem_11719, getitem_11720, getitem_11721, getitem_11722, getitem_11723, getitem_11724, getitem_11725, getitem_11726, getitem_11727, getitem_11728, getitem_11729, getitem_11730, getitem_11731, getitem_11732, getitem_11733, getitem_11734, getitem_11735, getitem_11736, getitem_11737, getitem_11738, getitem_11739, getitem_11740, getitem_11741, getitem_11742, getitem_11743, getitem_11744, getitem_11745, getitem_11746, getitem_11747, getitem_11748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_11685 = getitem_11686 = getitem_11687 = getitem_11688 = getitem_11689 = getitem_11690 = getitem_11691 = getitem_11692 = getitem_11693 = getitem_11694 = getitem_11695 = getitem_11696 = getitem_11697 = getitem_11698 = getitem_11699 = getitem_11700 = getitem_11701 = getitem_11702 = getitem_11703 = getitem_11704 = getitem_11705 = getitem_11706 = getitem_11707 = getitem_11708 = getitem_11709 = getitem_11710 = getitem_11711 = getitem_11712 = getitem_11713 = getitem_11714 = getitem_11715 = getitem_11716 = getitem_11717 = getitem_11718 = getitem_11719 = getitem_11720 = getitem_11721 = getitem_11722 = getitem_11723 = getitem_11724 = getitem_11725 = getitem_11726 = getitem_11727 = getitem_11728 = getitem_11729 = getitem_11730 = getitem_11731 = getitem_11732 = getitem_11733 = getitem_11734 = getitem_11735 = getitem_11736 = getitem_11737 = getitem_11738 = getitem_11739 = getitem_11740 = getitem_11741 = getitem_11742 = getitem_11743 = getitem_11744 = getitem_11745 = getitem_11746 = getitem_11747 = getitem_11748 = None + reduce_scatter_tensor_120 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_303, 'avg', 128, '0'); cat_303 = None + wait_tensor_697 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_120); reduce_scatter_tensor_120 = None + view_1937 = torch.ops.aten.view.default(add_1906, [2, 4096, 2048]); add_1906 = None + convert_element_type_2082 = torch.ops.prims.convert_element_type.default(view_1937, torch.float32); view_1937 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_301, torch.bfloat16); primals_301 = None + all_gather_into_tensor_306 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 128, '0'); convert_element_type_977 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_306); all_gather_into_tensor_306 = None + convert_element_type_2084 = torch.ops.prims.convert_element_type.default(wait_tensor_374, torch.float32); wait_tensor_374 = None + mul_1577 = torch.ops.aten.mul.Tensor(convert_element_type_2082, convert_element_type_2084); convert_element_type_2084 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_1164, torch.float32); add_1164 = None + mul_848 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_56); convert_element_type_978 = None + mul_1579 = torch.ops.aten.mul.Tensor(mul_848, mul_1577) + sum_172 = torch.ops.aten.sum.dim_IntList(mul_1579, [2], True); mul_1579 = None + div_183 = torch.ops.aten.div.Tensor(mul_848, 2048) + mul_1580 = torch.ops.aten.mul.Tensor(div_183, sum_172); div_183 = sum_172 = None + sub_676 = torch.ops.aten.sub.Tensor(mul_1577, mul_1580); mul_1577 = mul_1580 = None + mul_1581 = torch.ops.aten.mul.Tensor(sub_676, rsqrt_56); sub_676 = rsqrt_56 = None + mul_1582 = torch.ops.aten.mul.Tensor(convert_element_type_2082, mul_848); convert_element_type_2082 = mul_848 = None + sum_173 = torch.ops.aten.sum.dim_IntList(mul_1582, [0, 1]); mul_1582 = None + convert_element_type_2085 = torch.ops.prims.convert_element_type.default(mul_1581, torch.bfloat16); mul_1581 = None + add_1907 = torch.ops.aten.add.Tensor(add_1894, convert_element_type_2085); add_1894 = convert_element_type_2085 = None + convert_element_type_default_57 = torch.ops.prims.convert_element_type.default(sum_173, torch.float32); sum_173 = None + reduce_scatter_tensor_121 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_57, 'avg', 128, '0'); convert_element_type_default_57 = None + wait_tensor_698 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_121); reduce_scatter_tensor_121 = None + view_1938 = torch.ops.aten.view.default(add_1907, [8192, 2048]) + permute_836 = torch.ops.aten.permute.default(view_1938, [1, 0]) + permute_272 = torch.ops.aten.permute.default(getitem_1885, [0, 2, 1, 3]) + view_1192 = torch.ops.aten.view.default(permute_272, [2, 4096, -1]); permute_272 = None + view_1194 = torch.ops.aten.view.default(view_1192, [8192, 2048]); view_1192 = None + mm_354 = torch.ops.aten.mm.default(permute_836, view_1194); permute_836 = view_1194 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_300, torch.bfloat16); primals_300 = None + all_gather_into_tensor_305 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 128, '0'); convert_element_type_974 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_305); all_gather_into_tensor_305 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_373, [1, 0]); wait_tensor_373 = None + permute_838 = torch.ops.aten.permute.default(permute_273, [1, 0]); permute_273 = None + mm_355 = torch.ops.aten.mm.default(view_1938, permute_838); view_1938 = permute_838 = None + view_1939 = torch.ops.aten.view.default(mm_355, [2, 4096, 2048]); mm_355 = None + convert_element_type_2092 = torch.ops.prims.convert_element_type.default(mm_354, torch.float32); mm_354 = None + reduce_scatter_tensor_122 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2092, 'avg', 128, '0'); convert_element_type_2092 = None + wait_tensor_699 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_122); reduce_scatter_tensor_122 = None + view_1940 = torch.ops.aten.view.default(view_1939, [2, 4096, 16, 128]); view_1939 = None + permute_840 = torch.ops.aten.permute.default(view_1940, [0, 2, 1, 3]); view_1940 = None + fw_graph8 = self.fw_graph8 + joint_graph8 = self.joint_graph8 + mask_graph8 = self.mask_graph8 + flex_attention_backward_8 = torch.ops.higher_order.flex_attention_backward(permute_269, permute_270, permute_271, getitem_1885, getitem_1886, permute_840, None, fw_graph8, joint_graph8, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph8), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_269 = permute_270 = permute_271 = getitem_1885 = getitem_1886 = permute_840 = fw_graph8 = joint_graph8 = mask_graph8 = None + getitem_11749 = flex_attention_backward_8[0] + getitem_11750 = flex_attention_backward_8[1] + getitem_11751 = flex_attention_backward_8[2]; flex_attention_backward_8 = None + permute_841 = torch.ops.aten.permute.default(getitem_11751, [0, 2, 1, 3]); getitem_11751 = None + permute_842 = torch.ops.aten.permute.default(getitem_11750, [0, 2, 1, 3]); getitem_11750 = None + permute_843 = torch.ops.aten.permute.default(getitem_11749, [0, 2, 1, 3]); getitem_11749 = None + slice_212 = torch.ops.aten.slice.Tensor(permute_842, 3, 0, 128) + slice_213 = torch.ops.aten.slice.Tensor(permute_842, 3, 128, 192); permute_842 = None + sum_174 = torch.ops.aten.sum.dim_IntList(slice_213, [2], True); slice_213 = None + cat_304 = torch.ops.aten.cat.default([slice_212, permute_841], 3); slice_212 = permute_841 = None + view_1941 = torch.ops.aten.view.default(cat_304, [2, 4096, 4096]); cat_304 = None + view_1942 = torch.ops.aten.view.default(view_1941, [8192, 4096]); view_1941 = None + permute_844 = torch.ops.aten.permute.default(view_1942, [1, 0]) + mm_356 = torch.ops.aten.mm.default(permute_844, view_1189); permute_844 = view_1189 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(primals_299, torch.bfloat16); primals_299 = None + all_gather_into_tensor_304 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_971, 128, '0'); convert_element_type_971 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_304); all_gather_into_tensor_304 = None + permute_268 = torch.ops.aten.permute.default(wait_tensor_372, [1, 0]); wait_tensor_372 = None + permute_846 = torch.ops.aten.permute.default(permute_268, [1, 0]); permute_268 = None + mm_357 = torch.ops.aten.mm.default(view_1942, permute_846); view_1942 = permute_846 = None + view_1943 = torch.ops.aten.view.default(mm_357, [2, 4096, 512]); mm_357 = None + convert_element_type_2097 = torch.ops.prims.convert_element_type.default(mm_356, torch.float32); mm_356 = None + reduce_scatter_tensor_123 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2097, 'avg', 128, '0'); convert_element_type_2097 = None + wait_tensor_700 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_123); reduce_scatter_tensor_123 = None + convert_element_type_2098 = torch.ops.prims.convert_element_type.default(view_1943, torch.float32); view_1943 = None + convert_element_type_968 = torch.ops.prims.convert_element_type.default(primals_298, torch.bfloat16); primals_298 = None + all_gather_into_tensor_303 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_968, 128, '0'); convert_element_type_968 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_303); all_gather_into_tensor_303 = None + convert_element_type_2100 = torch.ops.prims.convert_element_type.default(wait_tensor_371, torch.float32); wait_tensor_371 = None + mul_1583 = torch.ops.aten.mul.Tensor(convert_element_type_2098, convert_element_type_2100); convert_element_type_2100 = None + convert_element_type_969 = torch.ops.prims.convert_element_type.default(getitem_1881, torch.float32); getitem_1881 = None + mul_846 = torch.ops.aten.mul.Tensor(convert_element_type_969, rsqrt_55); convert_element_type_969 = None + mul_1585 = torch.ops.aten.mul.Tensor(mul_846, mul_1583) + sum_175 = torch.ops.aten.sum.dim_IntList(mul_1585, [2], True); mul_1585 = None + div_184 = torch.ops.aten.div.Tensor(mul_846, 512) + mul_1586 = torch.ops.aten.mul.Tensor(div_184, sum_175); div_184 = sum_175 = None + sub_677 = torch.ops.aten.sub.Tensor(mul_1583, mul_1586); mul_1583 = mul_1586 = None + mul_1587 = torch.ops.aten.mul.Tensor(sub_677, rsqrt_55); sub_677 = rsqrt_55 = None + mul_1588 = torch.ops.aten.mul.Tensor(convert_element_type_2098, mul_846); convert_element_type_2098 = mul_846 = None + sum_176 = torch.ops.aten.sum.dim_IntList(mul_1588, [0, 1]); mul_1588 = None + convert_element_type_2101 = torch.ops.prims.convert_element_type.default(mul_1587, torch.bfloat16); mul_1587 = None + convert_element_type_default_56 = torch.ops.prims.convert_element_type.default(sum_176, torch.float32); sum_176 = None + reduce_scatter_tensor_124 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_56, 'avg', 128, '0'); convert_element_type_default_56 = None + wait_tensor_701 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_124); reduce_scatter_tensor_124 = None + convert_element_type_2104 = torch.ops.prims.convert_element_type.default(sum_174, torch.float32); sum_174 = None + view_1944 = torch.ops.aten.view.default(convert_element_type_2104, [2, 4096, 1, 32, 2]); convert_element_type_2104 = None + view_as_complex_70 = torch.ops.aten.view_as_complex.default(view_1944); view_1944 = None + mul_1589 = torch.ops.aten.mul.Tensor(view_as_complex_70, clone_9); view_as_complex_70 = None + view_as_real_70 = torch.ops.aten.view_as_real.default(mul_1589); mul_1589 = None + view_1945 = torch.ops.aten.view.default(view_as_real_70, [2, 4096, 1, 64]); view_as_real_70 = None + convert_element_type_2105 = torch.ops.prims.convert_element_type.default(view_1945, torch.bfloat16); view_1945 = None + squeeze_34 = torch.ops.aten.squeeze.dim(convert_element_type_2105, 2); convert_element_type_2105 = None + cat_305 = torch.ops.aten.cat.default([convert_element_type_2101, squeeze_34], 2); convert_element_type_2101 = squeeze_34 = None + view_1946 = torch.ops.aten.view.default(cat_305, [8192, 576]); cat_305 = None + permute_848 = torch.ops.aten.permute.default(view_1946, [1, 0]) + mm_358 = torch.ops.aten.mm.default(permute_848, view_1175); permute_848 = None + convert_element_type_963 = torch.ops.prims.convert_element_type.default(primals_297, torch.bfloat16); primals_297 = None + all_gather_into_tensor_302 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_963, 128, '0'); convert_element_type_963 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_302); all_gather_into_tensor_302 = None + slice_109 = torch.ops.aten.slice.Tensor(wait_tensor_370, 0, 0, 576); wait_tensor_370 = None + permute_267 = torch.ops.aten.permute.default(slice_109, [1, 0]); slice_109 = None + permute_850 = torch.ops.aten.permute.default(permute_267, [1, 0]); permute_267 = None + mm_359 = torch.ops.aten.mm.default(view_1946, permute_850); view_1946 = permute_850 = None + view_1947 = torch.ops.aten.view.default(mm_359, [2, 4096, 2048]); mm_359 = None + convert_element_type_2110 = torch.ops.prims.convert_element_type.default(mm_358, torch.float32); mm_358 = None + split_632 = torch.ops.aten.split.Tensor(convert_element_type_2110, 5); convert_element_type_2110 = None + getitem_11753 = split_632[0] + getitem_11754 = split_632[1] + getitem_11755 = split_632[2] + getitem_11756 = split_632[3] + getitem_11757 = split_632[4] + getitem_11758 = split_632[5] + getitem_11759 = split_632[6] + getitem_11760 = split_632[7] + getitem_11761 = split_632[8] + getitem_11762 = split_632[9] + getitem_11763 = split_632[10] + getitem_11764 = split_632[11] + getitem_11765 = split_632[12] + getitem_11766 = split_632[13] + getitem_11767 = split_632[14] + getitem_11768 = split_632[15] + getitem_11769 = split_632[16] + getitem_11770 = split_632[17] + getitem_11771 = split_632[18] + getitem_11772 = split_632[19] + getitem_11773 = split_632[20] + getitem_11774 = split_632[21] + getitem_11775 = split_632[22] + getitem_11776 = split_632[23] + getitem_11777 = split_632[24] + getitem_11778 = split_632[25] + getitem_11779 = split_632[26] + getitem_11780 = split_632[27] + getitem_11781 = split_632[28] + getitem_11782 = split_632[29] + getitem_11783 = split_632[30] + getitem_11784 = split_632[31] + getitem_11785 = split_632[32] + getitem_11786 = split_632[33] + getitem_11787 = split_632[34] + getitem_11788 = split_632[35] + getitem_11789 = split_632[36] + getitem_11790 = split_632[37] + getitem_11791 = split_632[38] + getitem_11792 = split_632[39] + getitem_11793 = split_632[40] + getitem_11794 = split_632[41] + getitem_11795 = split_632[42] + getitem_11796 = split_632[43] + getitem_11797 = split_632[44] + getitem_11798 = split_632[45] + getitem_11799 = split_632[46] + getitem_11800 = split_632[47] + getitem_11801 = split_632[48] + getitem_11802 = split_632[49] + getitem_11803 = split_632[50] + getitem_11804 = split_632[51] + getitem_11805 = split_632[52] + getitem_11806 = split_632[53] + getitem_11807 = split_632[54] + getitem_11808 = split_632[55] + getitem_11809 = split_632[56] + getitem_11810 = split_632[57] + getitem_11811 = split_632[58] + getitem_11812 = split_632[59] + getitem_11813 = split_632[60] + getitem_11814 = split_632[61] + getitem_11815 = split_632[62] + getitem_11816 = split_632[63] + getitem_11817 = split_632[64] + getitem_11818 = split_632[65] + getitem_11819 = split_632[66] + getitem_11820 = split_632[67] + getitem_11821 = split_632[68] + getitem_11822 = split_632[69] + getitem_11823 = split_632[70] + getitem_11824 = split_632[71] + getitem_11825 = split_632[72] + getitem_11826 = split_632[73] + getitem_11827 = split_632[74] + getitem_11828 = split_632[75] + getitem_11829 = split_632[76] + getitem_11830 = split_632[77] + getitem_11831 = split_632[78] + getitem_11832 = split_632[79] + getitem_11833 = split_632[80] + getitem_11834 = split_632[81] + getitem_11835 = split_632[82] + getitem_11836 = split_632[83] + getitem_11837 = split_632[84] + getitem_11838 = split_632[85] + getitem_11839 = split_632[86] + getitem_11840 = split_632[87] + getitem_11841 = split_632[88] + getitem_11842 = split_632[89] + getitem_11843 = split_632[90] + getitem_11844 = split_632[91] + getitem_11845 = split_632[92] + getitem_11846 = split_632[93] + getitem_11847 = split_632[94] + getitem_11848 = split_632[95] + getitem_11849 = split_632[96] + getitem_11850 = split_632[97] + getitem_11851 = split_632[98] + getitem_11852 = split_632[99] + getitem_11853 = split_632[100] + getitem_11854 = split_632[101] + getitem_11855 = split_632[102] + getitem_11856 = split_632[103] + getitem_11857 = split_632[104] + getitem_11858 = split_632[105] + getitem_11859 = split_632[106] + getitem_11860 = split_632[107] + getitem_11861 = split_632[108] + getitem_11862 = split_632[109] + getitem_11863 = split_632[110] + getitem_11864 = split_632[111] + getitem_11865 = split_632[112] + getitem_11866 = split_632[113] + getitem_11867 = split_632[114] + getitem_11868 = split_632[115]; split_632 = None + constant_pad_nd_680 = torch.ops.aten.constant_pad_nd.default(getitem_11868, [0, 0, 0, 4], 0.0); getitem_11868 = None + cat_306 = torch.ops.aten.cat.default([getitem_11753, getitem_11754, getitem_11755, getitem_11756, getitem_11757, getitem_11758, getitem_11759, getitem_11760, getitem_11761, getitem_11762, getitem_11763, getitem_11764, getitem_11765, getitem_11766, getitem_11767, getitem_11768, getitem_11769, getitem_11770, getitem_11771, getitem_11772, getitem_11773, getitem_11774, getitem_11775, getitem_11776, getitem_11777, getitem_11778, getitem_11779, getitem_11780, getitem_11781, getitem_11782, getitem_11783, getitem_11784, getitem_11785, getitem_11786, getitem_11787, getitem_11788, getitem_11789, getitem_11790, getitem_11791, getitem_11792, getitem_11793, getitem_11794, getitem_11795, getitem_11796, getitem_11797, getitem_11798, getitem_11799, getitem_11800, getitem_11801, getitem_11802, getitem_11803, getitem_11804, getitem_11805, getitem_11806, getitem_11807, getitem_11808, getitem_11809, getitem_11810, getitem_11811, getitem_11812, getitem_11813, getitem_11814, getitem_11815, getitem_11816, getitem_11817, getitem_11818, getitem_11819, getitem_11820, getitem_11821, getitem_11822, getitem_11823, getitem_11824, getitem_11825, getitem_11826, getitem_11827, getitem_11828, getitem_11829, getitem_11830, getitem_11831, getitem_11832, getitem_11833, getitem_11834, getitem_11835, getitem_11836, getitem_11837, getitem_11838, getitem_11839, getitem_11840, getitem_11841, getitem_11842, getitem_11843, getitem_11844, getitem_11845, getitem_11846, getitem_11847, getitem_11848, getitem_11849, getitem_11850, getitem_11851, getitem_11852, getitem_11853, getitem_11854, getitem_11855, getitem_11856, getitem_11857, getitem_11858, getitem_11859, getitem_11860, getitem_11861, getitem_11862, getitem_11863, getitem_11864, getitem_11865, getitem_11866, getitem_11867, constant_pad_nd_680, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_11753 = getitem_11754 = getitem_11755 = getitem_11756 = getitem_11757 = getitem_11758 = getitem_11759 = getitem_11760 = getitem_11761 = getitem_11762 = getitem_11763 = getitem_11764 = getitem_11765 = getitem_11766 = getitem_11767 = getitem_11768 = getitem_11769 = getitem_11770 = getitem_11771 = getitem_11772 = getitem_11773 = getitem_11774 = getitem_11775 = getitem_11776 = getitem_11777 = getitem_11778 = getitem_11779 = getitem_11780 = getitem_11781 = getitem_11782 = getitem_11783 = getitem_11784 = getitem_11785 = getitem_11786 = getitem_11787 = getitem_11788 = getitem_11789 = getitem_11790 = getitem_11791 = getitem_11792 = getitem_11793 = getitem_11794 = getitem_11795 = getitem_11796 = getitem_11797 = getitem_11798 = getitem_11799 = getitem_11800 = getitem_11801 = getitem_11802 = getitem_11803 = getitem_11804 = getitem_11805 = getitem_11806 = getitem_11807 = getitem_11808 = getitem_11809 = getitem_11810 = getitem_11811 = getitem_11812 = getitem_11813 = getitem_11814 = getitem_11815 = getitem_11816 = getitem_11817 = getitem_11818 = getitem_11819 = getitem_11820 = getitem_11821 = getitem_11822 = getitem_11823 = getitem_11824 = getitem_11825 = getitem_11826 = getitem_11827 = getitem_11828 = getitem_11829 = getitem_11830 = getitem_11831 = getitem_11832 = getitem_11833 = getitem_11834 = getitem_11835 = getitem_11836 = getitem_11837 = getitem_11838 = getitem_11839 = getitem_11840 = getitem_11841 = getitem_11842 = getitem_11843 = getitem_11844 = getitem_11845 = getitem_11846 = getitem_11847 = getitem_11848 = getitem_11849 = getitem_11850 = getitem_11851 = getitem_11852 = getitem_11853 = getitem_11854 = getitem_11855 = getitem_11856 = getitem_11857 = getitem_11858 = getitem_11859 = getitem_11860 = getitem_11861 = getitem_11862 = getitem_11863 = getitem_11864 = getitem_11865 = getitem_11866 = getitem_11867 = constant_pad_nd_680 = None + reduce_scatter_tensor_125 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_306, 'avg', 128, '0'); cat_306 = None + wait_tensor_702 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_125); reduce_scatter_tensor_125 = None + slice_214 = torch.ops.aten.slice.Tensor(permute_843, 3, 0, 128) + slice_215 = torch.ops.aten.slice.Tensor(permute_843, 3, 128, 192); permute_843 = None + convert_element_type_2111 = torch.ops.prims.convert_element_type.default(slice_215, torch.float32); slice_215 = None + view_1948 = torch.ops.aten.view.default(convert_element_type_2111, [2, 4096, 16, 32, 2]); convert_element_type_2111 = None + view_as_complex_71 = torch.ops.aten.view_as_complex.default(view_1948); view_1948 = None + mul_1590 = torch.ops.aten.mul.Tensor(view_as_complex_71, clone_9); view_as_complex_71 = None + view_as_real_71 = torch.ops.aten.view_as_real.default(mul_1590); mul_1590 = None + view_1949 = torch.ops.aten.view.default(view_as_real_71, [2, 4096, 16, 64]); view_as_real_71 = None + convert_element_type_2112 = torch.ops.prims.convert_element_type.default(view_1949, torch.bfloat16); view_1949 = None + cat_307 = torch.ops.aten.cat.default([slice_214, convert_element_type_2112], 3); slice_214 = convert_element_type_2112 = None + view_1950 = torch.ops.aten.view.default(cat_307, [2, 4096, 3072]); cat_307 = None + view_1951 = torch.ops.aten.view.default(view_1950, [8192, 3072]); view_1950 = None + permute_852 = torch.ops.aten.permute.default(view_1951, [1, 0]) + mm_360 = torch.ops.aten.mm.default(permute_852, view_1175); permute_852 = view_1175 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_296, torch.bfloat16); primals_296 = None + all_gather_into_tensor_301 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 128, '0'); convert_element_type_958 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_301); all_gather_into_tensor_301 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_369, [1, 0]); wait_tensor_369 = None + permute_854 = torch.ops.aten.permute.default(permute_266, [1, 0]); permute_266 = None + mm_361 = torch.ops.aten.mm.default(view_1951, permute_854); view_1951 = permute_854 = None + view_1952 = torch.ops.aten.view.default(mm_361, [2, 4096, 2048]); mm_361 = None + add_1908 = torch.ops.aten.add.Tensor(view_1947, view_1952); view_1947 = view_1952 = None + convert_element_type_2117 = torch.ops.prims.convert_element_type.default(mm_360, torch.float32); mm_360 = None + reduce_scatter_tensor_126 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2117, 'avg', 128, '0'); convert_element_type_2117 = None + wait_tensor_703 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_126); reduce_scatter_tensor_126 = None + convert_element_type_2118 = torch.ops.prims.convert_element_type.default(add_1908, torch.float32); add_1908 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_295, torch.bfloat16); primals_295 = None + all_gather_into_tensor_300 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 128, '0'); convert_element_type_955 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_300); all_gather_into_tensor_300 = None + convert_element_type_2120 = torch.ops.prims.convert_element_type.default(wait_tensor_368, torch.float32); wait_tensor_368 = None + mul_1591 = torch.ops.aten.mul.Tensor(convert_element_type_2118, convert_element_type_2120); convert_element_type_2120 = None + convert_element_type_956 = torch.ops.prims.convert_element_type.default(add_1161, torch.float32); add_1161 = None + mul_842 = torch.ops.aten.mul.Tensor(convert_element_type_956, rsqrt_54); convert_element_type_956 = None + mul_1593 = torch.ops.aten.mul.Tensor(mul_842, mul_1591) + sum_177 = torch.ops.aten.sum.dim_IntList(mul_1593, [2], True); mul_1593 = None + div_185 = torch.ops.aten.div.Tensor(mul_842, 2048) + mul_1594 = torch.ops.aten.mul.Tensor(div_185, sum_177); div_185 = sum_177 = None + sub_678 = torch.ops.aten.sub.Tensor(mul_1591, mul_1594); mul_1591 = mul_1594 = None + mul_1595 = torch.ops.aten.mul.Tensor(sub_678, rsqrt_54); sub_678 = rsqrt_54 = None + mul_1596 = torch.ops.aten.mul.Tensor(convert_element_type_2118, mul_842); convert_element_type_2118 = mul_842 = None + sum_178 = torch.ops.aten.sum.dim_IntList(mul_1596, [0, 1]); mul_1596 = None + convert_element_type_2121 = torch.ops.prims.convert_element_type.default(mul_1595, torch.bfloat16); mul_1595 = None + add_1909 = torch.ops.aten.add.Tensor(add_1907, convert_element_type_2121); add_1907 = convert_element_type_2121 = None + convert_element_type_default_55 = torch.ops.prims.convert_element_type.default(sum_178, torch.float32); sum_178 = None + reduce_scatter_tensor_127 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_55, 'avg', 128, '0'); convert_element_type_default_55 = None + wait_tensor_704 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_127); reduce_scatter_tensor_127 = None + view_1953 = torch.ops.aten.view.default(add_1909, [8192, 2048]) + unsqueeze_62 = torch.ops.aten.unsqueeze.default(view_1953, 1) + convert_element_type_2124 = torch.ops.prims.convert_element_type.default(unsqueeze_62, torch.float32); unsqueeze_62 = None + bmm_44 = torch.ops.aten.bmm.default(permute_856, convert_element_type_2124); permute_856 = None + bmm_45 = torch.ops.aten.bmm.default(convert_element_type_2124, permute_857); convert_element_type_2124 = permute_857 = None + convert_element_type_2125 = torch.ops.prims.convert_element_type.default(bmm_44, torch.bfloat16); bmm_44 = None + view_1954 = torch.ops.aten.view.default(bmm_45, [8192, 6]); bmm_45 = None + view_1955 = torch.ops.aten.view.default(convert_element_type_2125, [49152, 2048]); convert_element_type_2125 = None + index_70 = torch.ops.aten.index.Tensor(view_1955, [getitem_1781]); view_1955 = getitem_1781 = None + permute_858 = torch.ops.aten.permute.default(view_1953, [1, 0]) + mm_362 = torch.ops.aten.mm.default(permute_858, mul_839); permute_858 = mul_839 = None + convert_element_type_950 = torch.ops.prims.convert_element_type.default(primals_294, torch.bfloat16); primals_294 = None + all_gather_into_tensor_299 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_950, 128, '0'); convert_element_type_950 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_299); all_gather_into_tensor_299 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_367, [1, 0]); wait_tensor_367 = None + permute_860 = torch.ops.aten.permute.default(permute_265, [1, 0]); permute_265 = None + mm_363 = torch.ops.aten.mm.default(view_1953, permute_860); view_1953 = permute_860 = None + convert_element_type_2130 = torch.ops.prims.convert_element_type.default(mm_362, torch.float32); mm_362 = None + reduce_scatter_tensor_128 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2130, 'avg', 128, '0'); convert_element_type_2130 = None + wait_tensor_705 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_128); reduce_scatter_tensor_128 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(mm_140, torch.float32); mm_140 = None + neg_34 = torch.ops.aten.neg.default(convert_element_type_945) + exp_51 = torch.ops.aten.exp.default(neg_34); neg_34 = None + add_1156 = torch.ops.aten.add.Tensor(exp_51, 1); exp_51 = None + div_85 = torch.ops.aten.div.Tensor(convert_element_type_945, add_1156) + convert_element_type_946 = torch.ops.prims.convert_element_type.default(div_85, torch.bfloat16); div_85 = None + mul_1597 = torch.ops.aten.mul.Tensor(mm_363, convert_element_type_946); convert_element_type_946 = None + mul_1598 = torch.ops.aten.mul.Tensor(mm_363, mm_141); mm_363 = mm_141 = None + permute_862 = torch.ops.aten.permute.default(mul_1597, [1, 0]) + mm_364 = torch.ops.aten.mm.default(permute_862, view_1130); permute_862 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16); primals_293 = None + all_gather_into_tensor_298 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 128, '0'); convert_element_type_947 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_298); all_gather_into_tensor_298 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_366, [1, 0]); wait_tensor_366 = None + permute_864 = torch.ops.aten.permute.default(permute_264, [1, 0]); permute_264 = None + mm_365 = torch.ops.aten.mm.default(mul_1597, permute_864); mul_1597 = permute_864 = None + convert_element_type_2135 = torch.ops.prims.convert_element_type.default(mm_364, torch.float32); mm_364 = None + reduce_scatter_tensor_129 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2135, 'avg', 128, '0'); convert_element_type_2135 = None + wait_tensor_706 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_129); reduce_scatter_tensor_129 = None + convert_element_type_2136 = torch.ops.prims.convert_element_type.default(mul_1598, torch.float32); mul_1598 = None + reciprocal_18 = torch.ops.aten.reciprocal.default(add_1156); add_1156 = None + mul_1599 = torch.ops.aten.mul.Tensor(reciprocal_18, 1); reciprocal_18 = None + mul_1600 = torch.ops.aten.mul.Tensor(convert_element_type_2136, mul_1599); convert_element_type_2136 = None + sub_679 = torch.ops.aten.sub.Tensor(1, mul_1599); mul_1599 = None + mul_1601 = torch.ops.aten.mul.Tensor(convert_element_type_945, sub_679); convert_element_type_945 = sub_679 = None + add_1911 = torch.ops.aten.add.Tensor(mul_1601, 1); mul_1601 = None + mul_1602 = torch.ops.aten.mul.Tensor(mul_1600, add_1911); mul_1600 = add_1911 = None + convert_element_type_2138 = torch.ops.prims.convert_element_type.default(mul_1602, torch.bfloat16); mul_1602 = None + permute_866 = torch.ops.aten.permute.default(convert_element_type_2138, [1, 0]) + mm_366 = torch.ops.aten.mm.default(permute_866, view_1130); permute_866 = None + convert_element_type_942 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16); primals_292 = None + all_gather_into_tensor_297 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_942, 128, '0'); convert_element_type_942 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_297); all_gather_into_tensor_297 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_365, [1, 0]); wait_tensor_365 = None + permute_868 = torch.ops.aten.permute.default(permute_263, [1, 0]); permute_263 = None + mm_367 = torch.ops.aten.mm.default(convert_element_type_2138, permute_868); convert_element_type_2138 = permute_868 = None + add_1912 = torch.ops.aten.add.Tensor(mm_365, mm_367); mm_365 = mm_367 = None + convert_element_type_2143 = torch.ops.prims.convert_element_type.default(mm_366, torch.float32); mm_366 = None + reduce_scatter_tensor_130 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2143, 'avg', 128, '0'); convert_element_type_2143 = None + wait_tensor_707 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_130); reduce_scatter_tensor_130 = None + all_to_all_single_96 = torch.ops._c10d_functional.all_to_all_single.default(index_70, [_local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271], [_local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263], '1033'); index_70 = None + wait_tensor_708 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_96); all_to_all_single_96 = None + full_402 = torch.ops.aten.full.default([sym_size_int_65, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_65 = None + slice_scatter_9 = torch.ops.aten.slice_scatter.default(full_402, wait_tensor_708, 0, 0, -1); wait_tensor_708 = None + index_71 = torch.ops.aten.index.Tensor(slice_scatter_9, [getitem_1782]); slice_scatter_9 = None + permute_870 = torch.ops.aten.permute.default(index_71, [1, 0]) + _grouped_mm_132 = torch.ops.aten._grouped_mm.default(permute_870, mul_819, cumsum_50); permute_870 = mul_819 = None + _grouped_mm_133 = torch.ops.aten._grouped_mm.default(index_71, permute_872, cumsum_50); index_71 = permute_872 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(_grouped_mm_48, torch.float32); _grouped_mm_48 = None + neg_33 = torch.ops.aten.neg.default(convert_element_type_940) + exp_50 = torch.ops.aten.exp.default(neg_33); neg_33 = None + add_1120 = torch.ops.aten.add.Tensor(exp_50, 1); exp_50 = None + div_84 = torch.ops.aten.div.Tensor(convert_element_type_940, add_1120) + convert_element_type_941 = torch.ops.prims.convert_element_type.default(div_84, torch.bfloat16); div_84 = None + mul_1603 = torch.ops.aten.mul.Tensor(_grouped_mm_133, convert_element_type_941); convert_element_type_941 = None + mul_1604 = torch.ops.aten.mul.Tensor(_grouped_mm_133, _grouped_mm_49); _grouped_mm_133 = _grouped_mm_49 = None + permute_874 = torch.ops.aten.permute.default(mul_1603, [1, 0]) + _grouped_mm_134 = torch.ops.aten._grouped_mm.default(permute_874, index_33, cumsum_50); permute_874 = None + _grouped_mm_135 = torch.ops.aten._grouped_mm.default(mul_1603, permute_876, cumsum_50); mul_1603 = permute_876 = None + convert_element_type_2144 = torch.ops.prims.convert_element_type.default(mul_1604, torch.float32); mul_1604 = None + reciprocal_19 = torch.ops.aten.reciprocal.default(add_1120); add_1120 = None + mul_1605 = torch.ops.aten.mul.Tensor(reciprocal_19, 1); reciprocal_19 = None + mul_1606 = torch.ops.aten.mul.Tensor(convert_element_type_2144, mul_1605); convert_element_type_2144 = None + sub_680 = torch.ops.aten.sub.Tensor(1, mul_1605); mul_1605 = None + mul_1607 = torch.ops.aten.mul.Tensor(convert_element_type_940, sub_680); convert_element_type_940 = sub_680 = None + add_1914 = torch.ops.aten.add.Tensor(mul_1607, 1); mul_1607 = None + mul_1608 = torch.ops.aten.mul.Tensor(mul_1606, add_1914); mul_1606 = add_1914 = None + convert_element_type_2146 = torch.ops.prims.convert_element_type.default(mul_1608, torch.bfloat16); mul_1608 = None + permute_878 = torch.ops.aten.permute.default(convert_element_type_2146, [1, 0]) + _grouped_mm_136 = torch.ops.aten._grouped_mm.default(permute_878, index_33, cumsum_50); permute_878 = index_33 = None + _grouped_mm_137 = torch.ops.aten._grouped_mm.default(convert_element_type_2146, permute_880, cumsum_50); convert_element_type_2146 = permute_880 = cumsum_50 = None + add_1915 = torch.ops.aten.add.Tensor(_grouped_mm_135, _grouped_mm_137); _grouped_mm_135 = _grouped_mm_137 = None + convert_element_type_2147 = torch.ops.prims.convert_element_type.default(_grouped_mm_134, torch.float32); _grouped_mm_134 = None + div_186 = torch.ops.aten.div.Tensor(convert_element_type_2147, 128); convert_element_type_2147 = None + split_634 = torch.ops.aten.split.Tensor(div_186, 88, 1); div_186 = None + getitem_11885 = split_634[0] + getitem_11902 = split_634[1] + getitem_11919 = split_634[2] + getitem_11936 = split_634[3] + getitem_11953 = split_634[4] + getitem_11970 = split_634[5] + getitem_11987 = split_634[6] + getitem_12004 = split_634[7] + getitem_12021 = split_634[8] + getitem_12038 = split_634[9] + getitem_12055 = split_634[10] + getitem_12072 = split_634[11] + getitem_12089 = split_634[12] + getitem_12106 = split_634[13] + getitem_12123 = split_634[14] + getitem_12140 = split_634[15]; split_634 = None + cat_308 = torch.ops.aten.cat.default([getitem_11885, getitem_11902, getitem_11919, getitem_11936, getitem_11953, getitem_11970, getitem_11987, getitem_12004, getitem_12021, getitem_12038, getitem_12055, getitem_12072, getitem_12089, getitem_12106, getitem_12123, getitem_12140]); getitem_11885 = getitem_11902 = getitem_11919 = getitem_11936 = getitem_11953 = getitem_11970 = getitem_11987 = getitem_12004 = getitem_12021 = getitem_12038 = getitem_12055 = getitem_12072 = getitem_12089 = getitem_12106 = getitem_12123 = getitem_12140 = None + reduce_scatter_tensor_131 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_308, 'sum', 16, '1025'); cat_308 = None + wait_tensor_709 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_131); reduce_scatter_tensor_131 = None + convert_element_type_2148 = torch.ops.prims.convert_element_type.default(_grouped_mm_132, torch.float32); _grouped_mm_132 = None + div_187 = torch.ops.aten.div.Tensor(convert_element_type_2148, 128); convert_element_type_2148 = None + split_651 = torch.ops.aten.split.Tensor(div_187, 128, 1); div_187 = None + getitem_12157 = split_651[0] + getitem_12174 = split_651[1] + getitem_12191 = split_651[2] + getitem_12208 = split_651[3] + getitem_12225 = split_651[4] + getitem_12242 = split_651[5] + getitem_12259 = split_651[6] + getitem_12276 = split_651[7] + getitem_12293 = split_651[8] + getitem_12310 = split_651[9] + getitem_12327 = split_651[10] + getitem_12344 = split_651[11] + getitem_12361 = split_651[12] + getitem_12378 = split_651[13] + getitem_12395 = split_651[14] + getitem_12412 = split_651[15]; split_651 = None + cat_309 = torch.ops.aten.cat.default([getitem_12157, getitem_12174, getitem_12191, getitem_12208, getitem_12225, getitem_12242, getitem_12259, getitem_12276, getitem_12293, getitem_12310, getitem_12327, getitem_12344, getitem_12361, getitem_12378, getitem_12395, getitem_12412]); getitem_12157 = getitem_12174 = getitem_12191 = getitem_12208 = getitem_12225 = getitem_12242 = getitem_12259 = getitem_12276 = getitem_12293 = getitem_12310 = getitem_12327 = getitem_12344 = getitem_12361 = getitem_12378 = getitem_12395 = getitem_12412 = None + reduce_scatter_tensor_132 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_309, 'sum', 16, '1025'); cat_309 = None + wait_tensor_710 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_132); reduce_scatter_tensor_132 = None + convert_element_type_2149 = torch.ops.prims.convert_element_type.default(_grouped_mm_136, torch.float32); _grouped_mm_136 = None + div_188 = torch.ops.aten.div.Tensor(convert_element_type_2149, 128); convert_element_type_2149 = None + split_668 = torch.ops.aten.split.Tensor(div_188, 88, 1); div_188 = None + getitem_12429 = split_668[0] + getitem_12446 = split_668[1] + getitem_12463 = split_668[2] + getitem_12480 = split_668[3] + getitem_12497 = split_668[4] + getitem_12514 = split_668[5] + getitem_12531 = split_668[6] + getitem_12548 = split_668[7] + getitem_12565 = split_668[8] + getitem_12582 = split_668[9] + getitem_12599 = split_668[10] + getitem_12616 = split_668[11] + getitem_12633 = split_668[12] + getitem_12650 = split_668[13] + getitem_12667 = split_668[14] + getitem_12684 = split_668[15]; split_668 = None + cat_310 = torch.ops.aten.cat.default([getitem_12429, getitem_12446, getitem_12463, getitem_12480, getitem_12497, getitem_12514, getitem_12531, getitem_12548, getitem_12565, getitem_12582, getitem_12599, getitem_12616, getitem_12633, getitem_12650, getitem_12667, getitem_12684]); getitem_12429 = getitem_12446 = getitem_12463 = getitem_12480 = getitem_12497 = getitem_12514 = getitem_12531 = getitem_12548 = getitem_12565 = getitem_12582 = getitem_12599 = getitem_12616 = getitem_12633 = getitem_12650 = getitem_12667 = getitem_12684 = None + reduce_scatter_tensor_133 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_310, 'sum', 16, '1025'); cat_310 = None + wait_tensor_711 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_133); reduce_scatter_tensor_133 = None + index_put_70 = torch.ops.aten.index_put.default(full_402, [getitem_1782], add_1915, True); full_402 = getitem_1782 = add_1915 = None + slice_216 = torch.ops.aten.slice.Tensor(index_put_70, 0, 0, add_1916); index_put_70 = add_1916 = None + all_to_all_single_97 = torch.ops._c10d_functional.all_to_all_single.default(slice_216, [_local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263], [_local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271], '1033'); slice_216 = _local_scalar_dense_256 = _local_scalar_dense_257 = _local_scalar_dense_258 = _local_scalar_dense_259 = _local_scalar_dense_260 = _local_scalar_dense_261 = _local_scalar_dense_262 = _local_scalar_dense_263 = _local_scalar_dense_264 = _local_scalar_dense_265 = _local_scalar_dense_266 = _local_scalar_dense_267 = _local_scalar_dense_268 = _local_scalar_dense_269 = _local_scalar_dense_270 = _local_scalar_dense_271 = None + wait_tensor_712 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_97); all_to_all_single_97 = None + index_put_71 = torch.ops.aten.index_put.default(full_default_52, [div_82], wait_tensor_712, True); div_82 = wait_tensor_712 = None + add_1920 = torch.ops.aten.add.Tensor(add_1912, index_put_71); add_1912 = index_put_71 = None + mul_1609 = torch.ops.aten.mul.Tensor(view_1954, 1.0); view_1954 = None + scatter_add_9 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_1779, mul_1609); getitem_1779 = mul_1609 = None + convert_element_type_929 = torch.ops.prims.convert_element_type.default(mm_139, torch.float32); mm_139 = None + sub_384 = torch.ops.aten.sub.Tensor(convert_element_type_929, amax_16); convert_element_type_929 = amax_16 = None + exp_49 = torch.ops.aten.exp.default(sub_384); sub_384 = None + div_81 = torch.ops.aten.div.Tensor(exp_49, sum_65); exp_49 = sum_65 = None + mul_1610 = torch.ops.aten.mul.Tensor(scatter_add_9, div_81); scatter_add_9 = None + sum_179 = torch.ops.aten.sum.dim_IntList(mul_1610, [1], True) + neg_82 = torch.ops.aten.neg.default(div_81); div_81 = None + fma_9 = torch.ops.prims.fma.default(neg_82, sum_179, mul_1610); neg_82 = sum_179 = mul_1610 = None + convert_element_type_2150 = torch.ops.prims.convert_element_type.default(fma_9, torch.bfloat16); fma_9 = None + permute_882 = torch.ops.aten.permute.default(convert_element_type_2150, [1, 0]) + mm_368 = torch.ops.aten.mm.default(permute_882, view_1130); permute_882 = view_1130 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16); primals_287 = None + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_926, 128, '0'); convert_element_type_926 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + slice_105 = torch.ops.aten.slice.Tensor(wait_tensor_354, 0, 0, 64); wait_tensor_354 = None + permute_259 = torch.ops.aten.permute.default(slice_105, [1, 0]); slice_105 = None + permute_884 = torch.ops.aten.permute.default(permute_259, [1, 0]); permute_259 = None + mm_369 = torch.ops.aten.mm.default(convert_element_type_2150, permute_884); convert_element_type_2150 = permute_884 = None + add_1921 = torch.ops.aten.add.Tensor(add_1920, mm_369); add_1920 = mm_369 = None + convert_element_type_2155 = torch.ops.prims.convert_element_type.default(mm_368, torch.float32); mm_368 = None + split_684 = torch.ops.aten.split.Tensor(convert_element_type_2155, 1); convert_element_type_2155 = None + getitem_12685 = split_684[0] + getitem_12686 = split_684[1] + getitem_12687 = split_684[2] + getitem_12688 = split_684[3] + getitem_12689 = split_684[4] + getitem_12690 = split_684[5] + getitem_12691 = split_684[6] + getitem_12692 = split_684[7] + getitem_12693 = split_684[8] + getitem_12694 = split_684[9] + getitem_12695 = split_684[10] + getitem_12696 = split_684[11] + getitem_12697 = split_684[12] + getitem_12698 = split_684[13] + getitem_12699 = split_684[14] + getitem_12700 = split_684[15] + getitem_12701 = split_684[16] + getitem_12702 = split_684[17] + getitem_12703 = split_684[18] + getitem_12704 = split_684[19] + getitem_12705 = split_684[20] + getitem_12706 = split_684[21] + getitem_12707 = split_684[22] + getitem_12708 = split_684[23] + getitem_12709 = split_684[24] + getitem_12710 = split_684[25] + getitem_12711 = split_684[26] + getitem_12712 = split_684[27] + getitem_12713 = split_684[28] + getitem_12714 = split_684[29] + getitem_12715 = split_684[30] + getitem_12716 = split_684[31] + getitem_12717 = split_684[32] + getitem_12718 = split_684[33] + getitem_12719 = split_684[34] + getitem_12720 = split_684[35] + getitem_12721 = split_684[36] + getitem_12722 = split_684[37] + getitem_12723 = split_684[38] + getitem_12724 = split_684[39] + getitem_12725 = split_684[40] + getitem_12726 = split_684[41] + getitem_12727 = split_684[42] + getitem_12728 = split_684[43] + getitem_12729 = split_684[44] + getitem_12730 = split_684[45] + getitem_12731 = split_684[46] + getitem_12732 = split_684[47] + getitem_12733 = split_684[48] + getitem_12734 = split_684[49] + getitem_12735 = split_684[50] + getitem_12736 = split_684[51] + getitem_12737 = split_684[52] + getitem_12738 = split_684[53] + getitem_12739 = split_684[54] + getitem_12740 = split_684[55] + getitem_12741 = split_684[56] + getitem_12742 = split_684[57] + getitem_12743 = split_684[58] + getitem_12744 = split_684[59] + getitem_12745 = split_684[60] + getitem_12746 = split_684[61] + getitem_12747 = split_684[62] + getitem_12748 = split_684[63]; split_684 = None + cat_311 = torch.ops.aten.cat.default([getitem_12685, getitem_12686, getitem_12687, getitem_12688, getitem_12689, getitem_12690, getitem_12691, getitem_12692, getitem_12693, getitem_12694, getitem_12695, getitem_12696, getitem_12697, getitem_12698, getitem_12699, getitem_12700, getitem_12701, getitem_12702, getitem_12703, getitem_12704, getitem_12705, getitem_12706, getitem_12707, getitem_12708, getitem_12709, getitem_12710, getitem_12711, getitem_12712, getitem_12713, getitem_12714, getitem_12715, getitem_12716, getitem_12717, getitem_12718, getitem_12719, getitem_12720, getitem_12721, getitem_12722, getitem_12723, getitem_12724, getitem_12725, getitem_12726, getitem_12727, getitem_12728, getitem_12729, getitem_12730, getitem_12731, getitem_12732, getitem_12733, getitem_12734, getitem_12735, getitem_12736, getitem_12737, getitem_12738, getitem_12739, getitem_12740, getitem_12741, getitem_12742, getitem_12743, getitem_12744, getitem_12745, getitem_12746, getitem_12747, getitem_12748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_12685 = getitem_12686 = getitem_12687 = getitem_12688 = getitem_12689 = getitem_12690 = getitem_12691 = getitem_12692 = getitem_12693 = getitem_12694 = getitem_12695 = getitem_12696 = getitem_12697 = getitem_12698 = getitem_12699 = getitem_12700 = getitem_12701 = getitem_12702 = getitem_12703 = getitem_12704 = getitem_12705 = getitem_12706 = getitem_12707 = getitem_12708 = getitem_12709 = getitem_12710 = getitem_12711 = getitem_12712 = getitem_12713 = getitem_12714 = getitem_12715 = getitem_12716 = getitem_12717 = getitem_12718 = getitem_12719 = getitem_12720 = getitem_12721 = getitem_12722 = getitem_12723 = getitem_12724 = getitem_12725 = getitem_12726 = getitem_12727 = getitem_12728 = getitem_12729 = getitem_12730 = getitem_12731 = getitem_12732 = getitem_12733 = getitem_12734 = getitem_12735 = getitem_12736 = getitem_12737 = getitem_12738 = getitem_12739 = getitem_12740 = getitem_12741 = getitem_12742 = getitem_12743 = getitem_12744 = getitem_12745 = getitem_12746 = getitem_12747 = getitem_12748 = None + reduce_scatter_tensor_134 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_311, 'avg', 128, '0'); cat_311 = None + wait_tensor_713 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_134); reduce_scatter_tensor_134 = None + view_1956 = torch.ops.aten.view.default(add_1921, [2, 4096, 2048]); add_1921 = None + convert_element_type_2156 = torch.ops.prims.convert_element_type.default(view_1956, torch.float32); view_1956 = None + convert_element_type_923 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_923, 128, '0'); convert_element_type_923 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_2158 = torch.ops.prims.convert_element_type.default(wait_tensor_353, torch.float32); wait_tensor_353 = None + mul_1611 = torch.ops.aten.mul.Tensor(convert_element_type_2156, convert_element_type_2158); convert_element_type_2158 = None + convert_element_type_924 = torch.ops.prims.convert_element_type.default(add_1096, torch.float32); add_1096 = None + mul_799 = torch.ops.aten.mul.Tensor(convert_element_type_924, rsqrt_53); convert_element_type_924 = None + mul_1613 = torch.ops.aten.mul.Tensor(mul_799, mul_1611) + sum_180 = torch.ops.aten.sum.dim_IntList(mul_1613, [2], True); mul_1613 = None + div_189 = torch.ops.aten.div.Tensor(mul_799, 2048) + mul_1614 = torch.ops.aten.mul.Tensor(div_189, sum_180); div_189 = sum_180 = None + sub_682 = torch.ops.aten.sub.Tensor(mul_1611, mul_1614); mul_1611 = mul_1614 = None + mul_1615 = torch.ops.aten.mul.Tensor(sub_682, rsqrt_53); sub_682 = rsqrt_53 = None + mul_1616 = torch.ops.aten.mul.Tensor(convert_element_type_2156, mul_799); convert_element_type_2156 = mul_799 = None + sum_181 = torch.ops.aten.sum.dim_IntList(mul_1616, [0, 1]); mul_1616 = None + convert_element_type_2159 = torch.ops.prims.convert_element_type.default(mul_1615, torch.bfloat16); mul_1615 = None + add_1922 = torch.ops.aten.add.Tensor(add_1909, convert_element_type_2159); add_1909 = convert_element_type_2159 = None + convert_element_type_default_54 = torch.ops.prims.convert_element_type.default(sum_181, torch.float32); sum_181 = None + reduce_scatter_tensor_135 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_54, 'avg', 128, '0'); convert_element_type_default_54 = None + wait_tensor_714 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_135); reduce_scatter_tensor_135 = None + view_1957 = torch.ops.aten.view.default(add_1922, [8192, 2048]) + permute_886 = torch.ops.aten.permute.default(view_1957, [1, 0]) + permute_257 = torch.ops.aten.permute.default(getitem_1775, [0, 2, 1, 3]) + view_1125 = torch.ops.aten.view.default(permute_257, [2, 4096, -1]); permute_257 = None + view_1127 = torch.ops.aten.view.default(view_1125, [8192, 2048]); view_1125 = None + mm_370 = torch.ops.aten.mm.default(permute_886, view_1127); permute_886 = view_1127 = None + convert_element_type_920 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_920, 128, '0'); convert_element_type_920 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_258 = torch.ops.aten.permute.default(wait_tensor_352, [1, 0]); wait_tensor_352 = None + permute_888 = torch.ops.aten.permute.default(permute_258, [1, 0]); permute_258 = None + mm_371 = torch.ops.aten.mm.default(view_1957, permute_888); view_1957 = permute_888 = None + view_1958 = torch.ops.aten.view.default(mm_371, [2, 4096, 2048]); mm_371 = None + convert_element_type_2166 = torch.ops.prims.convert_element_type.default(mm_370, torch.float32); mm_370 = None + reduce_scatter_tensor_136 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2166, 'avg', 128, '0'); convert_element_type_2166 = None + wait_tensor_715 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_136); reduce_scatter_tensor_136 = None + view_1959 = torch.ops.aten.view.default(view_1958, [2, 4096, 16, 128]); view_1958 = None + permute_890 = torch.ops.aten.permute.default(view_1959, [0, 2, 1, 3]); view_1959 = None + fw_graph9 = self.fw_graph9 + joint_graph9 = self.joint_graph9 + mask_graph9 = self.mask_graph9 + flex_attention_backward_9 = torch.ops.higher_order.flex_attention_backward(permute_254, permute_255, permute_256, getitem_1775, getitem_1776, permute_890, None, fw_graph9, joint_graph9, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph9), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_254 = permute_255 = permute_256 = getitem_1775 = getitem_1776 = permute_890 = fw_graph9 = joint_graph9 = mask_graph9 = None + getitem_12749 = flex_attention_backward_9[0] + getitem_12750 = flex_attention_backward_9[1] + getitem_12751 = flex_attention_backward_9[2]; flex_attention_backward_9 = None + permute_891 = torch.ops.aten.permute.default(getitem_12751, [0, 2, 1, 3]); getitem_12751 = None + permute_892 = torch.ops.aten.permute.default(getitem_12750, [0, 2, 1, 3]); getitem_12750 = None + permute_893 = torch.ops.aten.permute.default(getitem_12749, [0, 2, 1, 3]); getitem_12749 = None + slice_218 = torch.ops.aten.slice.Tensor(permute_892, 3, 0, 128) + slice_219 = torch.ops.aten.slice.Tensor(permute_892, 3, 128, 192); permute_892 = None + sum_182 = torch.ops.aten.sum.dim_IntList(slice_219, [2], True); slice_219 = None + cat_312 = torch.ops.aten.cat.default([slice_218, permute_891], 3); slice_218 = permute_891 = None + view_1960 = torch.ops.aten.view.default(cat_312, [2, 4096, 4096]); cat_312 = None + view_1961 = torch.ops.aten.view.default(view_1960, [8192, 4096]); view_1960 = None + permute_894 = torch.ops.aten.permute.default(view_1961, [1, 0]) + mm_372 = torch.ops.aten.mm.default(permute_894, view_1122); permute_894 = view_1122 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16); primals_283 = None + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_917, 128, '0'); convert_element_type_917 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_351, [1, 0]); wait_tensor_351 = None + permute_896 = torch.ops.aten.permute.default(permute_253, [1, 0]); permute_253 = None + mm_373 = torch.ops.aten.mm.default(view_1961, permute_896); view_1961 = permute_896 = None + view_1962 = torch.ops.aten.view.default(mm_373, [2, 4096, 512]); mm_373 = None + convert_element_type_2171 = torch.ops.prims.convert_element_type.default(mm_372, torch.float32); mm_372 = None + reduce_scatter_tensor_137 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2171, 'avg', 128, '0'); convert_element_type_2171 = None + wait_tensor_716 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_137); reduce_scatter_tensor_137 = None + convert_element_type_2172 = torch.ops.prims.convert_element_type.default(view_1962, torch.float32); view_1962 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16); primals_282 = None + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 128, '0'); convert_element_type_914 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + convert_element_type_2174 = torch.ops.prims.convert_element_type.default(wait_tensor_350, torch.float32); wait_tensor_350 = None + mul_1617 = torch.ops.aten.mul.Tensor(convert_element_type_2172, convert_element_type_2174); convert_element_type_2174 = None + convert_element_type_915 = torch.ops.prims.convert_element_type.default(getitem_1771, torch.float32); getitem_1771 = None + mul_797 = torch.ops.aten.mul.Tensor(convert_element_type_915, rsqrt_52); convert_element_type_915 = None + mul_1619 = torch.ops.aten.mul.Tensor(mul_797, mul_1617) + sum_183 = torch.ops.aten.sum.dim_IntList(mul_1619, [2], True); mul_1619 = None + div_190 = torch.ops.aten.div.Tensor(mul_797, 512) + mul_1620 = torch.ops.aten.mul.Tensor(div_190, sum_183); div_190 = sum_183 = None + sub_683 = torch.ops.aten.sub.Tensor(mul_1617, mul_1620); mul_1617 = mul_1620 = None + mul_1621 = torch.ops.aten.mul.Tensor(sub_683, rsqrt_52); sub_683 = rsqrt_52 = None + mul_1622 = torch.ops.aten.mul.Tensor(convert_element_type_2172, mul_797); convert_element_type_2172 = mul_797 = None + sum_184 = torch.ops.aten.sum.dim_IntList(mul_1622, [0, 1]); mul_1622 = None + convert_element_type_2175 = torch.ops.prims.convert_element_type.default(mul_1621, torch.bfloat16); mul_1621 = None + convert_element_type_default_53 = torch.ops.prims.convert_element_type.default(sum_184, torch.float32); sum_184 = None + reduce_scatter_tensor_138 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_53, 'avg', 128, '0'); convert_element_type_default_53 = None + wait_tensor_717 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_138); reduce_scatter_tensor_138 = None + convert_element_type_2178 = torch.ops.prims.convert_element_type.default(sum_182, torch.float32); sum_182 = None + view_1963 = torch.ops.aten.view.default(convert_element_type_2178, [2, 4096, 1, 32, 2]); convert_element_type_2178 = None + view_as_complex_72 = torch.ops.aten.view_as_complex.default(view_1963); view_1963 = None + mul_1623 = torch.ops.aten.mul.Tensor(view_as_complex_72, clone_9); view_as_complex_72 = None + view_as_real_72 = torch.ops.aten.view_as_real.default(mul_1623); mul_1623 = None + view_1964 = torch.ops.aten.view.default(view_as_real_72, [2, 4096, 1, 64]); view_as_real_72 = None + convert_element_type_2179 = torch.ops.prims.convert_element_type.default(view_1964, torch.bfloat16); view_1964 = None + squeeze_35 = torch.ops.aten.squeeze.dim(convert_element_type_2179, 2); convert_element_type_2179 = None + cat_313 = torch.ops.aten.cat.default([convert_element_type_2175, squeeze_35], 2); convert_element_type_2175 = squeeze_35 = None + view_1965 = torch.ops.aten.view.default(cat_313, [8192, 576]); cat_313 = None + permute_898 = torch.ops.aten.permute.default(view_1965, [1, 0]) + mm_374 = torch.ops.aten.mm.default(permute_898, view_1108); permute_898 = None + convert_element_type_909 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16); primals_281 = None + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_909, 128, '0'); convert_element_type_909 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + slice_103 = torch.ops.aten.slice.Tensor(wait_tensor_349, 0, 0, 576); wait_tensor_349 = None + permute_252 = torch.ops.aten.permute.default(slice_103, [1, 0]); slice_103 = None + permute_900 = torch.ops.aten.permute.default(permute_252, [1, 0]); permute_252 = None + mm_375 = torch.ops.aten.mm.default(view_1965, permute_900); view_1965 = permute_900 = None + view_1966 = torch.ops.aten.view.default(mm_375, [2, 4096, 2048]); mm_375 = None + convert_element_type_2184 = torch.ops.prims.convert_element_type.default(mm_374, torch.float32); mm_374 = None + split_685 = torch.ops.aten.split.Tensor(convert_element_type_2184, 5); convert_element_type_2184 = None + getitem_12753 = split_685[0] + getitem_12754 = split_685[1] + getitem_12755 = split_685[2] + getitem_12756 = split_685[3] + getitem_12757 = split_685[4] + getitem_12758 = split_685[5] + getitem_12759 = split_685[6] + getitem_12760 = split_685[7] + getitem_12761 = split_685[8] + getitem_12762 = split_685[9] + getitem_12763 = split_685[10] + getitem_12764 = split_685[11] + getitem_12765 = split_685[12] + getitem_12766 = split_685[13] + getitem_12767 = split_685[14] + getitem_12768 = split_685[15] + getitem_12769 = split_685[16] + getitem_12770 = split_685[17] + getitem_12771 = split_685[18] + getitem_12772 = split_685[19] + getitem_12773 = split_685[20] + getitem_12774 = split_685[21] + getitem_12775 = split_685[22] + getitem_12776 = split_685[23] + getitem_12777 = split_685[24] + getitem_12778 = split_685[25] + getitem_12779 = split_685[26] + getitem_12780 = split_685[27] + getitem_12781 = split_685[28] + getitem_12782 = split_685[29] + getitem_12783 = split_685[30] + getitem_12784 = split_685[31] + getitem_12785 = split_685[32] + getitem_12786 = split_685[33] + getitem_12787 = split_685[34] + getitem_12788 = split_685[35] + getitem_12789 = split_685[36] + getitem_12790 = split_685[37] + getitem_12791 = split_685[38] + getitem_12792 = split_685[39] + getitem_12793 = split_685[40] + getitem_12794 = split_685[41] + getitem_12795 = split_685[42] + getitem_12796 = split_685[43] + getitem_12797 = split_685[44] + getitem_12798 = split_685[45] + getitem_12799 = split_685[46] + getitem_12800 = split_685[47] + getitem_12801 = split_685[48] + getitem_12802 = split_685[49] + getitem_12803 = split_685[50] + getitem_12804 = split_685[51] + getitem_12805 = split_685[52] + getitem_12806 = split_685[53] + getitem_12807 = split_685[54] + getitem_12808 = split_685[55] + getitem_12809 = split_685[56] + getitem_12810 = split_685[57] + getitem_12811 = split_685[58] + getitem_12812 = split_685[59] + getitem_12813 = split_685[60] + getitem_12814 = split_685[61] + getitem_12815 = split_685[62] + getitem_12816 = split_685[63] + getitem_12817 = split_685[64] + getitem_12818 = split_685[65] + getitem_12819 = split_685[66] + getitem_12820 = split_685[67] + getitem_12821 = split_685[68] + getitem_12822 = split_685[69] + getitem_12823 = split_685[70] + getitem_12824 = split_685[71] + getitem_12825 = split_685[72] + getitem_12826 = split_685[73] + getitem_12827 = split_685[74] + getitem_12828 = split_685[75] + getitem_12829 = split_685[76] + getitem_12830 = split_685[77] + getitem_12831 = split_685[78] + getitem_12832 = split_685[79] + getitem_12833 = split_685[80] + getitem_12834 = split_685[81] + getitem_12835 = split_685[82] + getitem_12836 = split_685[83] + getitem_12837 = split_685[84] + getitem_12838 = split_685[85] + getitem_12839 = split_685[86] + getitem_12840 = split_685[87] + getitem_12841 = split_685[88] + getitem_12842 = split_685[89] + getitem_12843 = split_685[90] + getitem_12844 = split_685[91] + getitem_12845 = split_685[92] + getitem_12846 = split_685[93] + getitem_12847 = split_685[94] + getitem_12848 = split_685[95] + getitem_12849 = split_685[96] + getitem_12850 = split_685[97] + getitem_12851 = split_685[98] + getitem_12852 = split_685[99] + getitem_12853 = split_685[100] + getitem_12854 = split_685[101] + getitem_12855 = split_685[102] + getitem_12856 = split_685[103] + getitem_12857 = split_685[104] + getitem_12858 = split_685[105] + getitem_12859 = split_685[106] + getitem_12860 = split_685[107] + getitem_12861 = split_685[108] + getitem_12862 = split_685[109] + getitem_12863 = split_685[110] + getitem_12864 = split_685[111] + getitem_12865 = split_685[112] + getitem_12866 = split_685[113] + getitem_12867 = split_685[114] + getitem_12868 = split_685[115]; split_685 = None + constant_pad_nd_757 = torch.ops.aten.constant_pad_nd.default(getitem_12868, [0, 0, 0, 4], 0.0); getitem_12868 = None + cat_314 = torch.ops.aten.cat.default([getitem_12753, getitem_12754, getitem_12755, getitem_12756, getitem_12757, getitem_12758, getitem_12759, getitem_12760, getitem_12761, getitem_12762, getitem_12763, getitem_12764, getitem_12765, getitem_12766, getitem_12767, getitem_12768, getitem_12769, getitem_12770, getitem_12771, getitem_12772, getitem_12773, getitem_12774, getitem_12775, getitem_12776, getitem_12777, getitem_12778, getitem_12779, getitem_12780, getitem_12781, getitem_12782, getitem_12783, getitem_12784, getitem_12785, getitem_12786, getitem_12787, getitem_12788, getitem_12789, getitem_12790, getitem_12791, getitem_12792, getitem_12793, getitem_12794, getitem_12795, getitem_12796, getitem_12797, getitem_12798, getitem_12799, getitem_12800, getitem_12801, getitem_12802, getitem_12803, getitem_12804, getitem_12805, getitem_12806, getitem_12807, getitem_12808, getitem_12809, getitem_12810, getitem_12811, getitem_12812, getitem_12813, getitem_12814, getitem_12815, getitem_12816, getitem_12817, getitem_12818, getitem_12819, getitem_12820, getitem_12821, getitem_12822, getitem_12823, getitem_12824, getitem_12825, getitem_12826, getitem_12827, getitem_12828, getitem_12829, getitem_12830, getitem_12831, getitem_12832, getitem_12833, getitem_12834, getitem_12835, getitem_12836, getitem_12837, getitem_12838, getitem_12839, getitem_12840, getitem_12841, getitem_12842, getitem_12843, getitem_12844, getitem_12845, getitem_12846, getitem_12847, getitem_12848, getitem_12849, getitem_12850, getitem_12851, getitem_12852, getitem_12853, getitem_12854, getitem_12855, getitem_12856, getitem_12857, getitem_12858, getitem_12859, getitem_12860, getitem_12861, getitem_12862, getitem_12863, getitem_12864, getitem_12865, getitem_12866, getitem_12867, constant_pad_nd_757, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_12753 = getitem_12754 = getitem_12755 = getitem_12756 = getitem_12757 = getitem_12758 = getitem_12759 = getitem_12760 = getitem_12761 = getitem_12762 = getitem_12763 = getitem_12764 = getitem_12765 = getitem_12766 = getitem_12767 = getitem_12768 = getitem_12769 = getitem_12770 = getitem_12771 = getitem_12772 = getitem_12773 = getitem_12774 = getitem_12775 = getitem_12776 = getitem_12777 = getitem_12778 = getitem_12779 = getitem_12780 = getitem_12781 = getitem_12782 = getitem_12783 = getitem_12784 = getitem_12785 = getitem_12786 = getitem_12787 = getitem_12788 = getitem_12789 = getitem_12790 = getitem_12791 = getitem_12792 = getitem_12793 = getitem_12794 = getitem_12795 = getitem_12796 = getitem_12797 = getitem_12798 = getitem_12799 = getitem_12800 = getitem_12801 = getitem_12802 = getitem_12803 = getitem_12804 = getitem_12805 = getitem_12806 = getitem_12807 = getitem_12808 = getitem_12809 = getitem_12810 = getitem_12811 = getitem_12812 = getitem_12813 = getitem_12814 = getitem_12815 = getitem_12816 = getitem_12817 = getitem_12818 = getitem_12819 = getitem_12820 = getitem_12821 = getitem_12822 = getitem_12823 = getitem_12824 = getitem_12825 = getitem_12826 = getitem_12827 = getitem_12828 = getitem_12829 = getitem_12830 = getitem_12831 = getitem_12832 = getitem_12833 = getitem_12834 = getitem_12835 = getitem_12836 = getitem_12837 = getitem_12838 = getitem_12839 = getitem_12840 = getitem_12841 = getitem_12842 = getitem_12843 = getitem_12844 = getitem_12845 = getitem_12846 = getitem_12847 = getitem_12848 = getitem_12849 = getitem_12850 = getitem_12851 = getitem_12852 = getitem_12853 = getitem_12854 = getitem_12855 = getitem_12856 = getitem_12857 = getitem_12858 = getitem_12859 = getitem_12860 = getitem_12861 = getitem_12862 = getitem_12863 = getitem_12864 = getitem_12865 = getitem_12866 = getitem_12867 = constant_pad_nd_757 = None + reduce_scatter_tensor_139 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_314, 'avg', 128, '0'); cat_314 = None + wait_tensor_718 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_139); reduce_scatter_tensor_139 = None + slice_220 = torch.ops.aten.slice.Tensor(permute_893, 3, 0, 128) + slice_221 = torch.ops.aten.slice.Tensor(permute_893, 3, 128, 192); permute_893 = None + convert_element_type_2185 = torch.ops.prims.convert_element_type.default(slice_221, torch.float32); slice_221 = None + view_1967 = torch.ops.aten.view.default(convert_element_type_2185, [2, 4096, 16, 32, 2]); convert_element_type_2185 = None + view_as_complex_73 = torch.ops.aten.view_as_complex.default(view_1967); view_1967 = None + mul_1624 = torch.ops.aten.mul.Tensor(view_as_complex_73, clone_9); view_as_complex_73 = None + view_as_real_73 = torch.ops.aten.view_as_real.default(mul_1624); mul_1624 = None + view_1968 = torch.ops.aten.view.default(view_as_real_73, [2, 4096, 16, 64]); view_as_real_73 = None + convert_element_type_2186 = torch.ops.prims.convert_element_type.default(view_1968, torch.bfloat16); view_1968 = None + cat_315 = torch.ops.aten.cat.default([slice_220, convert_element_type_2186], 3); slice_220 = convert_element_type_2186 = None + view_1969 = torch.ops.aten.view.default(cat_315, [2, 4096, 3072]); cat_315 = None + view_1970 = torch.ops.aten.view.default(view_1969, [8192, 3072]); view_1969 = None + permute_902 = torch.ops.aten.permute.default(view_1970, [1, 0]) + mm_376 = torch.ops.aten.mm.default(permute_902, view_1108); permute_902 = view_1108 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16); primals_280 = None + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_904, 128, '0'); convert_element_type_904 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_348, [1, 0]); wait_tensor_348 = None + permute_904 = torch.ops.aten.permute.default(permute_251, [1, 0]); permute_251 = None + mm_377 = torch.ops.aten.mm.default(view_1970, permute_904); view_1970 = permute_904 = None + view_1971 = torch.ops.aten.view.default(mm_377, [2, 4096, 2048]); mm_377 = None + add_1923 = torch.ops.aten.add.Tensor(view_1966, view_1971); view_1966 = view_1971 = None + convert_element_type_2191 = torch.ops.prims.convert_element_type.default(mm_376, torch.float32); mm_376 = None + reduce_scatter_tensor_140 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2191, 'avg', 128, '0'); convert_element_type_2191 = None + wait_tensor_719 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_140); reduce_scatter_tensor_140 = None + convert_element_type_2192 = torch.ops.prims.convert_element_type.default(add_1923, torch.float32); add_1923 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16); primals_279 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 128, '0'); convert_element_type_901 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + convert_element_type_2194 = torch.ops.prims.convert_element_type.default(wait_tensor_347, torch.float32); wait_tensor_347 = None + mul_1625 = torch.ops.aten.mul.Tensor(convert_element_type_2192, convert_element_type_2194); convert_element_type_2194 = None + convert_element_type_902 = torch.ops.prims.convert_element_type.default(add_1093, torch.float32); add_1093 = None + mul_793 = torch.ops.aten.mul.Tensor(convert_element_type_902, rsqrt_51); convert_element_type_902 = None + mul_1627 = torch.ops.aten.mul.Tensor(mul_793, mul_1625) + sum_185 = torch.ops.aten.sum.dim_IntList(mul_1627, [2], True); mul_1627 = None + div_191 = torch.ops.aten.div.Tensor(mul_793, 2048) + mul_1628 = torch.ops.aten.mul.Tensor(div_191, sum_185); div_191 = sum_185 = None + sub_684 = torch.ops.aten.sub.Tensor(mul_1625, mul_1628); mul_1625 = mul_1628 = None + mul_1629 = torch.ops.aten.mul.Tensor(sub_684, rsqrt_51); sub_684 = rsqrt_51 = None + mul_1630 = torch.ops.aten.mul.Tensor(convert_element_type_2192, mul_793); convert_element_type_2192 = mul_793 = None + sum_186 = torch.ops.aten.sum.dim_IntList(mul_1630, [0, 1]); mul_1630 = None + convert_element_type_2195 = torch.ops.prims.convert_element_type.default(mul_1629, torch.bfloat16); mul_1629 = None + add_1924 = torch.ops.aten.add.Tensor(add_1922, convert_element_type_2195); add_1922 = convert_element_type_2195 = None + convert_element_type_default_52 = torch.ops.prims.convert_element_type.default(sum_186, torch.float32); sum_186 = None + reduce_scatter_tensor_141 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_52, 'avg', 128, '0'); convert_element_type_default_52 = None + wait_tensor_720 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_141); reduce_scatter_tensor_141 = None + view_1972 = torch.ops.aten.view.default(add_1924, [8192, 2048]) + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_1972, 1) + convert_element_type_2198 = torch.ops.prims.convert_element_type.default(unsqueeze_63, torch.float32); unsqueeze_63 = None + bmm_46 = torch.ops.aten.bmm.default(permute_906, convert_element_type_2198); permute_906 = None + bmm_47 = torch.ops.aten.bmm.default(convert_element_type_2198, permute_907); convert_element_type_2198 = permute_907 = None + convert_element_type_2199 = torch.ops.prims.convert_element_type.default(bmm_46, torch.bfloat16); bmm_46 = None + view_1973 = torch.ops.aten.view.default(bmm_47, [8192, 6]); bmm_47 = None + view_1974 = torch.ops.aten.view.default(convert_element_type_2199, [49152, 2048]); convert_element_type_2199 = None + index_72 = torch.ops.aten.index.Tensor(view_1974, [getitem_1671]); view_1974 = getitem_1671 = None + permute_908 = torch.ops.aten.permute.default(view_1972, [1, 0]) + mm_378 = torch.ops.aten.mm.default(permute_908, mul_790); permute_908 = mul_790 = None + convert_element_type_896 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16); primals_278 = None + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_896, 128, '0'); convert_element_type_896 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_346, [1, 0]); wait_tensor_346 = None + permute_910 = torch.ops.aten.permute.default(permute_250, [1, 0]); permute_250 = None + mm_379 = torch.ops.aten.mm.default(view_1972, permute_910); view_1972 = permute_910 = None + convert_element_type_2204 = torch.ops.prims.convert_element_type.default(mm_378, torch.float32); mm_378 = None + reduce_scatter_tensor_142 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2204, 'avg', 128, '0'); convert_element_type_2204 = None + wait_tensor_721 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_142); reduce_scatter_tensor_142 = None + convert_element_type_891 = torch.ops.prims.convert_element_type.default(mm_132, torch.float32); mm_132 = None + neg_32 = torch.ops.aten.neg.default(convert_element_type_891) + exp_48 = torch.ops.aten.exp.default(neg_32); neg_32 = None + add_1088 = torch.ops.aten.add.Tensor(exp_48, 1); exp_48 = None + div_80 = torch.ops.aten.div.Tensor(convert_element_type_891, add_1088) + convert_element_type_892 = torch.ops.prims.convert_element_type.default(div_80, torch.bfloat16); div_80 = None + mul_1631 = torch.ops.aten.mul.Tensor(mm_379, convert_element_type_892); convert_element_type_892 = None + mul_1632 = torch.ops.aten.mul.Tensor(mm_379, mm_133); mm_379 = mm_133 = None + permute_912 = torch.ops.aten.permute.default(mul_1631, [1, 0]) + mm_380 = torch.ops.aten.mm.default(permute_912, view_1063); permute_912 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16); primals_277 = None + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_893, 128, '0'); convert_element_type_893 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_345, [1, 0]); wait_tensor_345 = None + permute_914 = torch.ops.aten.permute.default(permute_249, [1, 0]); permute_249 = None + mm_381 = torch.ops.aten.mm.default(mul_1631, permute_914); mul_1631 = permute_914 = None + convert_element_type_2209 = torch.ops.prims.convert_element_type.default(mm_380, torch.float32); mm_380 = None + reduce_scatter_tensor_143 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2209, 'avg', 128, '0'); convert_element_type_2209 = None + wait_tensor_722 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_143); reduce_scatter_tensor_143 = None + convert_element_type_2210 = torch.ops.prims.convert_element_type.default(mul_1632, torch.float32); mul_1632 = None + reciprocal_20 = torch.ops.aten.reciprocal.default(add_1088); add_1088 = None + mul_1633 = torch.ops.aten.mul.Tensor(reciprocal_20, 1); reciprocal_20 = None + mul_1634 = torch.ops.aten.mul.Tensor(convert_element_type_2210, mul_1633); convert_element_type_2210 = None + sub_685 = torch.ops.aten.sub.Tensor(1, mul_1633); mul_1633 = None + mul_1635 = torch.ops.aten.mul.Tensor(convert_element_type_891, sub_685); convert_element_type_891 = sub_685 = None + add_1926 = torch.ops.aten.add.Tensor(mul_1635, 1); mul_1635 = None + mul_1636 = torch.ops.aten.mul.Tensor(mul_1634, add_1926); mul_1634 = add_1926 = None + convert_element_type_2212 = torch.ops.prims.convert_element_type.default(mul_1636, torch.bfloat16); mul_1636 = None + permute_916 = torch.ops.aten.permute.default(convert_element_type_2212, [1, 0]) + mm_382 = torch.ops.aten.mm.default(permute_916, view_1063); permute_916 = None + convert_element_type_888 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16); primals_276 = None + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_888, 128, '0'); convert_element_type_888 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + permute_248 = torch.ops.aten.permute.default(wait_tensor_344, [1, 0]); wait_tensor_344 = None + permute_918 = torch.ops.aten.permute.default(permute_248, [1, 0]); permute_248 = None + mm_383 = torch.ops.aten.mm.default(convert_element_type_2212, permute_918); convert_element_type_2212 = permute_918 = None + add_1927 = torch.ops.aten.add.Tensor(mm_381, mm_383); mm_381 = mm_383 = None + convert_element_type_2217 = torch.ops.prims.convert_element_type.default(mm_382, torch.float32); mm_382 = None + reduce_scatter_tensor_144 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2217, 'avg', 128, '0'); convert_element_type_2217 = None + wait_tensor_723 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_144); reduce_scatter_tensor_144 = None + all_to_all_single_98 = torch.ops._c10d_functional.all_to_all_single.default(index_72, [_local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255], [_local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247], '1033'); index_72 = None + wait_tensor_724 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_98); all_to_all_single_98 = None + full_408 = torch.ops.aten.full.default([sym_size_int_61, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_61 = None + slice_scatter_10 = torch.ops.aten.slice_scatter.default(full_408, wait_tensor_724, 0, 0, -1); wait_tensor_724 = None + index_73 = torch.ops.aten.index.Tensor(slice_scatter_10, [getitem_1672]); slice_scatter_10 = None + permute_920 = torch.ops.aten.permute.default(index_73, [1, 0]) + _grouped_mm_138 = torch.ops.aten._grouped_mm.default(permute_920, mul_770, cumsum_47); permute_920 = mul_770 = None + _grouped_mm_139 = torch.ops.aten._grouped_mm.default(index_73, permute_922, cumsum_47); index_73 = permute_922 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(_grouped_mm_45, torch.float32); _grouped_mm_45 = None + neg_31 = torch.ops.aten.neg.default(convert_element_type_886) + exp_47 = torch.ops.aten.exp.default(neg_31); neg_31 = None + add_1052 = torch.ops.aten.add.Tensor(exp_47, 1); exp_47 = None + div_79 = torch.ops.aten.div.Tensor(convert_element_type_886, add_1052) + convert_element_type_887 = torch.ops.prims.convert_element_type.default(div_79, torch.bfloat16); div_79 = None + mul_1637 = torch.ops.aten.mul.Tensor(_grouped_mm_139, convert_element_type_887); convert_element_type_887 = None + mul_1638 = torch.ops.aten.mul.Tensor(_grouped_mm_139, _grouped_mm_46); _grouped_mm_139 = _grouped_mm_46 = None + permute_924 = torch.ops.aten.permute.default(mul_1637, [1, 0]) + _grouped_mm_140 = torch.ops.aten._grouped_mm.default(permute_924, index_31, cumsum_47); permute_924 = None + _grouped_mm_141 = torch.ops.aten._grouped_mm.default(mul_1637, permute_926, cumsum_47); mul_1637 = permute_926 = None + convert_element_type_2218 = torch.ops.prims.convert_element_type.default(mul_1638, torch.float32); mul_1638 = None + reciprocal_21 = torch.ops.aten.reciprocal.default(add_1052); add_1052 = None + mul_1639 = torch.ops.aten.mul.Tensor(reciprocal_21, 1); reciprocal_21 = None + mul_1640 = torch.ops.aten.mul.Tensor(convert_element_type_2218, mul_1639); convert_element_type_2218 = None + sub_686 = torch.ops.aten.sub.Tensor(1, mul_1639); mul_1639 = None + mul_1641 = torch.ops.aten.mul.Tensor(convert_element_type_886, sub_686); convert_element_type_886 = sub_686 = None + add_1929 = torch.ops.aten.add.Tensor(mul_1641, 1); mul_1641 = None + mul_1642 = torch.ops.aten.mul.Tensor(mul_1640, add_1929); mul_1640 = add_1929 = None + convert_element_type_2220 = torch.ops.prims.convert_element_type.default(mul_1642, torch.bfloat16); mul_1642 = None + permute_928 = torch.ops.aten.permute.default(convert_element_type_2220, [1, 0]) + _grouped_mm_142 = torch.ops.aten._grouped_mm.default(permute_928, index_31, cumsum_47); permute_928 = index_31 = None + _grouped_mm_143 = torch.ops.aten._grouped_mm.default(convert_element_type_2220, permute_930, cumsum_47); convert_element_type_2220 = permute_930 = cumsum_47 = None + add_1930 = torch.ops.aten.add.Tensor(_grouped_mm_141, _grouped_mm_143); _grouped_mm_141 = _grouped_mm_143 = None + convert_element_type_2221 = torch.ops.prims.convert_element_type.default(_grouped_mm_140, torch.float32); _grouped_mm_140 = None + div_192 = torch.ops.aten.div.Tensor(convert_element_type_2221, 128); convert_element_type_2221 = None + split_687 = torch.ops.aten.split.Tensor(div_192, 88, 1); div_192 = None + getitem_12885 = split_687[0] + getitem_12902 = split_687[1] + getitem_12919 = split_687[2] + getitem_12936 = split_687[3] + getitem_12953 = split_687[4] + getitem_12970 = split_687[5] + getitem_12987 = split_687[6] + getitem_13004 = split_687[7] + getitem_13021 = split_687[8] + getitem_13038 = split_687[9] + getitem_13055 = split_687[10] + getitem_13072 = split_687[11] + getitem_13089 = split_687[12] + getitem_13106 = split_687[13] + getitem_13123 = split_687[14] + getitem_13140 = split_687[15]; split_687 = None + cat_316 = torch.ops.aten.cat.default([getitem_12885, getitem_12902, getitem_12919, getitem_12936, getitem_12953, getitem_12970, getitem_12987, getitem_13004, getitem_13021, getitem_13038, getitem_13055, getitem_13072, getitem_13089, getitem_13106, getitem_13123, getitem_13140]); getitem_12885 = getitem_12902 = getitem_12919 = getitem_12936 = getitem_12953 = getitem_12970 = getitem_12987 = getitem_13004 = getitem_13021 = getitem_13038 = getitem_13055 = getitem_13072 = getitem_13089 = getitem_13106 = getitem_13123 = getitem_13140 = None + reduce_scatter_tensor_145 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_316, 'sum', 16, '1025'); cat_316 = None + wait_tensor_725 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_145); reduce_scatter_tensor_145 = None + convert_element_type_2222 = torch.ops.prims.convert_element_type.default(_grouped_mm_138, torch.float32); _grouped_mm_138 = None + div_193 = torch.ops.aten.div.Tensor(convert_element_type_2222, 128); convert_element_type_2222 = None + split_704 = torch.ops.aten.split.Tensor(div_193, 128, 1); div_193 = None + getitem_13157 = split_704[0] + getitem_13174 = split_704[1] + getitem_13191 = split_704[2] + getitem_13208 = split_704[3] + getitem_13225 = split_704[4] + getitem_13242 = split_704[5] + getitem_13259 = split_704[6] + getitem_13276 = split_704[7] + getitem_13293 = split_704[8] + getitem_13310 = split_704[9] + getitem_13327 = split_704[10] + getitem_13344 = split_704[11] + getitem_13361 = split_704[12] + getitem_13378 = split_704[13] + getitem_13395 = split_704[14] + getitem_13412 = split_704[15]; split_704 = None + cat_317 = torch.ops.aten.cat.default([getitem_13157, getitem_13174, getitem_13191, getitem_13208, getitem_13225, getitem_13242, getitem_13259, getitem_13276, getitem_13293, getitem_13310, getitem_13327, getitem_13344, getitem_13361, getitem_13378, getitem_13395, getitem_13412]); getitem_13157 = getitem_13174 = getitem_13191 = getitem_13208 = getitem_13225 = getitem_13242 = getitem_13259 = getitem_13276 = getitem_13293 = getitem_13310 = getitem_13327 = getitem_13344 = getitem_13361 = getitem_13378 = getitem_13395 = getitem_13412 = None + reduce_scatter_tensor_146 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_317, 'sum', 16, '1025'); cat_317 = None + wait_tensor_726 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_146); reduce_scatter_tensor_146 = None + convert_element_type_2223 = torch.ops.prims.convert_element_type.default(_grouped_mm_142, torch.float32); _grouped_mm_142 = None + div_194 = torch.ops.aten.div.Tensor(convert_element_type_2223, 128); convert_element_type_2223 = None + split_721 = torch.ops.aten.split.Tensor(div_194, 88, 1); div_194 = None + getitem_13429 = split_721[0] + getitem_13446 = split_721[1] + getitem_13463 = split_721[2] + getitem_13480 = split_721[3] + getitem_13497 = split_721[4] + getitem_13514 = split_721[5] + getitem_13531 = split_721[6] + getitem_13548 = split_721[7] + getitem_13565 = split_721[8] + getitem_13582 = split_721[9] + getitem_13599 = split_721[10] + getitem_13616 = split_721[11] + getitem_13633 = split_721[12] + getitem_13650 = split_721[13] + getitem_13667 = split_721[14] + getitem_13684 = split_721[15]; split_721 = None + cat_318 = torch.ops.aten.cat.default([getitem_13429, getitem_13446, getitem_13463, getitem_13480, getitem_13497, getitem_13514, getitem_13531, getitem_13548, getitem_13565, getitem_13582, getitem_13599, getitem_13616, getitem_13633, getitem_13650, getitem_13667, getitem_13684]); getitem_13429 = getitem_13446 = getitem_13463 = getitem_13480 = getitem_13497 = getitem_13514 = getitem_13531 = getitem_13548 = getitem_13565 = getitem_13582 = getitem_13599 = getitem_13616 = getitem_13633 = getitem_13650 = getitem_13667 = getitem_13684 = None + reduce_scatter_tensor_147 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_318, 'sum', 16, '1025'); cat_318 = None + wait_tensor_727 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_147); reduce_scatter_tensor_147 = None + index_put_72 = torch.ops.aten.index_put.default(full_408, [getitem_1672], add_1930, True); full_408 = getitem_1672 = add_1930 = None + slice_222 = torch.ops.aten.slice.Tensor(index_put_72, 0, 0, add_1931); index_put_72 = add_1931 = None + all_to_all_single_99 = torch.ops._c10d_functional.all_to_all_single.default(slice_222, [_local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247], [_local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255], '1033'); slice_222 = _local_scalar_dense_240 = _local_scalar_dense_241 = _local_scalar_dense_242 = _local_scalar_dense_243 = _local_scalar_dense_244 = _local_scalar_dense_245 = _local_scalar_dense_246 = _local_scalar_dense_247 = _local_scalar_dense_248 = _local_scalar_dense_249 = _local_scalar_dense_250 = _local_scalar_dense_251 = _local_scalar_dense_252 = _local_scalar_dense_253 = _local_scalar_dense_254 = _local_scalar_dense_255 = None + wait_tensor_728 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_99); all_to_all_single_99 = None + index_put_73 = torch.ops.aten.index_put.default(full_default_52, [div_77], wait_tensor_728, True); div_77 = wait_tensor_728 = None + add_1935 = torch.ops.aten.add.Tensor(add_1927, index_put_73); add_1927 = index_put_73 = None + mul_1643 = torch.ops.aten.mul.Tensor(view_1973, 1.0); view_1973 = None + scatter_add_10 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_1669, mul_1643); getitem_1669 = mul_1643 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(mm_131, torch.float32); mm_131 = None + sub_360 = torch.ops.aten.sub.Tensor(convert_element_type_875, amax_15); convert_element_type_875 = amax_15 = None + exp_46 = torch.ops.aten.exp.default(sub_360); sub_360 = None + div_76 = torch.ops.aten.div.Tensor(exp_46, sum_61); exp_46 = sum_61 = None + mul_1644 = torch.ops.aten.mul.Tensor(scatter_add_10, div_76); scatter_add_10 = None + sum_187 = torch.ops.aten.sum.dim_IntList(mul_1644, [1], True) + neg_85 = torch.ops.aten.neg.default(div_76); div_76 = None + fma_10 = torch.ops.prims.fma.default(neg_85, sum_187, mul_1644); neg_85 = sum_187 = mul_1644 = None + convert_element_type_2224 = torch.ops.prims.convert_element_type.default(fma_10, torch.bfloat16); fma_10 = None + permute_932 = torch.ops.aten.permute.default(convert_element_type_2224, [1, 0]) + mm_384 = torch.ops.aten.mm.default(permute_932, view_1063); permute_932 = view_1063 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16); primals_271 = None + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_872, 128, '0'); convert_element_type_872 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + slice_99 = torch.ops.aten.slice.Tensor(wait_tensor_333, 0, 0, 64); wait_tensor_333 = None + permute_244 = torch.ops.aten.permute.default(slice_99, [1, 0]); slice_99 = None + permute_934 = torch.ops.aten.permute.default(permute_244, [1, 0]); permute_244 = None + mm_385 = torch.ops.aten.mm.default(convert_element_type_2224, permute_934); convert_element_type_2224 = permute_934 = None + add_1936 = torch.ops.aten.add.Tensor(add_1935, mm_385); add_1935 = mm_385 = None + convert_element_type_2229 = torch.ops.prims.convert_element_type.default(mm_384, torch.float32); mm_384 = None + split_737 = torch.ops.aten.split.Tensor(convert_element_type_2229, 1); convert_element_type_2229 = None + getitem_13685 = split_737[0] + getitem_13686 = split_737[1] + getitem_13687 = split_737[2] + getitem_13688 = split_737[3] + getitem_13689 = split_737[4] + getitem_13690 = split_737[5] + getitem_13691 = split_737[6] + getitem_13692 = split_737[7] + getitem_13693 = split_737[8] + getitem_13694 = split_737[9] + getitem_13695 = split_737[10] + getitem_13696 = split_737[11] + getitem_13697 = split_737[12] + getitem_13698 = split_737[13] + getitem_13699 = split_737[14] + getitem_13700 = split_737[15] + getitem_13701 = split_737[16] + getitem_13702 = split_737[17] + getitem_13703 = split_737[18] + getitem_13704 = split_737[19] + getitem_13705 = split_737[20] + getitem_13706 = split_737[21] + getitem_13707 = split_737[22] + getitem_13708 = split_737[23] + getitem_13709 = split_737[24] + getitem_13710 = split_737[25] + getitem_13711 = split_737[26] + getitem_13712 = split_737[27] + getitem_13713 = split_737[28] + getitem_13714 = split_737[29] + getitem_13715 = split_737[30] + getitem_13716 = split_737[31] + getitem_13717 = split_737[32] + getitem_13718 = split_737[33] + getitem_13719 = split_737[34] + getitem_13720 = split_737[35] + getitem_13721 = split_737[36] + getitem_13722 = split_737[37] + getitem_13723 = split_737[38] + getitem_13724 = split_737[39] + getitem_13725 = split_737[40] + getitem_13726 = split_737[41] + getitem_13727 = split_737[42] + getitem_13728 = split_737[43] + getitem_13729 = split_737[44] + getitem_13730 = split_737[45] + getitem_13731 = split_737[46] + getitem_13732 = split_737[47] + getitem_13733 = split_737[48] + getitem_13734 = split_737[49] + getitem_13735 = split_737[50] + getitem_13736 = split_737[51] + getitem_13737 = split_737[52] + getitem_13738 = split_737[53] + getitem_13739 = split_737[54] + getitem_13740 = split_737[55] + getitem_13741 = split_737[56] + getitem_13742 = split_737[57] + getitem_13743 = split_737[58] + getitem_13744 = split_737[59] + getitem_13745 = split_737[60] + getitem_13746 = split_737[61] + getitem_13747 = split_737[62] + getitem_13748 = split_737[63]; split_737 = None + cat_319 = torch.ops.aten.cat.default([getitem_13685, getitem_13686, getitem_13687, getitem_13688, getitem_13689, getitem_13690, getitem_13691, getitem_13692, getitem_13693, getitem_13694, getitem_13695, getitem_13696, getitem_13697, getitem_13698, getitem_13699, getitem_13700, getitem_13701, getitem_13702, getitem_13703, getitem_13704, getitem_13705, getitem_13706, getitem_13707, getitem_13708, getitem_13709, getitem_13710, getitem_13711, getitem_13712, getitem_13713, getitem_13714, getitem_13715, getitem_13716, getitem_13717, getitem_13718, getitem_13719, getitem_13720, getitem_13721, getitem_13722, getitem_13723, getitem_13724, getitem_13725, getitem_13726, getitem_13727, getitem_13728, getitem_13729, getitem_13730, getitem_13731, getitem_13732, getitem_13733, getitem_13734, getitem_13735, getitem_13736, getitem_13737, getitem_13738, getitem_13739, getitem_13740, getitem_13741, getitem_13742, getitem_13743, getitem_13744, getitem_13745, getitem_13746, getitem_13747, getitem_13748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_13685 = getitem_13686 = getitem_13687 = getitem_13688 = getitem_13689 = getitem_13690 = getitem_13691 = getitem_13692 = getitem_13693 = getitem_13694 = getitem_13695 = getitem_13696 = getitem_13697 = getitem_13698 = getitem_13699 = getitem_13700 = getitem_13701 = getitem_13702 = getitem_13703 = getitem_13704 = getitem_13705 = getitem_13706 = getitem_13707 = getitem_13708 = getitem_13709 = getitem_13710 = getitem_13711 = getitem_13712 = getitem_13713 = getitem_13714 = getitem_13715 = getitem_13716 = getitem_13717 = getitem_13718 = getitem_13719 = getitem_13720 = getitem_13721 = getitem_13722 = getitem_13723 = getitem_13724 = getitem_13725 = getitem_13726 = getitem_13727 = getitem_13728 = getitem_13729 = getitem_13730 = getitem_13731 = getitem_13732 = getitem_13733 = getitem_13734 = getitem_13735 = getitem_13736 = getitem_13737 = getitem_13738 = getitem_13739 = getitem_13740 = getitem_13741 = getitem_13742 = getitem_13743 = getitem_13744 = getitem_13745 = getitem_13746 = getitem_13747 = getitem_13748 = None + reduce_scatter_tensor_148 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_319, 'avg', 128, '0'); cat_319 = None + wait_tensor_729 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_148); reduce_scatter_tensor_148 = None + view_1975 = torch.ops.aten.view.default(add_1936, [2, 4096, 2048]); add_1936 = None + convert_element_type_2230 = torch.ops.prims.convert_element_type.default(view_1975, torch.float32); view_1975 = None + convert_element_type_869 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16); primals_269 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_869, 128, '0'); convert_element_type_869 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + convert_element_type_2232 = torch.ops.prims.convert_element_type.default(wait_tensor_332, torch.float32); wait_tensor_332 = None + mul_1645 = torch.ops.aten.mul.Tensor(convert_element_type_2230, convert_element_type_2232); convert_element_type_2232 = None + convert_element_type_870 = torch.ops.prims.convert_element_type.default(add_1028, torch.float32); add_1028 = None + mul_750 = torch.ops.aten.mul.Tensor(convert_element_type_870, rsqrt_50); convert_element_type_870 = None + mul_1647 = torch.ops.aten.mul.Tensor(mul_750, mul_1645) + sum_188 = torch.ops.aten.sum.dim_IntList(mul_1647, [2], True); mul_1647 = None + div_195 = torch.ops.aten.div.Tensor(mul_750, 2048) + mul_1648 = torch.ops.aten.mul.Tensor(div_195, sum_188); div_195 = sum_188 = None + sub_688 = torch.ops.aten.sub.Tensor(mul_1645, mul_1648); mul_1645 = mul_1648 = None + mul_1649 = torch.ops.aten.mul.Tensor(sub_688, rsqrt_50); sub_688 = rsqrt_50 = None + mul_1650 = torch.ops.aten.mul.Tensor(convert_element_type_2230, mul_750); convert_element_type_2230 = mul_750 = None + sum_189 = torch.ops.aten.sum.dim_IntList(mul_1650, [0, 1]); mul_1650 = None + convert_element_type_2233 = torch.ops.prims.convert_element_type.default(mul_1649, torch.bfloat16); mul_1649 = None + add_1937 = torch.ops.aten.add.Tensor(add_1924, convert_element_type_2233); add_1924 = convert_element_type_2233 = None + convert_element_type_default_51 = torch.ops.prims.convert_element_type.default(sum_189, torch.float32); sum_189 = None + reduce_scatter_tensor_149 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_51, 'avg', 128, '0'); convert_element_type_default_51 = None + wait_tensor_730 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_149); reduce_scatter_tensor_149 = None + view_1976 = torch.ops.aten.view.default(add_1937, [8192, 2048]) + permute_936 = torch.ops.aten.permute.default(view_1976, [1, 0]) + permute_242 = torch.ops.aten.permute.default(getitem_1665, [0, 2, 1, 3]) + view_1058 = torch.ops.aten.view.default(permute_242, [2, 4096, -1]); permute_242 = None + view_1060 = torch.ops.aten.view.default(view_1058, [8192, 2048]); view_1058 = None + mm_386 = torch.ops.aten.mm.default(permute_936, view_1060); permute_936 = view_1060 = None + convert_element_type_866 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_866, 128, '0'); convert_element_type_866 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_331, [1, 0]); wait_tensor_331 = None + permute_938 = torch.ops.aten.permute.default(permute_243, [1, 0]); permute_243 = None + mm_387 = torch.ops.aten.mm.default(view_1976, permute_938); view_1976 = permute_938 = None + view_1977 = torch.ops.aten.view.default(mm_387, [2, 4096, 2048]); mm_387 = None + convert_element_type_2240 = torch.ops.prims.convert_element_type.default(mm_386, torch.float32); mm_386 = None + reduce_scatter_tensor_150 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2240, 'avg', 128, '0'); convert_element_type_2240 = None + wait_tensor_731 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_150); reduce_scatter_tensor_150 = None + view_1978 = torch.ops.aten.view.default(view_1977, [2, 4096, 16, 128]); view_1977 = None + permute_940 = torch.ops.aten.permute.default(view_1978, [0, 2, 1, 3]); view_1978 = None + fw_graph10 = self.fw_graph10 + joint_graph10 = self.joint_graph10 + mask_graph10 = self.mask_graph10 + flex_attention_backward_10 = torch.ops.higher_order.flex_attention_backward(permute_239, permute_240, permute_241, getitem_1665, getitem_1666, permute_940, None, fw_graph10, joint_graph10, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph10), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_239 = permute_240 = permute_241 = getitem_1665 = getitem_1666 = permute_940 = fw_graph10 = joint_graph10 = mask_graph10 = None + getitem_13749 = flex_attention_backward_10[0] + getitem_13750 = flex_attention_backward_10[1] + getitem_13751 = flex_attention_backward_10[2]; flex_attention_backward_10 = None + permute_941 = torch.ops.aten.permute.default(getitem_13751, [0, 2, 1, 3]); getitem_13751 = None + permute_942 = torch.ops.aten.permute.default(getitem_13750, [0, 2, 1, 3]); getitem_13750 = None + permute_943 = torch.ops.aten.permute.default(getitem_13749, [0, 2, 1, 3]); getitem_13749 = None + slice_224 = torch.ops.aten.slice.Tensor(permute_942, 3, 0, 128) + slice_225 = torch.ops.aten.slice.Tensor(permute_942, 3, 128, 192); permute_942 = None + sum_190 = torch.ops.aten.sum.dim_IntList(slice_225, [2], True); slice_225 = None + cat_320 = torch.ops.aten.cat.default([slice_224, permute_941], 3); slice_224 = permute_941 = None + view_1979 = torch.ops.aten.view.default(cat_320, [2, 4096, 4096]); cat_320 = None + view_1980 = torch.ops.aten.view.default(view_1979, [8192, 4096]); view_1979 = None + permute_944 = torch.ops.aten.permute.default(view_1980, [1, 0]) + mm_388 = torch.ops.aten.mm.default(permute_944, view_1055); permute_944 = view_1055 = None + convert_element_type_863 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_863, 128, '0'); convert_element_type_863 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_330, [1, 0]); wait_tensor_330 = None + permute_946 = torch.ops.aten.permute.default(permute_238, [1, 0]); permute_238 = None + mm_389 = torch.ops.aten.mm.default(view_1980, permute_946); view_1980 = permute_946 = None + view_1981 = torch.ops.aten.view.default(mm_389, [2, 4096, 512]); mm_389 = None + convert_element_type_2245 = torch.ops.prims.convert_element_type.default(mm_388, torch.float32); mm_388 = None + reduce_scatter_tensor_151 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2245, 'avg', 128, '0'); convert_element_type_2245 = None + wait_tensor_732 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_151); reduce_scatter_tensor_151 = None + convert_element_type_2246 = torch.ops.prims.convert_element_type.default(view_1981, torch.float32); view_1981 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16); primals_266 = None + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_860, 128, '0'); convert_element_type_860 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + convert_element_type_2248 = torch.ops.prims.convert_element_type.default(wait_tensor_329, torch.float32); wait_tensor_329 = None + mul_1651 = torch.ops.aten.mul.Tensor(convert_element_type_2246, convert_element_type_2248); convert_element_type_2248 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(getitem_1661, torch.float32); getitem_1661 = None + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_861, rsqrt_49); convert_element_type_861 = None + mul_1653 = torch.ops.aten.mul.Tensor(mul_748, mul_1651) + sum_191 = torch.ops.aten.sum.dim_IntList(mul_1653, [2], True); mul_1653 = None + div_196 = torch.ops.aten.div.Tensor(mul_748, 512) + mul_1654 = torch.ops.aten.mul.Tensor(div_196, sum_191); div_196 = sum_191 = None + sub_689 = torch.ops.aten.sub.Tensor(mul_1651, mul_1654); mul_1651 = mul_1654 = None + mul_1655 = torch.ops.aten.mul.Tensor(sub_689, rsqrt_49); sub_689 = rsqrt_49 = None + mul_1656 = torch.ops.aten.mul.Tensor(convert_element_type_2246, mul_748); convert_element_type_2246 = mul_748 = None + sum_192 = torch.ops.aten.sum.dim_IntList(mul_1656, [0, 1]); mul_1656 = None + convert_element_type_2249 = torch.ops.prims.convert_element_type.default(mul_1655, torch.bfloat16); mul_1655 = None + convert_element_type_default_50 = torch.ops.prims.convert_element_type.default(sum_192, torch.float32); sum_192 = None + reduce_scatter_tensor_152 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_50, 'avg', 128, '0'); convert_element_type_default_50 = None + wait_tensor_733 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_152); reduce_scatter_tensor_152 = None + convert_element_type_2252 = torch.ops.prims.convert_element_type.default(sum_190, torch.float32); sum_190 = None + view_1982 = torch.ops.aten.view.default(convert_element_type_2252, [2, 4096, 1, 32, 2]); convert_element_type_2252 = None + view_as_complex_74 = torch.ops.aten.view_as_complex.default(view_1982); view_1982 = None + mul_1657 = torch.ops.aten.mul.Tensor(view_as_complex_74, clone_9); view_as_complex_74 = None + view_as_real_74 = torch.ops.aten.view_as_real.default(mul_1657); mul_1657 = None + view_1983 = torch.ops.aten.view.default(view_as_real_74, [2, 4096, 1, 64]); view_as_real_74 = None + convert_element_type_2253 = torch.ops.prims.convert_element_type.default(view_1983, torch.bfloat16); view_1983 = None + squeeze_36 = torch.ops.aten.squeeze.dim(convert_element_type_2253, 2); convert_element_type_2253 = None + cat_321 = torch.ops.aten.cat.default([convert_element_type_2249, squeeze_36], 2); convert_element_type_2249 = squeeze_36 = None + view_1984 = torch.ops.aten.view.default(cat_321, [8192, 576]); cat_321 = None + permute_948 = torch.ops.aten.permute.default(view_1984, [1, 0]) + mm_390 = torch.ops.aten.mm.default(permute_948, view_1041); permute_948 = None + convert_element_type_855 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16); primals_265 = None + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_855, 128, '0'); convert_element_type_855 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + slice_97 = torch.ops.aten.slice.Tensor(wait_tensor_328, 0, 0, 576); wait_tensor_328 = None + permute_237 = torch.ops.aten.permute.default(slice_97, [1, 0]); slice_97 = None + permute_950 = torch.ops.aten.permute.default(permute_237, [1, 0]); permute_237 = None + mm_391 = torch.ops.aten.mm.default(view_1984, permute_950); view_1984 = permute_950 = None + view_1985 = torch.ops.aten.view.default(mm_391, [2, 4096, 2048]); mm_391 = None + convert_element_type_2258 = torch.ops.prims.convert_element_type.default(mm_390, torch.float32); mm_390 = None + split_738 = torch.ops.aten.split.Tensor(convert_element_type_2258, 5); convert_element_type_2258 = None + getitem_13753 = split_738[0] + getitem_13754 = split_738[1] + getitem_13755 = split_738[2] + getitem_13756 = split_738[3] + getitem_13757 = split_738[4] + getitem_13758 = split_738[5] + getitem_13759 = split_738[6] + getitem_13760 = split_738[7] + getitem_13761 = split_738[8] + getitem_13762 = split_738[9] + getitem_13763 = split_738[10] + getitem_13764 = split_738[11] + getitem_13765 = split_738[12] + getitem_13766 = split_738[13] + getitem_13767 = split_738[14] + getitem_13768 = split_738[15] + getitem_13769 = split_738[16] + getitem_13770 = split_738[17] + getitem_13771 = split_738[18] + getitem_13772 = split_738[19] + getitem_13773 = split_738[20] + getitem_13774 = split_738[21] + getitem_13775 = split_738[22] + getitem_13776 = split_738[23] + getitem_13777 = split_738[24] + getitem_13778 = split_738[25] + getitem_13779 = split_738[26] + getitem_13780 = split_738[27] + getitem_13781 = split_738[28] + getitem_13782 = split_738[29] + getitem_13783 = split_738[30] + getitem_13784 = split_738[31] + getitem_13785 = split_738[32] + getitem_13786 = split_738[33] + getitem_13787 = split_738[34] + getitem_13788 = split_738[35] + getitem_13789 = split_738[36] + getitem_13790 = split_738[37] + getitem_13791 = split_738[38] + getitem_13792 = split_738[39] + getitem_13793 = split_738[40] + getitem_13794 = split_738[41] + getitem_13795 = split_738[42] + getitem_13796 = split_738[43] + getitem_13797 = split_738[44] + getitem_13798 = split_738[45] + getitem_13799 = split_738[46] + getitem_13800 = split_738[47] + getitem_13801 = split_738[48] + getitem_13802 = split_738[49] + getitem_13803 = split_738[50] + getitem_13804 = split_738[51] + getitem_13805 = split_738[52] + getitem_13806 = split_738[53] + getitem_13807 = split_738[54] + getitem_13808 = split_738[55] + getitem_13809 = split_738[56] + getitem_13810 = split_738[57] + getitem_13811 = split_738[58] + getitem_13812 = split_738[59] + getitem_13813 = split_738[60] + getitem_13814 = split_738[61] + getitem_13815 = split_738[62] + getitem_13816 = split_738[63] + getitem_13817 = split_738[64] + getitem_13818 = split_738[65] + getitem_13819 = split_738[66] + getitem_13820 = split_738[67] + getitem_13821 = split_738[68] + getitem_13822 = split_738[69] + getitem_13823 = split_738[70] + getitem_13824 = split_738[71] + getitem_13825 = split_738[72] + getitem_13826 = split_738[73] + getitem_13827 = split_738[74] + getitem_13828 = split_738[75] + getitem_13829 = split_738[76] + getitem_13830 = split_738[77] + getitem_13831 = split_738[78] + getitem_13832 = split_738[79] + getitem_13833 = split_738[80] + getitem_13834 = split_738[81] + getitem_13835 = split_738[82] + getitem_13836 = split_738[83] + getitem_13837 = split_738[84] + getitem_13838 = split_738[85] + getitem_13839 = split_738[86] + getitem_13840 = split_738[87] + getitem_13841 = split_738[88] + getitem_13842 = split_738[89] + getitem_13843 = split_738[90] + getitem_13844 = split_738[91] + getitem_13845 = split_738[92] + getitem_13846 = split_738[93] + getitem_13847 = split_738[94] + getitem_13848 = split_738[95] + getitem_13849 = split_738[96] + getitem_13850 = split_738[97] + getitem_13851 = split_738[98] + getitem_13852 = split_738[99] + getitem_13853 = split_738[100] + getitem_13854 = split_738[101] + getitem_13855 = split_738[102] + getitem_13856 = split_738[103] + getitem_13857 = split_738[104] + getitem_13858 = split_738[105] + getitem_13859 = split_738[106] + getitem_13860 = split_738[107] + getitem_13861 = split_738[108] + getitem_13862 = split_738[109] + getitem_13863 = split_738[110] + getitem_13864 = split_738[111] + getitem_13865 = split_738[112] + getitem_13866 = split_738[113] + getitem_13867 = split_738[114] + getitem_13868 = split_738[115]; split_738 = None + constant_pad_nd_834 = torch.ops.aten.constant_pad_nd.default(getitem_13868, [0, 0, 0, 4], 0.0); getitem_13868 = None + cat_322 = torch.ops.aten.cat.default([getitem_13753, getitem_13754, getitem_13755, getitem_13756, getitem_13757, getitem_13758, getitem_13759, getitem_13760, getitem_13761, getitem_13762, getitem_13763, getitem_13764, getitem_13765, getitem_13766, getitem_13767, getitem_13768, getitem_13769, getitem_13770, getitem_13771, getitem_13772, getitem_13773, getitem_13774, getitem_13775, getitem_13776, getitem_13777, getitem_13778, getitem_13779, getitem_13780, getitem_13781, getitem_13782, getitem_13783, getitem_13784, getitem_13785, getitem_13786, getitem_13787, getitem_13788, getitem_13789, getitem_13790, getitem_13791, getitem_13792, getitem_13793, getitem_13794, getitem_13795, getitem_13796, getitem_13797, getitem_13798, getitem_13799, getitem_13800, getitem_13801, getitem_13802, getitem_13803, getitem_13804, getitem_13805, getitem_13806, getitem_13807, getitem_13808, getitem_13809, getitem_13810, getitem_13811, getitem_13812, getitem_13813, getitem_13814, getitem_13815, getitem_13816, getitem_13817, getitem_13818, getitem_13819, getitem_13820, getitem_13821, getitem_13822, getitem_13823, getitem_13824, getitem_13825, getitem_13826, getitem_13827, getitem_13828, getitem_13829, getitem_13830, getitem_13831, getitem_13832, getitem_13833, getitem_13834, getitem_13835, getitem_13836, getitem_13837, getitem_13838, getitem_13839, getitem_13840, getitem_13841, getitem_13842, getitem_13843, getitem_13844, getitem_13845, getitem_13846, getitem_13847, getitem_13848, getitem_13849, getitem_13850, getitem_13851, getitem_13852, getitem_13853, getitem_13854, getitem_13855, getitem_13856, getitem_13857, getitem_13858, getitem_13859, getitem_13860, getitem_13861, getitem_13862, getitem_13863, getitem_13864, getitem_13865, getitem_13866, getitem_13867, constant_pad_nd_834, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_13753 = getitem_13754 = getitem_13755 = getitem_13756 = getitem_13757 = getitem_13758 = getitem_13759 = getitem_13760 = getitem_13761 = getitem_13762 = getitem_13763 = getitem_13764 = getitem_13765 = getitem_13766 = getitem_13767 = getitem_13768 = getitem_13769 = getitem_13770 = getitem_13771 = getitem_13772 = getitem_13773 = getitem_13774 = getitem_13775 = getitem_13776 = getitem_13777 = getitem_13778 = getitem_13779 = getitem_13780 = getitem_13781 = getitem_13782 = getitem_13783 = getitem_13784 = getitem_13785 = getitem_13786 = getitem_13787 = getitem_13788 = getitem_13789 = getitem_13790 = getitem_13791 = getitem_13792 = getitem_13793 = getitem_13794 = getitem_13795 = getitem_13796 = getitem_13797 = getitem_13798 = getitem_13799 = getitem_13800 = getitem_13801 = getitem_13802 = getitem_13803 = getitem_13804 = getitem_13805 = getitem_13806 = getitem_13807 = getitem_13808 = getitem_13809 = getitem_13810 = getitem_13811 = getitem_13812 = getitem_13813 = getitem_13814 = getitem_13815 = getitem_13816 = getitem_13817 = getitem_13818 = getitem_13819 = getitem_13820 = getitem_13821 = getitem_13822 = getitem_13823 = getitem_13824 = getitem_13825 = getitem_13826 = getitem_13827 = getitem_13828 = getitem_13829 = getitem_13830 = getitem_13831 = getitem_13832 = getitem_13833 = getitem_13834 = getitem_13835 = getitem_13836 = getitem_13837 = getitem_13838 = getitem_13839 = getitem_13840 = getitem_13841 = getitem_13842 = getitem_13843 = getitem_13844 = getitem_13845 = getitem_13846 = getitem_13847 = getitem_13848 = getitem_13849 = getitem_13850 = getitem_13851 = getitem_13852 = getitem_13853 = getitem_13854 = getitem_13855 = getitem_13856 = getitem_13857 = getitem_13858 = getitem_13859 = getitem_13860 = getitem_13861 = getitem_13862 = getitem_13863 = getitem_13864 = getitem_13865 = getitem_13866 = getitem_13867 = constant_pad_nd_834 = None + reduce_scatter_tensor_153 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_322, 'avg', 128, '0'); cat_322 = None + wait_tensor_734 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_153); reduce_scatter_tensor_153 = None + slice_226 = torch.ops.aten.slice.Tensor(permute_943, 3, 0, 128) + slice_227 = torch.ops.aten.slice.Tensor(permute_943, 3, 128, 192); permute_943 = None + convert_element_type_2259 = torch.ops.prims.convert_element_type.default(slice_227, torch.float32); slice_227 = None + view_1986 = torch.ops.aten.view.default(convert_element_type_2259, [2, 4096, 16, 32, 2]); convert_element_type_2259 = None + view_as_complex_75 = torch.ops.aten.view_as_complex.default(view_1986); view_1986 = None + mul_1658 = torch.ops.aten.mul.Tensor(view_as_complex_75, clone_9); view_as_complex_75 = None + view_as_real_75 = torch.ops.aten.view_as_real.default(mul_1658); mul_1658 = None + view_1987 = torch.ops.aten.view.default(view_as_real_75, [2, 4096, 16, 64]); view_as_real_75 = None + convert_element_type_2260 = torch.ops.prims.convert_element_type.default(view_1987, torch.bfloat16); view_1987 = None + cat_323 = torch.ops.aten.cat.default([slice_226, convert_element_type_2260], 3); slice_226 = convert_element_type_2260 = None + view_1988 = torch.ops.aten.view.default(cat_323, [2, 4096, 3072]); cat_323 = None + view_1989 = torch.ops.aten.view.default(view_1988, [8192, 3072]); view_1988 = None + permute_952 = torch.ops.aten.permute.default(view_1989, [1, 0]) + mm_392 = torch.ops.aten.mm.default(permute_952, view_1041); permute_952 = view_1041 = None + convert_element_type_850 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16); primals_264 = None + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_850, 128, '0'); convert_element_type_850 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + permute_236 = torch.ops.aten.permute.default(wait_tensor_327, [1, 0]); wait_tensor_327 = None + permute_954 = torch.ops.aten.permute.default(permute_236, [1, 0]); permute_236 = None + mm_393 = torch.ops.aten.mm.default(view_1989, permute_954); view_1989 = permute_954 = None + view_1990 = torch.ops.aten.view.default(mm_393, [2, 4096, 2048]); mm_393 = None + add_1938 = torch.ops.aten.add.Tensor(view_1985, view_1990); view_1985 = view_1990 = None + convert_element_type_2265 = torch.ops.prims.convert_element_type.default(mm_392, torch.float32); mm_392 = None + reduce_scatter_tensor_154 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2265, 'avg', 128, '0'); convert_element_type_2265 = None + wait_tensor_735 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_154); reduce_scatter_tensor_154 = None + convert_element_type_2266 = torch.ops.prims.convert_element_type.default(add_1938, torch.float32); add_1938 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16); primals_263 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_847, 128, '0'); convert_element_type_847 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + convert_element_type_2268 = torch.ops.prims.convert_element_type.default(wait_tensor_326, torch.float32); wait_tensor_326 = None + mul_1659 = torch.ops.aten.mul.Tensor(convert_element_type_2266, convert_element_type_2268); convert_element_type_2268 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(add_1025, torch.float32); add_1025 = None + mul_744 = torch.ops.aten.mul.Tensor(convert_element_type_848, rsqrt_48); convert_element_type_848 = None + mul_1661 = torch.ops.aten.mul.Tensor(mul_744, mul_1659) + sum_193 = torch.ops.aten.sum.dim_IntList(mul_1661, [2], True); mul_1661 = None + div_197 = torch.ops.aten.div.Tensor(mul_744, 2048) + mul_1662 = torch.ops.aten.mul.Tensor(div_197, sum_193); div_197 = sum_193 = None + sub_690 = torch.ops.aten.sub.Tensor(mul_1659, mul_1662); mul_1659 = mul_1662 = None + mul_1663 = torch.ops.aten.mul.Tensor(sub_690, rsqrt_48); sub_690 = rsqrt_48 = None + mul_1664 = torch.ops.aten.mul.Tensor(convert_element_type_2266, mul_744); convert_element_type_2266 = mul_744 = None + sum_194 = torch.ops.aten.sum.dim_IntList(mul_1664, [0, 1]); mul_1664 = None + convert_element_type_2269 = torch.ops.prims.convert_element_type.default(mul_1663, torch.bfloat16); mul_1663 = None + add_1939 = torch.ops.aten.add.Tensor(add_1937, convert_element_type_2269); add_1937 = convert_element_type_2269 = None + convert_element_type_default_49 = torch.ops.prims.convert_element_type.default(sum_194, torch.float32); sum_194 = None + reduce_scatter_tensor_155 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_49, 'avg', 128, '0'); convert_element_type_default_49 = None + wait_tensor_736 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_155); reduce_scatter_tensor_155 = None + view_1991 = torch.ops.aten.view.default(add_1939, [8192, 2048]) + unsqueeze_64 = torch.ops.aten.unsqueeze.default(view_1991, 1) + convert_element_type_2272 = torch.ops.prims.convert_element_type.default(unsqueeze_64, torch.float32); unsqueeze_64 = None + bmm_48 = torch.ops.aten.bmm.default(permute_956, convert_element_type_2272); permute_956 = None + bmm_49 = torch.ops.aten.bmm.default(convert_element_type_2272, permute_957); convert_element_type_2272 = permute_957 = None + convert_element_type_2273 = torch.ops.prims.convert_element_type.default(bmm_48, torch.bfloat16); bmm_48 = None + view_1992 = torch.ops.aten.view.default(bmm_49, [8192, 6]); bmm_49 = None + view_1993 = torch.ops.aten.view.default(convert_element_type_2273, [49152, 2048]); convert_element_type_2273 = None + index_74 = torch.ops.aten.index.Tensor(view_1993, [getitem_1561]); view_1993 = getitem_1561 = None + permute_958 = torch.ops.aten.permute.default(view_1991, [1, 0]) + mm_394 = torch.ops.aten.mm.default(permute_958, mul_741); permute_958 = mul_741 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16); primals_262 = None + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 128, '0'); convert_element_type_842 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_235 = torch.ops.aten.permute.default(wait_tensor_325, [1, 0]); wait_tensor_325 = None + permute_960 = torch.ops.aten.permute.default(permute_235, [1, 0]); permute_235 = None + mm_395 = torch.ops.aten.mm.default(view_1991, permute_960); view_1991 = permute_960 = None + convert_element_type_2278 = torch.ops.prims.convert_element_type.default(mm_394, torch.float32); mm_394 = None + reduce_scatter_tensor_156 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2278, 'avg', 128, '0'); convert_element_type_2278 = None + wait_tensor_737 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_156); reduce_scatter_tensor_156 = None + convert_element_type_837 = torch.ops.prims.convert_element_type.default(mm_124, torch.float32); mm_124 = None + neg_30 = torch.ops.aten.neg.default(convert_element_type_837) + exp_45 = torch.ops.aten.exp.default(neg_30); neg_30 = None + add_1020 = torch.ops.aten.add.Tensor(exp_45, 1); exp_45 = None + div_75 = torch.ops.aten.div.Tensor(convert_element_type_837, add_1020) + convert_element_type_838 = torch.ops.prims.convert_element_type.default(div_75, torch.bfloat16); div_75 = None + mul_1665 = torch.ops.aten.mul.Tensor(mm_395, convert_element_type_838); convert_element_type_838 = None + mul_1666 = torch.ops.aten.mul.Tensor(mm_395, mm_125); mm_395 = mm_125 = None + permute_962 = torch.ops.aten.permute.default(mul_1665, [1, 0]) + mm_396 = torch.ops.aten.mm.default(permute_962, view_996); permute_962 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16); primals_261 = None + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_839, 128, '0'); convert_element_type_839 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_234 = torch.ops.aten.permute.default(wait_tensor_324, [1, 0]); wait_tensor_324 = None + permute_964 = torch.ops.aten.permute.default(permute_234, [1, 0]); permute_234 = None + mm_397 = torch.ops.aten.mm.default(mul_1665, permute_964); mul_1665 = permute_964 = None + convert_element_type_2283 = torch.ops.prims.convert_element_type.default(mm_396, torch.float32); mm_396 = None + reduce_scatter_tensor_157 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2283, 'avg', 128, '0'); convert_element_type_2283 = None + wait_tensor_738 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_157); reduce_scatter_tensor_157 = None + convert_element_type_2284 = torch.ops.prims.convert_element_type.default(mul_1666, torch.float32); mul_1666 = None + reciprocal_22 = torch.ops.aten.reciprocal.default(add_1020); add_1020 = None + mul_1667 = torch.ops.aten.mul.Tensor(reciprocal_22, 1); reciprocal_22 = None + mul_1668 = torch.ops.aten.mul.Tensor(convert_element_type_2284, mul_1667); convert_element_type_2284 = None + sub_691 = torch.ops.aten.sub.Tensor(1, mul_1667); mul_1667 = None + mul_1669 = torch.ops.aten.mul.Tensor(convert_element_type_837, sub_691); convert_element_type_837 = sub_691 = None + add_1941 = torch.ops.aten.add.Tensor(mul_1669, 1); mul_1669 = None + mul_1670 = torch.ops.aten.mul.Tensor(mul_1668, add_1941); mul_1668 = add_1941 = None + convert_element_type_2286 = torch.ops.prims.convert_element_type.default(mul_1670, torch.bfloat16); mul_1670 = None + permute_966 = torch.ops.aten.permute.default(convert_element_type_2286, [1, 0]) + mm_398 = torch.ops.aten.mm.default(permute_966, view_996); permute_966 = None + convert_element_type_834 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16); primals_260 = None + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_834, 128, '0'); convert_element_type_834 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_323, [1, 0]); wait_tensor_323 = None + permute_968 = torch.ops.aten.permute.default(permute_233, [1, 0]); permute_233 = None + mm_399 = torch.ops.aten.mm.default(convert_element_type_2286, permute_968); convert_element_type_2286 = permute_968 = None + add_1942 = torch.ops.aten.add.Tensor(mm_397, mm_399); mm_397 = mm_399 = None + convert_element_type_2291 = torch.ops.prims.convert_element_type.default(mm_398, torch.float32); mm_398 = None + reduce_scatter_tensor_158 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2291, 'avg', 128, '0'); convert_element_type_2291 = None + wait_tensor_739 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_158); reduce_scatter_tensor_158 = None + all_to_all_single_100 = torch.ops._c10d_functional.all_to_all_single.default(index_74, [_local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239], [_local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231], '1033'); index_74 = None + wait_tensor_740 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_100); all_to_all_single_100 = None + full_414 = torch.ops.aten.full.default([sym_size_int_57, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_57 = None + slice_scatter_11 = torch.ops.aten.slice_scatter.default(full_414, wait_tensor_740, 0, 0, -1); wait_tensor_740 = None + index_75 = torch.ops.aten.index.Tensor(slice_scatter_11, [getitem_1562]); slice_scatter_11 = None + permute_970 = torch.ops.aten.permute.default(index_75, [1, 0]) + _grouped_mm_144 = torch.ops.aten._grouped_mm.default(permute_970, mul_721, cumsum_44); permute_970 = mul_721 = None + _grouped_mm_145 = torch.ops.aten._grouped_mm.default(index_75, permute_972, cumsum_44); index_75 = permute_972 = None + convert_element_type_832 = torch.ops.prims.convert_element_type.default(_grouped_mm_42, torch.float32); _grouped_mm_42 = None + neg_29 = torch.ops.aten.neg.default(convert_element_type_832) + exp_44 = torch.ops.aten.exp.default(neg_29); neg_29 = None + add_984 = torch.ops.aten.add.Tensor(exp_44, 1); exp_44 = None + div_74 = torch.ops.aten.div.Tensor(convert_element_type_832, add_984) + convert_element_type_833 = torch.ops.prims.convert_element_type.default(div_74, torch.bfloat16); div_74 = None + mul_1671 = torch.ops.aten.mul.Tensor(_grouped_mm_145, convert_element_type_833); convert_element_type_833 = None + mul_1672 = torch.ops.aten.mul.Tensor(_grouped_mm_145, _grouped_mm_43); _grouped_mm_145 = _grouped_mm_43 = None + permute_974 = torch.ops.aten.permute.default(mul_1671, [1, 0]) + _grouped_mm_146 = torch.ops.aten._grouped_mm.default(permute_974, index_29, cumsum_44); permute_974 = None + _grouped_mm_147 = torch.ops.aten._grouped_mm.default(mul_1671, permute_976, cumsum_44); mul_1671 = permute_976 = None + convert_element_type_2292 = torch.ops.prims.convert_element_type.default(mul_1672, torch.float32); mul_1672 = None + reciprocal_23 = torch.ops.aten.reciprocal.default(add_984); add_984 = None + mul_1673 = torch.ops.aten.mul.Tensor(reciprocal_23, 1); reciprocal_23 = None + mul_1674 = torch.ops.aten.mul.Tensor(convert_element_type_2292, mul_1673); convert_element_type_2292 = None + sub_692 = torch.ops.aten.sub.Tensor(1, mul_1673); mul_1673 = None + mul_1675 = torch.ops.aten.mul.Tensor(convert_element_type_832, sub_692); convert_element_type_832 = sub_692 = None + add_1944 = torch.ops.aten.add.Tensor(mul_1675, 1); mul_1675 = None + mul_1676 = torch.ops.aten.mul.Tensor(mul_1674, add_1944); mul_1674 = add_1944 = None + convert_element_type_2294 = torch.ops.prims.convert_element_type.default(mul_1676, torch.bfloat16); mul_1676 = None + permute_978 = torch.ops.aten.permute.default(convert_element_type_2294, [1, 0]) + _grouped_mm_148 = torch.ops.aten._grouped_mm.default(permute_978, index_29, cumsum_44); permute_978 = index_29 = None + _grouped_mm_149 = torch.ops.aten._grouped_mm.default(convert_element_type_2294, permute_980, cumsum_44); convert_element_type_2294 = permute_980 = cumsum_44 = None + add_1945 = torch.ops.aten.add.Tensor(_grouped_mm_147, _grouped_mm_149); _grouped_mm_147 = _grouped_mm_149 = None + convert_element_type_2295 = torch.ops.prims.convert_element_type.default(_grouped_mm_146, torch.float32); _grouped_mm_146 = None + div_198 = torch.ops.aten.div.Tensor(convert_element_type_2295, 128); convert_element_type_2295 = None + split_740 = torch.ops.aten.split.Tensor(div_198, 88, 1); div_198 = None + getitem_13885 = split_740[0] + getitem_13902 = split_740[1] + getitem_13919 = split_740[2] + getitem_13936 = split_740[3] + getitem_13953 = split_740[4] + getitem_13970 = split_740[5] + getitem_13987 = split_740[6] + getitem_14004 = split_740[7] + getitem_14021 = split_740[8] + getitem_14038 = split_740[9] + getitem_14055 = split_740[10] + getitem_14072 = split_740[11] + getitem_14089 = split_740[12] + getitem_14106 = split_740[13] + getitem_14123 = split_740[14] + getitem_14140 = split_740[15]; split_740 = None + cat_324 = torch.ops.aten.cat.default([getitem_13885, getitem_13902, getitem_13919, getitem_13936, getitem_13953, getitem_13970, getitem_13987, getitem_14004, getitem_14021, getitem_14038, getitem_14055, getitem_14072, getitem_14089, getitem_14106, getitem_14123, getitem_14140]); getitem_13885 = getitem_13902 = getitem_13919 = getitem_13936 = getitem_13953 = getitem_13970 = getitem_13987 = getitem_14004 = getitem_14021 = getitem_14038 = getitem_14055 = getitem_14072 = getitem_14089 = getitem_14106 = getitem_14123 = getitem_14140 = None + reduce_scatter_tensor_159 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_324, 'sum', 16, '1025'); cat_324 = None + wait_tensor_741 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_159); reduce_scatter_tensor_159 = None + convert_element_type_2296 = torch.ops.prims.convert_element_type.default(_grouped_mm_144, torch.float32); _grouped_mm_144 = None + div_199 = torch.ops.aten.div.Tensor(convert_element_type_2296, 128); convert_element_type_2296 = None + split_757 = torch.ops.aten.split.Tensor(div_199, 128, 1); div_199 = None + getitem_14157 = split_757[0] + getitem_14174 = split_757[1] + getitem_14191 = split_757[2] + getitem_14208 = split_757[3] + getitem_14225 = split_757[4] + getitem_14242 = split_757[5] + getitem_14259 = split_757[6] + getitem_14276 = split_757[7] + getitem_14293 = split_757[8] + getitem_14310 = split_757[9] + getitem_14327 = split_757[10] + getitem_14344 = split_757[11] + getitem_14361 = split_757[12] + getitem_14378 = split_757[13] + getitem_14395 = split_757[14] + getitem_14412 = split_757[15]; split_757 = None + cat_325 = torch.ops.aten.cat.default([getitem_14157, getitem_14174, getitem_14191, getitem_14208, getitem_14225, getitem_14242, getitem_14259, getitem_14276, getitem_14293, getitem_14310, getitem_14327, getitem_14344, getitem_14361, getitem_14378, getitem_14395, getitem_14412]); getitem_14157 = getitem_14174 = getitem_14191 = getitem_14208 = getitem_14225 = getitem_14242 = getitem_14259 = getitem_14276 = getitem_14293 = getitem_14310 = getitem_14327 = getitem_14344 = getitem_14361 = getitem_14378 = getitem_14395 = getitem_14412 = None + reduce_scatter_tensor_160 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_325, 'sum', 16, '1025'); cat_325 = None + wait_tensor_742 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_160); reduce_scatter_tensor_160 = None + convert_element_type_2297 = torch.ops.prims.convert_element_type.default(_grouped_mm_148, torch.float32); _grouped_mm_148 = None + div_200 = torch.ops.aten.div.Tensor(convert_element_type_2297, 128); convert_element_type_2297 = None + split_774 = torch.ops.aten.split.Tensor(div_200, 88, 1); div_200 = None + getitem_14429 = split_774[0] + getitem_14446 = split_774[1] + getitem_14463 = split_774[2] + getitem_14480 = split_774[3] + getitem_14497 = split_774[4] + getitem_14514 = split_774[5] + getitem_14531 = split_774[6] + getitem_14548 = split_774[7] + getitem_14565 = split_774[8] + getitem_14582 = split_774[9] + getitem_14599 = split_774[10] + getitem_14616 = split_774[11] + getitem_14633 = split_774[12] + getitem_14650 = split_774[13] + getitem_14667 = split_774[14] + getitem_14684 = split_774[15]; split_774 = None + cat_326 = torch.ops.aten.cat.default([getitem_14429, getitem_14446, getitem_14463, getitem_14480, getitem_14497, getitem_14514, getitem_14531, getitem_14548, getitem_14565, getitem_14582, getitem_14599, getitem_14616, getitem_14633, getitem_14650, getitem_14667, getitem_14684]); getitem_14429 = getitem_14446 = getitem_14463 = getitem_14480 = getitem_14497 = getitem_14514 = getitem_14531 = getitem_14548 = getitem_14565 = getitem_14582 = getitem_14599 = getitem_14616 = getitem_14633 = getitem_14650 = getitem_14667 = getitem_14684 = None + reduce_scatter_tensor_161 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_326, 'sum', 16, '1025'); cat_326 = None + wait_tensor_743 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_161); reduce_scatter_tensor_161 = None + index_put_74 = torch.ops.aten.index_put.default(full_414, [getitem_1562], add_1945, True); full_414 = getitem_1562 = add_1945 = None + slice_228 = torch.ops.aten.slice.Tensor(index_put_74, 0, 0, add_1946); index_put_74 = add_1946 = None + all_to_all_single_101 = torch.ops._c10d_functional.all_to_all_single.default(slice_228, [_local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231], [_local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239], '1033'); slice_228 = _local_scalar_dense_224 = _local_scalar_dense_225 = _local_scalar_dense_226 = _local_scalar_dense_227 = _local_scalar_dense_228 = _local_scalar_dense_229 = _local_scalar_dense_230 = _local_scalar_dense_231 = _local_scalar_dense_232 = _local_scalar_dense_233 = _local_scalar_dense_234 = _local_scalar_dense_235 = _local_scalar_dense_236 = _local_scalar_dense_237 = _local_scalar_dense_238 = _local_scalar_dense_239 = None + wait_tensor_744 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_101); all_to_all_single_101 = None + index_put_75 = torch.ops.aten.index_put.default(full_default_52, [div_72], wait_tensor_744, True); div_72 = wait_tensor_744 = None + add_1950 = torch.ops.aten.add.Tensor(add_1942, index_put_75); add_1942 = index_put_75 = None + mul_1677 = torch.ops.aten.mul.Tensor(view_1992, 1.0); view_1992 = None + scatter_add_11 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_1559, mul_1677); getitem_1559 = mul_1677 = None + convert_element_type_821 = torch.ops.prims.convert_element_type.default(mm_123, torch.float32); mm_123 = None + sub_336 = torch.ops.aten.sub.Tensor(convert_element_type_821, amax_14); convert_element_type_821 = amax_14 = None + exp_43 = torch.ops.aten.exp.default(sub_336); sub_336 = None + div_71 = torch.ops.aten.div.Tensor(exp_43, sum_57); exp_43 = sum_57 = None + mul_1678 = torch.ops.aten.mul.Tensor(scatter_add_11, div_71); scatter_add_11 = None + sum_195 = torch.ops.aten.sum.dim_IntList(mul_1678, [1], True) + neg_88 = torch.ops.aten.neg.default(div_71); div_71 = None + fma_11 = torch.ops.prims.fma.default(neg_88, sum_195, mul_1678); neg_88 = sum_195 = mul_1678 = None + convert_element_type_2298 = torch.ops.prims.convert_element_type.default(fma_11, torch.bfloat16); fma_11 = None + permute_982 = torch.ops.aten.permute.default(convert_element_type_2298, [1, 0]) + mm_400 = torch.ops.aten.mm.default(permute_982, view_996); permute_982 = view_996 = None + convert_element_type_818 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16); primals_255 = None + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_818, 128, '0'); convert_element_type_818 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + slice_93 = torch.ops.aten.slice.Tensor(wait_tensor_312, 0, 0, 64); wait_tensor_312 = None + permute_229 = torch.ops.aten.permute.default(slice_93, [1, 0]); slice_93 = None + permute_984 = torch.ops.aten.permute.default(permute_229, [1, 0]); permute_229 = None + mm_401 = torch.ops.aten.mm.default(convert_element_type_2298, permute_984); convert_element_type_2298 = permute_984 = None + add_1951 = torch.ops.aten.add.Tensor(add_1950, mm_401); add_1950 = mm_401 = None + convert_element_type_2303 = torch.ops.prims.convert_element_type.default(mm_400, torch.float32); mm_400 = None + split_790 = torch.ops.aten.split.Tensor(convert_element_type_2303, 1); convert_element_type_2303 = None + getitem_14685 = split_790[0] + getitem_14686 = split_790[1] + getitem_14687 = split_790[2] + getitem_14688 = split_790[3] + getitem_14689 = split_790[4] + getitem_14690 = split_790[5] + getitem_14691 = split_790[6] + getitem_14692 = split_790[7] + getitem_14693 = split_790[8] + getitem_14694 = split_790[9] + getitem_14695 = split_790[10] + getitem_14696 = split_790[11] + getitem_14697 = split_790[12] + getitem_14698 = split_790[13] + getitem_14699 = split_790[14] + getitem_14700 = split_790[15] + getitem_14701 = split_790[16] + getitem_14702 = split_790[17] + getitem_14703 = split_790[18] + getitem_14704 = split_790[19] + getitem_14705 = split_790[20] + getitem_14706 = split_790[21] + getitem_14707 = split_790[22] + getitem_14708 = split_790[23] + getitem_14709 = split_790[24] + getitem_14710 = split_790[25] + getitem_14711 = split_790[26] + getitem_14712 = split_790[27] + getitem_14713 = split_790[28] + getitem_14714 = split_790[29] + getitem_14715 = split_790[30] + getitem_14716 = split_790[31] + getitem_14717 = split_790[32] + getitem_14718 = split_790[33] + getitem_14719 = split_790[34] + getitem_14720 = split_790[35] + getitem_14721 = split_790[36] + getitem_14722 = split_790[37] + getitem_14723 = split_790[38] + getitem_14724 = split_790[39] + getitem_14725 = split_790[40] + getitem_14726 = split_790[41] + getitem_14727 = split_790[42] + getitem_14728 = split_790[43] + getitem_14729 = split_790[44] + getitem_14730 = split_790[45] + getitem_14731 = split_790[46] + getitem_14732 = split_790[47] + getitem_14733 = split_790[48] + getitem_14734 = split_790[49] + getitem_14735 = split_790[50] + getitem_14736 = split_790[51] + getitem_14737 = split_790[52] + getitem_14738 = split_790[53] + getitem_14739 = split_790[54] + getitem_14740 = split_790[55] + getitem_14741 = split_790[56] + getitem_14742 = split_790[57] + getitem_14743 = split_790[58] + getitem_14744 = split_790[59] + getitem_14745 = split_790[60] + getitem_14746 = split_790[61] + getitem_14747 = split_790[62] + getitem_14748 = split_790[63]; split_790 = None + cat_327 = torch.ops.aten.cat.default([getitem_14685, getitem_14686, getitem_14687, getitem_14688, getitem_14689, getitem_14690, getitem_14691, getitem_14692, getitem_14693, getitem_14694, getitem_14695, getitem_14696, getitem_14697, getitem_14698, getitem_14699, getitem_14700, getitem_14701, getitem_14702, getitem_14703, getitem_14704, getitem_14705, getitem_14706, getitem_14707, getitem_14708, getitem_14709, getitem_14710, getitem_14711, getitem_14712, getitem_14713, getitem_14714, getitem_14715, getitem_14716, getitem_14717, getitem_14718, getitem_14719, getitem_14720, getitem_14721, getitem_14722, getitem_14723, getitem_14724, getitem_14725, getitem_14726, getitem_14727, getitem_14728, getitem_14729, getitem_14730, getitem_14731, getitem_14732, getitem_14733, getitem_14734, getitem_14735, getitem_14736, getitem_14737, getitem_14738, getitem_14739, getitem_14740, getitem_14741, getitem_14742, getitem_14743, getitem_14744, getitem_14745, getitem_14746, getitem_14747, getitem_14748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_14685 = getitem_14686 = getitem_14687 = getitem_14688 = getitem_14689 = getitem_14690 = getitem_14691 = getitem_14692 = getitem_14693 = getitem_14694 = getitem_14695 = getitem_14696 = getitem_14697 = getitem_14698 = getitem_14699 = getitem_14700 = getitem_14701 = getitem_14702 = getitem_14703 = getitem_14704 = getitem_14705 = getitem_14706 = getitem_14707 = getitem_14708 = getitem_14709 = getitem_14710 = getitem_14711 = getitem_14712 = getitem_14713 = getitem_14714 = getitem_14715 = getitem_14716 = getitem_14717 = getitem_14718 = getitem_14719 = getitem_14720 = getitem_14721 = getitem_14722 = getitem_14723 = getitem_14724 = getitem_14725 = getitem_14726 = getitem_14727 = getitem_14728 = getitem_14729 = getitem_14730 = getitem_14731 = getitem_14732 = getitem_14733 = getitem_14734 = getitem_14735 = getitem_14736 = getitem_14737 = getitem_14738 = getitem_14739 = getitem_14740 = getitem_14741 = getitem_14742 = getitem_14743 = getitem_14744 = getitem_14745 = getitem_14746 = getitem_14747 = getitem_14748 = None + reduce_scatter_tensor_162 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_327, 'avg', 128, '0'); cat_327 = None + wait_tensor_745 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_162); reduce_scatter_tensor_162 = None + view_1994 = torch.ops.aten.view.default(add_1951, [2, 4096, 2048]); add_1951 = None + convert_element_type_2304 = torch.ops.prims.convert_element_type.default(view_1994, torch.float32); view_1994 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16); primals_253 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 128, '0'); convert_element_type_815 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + convert_element_type_2306 = torch.ops.prims.convert_element_type.default(wait_tensor_311, torch.float32); wait_tensor_311 = None + mul_1679 = torch.ops.aten.mul.Tensor(convert_element_type_2304, convert_element_type_2306); convert_element_type_2306 = None + convert_element_type_816 = torch.ops.prims.convert_element_type.default(add_960, torch.float32); add_960 = None + mul_701 = torch.ops.aten.mul.Tensor(convert_element_type_816, rsqrt_47); convert_element_type_816 = None + mul_1681 = torch.ops.aten.mul.Tensor(mul_701, mul_1679) + sum_196 = torch.ops.aten.sum.dim_IntList(mul_1681, [2], True); mul_1681 = None + div_201 = torch.ops.aten.div.Tensor(mul_701, 2048) + mul_1682 = torch.ops.aten.mul.Tensor(div_201, sum_196); div_201 = sum_196 = None + sub_694 = torch.ops.aten.sub.Tensor(mul_1679, mul_1682); mul_1679 = mul_1682 = None + mul_1683 = torch.ops.aten.mul.Tensor(sub_694, rsqrt_47); sub_694 = rsqrt_47 = None + mul_1684 = torch.ops.aten.mul.Tensor(convert_element_type_2304, mul_701); convert_element_type_2304 = mul_701 = None + sum_197 = torch.ops.aten.sum.dim_IntList(mul_1684, [0, 1]); mul_1684 = None + convert_element_type_2307 = torch.ops.prims.convert_element_type.default(mul_1683, torch.bfloat16); mul_1683 = None + add_1952 = torch.ops.aten.add.Tensor(add_1939, convert_element_type_2307); add_1939 = convert_element_type_2307 = None + convert_element_type_default_48 = torch.ops.prims.convert_element_type.default(sum_197, torch.float32); sum_197 = None + reduce_scatter_tensor_163 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_48, 'avg', 128, '0'); convert_element_type_default_48 = None + wait_tensor_746 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_163); reduce_scatter_tensor_163 = None + view_1995 = torch.ops.aten.view.default(add_1952, [8192, 2048]) + permute_986 = torch.ops.aten.permute.default(view_1995, [1, 0]) + permute_227 = torch.ops.aten.permute.default(getitem_1555, [0, 2, 1, 3]) + view_991 = torch.ops.aten.view.default(permute_227, [2, 4096, -1]); permute_227 = None + view_993 = torch.ops.aten.view.default(view_991, [8192, 2048]); view_991 = None + mm_402 = torch.ops.aten.mm.default(permute_986, view_993); permute_986 = view_993 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16); primals_252 = None + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 128, '0'); convert_element_type_812 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_310, [1, 0]); wait_tensor_310 = None + permute_988 = torch.ops.aten.permute.default(permute_228, [1, 0]); permute_228 = None + mm_403 = torch.ops.aten.mm.default(view_1995, permute_988); view_1995 = permute_988 = None + view_1996 = torch.ops.aten.view.default(mm_403, [2, 4096, 2048]); mm_403 = None + convert_element_type_2314 = torch.ops.prims.convert_element_type.default(mm_402, torch.float32); mm_402 = None + reduce_scatter_tensor_164 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2314, 'avg', 128, '0'); convert_element_type_2314 = None + wait_tensor_747 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_164); reduce_scatter_tensor_164 = None + view_1997 = torch.ops.aten.view.default(view_1996, [2, 4096, 16, 128]); view_1996 = None + permute_990 = torch.ops.aten.permute.default(view_1997, [0, 2, 1, 3]); view_1997 = None + fw_graph11 = self.fw_graph11 + joint_graph11 = self.joint_graph11 + mask_graph11 = self.mask_graph11 + flex_attention_backward_11 = torch.ops.higher_order.flex_attention_backward(permute_224, permute_225, permute_226, getitem_1555, getitem_1556, permute_990, None, fw_graph11, joint_graph11, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph11), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_224 = permute_225 = permute_226 = getitem_1555 = getitem_1556 = permute_990 = fw_graph11 = joint_graph11 = mask_graph11 = None + getitem_14749 = flex_attention_backward_11[0] + getitem_14750 = flex_attention_backward_11[1] + getitem_14751 = flex_attention_backward_11[2]; flex_attention_backward_11 = None + permute_991 = torch.ops.aten.permute.default(getitem_14751, [0, 2, 1, 3]); getitem_14751 = None + permute_992 = torch.ops.aten.permute.default(getitem_14750, [0, 2, 1, 3]); getitem_14750 = None + permute_993 = torch.ops.aten.permute.default(getitem_14749, [0, 2, 1, 3]); getitem_14749 = None + slice_230 = torch.ops.aten.slice.Tensor(permute_992, 3, 0, 128) + slice_231 = torch.ops.aten.slice.Tensor(permute_992, 3, 128, 192); permute_992 = None + sum_198 = torch.ops.aten.sum.dim_IntList(slice_231, [2], True); slice_231 = None + cat_328 = torch.ops.aten.cat.default([slice_230, permute_991], 3); slice_230 = permute_991 = None + view_1998 = torch.ops.aten.view.default(cat_328, [2, 4096, 4096]); cat_328 = None + view_1999 = torch.ops.aten.view.default(view_1998, [8192, 4096]); view_1998 = None + permute_994 = torch.ops.aten.permute.default(view_1999, [1, 0]) + mm_404 = torch.ops.aten.mm.default(permute_994, view_988); permute_994 = view_988 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16); primals_251 = None + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 128, '0'); convert_element_type_809 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + permute_223 = torch.ops.aten.permute.default(wait_tensor_309, [1, 0]); wait_tensor_309 = None + permute_996 = torch.ops.aten.permute.default(permute_223, [1, 0]); permute_223 = None + mm_405 = torch.ops.aten.mm.default(view_1999, permute_996); view_1999 = permute_996 = None + view_2000 = torch.ops.aten.view.default(mm_405, [2, 4096, 512]); mm_405 = None + convert_element_type_2319 = torch.ops.prims.convert_element_type.default(mm_404, torch.float32); mm_404 = None + reduce_scatter_tensor_165 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2319, 'avg', 128, '0'); convert_element_type_2319 = None + wait_tensor_748 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_165); reduce_scatter_tensor_165 = None + convert_element_type_2320 = torch.ops.prims.convert_element_type.default(view_2000, torch.float32); view_2000 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16); primals_250 = None + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_806, 128, '0'); convert_element_type_806 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + convert_element_type_2322 = torch.ops.prims.convert_element_type.default(wait_tensor_308, torch.float32); wait_tensor_308 = None + mul_1685 = torch.ops.aten.mul.Tensor(convert_element_type_2320, convert_element_type_2322); convert_element_type_2322 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(getitem_1551, torch.float32); getitem_1551 = None + mul_699 = torch.ops.aten.mul.Tensor(convert_element_type_807, rsqrt_46); convert_element_type_807 = None + mul_1687 = torch.ops.aten.mul.Tensor(mul_699, mul_1685) + sum_199 = torch.ops.aten.sum.dim_IntList(mul_1687, [2], True); mul_1687 = None + div_202 = torch.ops.aten.div.Tensor(mul_699, 512) + mul_1688 = torch.ops.aten.mul.Tensor(div_202, sum_199); div_202 = sum_199 = None + sub_695 = torch.ops.aten.sub.Tensor(mul_1685, mul_1688); mul_1685 = mul_1688 = None + mul_1689 = torch.ops.aten.mul.Tensor(sub_695, rsqrt_46); sub_695 = rsqrt_46 = None + mul_1690 = torch.ops.aten.mul.Tensor(convert_element_type_2320, mul_699); convert_element_type_2320 = mul_699 = None + sum_200 = torch.ops.aten.sum.dim_IntList(mul_1690, [0, 1]); mul_1690 = None + convert_element_type_2323 = torch.ops.prims.convert_element_type.default(mul_1689, torch.bfloat16); mul_1689 = None + convert_element_type_default_47 = torch.ops.prims.convert_element_type.default(sum_200, torch.float32); sum_200 = None + reduce_scatter_tensor_166 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_47, 'avg', 128, '0'); convert_element_type_default_47 = None + wait_tensor_749 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_166); reduce_scatter_tensor_166 = None + convert_element_type_2326 = torch.ops.prims.convert_element_type.default(sum_198, torch.float32); sum_198 = None + view_2001 = torch.ops.aten.view.default(convert_element_type_2326, [2, 4096, 1, 32, 2]); convert_element_type_2326 = None + view_as_complex_76 = torch.ops.aten.view_as_complex.default(view_2001); view_2001 = None + mul_1691 = torch.ops.aten.mul.Tensor(view_as_complex_76, clone_9); view_as_complex_76 = None + view_as_real_76 = torch.ops.aten.view_as_real.default(mul_1691); mul_1691 = None + view_2002 = torch.ops.aten.view.default(view_as_real_76, [2, 4096, 1, 64]); view_as_real_76 = None + convert_element_type_2327 = torch.ops.prims.convert_element_type.default(view_2002, torch.bfloat16); view_2002 = None + squeeze_37 = torch.ops.aten.squeeze.dim(convert_element_type_2327, 2); convert_element_type_2327 = None + cat_329 = torch.ops.aten.cat.default([convert_element_type_2323, squeeze_37], 2); convert_element_type_2323 = squeeze_37 = None + view_2003 = torch.ops.aten.view.default(cat_329, [8192, 576]); cat_329 = None + permute_998 = torch.ops.aten.permute.default(view_2003, [1, 0]) + mm_406 = torch.ops.aten.mm.default(permute_998, view_974); permute_998 = None + convert_element_type_801 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16); primals_249 = None + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_801, 128, '0'); convert_element_type_801 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + slice_91 = torch.ops.aten.slice.Tensor(wait_tensor_307, 0, 0, 576); wait_tensor_307 = None + permute_222 = torch.ops.aten.permute.default(slice_91, [1, 0]); slice_91 = None + permute_1000 = torch.ops.aten.permute.default(permute_222, [1, 0]); permute_222 = None + mm_407 = torch.ops.aten.mm.default(view_2003, permute_1000); view_2003 = permute_1000 = None + view_2004 = torch.ops.aten.view.default(mm_407, [2, 4096, 2048]); mm_407 = None + convert_element_type_2332 = torch.ops.prims.convert_element_type.default(mm_406, torch.float32); mm_406 = None + split_791 = torch.ops.aten.split.Tensor(convert_element_type_2332, 5); convert_element_type_2332 = None + getitem_14753 = split_791[0] + getitem_14754 = split_791[1] + getitem_14755 = split_791[2] + getitem_14756 = split_791[3] + getitem_14757 = split_791[4] + getitem_14758 = split_791[5] + getitem_14759 = split_791[6] + getitem_14760 = split_791[7] + getitem_14761 = split_791[8] + getitem_14762 = split_791[9] + getitem_14763 = split_791[10] + getitem_14764 = split_791[11] + getitem_14765 = split_791[12] + getitem_14766 = split_791[13] + getitem_14767 = split_791[14] + getitem_14768 = split_791[15] + getitem_14769 = split_791[16] + getitem_14770 = split_791[17] + getitem_14771 = split_791[18] + getitem_14772 = split_791[19] + getitem_14773 = split_791[20] + getitem_14774 = split_791[21] + getitem_14775 = split_791[22] + getitem_14776 = split_791[23] + getitem_14777 = split_791[24] + getitem_14778 = split_791[25] + getitem_14779 = split_791[26] + getitem_14780 = split_791[27] + getitem_14781 = split_791[28] + getitem_14782 = split_791[29] + getitem_14783 = split_791[30] + getitem_14784 = split_791[31] + getitem_14785 = split_791[32] + getitem_14786 = split_791[33] + getitem_14787 = split_791[34] + getitem_14788 = split_791[35] + getitem_14789 = split_791[36] + getitem_14790 = split_791[37] + getitem_14791 = split_791[38] + getitem_14792 = split_791[39] + getitem_14793 = split_791[40] + getitem_14794 = split_791[41] + getitem_14795 = split_791[42] + getitem_14796 = split_791[43] + getitem_14797 = split_791[44] + getitem_14798 = split_791[45] + getitem_14799 = split_791[46] + getitem_14800 = split_791[47] + getitem_14801 = split_791[48] + getitem_14802 = split_791[49] + getitem_14803 = split_791[50] + getitem_14804 = split_791[51] + getitem_14805 = split_791[52] + getitem_14806 = split_791[53] + getitem_14807 = split_791[54] + getitem_14808 = split_791[55] + getitem_14809 = split_791[56] + getitem_14810 = split_791[57] + getitem_14811 = split_791[58] + getitem_14812 = split_791[59] + getitem_14813 = split_791[60] + getitem_14814 = split_791[61] + getitem_14815 = split_791[62] + getitem_14816 = split_791[63] + getitem_14817 = split_791[64] + getitem_14818 = split_791[65] + getitem_14819 = split_791[66] + getitem_14820 = split_791[67] + getitem_14821 = split_791[68] + getitem_14822 = split_791[69] + getitem_14823 = split_791[70] + getitem_14824 = split_791[71] + getitem_14825 = split_791[72] + getitem_14826 = split_791[73] + getitem_14827 = split_791[74] + getitem_14828 = split_791[75] + getitem_14829 = split_791[76] + getitem_14830 = split_791[77] + getitem_14831 = split_791[78] + getitem_14832 = split_791[79] + getitem_14833 = split_791[80] + getitem_14834 = split_791[81] + getitem_14835 = split_791[82] + getitem_14836 = split_791[83] + getitem_14837 = split_791[84] + getitem_14838 = split_791[85] + getitem_14839 = split_791[86] + getitem_14840 = split_791[87] + getitem_14841 = split_791[88] + getitem_14842 = split_791[89] + getitem_14843 = split_791[90] + getitem_14844 = split_791[91] + getitem_14845 = split_791[92] + getitem_14846 = split_791[93] + getitem_14847 = split_791[94] + getitem_14848 = split_791[95] + getitem_14849 = split_791[96] + getitem_14850 = split_791[97] + getitem_14851 = split_791[98] + getitem_14852 = split_791[99] + getitem_14853 = split_791[100] + getitem_14854 = split_791[101] + getitem_14855 = split_791[102] + getitem_14856 = split_791[103] + getitem_14857 = split_791[104] + getitem_14858 = split_791[105] + getitem_14859 = split_791[106] + getitem_14860 = split_791[107] + getitem_14861 = split_791[108] + getitem_14862 = split_791[109] + getitem_14863 = split_791[110] + getitem_14864 = split_791[111] + getitem_14865 = split_791[112] + getitem_14866 = split_791[113] + getitem_14867 = split_791[114] + getitem_14868 = split_791[115]; split_791 = None + constant_pad_nd_911 = torch.ops.aten.constant_pad_nd.default(getitem_14868, [0, 0, 0, 4], 0.0); getitem_14868 = None + cat_330 = torch.ops.aten.cat.default([getitem_14753, getitem_14754, getitem_14755, getitem_14756, getitem_14757, getitem_14758, getitem_14759, getitem_14760, getitem_14761, getitem_14762, getitem_14763, getitem_14764, getitem_14765, getitem_14766, getitem_14767, getitem_14768, getitem_14769, getitem_14770, getitem_14771, getitem_14772, getitem_14773, getitem_14774, getitem_14775, getitem_14776, getitem_14777, getitem_14778, getitem_14779, getitem_14780, getitem_14781, getitem_14782, getitem_14783, getitem_14784, getitem_14785, getitem_14786, getitem_14787, getitem_14788, getitem_14789, getitem_14790, getitem_14791, getitem_14792, getitem_14793, getitem_14794, getitem_14795, getitem_14796, getitem_14797, getitem_14798, getitem_14799, getitem_14800, getitem_14801, getitem_14802, getitem_14803, getitem_14804, getitem_14805, getitem_14806, getitem_14807, getitem_14808, getitem_14809, getitem_14810, getitem_14811, getitem_14812, getitem_14813, getitem_14814, getitem_14815, getitem_14816, getitem_14817, getitem_14818, getitem_14819, getitem_14820, getitem_14821, getitem_14822, getitem_14823, getitem_14824, getitem_14825, getitem_14826, getitem_14827, getitem_14828, getitem_14829, getitem_14830, getitem_14831, getitem_14832, getitem_14833, getitem_14834, getitem_14835, getitem_14836, getitem_14837, getitem_14838, getitem_14839, getitem_14840, getitem_14841, getitem_14842, getitem_14843, getitem_14844, getitem_14845, getitem_14846, getitem_14847, getitem_14848, getitem_14849, getitem_14850, getitem_14851, getitem_14852, getitem_14853, getitem_14854, getitem_14855, getitem_14856, getitem_14857, getitem_14858, getitem_14859, getitem_14860, getitem_14861, getitem_14862, getitem_14863, getitem_14864, getitem_14865, getitem_14866, getitem_14867, constant_pad_nd_911, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_14753 = getitem_14754 = getitem_14755 = getitem_14756 = getitem_14757 = getitem_14758 = getitem_14759 = getitem_14760 = getitem_14761 = getitem_14762 = getitem_14763 = getitem_14764 = getitem_14765 = getitem_14766 = getitem_14767 = getitem_14768 = getitem_14769 = getitem_14770 = getitem_14771 = getitem_14772 = getitem_14773 = getitem_14774 = getitem_14775 = getitem_14776 = getitem_14777 = getitem_14778 = getitem_14779 = getitem_14780 = getitem_14781 = getitem_14782 = getitem_14783 = getitem_14784 = getitem_14785 = getitem_14786 = getitem_14787 = getitem_14788 = getitem_14789 = getitem_14790 = getitem_14791 = getitem_14792 = getitem_14793 = getitem_14794 = getitem_14795 = getitem_14796 = getitem_14797 = getitem_14798 = getitem_14799 = getitem_14800 = getitem_14801 = getitem_14802 = getitem_14803 = getitem_14804 = getitem_14805 = getitem_14806 = getitem_14807 = getitem_14808 = getitem_14809 = getitem_14810 = getitem_14811 = getitem_14812 = getitem_14813 = getitem_14814 = getitem_14815 = getitem_14816 = getitem_14817 = getitem_14818 = getitem_14819 = getitem_14820 = getitem_14821 = getitem_14822 = getitem_14823 = getitem_14824 = getitem_14825 = getitem_14826 = getitem_14827 = getitem_14828 = getitem_14829 = getitem_14830 = getitem_14831 = getitem_14832 = getitem_14833 = getitem_14834 = getitem_14835 = getitem_14836 = getitem_14837 = getitem_14838 = getitem_14839 = getitem_14840 = getitem_14841 = getitem_14842 = getitem_14843 = getitem_14844 = getitem_14845 = getitem_14846 = getitem_14847 = getitem_14848 = getitem_14849 = getitem_14850 = getitem_14851 = getitem_14852 = getitem_14853 = getitem_14854 = getitem_14855 = getitem_14856 = getitem_14857 = getitem_14858 = getitem_14859 = getitem_14860 = getitem_14861 = getitem_14862 = getitem_14863 = getitem_14864 = getitem_14865 = getitem_14866 = getitem_14867 = constant_pad_nd_911 = None + reduce_scatter_tensor_167 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_330, 'avg', 128, '0'); cat_330 = None + wait_tensor_750 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_167); reduce_scatter_tensor_167 = None + slice_232 = torch.ops.aten.slice.Tensor(permute_993, 3, 0, 128) + slice_233 = torch.ops.aten.slice.Tensor(permute_993, 3, 128, 192); permute_993 = None + convert_element_type_2333 = torch.ops.prims.convert_element_type.default(slice_233, torch.float32); slice_233 = None + view_2005 = torch.ops.aten.view.default(convert_element_type_2333, [2, 4096, 16, 32, 2]); convert_element_type_2333 = None + view_as_complex_77 = torch.ops.aten.view_as_complex.default(view_2005); view_2005 = None + mul_1692 = torch.ops.aten.mul.Tensor(view_as_complex_77, clone_9); view_as_complex_77 = None + view_as_real_77 = torch.ops.aten.view_as_real.default(mul_1692); mul_1692 = None + view_2006 = torch.ops.aten.view.default(view_as_real_77, [2, 4096, 16, 64]); view_as_real_77 = None + convert_element_type_2334 = torch.ops.prims.convert_element_type.default(view_2006, torch.bfloat16); view_2006 = None + cat_331 = torch.ops.aten.cat.default([slice_232, convert_element_type_2334], 3); slice_232 = convert_element_type_2334 = None + view_2007 = torch.ops.aten.view.default(cat_331, [2, 4096, 3072]); cat_331 = None + view_2008 = torch.ops.aten.view.default(view_2007, [8192, 3072]); view_2007 = None + permute_1002 = torch.ops.aten.permute.default(view_2008, [1, 0]) + mm_408 = torch.ops.aten.mm.default(permute_1002, view_974); permute_1002 = view_974 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16); primals_248 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 128, '0'); convert_element_type_796 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_306, [1, 0]); wait_tensor_306 = None + permute_1004 = torch.ops.aten.permute.default(permute_221, [1, 0]); permute_221 = None + mm_409 = torch.ops.aten.mm.default(view_2008, permute_1004); view_2008 = permute_1004 = None + view_2009 = torch.ops.aten.view.default(mm_409, [2, 4096, 2048]); mm_409 = None + add_1953 = torch.ops.aten.add.Tensor(view_2004, view_2009); view_2004 = view_2009 = None + convert_element_type_2339 = torch.ops.prims.convert_element_type.default(mm_408, torch.float32); mm_408 = None + reduce_scatter_tensor_168 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2339, 'avg', 128, '0'); convert_element_type_2339 = None + wait_tensor_751 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_168); reduce_scatter_tensor_168 = None + convert_element_type_2340 = torch.ops.prims.convert_element_type.default(add_1953, torch.float32); add_1953 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16); primals_247 = None + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 128, '0'); convert_element_type_793 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_2342 = torch.ops.prims.convert_element_type.default(wait_tensor_305, torch.float32); wait_tensor_305 = None + mul_1693 = torch.ops.aten.mul.Tensor(convert_element_type_2340, convert_element_type_2342); convert_element_type_2342 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_957, torch.float32); add_957 = None + mul_695 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_45); convert_element_type_794 = None + mul_1695 = torch.ops.aten.mul.Tensor(mul_695, mul_1693) + sum_201 = torch.ops.aten.sum.dim_IntList(mul_1695, [2], True); mul_1695 = None + div_203 = torch.ops.aten.div.Tensor(mul_695, 2048) + mul_1696 = torch.ops.aten.mul.Tensor(div_203, sum_201); div_203 = sum_201 = None + sub_696 = torch.ops.aten.sub.Tensor(mul_1693, mul_1696); mul_1693 = mul_1696 = None + mul_1697 = torch.ops.aten.mul.Tensor(sub_696, rsqrt_45); sub_696 = rsqrt_45 = None + mul_1698 = torch.ops.aten.mul.Tensor(convert_element_type_2340, mul_695); convert_element_type_2340 = mul_695 = None + sum_202 = torch.ops.aten.sum.dim_IntList(mul_1698, [0, 1]); mul_1698 = None + convert_element_type_2343 = torch.ops.prims.convert_element_type.default(mul_1697, torch.bfloat16); mul_1697 = None + add_1954 = torch.ops.aten.add.Tensor(add_1952, convert_element_type_2343); add_1952 = convert_element_type_2343 = None + convert_element_type_default_46 = torch.ops.prims.convert_element_type.default(sum_202, torch.float32); sum_202 = None + reduce_scatter_tensor_169 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_46, 'avg', 128, '0'); convert_element_type_default_46 = None + wait_tensor_752 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_169); reduce_scatter_tensor_169 = None + view_2010 = torch.ops.aten.view.default(add_1954, [8192, 2048]) + unsqueeze_65 = torch.ops.aten.unsqueeze.default(view_2010, 1) + convert_element_type_2346 = torch.ops.prims.convert_element_type.default(unsqueeze_65, torch.float32); unsqueeze_65 = None + bmm_50 = torch.ops.aten.bmm.default(permute_1006, convert_element_type_2346); permute_1006 = None + bmm_51 = torch.ops.aten.bmm.default(convert_element_type_2346, permute_1007); convert_element_type_2346 = permute_1007 = None + convert_element_type_2347 = torch.ops.prims.convert_element_type.default(bmm_50, torch.bfloat16); bmm_50 = None + view_2011 = torch.ops.aten.view.default(bmm_51, [8192, 6]); bmm_51 = None + view_2012 = torch.ops.aten.view.default(convert_element_type_2347, [49152, 2048]); convert_element_type_2347 = None + index_76 = torch.ops.aten.index.Tensor(view_2012, [getitem_1451]); view_2012 = getitem_1451 = None + permute_1008 = torch.ops.aten.permute.default(view_2010, [1, 0]) + mm_410 = torch.ops.aten.mm.default(permute_1008, mul_692); permute_1008 = mul_692 = None + convert_element_type_788 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16); primals_246 = None + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_788, 128, '0'); convert_element_type_788 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_304, [1, 0]); wait_tensor_304 = None + permute_1010 = torch.ops.aten.permute.default(permute_220, [1, 0]); permute_220 = None + mm_411 = torch.ops.aten.mm.default(view_2010, permute_1010); view_2010 = permute_1010 = None + convert_element_type_2352 = torch.ops.prims.convert_element_type.default(mm_410, torch.float32); mm_410 = None + reduce_scatter_tensor_170 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2352, 'avg', 128, '0'); convert_element_type_2352 = None + wait_tensor_753 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_170); reduce_scatter_tensor_170 = None + convert_element_type_783 = torch.ops.prims.convert_element_type.default(mm_116, torch.float32); mm_116 = None + neg_28 = torch.ops.aten.neg.default(convert_element_type_783) + exp_42 = torch.ops.aten.exp.default(neg_28); neg_28 = None + add_952 = torch.ops.aten.add.Tensor(exp_42, 1); exp_42 = None + div_70 = torch.ops.aten.div.Tensor(convert_element_type_783, add_952) + convert_element_type_784 = torch.ops.prims.convert_element_type.default(div_70, torch.bfloat16); div_70 = None + mul_1699 = torch.ops.aten.mul.Tensor(mm_411, convert_element_type_784); convert_element_type_784 = None + mul_1700 = torch.ops.aten.mul.Tensor(mm_411, mm_117); mm_411 = mm_117 = None + permute_1012 = torch.ops.aten.permute.default(mul_1699, [1, 0]) + mm_412 = torch.ops.aten.mm.default(permute_1012, view_929); permute_1012 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16); primals_245 = None + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_785, 128, '0'); convert_element_type_785 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_303, [1, 0]); wait_tensor_303 = None + permute_1014 = torch.ops.aten.permute.default(permute_219, [1, 0]); permute_219 = None + mm_413 = torch.ops.aten.mm.default(mul_1699, permute_1014); mul_1699 = permute_1014 = None + convert_element_type_2357 = torch.ops.prims.convert_element_type.default(mm_412, torch.float32); mm_412 = None + reduce_scatter_tensor_171 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2357, 'avg', 128, '0'); convert_element_type_2357 = None + wait_tensor_754 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_171); reduce_scatter_tensor_171 = None + convert_element_type_2358 = torch.ops.prims.convert_element_type.default(mul_1700, torch.float32); mul_1700 = None + reciprocal_24 = torch.ops.aten.reciprocal.default(add_952); add_952 = None + mul_1701 = torch.ops.aten.mul.Tensor(reciprocal_24, 1); reciprocal_24 = None + mul_1702 = torch.ops.aten.mul.Tensor(convert_element_type_2358, mul_1701); convert_element_type_2358 = None + sub_697 = torch.ops.aten.sub.Tensor(1, mul_1701); mul_1701 = None + mul_1703 = torch.ops.aten.mul.Tensor(convert_element_type_783, sub_697); convert_element_type_783 = sub_697 = None + add_1956 = torch.ops.aten.add.Tensor(mul_1703, 1); mul_1703 = None + mul_1704 = torch.ops.aten.mul.Tensor(mul_1702, add_1956); mul_1702 = add_1956 = None + convert_element_type_2360 = torch.ops.prims.convert_element_type.default(mul_1704, torch.bfloat16); mul_1704 = None + permute_1016 = torch.ops.aten.permute.default(convert_element_type_2360, [1, 0]) + mm_414 = torch.ops.aten.mm.default(permute_1016, view_929); permute_1016 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_780, 128, '0'); convert_element_type_780 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_302, [1, 0]); wait_tensor_302 = None + permute_1018 = torch.ops.aten.permute.default(permute_218, [1, 0]); permute_218 = None + mm_415 = torch.ops.aten.mm.default(convert_element_type_2360, permute_1018); convert_element_type_2360 = permute_1018 = None + add_1957 = torch.ops.aten.add.Tensor(mm_413, mm_415); mm_413 = mm_415 = None + convert_element_type_2365 = torch.ops.prims.convert_element_type.default(mm_414, torch.float32); mm_414 = None + reduce_scatter_tensor_172 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2365, 'avg', 128, '0'); convert_element_type_2365 = None + wait_tensor_755 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_172); reduce_scatter_tensor_172 = None + all_to_all_single_102 = torch.ops._c10d_functional.all_to_all_single.default(index_76, [_local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223], [_local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215], '1033'); index_76 = None + wait_tensor_756 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_102); all_to_all_single_102 = None + full_420 = torch.ops.aten.full.default([sym_size_int_53, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_53 = None + slice_scatter_12 = torch.ops.aten.slice_scatter.default(full_420, wait_tensor_756, 0, 0, -1); wait_tensor_756 = None + index_77 = torch.ops.aten.index.Tensor(slice_scatter_12, [getitem_1452]); slice_scatter_12 = None + permute_1020 = torch.ops.aten.permute.default(index_77, [1, 0]) + _grouped_mm_150 = torch.ops.aten._grouped_mm.default(permute_1020, mul_672, cumsum_41); permute_1020 = mul_672 = None + _grouped_mm_151 = torch.ops.aten._grouped_mm.default(index_77, permute_1022, cumsum_41); index_77 = permute_1022 = None + convert_element_type_778 = torch.ops.prims.convert_element_type.default(_grouped_mm_39, torch.float32); _grouped_mm_39 = None + neg_27 = torch.ops.aten.neg.default(convert_element_type_778) + exp_41 = torch.ops.aten.exp.default(neg_27); neg_27 = None + add_916 = torch.ops.aten.add.Tensor(exp_41, 1); exp_41 = None + div_69 = torch.ops.aten.div.Tensor(convert_element_type_778, add_916) + convert_element_type_779 = torch.ops.prims.convert_element_type.default(div_69, torch.bfloat16); div_69 = None + mul_1705 = torch.ops.aten.mul.Tensor(_grouped_mm_151, convert_element_type_779); convert_element_type_779 = None + mul_1706 = torch.ops.aten.mul.Tensor(_grouped_mm_151, _grouped_mm_40); _grouped_mm_151 = _grouped_mm_40 = None + permute_1024 = torch.ops.aten.permute.default(mul_1705, [1, 0]) + _grouped_mm_152 = torch.ops.aten._grouped_mm.default(permute_1024, index_27, cumsum_41); permute_1024 = None + _grouped_mm_153 = torch.ops.aten._grouped_mm.default(mul_1705, permute_1026, cumsum_41); mul_1705 = permute_1026 = None + convert_element_type_2366 = torch.ops.prims.convert_element_type.default(mul_1706, torch.float32); mul_1706 = None + reciprocal_25 = torch.ops.aten.reciprocal.default(add_916); add_916 = None + mul_1707 = torch.ops.aten.mul.Tensor(reciprocal_25, 1); reciprocal_25 = None + mul_1708 = torch.ops.aten.mul.Tensor(convert_element_type_2366, mul_1707); convert_element_type_2366 = None + sub_698 = torch.ops.aten.sub.Tensor(1, mul_1707); mul_1707 = None + mul_1709 = torch.ops.aten.mul.Tensor(convert_element_type_778, sub_698); convert_element_type_778 = sub_698 = None + add_1959 = torch.ops.aten.add.Tensor(mul_1709, 1); mul_1709 = None + mul_1710 = torch.ops.aten.mul.Tensor(mul_1708, add_1959); mul_1708 = add_1959 = None + convert_element_type_2368 = torch.ops.prims.convert_element_type.default(mul_1710, torch.bfloat16); mul_1710 = None + permute_1028 = torch.ops.aten.permute.default(convert_element_type_2368, [1, 0]) + _grouped_mm_154 = torch.ops.aten._grouped_mm.default(permute_1028, index_27, cumsum_41); permute_1028 = index_27 = None + _grouped_mm_155 = torch.ops.aten._grouped_mm.default(convert_element_type_2368, permute_1030, cumsum_41); convert_element_type_2368 = permute_1030 = cumsum_41 = None + add_1960 = torch.ops.aten.add.Tensor(_grouped_mm_153, _grouped_mm_155); _grouped_mm_153 = _grouped_mm_155 = None + convert_element_type_2369 = torch.ops.prims.convert_element_type.default(_grouped_mm_152, torch.float32); _grouped_mm_152 = None + div_204 = torch.ops.aten.div.Tensor(convert_element_type_2369, 128); convert_element_type_2369 = None + split_793 = torch.ops.aten.split.Tensor(div_204, 88, 1); div_204 = None + getitem_14885 = split_793[0] + getitem_14902 = split_793[1] + getitem_14919 = split_793[2] + getitem_14936 = split_793[3] + getitem_14953 = split_793[4] + getitem_14970 = split_793[5] + getitem_14987 = split_793[6] + getitem_15004 = split_793[7] + getitem_15021 = split_793[8] + getitem_15038 = split_793[9] + getitem_15055 = split_793[10] + getitem_15072 = split_793[11] + getitem_15089 = split_793[12] + getitem_15106 = split_793[13] + getitem_15123 = split_793[14] + getitem_15140 = split_793[15]; split_793 = None + cat_332 = torch.ops.aten.cat.default([getitem_14885, getitem_14902, getitem_14919, getitem_14936, getitem_14953, getitem_14970, getitem_14987, getitem_15004, getitem_15021, getitem_15038, getitem_15055, getitem_15072, getitem_15089, getitem_15106, getitem_15123, getitem_15140]); getitem_14885 = getitem_14902 = getitem_14919 = getitem_14936 = getitem_14953 = getitem_14970 = getitem_14987 = getitem_15004 = getitem_15021 = getitem_15038 = getitem_15055 = getitem_15072 = getitem_15089 = getitem_15106 = getitem_15123 = getitem_15140 = None + reduce_scatter_tensor_173 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_332, 'sum', 16, '1025'); cat_332 = None + wait_tensor_757 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_173); reduce_scatter_tensor_173 = None + convert_element_type_2370 = torch.ops.prims.convert_element_type.default(_grouped_mm_150, torch.float32); _grouped_mm_150 = None + div_205 = torch.ops.aten.div.Tensor(convert_element_type_2370, 128); convert_element_type_2370 = None + split_810 = torch.ops.aten.split.Tensor(div_205, 128, 1); div_205 = None + getitem_15157 = split_810[0] + getitem_15174 = split_810[1] + getitem_15191 = split_810[2] + getitem_15208 = split_810[3] + getitem_15225 = split_810[4] + getitem_15242 = split_810[5] + getitem_15259 = split_810[6] + getitem_15276 = split_810[7] + getitem_15293 = split_810[8] + getitem_15310 = split_810[9] + getitem_15327 = split_810[10] + getitem_15344 = split_810[11] + getitem_15361 = split_810[12] + getitem_15378 = split_810[13] + getitem_15395 = split_810[14] + getitem_15412 = split_810[15]; split_810 = None + cat_333 = torch.ops.aten.cat.default([getitem_15157, getitem_15174, getitem_15191, getitem_15208, getitem_15225, getitem_15242, getitem_15259, getitem_15276, getitem_15293, getitem_15310, getitem_15327, getitem_15344, getitem_15361, getitem_15378, getitem_15395, getitem_15412]); getitem_15157 = getitem_15174 = getitem_15191 = getitem_15208 = getitem_15225 = getitem_15242 = getitem_15259 = getitem_15276 = getitem_15293 = getitem_15310 = getitem_15327 = getitem_15344 = getitem_15361 = getitem_15378 = getitem_15395 = getitem_15412 = None + reduce_scatter_tensor_174 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_333, 'sum', 16, '1025'); cat_333 = None + wait_tensor_758 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_174); reduce_scatter_tensor_174 = None + convert_element_type_2371 = torch.ops.prims.convert_element_type.default(_grouped_mm_154, torch.float32); _grouped_mm_154 = None + div_206 = torch.ops.aten.div.Tensor(convert_element_type_2371, 128); convert_element_type_2371 = None + split_827 = torch.ops.aten.split.Tensor(div_206, 88, 1); div_206 = None + getitem_15429 = split_827[0] + getitem_15446 = split_827[1] + getitem_15463 = split_827[2] + getitem_15480 = split_827[3] + getitem_15497 = split_827[4] + getitem_15514 = split_827[5] + getitem_15531 = split_827[6] + getitem_15548 = split_827[7] + getitem_15565 = split_827[8] + getitem_15582 = split_827[9] + getitem_15599 = split_827[10] + getitem_15616 = split_827[11] + getitem_15633 = split_827[12] + getitem_15650 = split_827[13] + getitem_15667 = split_827[14] + getitem_15684 = split_827[15]; split_827 = None + cat_334 = torch.ops.aten.cat.default([getitem_15429, getitem_15446, getitem_15463, getitem_15480, getitem_15497, getitem_15514, getitem_15531, getitem_15548, getitem_15565, getitem_15582, getitem_15599, getitem_15616, getitem_15633, getitem_15650, getitem_15667, getitem_15684]); getitem_15429 = getitem_15446 = getitem_15463 = getitem_15480 = getitem_15497 = getitem_15514 = getitem_15531 = getitem_15548 = getitem_15565 = getitem_15582 = getitem_15599 = getitem_15616 = getitem_15633 = getitem_15650 = getitem_15667 = getitem_15684 = None + reduce_scatter_tensor_175 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_334, 'sum', 16, '1025'); cat_334 = None + wait_tensor_759 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_175); reduce_scatter_tensor_175 = None + index_put_76 = torch.ops.aten.index_put.default(full_420, [getitem_1452], add_1960, True); full_420 = getitem_1452 = add_1960 = None + slice_234 = torch.ops.aten.slice.Tensor(index_put_76, 0, 0, add_1961); index_put_76 = add_1961 = None + all_to_all_single_103 = torch.ops._c10d_functional.all_to_all_single.default(slice_234, [_local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215], [_local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223], '1033'); slice_234 = _local_scalar_dense_208 = _local_scalar_dense_209 = _local_scalar_dense_210 = _local_scalar_dense_211 = _local_scalar_dense_212 = _local_scalar_dense_213 = _local_scalar_dense_214 = _local_scalar_dense_215 = _local_scalar_dense_216 = _local_scalar_dense_217 = _local_scalar_dense_218 = _local_scalar_dense_219 = _local_scalar_dense_220 = _local_scalar_dense_221 = _local_scalar_dense_222 = _local_scalar_dense_223 = None + wait_tensor_760 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_103); all_to_all_single_103 = None + index_put_77 = torch.ops.aten.index_put.default(full_default_52, [div_67], wait_tensor_760, True); div_67 = wait_tensor_760 = None + add_1965 = torch.ops.aten.add.Tensor(add_1957, index_put_77); add_1957 = index_put_77 = None + mul_1711 = torch.ops.aten.mul.Tensor(view_2011, 1.0); view_2011 = None + scatter_add_12 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_1449, mul_1711); getitem_1449 = mul_1711 = None + convert_element_type_767 = torch.ops.prims.convert_element_type.default(mm_115, torch.float32); mm_115 = None + sub_312 = torch.ops.aten.sub.Tensor(convert_element_type_767, amax_13); convert_element_type_767 = amax_13 = None + exp_40 = torch.ops.aten.exp.default(sub_312); sub_312 = None + div_66 = torch.ops.aten.div.Tensor(exp_40, sum_53); exp_40 = sum_53 = None + mul_1712 = torch.ops.aten.mul.Tensor(scatter_add_12, div_66); scatter_add_12 = None + sum_203 = torch.ops.aten.sum.dim_IntList(mul_1712, [1], True) + neg_91 = torch.ops.aten.neg.default(div_66); div_66 = None + fma_12 = torch.ops.prims.fma.default(neg_91, sum_203, mul_1712); neg_91 = sum_203 = mul_1712 = None + convert_element_type_2372 = torch.ops.prims.convert_element_type.default(fma_12, torch.bfloat16); fma_12 = None + permute_1032 = torch.ops.aten.permute.default(convert_element_type_2372, [1, 0]) + mm_416 = torch.ops.aten.mm.default(permute_1032, view_929); permute_1032 = view_929 = None + convert_element_type_764 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16); primals_239 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_764, 128, '0'); convert_element_type_764 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + slice_87 = torch.ops.aten.slice.Tensor(wait_tensor_291, 0, 0, 64); wait_tensor_291 = None + permute_214 = torch.ops.aten.permute.default(slice_87, [1, 0]); slice_87 = None + permute_1034 = torch.ops.aten.permute.default(permute_214, [1, 0]); permute_214 = None + mm_417 = torch.ops.aten.mm.default(convert_element_type_2372, permute_1034); convert_element_type_2372 = permute_1034 = None + add_1966 = torch.ops.aten.add.Tensor(add_1965, mm_417); add_1965 = mm_417 = None + convert_element_type_2377 = torch.ops.prims.convert_element_type.default(mm_416, torch.float32); mm_416 = None + split_843 = torch.ops.aten.split.Tensor(convert_element_type_2377, 1); convert_element_type_2377 = None + getitem_15685 = split_843[0] + getitem_15686 = split_843[1] + getitem_15687 = split_843[2] + getitem_15688 = split_843[3] + getitem_15689 = split_843[4] + getitem_15690 = split_843[5] + getitem_15691 = split_843[6] + getitem_15692 = split_843[7] + getitem_15693 = split_843[8] + getitem_15694 = split_843[9] + getitem_15695 = split_843[10] + getitem_15696 = split_843[11] + getitem_15697 = split_843[12] + getitem_15698 = split_843[13] + getitem_15699 = split_843[14] + getitem_15700 = split_843[15] + getitem_15701 = split_843[16] + getitem_15702 = split_843[17] + getitem_15703 = split_843[18] + getitem_15704 = split_843[19] + getitem_15705 = split_843[20] + getitem_15706 = split_843[21] + getitem_15707 = split_843[22] + getitem_15708 = split_843[23] + getitem_15709 = split_843[24] + getitem_15710 = split_843[25] + getitem_15711 = split_843[26] + getitem_15712 = split_843[27] + getitem_15713 = split_843[28] + getitem_15714 = split_843[29] + getitem_15715 = split_843[30] + getitem_15716 = split_843[31] + getitem_15717 = split_843[32] + getitem_15718 = split_843[33] + getitem_15719 = split_843[34] + getitem_15720 = split_843[35] + getitem_15721 = split_843[36] + getitem_15722 = split_843[37] + getitem_15723 = split_843[38] + getitem_15724 = split_843[39] + getitem_15725 = split_843[40] + getitem_15726 = split_843[41] + getitem_15727 = split_843[42] + getitem_15728 = split_843[43] + getitem_15729 = split_843[44] + getitem_15730 = split_843[45] + getitem_15731 = split_843[46] + getitem_15732 = split_843[47] + getitem_15733 = split_843[48] + getitem_15734 = split_843[49] + getitem_15735 = split_843[50] + getitem_15736 = split_843[51] + getitem_15737 = split_843[52] + getitem_15738 = split_843[53] + getitem_15739 = split_843[54] + getitem_15740 = split_843[55] + getitem_15741 = split_843[56] + getitem_15742 = split_843[57] + getitem_15743 = split_843[58] + getitem_15744 = split_843[59] + getitem_15745 = split_843[60] + getitem_15746 = split_843[61] + getitem_15747 = split_843[62] + getitem_15748 = split_843[63]; split_843 = None + cat_335 = torch.ops.aten.cat.default([getitem_15685, getitem_15686, getitem_15687, getitem_15688, getitem_15689, getitem_15690, getitem_15691, getitem_15692, getitem_15693, getitem_15694, getitem_15695, getitem_15696, getitem_15697, getitem_15698, getitem_15699, getitem_15700, getitem_15701, getitem_15702, getitem_15703, getitem_15704, getitem_15705, getitem_15706, getitem_15707, getitem_15708, getitem_15709, getitem_15710, getitem_15711, getitem_15712, getitem_15713, getitem_15714, getitem_15715, getitem_15716, getitem_15717, getitem_15718, getitem_15719, getitem_15720, getitem_15721, getitem_15722, getitem_15723, getitem_15724, getitem_15725, getitem_15726, getitem_15727, getitem_15728, getitem_15729, getitem_15730, getitem_15731, getitem_15732, getitem_15733, getitem_15734, getitem_15735, getitem_15736, getitem_15737, getitem_15738, getitem_15739, getitem_15740, getitem_15741, getitem_15742, getitem_15743, getitem_15744, getitem_15745, getitem_15746, getitem_15747, getitem_15748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_15685 = getitem_15686 = getitem_15687 = getitem_15688 = getitem_15689 = getitem_15690 = getitem_15691 = getitem_15692 = getitem_15693 = getitem_15694 = getitem_15695 = getitem_15696 = getitem_15697 = getitem_15698 = getitem_15699 = getitem_15700 = getitem_15701 = getitem_15702 = getitem_15703 = getitem_15704 = getitem_15705 = getitem_15706 = getitem_15707 = getitem_15708 = getitem_15709 = getitem_15710 = getitem_15711 = getitem_15712 = getitem_15713 = getitem_15714 = getitem_15715 = getitem_15716 = getitem_15717 = getitem_15718 = getitem_15719 = getitem_15720 = getitem_15721 = getitem_15722 = getitem_15723 = getitem_15724 = getitem_15725 = getitem_15726 = getitem_15727 = getitem_15728 = getitem_15729 = getitem_15730 = getitem_15731 = getitem_15732 = getitem_15733 = getitem_15734 = getitem_15735 = getitem_15736 = getitem_15737 = getitem_15738 = getitem_15739 = getitem_15740 = getitem_15741 = getitem_15742 = getitem_15743 = getitem_15744 = getitem_15745 = getitem_15746 = getitem_15747 = getitem_15748 = None + reduce_scatter_tensor_176 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_335, 'avg', 128, '0'); cat_335 = None + wait_tensor_761 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_176); reduce_scatter_tensor_176 = None + view_2013 = torch.ops.aten.view.default(add_1966, [2, 4096, 2048]); add_1966 = None + convert_element_type_2378 = torch.ops.prims.convert_element_type.default(view_2013, torch.float32); view_2013 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16); primals_237 = None + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_761, 128, '0'); convert_element_type_761 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + convert_element_type_2380 = torch.ops.prims.convert_element_type.default(wait_tensor_290, torch.float32); wait_tensor_290 = None + mul_1713 = torch.ops.aten.mul.Tensor(convert_element_type_2378, convert_element_type_2380); convert_element_type_2380 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(add_892, torch.float32); add_892 = None + mul_652 = torch.ops.aten.mul.Tensor(convert_element_type_762, rsqrt_44); convert_element_type_762 = None + mul_1715 = torch.ops.aten.mul.Tensor(mul_652, mul_1713) + sum_204 = torch.ops.aten.sum.dim_IntList(mul_1715, [2], True); mul_1715 = None + div_207 = torch.ops.aten.div.Tensor(mul_652, 2048) + mul_1716 = torch.ops.aten.mul.Tensor(div_207, sum_204); div_207 = sum_204 = None + sub_700 = torch.ops.aten.sub.Tensor(mul_1713, mul_1716); mul_1713 = mul_1716 = None + mul_1717 = torch.ops.aten.mul.Tensor(sub_700, rsqrt_44); sub_700 = rsqrt_44 = None + mul_1718 = torch.ops.aten.mul.Tensor(convert_element_type_2378, mul_652); convert_element_type_2378 = mul_652 = None + sum_205 = torch.ops.aten.sum.dim_IntList(mul_1718, [0, 1]); mul_1718 = None + convert_element_type_2381 = torch.ops.prims.convert_element_type.default(mul_1717, torch.bfloat16); mul_1717 = None + add_1967 = torch.ops.aten.add.Tensor(add_1954, convert_element_type_2381); add_1954 = convert_element_type_2381 = None + convert_element_type_default_45 = torch.ops.prims.convert_element_type.default(sum_205, torch.float32); sum_205 = None + reduce_scatter_tensor_177 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_45, 'avg', 128, '0'); convert_element_type_default_45 = None + wait_tensor_762 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_177); reduce_scatter_tensor_177 = None + view_2014 = torch.ops.aten.view.default(add_1967, [8192, 2048]) + permute_1036 = torch.ops.aten.permute.default(view_2014, [1, 0]) + permute_212 = torch.ops.aten.permute.default(getitem_1445, [0, 2, 1, 3]) + view_924 = torch.ops.aten.view.default(permute_212, [2, 4096, -1]); permute_212 = None + view_926 = torch.ops.aten.view.default(view_924, [8192, 2048]); view_924 = None + mm_418 = torch.ops.aten.mm.default(permute_1036, view_926); permute_1036 = view_926 = None + convert_element_type_758 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16); primals_236 = None + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_758, 128, '0'); convert_element_type_758 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_213 = torch.ops.aten.permute.default(wait_tensor_289, [1, 0]); wait_tensor_289 = None + permute_1038 = torch.ops.aten.permute.default(permute_213, [1, 0]); permute_213 = None + mm_419 = torch.ops.aten.mm.default(view_2014, permute_1038); view_2014 = permute_1038 = None + view_2015 = torch.ops.aten.view.default(mm_419, [2, 4096, 2048]); mm_419 = None + convert_element_type_2388 = torch.ops.prims.convert_element_type.default(mm_418, torch.float32); mm_418 = None + reduce_scatter_tensor_178 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2388, 'avg', 128, '0'); convert_element_type_2388 = None + wait_tensor_763 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_178); reduce_scatter_tensor_178 = None + view_2016 = torch.ops.aten.view.default(view_2015, [2, 4096, 16, 128]); view_2015 = None + permute_1040 = torch.ops.aten.permute.default(view_2016, [0, 2, 1, 3]); view_2016 = None + fw_graph12 = self.fw_graph12 + joint_graph12 = self.joint_graph12 + mask_graph12 = self.mask_graph12 + flex_attention_backward_12 = torch.ops.higher_order.flex_attention_backward(permute_209, permute_210, permute_211, getitem_1445, getitem_1446, permute_1040, None, fw_graph12, joint_graph12, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph12), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_209 = permute_210 = permute_211 = getitem_1445 = getitem_1446 = permute_1040 = fw_graph12 = joint_graph12 = mask_graph12 = None + getitem_15749 = flex_attention_backward_12[0] + getitem_15750 = flex_attention_backward_12[1] + getitem_15751 = flex_attention_backward_12[2]; flex_attention_backward_12 = None + permute_1041 = torch.ops.aten.permute.default(getitem_15751, [0, 2, 1, 3]); getitem_15751 = None + permute_1042 = torch.ops.aten.permute.default(getitem_15750, [0, 2, 1, 3]); getitem_15750 = None + permute_1043 = torch.ops.aten.permute.default(getitem_15749, [0, 2, 1, 3]); getitem_15749 = None + slice_236 = torch.ops.aten.slice.Tensor(permute_1042, 3, 0, 128) + slice_237 = torch.ops.aten.slice.Tensor(permute_1042, 3, 128, 192); permute_1042 = None + sum_206 = torch.ops.aten.sum.dim_IntList(slice_237, [2], True); slice_237 = None + cat_336 = torch.ops.aten.cat.default([slice_236, permute_1041], 3); slice_236 = permute_1041 = None + view_2017 = torch.ops.aten.view.default(cat_336, [2, 4096, 4096]); cat_336 = None + view_2018 = torch.ops.aten.view.default(view_2017, [8192, 4096]); view_2017 = None + permute_1044 = torch.ops.aten.permute.default(view_2018, [1, 0]) + mm_420 = torch.ops.aten.mm.default(permute_1044, view_921); permute_1044 = view_921 = None + convert_element_type_755 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16); primals_235 = None + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_755, 128, '0'); convert_element_type_755 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + permute_1046 = torch.ops.aten.permute.default(permute_208, [1, 0]); permute_208 = None + mm_421 = torch.ops.aten.mm.default(view_2018, permute_1046); view_2018 = permute_1046 = None + view_2019 = torch.ops.aten.view.default(mm_421, [2, 4096, 512]); mm_421 = None + convert_element_type_2393 = torch.ops.prims.convert_element_type.default(mm_420, torch.float32); mm_420 = None + reduce_scatter_tensor_179 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2393, 'avg', 128, '0'); convert_element_type_2393 = None + wait_tensor_764 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_179); reduce_scatter_tensor_179 = None + convert_element_type_2394 = torch.ops.prims.convert_element_type.default(view_2019, torch.float32); view_2019 = None + convert_element_type_752 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16); primals_234 = None + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_752, 128, '0'); convert_element_type_752 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_2396 = torch.ops.prims.convert_element_type.default(wait_tensor_287, torch.float32); wait_tensor_287 = None + mul_1719 = torch.ops.aten.mul.Tensor(convert_element_type_2394, convert_element_type_2396); convert_element_type_2396 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(getitem_1441, torch.float32); getitem_1441 = None + mul_650 = torch.ops.aten.mul.Tensor(convert_element_type_753, rsqrt_43); convert_element_type_753 = None + mul_1721 = torch.ops.aten.mul.Tensor(mul_650, mul_1719) + sum_207 = torch.ops.aten.sum.dim_IntList(mul_1721, [2], True); mul_1721 = None + div_208 = torch.ops.aten.div.Tensor(mul_650, 512) + mul_1722 = torch.ops.aten.mul.Tensor(div_208, sum_207); div_208 = sum_207 = None + sub_701 = torch.ops.aten.sub.Tensor(mul_1719, mul_1722); mul_1719 = mul_1722 = None + mul_1723 = torch.ops.aten.mul.Tensor(sub_701, rsqrt_43); sub_701 = rsqrt_43 = None + mul_1724 = torch.ops.aten.mul.Tensor(convert_element_type_2394, mul_650); convert_element_type_2394 = mul_650 = None + sum_208 = torch.ops.aten.sum.dim_IntList(mul_1724, [0, 1]); mul_1724 = None + convert_element_type_2397 = torch.ops.prims.convert_element_type.default(mul_1723, torch.bfloat16); mul_1723 = None + convert_element_type_default_44 = torch.ops.prims.convert_element_type.default(sum_208, torch.float32); sum_208 = None + reduce_scatter_tensor_180 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_44, 'avg', 128, '0'); convert_element_type_default_44 = None + wait_tensor_765 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_180); reduce_scatter_tensor_180 = None + convert_element_type_2400 = torch.ops.prims.convert_element_type.default(sum_206, torch.float32); sum_206 = None + view_2020 = torch.ops.aten.view.default(convert_element_type_2400, [2, 4096, 1, 32, 2]); convert_element_type_2400 = None + view_as_complex_78 = torch.ops.aten.view_as_complex.default(view_2020); view_2020 = None + mul_1725 = torch.ops.aten.mul.Tensor(view_as_complex_78, clone_9); view_as_complex_78 = None + view_as_real_78 = torch.ops.aten.view_as_real.default(mul_1725); mul_1725 = None + view_2021 = torch.ops.aten.view.default(view_as_real_78, [2, 4096, 1, 64]); view_as_real_78 = None + convert_element_type_2401 = torch.ops.prims.convert_element_type.default(view_2021, torch.bfloat16); view_2021 = None + squeeze_38 = torch.ops.aten.squeeze.dim(convert_element_type_2401, 2); convert_element_type_2401 = None + cat_337 = torch.ops.aten.cat.default([convert_element_type_2397, squeeze_38], 2); convert_element_type_2397 = squeeze_38 = None + view_2022 = torch.ops.aten.view.default(cat_337, [8192, 576]); cat_337 = None + permute_1048 = torch.ops.aten.permute.default(view_2022, [1, 0]) + mm_422 = torch.ops.aten.mm.default(permute_1048, view_907); permute_1048 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16); primals_233 = None + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_747, 128, '0'); convert_element_type_747 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + slice_85 = torch.ops.aten.slice.Tensor(wait_tensor_286, 0, 0, 576); wait_tensor_286 = None + permute_207 = torch.ops.aten.permute.default(slice_85, [1, 0]); slice_85 = None + permute_1050 = torch.ops.aten.permute.default(permute_207, [1, 0]); permute_207 = None + mm_423 = torch.ops.aten.mm.default(view_2022, permute_1050); view_2022 = permute_1050 = None + view_2023 = torch.ops.aten.view.default(mm_423, [2, 4096, 2048]); mm_423 = None + convert_element_type_2406 = torch.ops.prims.convert_element_type.default(mm_422, torch.float32); mm_422 = None + split_844 = torch.ops.aten.split.Tensor(convert_element_type_2406, 5); convert_element_type_2406 = None + getitem_15753 = split_844[0] + getitem_15754 = split_844[1] + getitem_15755 = split_844[2] + getitem_15756 = split_844[3] + getitem_15757 = split_844[4] + getitem_15758 = split_844[5] + getitem_15759 = split_844[6] + getitem_15760 = split_844[7] + getitem_15761 = split_844[8] + getitem_15762 = split_844[9] + getitem_15763 = split_844[10] + getitem_15764 = split_844[11] + getitem_15765 = split_844[12] + getitem_15766 = split_844[13] + getitem_15767 = split_844[14] + getitem_15768 = split_844[15] + getitem_15769 = split_844[16] + getitem_15770 = split_844[17] + getitem_15771 = split_844[18] + getitem_15772 = split_844[19] + getitem_15773 = split_844[20] + getitem_15774 = split_844[21] + getitem_15775 = split_844[22] + getitem_15776 = split_844[23] + getitem_15777 = split_844[24] + getitem_15778 = split_844[25] + getitem_15779 = split_844[26] + getitem_15780 = split_844[27] + getitem_15781 = split_844[28] + getitem_15782 = split_844[29] + getitem_15783 = split_844[30] + getitem_15784 = split_844[31] + getitem_15785 = split_844[32] + getitem_15786 = split_844[33] + getitem_15787 = split_844[34] + getitem_15788 = split_844[35] + getitem_15789 = split_844[36] + getitem_15790 = split_844[37] + getitem_15791 = split_844[38] + getitem_15792 = split_844[39] + getitem_15793 = split_844[40] + getitem_15794 = split_844[41] + getitem_15795 = split_844[42] + getitem_15796 = split_844[43] + getitem_15797 = split_844[44] + getitem_15798 = split_844[45] + getitem_15799 = split_844[46] + getitem_15800 = split_844[47] + getitem_15801 = split_844[48] + getitem_15802 = split_844[49] + getitem_15803 = split_844[50] + getitem_15804 = split_844[51] + getitem_15805 = split_844[52] + getitem_15806 = split_844[53] + getitem_15807 = split_844[54] + getitem_15808 = split_844[55] + getitem_15809 = split_844[56] + getitem_15810 = split_844[57] + getitem_15811 = split_844[58] + getitem_15812 = split_844[59] + getitem_15813 = split_844[60] + getitem_15814 = split_844[61] + getitem_15815 = split_844[62] + getitem_15816 = split_844[63] + getitem_15817 = split_844[64] + getitem_15818 = split_844[65] + getitem_15819 = split_844[66] + getitem_15820 = split_844[67] + getitem_15821 = split_844[68] + getitem_15822 = split_844[69] + getitem_15823 = split_844[70] + getitem_15824 = split_844[71] + getitem_15825 = split_844[72] + getitem_15826 = split_844[73] + getitem_15827 = split_844[74] + getitem_15828 = split_844[75] + getitem_15829 = split_844[76] + getitem_15830 = split_844[77] + getitem_15831 = split_844[78] + getitem_15832 = split_844[79] + getitem_15833 = split_844[80] + getitem_15834 = split_844[81] + getitem_15835 = split_844[82] + getitem_15836 = split_844[83] + getitem_15837 = split_844[84] + getitem_15838 = split_844[85] + getitem_15839 = split_844[86] + getitem_15840 = split_844[87] + getitem_15841 = split_844[88] + getitem_15842 = split_844[89] + getitem_15843 = split_844[90] + getitem_15844 = split_844[91] + getitem_15845 = split_844[92] + getitem_15846 = split_844[93] + getitem_15847 = split_844[94] + getitem_15848 = split_844[95] + getitem_15849 = split_844[96] + getitem_15850 = split_844[97] + getitem_15851 = split_844[98] + getitem_15852 = split_844[99] + getitem_15853 = split_844[100] + getitem_15854 = split_844[101] + getitem_15855 = split_844[102] + getitem_15856 = split_844[103] + getitem_15857 = split_844[104] + getitem_15858 = split_844[105] + getitem_15859 = split_844[106] + getitem_15860 = split_844[107] + getitem_15861 = split_844[108] + getitem_15862 = split_844[109] + getitem_15863 = split_844[110] + getitem_15864 = split_844[111] + getitem_15865 = split_844[112] + getitem_15866 = split_844[113] + getitem_15867 = split_844[114] + getitem_15868 = split_844[115]; split_844 = None + constant_pad_nd_988 = torch.ops.aten.constant_pad_nd.default(getitem_15868, [0, 0, 0, 4], 0.0); getitem_15868 = None + cat_338 = torch.ops.aten.cat.default([getitem_15753, getitem_15754, getitem_15755, getitem_15756, getitem_15757, getitem_15758, getitem_15759, getitem_15760, getitem_15761, getitem_15762, getitem_15763, getitem_15764, getitem_15765, getitem_15766, getitem_15767, getitem_15768, getitem_15769, getitem_15770, getitem_15771, getitem_15772, getitem_15773, getitem_15774, getitem_15775, getitem_15776, getitem_15777, getitem_15778, getitem_15779, getitem_15780, getitem_15781, getitem_15782, getitem_15783, getitem_15784, getitem_15785, getitem_15786, getitem_15787, getitem_15788, getitem_15789, getitem_15790, getitem_15791, getitem_15792, getitem_15793, getitem_15794, getitem_15795, getitem_15796, getitem_15797, getitem_15798, getitem_15799, getitem_15800, getitem_15801, getitem_15802, getitem_15803, getitem_15804, getitem_15805, getitem_15806, getitem_15807, getitem_15808, getitem_15809, getitem_15810, getitem_15811, getitem_15812, getitem_15813, getitem_15814, getitem_15815, getitem_15816, getitem_15817, getitem_15818, getitem_15819, getitem_15820, getitem_15821, getitem_15822, getitem_15823, getitem_15824, getitem_15825, getitem_15826, getitem_15827, getitem_15828, getitem_15829, getitem_15830, getitem_15831, getitem_15832, getitem_15833, getitem_15834, getitem_15835, getitem_15836, getitem_15837, getitem_15838, getitem_15839, getitem_15840, getitem_15841, getitem_15842, getitem_15843, getitem_15844, getitem_15845, getitem_15846, getitem_15847, getitem_15848, getitem_15849, getitem_15850, getitem_15851, getitem_15852, getitem_15853, getitem_15854, getitem_15855, getitem_15856, getitem_15857, getitem_15858, getitem_15859, getitem_15860, getitem_15861, getitem_15862, getitem_15863, getitem_15864, getitem_15865, getitem_15866, getitem_15867, constant_pad_nd_988, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_15753 = getitem_15754 = getitem_15755 = getitem_15756 = getitem_15757 = getitem_15758 = getitem_15759 = getitem_15760 = getitem_15761 = getitem_15762 = getitem_15763 = getitem_15764 = getitem_15765 = getitem_15766 = getitem_15767 = getitem_15768 = getitem_15769 = getitem_15770 = getitem_15771 = getitem_15772 = getitem_15773 = getitem_15774 = getitem_15775 = getitem_15776 = getitem_15777 = getitem_15778 = getitem_15779 = getitem_15780 = getitem_15781 = getitem_15782 = getitem_15783 = getitem_15784 = getitem_15785 = getitem_15786 = getitem_15787 = getitem_15788 = getitem_15789 = getitem_15790 = getitem_15791 = getitem_15792 = getitem_15793 = getitem_15794 = getitem_15795 = getitem_15796 = getitem_15797 = getitem_15798 = getitem_15799 = getitem_15800 = getitem_15801 = getitem_15802 = getitem_15803 = getitem_15804 = getitem_15805 = getitem_15806 = getitem_15807 = getitem_15808 = getitem_15809 = getitem_15810 = getitem_15811 = getitem_15812 = getitem_15813 = getitem_15814 = getitem_15815 = getitem_15816 = getitem_15817 = getitem_15818 = getitem_15819 = getitem_15820 = getitem_15821 = getitem_15822 = getitem_15823 = getitem_15824 = getitem_15825 = getitem_15826 = getitem_15827 = getitem_15828 = getitem_15829 = getitem_15830 = getitem_15831 = getitem_15832 = getitem_15833 = getitem_15834 = getitem_15835 = getitem_15836 = getitem_15837 = getitem_15838 = getitem_15839 = getitem_15840 = getitem_15841 = getitem_15842 = getitem_15843 = getitem_15844 = getitem_15845 = getitem_15846 = getitem_15847 = getitem_15848 = getitem_15849 = getitem_15850 = getitem_15851 = getitem_15852 = getitem_15853 = getitem_15854 = getitem_15855 = getitem_15856 = getitem_15857 = getitem_15858 = getitem_15859 = getitem_15860 = getitem_15861 = getitem_15862 = getitem_15863 = getitem_15864 = getitem_15865 = getitem_15866 = getitem_15867 = constant_pad_nd_988 = None + reduce_scatter_tensor_181 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_338, 'avg', 128, '0'); cat_338 = None + wait_tensor_766 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_181); reduce_scatter_tensor_181 = None + slice_238 = torch.ops.aten.slice.Tensor(permute_1043, 3, 0, 128) + slice_239 = torch.ops.aten.slice.Tensor(permute_1043, 3, 128, 192); permute_1043 = None + convert_element_type_2407 = torch.ops.prims.convert_element_type.default(slice_239, torch.float32); slice_239 = None + view_2024 = torch.ops.aten.view.default(convert_element_type_2407, [2, 4096, 16, 32, 2]); convert_element_type_2407 = None + view_as_complex_79 = torch.ops.aten.view_as_complex.default(view_2024); view_2024 = None + mul_1726 = torch.ops.aten.mul.Tensor(view_as_complex_79, clone_9); view_as_complex_79 = None + view_as_real_79 = torch.ops.aten.view_as_real.default(mul_1726); mul_1726 = None + view_2025 = torch.ops.aten.view.default(view_as_real_79, [2, 4096, 16, 64]); view_as_real_79 = None + convert_element_type_2408 = torch.ops.prims.convert_element_type.default(view_2025, torch.bfloat16); view_2025 = None + cat_339 = torch.ops.aten.cat.default([slice_238, convert_element_type_2408], 3); slice_238 = convert_element_type_2408 = None + view_2026 = torch.ops.aten.view.default(cat_339, [2, 4096, 3072]); cat_339 = None + view_2027 = torch.ops.aten.view.default(view_2026, [8192, 3072]); view_2026 = None + permute_1052 = torch.ops.aten.permute.default(view_2027, [1, 0]) + mm_424 = torch.ops.aten.mm.default(permute_1052, view_907); permute_1052 = view_907 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16); primals_232 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_742, 128, '0'); convert_element_type_742 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_285, [1, 0]); wait_tensor_285 = None + permute_1054 = torch.ops.aten.permute.default(permute_206, [1, 0]); permute_206 = None + mm_425 = torch.ops.aten.mm.default(view_2027, permute_1054); view_2027 = permute_1054 = None + view_2028 = torch.ops.aten.view.default(mm_425, [2, 4096, 2048]); mm_425 = None + add_1968 = torch.ops.aten.add.Tensor(view_2023, view_2028); view_2023 = view_2028 = None + convert_element_type_2413 = torch.ops.prims.convert_element_type.default(mm_424, torch.float32); mm_424 = None + reduce_scatter_tensor_182 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2413, 'avg', 128, '0'); convert_element_type_2413 = None + wait_tensor_767 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_182); reduce_scatter_tensor_182 = None + convert_element_type_2414 = torch.ops.prims.convert_element_type.default(add_1968, torch.float32); add_1968 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16); primals_231 = None + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_739, 128, '0'); convert_element_type_739 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + convert_element_type_2416 = torch.ops.prims.convert_element_type.default(wait_tensor_284, torch.float32); wait_tensor_284 = None + mul_1727 = torch.ops.aten.mul.Tensor(convert_element_type_2414, convert_element_type_2416); convert_element_type_2416 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(add_889, torch.float32); add_889 = None + mul_646 = torch.ops.aten.mul.Tensor(convert_element_type_740, rsqrt_42); convert_element_type_740 = None + mul_1729 = torch.ops.aten.mul.Tensor(mul_646, mul_1727) + sum_209 = torch.ops.aten.sum.dim_IntList(mul_1729, [2], True); mul_1729 = None + div_209 = torch.ops.aten.div.Tensor(mul_646, 2048) + mul_1730 = torch.ops.aten.mul.Tensor(div_209, sum_209); div_209 = sum_209 = None + sub_702 = torch.ops.aten.sub.Tensor(mul_1727, mul_1730); mul_1727 = mul_1730 = None + mul_1731 = torch.ops.aten.mul.Tensor(sub_702, rsqrt_42); sub_702 = rsqrt_42 = None + mul_1732 = torch.ops.aten.mul.Tensor(convert_element_type_2414, mul_646); convert_element_type_2414 = mul_646 = None + sum_210 = torch.ops.aten.sum.dim_IntList(mul_1732, [0, 1]); mul_1732 = None + convert_element_type_2417 = torch.ops.prims.convert_element_type.default(mul_1731, torch.bfloat16); mul_1731 = None + add_1969 = torch.ops.aten.add.Tensor(add_1967, convert_element_type_2417); add_1967 = convert_element_type_2417 = None + convert_element_type_default_43 = torch.ops.prims.convert_element_type.default(sum_210, torch.float32); sum_210 = None + reduce_scatter_tensor_183 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_43, 'avg', 128, '0'); convert_element_type_default_43 = None + wait_tensor_768 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_183); reduce_scatter_tensor_183 = None + view_2029 = torch.ops.aten.view.default(add_1969, [8192, 2048]) + unsqueeze_66 = torch.ops.aten.unsqueeze.default(view_2029, 1) + convert_element_type_2420 = torch.ops.prims.convert_element_type.default(unsqueeze_66, torch.float32); unsqueeze_66 = None + bmm_52 = torch.ops.aten.bmm.default(permute_1056, convert_element_type_2420); permute_1056 = None + bmm_53 = torch.ops.aten.bmm.default(convert_element_type_2420, permute_1057); convert_element_type_2420 = permute_1057 = None + convert_element_type_2421 = torch.ops.prims.convert_element_type.default(bmm_52, torch.bfloat16); bmm_52 = None + view_2030 = torch.ops.aten.view.default(bmm_53, [8192, 6]); bmm_53 = None + view_2031 = torch.ops.aten.view.default(convert_element_type_2421, [49152, 2048]); convert_element_type_2421 = None + index_78 = torch.ops.aten.index.Tensor(view_2031, [getitem_1341]); view_2031 = getitem_1341 = None + permute_1058 = torch.ops.aten.permute.default(view_2029, [1, 0]) + mm_426 = torch.ops.aten.mm.default(permute_1058, mul_643); permute_1058 = mul_643 = None + convert_element_type_734 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16); primals_230 = None + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_734, 128, '0'); convert_element_type_734 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + permute_1060 = torch.ops.aten.permute.default(permute_205, [1, 0]); permute_205 = None + mm_427 = torch.ops.aten.mm.default(view_2029, permute_1060); view_2029 = permute_1060 = None + convert_element_type_2426 = torch.ops.prims.convert_element_type.default(mm_426, torch.float32); mm_426 = None + reduce_scatter_tensor_184 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2426, 'avg', 128, '0'); convert_element_type_2426 = None + wait_tensor_769 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_184); reduce_scatter_tensor_184 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mm_108, torch.float32); mm_108 = None + neg_26 = torch.ops.aten.neg.default(convert_element_type_729) + exp_39 = torch.ops.aten.exp.default(neg_26); neg_26 = None + add_884 = torch.ops.aten.add.Tensor(exp_39, 1); exp_39 = None + div_65 = torch.ops.aten.div.Tensor(convert_element_type_729, add_884) + convert_element_type_730 = torch.ops.prims.convert_element_type.default(div_65, torch.bfloat16); div_65 = None + mul_1733 = torch.ops.aten.mul.Tensor(mm_427, convert_element_type_730); convert_element_type_730 = None + mul_1734 = torch.ops.aten.mul.Tensor(mm_427, mm_109); mm_427 = mm_109 = None + permute_1062 = torch.ops.aten.permute.default(mul_1733, [1, 0]) + mm_428 = torch.ops.aten.mm.default(permute_1062, view_862); permute_1062 = None + convert_element_type_731 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16); primals_229 = None + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_731, 128, '0'); convert_element_type_731 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_204 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + permute_1064 = torch.ops.aten.permute.default(permute_204, [1, 0]); permute_204 = None + mm_429 = torch.ops.aten.mm.default(mul_1733, permute_1064); mul_1733 = permute_1064 = None + convert_element_type_2431 = torch.ops.prims.convert_element_type.default(mm_428, torch.float32); mm_428 = None + reduce_scatter_tensor_185 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2431, 'avg', 128, '0'); convert_element_type_2431 = None + wait_tensor_770 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_185); reduce_scatter_tensor_185 = None + convert_element_type_2432 = torch.ops.prims.convert_element_type.default(mul_1734, torch.float32); mul_1734 = None + reciprocal_26 = torch.ops.aten.reciprocal.default(add_884); add_884 = None + mul_1735 = torch.ops.aten.mul.Tensor(reciprocal_26, 1); reciprocal_26 = None + mul_1736 = torch.ops.aten.mul.Tensor(convert_element_type_2432, mul_1735); convert_element_type_2432 = None + sub_703 = torch.ops.aten.sub.Tensor(1, mul_1735); mul_1735 = None + mul_1737 = torch.ops.aten.mul.Tensor(convert_element_type_729, sub_703); convert_element_type_729 = sub_703 = None + add_1971 = torch.ops.aten.add.Tensor(mul_1737, 1); mul_1737 = None + mul_1738 = torch.ops.aten.mul.Tensor(mul_1736, add_1971); mul_1736 = add_1971 = None + convert_element_type_2434 = torch.ops.prims.convert_element_type.default(mul_1738, torch.bfloat16); mul_1738 = None + permute_1066 = torch.ops.aten.permute.default(convert_element_type_2434, [1, 0]) + mm_430 = torch.ops.aten.mm.default(permute_1066, view_862); permute_1066 = None + convert_element_type_726 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16); primals_228 = None + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_726, 128, '0'); convert_element_type_726 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_203 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + permute_1068 = torch.ops.aten.permute.default(permute_203, [1, 0]); permute_203 = None + mm_431 = torch.ops.aten.mm.default(convert_element_type_2434, permute_1068); convert_element_type_2434 = permute_1068 = None + add_1972 = torch.ops.aten.add.Tensor(mm_429, mm_431); mm_429 = mm_431 = None + convert_element_type_2439 = torch.ops.prims.convert_element_type.default(mm_430, torch.float32); mm_430 = None + reduce_scatter_tensor_186 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2439, 'avg', 128, '0'); convert_element_type_2439 = None + wait_tensor_771 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_186); reduce_scatter_tensor_186 = None + all_to_all_single_104 = torch.ops._c10d_functional.all_to_all_single.default(index_78, [_local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207], [_local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199], '1033'); index_78 = None + wait_tensor_772 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_104); all_to_all_single_104 = None + full_426 = torch.ops.aten.full.default([sym_size_int_49, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_49 = None + slice_scatter_13 = torch.ops.aten.slice_scatter.default(full_426, wait_tensor_772, 0, 0, -1); wait_tensor_772 = None + index_79 = torch.ops.aten.index.Tensor(slice_scatter_13, [getitem_1342]); slice_scatter_13 = None + permute_1070 = torch.ops.aten.permute.default(index_79, [1, 0]) + _grouped_mm_156 = torch.ops.aten._grouped_mm.default(permute_1070, mul_623, cumsum_38); permute_1070 = mul_623 = None + _grouped_mm_157 = torch.ops.aten._grouped_mm.default(index_79, permute_1072, cumsum_38); index_79 = permute_1072 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(_grouped_mm_36, torch.float32); _grouped_mm_36 = None + neg_25 = torch.ops.aten.neg.default(convert_element_type_724) + exp_38 = torch.ops.aten.exp.default(neg_25); neg_25 = None + add_848 = torch.ops.aten.add.Tensor(exp_38, 1); exp_38 = None + div_64 = torch.ops.aten.div.Tensor(convert_element_type_724, add_848) + convert_element_type_725 = torch.ops.prims.convert_element_type.default(div_64, torch.bfloat16); div_64 = None + mul_1739 = torch.ops.aten.mul.Tensor(_grouped_mm_157, convert_element_type_725); convert_element_type_725 = None + mul_1740 = torch.ops.aten.mul.Tensor(_grouped_mm_157, _grouped_mm_37); _grouped_mm_157 = _grouped_mm_37 = None + permute_1074 = torch.ops.aten.permute.default(mul_1739, [1, 0]) + _grouped_mm_158 = torch.ops.aten._grouped_mm.default(permute_1074, index_25, cumsum_38); permute_1074 = None + _grouped_mm_159 = torch.ops.aten._grouped_mm.default(mul_1739, permute_1076, cumsum_38); mul_1739 = permute_1076 = None + convert_element_type_2440 = torch.ops.prims.convert_element_type.default(mul_1740, torch.float32); mul_1740 = None + reciprocal_27 = torch.ops.aten.reciprocal.default(add_848); add_848 = None + mul_1741 = torch.ops.aten.mul.Tensor(reciprocal_27, 1); reciprocal_27 = None + mul_1742 = torch.ops.aten.mul.Tensor(convert_element_type_2440, mul_1741); convert_element_type_2440 = None + sub_704 = torch.ops.aten.sub.Tensor(1, mul_1741); mul_1741 = None + mul_1743 = torch.ops.aten.mul.Tensor(convert_element_type_724, sub_704); convert_element_type_724 = sub_704 = None + add_1974 = torch.ops.aten.add.Tensor(mul_1743, 1); mul_1743 = None + mul_1744 = torch.ops.aten.mul.Tensor(mul_1742, add_1974); mul_1742 = add_1974 = None + convert_element_type_2442 = torch.ops.prims.convert_element_type.default(mul_1744, torch.bfloat16); mul_1744 = None + permute_1078 = torch.ops.aten.permute.default(convert_element_type_2442, [1, 0]) + _grouped_mm_160 = torch.ops.aten._grouped_mm.default(permute_1078, index_25, cumsum_38); permute_1078 = index_25 = None + _grouped_mm_161 = torch.ops.aten._grouped_mm.default(convert_element_type_2442, permute_1080, cumsum_38); convert_element_type_2442 = permute_1080 = cumsum_38 = None + add_1975 = torch.ops.aten.add.Tensor(_grouped_mm_159, _grouped_mm_161); _grouped_mm_159 = _grouped_mm_161 = None + convert_element_type_2443 = torch.ops.prims.convert_element_type.default(_grouped_mm_158, torch.float32); _grouped_mm_158 = None + div_210 = torch.ops.aten.div.Tensor(convert_element_type_2443, 128); convert_element_type_2443 = None + split_846 = torch.ops.aten.split.Tensor(div_210, 88, 1); div_210 = None + getitem_15885 = split_846[0] + getitem_15902 = split_846[1] + getitem_15919 = split_846[2] + getitem_15936 = split_846[3] + getitem_15953 = split_846[4] + getitem_15970 = split_846[5] + getitem_15987 = split_846[6] + getitem_16004 = split_846[7] + getitem_16021 = split_846[8] + getitem_16038 = split_846[9] + getitem_16055 = split_846[10] + getitem_16072 = split_846[11] + getitem_16089 = split_846[12] + getitem_16106 = split_846[13] + getitem_16123 = split_846[14] + getitem_16140 = split_846[15]; split_846 = None + cat_340 = torch.ops.aten.cat.default([getitem_15885, getitem_15902, getitem_15919, getitem_15936, getitem_15953, getitem_15970, getitem_15987, getitem_16004, getitem_16021, getitem_16038, getitem_16055, getitem_16072, getitem_16089, getitem_16106, getitem_16123, getitem_16140]); getitem_15885 = getitem_15902 = getitem_15919 = getitem_15936 = getitem_15953 = getitem_15970 = getitem_15987 = getitem_16004 = getitem_16021 = getitem_16038 = getitem_16055 = getitem_16072 = getitem_16089 = getitem_16106 = getitem_16123 = getitem_16140 = None + reduce_scatter_tensor_187 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_340, 'sum', 16, '1025'); cat_340 = None + wait_tensor_773 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_187); reduce_scatter_tensor_187 = None + convert_element_type_2444 = torch.ops.prims.convert_element_type.default(_grouped_mm_156, torch.float32); _grouped_mm_156 = None + div_211 = torch.ops.aten.div.Tensor(convert_element_type_2444, 128); convert_element_type_2444 = None + split_863 = torch.ops.aten.split.Tensor(div_211, 128, 1); div_211 = None + getitem_16157 = split_863[0] + getitem_16174 = split_863[1] + getitem_16191 = split_863[2] + getitem_16208 = split_863[3] + getitem_16225 = split_863[4] + getitem_16242 = split_863[5] + getitem_16259 = split_863[6] + getitem_16276 = split_863[7] + getitem_16293 = split_863[8] + getitem_16310 = split_863[9] + getitem_16327 = split_863[10] + getitem_16344 = split_863[11] + getitem_16361 = split_863[12] + getitem_16378 = split_863[13] + getitem_16395 = split_863[14] + getitem_16412 = split_863[15]; split_863 = None + cat_341 = torch.ops.aten.cat.default([getitem_16157, getitem_16174, getitem_16191, getitem_16208, getitem_16225, getitem_16242, getitem_16259, getitem_16276, getitem_16293, getitem_16310, getitem_16327, getitem_16344, getitem_16361, getitem_16378, getitem_16395, getitem_16412]); getitem_16157 = getitem_16174 = getitem_16191 = getitem_16208 = getitem_16225 = getitem_16242 = getitem_16259 = getitem_16276 = getitem_16293 = getitem_16310 = getitem_16327 = getitem_16344 = getitem_16361 = getitem_16378 = getitem_16395 = getitem_16412 = None + reduce_scatter_tensor_188 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_341, 'sum', 16, '1025'); cat_341 = None + wait_tensor_774 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_188); reduce_scatter_tensor_188 = None + convert_element_type_2445 = torch.ops.prims.convert_element_type.default(_grouped_mm_160, torch.float32); _grouped_mm_160 = None + div_212 = torch.ops.aten.div.Tensor(convert_element_type_2445, 128); convert_element_type_2445 = None + split_880 = torch.ops.aten.split.Tensor(div_212, 88, 1); div_212 = None + getitem_16429 = split_880[0] + getitem_16446 = split_880[1] + getitem_16463 = split_880[2] + getitem_16480 = split_880[3] + getitem_16497 = split_880[4] + getitem_16514 = split_880[5] + getitem_16531 = split_880[6] + getitem_16548 = split_880[7] + getitem_16565 = split_880[8] + getitem_16582 = split_880[9] + getitem_16599 = split_880[10] + getitem_16616 = split_880[11] + getitem_16633 = split_880[12] + getitem_16650 = split_880[13] + getitem_16667 = split_880[14] + getitem_16684 = split_880[15]; split_880 = None + cat_342 = torch.ops.aten.cat.default([getitem_16429, getitem_16446, getitem_16463, getitem_16480, getitem_16497, getitem_16514, getitem_16531, getitem_16548, getitem_16565, getitem_16582, getitem_16599, getitem_16616, getitem_16633, getitem_16650, getitem_16667, getitem_16684]); getitem_16429 = getitem_16446 = getitem_16463 = getitem_16480 = getitem_16497 = getitem_16514 = getitem_16531 = getitem_16548 = getitem_16565 = getitem_16582 = getitem_16599 = getitem_16616 = getitem_16633 = getitem_16650 = getitem_16667 = getitem_16684 = None + reduce_scatter_tensor_189 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_342, 'sum', 16, '1025'); cat_342 = None + wait_tensor_775 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_189); reduce_scatter_tensor_189 = None + index_put_78 = torch.ops.aten.index_put.default(full_426, [getitem_1342], add_1975, True); full_426 = getitem_1342 = add_1975 = None + slice_240 = torch.ops.aten.slice.Tensor(index_put_78, 0, 0, add_1976); index_put_78 = add_1976 = None + all_to_all_single_105 = torch.ops._c10d_functional.all_to_all_single.default(slice_240, [_local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199], [_local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207], '1033'); slice_240 = _local_scalar_dense_192 = _local_scalar_dense_193 = _local_scalar_dense_194 = _local_scalar_dense_195 = _local_scalar_dense_196 = _local_scalar_dense_197 = _local_scalar_dense_198 = _local_scalar_dense_199 = _local_scalar_dense_200 = _local_scalar_dense_201 = _local_scalar_dense_202 = _local_scalar_dense_203 = _local_scalar_dense_204 = _local_scalar_dense_205 = _local_scalar_dense_206 = _local_scalar_dense_207 = None + wait_tensor_776 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_105); all_to_all_single_105 = None + index_put_79 = torch.ops.aten.index_put.default(full_default_52, [div_62], wait_tensor_776, True); div_62 = wait_tensor_776 = None + add_1980 = torch.ops.aten.add.Tensor(add_1972, index_put_79); add_1972 = index_put_79 = None + mul_1745 = torch.ops.aten.mul.Tensor(view_2030, 1.0); view_2030 = None + scatter_add_13 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_1339, mul_1745); getitem_1339 = mul_1745 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(mm_107, torch.float32); mm_107 = None + sub_288 = torch.ops.aten.sub.Tensor(convert_element_type_713, amax_12); convert_element_type_713 = amax_12 = None + exp_37 = torch.ops.aten.exp.default(sub_288); sub_288 = None + div_61 = torch.ops.aten.div.Tensor(exp_37, sum_49); exp_37 = sum_49 = None + mul_1746 = torch.ops.aten.mul.Tensor(scatter_add_13, div_61); scatter_add_13 = None + sum_211 = torch.ops.aten.sum.dim_IntList(mul_1746, [1], True) + neg_94 = torch.ops.aten.neg.default(div_61); div_61 = None + fma_13 = torch.ops.prims.fma.default(neg_94, sum_211, mul_1746); neg_94 = sum_211 = mul_1746 = None + convert_element_type_2446 = torch.ops.prims.convert_element_type.default(fma_13, torch.bfloat16); fma_13 = None + permute_1082 = torch.ops.aten.permute.default(convert_element_type_2446, [1, 0]) + mm_432 = torch.ops.aten.mm.default(permute_1082, view_862); permute_1082 = view_862 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16); primals_223 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 128, '0'); convert_element_type_710 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + slice_81 = torch.ops.aten.slice.Tensor(wait_tensor_270, 0, 0, 64); wait_tensor_270 = None + permute_199 = torch.ops.aten.permute.default(slice_81, [1, 0]); slice_81 = None + permute_1084 = torch.ops.aten.permute.default(permute_199, [1, 0]); permute_199 = None + mm_433 = torch.ops.aten.mm.default(convert_element_type_2446, permute_1084); convert_element_type_2446 = permute_1084 = None + add_1981 = torch.ops.aten.add.Tensor(add_1980, mm_433); add_1980 = mm_433 = None + convert_element_type_2451 = torch.ops.prims.convert_element_type.default(mm_432, torch.float32); mm_432 = None + split_896 = torch.ops.aten.split.Tensor(convert_element_type_2451, 1); convert_element_type_2451 = None + getitem_16685 = split_896[0] + getitem_16686 = split_896[1] + getitem_16687 = split_896[2] + getitem_16688 = split_896[3] + getitem_16689 = split_896[4] + getitem_16690 = split_896[5] + getitem_16691 = split_896[6] + getitem_16692 = split_896[7] + getitem_16693 = split_896[8] + getitem_16694 = split_896[9] + getitem_16695 = split_896[10] + getitem_16696 = split_896[11] + getitem_16697 = split_896[12] + getitem_16698 = split_896[13] + getitem_16699 = split_896[14] + getitem_16700 = split_896[15] + getitem_16701 = split_896[16] + getitem_16702 = split_896[17] + getitem_16703 = split_896[18] + getitem_16704 = split_896[19] + getitem_16705 = split_896[20] + getitem_16706 = split_896[21] + getitem_16707 = split_896[22] + getitem_16708 = split_896[23] + getitem_16709 = split_896[24] + getitem_16710 = split_896[25] + getitem_16711 = split_896[26] + getitem_16712 = split_896[27] + getitem_16713 = split_896[28] + getitem_16714 = split_896[29] + getitem_16715 = split_896[30] + getitem_16716 = split_896[31] + getitem_16717 = split_896[32] + getitem_16718 = split_896[33] + getitem_16719 = split_896[34] + getitem_16720 = split_896[35] + getitem_16721 = split_896[36] + getitem_16722 = split_896[37] + getitem_16723 = split_896[38] + getitem_16724 = split_896[39] + getitem_16725 = split_896[40] + getitem_16726 = split_896[41] + getitem_16727 = split_896[42] + getitem_16728 = split_896[43] + getitem_16729 = split_896[44] + getitem_16730 = split_896[45] + getitem_16731 = split_896[46] + getitem_16732 = split_896[47] + getitem_16733 = split_896[48] + getitem_16734 = split_896[49] + getitem_16735 = split_896[50] + getitem_16736 = split_896[51] + getitem_16737 = split_896[52] + getitem_16738 = split_896[53] + getitem_16739 = split_896[54] + getitem_16740 = split_896[55] + getitem_16741 = split_896[56] + getitem_16742 = split_896[57] + getitem_16743 = split_896[58] + getitem_16744 = split_896[59] + getitem_16745 = split_896[60] + getitem_16746 = split_896[61] + getitem_16747 = split_896[62] + getitem_16748 = split_896[63]; split_896 = None + cat_343 = torch.ops.aten.cat.default([getitem_16685, getitem_16686, getitem_16687, getitem_16688, getitem_16689, getitem_16690, getitem_16691, getitem_16692, getitem_16693, getitem_16694, getitem_16695, getitem_16696, getitem_16697, getitem_16698, getitem_16699, getitem_16700, getitem_16701, getitem_16702, getitem_16703, getitem_16704, getitem_16705, getitem_16706, getitem_16707, getitem_16708, getitem_16709, getitem_16710, getitem_16711, getitem_16712, getitem_16713, getitem_16714, getitem_16715, getitem_16716, getitem_16717, getitem_16718, getitem_16719, getitem_16720, getitem_16721, getitem_16722, getitem_16723, getitem_16724, getitem_16725, getitem_16726, getitem_16727, getitem_16728, getitem_16729, getitem_16730, getitem_16731, getitem_16732, getitem_16733, getitem_16734, getitem_16735, getitem_16736, getitem_16737, getitem_16738, getitem_16739, getitem_16740, getitem_16741, getitem_16742, getitem_16743, getitem_16744, getitem_16745, getitem_16746, getitem_16747, getitem_16748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_16685 = getitem_16686 = getitem_16687 = getitem_16688 = getitem_16689 = getitem_16690 = getitem_16691 = getitem_16692 = getitem_16693 = getitem_16694 = getitem_16695 = getitem_16696 = getitem_16697 = getitem_16698 = getitem_16699 = getitem_16700 = getitem_16701 = getitem_16702 = getitem_16703 = getitem_16704 = getitem_16705 = getitem_16706 = getitem_16707 = getitem_16708 = getitem_16709 = getitem_16710 = getitem_16711 = getitem_16712 = getitem_16713 = getitem_16714 = getitem_16715 = getitem_16716 = getitem_16717 = getitem_16718 = getitem_16719 = getitem_16720 = getitem_16721 = getitem_16722 = getitem_16723 = getitem_16724 = getitem_16725 = getitem_16726 = getitem_16727 = getitem_16728 = getitem_16729 = getitem_16730 = getitem_16731 = getitem_16732 = getitem_16733 = getitem_16734 = getitem_16735 = getitem_16736 = getitem_16737 = getitem_16738 = getitem_16739 = getitem_16740 = getitem_16741 = getitem_16742 = getitem_16743 = getitem_16744 = getitem_16745 = getitem_16746 = getitem_16747 = getitem_16748 = None + reduce_scatter_tensor_190 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_343, 'avg', 128, '0'); cat_343 = None + wait_tensor_777 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_190); reduce_scatter_tensor_190 = None + view_2032 = torch.ops.aten.view.default(add_1981, [2, 4096, 2048]); add_1981 = None + convert_element_type_2452 = torch.ops.prims.convert_element_type.default(view_2032, torch.float32); view_2032 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16); primals_221 = None + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_707, 128, '0'); convert_element_type_707 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + convert_element_type_2454 = torch.ops.prims.convert_element_type.default(wait_tensor_269, torch.float32); wait_tensor_269 = None + mul_1747 = torch.ops.aten.mul.Tensor(convert_element_type_2452, convert_element_type_2454); convert_element_type_2454 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(add_824, torch.float32); add_824 = None + mul_603 = torch.ops.aten.mul.Tensor(convert_element_type_708, rsqrt_41); convert_element_type_708 = None + mul_1749 = torch.ops.aten.mul.Tensor(mul_603, mul_1747) + sum_212 = torch.ops.aten.sum.dim_IntList(mul_1749, [2], True); mul_1749 = None + div_213 = torch.ops.aten.div.Tensor(mul_603, 2048) + mul_1750 = torch.ops.aten.mul.Tensor(div_213, sum_212); div_213 = sum_212 = None + sub_706 = torch.ops.aten.sub.Tensor(mul_1747, mul_1750); mul_1747 = mul_1750 = None + mul_1751 = torch.ops.aten.mul.Tensor(sub_706, rsqrt_41); sub_706 = rsqrt_41 = None + mul_1752 = torch.ops.aten.mul.Tensor(convert_element_type_2452, mul_603); convert_element_type_2452 = mul_603 = None + sum_213 = torch.ops.aten.sum.dim_IntList(mul_1752, [0, 1]); mul_1752 = None + convert_element_type_2455 = torch.ops.prims.convert_element_type.default(mul_1751, torch.bfloat16); mul_1751 = None + add_1982 = torch.ops.aten.add.Tensor(add_1969, convert_element_type_2455); add_1969 = convert_element_type_2455 = None + convert_element_type_default_42 = torch.ops.prims.convert_element_type.default(sum_213, torch.float32); sum_213 = None + reduce_scatter_tensor_191 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_42, 'avg', 128, '0'); convert_element_type_default_42 = None + wait_tensor_778 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_191); reduce_scatter_tensor_191 = None + view_2033 = torch.ops.aten.view.default(add_1982, [8192, 2048]) + permute_1086 = torch.ops.aten.permute.default(view_2033, [1, 0]) + permute_197 = torch.ops.aten.permute.default(getitem_1335, [0, 2, 1, 3]) + view_857 = torch.ops.aten.view.default(permute_197, [2, 4096, -1]); permute_197 = None + view_859 = torch.ops.aten.view.default(view_857, [8192, 2048]); view_857 = None + mm_434 = torch.ops.aten.mm.default(permute_1086, view_859); permute_1086 = view_859 = None + convert_element_type_704 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16); primals_220 = None + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_704, 128, '0'); convert_element_type_704 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + permute_1088 = torch.ops.aten.permute.default(permute_198, [1, 0]); permute_198 = None + mm_435 = torch.ops.aten.mm.default(view_2033, permute_1088); view_2033 = permute_1088 = None + view_2034 = torch.ops.aten.view.default(mm_435, [2, 4096, 2048]); mm_435 = None + convert_element_type_2462 = torch.ops.prims.convert_element_type.default(mm_434, torch.float32); mm_434 = None + reduce_scatter_tensor_192 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2462, 'avg', 128, '0'); convert_element_type_2462 = None + wait_tensor_779 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_192); reduce_scatter_tensor_192 = None + view_2035 = torch.ops.aten.view.default(view_2034, [2, 4096, 16, 128]); view_2034 = None + permute_1090 = torch.ops.aten.permute.default(view_2035, [0, 2, 1, 3]); view_2035 = None + fw_graph13 = self.fw_graph13 + joint_graph13 = self.joint_graph13 + mask_graph13 = self.mask_graph13 + flex_attention_backward_13 = torch.ops.higher_order.flex_attention_backward(permute_194, permute_195, permute_196, getitem_1335, getitem_1336, permute_1090, None, fw_graph13, joint_graph13, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph13), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_194 = permute_195 = permute_196 = getitem_1335 = getitem_1336 = permute_1090 = fw_graph13 = joint_graph13 = mask_graph13 = None + getitem_16749 = flex_attention_backward_13[0] + getitem_16750 = flex_attention_backward_13[1] + getitem_16751 = flex_attention_backward_13[2]; flex_attention_backward_13 = None + permute_1091 = torch.ops.aten.permute.default(getitem_16751, [0, 2, 1, 3]); getitem_16751 = None + permute_1092 = torch.ops.aten.permute.default(getitem_16750, [0, 2, 1, 3]); getitem_16750 = None + permute_1093 = torch.ops.aten.permute.default(getitem_16749, [0, 2, 1, 3]); getitem_16749 = None + slice_242 = torch.ops.aten.slice.Tensor(permute_1092, 3, 0, 128) + slice_243 = torch.ops.aten.slice.Tensor(permute_1092, 3, 128, 192); permute_1092 = None + sum_214 = torch.ops.aten.sum.dim_IntList(slice_243, [2], True); slice_243 = None + cat_344 = torch.ops.aten.cat.default([slice_242, permute_1091], 3); slice_242 = permute_1091 = None + view_2036 = torch.ops.aten.view.default(cat_344, [2, 4096, 4096]); cat_344 = None + view_2037 = torch.ops.aten.view.default(view_2036, [8192, 4096]); view_2036 = None + permute_1094 = torch.ops.aten.permute.default(view_2037, [1, 0]) + mm_436 = torch.ops.aten.mm.default(permute_1094, view_854); permute_1094 = view_854 = None + convert_element_type_701 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16); primals_219 = None + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_701, 128, '0'); convert_element_type_701 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_193 = torch.ops.aten.permute.default(wait_tensor_267, [1, 0]); wait_tensor_267 = None + permute_1096 = torch.ops.aten.permute.default(permute_193, [1, 0]); permute_193 = None + mm_437 = torch.ops.aten.mm.default(view_2037, permute_1096); view_2037 = permute_1096 = None + view_2038 = torch.ops.aten.view.default(mm_437, [2, 4096, 512]); mm_437 = None + convert_element_type_2467 = torch.ops.prims.convert_element_type.default(mm_436, torch.float32); mm_436 = None + reduce_scatter_tensor_193 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2467, 'avg', 128, '0'); convert_element_type_2467 = None + wait_tensor_780 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_193); reduce_scatter_tensor_193 = None + convert_element_type_2468 = torch.ops.prims.convert_element_type.default(view_2038, torch.float32); view_2038 = None + convert_element_type_698 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16); primals_218 = None + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_698, 128, '0'); convert_element_type_698 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + convert_element_type_2470 = torch.ops.prims.convert_element_type.default(wait_tensor_266, torch.float32); wait_tensor_266 = None + mul_1753 = torch.ops.aten.mul.Tensor(convert_element_type_2468, convert_element_type_2470); convert_element_type_2470 = None + convert_element_type_699 = torch.ops.prims.convert_element_type.default(getitem_1331, torch.float32); getitem_1331 = None + mul_601 = torch.ops.aten.mul.Tensor(convert_element_type_699, rsqrt_40); convert_element_type_699 = None + mul_1755 = torch.ops.aten.mul.Tensor(mul_601, mul_1753) + sum_215 = torch.ops.aten.sum.dim_IntList(mul_1755, [2], True); mul_1755 = None + div_214 = torch.ops.aten.div.Tensor(mul_601, 512) + mul_1756 = torch.ops.aten.mul.Tensor(div_214, sum_215); div_214 = sum_215 = None + sub_707 = torch.ops.aten.sub.Tensor(mul_1753, mul_1756); mul_1753 = mul_1756 = None + mul_1757 = torch.ops.aten.mul.Tensor(sub_707, rsqrt_40); sub_707 = rsqrt_40 = None + mul_1758 = torch.ops.aten.mul.Tensor(convert_element_type_2468, mul_601); convert_element_type_2468 = mul_601 = None + sum_216 = torch.ops.aten.sum.dim_IntList(mul_1758, [0, 1]); mul_1758 = None + convert_element_type_2471 = torch.ops.prims.convert_element_type.default(mul_1757, torch.bfloat16); mul_1757 = None + convert_element_type_default_41 = torch.ops.prims.convert_element_type.default(sum_216, torch.float32); sum_216 = None + reduce_scatter_tensor_194 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_41, 'avg', 128, '0'); convert_element_type_default_41 = None + wait_tensor_781 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_194); reduce_scatter_tensor_194 = None + convert_element_type_2474 = torch.ops.prims.convert_element_type.default(sum_214, torch.float32); sum_214 = None + view_2039 = torch.ops.aten.view.default(convert_element_type_2474, [2, 4096, 1, 32, 2]); convert_element_type_2474 = None + view_as_complex_80 = torch.ops.aten.view_as_complex.default(view_2039); view_2039 = None + mul_1759 = torch.ops.aten.mul.Tensor(view_as_complex_80, clone_9); view_as_complex_80 = None + view_as_real_80 = torch.ops.aten.view_as_real.default(mul_1759); mul_1759 = None + view_2040 = torch.ops.aten.view.default(view_as_real_80, [2, 4096, 1, 64]); view_as_real_80 = None + convert_element_type_2475 = torch.ops.prims.convert_element_type.default(view_2040, torch.bfloat16); view_2040 = None + squeeze_39 = torch.ops.aten.squeeze.dim(convert_element_type_2475, 2); convert_element_type_2475 = None + cat_345 = torch.ops.aten.cat.default([convert_element_type_2471, squeeze_39], 2); convert_element_type_2471 = squeeze_39 = None + view_2041 = torch.ops.aten.view.default(cat_345, [8192, 576]); cat_345 = None + permute_1098 = torch.ops.aten.permute.default(view_2041, [1, 0]) + mm_438 = torch.ops.aten.mm.default(permute_1098, view_840); permute_1098 = None + convert_element_type_693 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16); primals_217 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_693, 128, '0'); convert_element_type_693 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + slice_79 = torch.ops.aten.slice.Tensor(wait_tensor_265, 0, 0, 576); wait_tensor_265 = None + permute_192 = torch.ops.aten.permute.default(slice_79, [1, 0]); slice_79 = None + permute_1100 = torch.ops.aten.permute.default(permute_192, [1, 0]); permute_192 = None + mm_439 = torch.ops.aten.mm.default(view_2041, permute_1100); view_2041 = permute_1100 = None + view_2042 = torch.ops.aten.view.default(mm_439, [2, 4096, 2048]); mm_439 = None + convert_element_type_2480 = torch.ops.prims.convert_element_type.default(mm_438, torch.float32); mm_438 = None + split_897 = torch.ops.aten.split.Tensor(convert_element_type_2480, 5); convert_element_type_2480 = None + getitem_16753 = split_897[0] + getitem_16754 = split_897[1] + getitem_16755 = split_897[2] + getitem_16756 = split_897[3] + getitem_16757 = split_897[4] + getitem_16758 = split_897[5] + getitem_16759 = split_897[6] + getitem_16760 = split_897[7] + getitem_16761 = split_897[8] + getitem_16762 = split_897[9] + getitem_16763 = split_897[10] + getitem_16764 = split_897[11] + getitem_16765 = split_897[12] + getitem_16766 = split_897[13] + getitem_16767 = split_897[14] + getitem_16768 = split_897[15] + getitem_16769 = split_897[16] + getitem_16770 = split_897[17] + getitem_16771 = split_897[18] + getitem_16772 = split_897[19] + getitem_16773 = split_897[20] + getitem_16774 = split_897[21] + getitem_16775 = split_897[22] + getitem_16776 = split_897[23] + getitem_16777 = split_897[24] + getitem_16778 = split_897[25] + getitem_16779 = split_897[26] + getitem_16780 = split_897[27] + getitem_16781 = split_897[28] + getitem_16782 = split_897[29] + getitem_16783 = split_897[30] + getitem_16784 = split_897[31] + getitem_16785 = split_897[32] + getitem_16786 = split_897[33] + getitem_16787 = split_897[34] + getitem_16788 = split_897[35] + getitem_16789 = split_897[36] + getitem_16790 = split_897[37] + getitem_16791 = split_897[38] + getitem_16792 = split_897[39] + getitem_16793 = split_897[40] + getitem_16794 = split_897[41] + getitem_16795 = split_897[42] + getitem_16796 = split_897[43] + getitem_16797 = split_897[44] + getitem_16798 = split_897[45] + getitem_16799 = split_897[46] + getitem_16800 = split_897[47] + getitem_16801 = split_897[48] + getitem_16802 = split_897[49] + getitem_16803 = split_897[50] + getitem_16804 = split_897[51] + getitem_16805 = split_897[52] + getitem_16806 = split_897[53] + getitem_16807 = split_897[54] + getitem_16808 = split_897[55] + getitem_16809 = split_897[56] + getitem_16810 = split_897[57] + getitem_16811 = split_897[58] + getitem_16812 = split_897[59] + getitem_16813 = split_897[60] + getitem_16814 = split_897[61] + getitem_16815 = split_897[62] + getitem_16816 = split_897[63] + getitem_16817 = split_897[64] + getitem_16818 = split_897[65] + getitem_16819 = split_897[66] + getitem_16820 = split_897[67] + getitem_16821 = split_897[68] + getitem_16822 = split_897[69] + getitem_16823 = split_897[70] + getitem_16824 = split_897[71] + getitem_16825 = split_897[72] + getitem_16826 = split_897[73] + getitem_16827 = split_897[74] + getitem_16828 = split_897[75] + getitem_16829 = split_897[76] + getitem_16830 = split_897[77] + getitem_16831 = split_897[78] + getitem_16832 = split_897[79] + getitem_16833 = split_897[80] + getitem_16834 = split_897[81] + getitem_16835 = split_897[82] + getitem_16836 = split_897[83] + getitem_16837 = split_897[84] + getitem_16838 = split_897[85] + getitem_16839 = split_897[86] + getitem_16840 = split_897[87] + getitem_16841 = split_897[88] + getitem_16842 = split_897[89] + getitem_16843 = split_897[90] + getitem_16844 = split_897[91] + getitem_16845 = split_897[92] + getitem_16846 = split_897[93] + getitem_16847 = split_897[94] + getitem_16848 = split_897[95] + getitem_16849 = split_897[96] + getitem_16850 = split_897[97] + getitem_16851 = split_897[98] + getitem_16852 = split_897[99] + getitem_16853 = split_897[100] + getitem_16854 = split_897[101] + getitem_16855 = split_897[102] + getitem_16856 = split_897[103] + getitem_16857 = split_897[104] + getitem_16858 = split_897[105] + getitem_16859 = split_897[106] + getitem_16860 = split_897[107] + getitem_16861 = split_897[108] + getitem_16862 = split_897[109] + getitem_16863 = split_897[110] + getitem_16864 = split_897[111] + getitem_16865 = split_897[112] + getitem_16866 = split_897[113] + getitem_16867 = split_897[114] + getitem_16868 = split_897[115]; split_897 = None + constant_pad_nd_1065 = torch.ops.aten.constant_pad_nd.default(getitem_16868, [0, 0, 0, 4], 0.0); getitem_16868 = None + cat_346 = torch.ops.aten.cat.default([getitem_16753, getitem_16754, getitem_16755, getitem_16756, getitem_16757, getitem_16758, getitem_16759, getitem_16760, getitem_16761, getitem_16762, getitem_16763, getitem_16764, getitem_16765, getitem_16766, getitem_16767, getitem_16768, getitem_16769, getitem_16770, getitem_16771, getitem_16772, getitem_16773, getitem_16774, getitem_16775, getitem_16776, getitem_16777, getitem_16778, getitem_16779, getitem_16780, getitem_16781, getitem_16782, getitem_16783, getitem_16784, getitem_16785, getitem_16786, getitem_16787, getitem_16788, getitem_16789, getitem_16790, getitem_16791, getitem_16792, getitem_16793, getitem_16794, getitem_16795, getitem_16796, getitem_16797, getitem_16798, getitem_16799, getitem_16800, getitem_16801, getitem_16802, getitem_16803, getitem_16804, getitem_16805, getitem_16806, getitem_16807, getitem_16808, getitem_16809, getitem_16810, getitem_16811, getitem_16812, getitem_16813, getitem_16814, getitem_16815, getitem_16816, getitem_16817, getitem_16818, getitem_16819, getitem_16820, getitem_16821, getitem_16822, getitem_16823, getitem_16824, getitem_16825, getitem_16826, getitem_16827, getitem_16828, getitem_16829, getitem_16830, getitem_16831, getitem_16832, getitem_16833, getitem_16834, getitem_16835, getitem_16836, getitem_16837, getitem_16838, getitem_16839, getitem_16840, getitem_16841, getitem_16842, getitem_16843, getitem_16844, getitem_16845, getitem_16846, getitem_16847, getitem_16848, getitem_16849, getitem_16850, getitem_16851, getitem_16852, getitem_16853, getitem_16854, getitem_16855, getitem_16856, getitem_16857, getitem_16858, getitem_16859, getitem_16860, getitem_16861, getitem_16862, getitem_16863, getitem_16864, getitem_16865, getitem_16866, getitem_16867, constant_pad_nd_1065, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_16753 = getitem_16754 = getitem_16755 = getitem_16756 = getitem_16757 = getitem_16758 = getitem_16759 = getitem_16760 = getitem_16761 = getitem_16762 = getitem_16763 = getitem_16764 = getitem_16765 = getitem_16766 = getitem_16767 = getitem_16768 = getitem_16769 = getitem_16770 = getitem_16771 = getitem_16772 = getitem_16773 = getitem_16774 = getitem_16775 = getitem_16776 = getitem_16777 = getitem_16778 = getitem_16779 = getitem_16780 = getitem_16781 = getitem_16782 = getitem_16783 = getitem_16784 = getitem_16785 = getitem_16786 = getitem_16787 = getitem_16788 = getitem_16789 = getitem_16790 = getitem_16791 = getitem_16792 = getitem_16793 = getitem_16794 = getitem_16795 = getitem_16796 = getitem_16797 = getitem_16798 = getitem_16799 = getitem_16800 = getitem_16801 = getitem_16802 = getitem_16803 = getitem_16804 = getitem_16805 = getitem_16806 = getitem_16807 = getitem_16808 = getitem_16809 = getitem_16810 = getitem_16811 = getitem_16812 = getitem_16813 = getitem_16814 = getitem_16815 = getitem_16816 = getitem_16817 = getitem_16818 = getitem_16819 = getitem_16820 = getitem_16821 = getitem_16822 = getitem_16823 = getitem_16824 = getitem_16825 = getitem_16826 = getitem_16827 = getitem_16828 = getitem_16829 = getitem_16830 = getitem_16831 = getitem_16832 = getitem_16833 = getitem_16834 = getitem_16835 = getitem_16836 = getitem_16837 = getitem_16838 = getitem_16839 = getitem_16840 = getitem_16841 = getitem_16842 = getitem_16843 = getitem_16844 = getitem_16845 = getitem_16846 = getitem_16847 = getitem_16848 = getitem_16849 = getitem_16850 = getitem_16851 = getitem_16852 = getitem_16853 = getitem_16854 = getitem_16855 = getitem_16856 = getitem_16857 = getitem_16858 = getitem_16859 = getitem_16860 = getitem_16861 = getitem_16862 = getitem_16863 = getitem_16864 = getitem_16865 = getitem_16866 = getitem_16867 = constant_pad_nd_1065 = None + reduce_scatter_tensor_195 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_346, 'avg', 128, '0'); cat_346 = None + wait_tensor_782 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_195); reduce_scatter_tensor_195 = None + slice_244 = torch.ops.aten.slice.Tensor(permute_1093, 3, 0, 128) + slice_245 = torch.ops.aten.slice.Tensor(permute_1093, 3, 128, 192); permute_1093 = None + convert_element_type_2481 = torch.ops.prims.convert_element_type.default(slice_245, torch.float32); slice_245 = None + view_2043 = torch.ops.aten.view.default(convert_element_type_2481, [2, 4096, 16, 32, 2]); convert_element_type_2481 = None + view_as_complex_81 = torch.ops.aten.view_as_complex.default(view_2043); view_2043 = None + mul_1760 = torch.ops.aten.mul.Tensor(view_as_complex_81, clone_9); view_as_complex_81 = None + view_as_real_81 = torch.ops.aten.view_as_real.default(mul_1760); mul_1760 = None + view_2044 = torch.ops.aten.view.default(view_as_real_81, [2, 4096, 16, 64]); view_as_real_81 = None + convert_element_type_2482 = torch.ops.prims.convert_element_type.default(view_2044, torch.bfloat16); view_2044 = None + cat_347 = torch.ops.aten.cat.default([slice_244, convert_element_type_2482], 3); slice_244 = convert_element_type_2482 = None + view_2045 = torch.ops.aten.view.default(cat_347, [2, 4096, 3072]); cat_347 = None + view_2046 = torch.ops.aten.view.default(view_2045, [8192, 3072]); view_2045 = None + permute_1102 = torch.ops.aten.permute.default(view_2046, [1, 0]) + mm_440 = torch.ops.aten.mm.default(permute_1102, view_840); permute_1102 = view_840 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16); primals_216 = None + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 128, '0'); convert_element_type_688 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_191 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + permute_1104 = torch.ops.aten.permute.default(permute_191, [1, 0]); permute_191 = None + mm_441 = torch.ops.aten.mm.default(view_2046, permute_1104); view_2046 = permute_1104 = None + view_2047 = torch.ops.aten.view.default(mm_441, [2, 4096, 2048]); mm_441 = None + add_1983 = torch.ops.aten.add.Tensor(view_2042, view_2047); view_2042 = view_2047 = None + convert_element_type_2487 = torch.ops.prims.convert_element_type.default(mm_440, torch.float32); mm_440 = None + reduce_scatter_tensor_196 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2487, 'avg', 128, '0'); convert_element_type_2487 = None + wait_tensor_783 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_196); reduce_scatter_tensor_196 = None + convert_element_type_2488 = torch.ops.prims.convert_element_type.default(add_1983, torch.float32); add_1983 = None + convert_element_type_685 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16); primals_215 = None + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_685, 128, '0'); convert_element_type_685 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + convert_element_type_2490 = torch.ops.prims.convert_element_type.default(wait_tensor_263, torch.float32); wait_tensor_263 = None + mul_1761 = torch.ops.aten.mul.Tensor(convert_element_type_2488, convert_element_type_2490); convert_element_type_2490 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(add_821, torch.float32); add_821 = None + mul_597 = torch.ops.aten.mul.Tensor(convert_element_type_686, rsqrt_39); convert_element_type_686 = None + mul_1763 = torch.ops.aten.mul.Tensor(mul_597, mul_1761) + sum_217 = torch.ops.aten.sum.dim_IntList(mul_1763, [2], True); mul_1763 = None + div_215 = torch.ops.aten.div.Tensor(mul_597, 2048) + mul_1764 = torch.ops.aten.mul.Tensor(div_215, sum_217); div_215 = sum_217 = None + sub_708 = torch.ops.aten.sub.Tensor(mul_1761, mul_1764); mul_1761 = mul_1764 = None + mul_1765 = torch.ops.aten.mul.Tensor(sub_708, rsqrt_39); sub_708 = rsqrt_39 = None + mul_1766 = torch.ops.aten.mul.Tensor(convert_element_type_2488, mul_597); convert_element_type_2488 = mul_597 = None + sum_218 = torch.ops.aten.sum.dim_IntList(mul_1766, [0, 1]); mul_1766 = None + convert_element_type_2491 = torch.ops.prims.convert_element_type.default(mul_1765, torch.bfloat16); mul_1765 = None + add_1984 = torch.ops.aten.add.Tensor(add_1982, convert_element_type_2491); add_1982 = convert_element_type_2491 = None + convert_element_type_default_40 = torch.ops.prims.convert_element_type.default(sum_218, torch.float32); sum_218 = None + reduce_scatter_tensor_197 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_40, 'avg', 128, '0'); convert_element_type_default_40 = None + wait_tensor_784 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_197); reduce_scatter_tensor_197 = None + view_2048 = torch.ops.aten.view.default(add_1984, [8192, 2048]) + unsqueeze_67 = torch.ops.aten.unsqueeze.default(view_2048, 1) + convert_element_type_2494 = torch.ops.prims.convert_element_type.default(unsqueeze_67, torch.float32); unsqueeze_67 = None + bmm_54 = torch.ops.aten.bmm.default(permute_1106, convert_element_type_2494); permute_1106 = None + bmm_55 = torch.ops.aten.bmm.default(convert_element_type_2494, permute_1107); convert_element_type_2494 = permute_1107 = None + convert_element_type_2495 = torch.ops.prims.convert_element_type.default(bmm_54, torch.bfloat16); bmm_54 = None + view_2049 = torch.ops.aten.view.default(bmm_55, [8192, 6]); bmm_55 = None + view_2050 = torch.ops.aten.view.default(convert_element_type_2495, [49152, 2048]); convert_element_type_2495 = None + index_80 = torch.ops.aten.index.Tensor(view_2050, [getitem_1231]); view_2050 = getitem_1231 = None + permute_1108 = torch.ops.aten.permute.default(view_2048, [1, 0]) + mm_442 = torch.ops.aten.mm.default(permute_1108, mul_594); permute_1108 = mul_594 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16); primals_214 = None + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 128, '0'); convert_element_type_680 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_190 = torch.ops.aten.permute.default(wait_tensor_262, [1, 0]); wait_tensor_262 = None + permute_1110 = torch.ops.aten.permute.default(permute_190, [1, 0]); permute_190 = None + mm_443 = torch.ops.aten.mm.default(view_2048, permute_1110); view_2048 = permute_1110 = None + convert_element_type_2500 = torch.ops.prims.convert_element_type.default(mm_442, torch.float32); mm_442 = None + reduce_scatter_tensor_198 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2500, 'avg', 128, '0'); convert_element_type_2500 = None + wait_tensor_785 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_198); reduce_scatter_tensor_198 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(mm_100, torch.float32); mm_100 = None + neg_24 = torch.ops.aten.neg.default(convert_element_type_675) + exp_36 = torch.ops.aten.exp.default(neg_24); neg_24 = None + add_816 = torch.ops.aten.add.Tensor(exp_36, 1); exp_36 = None + div_60 = torch.ops.aten.div.Tensor(convert_element_type_675, add_816) + convert_element_type_676 = torch.ops.prims.convert_element_type.default(div_60, torch.bfloat16); div_60 = None + mul_1767 = torch.ops.aten.mul.Tensor(mm_443, convert_element_type_676); convert_element_type_676 = None + mul_1768 = torch.ops.aten.mul.Tensor(mm_443, mm_101); mm_443 = mm_101 = None + permute_1112 = torch.ops.aten.permute.default(mul_1767, [1, 0]) + mm_444 = torch.ops.aten.mm.default(permute_1112, view_795); permute_1112 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16); primals_213 = None + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 128, '0'); convert_element_type_677 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + permute_1114 = torch.ops.aten.permute.default(permute_189, [1, 0]); permute_189 = None + mm_445 = torch.ops.aten.mm.default(mul_1767, permute_1114); mul_1767 = permute_1114 = None + convert_element_type_2505 = torch.ops.prims.convert_element_type.default(mm_444, torch.float32); mm_444 = None + reduce_scatter_tensor_199 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2505, 'avg', 128, '0'); convert_element_type_2505 = None + wait_tensor_786 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_199); reduce_scatter_tensor_199 = None + convert_element_type_2506 = torch.ops.prims.convert_element_type.default(mul_1768, torch.float32); mul_1768 = None + reciprocal_28 = torch.ops.aten.reciprocal.default(add_816); add_816 = None + mul_1769 = torch.ops.aten.mul.Tensor(reciprocal_28, 1); reciprocal_28 = None + mul_1770 = torch.ops.aten.mul.Tensor(convert_element_type_2506, mul_1769); convert_element_type_2506 = None + sub_709 = torch.ops.aten.sub.Tensor(1, mul_1769); mul_1769 = None + mul_1771 = torch.ops.aten.mul.Tensor(convert_element_type_675, sub_709); convert_element_type_675 = sub_709 = None + add_1986 = torch.ops.aten.add.Tensor(mul_1771, 1); mul_1771 = None + mul_1772 = torch.ops.aten.mul.Tensor(mul_1770, add_1986); mul_1770 = add_1986 = None + convert_element_type_2508 = torch.ops.prims.convert_element_type.default(mul_1772, torch.bfloat16); mul_1772 = None + permute_1116 = torch.ops.aten.permute.default(convert_element_type_2508, [1, 0]) + mm_446 = torch.ops.aten.mm.default(permute_1116, view_795); permute_1116 = None + convert_element_type_672 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16); primals_212 = None + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_672, 128, '0'); convert_element_type_672 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + permute_1118 = torch.ops.aten.permute.default(permute_188, [1, 0]); permute_188 = None + mm_447 = torch.ops.aten.mm.default(convert_element_type_2508, permute_1118); convert_element_type_2508 = permute_1118 = None + add_1987 = torch.ops.aten.add.Tensor(mm_445, mm_447); mm_445 = mm_447 = None + convert_element_type_2513 = torch.ops.prims.convert_element_type.default(mm_446, torch.float32); mm_446 = None + reduce_scatter_tensor_200 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2513, 'avg', 128, '0'); convert_element_type_2513 = None + wait_tensor_787 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_200); reduce_scatter_tensor_200 = None + all_to_all_single_106 = torch.ops._c10d_functional.all_to_all_single.default(index_80, [_local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191], [_local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183], '1033'); index_80 = None + wait_tensor_788 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_106); all_to_all_single_106 = None + full_432 = torch.ops.aten.full.default([sym_size_int_45, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_45 = None + slice_scatter_14 = torch.ops.aten.slice_scatter.default(full_432, wait_tensor_788, 0, 0, -1); wait_tensor_788 = None + index_81 = torch.ops.aten.index.Tensor(slice_scatter_14, [getitem_1232]); slice_scatter_14 = None + permute_1120 = torch.ops.aten.permute.default(index_81, [1, 0]) + _grouped_mm_162 = torch.ops.aten._grouped_mm.default(permute_1120, mul_574, cumsum_35); permute_1120 = mul_574 = None + _grouped_mm_163 = torch.ops.aten._grouped_mm.default(index_81, permute_1122, cumsum_35); index_81 = permute_1122 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(_grouped_mm_33, torch.float32); _grouped_mm_33 = None + neg_23 = torch.ops.aten.neg.default(convert_element_type_670) + exp_35 = torch.ops.aten.exp.default(neg_23); neg_23 = None + add_780 = torch.ops.aten.add.Tensor(exp_35, 1); exp_35 = None + div_59 = torch.ops.aten.div.Tensor(convert_element_type_670, add_780) + convert_element_type_671 = torch.ops.prims.convert_element_type.default(div_59, torch.bfloat16); div_59 = None + mul_1773 = torch.ops.aten.mul.Tensor(_grouped_mm_163, convert_element_type_671); convert_element_type_671 = None + mul_1774 = torch.ops.aten.mul.Tensor(_grouped_mm_163, _grouped_mm_34); _grouped_mm_163 = _grouped_mm_34 = None + permute_1124 = torch.ops.aten.permute.default(mul_1773, [1, 0]) + _grouped_mm_164 = torch.ops.aten._grouped_mm.default(permute_1124, index_23, cumsum_35); permute_1124 = None + _grouped_mm_165 = torch.ops.aten._grouped_mm.default(mul_1773, permute_1126, cumsum_35); mul_1773 = permute_1126 = None + convert_element_type_2514 = torch.ops.prims.convert_element_type.default(mul_1774, torch.float32); mul_1774 = None + reciprocal_29 = torch.ops.aten.reciprocal.default(add_780); add_780 = None + mul_1775 = torch.ops.aten.mul.Tensor(reciprocal_29, 1); reciprocal_29 = None + mul_1776 = torch.ops.aten.mul.Tensor(convert_element_type_2514, mul_1775); convert_element_type_2514 = None + sub_710 = torch.ops.aten.sub.Tensor(1, mul_1775); mul_1775 = None + mul_1777 = torch.ops.aten.mul.Tensor(convert_element_type_670, sub_710); convert_element_type_670 = sub_710 = None + add_1989 = torch.ops.aten.add.Tensor(mul_1777, 1); mul_1777 = None + mul_1778 = torch.ops.aten.mul.Tensor(mul_1776, add_1989); mul_1776 = add_1989 = None + convert_element_type_2516 = torch.ops.prims.convert_element_type.default(mul_1778, torch.bfloat16); mul_1778 = None + permute_1128 = torch.ops.aten.permute.default(convert_element_type_2516, [1, 0]) + _grouped_mm_166 = torch.ops.aten._grouped_mm.default(permute_1128, index_23, cumsum_35); permute_1128 = index_23 = None + _grouped_mm_167 = torch.ops.aten._grouped_mm.default(convert_element_type_2516, permute_1130, cumsum_35); convert_element_type_2516 = permute_1130 = cumsum_35 = None + add_1990 = torch.ops.aten.add.Tensor(_grouped_mm_165, _grouped_mm_167); _grouped_mm_165 = _grouped_mm_167 = None + convert_element_type_2517 = torch.ops.prims.convert_element_type.default(_grouped_mm_164, torch.float32); _grouped_mm_164 = None + div_216 = torch.ops.aten.div.Tensor(convert_element_type_2517, 128); convert_element_type_2517 = None + split_899 = torch.ops.aten.split.Tensor(div_216, 88, 1); div_216 = None + getitem_16885 = split_899[0] + getitem_16902 = split_899[1] + getitem_16919 = split_899[2] + getitem_16936 = split_899[3] + getitem_16953 = split_899[4] + getitem_16970 = split_899[5] + getitem_16987 = split_899[6] + getitem_17004 = split_899[7] + getitem_17021 = split_899[8] + getitem_17038 = split_899[9] + getitem_17055 = split_899[10] + getitem_17072 = split_899[11] + getitem_17089 = split_899[12] + getitem_17106 = split_899[13] + getitem_17123 = split_899[14] + getitem_17140 = split_899[15]; split_899 = None + cat_348 = torch.ops.aten.cat.default([getitem_16885, getitem_16902, getitem_16919, getitem_16936, getitem_16953, getitem_16970, getitem_16987, getitem_17004, getitem_17021, getitem_17038, getitem_17055, getitem_17072, getitem_17089, getitem_17106, getitem_17123, getitem_17140]); getitem_16885 = getitem_16902 = getitem_16919 = getitem_16936 = getitem_16953 = getitem_16970 = getitem_16987 = getitem_17004 = getitem_17021 = getitem_17038 = getitem_17055 = getitem_17072 = getitem_17089 = getitem_17106 = getitem_17123 = getitem_17140 = None + reduce_scatter_tensor_201 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_348, 'sum', 16, '1025'); cat_348 = None + wait_tensor_789 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_201); reduce_scatter_tensor_201 = None + convert_element_type_2518 = torch.ops.prims.convert_element_type.default(_grouped_mm_162, torch.float32); _grouped_mm_162 = None + div_217 = torch.ops.aten.div.Tensor(convert_element_type_2518, 128); convert_element_type_2518 = None + split_916 = torch.ops.aten.split.Tensor(div_217, 128, 1); div_217 = None + getitem_17157 = split_916[0] + getitem_17174 = split_916[1] + getitem_17191 = split_916[2] + getitem_17208 = split_916[3] + getitem_17225 = split_916[4] + getitem_17242 = split_916[5] + getitem_17259 = split_916[6] + getitem_17276 = split_916[7] + getitem_17293 = split_916[8] + getitem_17310 = split_916[9] + getitem_17327 = split_916[10] + getitem_17344 = split_916[11] + getitem_17361 = split_916[12] + getitem_17378 = split_916[13] + getitem_17395 = split_916[14] + getitem_17412 = split_916[15]; split_916 = None + cat_349 = torch.ops.aten.cat.default([getitem_17157, getitem_17174, getitem_17191, getitem_17208, getitem_17225, getitem_17242, getitem_17259, getitem_17276, getitem_17293, getitem_17310, getitem_17327, getitem_17344, getitem_17361, getitem_17378, getitem_17395, getitem_17412]); getitem_17157 = getitem_17174 = getitem_17191 = getitem_17208 = getitem_17225 = getitem_17242 = getitem_17259 = getitem_17276 = getitem_17293 = getitem_17310 = getitem_17327 = getitem_17344 = getitem_17361 = getitem_17378 = getitem_17395 = getitem_17412 = None + reduce_scatter_tensor_202 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_349, 'sum', 16, '1025'); cat_349 = None + wait_tensor_790 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_202); reduce_scatter_tensor_202 = None + convert_element_type_2519 = torch.ops.prims.convert_element_type.default(_grouped_mm_166, torch.float32); _grouped_mm_166 = None + div_218 = torch.ops.aten.div.Tensor(convert_element_type_2519, 128); convert_element_type_2519 = None + split_933 = torch.ops.aten.split.Tensor(div_218, 88, 1); div_218 = None + getitem_17429 = split_933[0] + getitem_17446 = split_933[1] + getitem_17463 = split_933[2] + getitem_17480 = split_933[3] + getitem_17497 = split_933[4] + getitem_17514 = split_933[5] + getitem_17531 = split_933[6] + getitem_17548 = split_933[7] + getitem_17565 = split_933[8] + getitem_17582 = split_933[9] + getitem_17599 = split_933[10] + getitem_17616 = split_933[11] + getitem_17633 = split_933[12] + getitem_17650 = split_933[13] + getitem_17667 = split_933[14] + getitem_17684 = split_933[15]; split_933 = None + cat_350 = torch.ops.aten.cat.default([getitem_17429, getitem_17446, getitem_17463, getitem_17480, getitem_17497, getitem_17514, getitem_17531, getitem_17548, getitem_17565, getitem_17582, getitem_17599, getitem_17616, getitem_17633, getitem_17650, getitem_17667, getitem_17684]); getitem_17429 = getitem_17446 = getitem_17463 = getitem_17480 = getitem_17497 = getitem_17514 = getitem_17531 = getitem_17548 = getitem_17565 = getitem_17582 = getitem_17599 = getitem_17616 = getitem_17633 = getitem_17650 = getitem_17667 = getitem_17684 = None + reduce_scatter_tensor_203 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_350, 'sum', 16, '1025'); cat_350 = None + wait_tensor_791 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_203); reduce_scatter_tensor_203 = None + index_put_80 = torch.ops.aten.index_put.default(full_432, [getitem_1232], add_1990, True); full_432 = getitem_1232 = add_1990 = None + slice_246 = torch.ops.aten.slice.Tensor(index_put_80, 0, 0, add_1991); index_put_80 = add_1991 = None + all_to_all_single_107 = torch.ops._c10d_functional.all_to_all_single.default(slice_246, [_local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183], [_local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191], '1033'); slice_246 = _local_scalar_dense_176 = _local_scalar_dense_177 = _local_scalar_dense_178 = _local_scalar_dense_179 = _local_scalar_dense_180 = _local_scalar_dense_181 = _local_scalar_dense_182 = _local_scalar_dense_183 = _local_scalar_dense_184 = _local_scalar_dense_185 = _local_scalar_dense_186 = _local_scalar_dense_187 = _local_scalar_dense_188 = _local_scalar_dense_189 = _local_scalar_dense_190 = _local_scalar_dense_191 = None + wait_tensor_792 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_107); all_to_all_single_107 = None + index_put_81 = torch.ops.aten.index_put.default(full_default_52, [div_57], wait_tensor_792, True); div_57 = wait_tensor_792 = None + add_1995 = torch.ops.aten.add.Tensor(add_1987, index_put_81); add_1987 = index_put_81 = None + mul_1779 = torch.ops.aten.mul.Tensor(view_2049, 1.0); view_2049 = None + scatter_add_14 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_1229, mul_1779); getitem_1229 = mul_1779 = None + convert_element_type_659 = torch.ops.prims.convert_element_type.default(mm_99, torch.float32); mm_99 = None + sub_264 = torch.ops.aten.sub.Tensor(convert_element_type_659, amax_11); convert_element_type_659 = amax_11 = None + exp_34 = torch.ops.aten.exp.default(sub_264); sub_264 = None + div_56 = torch.ops.aten.div.Tensor(exp_34, sum_45); exp_34 = sum_45 = None + mul_1780 = torch.ops.aten.mul.Tensor(scatter_add_14, div_56); scatter_add_14 = None + sum_219 = torch.ops.aten.sum.dim_IntList(mul_1780, [1], True) + neg_97 = torch.ops.aten.neg.default(div_56); div_56 = None + fma_14 = torch.ops.prims.fma.default(neg_97, sum_219, mul_1780); neg_97 = sum_219 = mul_1780 = None + convert_element_type_2520 = torch.ops.prims.convert_element_type.default(fma_14, torch.bfloat16); fma_14 = None + permute_1132 = torch.ops.aten.permute.default(convert_element_type_2520, [1, 0]) + mm_448 = torch.ops.aten.mm.default(permute_1132, view_795); permute_1132 = view_795 = None + convert_element_type_656 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16); primals_207 = None + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_656, 128, '0'); convert_element_type_656 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + slice_75 = torch.ops.aten.slice.Tensor(wait_tensor_249, 0, 0, 64); wait_tensor_249 = None + permute_184 = torch.ops.aten.permute.default(slice_75, [1, 0]); slice_75 = None + permute_1134 = torch.ops.aten.permute.default(permute_184, [1, 0]); permute_184 = None + mm_449 = torch.ops.aten.mm.default(convert_element_type_2520, permute_1134); convert_element_type_2520 = permute_1134 = None + add_1996 = torch.ops.aten.add.Tensor(add_1995, mm_449); add_1995 = mm_449 = None + convert_element_type_2525 = torch.ops.prims.convert_element_type.default(mm_448, torch.float32); mm_448 = None + split_949 = torch.ops.aten.split.Tensor(convert_element_type_2525, 1); convert_element_type_2525 = None + getitem_17685 = split_949[0] + getitem_17686 = split_949[1] + getitem_17687 = split_949[2] + getitem_17688 = split_949[3] + getitem_17689 = split_949[4] + getitem_17690 = split_949[5] + getitem_17691 = split_949[6] + getitem_17692 = split_949[7] + getitem_17693 = split_949[8] + getitem_17694 = split_949[9] + getitem_17695 = split_949[10] + getitem_17696 = split_949[11] + getitem_17697 = split_949[12] + getitem_17698 = split_949[13] + getitem_17699 = split_949[14] + getitem_17700 = split_949[15] + getitem_17701 = split_949[16] + getitem_17702 = split_949[17] + getitem_17703 = split_949[18] + getitem_17704 = split_949[19] + getitem_17705 = split_949[20] + getitem_17706 = split_949[21] + getitem_17707 = split_949[22] + getitem_17708 = split_949[23] + getitem_17709 = split_949[24] + getitem_17710 = split_949[25] + getitem_17711 = split_949[26] + getitem_17712 = split_949[27] + getitem_17713 = split_949[28] + getitem_17714 = split_949[29] + getitem_17715 = split_949[30] + getitem_17716 = split_949[31] + getitem_17717 = split_949[32] + getitem_17718 = split_949[33] + getitem_17719 = split_949[34] + getitem_17720 = split_949[35] + getitem_17721 = split_949[36] + getitem_17722 = split_949[37] + getitem_17723 = split_949[38] + getitem_17724 = split_949[39] + getitem_17725 = split_949[40] + getitem_17726 = split_949[41] + getitem_17727 = split_949[42] + getitem_17728 = split_949[43] + getitem_17729 = split_949[44] + getitem_17730 = split_949[45] + getitem_17731 = split_949[46] + getitem_17732 = split_949[47] + getitem_17733 = split_949[48] + getitem_17734 = split_949[49] + getitem_17735 = split_949[50] + getitem_17736 = split_949[51] + getitem_17737 = split_949[52] + getitem_17738 = split_949[53] + getitem_17739 = split_949[54] + getitem_17740 = split_949[55] + getitem_17741 = split_949[56] + getitem_17742 = split_949[57] + getitem_17743 = split_949[58] + getitem_17744 = split_949[59] + getitem_17745 = split_949[60] + getitem_17746 = split_949[61] + getitem_17747 = split_949[62] + getitem_17748 = split_949[63]; split_949 = None + cat_351 = torch.ops.aten.cat.default([getitem_17685, getitem_17686, getitem_17687, getitem_17688, getitem_17689, getitem_17690, getitem_17691, getitem_17692, getitem_17693, getitem_17694, getitem_17695, getitem_17696, getitem_17697, getitem_17698, getitem_17699, getitem_17700, getitem_17701, getitem_17702, getitem_17703, getitem_17704, getitem_17705, getitem_17706, getitem_17707, getitem_17708, getitem_17709, getitem_17710, getitem_17711, getitem_17712, getitem_17713, getitem_17714, getitem_17715, getitem_17716, getitem_17717, getitem_17718, getitem_17719, getitem_17720, getitem_17721, getitem_17722, getitem_17723, getitem_17724, getitem_17725, getitem_17726, getitem_17727, getitem_17728, getitem_17729, getitem_17730, getitem_17731, getitem_17732, getitem_17733, getitem_17734, getitem_17735, getitem_17736, getitem_17737, getitem_17738, getitem_17739, getitem_17740, getitem_17741, getitem_17742, getitem_17743, getitem_17744, getitem_17745, getitem_17746, getitem_17747, getitem_17748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_17685 = getitem_17686 = getitem_17687 = getitem_17688 = getitem_17689 = getitem_17690 = getitem_17691 = getitem_17692 = getitem_17693 = getitem_17694 = getitem_17695 = getitem_17696 = getitem_17697 = getitem_17698 = getitem_17699 = getitem_17700 = getitem_17701 = getitem_17702 = getitem_17703 = getitem_17704 = getitem_17705 = getitem_17706 = getitem_17707 = getitem_17708 = getitem_17709 = getitem_17710 = getitem_17711 = getitem_17712 = getitem_17713 = getitem_17714 = getitem_17715 = getitem_17716 = getitem_17717 = getitem_17718 = getitem_17719 = getitem_17720 = getitem_17721 = getitem_17722 = getitem_17723 = getitem_17724 = getitem_17725 = getitem_17726 = getitem_17727 = getitem_17728 = getitem_17729 = getitem_17730 = getitem_17731 = getitem_17732 = getitem_17733 = getitem_17734 = getitem_17735 = getitem_17736 = getitem_17737 = getitem_17738 = getitem_17739 = getitem_17740 = getitem_17741 = getitem_17742 = getitem_17743 = getitem_17744 = getitem_17745 = getitem_17746 = getitem_17747 = getitem_17748 = None + reduce_scatter_tensor_204 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_351, 'avg', 128, '0'); cat_351 = None + wait_tensor_793 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_204); reduce_scatter_tensor_204 = None + view_2051 = torch.ops.aten.view.default(add_1996, [2, 4096, 2048]); add_1996 = None + convert_element_type_2526 = torch.ops.prims.convert_element_type.default(view_2051, torch.float32); view_2051 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16); primals_205 = None + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_653, 128, '0'); convert_element_type_653 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_2528 = torch.ops.prims.convert_element_type.default(wait_tensor_248, torch.float32); wait_tensor_248 = None + mul_1781 = torch.ops.aten.mul.Tensor(convert_element_type_2526, convert_element_type_2528); convert_element_type_2528 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(add_756, torch.float32); add_756 = None + mul_554 = torch.ops.aten.mul.Tensor(convert_element_type_654, rsqrt_38); convert_element_type_654 = None + mul_1783 = torch.ops.aten.mul.Tensor(mul_554, mul_1781) + sum_220 = torch.ops.aten.sum.dim_IntList(mul_1783, [2], True); mul_1783 = None + div_219 = torch.ops.aten.div.Tensor(mul_554, 2048) + mul_1784 = torch.ops.aten.mul.Tensor(div_219, sum_220); div_219 = sum_220 = None + sub_712 = torch.ops.aten.sub.Tensor(mul_1781, mul_1784); mul_1781 = mul_1784 = None + mul_1785 = torch.ops.aten.mul.Tensor(sub_712, rsqrt_38); sub_712 = rsqrt_38 = None + mul_1786 = torch.ops.aten.mul.Tensor(convert_element_type_2526, mul_554); convert_element_type_2526 = mul_554 = None + sum_221 = torch.ops.aten.sum.dim_IntList(mul_1786, [0, 1]); mul_1786 = None + convert_element_type_2529 = torch.ops.prims.convert_element_type.default(mul_1785, torch.bfloat16); mul_1785 = None + add_1997 = torch.ops.aten.add.Tensor(add_1984, convert_element_type_2529); add_1984 = convert_element_type_2529 = None + convert_element_type_default_39 = torch.ops.prims.convert_element_type.default(sum_221, torch.float32); sum_221 = None + reduce_scatter_tensor_205 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_39, 'avg', 128, '0'); convert_element_type_default_39 = None + wait_tensor_794 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_205); reduce_scatter_tensor_205 = None + view_2052 = torch.ops.aten.view.default(add_1997, [8192, 2048]) + permute_1136 = torch.ops.aten.permute.default(view_2052, [1, 0]) + permute_182 = torch.ops.aten.permute.default(getitem_1225, [0, 2, 1, 3]) + view_790 = torch.ops.aten.view.default(permute_182, [2, 4096, -1]); permute_182 = None + view_792 = torch.ops.aten.view.default(view_790, [8192, 2048]); view_790 = None + mm_450 = torch.ops.aten.mm.default(permute_1136, view_792); permute_1136 = view_792 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16); primals_204 = None + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 128, '0'); convert_element_type_650 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + permute_1138 = torch.ops.aten.permute.default(permute_183, [1, 0]); permute_183 = None + mm_451 = torch.ops.aten.mm.default(view_2052, permute_1138); view_2052 = permute_1138 = None + view_2053 = torch.ops.aten.view.default(mm_451, [2, 4096, 2048]); mm_451 = None + convert_element_type_2536 = torch.ops.prims.convert_element_type.default(mm_450, torch.float32); mm_450 = None + reduce_scatter_tensor_206 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2536, 'avg', 128, '0'); convert_element_type_2536 = None + wait_tensor_795 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_206); reduce_scatter_tensor_206 = None + view_2054 = torch.ops.aten.view.default(view_2053, [2, 4096, 16, 128]); view_2053 = None + permute_1140 = torch.ops.aten.permute.default(view_2054, [0, 2, 1, 3]); view_2054 = None + fw_graph14 = self.fw_graph14 + joint_graph14 = self.joint_graph14 + mask_graph14 = self.mask_graph14 + flex_attention_backward_14 = torch.ops.higher_order.flex_attention_backward(permute_179, permute_180, permute_181, getitem_1225, getitem_1226, permute_1140, None, fw_graph14, joint_graph14, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph14), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_179 = permute_180 = permute_181 = getitem_1225 = getitem_1226 = permute_1140 = fw_graph14 = joint_graph14 = mask_graph14 = None + getitem_17749 = flex_attention_backward_14[0] + getitem_17750 = flex_attention_backward_14[1] + getitem_17751 = flex_attention_backward_14[2]; flex_attention_backward_14 = None + permute_1141 = torch.ops.aten.permute.default(getitem_17751, [0, 2, 1, 3]); getitem_17751 = None + permute_1142 = torch.ops.aten.permute.default(getitem_17750, [0, 2, 1, 3]); getitem_17750 = None + permute_1143 = torch.ops.aten.permute.default(getitem_17749, [0, 2, 1, 3]); getitem_17749 = None + slice_248 = torch.ops.aten.slice.Tensor(permute_1142, 3, 0, 128) + slice_249 = torch.ops.aten.slice.Tensor(permute_1142, 3, 128, 192); permute_1142 = None + sum_222 = torch.ops.aten.sum.dim_IntList(slice_249, [2], True); slice_249 = None + cat_352 = torch.ops.aten.cat.default([slice_248, permute_1141], 3); slice_248 = permute_1141 = None + view_2055 = torch.ops.aten.view.default(cat_352, [2, 4096, 4096]); cat_352 = None + view_2056 = torch.ops.aten.view.default(view_2055, [8192, 4096]); view_2055 = None + permute_1144 = torch.ops.aten.permute.default(view_2056, [1, 0]) + mm_452 = torch.ops.aten.mm.default(permute_1144, view_787); permute_1144 = view_787 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16); primals_203 = None + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 128, '0'); convert_element_type_647 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + permute_1146 = torch.ops.aten.permute.default(permute_178, [1, 0]); permute_178 = None + mm_453 = torch.ops.aten.mm.default(view_2056, permute_1146); view_2056 = permute_1146 = None + view_2057 = torch.ops.aten.view.default(mm_453, [2, 4096, 512]); mm_453 = None + convert_element_type_2541 = torch.ops.prims.convert_element_type.default(mm_452, torch.float32); mm_452 = None + reduce_scatter_tensor_207 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2541, 'avg', 128, '0'); convert_element_type_2541 = None + wait_tensor_796 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_207); reduce_scatter_tensor_207 = None + convert_element_type_2542 = torch.ops.prims.convert_element_type.default(view_2057, torch.float32); view_2057 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16); primals_202 = None + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 128, '0'); convert_element_type_644 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + convert_element_type_2544 = torch.ops.prims.convert_element_type.default(wait_tensor_245, torch.float32); wait_tensor_245 = None + mul_1787 = torch.ops.aten.mul.Tensor(convert_element_type_2542, convert_element_type_2544); convert_element_type_2544 = None + convert_element_type_645 = torch.ops.prims.convert_element_type.default(getitem_1221, torch.float32); getitem_1221 = None + mul_552 = torch.ops.aten.mul.Tensor(convert_element_type_645, rsqrt_37); convert_element_type_645 = None + mul_1789 = torch.ops.aten.mul.Tensor(mul_552, mul_1787) + sum_223 = torch.ops.aten.sum.dim_IntList(mul_1789, [2], True); mul_1789 = None + div_220 = torch.ops.aten.div.Tensor(mul_552, 512) + mul_1790 = torch.ops.aten.mul.Tensor(div_220, sum_223); div_220 = sum_223 = None + sub_713 = torch.ops.aten.sub.Tensor(mul_1787, mul_1790); mul_1787 = mul_1790 = None + mul_1791 = torch.ops.aten.mul.Tensor(sub_713, rsqrt_37); sub_713 = rsqrt_37 = None + mul_1792 = torch.ops.aten.mul.Tensor(convert_element_type_2542, mul_552); convert_element_type_2542 = mul_552 = None + sum_224 = torch.ops.aten.sum.dim_IntList(mul_1792, [0, 1]); mul_1792 = None + convert_element_type_2545 = torch.ops.prims.convert_element_type.default(mul_1791, torch.bfloat16); mul_1791 = None + convert_element_type_default_38 = torch.ops.prims.convert_element_type.default(sum_224, torch.float32); sum_224 = None + reduce_scatter_tensor_208 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_38, 'avg', 128, '0'); convert_element_type_default_38 = None + wait_tensor_797 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_208); reduce_scatter_tensor_208 = None + convert_element_type_2548 = torch.ops.prims.convert_element_type.default(sum_222, torch.float32); sum_222 = None + view_2058 = torch.ops.aten.view.default(convert_element_type_2548, [2, 4096, 1, 32, 2]); convert_element_type_2548 = None + view_as_complex_82 = torch.ops.aten.view_as_complex.default(view_2058); view_2058 = None + mul_1793 = torch.ops.aten.mul.Tensor(view_as_complex_82, clone_9); view_as_complex_82 = None + view_as_real_82 = torch.ops.aten.view_as_real.default(mul_1793); mul_1793 = None + view_2059 = torch.ops.aten.view.default(view_as_real_82, [2, 4096, 1, 64]); view_as_real_82 = None + convert_element_type_2549 = torch.ops.prims.convert_element_type.default(view_2059, torch.bfloat16); view_2059 = None + squeeze_40 = torch.ops.aten.squeeze.dim(convert_element_type_2549, 2); convert_element_type_2549 = None + cat_353 = torch.ops.aten.cat.default([convert_element_type_2545, squeeze_40], 2); convert_element_type_2545 = squeeze_40 = None + view_2060 = torch.ops.aten.view.default(cat_353, [8192, 576]); cat_353 = None + permute_1148 = torch.ops.aten.permute.default(view_2060, [1, 0]) + mm_454 = torch.ops.aten.mm.default(permute_1148, view_773); permute_1148 = None + convert_element_type_639 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16); primals_201 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_639, 128, '0'); convert_element_type_639 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + slice_73 = torch.ops.aten.slice.Tensor(wait_tensor_244, 0, 0, 576); wait_tensor_244 = None + permute_177 = torch.ops.aten.permute.default(slice_73, [1, 0]); slice_73 = None + permute_1150 = torch.ops.aten.permute.default(permute_177, [1, 0]); permute_177 = None + mm_455 = torch.ops.aten.mm.default(view_2060, permute_1150); view_2060 = permute_1150 = None + view_2061 = torch.ops.aten.view.default(mm_455, [2, 4096, 2048]); mm_455 = None + convert_element_type_2554 = torch.ops.prims.convert_element_type.default(mm_454, torch.float32); mm_454 = None + split_950 = torch.ops.aten.split.Tensor(convert_element_type_2554, 5); convert_element_type_2554 = None + getitem_17753 = split_950[0] + getitem_17754 = split_950[1] + getitem_17755 = split_950[2] + getitem_17756 = split_950[3] + getitem_17757 = split_950[4] + getitem_17758 = split_950[5] + getitem_17759 = split_950[6] + getitem_17760 = split_950[7] + getitem_17761 = split_950[8] + getitem_17762 = split_950[9] + getitem_17763 = split_950[10] + getitem_17764 = split_950[11] + getitem_17765 = split_950[12] + getitem_17766 = split_950[13] + getitem_17767 = split_950[14] + getitem_17768 = split_950[15] + getitem_17769 = split_950[16] + getitem_17770 = split_950[17] + getitem_17771 = split_950[18] + getitem_17772 = split_950[19] + getitem_17773 = split_950[20] + getitem_17774 = split_950[21] + getitem_17775 = split_950[22] + getitem_17776 = split_950[23] + getitem_17777 = split_950[24] + getitem_17778 = split_950[25] + getitem_17779 = split_950[26] + getitem_17780 = split_950[27] + getitem_17781 = split_950[28] + getitem_17782 = split_950[29] + getitem_17783 = split_950[30] + getitem_17784 = split_950[31] + getitem_17785 = split_950[32] + getitem_17786 = split_950[33] + getitem_17787 = split_950[34] + getitem_17788 = split_950[35] + getitem_17789 = split_950[36] + getitem_17790 = split_950[37] + getitem_17791 = split_950[38] + getitem_17792 = split_950[39] + getitem_17793 = split_950[40] + getitem_17794 = split_950[41] + getitem_17795 = split_950[42] + getitem_17796 = split_950[43] + getitem_17797 = split_950[44] + getitem_17798 = split_950[45] + getitem_17799 = split_950[46] + getitem_17800 = split_950[47] + getitem_17801 = split_950[48] + getitem_17802 = split_950[49] + getitem_17803 = split_950[50] + getitem_17804 = split_950[51] + getitem_17805 = split_950[52] + getitem_17806 = split_950[53] + getitem_17807 = split_950[54] + getitem_17808 = split_950[55] + getitem_17809 = split_950[56] + getitem_17810 = split_950[57] + getitem_17811 = split_950[58] + getitem_17812 = split_950[59] + getitem_17813 = split_950[60] + getitem_17814 = split_950[61] + getitem_17815 = split_950[62] + getitem_17816 = split_950[63] + getitem_17817 = split_950[64] + getitem_17818 = split_950[65] + getitem_17819 = split_950[66] + getitem_17820 = split_950[67] + getitem_17821 = split_950[68] + getitem_17822 = split_950[69] + getitem_17823 = split_950[70] + getitem_17824 = split_950[71] + getitem_17825 = split_950[72] + getitem_17826 = split_950[73] + getitem_17827 = split_950[74] + getitem_17828 = split_950[75] + getitem_17829 = split_950[76] + getitem_17830 = split_950[77] + getitem_17831 = split_950[78] + getitem_17832 = split_950[79] + getitem_17833 = split_950[80] + getitem_17834 = split_950[81] + getitem_17835 = split_950[82] + getitem_17836 = split_950[83] + getitem_17837 = split_950[84] + getitem_17838 = split_950[85] + getitem_17839 = split_950[86] + getitem_17840 = split_950[87] + getitem_17841 = split_950[88] + getitem_17842 = split_950[89] + getitem_17843 = split_950[90] + getitem_17844 = split_950[91] + getitem_17845 = split_950[92] + getitem_17846 = split_950[93] + getitem_17847 = split_950[94] + getitem_17848 = split_950[95] + getitem_17849 = split_950[96] + getitem_17850 = split_950[97] + getitem_17851 = split_950[98] + getitem_17852 = split_950[99] + getitem_17853 = split_950[100] + getitem_17854 = split_950[101] + getitem_17855 = split_950[102] + getitem_17856 = split_950[103] + getitem_17857 = split_950[104] + getitem_17858 = split_950[105] + getitem_17859 = split_950[106] + getitem_17860 = split_950[107] + getitem_17861 = split_950[108] + getitem_17862 = split_950[109] + getitem_17863 = split_950[110] + getitem_17864 = split_950[111] + getitem_17865 = split_950[112] + getitem_17866 = split_950[113] + getitem_17867 = split_950[114] + getitem_17868 = split_950[115]; split_950 = None + constant_pad_nd_1142 = torch.ops.aten.constant_pad_nd.default(getitem_17868, [0, 0, 0, 4], 0.0); getitem_17868 = None + cat_354 = torch.ops.aten.cat.default([getitem_17753, getitem_17754, getitem_17755, getitem_17756, getitem_17757, getitem_17758, getitem_17759, getitem_17760, getitem_17761, getitem_17762, getitem_17763, getitem_17764, getitem_17765, getitem_17766, getitem_17767, getitem_17768, getitem_17769, getitem_17770, getitem_17771, getitem_17772, getitem_17773, getitem_17774, getitem_17775, getitem_17776, getitem_17777, getitem_17778, getitem_17779, getitem_17780, getitem_17781, getitem_17782, getitem_17783, getitem_17784, getitem_17785, getitem_17786, getitem_17787, getitem_17788, getitem_17789, getitem_17790, getitem_17791, getitem_17792, getitem_17793, getitem_17794, getitem_17795, getitem_17796, getitem_17797, getitem_17798, getitem_17799, getitem_17800, getitem_17801, getitem_17802, getitem_17803, getitem_17804, getitem_17805, getitem_17806, getitem_17807, getitem_17808, getitem_17809, getitem_17810, getitem_17811, getitem_17812, getitem_17813, getitem_17814, getitem_17815, getitem_17816, getitem_17817, getitem_17818, getitem_17819, getitem_17820, getitem_17821, getitem_17822, getitem_17823, getitem_17824, getitem_17825, getitem_17826, getitem_17827, getitem_17828, getitem_17829, getitem_17830, getitem_17831, getitem_17832, getitem_17833, getitem_17834, getitem_17835, getitem_17836, getitem_17837, getitem_17838, getitem_17839, getitem_17840, getitem_17841, getitem_17842, getitem_17843, getitem_17844, getitem_17845, getitem_17846, getitem_17847, getitem_17848, getitem_17849, getitem_17850, getitem_17851, getitem_17852, getitem_17853, getitem_17854, getitem_17855, getitem_17856, getitem_17857, getitem_17858, getitem_17859, getitem_17860, getitem_17861, getitem_17862, getitem_17863, getitem_17864, getitem_17865, getitem_17866, getitem_17867, constant_pad_nd_1142, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_17753 = getitem_17754 = getitem_17755 = getitem_17756 = getitem_17757 = getitem_17758 = getitem_17759 = getitem_17760 = getitem_17761 = getitem_17762 = getitem_17763 = getitem_17764 = getitem_17765 = getitem_17766 = getitem_17767 = getitem_17768 = getitem_17769 = getitem_17770 = getitem_17771 = getitem_17772 = getitem_17773 = getitem_17774 = getitem_17775 = getitem_17776 = getitem_17777 = getitem_17778 = getitem_17779 = getitem_17780 = getitem_17781 = getitem_17782 = getitem_17783 = getitem_17784 = getitem_17785 = getitem_17786 = getitem_17787 = getitem_17788 = getitem_17789 = getitem_17790 = getitem_17791 = getitem_17792 = getitem_17793 = getitem_17794 = getitem_17795 = getitem_17796 = getitem_17797 = getitem_17798 = getitem_17799 = getitem_17800 = getitem_17801 = getitem_17802 = getitem_17803 = getitem_17804 = getitem_17805 = getitem_17806 = getitem_17807 = getitem_17808 = getitem_17809 = getitem_17810 = getitem_17811 = getitem_17812 = getitem_17813 = getitem_17814 = getitem_17815 = getitem_17816 = getitem_17817 = getitem_17818 = getitem_17819 = getitem_17820 = getitem_17821 = getitem_17822 = getitem_17823 = getitem_17824 = getitem_17825 = getitem_17826 = getitem_17827 = getitem_17828 = getitem_17829 = getitem_17830 = getitem_17831 = getitem_17832 = getitem_17833 = getitem_17834 = getitem_17835 = getitem_17836 = getitem_17837 = getitem_17838 = getitem_17839 = getitem_17840 = getitem_17841 = getitem_17842 = getitem_17843 = getitem_17844 = getitem_17845 = getitem_17846 = getitem_17847 = getitem_17848 = getitem_17849 = getitem_17850 = getitem_17851 = getitem_17852 = getitem_17853 = getitem_17854 = getitem_17855 = getitem_17856 = getitem_17857 = getitem_17858 = getitem_17859 = getitem_17860 = getitem_17861 = getitem_17862 = getitem_17863 = getitem_17864 = getitem_17865 = getitem_17866 = getitem_17867 = constant_pad_nd_1142 = None + reduce_scatter_tensor_209 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_354, 'avg', 128, '0'); cat_354 = None + wait_tensor_798 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_209); reduce_scatter_tensor_209 = None + slice_250 = torch.ops.aten.slice.Tensor(permute_1143, 3, 0, 128) + slice_251 = torch.ops.aten.slice.Tensor(permute_1143, 3, 128, 192); permute_1143 = None + convert_element_type_2555 = torch.ops.prims.convert_element_type.default(slice_251, torch.float32); slice_251 = None + view_2062 = torch.ops.aten.view.default(convert_element_type_2555, [2, 4096, 16, 32, 2]); convert_element_type_2555 = None + view_as_complex_83 = torch.ops.aten.view_as_complex.default(view_2062); view_2062 = None + mul_1794 = torch.ops.aten.mul.Tensor(view_as_complex_83, clone_9); view_as_complex_83 = None + view_as_real_83 = torch.ops.aten.view_as_real.default(mul_1794); mul_1794 = None + view_2063 = torch.ops.aten.view.default(view_as_real_83, [2, 4096, 16, 64]); view_as_real_83 = None + convert_element_type_2556 = torch.ops.prims.convert_element_type.default(view_2063, torch.bfloat16); view_2063 = None + cat_355 = torch.ops.aten.cat.default([slice_250, convert_element_type_2556], 3); slice_250 = convert_element_type_2556 = None + view_2064 = torch.ops.aten.view.default(cat_355, [2, 4096, 3072]); cat_355 = None + view_2065 = torch.ops.aten.view.default(view_2064, [8192, 3072]); view_2064 = None + permute_1152 = torch.ops.aten.permute.default(view_2065, [1, 0]) + mm_456 = torch.ops.aten.mm.default(permute_1152, view_773); permute_1152 = view_773 = None + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16); primals_200 = None + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 128, '0'); convert_element_type_634 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + permute_1154 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None + mm_457 = torch.ops.aten.mm.default(view_2065, permute_1154); view_2065 = permute_1154 = None + view_2066 = torch.ops.aten.view.default(mm_457, [2, 4096, 2048]); mm_457 = None + add_1998 = torch.ops.aten.add.Tensor(view_2061, view_2066); view_2061 = view_2066 = None + convert_element_type_2561 = torch.ops.prims.convert_element_type.default(mm_456, torch.float32); mm_456 = None + reduce_scatter_tensor_210 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2561, 'avg', 128, '0'); convert_element_type_2561 = None + wait_tensor_799 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_210); reduce_scatter_tensor_210 = None + convert_element_type_2562 = torch.ops.prims.convert_element_type.default(add_1998, torch.float32); add_1998 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16); primals_199 = None + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 128, '0'); convert_element_type_631 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + convert_element_type_2564 = torch.ops.prims.convert_element_type.default(wait_tensor_242, torch.float32); wait_tensor_242 = None + mul_1795 = torch.ops.aten.mul.Tensor(convert_element_type_2562, convert_element_type_2564); convert_element_type_2564 = None + convert_element_type_632 = torch.ops.prims.convert_element_type.default(add_753, torch.float32); add_753 = None + mul_548 = torch.ops.aten.mul.Tensor(convert_element_type_632, rsqrt_36); convert_element_type_632 = None + mul_1797 = torch.ops.aten.mul.Tensor(mul_548, mul_1795) + sum_225 = torch.ops.aten.sum.dim_IntList(mul_1797, [2], True); mul_1797 = None + div_221 = torch.ops.aten.div.Tensor(mul_548, 2048) + mul_1798 = torch.ops.aten.mul.Tensor(div_221, sum_225); div_221 = sum_225 = None + sub_714 = torch.ops.aten.sub.Tensor(mul_1795, mul_1798); mul_1795 = mul_1798 = None + mul_1799 = torch.ops.aten.mul.Tensor(sub_714, rsqrt_36); sub_714 = rsqrt_36 = None + mul_1800 = torch.ops.aten.mul.Tensor(convert_element_type_2562, mul_548); convert_element_type_2562 = mul_548 = None + sum_226 = torch.ops.aten.sum.dim_IntList(mul_1800, [0, 1]); mul_1800 = None + convert_element_type_2565 = torch.ops.prims.convert_element_type.default(mul_1799, torch.bfloat16); mul_1799 = None + add_1999 = torch.ops.aten.add.Tensor(add_1997, convert_element_type_2565); add_1997 = convert_element_type_2565 = None + convert_element_type_default_37 = torch.ops.prims.convert_element_type.default(sum_226, torch.float32); sum_226 = None + reduce_scatter_tensor_211 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_37, 'avg', 128, '0'); convert_element_type_default_37 = None + wait_tensor_800 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_211); reduce_scatter_tensor_211 = None + view_2067 = torch.ops.aten.view.default(add_1999, [8192, 2048]) + unsqueeze_68 = torch.ops.aten.unsqueeze.default(view_2067, 1) + convert_element_type_2568 = torch.ops.prims.convert_element_type.default(unsqueeze_68, torch.float32); unsqueeze_68 = None + bmm_56 = torch.ops.aten.bmm.default(permute_1156, convert_element_type_2568); permute_1156 = None + bmm_57 = torch.ops.aten.bmm.default(convert_element_type_2568, permute_1157); convert_element_type_2568 = permute_1157 = None + convert_element_type_2569 = torch.ops.prims.convert_element_type.default(bmm_56, torch.bfloat16); bmm_56 = None + view_2068 = torch.ops.aten.view.default(bmm_57, [8192, 6]); bmm_57 = None + view_2069 = torch.ops.aten.view.default(convert_element_type_2569, [49152, 2048]); convert_element_type_2569 = None + index_82 = torch.ops.aten.index.Tensor(view_2069, [getitem_1121]); view_2069 = getitem_1121 = None + permute_1158 = torch.ops.aten.permute.default(view_2067, [1, 0]) + mm_458 = torch.ops.aten.mm.default(permute_1158, mul_545); permute_1158 = mul_545 = None + convert_element_type_626 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16); primals_198 = None + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_626, 128, '0'); convert_element_type_626 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + permute_1160 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None + mm_459 = torch.ops.aten.mm.default(view_2067, permute_1160); view_2067 = permute_1160 = None + convert_element_type_2574 = torch.ops.prims.convert_element_type.default(mm_458, torch.float32); mm_458 = None + reduce_scatter_tensor_212 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2574, 'avg', 128, '0'); convert_element_type_2574 = None + wait_tensor_801 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_212); reduce_scatter_tensor_212 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mm_92, torch.float32); mm_92 = None + neg_22 = torch.ops.aten.neg.default(convert_element_type_621) + exp_33 = torch.ops.aten.exp.default(neg_22); neg_22 = None + add_748 = torch.ops.aten.add.Tensor(exp_33, 1); exp_33 = None + div_55 = torch.ops.aten.div.Tensor(convert_element_type_621, add_748) + convert_element_type_622 = torch.ops.prims.convert_element_type.default(div_55, torch.bfloat16); div_55 = None + mul_1801 = torch.ops.aten.mul.Tensor(mm_459, convert_element_type_622); convert_element_type_622 = None + mul_1802 = torch.ops.aten.mul.Tensor(mm_459, mm_93); mm_459 = mm_93 = None + permute_1162 = torch.ops.aten.permute.default(mul_1801, [1, 0]) + mm_460 = torch.ops.aten.mm.default(permute_1162, view_728); permute_1162 = None + convert_element_type_623 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16); primals_197 = None + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_623, 128, '0'); convert_element_type_623 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_240, [1, 0]); wait_tensor_240 = None + permute_1164 = torch.ops.aten.permute.default(permute_174, [1, 0]); permute_174 = None + mm_461 = torch.ops.aten.mm.default(mul_1801, permute_1164); mul_1801 = permute_1164 = None + convert_element_type_2579 = torch.ops.prims.convert_element_type.default(mm_460, torch.float32); mm_460 = None + reduce_scatter_tensor_213 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2579, 'avg', 128, '0'); convert_element_type_2579 = None + wait_tensor_802 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_213); reduce_scatter_tensor_213 = None + convert_element_type_2580 = torch.ops.prims.convert_element_type.default(mul_1802, torch.float32); mul_1802 = None + reciprocal_30 = torch.ops.aten.reciprocal.default(add_748); add_748 = None + mul_1803 = torch.ops.aten.mul.Tensor(reciprocal_30, 1); reciprocal_30 = None + mul_1804 = torch.ops.aten.mul.Tensor(convert_element_type_2580, mul_1803); convert_element_type_2580 = None + sub_715 = torch.ops.aten.sub.Tensor(1, mul_1803); mul_1803 = None + mul_1805 = torch.ops.aten.mul.Tensor(convert_element_type_621, sub_715); convert_element_type_621 = sub_715 = None + add_2001 = torch.ops.aten.add.Tensor(mul_1805, 1); mul_1805 = None + mul_1806 = torch.ops.aten.mul.Tensor(mul_1804, add_2001); mul_1804 = add_2001 = None + convert_element_type_2582 = torch.ops.prims.convert_element_type.default(mul_1806, torch.bfloat16); mul_1806 = None + permute_1166 = torch.ops.aten.permute.default(convert_element_type_2582, [1, 0]) + mm_462 = torch.ops.aten.mm.default(permute_1166, view_728); permute_1166 = None + convert_element_type_618 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_618, 128, '0'); convert_element_type_618 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + permute_1168 = torch.ops.aten.permute.default(permute_173, [1, 0]); permute_173 = None + mm_463 = torch.ops.aten.mm.default(convert_element_type_2582, permute_1168); convert_element_type_2582 = permute_1168 = None + add_2002 = torch.ops.aten.add.Tensor(mm_461, mm_463); mm_461 = mm_463 = None + convert_element_type_2587 = torch.ops.prims.convert_element_type.default(mm_462, torch.float32); mm_462 = None + reduce_scatter_tensor_214 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2587, 'avg', 128, '0'); convert_element_type_2587 = None + wait_tensor_803 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_214); reduce_scatter_tensor_214 = None + all_to_all_single_108 = torch.ops._c10d_functional.all_to_all_single.default(index_82, [_local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175], [_local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167], '1033'); index_82 = None + wait_tensor_804 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_108); all_to_all_single_108 = None + full_438 = torch.ops.aten.full.default([sym_size_int_41, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_41 = None + slice_scatter_15 = torch.ops.aten.slice_scatter.default(full_438, wait_tensor_804, 0, 0, -1); wait_tensor_804 = None + index_83 = torch.ops.aten.index.Tensor(slice_scatter_15, [getitem_1122]); slice_scatter_15 = None + permute_1170 = torch.ops.aten.permute.default(index_83, [1, 0]) + _grouped_mm_168 = torch.ops.aten._grouped_mm.default(permute_1170, mul_525, cumsum_32); permute_1170 = mul_525 = None + _grouped_mm_169 = torch.ops.aten._grouped_mm.default(index_83, permute_1172, cumsum_32); index_83 = permute_1172 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(_grouped_mm_30, torch.float32); _grouped_mm_30 = None + neg_21 = torch.ops.aten.neg.default(convert_element_type_616) + exp_32 = torch.ops.aten.exp.default(neg_21); neg_21 = None + add_712 = torch.ops.aten.add.Tensor(exp_32, 1); exp_32 = None + div_54 = torch.ops.aten.div.Tensor(convert_element_type_616, add_712) + convert_element_type_617 = torch.ops.prims.convert_element_type.default(div_54, torch.bfloat16); div_54 = None + mul_1807 = torch.ops.aten.mul.Tensor(_grouped_mm_169, convert_element_type_617); convert_element_type_617 = None + mul_1808 = torch.ops.aten.mul.Tensor(_grouped_mm_169, _grouped_mm_31); _grouped_mm_169 = _grouped_mm_31 = None + permute_1174 = torch.ops.aten.permute.default(mul_1807, [1, 0]) + _grouped_mm_170 = torch.ops.aten._grouped_mm.default(permute_1174, index_21, cumsum_32); permute_1174 = None + _grouped_mm_171 = torch.ops.aten._grouped_mm.default(mul_1807, permute_1176, cumsum_32); mul_1807 = permute_1176 = None + convert_element_type_2588 = torch.ops.prims.convert_element_type.default(mul_1808, torch.float32); mul_1808 = None + reciprocal_31 = torch.ops.aten.reciprocal.default(add_712); add_712 = None + mul_1809 = torch.ops.aten.mul.Tensor(reciprocal_31, 1); reciprocal_31 = None + mul_1810 = torch.ops.aten.mul.Tensor(convert_element_type_2588, mul_1809); convert_element_type_2588 = None + sub_716 = torch.ops.aten.sub.Tensor(1, mul_1809); mul_1809 = None + mul_1811 = torch.ops.aten.mul.Tensor(convert_element_type_616, sub_716); convert_element_type_616 = sub_716 = None + add_2004 = torch.ops.aten.add.Tensor(mul_1811, 1); mul_1811 = None + mul_1812 = torch.ops.aten.mul.Tensor(mul_1810, add_2004); mul_1810 = add_2004 = None + convert_element_type_2590 = torch.ops.prims.convert_element_type.default(mul_1812, torch.bfloat16); mul_1812 = None + permute_1178 = torch.ops.aten.permute.default(convert_element_type_2590, [1, 0]) + _grouped_mm_172 = torch.ops.aten._grouped_mm.default(permute_1178, index_21, cumsum_32); permute_1178 = index_21 = None + _grouped_mm_173 = torch.ops.aten._grouped_mm.default(convert_element_type_2590, permute_1180, cumsum_32); convert_element_type_2590 = permute_1180 = cumsum_32 = None + add_2005 = torch.ops.aten.add.Tensor(_grouped_mm_171, _grouped_mm_173); _grouped_mm_171 = _grouped_mm_173 = None + convert_element_type_2591 = torch.ops.prims.convert_element_type.default(_grouped_mm_170, torch.float32); _grouped_mm_170 = None + div_222 = torch.ops.aten.div.Tensor(convert_element_type_2591, 128); convert_element_type_2591 = None + split_952 = torch.ops.aten.split.Tensor(div_222, 88, 1); div_222 = None + getitem_17885 = split_952[0] + getitem_17902 = split_952[1] + getitem_17919 = split_952[2] + getitem_17936 = split_952[3] + getitem_17953 = split_952[4] + getitem_17970 = split_952[5] + getitem_17987 = split_952[6] + getitem_18004 = split_952[7] + getitem_18021 = split_952[8] + getitem_18038 = split_952[9] + getitem_18055 = split_952[10] + getitem_18072 = split_952[11] + getitem_18089 = split_952[12] + getitem_18106 = split_952[13] + getitem_18123 = split_952[14] + getitem_18140 = split_952[15]; split_952 = None + cat_356 = torch.ops.aten.cat.default([getitem_17885, getitem_17902, getitem_17919, getitem_17936, getitem_17953, getitem_17970, getitem_17987, getitem_18004, getitem_18021, getitem_18038, getitem_18055, getitem_18072, getitem_18089, getitem_18106, getitem_18123, getitem_18140]); getitem_17885 = getitem_17902 = getitem_17919 = getitem_17936 = getitem_17953 = getitem_17970 = getitem_17987 = getitem_18004 = getitem_18021 = getitem_18038 = getitem_18055 = getitem_18072 = getitem_18089 = getitem_18106 = getitem_18123 = getitem_18140 = None + reduce_scatter_tensor_215 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_356, 'sum', 16, '1025'); cat_356 = None + wait_tensor_805 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_215); reduce_scatter_tensor_215 = None + convert_element_type_2592 = torch.ops.prims.convert_element_type.default(_grouped_mm_168, torch.float32); _grouped_mm_168 = None + div_223 = torch.ops.aten.div.Tensor(convert_element_type_2592, 128); convert_element_type_2592 = None + split_969 = torch.ops.aten.split.Tensor(div_223, 128, 1); div_223 = None + getitem_18157 = split_969[0] + getitem_18174 = split_969[1] + getitem_18191 = split_969[2] + getitem_18208 = split_969[3] + getitem_18225 = split_969[4] + getitem_18242 = split_969[5] + getitem_18259 = split_969[6] + getitem_18276 = split_969[7] + getitem_18293 = split_969[8] + getitem_18310 = split_969[9] + getitem_18327 = split_969[10] + getitem_18344 = split_969[11] + getitem_18361 = split_969[12] + getitem_18378 = split_969[13] + getitem_18395 = split_969[14] + getitem_18412 = split_969[15]; split_969 = None + cat_357 = torch.ops.aten.cat.default([getitem_18157, getitem_18174, getitem_18191, getitem_18208, getitem_18225, getitem_18242, getitem_18259, getitem_18276, getitem_18293, getitem_18310, getitem_18327, getitem_18344, getitem_18361, getitem_18378, getitem_18395, getitem_18412]); getitem_18157 = getitem_18174 = getitem_18191 = getitem_18208 = getitem_18225 = getitem_18242 = getitem_18259 = getitem_18276 = getitem_18293 = getitem_18310 = getitem_18327 = getitem_18344 = getitem_18361 = getitem_18378 = getitem_18395 = getitem_18412 = None + reduce_scatter_tensor_216 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_357, 'sum', 16, '1025'); cat_357 = None + wait_tensor_806 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_216); reduce_scatter_tensor_216 = None + convert_element_type_2593 = torch.ops.prims.convert_element_type.default(_grouped_mm_172, torch.float32); _grouped_mm_172 = None + div_224 = torch.ops.aten.div.Tensor(convert_element_type_2593, 128); convert_element_type_2593 = None + split_986 = torch.ops.aten.split.Tensor(div_224, 88, 1); div_224 = None + getitem_18429 = split_986[0] + getitem_18446 = split_986[1] + getitem_18463 = split_986[2] + getitem_18480 = split_986[3] + getitem_18497 = split_986[4] + getitem_18514 = split_986[5] + getitem_18531 = split_986[6] + getitem_18548 = split_986[7] + getitem_18565 = split_986[8] + getitem_18582 = split_986[9] + getitem_18599 = split_986[10] + getitem_18616 = split_986[11] + getitem_18633 = split_986[12] + getitem_18650 = split_986[13] + getitem_18667 = split_986[14] + getitem_18684 = split_986[15]; split_986 = None + cat_358 = torch.ops.aten.cat.default([getitem_18429, getitem_18446, getitem_18463, getitem_18480, getitem_18497, getitem_18514, getitem_18531, getitem_18548, getitem_18565, getitem_18582, getitem_18599, getitem_18616, getitem_18633, getitem_18650, getitem_18667, getitem_18684]); getitem_18429 = getitem_18446 = getitem_18463 = getitem_18480 = getitem_18497 = getitem_18514 = getitem_18531 = getitem_18548 = getitem_18565 = getitem_18582 = getitem_18599 = getitem_18616 = getitem_18633 = getitem_18650 = getitem_18667 = getitem_18684 = None + reduce_scatter_tensor_217 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_358, 'sum', 16, '1025'); cat_358 = None + wait_tensor_807 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_217); reduce_scatter_tensor_217 = None + index_put_82 = torch.ops.aten.index_put.default(full_438, [getitem_1122], add_2005, True); full_438 = getitem_1122 = add_2005 = None + slice_252 = torch.ops.aten.slice.Tensor(index_put_82, 0, 0, add_2006); index_put_82 = add_2006 = None + all_to_all_single_109 = torch.ops._c10d_functional.all_to_all_single.default(slice_252, [_local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167], [_local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175], '1033'); slice_252 = _local_scalar_dense_160 = _local_scalar_dense_161 = _local_scalar_dense_162 = _local_scalar_dense_163 = _local_scalar_dense_164 = _local_scalar_dense_165 = _local_scalar_dense_166 = _local_scalar_dense_167 = _local_scalar_dense_168 = _local_scalar_dense_169 = _local_scalar_dense_170 = _local_scalar_dense_171 = _local_scalar_dense_172 = _local_scalar_dense_173 = _local_scalar_dense_174 = _local_scalar_dense_175 = None + wait_tensor_808 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_109); all_to_all_single_109 = None + index_put_83 = torch.ops.aten.index_put.default(full_default_52, [div_52], wait_tensor_808, True); div_52 = wait_tensor_808 = None + add_2010 = torch.ops.aten.add.Tensor(add_2002, index_put_83); add_2002 = index_put_83 = None + mul_1813 = torch.ops.aten.mul.Tensor(view_2068, 1.0); view_2068 = None + scatter_add_15 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_1119, mul_1813); getitem_1119 = mul_1813 = None + convert_element_type_605 = torch.ops.prims.convert_element_type.default(mm_91, torch.float32); mm_91 = None + sub_240 = torch.ops.aten.sub.Tensor(convert_element_type_605, amax_10); convert_element_type_605 = amax_10 = None + exp_31 = torch.ops.aten.exp.default(sub_240); sub_240 = None + div_51 = torch.ops.aten.div.Tensor(exp_31, sum_41); exp_31 = sum_41 = None + mul_1814 = torch.ops.aten.mul.Tensor(scatter_add_15, div_51); scatter_add_15 = None + sum_227 = torch.ops.aten.sum.dim_IntList(mul_1814, [1], True) + neg_100 = torch.ops.aten.neg.default(div_51); div_51 = None + fma_15 = torch.ops.prims.fma.default(neg_100, sum_227, mul_1814); neg_100 = sum_227 = mul_1814 = None + convert_element_type_2594 = torch.ops.prims.convert_element_type.default(fma_15, torch.bfloat16); fma_15 = None + permute_1182 = torch.ops.aten.permute.default(convert_element_type_2594, [1, 0]) + mm_464 = torch.ops.aten.mm.default(permute_1182, view_728); permute_1182 = view_728 = None + convert_element_type_602 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16); primals_191 = None + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_602, 128, '0'); convert_element_type_602 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + slice_69 = torch.ops.aten.slice.Tensor(wait_tensor_228, 0, 0, 64); wait_tensor_228 = None + permute_169 = torch.ops.aten.permute.default(slice_69, [1, 0]); slice_69 = None + permute_1184 = torch.ops.aten.permute.default(permute_169, [1, 0]); permute_169 = None + mm_465 = torch.ops.aten.mm.default(convert_element_type_2594, permute_1184); convert_element_type_2594 = permute_1184 = None + add_2011 = torch.ops.aten.add.Tensor(add_2010, mm_465); add_2010 = mm_465 = None + convert_element_type_2599 = torch.ops.prims.convert_element_type.default(mm_464, torch.float32); mm_464 = None + split_1002 = torch.ops.aten.split.Tensor(convert_element_type_2599, 1); convert_element_type_2599 = None + getitem_18685 = split_1002[0] + getitem_18686 = split_1002[1] + getitem_18687 = split_1002[2] + getitem_18688 = split_1002[3] + getitem_18689 = split_1002[4] + getitem_18690 = split_1002[5] + getitem_18691 = split_1002[6] + getitem_18692 = split_1002[7] + getitem_18693 = split_1002[8] + getitem_18694 = split_1002[9] + getitem_18695 = split_1002[10] + getitem_18696 = split_1002[11] + getitem_18697 = split_1002[12] + getitem_18698 = split_1002[13] + getitem_18699 = split_1002[14] + getitem_18700 = split_1002[15] + getitem_18701 = split_1002[16] + getitem_18702 = split_1002[17] + getitem_18703 = split_1002[18] + getitem_18704 = split_1002[19] + getitem_18705 = split_1002[20] + getitem_18706 = split_1002[21] + getitem_18707 = split_1002[22] + getitem_18708 = split_1002[23] + getitem_18709 = split_1002[24] + getitem_18710 = split_1002[25] + getitem_18711 = split_1002[26] + getitem_18712 = split_1002[27] + getitem_18713 = split_1002[28] + getitem_18714 = split_1002[29] + getitem_18715 = split_1002[30] + getitem_18716 = split_1002[31] + getitem_18717 = split_1002[32] + getitem_18718 = split_1002[33] + getitem_18719 = split_1002[34] + getitem_18720 = split_1002[35] + getitem_18721 = split_1002[36] + getitem_18722 = split_1002[37] + getitem_18723 = split_1002[38] + getitem_18724 = split_1002[39] + getitem_18725 = split_1002[40] + getitem_18726 = split_1002[41] + getitem_18727 = split_1002[42] + getitem_18728 = split_1002[43] + getitem_18729 = split_1002[44] + getitem_18730 = split_1002[45] + getitem_18731 = split_1002[46] + getitem_18732 = split_1002[47] + getitem_18733 = split_1002[48] + getitem_18734 = split_1002[49] + getitem_18735 = split_1002[50] + getitem_18736 = split_1002[51] + getitem_18737 = split_1002[52] + getitem_18738 = split_1002[53] + getitem_18739 = split_1002[54] + getitem_18740 = split_1002[55] + getitem_18741 = split_1002[56] + getitem_18742 = split_1002[57] + getitem_18743 = split_1002[58] + getitem_18744 = split_1002[59] + getitem_18745 = split_1002[60] + getitem_18746 = split_1002[61] + getitem_18747 = split_1002[62] + getitem_18748 = split_1002[63]; split_1002 = None + cat_359 = torch.ops.aten.cat.default([getitem_18685, getitem_18686, getitem_18687, getitem_18688, getitem_18689, getitem_18690, getitem_18691, getitem_18692, getitem_18693, getitem_18694, getitem_18695, getitem_18696, getitem_18697, getitem_18698, getitem_18699, getitem_18700, getitem_18701, getitem_18702, getitem_18703, getitem_18704, getitem_18705, getitem_18706, getitem_18707, getitem_18708, getitem_18709, getitem_18710, getitem_18711, getitem_18712, getitem_18713, getitem_18714, getitem_18715, getitem_18716, getitem_18717, getitem_18718, getitem_18719, getitem_18720, getitem_18721, getitem_18722, getitem_18723, getitem_18724, getitem_18725, getitem_18726, getitem_18727, getitem_18728, getitem_18729, getitem_18730, getitem_18731, getitem_18732, getitem_18733, getitem_18734, getitem_18735, getitem_18736, getitem_18737, getitem_18738, getitem_18739, getitem_18740, getitem_18741, getitem_18742, getitem_18743, getitem_18744, getitem_18745, getitem_18746, getitem_18747, getitem_18748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_18685 = getitem_18686 = getitem_18687 = getitem_18688 = getitem_18689 = getitem_18690 = getitem_18691 = getitem_18692 = getitem_18693 = getitem_18694 = getitem_18695 = getitem_18696 = getitem_18697 = getitem_18698 = getitem_18699 = getitem_18700 = getitem_18701 = getitem_18702 = getitem_18703 = getitem_18704 = getitem_18705 = getitem_18706 = getitem_18707 = getitem_18708 = getitem_18709 = getitem_18710 = getitem_18711 = getitem_18712 = getitem_18713 = getitem_18714 = getitem_18715 = getitem_18716 = getitem_18717 = getitem_18718 = getitem_18719 = getitem_18720 = getitem_18721 = getitem_18722 = getitem_18723 = getitem_18724 = getitem_18725 = getitem_18726 = getitem_18727 = getitem_18728 = getitem_18729 = getitem_18730 = getitem_18731 = getitem_18732 = getitem_18733 = getitem_18734 = getitem_18735 = getitem_18736 = getitem_18737 = getitem_18738 = getitem_18739 = getitem_18740 = getitem_18741 = getitem_18742 = getitem_18743 = getitem_18744 = getitem_18745 = getitem_18746 = getitem_18747 = getitem_18748 = None + reduce_scatter_tensor_218 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_359, 'avg', 128, '0'); cat_359 = None + wait_tensor_809 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_218); reduce_scatter_tensor_218 = None + view_2070 = torch.ops.aten.view.default(add_2011, [2, 4096, 2048]); add_2011 = None + convert_element_type_2600 = torch.ops.prims.convert_element_type.default(view_2070, torch.float32); view_2070 = None + convert_element_type_599 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16); primals_189 = None + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_599, 128, '0'); convert_element_type_599 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + convert_element_type_2602 = torch.ops.prims.convert_element_type.default(wait_tensor_227, torch.float32); wait_tensor_227 = None + mul_1815 = torch.ops.aten.mul.Tensor(convert_element_type_2600, convert_element_type_2602); convert_element_type_2602 = None + convert_element_type_600 = torch.ops.prims.convert_element_type.default(add_688, torch.float32); add_688 = None + mul_505 = torch.ops.aten.mul.Tensor(convert_element_type_600, rsqrt_35); convert_element_type_600 = None + mul_1817 = torch.ops.aten.mul.Tensor(mul_505, mul_1815) + sum_228 = torch.ops.aten.sum.dim_IntList(mul_1817, [2], True); mul_1817 = None + div_225 = torch.ops.aten.div.Tensor(mul_505, 2048) + mul_1818 = torch.ops.aten.mul.Tensor(div_225, sum_228); div_225 = sum_228 = None + sub_718 = torch.ops.aten.sub.Tensor(mul_1815, mul_1818); mul_1815 = mul_1818 = None + mul_1819 = torch.ops.aten.mul.Tensor(sub_718, rsqrt_35); sub_718 = rsqrt_35 = None + mul_1820 = torch.ops.aten.mul.Tensor(convert_element_type_2600, mul_505); convert_element_type_2600 = mul_505 = None + sum_229 = torch.ops.aten.sum.dim_IntList(mul_1820, [0, 1]); mul_1820 = None + convert_element_type_2603 = torch.ops.prims.convert_element_type.default(mul_1819, torch.bfloat16); mul_1819 = None + add_2012 = torch.ops.aten.add.Tensor(add_1999, convert_element_type_2603); add_1999 = convert_element_type_2603 = None + convert_element_type_default_36 = torch.ops.prims.convert_element_type.default(sum_229, torch.float32); sum_229 = None + reduce_scatter_tensor_219 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_36, 'avg', 128, '0'); convert_element_type_default_36 = None + wait_tensor_810 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_219); reduce_scatter_tensor_219 = None + view_2071 = torch.ops.aten.view.default(add_2012, [8192, 2048]) + permute_1186 = torch.ops.aten.permute.default(view_2071, [1, 0]) + permute_167 = torch.ops.aten.permute.default(getitem_1115, [0, 2, 1, 3]) + view_723 = torch.ops.aten.view.default(permute_167, [2, 4096, -1]); permute_167 = None + view_725 = torch.ops.aten.view.default(view_723, [8192, 2048]); view_723 = None + mm_466 = torch.ops.aten.mm.default(permute_1186, view_725); permute_1186 = view_725 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16); primals_188 = None + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_596, 128, '0'); convert_element_type_596 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + permute_168 = torch.ops.aten.permute.default(wait_tensor_226, [1, 0]); wait_tensor_226 = None + permute_1188 = torch.ops.aten.permute.default(permute_168, [1, 0]); permute_168 = None + mm_467 = torch.ops.aten.mm.default(view_2071, permute_1188); view_2071 = permute_1188 = None + view_2072 = torch.ops.aten.view.default(mm_467, [2, 4096, 2048]); mm_467 = None + convert_element_type_2610 = torch.ops.prims.convert_element_type.default(mm_466, torch.float32); mm_466 = None + reduce_scatter_tensor_220 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2610, 'avg', 128, '0'); convert_element_type_2610 = None + wait_tensor_811 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_220); reduce_scatter_tensor_220 = None + view_2073 = torch.ops.aten.view.default(view_2072, [2, 4096, 16, 128]); view_2072 = None + permute_1190 = torch.ops.aten.permute.default(view_2073, [0, 2, 1, 3]); view_2073 = None + fw_graph15 = self.fw_graph15 + joint_graph15 = self.joint_graph15 + mask_graph15 = self.mask_graph15 + flex_attention_backward_15 = torch.ops.higher_order.flex_attention_backward(permute_164, permute_165, permute_166, getitem_1115, getitem_1116, permute_1190, None, fw_graph15, joint_graph15, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph15), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_164 = permute_165 = permute_166 = getitem_1115 = getitem_1116 = permute_1190 = fw_graph15 = joint_graph15 = mask_graph15 = None + getitem_18749 = flex_attention_backward_15[0] + getitem_18750 = flex_attention_backward_15[1] + getitem_18751 = flex_attention_backward_15[2]; flex_attention_backward_15 = None + permute_1191 = torch.ops.aten.permute.default(getitem_18751, [0, 2, 1, 3]); getitem_18751 = None + permute_1192 = torch.ops.aten.permute.default(getitem_18750, [0, 2, 1, 3]); getitem_18750 = None + permute_1193 = torch.ops.aten.permute.default(getitem_18749, [0, 2, 1, 3]); getitem_18749 = None + slice_254 = torch.ops.aten.slice.Tensor(permute_1192, 3, 0, 128) + slice_255 = torch.ops.aten.slice.Tensor(permute_1192, 3, 128, 192); permute_1192 = None + sum_230 = torch.ops.aten.sum.dim_IntList(slice_255, [2], True); slice_255 = None + cat_360 = torch.ops.aten.cat.default([slice_254, permute_1191], 3); slice_254 = permute_1191 = None + view_2074 = torch.ops.aten.view.default(cat_360, [2, 4096, 4096]); cat_360 = None + view_2075 = torch.ops.aten.view.default(view_2074, [8192, 4096]); view_2074 = None + permute_1194 = torch.ops.aten.permute.default(view_2075, [1, 0]) + mm_468 = torch.ops.aten.mm.default(permute_1194, view_720); permute_1194 = view_720 = None + convert_element_type_593 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16); primals_187 = None + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_593, 128, '0'); convert_element_type_593 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + permute_1196 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None + mm_469 = torch.ops.aten.mm.default(view_2075, permute_1196); view_2075 = permute_1196 = None + view_2076 = torch.ops.aten.view.default(mm_469, [2, 4096, 512]); mm_469 = None + convert_element_type_2615 = torch.ops.prims.convert_element_type.default(mm_468, torch.float32); mm_468 = None + reduce_scatter_tensor_221 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2615, 'avg', 128, '0'); convert_element_type_2615 = None + wait_tensor_812 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_221); reduce_scatter_tensor_221 = None + convert_element_type_2616 = torch.ops.prims.convert_element_type.default(view_2076, torch.float32); view_2076 = None + convert_element_type_590 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16); primals_186 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_590, 128, '0'); convert_element_type_590 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + convert_element_type_2618 = torch.ops.prims.convert_element_type.default(wait_tensor_224, torch.float32); wait_tensor_224 = None + mul_1821 = torch.ops.aten.mul.Tensor(convert_element_type_2616, convert_element_type_2618); convert_element_type_2618 = None + convert_element_type_591 = torch.ops.prims.convert_element_type.default(getitem_1111, torch.float32); getitem_1111 = None + mul_503 = torch.ops.aten.mul.Tensor(convert_element_type_591, rsqrt_34); convert_element_type_591 = None + mul_1823 = torch.ops.aten.mul.Tensor(mul_503, mul_1821) + sum_231 = torch.ops.aten.sum.dim_IntList(mul_1823, [2], True); mul_1823 = None + div_226 = torch.ops.aten.div.Tensor(mul_503, 512) + mul_1824 = torch.ops.aten.mul.Tensor(div_226, sum_231); div_226 = sum_231 = None + sub_719 = torch.ops.aten.sub.Tensor(mul_1821, mul_1824); mul_1821 = mul_1824 = None + mul_1825 = torch.ops.aten.mul.Tensor(sub_719, rsqrt_34); sub_719 = rsqrt_34 = None + mul_1826 = torch.ops.aten.mul.Tensor(convert_element_type_2616, mul_503); convert_element_type_2616 = mul_503 = None + sum_232 = torch.ops.aten.sum.dim_IntList(mul_1826, [0, 1]); mul_1826 = None + convert_element_type_2619 = torch.ops.prims.convert_element_type.default(mul_1825, torch.bfloat16); mul_1825 = None + convert_element_type_default_35 = torch.ops.prims.convert_element_type.default(sum_232, torch.float32); sum_232 = None + reduce_scatter_tensor_222 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_35, 'avg', 128, '0'); convert_element_type_default_35 = None + wait_tensor_813 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_222); reduce_scatter_tensor_222 = None + convert_element_type_2622 = torch.ops.prims.convert_element_type.default(sum_230, torch.float32); sum_230 = None + view_2077 = torch.ops.aten.view.default(convert_element_type_2622, [2, 4096, 1, 32, 2]); convert_element_type_2622 = None + view_as_complex_84 = torch.ops.aten.view_as_complex.default(view_2077); view_2077 = None + mul_1827 = torch.ops.aten.mul.Tensor(view_as_complex_84, clone_9); view_as_complex_84 = None + view_as_real_84 = torch.ops.aten.view_as_real.default(mul_1827); mul_1827 = None + view_2078 = torch.ops.aten.view.default(view_as_real_84, [2, 4096, 1, 64]); view_as_real_84 = None + convert_element_type_2623 = torch.ops.prims.convert_element_type.default(view_2078, torch.bfloat16); view_2078 = None + squeeze_41 = torch.ops.aten.squeeze.dim(convert_element_type_2623, 2); convert_element_type_2623 = None + cat_361 = torch.ops.aten.cat.default([convert_element_type_2619, squeeze_41], 2); convert_element_type_2619 = squeeze_41 = None + view_2079 = torch.ops.aten.view.default(cat_361, [8192, 576]); cat_361 = None + permute_1198 = torch.ops.aten.permute.default(view_2079, [1, 0]) + mm_470 = torch.ops.aten.mm.default(permute_1198, view_706); permute_1198 = None + convert_element_type_585 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16); primals_185 = None + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_585, 128, '0'); convert_element_type_585 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + slice_67 = torch.ops.aten.slice.Tensor(wait_tensor_223, 0, 0, 576); wait_tensor_223 = None + permute_162 = torch.ops.aten.permute.default(slice_67, [1, 0]); slice_67 = None + permute_1200 = torch.ops.aten.permute.default(permute_162, [1, 0]); permute_162 = None + mm_471 = torch.ops.aten.mm.default(view_2079, permute_1200); view_2079 = permute_1200 = None + view_2080 = torch.ops.aten.view.default(mm_471, [2, 4096, 2048]); mm_471 = None + convert_element_type_2628 = torch.ops.prims.convert_element_type.default(mm_470, torch.float32); mm_470 = None + split_1003 = torch.ops.aten.split.Tensor(convert_element_type_2628, 5); convert_element_type_2628 = None + getitem_18753 = split_1003[0] + getitem_18754 = split_1003[1] + getitem_18755 = split_1003[2] + getitem_18756 = split_1003[3] + getitem_18757 = split_1003[4] + getitem_18758 = split_1003[5] + getitem_18759 = split_1003[6] + getitem_18760 = split_1003[7] + getitem_18761 = split_1003[8] + getitem_18762 = split_1003[9] + getitem_18763 = split_1003[10] + getitem_18764 = split_1003[11] + getitem_18765 = split_1003[12] + getitem_18766 = split_1003[13] + getitem_18767 = split_1003[14] + getitem_18768 = split_1003[15] + getitem_18769 = split_1003[16] + getitem_18770 = split_1003[17] + getitem_18771 = split_1003[18] + getitem_18772 = split_1003[19] + getitem_18773 = split_1003[20] + getitem_18774 = split_1003[21] + getitem_18775 = split_1003[22] + getitem_18776 = split_1003[23] + getitem_18777 = split_1003[24] + getitem_18778 = split_1003[25] + getitem_18779 = split_1003[26] + getitem_18780 = split_1003[27] + getitem_18781 = split_1003[28] + getitem_18782 = split_1003[29] + getitem_18783 = split_1003[30] + getitem_18784 = split_1003[31] + getitem_18785 = split_1003[32] + getitem_18786 = split_1003[33] + getitem_18787 = split_1003[34] + getitem_18788 = split_1003[35] + getitem_18789 = split_1003[36] + getitem_18790 = split_1003[37] + getitem_18791 = split_1003[38] + getitem_18792 = split_1003[39] + getitem_18793 = split_1003[40] + getitem_18794 = split_1003[41] + getitem_18795 = split_1003[42] + getitem_18796 = split_1003[43] + getitem_18797 = split_1003[44] + getitem_18798 = split_1003[45] + getitem_18799 = split_1003[46] + getitem_18800 = split_1003[47] + getitem_18801 = split_1003[48] + getitem_18802 = split_1003[49] + getitem_18803 = split_1003[50] + getitem_18804 = split_1003[51] + getitem_18805 = split_1003[52] + getitem_18806 = split_1003[53] + getitem_18807 = split_1003[54] + getitem_18808 = split_1003[55] + getitem_18809 = split_1003[56] + getitem_18810 = split_1003[57] + getitem_18811 = split_1003[58] + getitem_18812 = split_1003[59] + getitem_18813 = split_1003[60] + getitem_18814 = split_1003[61] + getitem_18815 = split_1003[62] + getitem_18816 = split_1003[63] + getitem_18817 = split_1003[64] + getitem_18818 = split_1003[65] + getitem_18819 = split_1003[66] + getitem_18820 = split_1003[67] + getitem_18821 = split_1003[68] + getitem_18822 = split_1003[69] + getitem_18823 = split_1003[70] + getitem_18824 = split_1003[71] + getitem_18825 = split_1003[72] + getitem_18826 = split_1003[73] + getitem_18827 = split_1003[74] + getitem_18828 = split_1003[75] + getitem_18829 = split_1003[76] + getitem_18830 = split_1003[77] + getitem_18831 = split_1003[78] + getitem_18832 = split_1003[79] + getitem_18833 = split_1003[80] + getitem_18834 = split_1003[81] + getitem_18835 = split_1003[82] + getitem_18836 = split_1003[83] + getitem_18837 = split_1003[84] + getitem_18838 = split_1003[85] + getitem_18839 = split_1003[86] + getitem_18840 = split_1003[87] + getitem_18841 = split_1003[88] + getitem_18842 = split_1003[89] + getitem_18843 = split_1003[90] + getitem_18844 = split_1003[91] + getitem_18845 = split_1003[92] + getitem_18846 = split_1003[93] + getitem_18847 = split_1003[94] + getitem_18848 = split_1003[95] + getitem_18849 = split_1003[96] + getitem_18850 = split_1003[97] + getitem_18851 = split_1003[98] + getitem_18852 = split_1003[99] + getitem_18853 = split_1003[100] + getitem_18854 = split_1003[101] + getitem_18855 = split_1003[102] + getitem_18856 = split_1003[103] + getitem_18857 = split_1003[104] + getitem_18858 = split_1003[105] + getitem_18859 = split_1003[106] + getitem_18860 = split_1003[107] + getitem_18861 = split_1003[108] + getitem_18862 = split_1003[109] + getitem_18863 = split_1003[110] + getitem_18864 = split_1003[111] + getitem_18865 = split_1003[112] + getitem_18866 = split_1003[113] + getitem_18867 = split_1003[114] + getitem_18868 = split_1003[115]; split_1003 = None + constant_pad_nd_1219 = torch.ops.aten.constant_pad_nd.default(getitem_18868, [0, 0, 0, 4], 0.0); getitem_18868 = None + cat_362 = torch.ops.aten.cat.default([getitem_18753, getitem_18754, getitem_18755, getitem_18756, getitem_18757, getitem_18758, getitem_18759, getitem_18760, getitem_18761, getitem_18762, getitem_18763, getitem_18764, getitem_18765, getitem_18766, getitem_18767, getitem_18768, getitem_18769, getitem_18770, getitem_18771, getitem_18772, getitem_18773, getitem_18774, getitem_18775, getitem_18776, getitem_18777, getitem_18778, getitem_18779, getitem_18780, getitem_18781, getitem_18782, getitem_18783, getitem_18784, getitem_18785, getitem_18786, getitem_18787, getitem_18788, getitem_18789, getitem_18790, getitem_18791, getitem_18792, getitem_18793, getitem_18794, getitem_18795, getitem_18796, getitem_18797, getitem_18798, getitem_18799, getitem_18800, getitem_18801, getitem_18802, getitem_18803, getitem_18804, getitem_18805, getitem_18806, getitem_18807, getitem_18808, getitem_18809, getitem_18810, getitem_18811, getitem_18812, getitem_18813, getitem_18814, getitem_18815, getitem_18816, getitem_18817, getitem_18818, getitem_18819, getitem_18820, getitem_18821, getitem_18822, getitem_18823, getitem_18824, getitem_18825, getitem_18826, getitem_18827, getitem_18828, getitem_18829, getitem_18830, getitem_18831, getitem_18832, getitem_18833, getitem_18834, getitem_18835, getitem_18836, getitem_18837, getitem_18838, getitem_18839, getitem_18840, getitem_18841, getitem_18842, getitem_18843, getitem_18844, getitem_18845, getitem_18846, getitem_18847, getitem_18848, getitem_18849, getitem_18850, getitem_18851, getitem_18852, getitem_18853, getitem_18854, getitem_18855, getitem_18856, getitem_18857, getitem_18858, getitem_18859, getitem_18860, getitem_18861, getitem_18862, getitem_18863, getitem_18864, getitem_18865, getitem_18866, getitem_18867, constant_pad_nd_1219, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_18753 = getitem_18754 = getitem_18755 = getitem_18756 = getitem_18757 = getitem_18758 = getitem_18759 = getitem_18760 = getitem_18761 = getitem_18762 = getitem_18763 = getitem_18764 = getitem_18765 = getitem_18766 = getitem_18767 = getitem_18768 = getitem_18769 = getitem_18770 = getitem_18771 = getitem_18772 = getitem_18773 = getitem_18774 = getitem_18775 = getitem_18776 = getitem_18777 = getitem_18778 = getitem_18779 = getitem_18780 = getitem_18781 = getitem_18782 = getitem_18783 = getitem_18784 = getitem_18785 = getitem_18786 = getitem_18787 = getitem_18788 = getitem_18789 = getitem_18790 = getitem_18791 = getitem_18792 = getitem_18793 = getitem_18794 = getitem_18795 = getitem_18796 = getitem_18797 = getitem_18798 = getitem_18799 = getitem_18800 = getitem_18801 = getitem_18802 = getitem_18803 = getitem_18804 = getitem_18805 = getitem_18806 = getitem_18807 = getitem_18808 = getitem_18809 = getitem_18810 = getitem_18811 = getitem_18812 = getitem_18813 = getitem_18814 = getitem_18815 = getitem_18816 = getitem_18817 = getitem_18818 = getitem_18819 = getitem_18820 = getitem_18821 = getitem_18822 = getitem_18823 = getitem_18824 = getitem_18825 = getitem_18826 = getitem_18827 = getitem_18828 = getitem_18829 = getitem_18830 = getitem_18831 = getitem_18832 = getitem_18833 = getitem_18834 = getitem_18835 = getitem_18836 = getitem_18837 = getitem_18838 = getitem_18839 = getitem_18840 = getitem_18841 = getitem_18842 = getitem_18843 = getitem_18844 = getitem_18845 = getitem_18846 = getitem_18847 = getitem_18848 = getitem_18849 = getitem_18850 = getitem_18851 = getitem_18852 = getitem_18853 = getitem_18854 = getitem_18855 = getitem_18856 = getitem_18857 = getitem_18858 = getitem_18859 = getitem_18860 = getitem_18861 = getitem_18862 = getitem_18863 = getitem_18864 = getitem_18865 = getitem_18866 = getitem_18867 = constant_pad_nd_1219 = None + reduce_scatter_tensor_223 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_362, 'avg', 128, '0'); cat_362 = None + wait_tensor_814 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_223); reduce_scatter_tensor_223 = None + slice_256 = torch.ops.aten.slice.Tensor(permute_1193, 3, 0, 128) + slice_257 = torch.ops.aten.slice.Tensor(permute_1193, 3, 128, 192); permute_1193 = None + convert_element_type_2629 = torch.ops.prims.convert_element_type.default(slice_257, torch.float32); slice_257 = None + view_2081 = torch.ops.aten.view.default(convert_element_type_2629, [2, 4096, 16, 32, 2]); convert_element_type_2629 = None + view_as_complex_85 = torch.ops.aten.view_as_complex.default(view_2081); view_2081 = None + mul_1828 = torch.ops.aten.mul.Tensor(view_as_complex_85, clone_9); view_as_complex_85 = None + view_as_real_85 = torch.ops.aten.view_as_real.default(mul_1828); mul_1828 = None + view_2082 = torch.ops.aten.view.default(view_as_real_85, [2, 4096, 16, 64]); view_as_real_85 = None + convert_element_type_2630 = torch.ops.prims.convert_element_type.default(view_2082, torch.bfloat16); view_2082 = None + cat_363 = torch.ops.aten.cat.default([slice_256, convert_element_type_2630], 3); slice_256 = convert_element_type_2630 = None + view_2083 = torch.ops.aten.view.default(cat_363, [2, 4096, 3072]); cat_363 = None + view_2084 = torch.ops.aten.view.default(view_2083, [8192, 3072]); view_2083 = None + permute_1202 = torch.ops.aten.permute.default(view_2084, [1, 0]) + mm_472 = torch.ops.aten.mm.default(permute_1202, view_706); permute_1202 = view_706 = None + convert_element_type_580 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16); primals_184 = None + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_580, 128, '0'); convert_element_type_580 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_222, [1, 0]); wait_tensor_222 = None + permute_1204 = torch.ops.aten.permute.default(permute_161, [1, 0]); permute_161 = None + mm_473 = torch.ops.aten.mm.default(view_2084, permute_1204); view_2084 = permute_1204 = None + view_2085 = torch.ops.aten.view.default(mm_473, [2, 4096, 2048]); mm_473 = None + add_2013 = torch.ops.aten.add.Tensor(view_2080, view_2085); view_2080 = view_2085 = None + convert_element_type_2635 = torch.ops.prims.convert_element_type.default(mm_472, torch.float32); mm_472 = None + reduce_scatter_tensor_224 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2635, 'avg', 128, '0'); convert_element_type_2635 = None + wait_tensor_815 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_224); reduce_scatter_tensor_224 = None + convert_element_type_2636 = torch.ops.prims.convert_element_type.default(add_2013, torch.float32); add_2013 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16); primals_183 = None + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_577, 128, '0'); convert_element_type_577 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_2638 = torch.ops.prims.convert_element_type.default(wait_tensor_221, torch.float32); wait_tensor_221 = None + mul_1829 = torch.ops.aten.mul.Tensor(convert_element_type_2636, convert_element_type_2638); convert_element_type_2638 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(add_685, torch.float32); add_685 = None + mul_499 = torch.ops.aten.mul.Tensor(convert_element_type_578, rsqrt_33); convert_element_type_578 = None + mul_1831 = torch.ops.aten.mul.Tensor(mul_499, mul_1829) + sum_233 = torch.ops.aten.sum.dim_IntList(mul_1831, [2], True); mul_1831 = None + div_227 = torch.ops.aten.div.Tensor(mul_499, 2048) + mul_1832 = torch.ops.aten.mul.Tensor(div_227, sum_233); div_227 = sum_233 = None + sub_720 = torch.ops.aten.sub.Tensor(mul_1829, mul_1832); mul_1829 = mul_1832 = None + mul_1833 = torch.ops.aten.mul.Tensor(sub_720, rsqrt_33); sub_720 = rsqrt_33 = None + mul_1834 = torch.ops.aten.mul.Tensor(convert_element_type_2636, mul_499); convert_element_type_2636 = mul_499 = None + sum_234 = torch.ops.aten.sum.dim_IntList(mul_1834, [0, 1]); mul_1834 = None + convert_element_type_2639 = torch.ops.prims.convert_element_type.default(mul_1833, torch.bfloat16); mul_1833 = None + add_2014 = torch.ops.aten.add.Tensor(add_2012, convert_element_type_2639); add_2012 = convert_element_type_2639 = None + convert_element_type_default_34 = torch.ops.prims.convert_element_type.default(sum_234, torch.float32); sum_234 = None + reduce_scatter_tensor_225 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_34, 'avg', 128, '0'); convert_element_type_default_34 = None + wait_tensor_816 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_225); reduce_scatter_tensor_225 = None + view_2086 = torch.ops.aten.view.default(add_2014, [8192, 2048]) + unsqueeze_69 = torch.ops.aten.unsqueeze.default(view_2086, 1) + convert_element_type_2642 = torch.ops.prims.convert_element_type.default(unsqueeze_69, torch.float32); unsqueeze_69 = None + bmm_58 = torch.ops.aten.bmm.default(permute_1206, convert_element_type_2642); permute_1206 = None + bmm_59 = torch.ops.aten.bmm.default(convert_element_type_2642, permute_1207); convert_element_type_2642 = permute_1207 = None + convert_element_type_2643 = torch.ops.prims.convert_element_type.default(bmm_58, torch.bfloat16); bmm_58 = None + view_2087 = torch.ops.aten.view.default(bmm_59, [8192, 6]); bmm_59 = None + view_2088 = torch.ops.aten.view.default(convert_element_type_2643, [49152, 2048]); convert_element_type_2643 = None + index_84 = torch.ops.aten.index.Tensor(view_2088, [getitem_1011]); view_2088 = getitem_1011 = None + permute_1208 = torch.ops.aten.permute.default(view_2086, [1, 0]) + mm_474 = torch.ops.aten.mm.default(permute_1208, mul_496); permute_1208 = mul_496 = None + convert_element_type_572 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16); primals_182 = None + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_572, 128, '0'); convert_element_type_572 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_160 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + permute_1210 = torch.ops.aten.permute.default(permute_160, [1, 0]); permute_160 = None + mm_475 = torch.ops.aten.mm.default(view_2086, permute_1210); view_2086 = permute_1210 = None + convert_element_type_2648 = torch.ops.prims.convert_element_type.default(mm_474, torch.float32); mm_474 = None + reduce_scatter_tensor_226 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2648, 'avg', 128, '0'); convert_element_type_2648 = None + wait_tensor_817 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_226); reduce_scatter_tensor_226 = None + convert_element_type_567 = torch.ops.prims.convert_element_type.default(mm_84, torch.float32); mm_84 = None + neg_20 = torch.ops.aten.neg.default(convert_element_type_567) + exp_30 = torch.ops.aten.exp.default(neg_20); neg_20 = None + add_680 = torch.ops.aten.add.Tensor(exp_30, 1); exp_30 = None + div_50 = torch.ops.aten.div.Tensor(convert_element_type_567, add_680) + convert_element_type_568 = torch.ops.prims.convert_element_type.default(div_50, torch.bfloat16); div_50 = None + mul_1835 = torch.ops.aten.mul.Tensor(mm_475, convert_element_type_568); convert_element_type_568 = None + mul_1836 = torch.ops.aten.mul.Tensor(mm_475, mm_85); mm_475 = mm_85 = None + permute_1212 = torch.ops.aten.permute.default(mul_1835, [1, 0]) + mm_476 = torch.ops.aten.mm.default(permute_1212, view_661); permute_1212 = None + convert_element_type_569 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16); primals_181 = None + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_569, 128, '0'); convert_element_type_569 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_159 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + permute_1214 = torch.ops.aten.permute.default(permute_159, [1, 0]); permute_159 = None + mm_477 = torch.ops.aten.mm.default(mul_1835, permute_1214); mul_1835 = permute_1214 = None + convert_element_type_2653 = torch.ops.prims.convert_element_type.default(mm_476, torch.float32); mm_476 = None + reduce_scatter_tensor_227 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2653, 'avg', 128, '0'); convert_element_type_2653 = None + wait_tensor_818 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_227); reduce_scatter_tensor_227 = None + convert_element_type_2654 = torch.ops.prims.convert_element_type.default(mul_1836, torch.float32); mul_1836 = None + reciprocal_32 = torch.ops.aten.reciprocal.default(add_680); add_680 = None + mul_1837 = torch.ops.aten.mul.Tensor(reciprocal_32, 1); reciprocal_32 = None + mul_1838 = torch.ops.aten.mul.Tensor(convert_element_type_2654, mul_1837); convert_element_type_2654 = None + sub_721 = torch.ops.aten.sub.Tensor(1, mul_1837); mul_1837 = None + mul_1839 = torch.ops.aten.mul.Tensor(convert_element_type_567, sub_721); convert_element_type_567 = sub_721 = None + add_2016 = torch.ops.aten.add.Tensor(mul_1839, 1); mul_1839 = None + mul_1840 = torch.ops.aten.mul.Tensor(mul_1838, add_2016); mul_1838 = add_2016 = None + convert_element_type_2656 = torch.ops.prims.convert_element_type.default(mul_1840, torch.bfloat16); mul_1840 = None + permute_1216 = torch.ops.aten.permute.default(convert_element_type_2656, [1, 0]) + mm_478 = torch.ops.aten.mm.default(permute_1216, view_661); permute_1216 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_564, 128, '0'); convert_element_type_564 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_158 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + permute_1218 = torch.ops.aten.permute.default(permute_158, [1, 0]); permute_158 = None + mm_479 = torch.ops.aten.mm.default(convert_element_type_2656, permute_1218); convert_element_type_2656 = permute_1218 = None + add_2017 = torch.ops.aten.add.Tensor(mm_477, mm_479); mm_477 = mm_479 = None + convert_element_type_2661 = torch.ops.prims.convert_element_type.default(mm_478, torch.float32); mm_478 = None + reduce_scatter_tensor_228 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2661, 'avg', 128, '0'); convert_element_type_2661 = None + wait_tensor_819 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_228); reduce_scatter_tensor_228 = None + all_to_all_single_110 = torch.ops._c10d_functional.all_to_all_single.default(index_84, [_local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159], [_local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151], '1033'); index_84 = None + wait_tensor_820 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_110); all_to_all_single_110 = None + full_444 = torch.ops.aten.full.default([sym_size_int_37, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_37 = None + slice_scatter_16 = torch.ops.aten.slice_scatter.default(full_444, wait_tensor_820, 0, 0, -1); wait_tensor_820 = None + index_85 = torch.ops.aten.index.Tensor(slice_scatter_16, [getitem_1012]); slice_scatter_16 = None + permute_1220 = torch.ops.aten.permute.default(index_85, [1, 0]) + _grouped_mm_174 = torch.ops.aten._grouped_mm.default(permute_1220, mul_476, cumsum_29); permute_1220 = mul_476 = None + _grouped_mm_175 = torch.ops.aten._grouped_mm.default(index_85, permute_1222, cumsum_29); index_85 = permute_1222 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(_grouped_mm_27, torch.float32); _grouped_mm_27 = None + neg_19 = torch.ops.aten.neg.default(convert_element_type_562) + exp_29 = torch.ops.aten.exp.default(neg_19); neg_19 = None + add_644 = torch.ops.aten.add.Tensor(exp_29, 1); exp_29 = None + div_49 = torch.ops.aten.div.Tensor(convert_element_type_562, add_644) + convert_element_type_563 = torch.ops.prims.convert_element_type.default(div_49, torch.bfloat16); div_49 = None + mul_1841 = torch.ops.aten.mul.Tensor(_grouped_mm_175, convert_element_type_563); convert_element_type_563 = None + mul_1842 = torch.ops.aten.mul.Tensor(_grouped_mm_175, _grouped_mm_28); _grouped_mm_175 = _grouped_mm_28 = None + permute_1224 = torch.ops.aten.permute.default(mul_1841, [1, 0]) + _grouped_mm_176 = torch.ops.aten._grouped_mm.default(permute_1224, index_19, cumsum_29); permute_1224 = None + _grouped_mm_177 = torch.ops.aten._grouped_mm.default(mul_1841, permute_1226, cumsum_29); mul_1841 = permute_1226 = None + convert_element_type_2662 = torch.ops.prims.convert_element_type.default(mul_1842, torch.float32); mul_1842 = None + reciprocal_33 = torch.ops.aten.reciprocal.default(add_644); add_644 = None + mul_1843 = torch.ops.aten.mul.Tensor(reciprocal_33, 1); reciprocal_33 = None + mul_1844 = torch.ops.aten.mul.Tensor(convert_element_type_2662, mul_1843); convert_element_type_2662 = None + sub_722 = torch.ops.aten.sub.Tensor(1, mul_1843); mul_1843 = None + mul_1845 = torch.ops.aten.mul.Tensor(convert_element_type_562, sub_722); convert_element_type_562 = sub_722 = None + add_2019 = torch.ops.aten.add.Tensor(mul_1845, 1); mul_1845 = None + mul_1846 = torch.ops.aten.mul.Tensor(mul_1844, add_2019); mul_1844 = add_2019 = None + convert_element_type_2664 = torch.ops.prims.convert_element_type.default(mul_1846, torch.bfloat16); mul_1846 = None + permute_1228 = torch.ops.aten.permute.default(convert_element_type_2664, [1, 0]) + _grouped_mm_178 = torch.ops.aten._grouped_mm.default(permute_1228, index_19, cumsum_29); permute_1228 = index_19 = None + _grouped_mm_179 = torch.ops.aten._grouped_mm.default(convert_element_type_2664, permute_1230, cumsum_29); convert_element_type_2664 = permute_1230 = cumsum_29 = None + add_2020 = torch.ops.aten.add.Tensor(_grouped_mm_177, _grouped_mm_179); _grouped_mm_177 = _grouped_mm_179 = None + convert_element_type_2665 = torch.ops.prims.convert_element_type.default(_grouped_mm_176, torch.float32); _grouped_mm_176 = None + div_228 = torch.ops.aten.div.Tensor(convert_element_type_2665, 128); convert_element_type_2665 = None + split_1005 = torch.ops.aten.split.Tensor(div_228, 88, 1); div_228 = None + getitem_18885 = split_1005[0] + getitem_18902 = split_1005[1] + getitem_18919 = split_1005[2] + getitem_18936 = split_1005[3] + getitem_18953 = split_1005[4] + getitem_18970 = split_1005[5] + getitem_18987 = split_1005[6] + getitem_19004 = split_1005[7] + getitem_19021 = split_1005[8] + getitem_19038 = split_1005[9] + getitem_19055 = split_1005[10] + getitem_19072 = split_1005[11] + getitem_19089 = split_1005[12] + getitem_19106 = split_1005[13] + getitem_19123 = split_1005[14] + getitem_19140 = split_1005[15]; split_1005 = None + cat_364 = torch.ops.aten.cat.default([getitem_18885, getitem_18902, getitem_18919, getitem_18936, getitem_18953, getitem_18970, getitem_18987, getitem_19004, getitem_19021, getitem_19038, getitem_19055, getitem_19072, getitem_19089, getitem_19106, getitem_19123, getitem_19140]); getitem_18885 = getitem_18902 = getitem_18919 = getitem_18936 = getitem_18953 = getitem_18970 = getitem_18987 = getitem_19004 = getitem_19021 = getitem_19038 = getitem_19055 = getitem_19072 = getitem_19089 = getitem_19106 = getitem_19123 = getitem_19140 = None + reduce_scatter_tensor_229 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_364, 'sum', 16, '1025'); cat_364 = None + wait_tensor_821 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_229); reduce_scatter_tensor_229 = None + convert_element_type_2666 = torch.ops.prims.convert_element_type.default(_grouped_mm_174, torch.float32); _grouped_mm_174 = None + div_229 = torch.ops.aten.div.Tensor(convert_element_type_2666, 128); convert_element_type_2666 = None + split_1022 = torch.ops.aten.split.Tensor(div_229, 128, 1); div_229 = None + getitem_19157 = split_1022[0] + getitem_19174 = split_1022[1] + getitem_19191 = split_1022[2] + getitem_19208 = split_1022[3] + getitem_19225 = split_1022[4] + getitem_19242 = split_1022[5] + getitem_19259 = split_1022[6] + getitem_19276 = split_1022[7] + getitem_19293 = split_1022[8] + getitem_19310 = split_1022[9] + getitem_19327 = split_1022[10] + getitem_19344 = split_1022[11] + getitem_19361 = split_1022[12] + getitem_19378 = split_1022[13] + getitem_19395 = split_1022[14] + getitem_19412 = split_1022[15]; split_1022 = None + cat_365 = torch.ops.aten.cat.default([getitem_19157, getitem_19174, getitem_19191, getitem_19208, getitem_19225, getitem_19242, getitem_19259, getitem_19276, getitem_19293, getitem_19310, getitem_19327, getitem_19344, getitem_19361, getitem_19378, getitem_19395, getitem_19412]); getitem_19157 = getitem_19174 = getitem_19191 = getitem_19208 = getitem_19225 = getitem_19242 = getitem_19259 = getitem_19276 = getitem_19293 = getitem_19310 = getitem_19327 = getitem_19344 = getitem_19361 = getitem_19378 = getitem_19395 = getitem_19412 = None + reduce_scatter_tensor_230 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_365, 'sum', 16, '1025'); cat_365 = None + wait_tensor_822 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_230); reduce_scatter_tensor_230 = None + convert_element_type_2667 = torch.ops.prims.convert_element_type.default(_grouped_mm_178, torch.float32); _grouped_mm_178 = None + div_230 = torch.ops.aten.div.Tensor(convert_element_type_2667, 128); convert_element_type_2667 = None + split_1039 = torch.ops.aten.split.Tensor(div_230, 88, 1); div_230 = None + getitem_19429 = split_1039[0] + getitem_19446 = split_1039[1] + getitem_19463 = split_1039[2] + getitem_19480 = split_1039[3] + getitem_19497 = split_1039[4] + getitem_19514 = split_1039[5] + getitem_19531 = split_1039[6] + getitem_19548 = split_1039[7] + getitem_19565 = split_1039[8] + getitem_19582 = split_1039[9] + getitem_19599 = split_1039[10] + getitem_19616 = split_1039[11] + getitem_19633 = split_1039[12] + getitem_19650 = split_1039[13] + getitem_19667 = split_1039[14] + getitem_19684 = split_1039[15]; split_1039 = None + cat_366 = torch.ops.aten.cat.default([getitem_19429, getitem_19446, getitem_19463, getitem_19480, getitem_19497, getitem_19514, getitem_19531, getitem_19548, getitem_19565, getitem_19582, getitem_19599, getitem_19616, getitem_19633, getitem_19650, getitem_19667, getitem_19684]); getitem_19429 = getitem_19446 = getitem_19463 = getitem_19480 = getitem_19497 = getitem_19514 = getitem_19531 = getitem_19548 = getitem_19565 = getitem_19582 = getitem_19599 = getitem_19616 = getitem_19633 = getitem_19650 = getitem_19667 = getitem_19684 = None + reduce_scatter_tensor_231 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_366, 'sum', 16, '1025'); cat_366 = None + wait_tensor_823 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_231); reduce_scatter_tensor_231 = None + index_put_84 = torch.ops.aten.index_put.default(full_444, [getitem_1012], add_2020, True); full_444 = getitem_1012 = add_2020 = None + slice_258 = torch.ops.aten.slice.Tensor(index_put_84, 0, 0, add_2021); index_put_84 = add_2021 = None + all_to_all_single_111 = torch.ops._c10d_functional.all_to_all_single.default(slice_258, [_local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151], [_local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159], '1033'); slice_258 = _local_scalar_dense_144 = _local_scalar_dense_145 = _local_scalar_dense_146 = _local_scalar_dense_147 = _local_scalar_dense_148 = _local_scalar_dense_149 = _local_scalar_dense_150 = _local_scalar_dense_151 = _local_scalar_dense_152 = _local_scalar_dense_153 = _local_scalar_dense_154 = _local_scalar_dense_155 = _local_scalar_dense_156 = _local_scalar_dense_157 = _local_scalar_dense_158 = _local_scalar_dense_159 = None + wait_tensor_824 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_111); all_to_all_single_111 = None + index_put_85 = torch.ops.aten.index_put.default(full_default_52, [div_47], wait_tensor_824, True); div_47 = wait_tensor_824 = None + add_2025 = torch.ops.aten.add.Tensor(add_2017, index_put_85); add_2017 = index_put_85 = None + mul_1847 = torch.ops.aten.mul.Tensor(view_2087, 1.0); view_2087 = None + scatter_add_16 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_1009, mul_1847); getitem_1009 = mul_1847 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(mm_83, torch.float32); mm_83 = None + sub_216 = torch.ops.aten.sub.Tensor(convert_element_type_551, amax_9); convert_element_type_551 = amax_9 = None + exp_28 = torch.ops.aten.exp.default(sub_216); sub_216 = None + div_46 = torch.ops.aten.div.Tensor(exp_28, sum_37); exp_28 = sum_37 = None + mul_1848 = torch.ops.aten.mul.Tensor(scatter_add_16, div_46); scatter_add_16 = None + sum_235 = torch.ops.aten.sum.dim_IntList(mul_1848, [1], True) + neg_103 = torch.ops.aten.neg.default(div_46); div_46 = None + fma_16 = torch.ops.prims.fma.default(neg_103, sum_235, mul_1848); neg_103 = sum_235 = mul_1848 = None + convert_element_type_2668 = torch.ops.prims.convert_element_type.default(fma_16, torch.bfloat16); fma_16 = None + permute_1232 = torch.ops.aten.permute.default(convert_element_type_2668, [1, 0]) + mm_480 = torch.ops.aten.mm.default(permute_1232, view_661); permute_1232 = view_661 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16); primals_175 = None + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 128, '0'); convert_element_type_548 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + slice_63 = torch.ops.aten.slice.Tensor(wait_tensor_207, 0, 0, 64); wait_tensor_207 = None + permute_154 = torch.ops.aten.permute.default(slice_63, [1, 0]); slice_63 = None + permute_1234 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None + mm_481 = torch.ops.aten.mm.default(convert_element_type_2668, permute_1234); convert_element_type_2668 = permute_1234 = None + add_2026 = torch.ops.aten.add.Tensor(add_2025, mm_481); add_2025 = mm_481 = None + convert_element_type_2673 = torch.ops.prims.convert_element_type.default(mm_480, torch.float32); mm_480 = None + split_1055 = torch.ops.aten.split.Tensor(convert_element_type_2673, 1); convert_element_type_2673 = None + getitem_19685 = split_1055[0] + getitem_19686 = split_1055[1] + getitem_19687 = split_1055[2] + getitem_19688 = split_1055[3] + getitem_19689 = split_1055[4] + getitem_19690 = split_1055[5] + getitem_19691 = split_1055[6] + getitem_19692 = split_1055[7] + getitem_19693 = split_1055[8] + getitem_19694 = split_1055[9] + getitem_19695 = split_1055[10] + getitem_19696 = split_1055[11] + getitem_19697 = split_1055[12] + getitem_19698 = split_1055[13] + getitem_19699 = split_1055[14] + getitem_19700 = split_1055[15] + getitem_19701 = split_1055[16] + getitem_19702 = split_1055[17] + getitem_19703 = split_1055[18] + getitem_19704 = split_1055[19] + getitem_19705 = split_1055[20] + getitem_19706 = split_1055[21] + getitem_19707 = split_1055[22] + getitem_19708 = split_1055[23] + getitem_19709 = split_1055[24] + getitem_19710 = split_1055[25] + getitem_19711 = split_1055[26] + getitem_19712 = split_1055[27] + getitem_19713 = split_1055[28] + getitem_19714 = split_1055[29] + getitem_19715 = split_1055[30] + getitem_19716 = split_1055[31] + getitem_19717 = split_1055[32] + getitem_19718 = split_1055[33] + getitem_19719 = split_1055[34] + getitem_19720 = split_1055[35] + getitem_19721 = split_1055[36] + getitem_19722 = split_1055[37] + getitem_19723 = split_1055[38] + getitem_19724 = split_1055[39] + getitem_19725 = split_1055[40] + getitem_19726 = split_1055[41] + getitem_19727 = split_1055[42] + getitem_19728 = split_1055[43] + getitem_19729 = split_1055[44] + getitem_19730 = split_1055[45] + getitem_19731 = split_1055[46] + getitem_19732 = split_1055[47] + getitem_19733 = split_1055[48] + getitem_19734 = split_1055[49] + getitem_19735 = split_1055[50] + getitem_19736 = split_1055[51] + getitem_19737 = split_1055[52] + getitem_19738 = split_1055[53] + getitem_19739 = split_1055[54] + getitem_19740 = split_1055[55] + getitem_19741 = split_1055[56] + getitem_19742 = split_1055[57] + getitem_19743 = split_1055[58] + getitem_19744 = split_1055[59] + getitem_19745 = split_1055[60] + getitem_19746 = split_1055[61] + getitem_19747 = split_1055[62] + getitem_19748 = split_1055[63]; split_1055 = None + cat_367 = torch.ops.aten.cat.default([getitem_19685, getitem_19686, getitem_19687, getitem_19688, getitem_19689, getitem_19690, getitem_19691, getitem_19692, getitem_19693, getitem_19694, getitem_19695, getitem_19696, getitem_19697, getitem_19698, getitem_19699, getitem_19700, getitem_19701, getitem_19702, getitem_19703, getitem_19704, getitem_19705, getitem_19706, getitem_19707, getitem_19708, getitem_19709, getitem_19710, getitem_19711, getitem_19712, getitem_19713, getitem_19714, getitem_19715, getitem_19716, getitem_19717, getitem_19718, getitem_19719, getitem_19720, getitem_19721, getitem_19722, getitem_19723, getitem_19724, getitem_19725, getitem_19726, getitem_19727, getitem_19728, getitem_19729, getitem_19730, getitem_19731, getitem_19732, getitem_19733, getitem_19734, getitem_19735, getitem_19736, getitem_19737, getitem_19738, getitem_19739, getitem_19740, getitem_19741, getitem_19742, getitem_19743, getitem_19744, getitem_19745, getitem_19746, getitem_19747, getitem_19748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_19685 = getitem_19686 = getitem_19687 = getitem_19688 = getitem_19689 = getitem_19690 = getitem_19691 = getitem_19692 = getitem_19693 = getitem_19694 = getitem_19695 = getitem_19696 = getitem_19697 = getitem_19698 = getitem_19699 = getitem_19700 = getitem_19701 = getitem_19702 = getitem_19703 = getitem_19704 = getitem_19705 = getitem_19706 = getitem_19707 = getitem_19708 = getitem_19709 = getitem_19710 = getitem_19711 = getitem_19712 = getitem_19713 = getitem_19714 = getitem_19715 = getitem_19716 = getitem_19717 = getitem_19718 = getitem_19719 = getitem_19720 = getitem_19721 = getitem_19722 = getitem_19723 = getitem_19724 = getitem_19725 = getitem_19726 = getitem_19727 = getitem_19728 = getitem_19729 = getitem_19730 = getitem_19731 = getitem_19732 = getitem_19733 = getitem_19734 = getitem_19735 = getitem_19736 = getitem_19737 = getitem_19738 = getitem_19739 = getitem_19740 = getitem_19741 = getitem_19742 = getitem_19743 = getitem_19744 = getitem_19745 = getitem_19746 = getitem_19747 = getitem_19748 = None + reduce_scatter_tensor_232 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_367, 'avg', 128, '0'); cat_367 = None + wait_tensor_825 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_232); reduce_scatter_tensor_232 = None + view_2089 = torch.ops.aten.view.default(add_2026, [2, 4096, 2048]); add_2026 = None + convert_element_type_2674 = torch.ops.prims.convert_element_type.default(view_2089, torch.float32); view_2089 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16); primals_173 = None + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 128, '0'); convert_element_type_545 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + convert_element_type_2676 = torch.ops.prims.convert_element_type.default(wait_tensor_206, torch.float32); wait_tensor_206 = None + mul_1849 = torch.ops.aten.mul.Tensor(convert_element_type_2674, convert_element_type_2676); convert_element_type_2676 = None + convert_element_type_546 = torch.ops.prims.convert_element_type.default(add_620, torch.float32); add_620 = None + mul_456 = torch.ops.aten.mul.Tensor(convert_element_type_546, rsqrt_32); convert_element_type_546 = None + mul_1851 = torch.ops.aten.mul.Tensor(mul_456, mul_1849) + sum_236 = torch.ops.aten.sum.dim_IntList(mul_1851, [2], True); mul_1851 = None + div_231 = torch.ops.aten.div.Tensor(mul_456, 2048) + mul_1852 = torch.ops.aten.mul.Tensor(div_231, sum_236); div_231 = sum_236 = None + sub_724 = torch.ops.aten.sub.Tensor(mul_1849, mul_1852); mul_1849 = mul_1852 = None + mul_1853 = torch.ops.aten.mul.Tensor(sub_724, rsqrt_32); sub_724 = rsqrt_32 = None + mul_1854 = torch.ops.aten.mul.Tensor(convert_element_type_2674, mul_456); convert_element_type_2674 = mul_456 = None + sum_237 = torch.ops.aten.sum.dim_IntList(mul_1854, [0, 1]); mul_1854 = None + convert_element_type_2677 = torch.ops.prims.convert_element_type.default(mul_1853, torch.bfloat16); mul_1853 = None + add_2027 = torch.ops.aten.add.Tensor(add_2014, convert_element_type_2677); add_2014 = convert_element_type_2677 = None + convert_element_type_default_33 = torch.ops.prims.convert_element_type.default(sum_237, torch.float32); sum_237 = None + reduce_scatter_tensor_233 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_33, 'avg', 128, '0'); convert_element_type_default_33 = None + wait_tensor_826 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_233); reduce_scatter_tensor_233 = None + view_2090 = torch.ops.aten.view.default(add_2027, [8192, 2048]) + permute_1236 = torch.ops.aten.permute.default(view_2090, [1, 0]) + permute_152 = torch.ops.aten.permute.default(getitem_1005, [0, 2, 1, 3]) + view_656 = torch.ops.aten.view.default(permute_152, [2, 4096, -1]); permute_152 = None + view_658 = torch.ops.aten.view.default(view_656, [8192, 2048]); view_656 = None + mm_482 = torch.ops.aten.mm.default(permute_1236, view_658); permute_1236 = view_658 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16); primals_172 = None + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_542, 128, '0'); convert_element_type_542 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + permute_1238 = torch.ops.aten.permute.default(permute_153, [1, 0]); permute_153 = None + mm_483 = torch.ops.aten.mm.default(view_2090, permute_1238); view_2090 = permute_1238 = None + view_2091 = torch.ops.aten.view.default(mm_483, [2, 4096, 2048]); mm_483 = None + convert_element_type_2684 = torch.ops.prims.convert_element_type.default(mm_482, torch.float32); mm_482 = None + reduce_scatter_tensor_234 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2684, 'avg', 128, '0'); convert_element_type_2684 = None + wait_tensor_827 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_234); reduce_scatter_tensor_234 = None + view_2092 = torch.ops.aten.view.default(view_2091, [2, 4096, 16, 128]); view_2091 = None + permute_1240 = torch.ops.aten.permute.default(view_2092, [0, 2, 1, 3]); view_2092 = None + fw_graph16 = self.fw_graph16 + joint_graph16 = self.joint_graph16 + mask_graph16 = self.mask_graph16 + flex_attention_backward_16 = torch.ops.higher_order.flex_attention_backward(permute_149, permute_150, permute_151, getitem_1005, getitem_1006, permute_1240, None, fw_graph16, joint_graph16, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph16), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_149 = permute_150 = permute_151 = getitem_1005 = getitem_1006 = permute_1240 = fw_graph16 = joint_graph16 = mask_graph16 = None + getitem_19749 = flex_attention_backward_16[0] + getitem_19750 = flex_attention_backward_16[1] + getitem_19751 = flex_attention_backward_16[2]; flex_attention_backward_16 = None + permute_1241 = torch.ops.aten.permute.default(getitem_19751, [0, 2, 1, 3]); getitem_19751 = None + permute_1242 = torch.ops.aten.permute.default(getitem_19750, [0, 2, 1, 3]); getitem_19750 = None + permute_1243 = torch.ops.aten.permute.default(getitem_19749, [0, 2, 1, 3]); getitem_19749 = None + slice_260 = torch.ops.aten.slice.Tensor(permute_1242, 3, 0, 128) + slice_261 = torch.ops.aten.slice.Tensor(permute_1242, 3, 128, 192); permute_1242 = None + sum_238 = torch.ops.aten.sum.dim_IntList(slice_261, [2], True); slice_261 = None + cat_368 = torch.ops.aten.cat.default([slice_260, permute_1241], 3); slice_260 = permute_1241 = None + view_2093 = torch.ops.aten.view.default(cat_368, [2, 4096, 4096]); cat_368 = None + view_2094 = torch.ops.aten.view.default(view_2093, [8192, 4096]); view_2093 = None + permute_1244 = torch.ops.aten.permute.default(view_2094, [1, 0]) + mm_484 = torch.ops.aten.mm.default(permute_1244, view_653); permute_1244 = view_653 = None + convert_element_type_539 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16); primals_171 = None + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_539, 128, '0'); convert_element_type_539 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_148 = torch.ops.aten.permute.default(wait_tensor_204, [1, 0]); wait_tensor_204 = None + permute_1246 = torch.ops.aten.permute.default(permute_148, [1, 0]); permute_148 = None + mm_485 = torch.ops.aten.mm.default(view_2094, permute_1246); view_2094 = permute_1246 = None + view_2095 = torch.ops.aten.view.default(mm_485, [2, 4096, 512]); mm_485 = None + convert_element_type_2689 = torch.ops.prims.convert_element_type.default(mm_484, torch.float32); mm_484 = None + reduce_scatter_tensor_235 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2689, 'avg', 128, '0'); convert_element_type_2689 = None + wait_tensor_828 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_235); reduce_scatter_tensor_235 = None + convert_element_type_2690 = torch.ops.prims.convert_element_type.default(view_2095, torch.float32); view_2095 = None + convert_element_type_536 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16); primals_170 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_536, 128, '0'); convert_element_type_536 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + convert_element_type_2692 = torch.ops.prims.convert_element_type.default(wait_tensor_203, torch.float32); wait_tensor_203 = None + mul_1855 = torch.ops.aten.mul.Tensor(convert_element_type_2690, convert_element_type_2692); convert_element_type_2692 = None + convert_element_type_537 = torch.ops.prims.convert_element_type.default(getitem_1001, torch.float32); getitem_1001 = None + mul_454 = torch.ops.aten.mul.Tensor(convert_element_type_537, rsqrt_31); convert_element_type_537 = None + mul_1857 = torch.ops.aten.mul.Tensor(mul_454, mul_1855) + sum_239 = torch.ops.aten.sum.dim_IntList(mul_1857, [2], True); mul_1857 = None + div_232 = torch.ops.aten.div.Tensor(mul_454, 512) + mul_1858 = torch.ops.aten.mul.Tensor(div_232, sum_239); div_232 = sum_239 = None + sub_725 = torch.ops.aten.sub.Tensor(mul_1855, mul_1858); mul_1855 = mul_1858 = None + mul_1859 = torch.ops.aten.mul.Tensor(sub_725, rsqrt_31); sub_725 = rsqrt_31 = None + mul_1860 = torch.ops.aten.mul.Tensor(convert_element_type_2690, mul_454); convert_element_type_2690 = mul_454 = None + sum_240 = torch.ops.aten.sum.dim_IntList(mul_1860, [0, 1]); mul_1860 = None + convert_element_type_2693 = torch.ops.prims.convert_element_type.default(mul_1859, torch.bfloat16); mul_1859 = None + convert_element_type_default_32 = torch.ops.prims.convert_element_type.default(sum_240, torch.float32); sum_240 = None + reduce_scatter_tensor_236 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_32, 'avg', 128, '0'); convert_element_type_default_32 = None + wait_tensor_829 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_236); reduce_scatter_tensor_236 = None + convert_element_type_2696 = torch.ops.prims.convert_element_type.default(sum_238, torch.float32); sum_238 = None + view_2096 = torch.ops.aten.view.default(convert_element_type_2696, [2, 4096, 1, 32, 2]); convert_element_type_2696 = None + view_as_complex_86 = torch.ops.aten.view_as_complex.default(view_2096); view_2096 = None + mul_1861 = torch.ops.aten.mul.Tensor(view_as_complex_86, clone_9); view_as_complex_86 = None + view_as_real_86 = torch.ops.aten.view_as_real.default(mul_1861); mul_1861 = None + view_2097 = torch.ops.aten.view.default(view_as_real_86, [2, 4096, 1, 64]); view_as_real_86 = None + convert_element_type_2697 = torch.ops.prims.convert_element_type.default(view_2097, torch.bfloat16); view_2097 = None + squeeze_42 = torch.ops.aten.squeeze.dim(convert_element_type_2697, 2); convert_element_type_2697 = None + cat_369 = torch.ops.aten.cat.default([convert_element_type_2693, squeeze_42], 2); convert_element_type_2693 = squeeze_42 = None + view_2098 = torch.ops.aten.view.default(cat_369, [8192, 576]); cat_369 = None + permute_1248 = torch.ops.aten.permute.default(view_2098, [1, 0]) + mm_486 = torch.ops.aten.mm.default(permute_1248, view_639); permute_1248 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16); primals_169 = None + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_531, 128, '0'); convert_element_type_531 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + slice_61 = torch.ops.aten.slice.Tensor(wait_tensor_202, 0, 0, 576); wait_tensor_202 = None + permute_147 = torch.ops.aten.permute.default(slice_61, [1, 0]); slice_61 = None + permute_1250 = torch.ops.aten.permute.default(permute_147, [1, 0]); permute_147 = None + mm_487 = torch.ops.aten.mm.default(view_2098, permute_1250); view_2098 = permute_1250 = None + view_2099 = torch.ops.aten.view.default(mm_487, [2, 4096, 2048]); mm_487 = None + convert_element_type_2702 = torch.ops.prims.convert_element_type.default(mm_486, torch.float32); mm_486 = None + split_1056 = torch.ops.aten.split.Tensor(convert_element_type_2702, 5); convert_element_type_2702 = None + getitem_19753 = split_1056[0] + getitem_19754 = split_1056[1] + getitem_19755 = split_1056[2] + getitem_19756 = split_1056[3] + getitem_19757 = split_1056[4] + getitem_19758 = split_1056[5] + getitem_19759 = split_1056[6] + getitem_19760 = split_1056[7] + getitem_19761 = split_1056[8] + getitem_19762 = split_1056[9] + getitem_19763 = split_1056[10] + getitem_19764 = split_1056[11] + getitem_19765 = split_1056[12] + getitem_19766 = split_1056[13] + getitem_19767 = split_1056[14] + getitem_19768 = split_1056[15] + getitem_19769 = split_1056[16] + getitem_19770 = split_1056[17] + getitem_19771 = split_1056[18] + getitem_19772 = split_1056[19] + getitem_19773 = split_1056[20] + getitem_19774 = split_1056[21] + getitem_19775 = split_1056[22] + getitem_19776 = split_1056[23] + getitem_19777 = split_1056[24] + getitem_19778 = split_1056[25] + getitem_19779 = split_1056[26] + getitem_19780 = split_1056[27] + getitem_19781 = split_1056[28] + getitem_19782 = split_1056[29] + getitem_19783 = split_1056[30] + getitem_19784 = split_1056[31] + getitem_19785 = split_1056[32] + getitem_19786 = split_1056[33] + getitem_19787 = split_1056[34] + getitem_19788 = split_1056[35] + getitem_19789 = split_1056[36] + getitem_19790 = split_1056[37] + getitem_19791 = split_1056[38] + getitem_19792 = split_1056[39] + getitem_19793 = split_1056[40] + getitem_19794 = split_1056[41] + getitem_19795 = split_1056[42] + getitem_19796 = split_1056[43] + getitem_19797 = split_1056[44] + getitem_19798 = split_1056[45] + getitem_19799 = split_1056[46] + getitem_19800 = split_1056[47] + getitem_19801 = split_1056[48] + getitem_19802 = split_1056[49] + getitem_19803 = split_1056[50] + getitem_19804 = split_1056[51] + getitem_19805 = split_1056[52] + getitem_19806 = split_1056[53] + getitem_19807 = split_1056[54] + getitem_19808 = split_1056[55] + getitem_19809 = split_1056[56] + getitem_19810 = split_1056[57] + getitem_19811 = split_1056[58] + getitem_19812 = split_1056[59] + getitem_19813 = split_1056[60] + getitem_19814 = split_1056[61] + getitem_19815 = split_1056[62] + getitem_19816 = split_1056[63] + getitem_19817 = split_1056[64] + getitem_19818 = split_1056[65] + getitem_19819 = split_1056[66] + getitem_19820 = split_1056[67] + getitem_19821 = split_1056[68] + getitem_19822 = split_1056[69] + getitem_19823 = split_1056[70] + getitem_19824 = split_1056[71] + getitem_19825 = split_1056[72] + getitem_19826 = split_1056[73] + getitem_19827 = split_1056[74] + getitem_19828 = split_1056[75] + getitem_19829 = split_1056[76] + getitem_19830 = split_1056[77] + getitem_19831 = split_1056[78] + getitem_19832 = split_1056[79] + getitem_19833 = split_1056[80] + getitem_19834 = split_1056[81] + getitem_19835 = split_1056[82] + getitem_19836 = split_1056[83] + getitem_19837 = split_1056[84] + getitem_19838 = split_1056[85] + getitem_19839 = split_1056[86] + getitem_19840 = split_1056[87] + getitem_19841 = split_1056[88] + getitem_19842 = split_1056[89] + getitem_19843 = split_1056[90] + getitem_19844 = split_1056[91] + getitem_19845 = split_1056[92] + getitem_19846 = split_1056[93] + getitem_19847 = split_1056[94] + getitem_19848 = split_1056[95] + getitem_19849 = split_1056[96] + getitem_19850 = split_1056[97] + getitem_19851 = split_1056[98] + getitem_19852 = split_1056[99] + getitem_19853 = split_1056[100] + getitem_19854 = split_1056[101] + getitem_19855 = split_1056[102] + getitem_19856 = split_1056[103] + getitem_19857 = split_1056[104] + getitem_19858 = split_1056[105] + getitem_19859 = split_1056[106] + getitem_19860 = split_1056[107] + getitem_19861 = split_1056[108] + getitem_19862 = split_1056[109] + getitem_19863 = split_1056[110] + getitem_19864 = split_1056[111] + getitem_19865 = split_1056[112] + getitem_19866 = split_1056[113] + getitem_19867 = split_1056[114] + getitem_19868 = split_1056[115]; split_1056 = None + constant_pad_nd_1296 = torch.ops.aten.constant_pad_nd.default(getitem_19868, [0, 0, 0, 4], 0.0); getitem_19868 = None + cat_370 = torch.ops.aten.cat.default([getitem_19753, getitem_19754, getitem_19755, getitem_19756, getitem_19757, getitem_19758, getitem_19759, getitem_19760, getitem_19761, getitem_19762, getitem_19763, getitem_19764, getitem_19765, getitem_19766, getitem_19767, getitem_19768, getitem_19769, getitem_19770, getitem_19771, getitem_19772, getitem_19773, getitem_19774, getitem_19775, getitem_19776, getitem_19777, getitem_19778, getitem_19779, getitem_19780, getitem_19781, getitem_19782, getitem_19783, getitem_19784, getitem_19785, getitem_19786, getitem_19787, getitem_19788, getitem_19789, getitem_19790, getitem_19791, getitem_19792, getitem_19793, getitem_19794, getitem_19795, getitem_19796, getitem_19797, getitem_19798, getitem_19799, getitem_19800, getitem_19801, getitem_19802, getitem_19803, getitem_19804, getitem_19805, getitem_19806, getitem_19807, getitem_19808, getitem_19809, getitem_19810, getitem_19811, getitem_19812, getitem_19813, getitem_19814, getitem_19815, getitem_19816, getitem_19817, getitem_19818, getitem_19819, getitem_19820, getitem_19821, getitem_19822, getitem_19823, getitem_19824, getitem_19825, getitem_19826, getitem_19827, getitem_19828, getitem_19829, getitem_19830, getitem_19831, getitem_19832, getitem_19833, getitem_19834, getitem_19835, getitem_19836, getitem_19837, getitem_19838, getitem_19839, getitem_19840, getitem_19841, getitem_19842, getitem_19843, getitem_19844, getitem_19845, getitem_19846, getitem_19847, getitem_19848, getitem_19849, getitem_19850, getitem_19851, getitem_19852, getitem_19853, getitem_19854, getitem_19855, getitem_19856, getitem_19857, getitem_19858, getitem_19859, getitem_19860, getitem_19861, getitem_19862, getitem_19863, getitem_19864, getitem_19865, getitem_19866, getitem_19867, constant_pad_nd_1296, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_19753 = getitem_19754 = getitem_19755 = getitem_19756 = getitem_19757 = getitem_19758 = getitem_19759 = getitem_19760 = getitem_19761 = getitem_19762 = getitem_19763 = getitem_19764 = getitem_19765 = getitem_19766 = getitem_19767 = getitem_19768 = getitem_19769 = getitem_19770 = getitem_19771 = getitem_19772 = getitem_19773 = getitem_19774 = getitem_19775 = getitem_19776 = getitem_19777 = getitem_19778 = getitem_19779 = getitem_19780 = getitem_19781 = getitem_19782 = getitem_19783 = getitem_19784 = getitem_19785 = getitem_19786 = getitem_19787 = getitem_19788 = getitem_19789 = getitem_19790 = getitem_19791 = getitem_19792 = getitem_19793 = getitem_19794 = getitem_19795 = getitem_19796 = getitem_19797 = getitem_19798 = getitem_19799 = getitem_19800 = getitem_19801 = getitem_19802 = getitem_19803 = getitem_19804 = getitem_19805 = getitem_19806 = getitem_19807 = getitem_19808 = getitem_19809 = getitem_19810 = getitem_19811 = getitem_19812 = getitem_19813 = getitem_19814 = getitem_19815 = getitem_19816 = getitem_19817 = getitem_19818 = getitem_19819 = getitem_19820 = getitem_19821 = getitem_19822 = getitem_19823 = getitem_19824 = getitem_19825 = getitem_19826 = getitem_19827 = getitem_19828 = getitem_19829 = getitem_19830 = getitem_19831 = getitem_19832 = getitem_19833 = getitem_19834 = getitem_19835 = getitem_19836 = getitem_19837 = getitem_19838 = getitem_19839 = getitem_19840 = getitem_19841 = getitem_19842 = getitem_19843 = getitem_19844 = getitem_19845 = getitem_19846 = getitem_19847 = getitem_19848 = getitem_19849 = getitem_19850 = getitem_19851 = getitem_19852 = getitem_19853 = getitem_19854 = getitem_19855 = getitem_19856 = getitem_19857 = getitem_19858 = getitem_19859 = getitem_19860 = getitem_19861 = getitem_19862 = getitem_19863 = getitem_19864 = getitem_19865 = getitem_19866 = getitem_19867 = constant_pad_nd_1296 = None + reduce_scatter_tensor_237 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_370, 'avg', 128, '0'); cat_370 = None + wait_tensor_830 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_237); reduce_scatter_tensor_237 = None + slice_262 = torch.ops.aten.slice.Tensor(permute_1243, 3, 0, 128) + slice_263 = torch.ops.aten.slice.Tensor(permute_1243, 3, 128, 192); permute_1243 = None + convert_element_type_2703 = torch.ops.prims.convert_element_type.default(slice_263, torch.float32); slice_263 = None + view_2100 = torch.ops.aten.view.default(convert_element_type_2703, [2, 4096, 16, 32, 2]); convert_element_type_2703 = None + view_as_complex_87 = torch.ops.aten.view_as_complex.default(view_2100); view_2100 = None + mul_1862 = torch.ops.aten.mul.Tensor(view_as_complex_87, clone_9); view_as_complex_87 = None + view_as_real_87 = torch.ops.aten.view_as_real.default(mul_1862); mul_1862 = None + view_2101 = torch.ops.aten.view.default(view_as_real_87, [2, 4096, 16, 64]); view_as_real_87 = None + convert_element_type_2704 = torch.ops.prims.convert_element_type.default(view_2101, torch.bfloat16); view_2101 = None + cat_371 = torch.ops.aten.cat.default([slice_262, convert_element_type_2704], 3); slice_262 = convert_element_type_2704 = None + view_2102 = torch.ops.aten.view.default(cat_371, [2, 4096, 3072]); cat_371 = None + view_2103 = torch.ops.aten.view.default(view_2102, [8192, 3072]); view_2102 = None + permute_1252 = torch.ops.aten.permute.default(view_2103, [1, 0]) + mm_488 = torch.ops.aten.mm.default(permute_1252, view_639); permute_1252 = view_639 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16); primals_168 = None + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 128, '0'); convert_element_type_526 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_146 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + permute_1254 = torch.ops.aten.permute.default(permute_146, [1, 0]); permute_146 = None + mm_489 = torch.ops.aten.mm.default(view_2103, permute_1254); view_2103 = permute_1254 = None + view_2104 = torch.ops.aten.view.default(mm_489, [2, 4096, 2048]); mm_489 = None + add_2028 = torch.ops.aten.add.Tensor(view_2099, view_2104); view_2099 = view_2104 = None + convert_element_type_2709 = torch.ops.prims.convert_element_type.default(mm_488, torch.float32); mm_488 = None + reduce_scatter_tensor_238 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2709, 'avg', 128, '0'); convert_element_type_2709 = None + wait_tensor_831 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_238); reduce_scatter_tensor_238 = None + convert_element_type_2710 = torch.ops.prims.convert_element_type.default(add_2028, torch.float32); add_2028 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16); primals_167 = None + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 128, '0'); convert_element_type_523 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + convert_element_type_2712 = torch.ops.prims.convert_element_type.default(wait_tensor_200, torch.float32); wait_tensor_200 = None + mul_1863 = torch.ops.aten.mul.Tensor(convert_element_type_2710, convert_element_type_2712); convert_element_type_2712 = None + convert_element_type_524 = torch.ops.prims.convert_element_type.default(add_617, torch.float32); add_617 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_524, rsqrt_30); convert_element_type_524 = None + mul_1865 = torch.ops.aten.mul.Tensor(mul_450, mul_1863) + sum_241 = torch.ops.aten.sum.dim_IntList(mul_1865, [2], True); mul_1865 = None + div_233 = torch.ops.aten.div.Tensor(mul_450, 2048) + mul_1866 = torch.ops.aten.mul.Tensor(div_233, sum_241); div_233 = sum_241 = None + sub_726 = torch.ops.aten.sub.Tensor(mul_1863, mul_1866); mul_1863 = mul_1866 = None + mul_1867 = torch.ops.aten.mul.Tensor(sub_726, rsqrt_30); sub_726 = rsqrt_30 = None + mul_1868 = torch.ops.aten.mul.Tensor(convert_element_type_2710, mul_450); convert_element_type_2710 = mul_450 = None + sum_242 = torch.ops.aten.sum.dim_IntList(mul_1868, [0, 1]); mul_1868 = None + convert_element_type_2713 = torch.ops.prims.convert_element_type.default(mul_1867, torch.bfloat16); mul_1867 = None + add_2029 = torch.ops.aten.add.Tensor(add_2027, convert_element_type_2713); add_2027 = convert_element_type_2713 = None + convert_element_type_default_31 = torch.ops.prims.convert_element_type.default(sum_242, torch.float32); sum_242 = None + reduce_scatter_tensor_239 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_31, 'avg', 128, '0'); convert_element_type_default_31 = None + wait_tensor_832 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_239); reduce_scatter_tensor_239 = None + view_2105 = torch.ops.aten.view.default(add_2029, [8192, 2048]) + unsqueeze_70 = torch.ops.aten.unsqueeze.default(view_2105, 1) + convert_element_type_2716 = torch.ops.prims.convert_element_type.default(unsqueeze_70, torch.float32); unsqueeze_70 = None + bmm_60 = torch.ops.aten.bmm.default(permute_1256, convert_element_type_2716); permute_1256 = None + bmm_61 = torch.ops.aten.bmm.default(convert_element_type_2716, permute_1257); convert_element_type_2716 = permute_1257 = None + convert_element_type_2717 = torch.ops.prims.convert_element_type.default(bmm_60, torch.bfloat16); bmm_60 = None + view_2106 = torch.ops.aten.view.default(bmm_61, [8192, 6]); bmm_61 = None + view_2107 = torch.ops.aten.view.default(convert_element_type_2717, [49152, 2048]); convert_element_type_2717 = None + index_86 = torch.ops.aten.index.Tensor(view_2107, [getitem_901]); view_2107 = getitem_901 = None + permute_1258 = torch.ops.aten.permute.default(view_2105, [1, 0]) + mm_490 = torch.ops.aten.mm.default(permute_1258, mul_447); permute_1258 = mul_447 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16); primals_166 = None + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 128, '0'); convert_element_type_518 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + permute_1260 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None + mm_491 = torch.ops.aten.mm.default(view_2105, permute_1260); view_2105 = permute_1260 = None + convert_element_type_2722 = torch.ops.prims.convert_element_type.default(mm_490, torch.float32); mm_490 = None + reduce_scatter_tensor_240 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2722, 'avg', 128, '0'); convert_element_type_2722 = None + wait_tensor_833 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_240); reduce_scatter_tensor_240 = None + convert_element_type_513 = torch.ops.prims.convert_element_type.default(mm_76, torch.float32); mm_76 = None + neg_18 = torch.ops.aten.neg.default(convert_element_type_513) + exp_27 = torch.ops.aten.exp.default(neg_18); neg_18 = None + add_612 = torch.ops.aten.add.Tensor(exp_27, 1); exp_27 = None + div_45 = torch.ops.aten.div.Tensor(convert_element_type_513, add_612) + convert_element_type_514 = torch.ops.prims.convert_element_type.default(div_45, torch.bfloat16); div_45 = None + mul_1869 = torch.ops.aten.mul.Tensor(mm_491, convert_element_type_514); convert_element_type_514 = None + mul_1870 = torch.ops.aten.mul.Tensor(mm_491, mm_77); mm_491 = mm_77 = None + permute_1262 = torch.ops.aten.permute.default(mul_1869, [1, 0]) + mm_492 = torch.ops.aten.mm.default(permute_1262, view_594); permute_1262 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16); primals_165 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 128, '0'); convert_element_type_515 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + permute_1264 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None + mm_493 = torch.ops.aten.mm.default(mul_1869, permute_1264); mul_1869 = permute_1264 = None + convert_element_type_2727 = torch.ops.prims.convert_element_type.default(mm_492, torch.float32); mm_492 = None + reduce_scatter_tensor_241 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2727, 'avg', 128, '0'); convert_element_type_2727 = None + wait_tensor_834 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_241); reduce_scatter_tensor_241 = None + convert_element_type_2728 = torch.ops.prims.convert_element_type.default(mul_1870, torch.float32); mul_1870 = None + reciprocal_34 = torch.ops.aten.reciprocal.default(add_612); add_612 = None + mul_1871 = torch.ops.aten.mul.Tensor(reciprocal_34, 1); reciprocal_34 = None + mul_1872 = torch.ops.aten.mul.Tensor(convert_element_type_2728, mul_1871); convert_element_type_2728 = None + sub_727 = torch.ops.aten.sub.Tensor(1, mul_1871); mul_1871 = None + mul_1873 = torch.ops.aten.mul.Tensor(convert_element_type_513, sub_727); convert_element_type_513 = sub_727 = None + add_2031 = torch.ops.aten.add.Tensor(mul_1873, 1); mul_1873 = None + mul_1874 = torch.ops.aten.mul.Tensor(mul_1872, add_2031); mul_1872 = add_2031 = None + convert_element_type_2730 = torch.ops.prims.convert_element_type.default(mul_1874, torch.bfloat16); mul_1874 = None + permute_1266 = torch.ops.aten.permute.default(convert_element_type_2730, [1, 0]) + mm_494 = torch.ops.aten.mm.default(permute_1266, view_594); permute_1266 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16); primals_164 = None + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_510, 128, '0'); convert_element_type_510 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + permute_1268 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None + mm_495 = torch.ops.aten.mm.default(convert_element_type_2730, permute_1268); convert_element_type_2730 = permute_1268 = None + add_2032 = torch.ops.aten.add.Tensor(mm_493, mm_495); mm_493 = mm_495 = None + convert_element_type_2735 = torch.ops.prims.convert_element_type.default(mm_494, torch.float32); mm_494 = None + reduce_scatter_tensor_242 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2735, 'avg', 128, '0'); convert_element_type_2735 = None + wait_tensor_835 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_242); reduce_scatter_tensor_242 = None + all_to_all_single_112 = torch.ops._c10d_functional.all_to_all_single.default(index_86, [_local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143], [_local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135], '1033'); index_86 = None + wait_tensor_836 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_112); all_to_all_single_112 = None + full_450 = torch.ops.aten.full.default([sym_size_int_33, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_33 = None + slice_scatter_17 = torch.ops.aten.slice_scatter.default(full_450, wait_tensor_836, 0, 0, -1); wait_tensor_836 = None + index_87 = torch.ops.aten.index.Tensor(slice_scatter_17, [getitem_902]); slice_scatter_17 = None + permute_1270 = torch.ops.aten.permute.default(index_87, [1, 0]) + _grouped_mm_180 = torch.ops.aten._grouped_mm.default(permute_1270, mul_427, cumsum_26); permute_1270 = mul_427 = None + _grouped_mm_181 = torch.ops.aten._grouped_mm.default(index_87, permute_1272, cumsum_26); index_87 = permute_1272 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(_grouped_mm_24, torch.float32); _grouped_mm_24 = None + neg_17 = torch.ops.aten.neg.default(convert_element_type_508) + exp_26 = torch.ops.aten.exp.default(neg_17); neg_17 = None + add_576 = torch.ops.aten.add.Tensor(exp_26, 1); exp_26 = None + div_44 = torch.ops.aten.div.Tensor(convert_element_type_508, add_576) + convert_element_type_509 = torch.ops.prims.convert_element_type.default(div_44, torch.bfloat16); div_44 = None + mul_1875 = torch.ops.aten.mul.Tensor(_grouped_mm_181, convert_element_type_509); convert_element_type_509 = None + mul_1876 = torch.ops.aten.mul.Tensor(_grouped_mm_181, _grouped_mm_25); _grouped_mm_181 = _grouped_mm_25 = None + permute_1274 = torch.ops.aten.permute.default(mul_1875, [1, 0]) + _grouped_mm_182 = torch.ops.aten._grouped_mm.default(permute_1274, index_17, cumsum_26); permute_1274 = None + _grouped_mm_183 = torch.ops.aten._grouped_mm.default(mul_1875, permute_1276, cumsum_26); mul_1875 = permute_1276 = None + convert_element_type_2736 = torch.ops.prims.convert_element_type.default(mul_1876, torch.float32); mul_1876 = None + reciprocal_35 = torch.ops.aten.reciprocal.default(add_576); add_576 = None + mul_1877 = torch.ops.aten.mul.Tensor(reciprocal_35, 1); reciprocal_35 = None + mul_1878 = torch.ops.aten.mul.Tensor(convert_element_type_2736, mul_1877); convert_element_type_2736 = None + sub_728 = torch.ops.aten.sub.Tensor(1, mul_1877); mul_1877 = None + mul_1879 = torch.ops.aten.mul.Tensor(convert_element_type_508, sub_728); convert_element_type_508 = sub_728 = None + add_2034 = torch.ops.aten.add.Tensor(mul_1879, 1); mul_1879 = None + mul_1880 = torch.ops.aten.mul.Tensor(mul_1878, add_2034); mul_1878 = add_2034 = None + convert_element_type_2738 = torch.ops.prims.convert_element_type.default(mul_1880, torch.bfloat16); mul_1880 = None + permute_1278 = torch.ops.aten.permute.default(convert_element_type_2738, [1, 0]) + _grouped_mm_184 = torch.ops.aten._grouped_mm.default(permute_1278, index_17, cumsum_26); permute_1278 = index_17 = None + _grouped_mm_185 = torch.ops.aten._grouped_mm.default(convert_element_type_2738, permute_1280, cumsum_26); convert_element_type_2738 = permute_1280 = cumsum_26 = None + add_2035 = torch.ops.aten.add.Tensor(_grouped_mm_183, _grouped_mm_185); _grouped_mm_183 = _grouped_mm_185 = None + convert_element_type_2739 = torch.ops.prims.convert_element_type.default(_grouped_mm_182, torch.float32); _grouped_mm_182 = None + div_234 = torch.ops.aten.div.Tensor(convert_element_type_2739, 128); convert_element_type_2739 = None + split_1058 = torch.ops.aten.split.Tensor(div_234, 88, 1); div_234 = None + getitem_19885 = split_1058[0] + getitem_19902 = split_1058[1] + getitem_19919 = split_1058[2] + getitem_19936 = split_1058[3] + getitem_19953 = split_1058[4] + getitem_19970 = split_1058[5] + getitem_19987 = split_1058[6] + getitem_20004 = split_1058[7] + getitem_20021 = split_1058[8] + getitem_20038 = split_1058[9] + getitem_20055 = split_1058[10] + getitem_20072 = split_1058[11] + getitem_20089 = split_1058[12] + getitem_20106 = split_1058[13] + getitem_20123 = split_1058[14] + getitem_20140 = split_1058[15]; split_1058 = None + cat_372 = torch.ops.aten.cat.default([getitem_19885, getitem_19902, getitem_19919, getitem_19936, getitem_19953, getitem_19970, getitem_19987, getitem_20004, getitem_20021, getitem_20038, getitem_20055, getitem_20072, getitem_20089, getitem_20106, getitem_20123, getitem_20140]); getitem_19885 = getitem_19902 = getitem_19919 = getitem_19936 = getitem_19953 = getitem_19970 = getitem_19987 = getitem_20004 = getitem_20021 = getitem_20038 = getitem_20055 = getitem_20072 = getitem_20089 = getitem_20106 = getitem_20123 = getitem_20140 = None + reduce_scatter_tensor_243 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_372, 'sum', 16, '1025'); cat_372 = None + wait_tensor_837 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_243); reduce_scatter_tensor_243 = None + convert_element_type_2740 = torch.ops.prims.convert_element_type.default(_grouped_mm_180, torch.float32); _grouped_mm_180 = None + div_235 = torch.ops.aten.div.Tensor(convert_element_type_2740, 128); convert_element_type_2740 = None + split_1075 = torch.ops.aten.split.Tensor(div_235, 128, 1); div_235 = None + getitem_20157 = split_1075[0] + getitem_20174 = split_1075[1] + getitem_20191 = split_1075[2] + getitem_20208 = split_1075[3] + getitem_20225 = split_1075[4] + getitem_20242 = split_1075[5] + getitem_20259 = split_1075[6] + getitem_20276 = split_1075[7] + getitem_20293 = split_1075[8] + getitem_20310 = split_1075[9] + getitem_20327 = split_1075[10] + getitem_20344 = split_1075[11] + getitem_20361 = split_1075[12] + getitem_20378 = split_1075[13] + getitem_20395 = split_1075[14] + getitem_20412 = split_1075[15]; split_1075 = None + cat_373 = torch.ops.aten.cat.default([getitem_20157, getitem_20174, getitem_20191, getitem_20208, getitem_20225, getitem_20242, getitem_20259, getitem_20276, getitem_20293, getitem_20310, getitem_20327, getitem_20344, getitem_20361, getitem_20378, getitem_20395, getitem_20412]); getitem_20157 = getitem_20174 = getitem_20191 = getitem_20208 = getitem_20225 = getitem_20242 = getitem_20259 = getitem_20276 = getitem_20293 = getitem_20310 = getitem_20327 = getitem_20344 = getitem_20361 = getitem_20378 = getitem_20395 = getitem_20412 = None + reduce_scatter_tensor_244 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_373, 'sum', 16, '1025'); cat_373 = None + wait_tensor_838 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_244); reduce_scatter_tensor_244 = None + convert_element_type_2741 = torch.ops.prims.convert_element_type.default(_grouped_mm_184, torch.float32); _grouped_mm_184 = None + div_236 = torch.ops.aten.div.Tensor(convert_element_type_2741, 128); convert_element_type_2741 = None + split_1092 = torch.ops.aten.split.Tensor(div_236, 88, 1); div_236 = None + getitem_20429 = split_1092[0] + getitem_20446 = split_1092[1] + getitem_20463 = split_1092[2] + getitem_20480 = split_1092[3] + getitem_20497 = split_1092[4] + getitem_20514 = split_1092[5] + getitem_20531 = split_1092[6] + getitem_20548 = split_1092[7] + getitem_20565 = split_1092[8] + getitem_20582 = split_1092[9] + getitem_20599 = split_1092[10] + getitem_20616 = split_1092[11] + getitem_20633 = split_1092[12] + getitem_20650 = split_1092[13] + getitem_20667 = split_1092[14] + getitem_20684 = split_1092[15]; split_1092 = None + cat_374 = torch.ops.aten.cat.default([getitem_20429, getitem_20446, getitem_20463, getitem_20480, getitem_20497, getitem_20514, getitem_20531, getitem_20548, getitem_20565, getitem_20582, getitem_20599, getitem_20616, getitem_20633, getitem_20650, getitem_20667, getitem_20684]); getitem_20429 = getitem_20446 = getitem_20463 = getitem_20480 = getitem_20497 = getitem_20514 = getitem_20531 = getitem_20548 = getitem_20565 = getitem_20582 = getitem_20599 = getitem_20616 = getitem_20633 = getitem_20650 = getitem_20667 = getitem_20684 = None + reduce_scatter_tensor_245 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_374, 'sum', 16, '1025'); cat_374 = None + wait_tensor_839 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_245); reduce_scatter_tensor_245 = None + index_put_86 = torch.ops.aten.index_put.default(full_450, [getitem_902], add_2035, True); full_450 = getitem_902 = add_2035 = None + slice_264 = torch.ops.aten.slice.Tensor(index_put_86, 0, 0, add_2036); index_put_86 = add_2036 = None + all_to_all_single_113 = torch.ops._c10d_functional.all_to_all_single.default(slice_264, [_local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135], [_local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143], '1033'); slice_264 = _local_scalar_dense_128 = _local_scalar_dense_129 = _local_scalar_dense_130 = _local_scalar_dense_131 = _local_scalar_dense_132 = _local_scalar_dense_133 = _local_scalar_dense_134 = _local_scalar_dense_135 = _local_scalar_dense_136 = _local_scalar_dense_137 = _local_scalar_dense_138 = _local_scalar_dense_139 = _local_scalar_dense_140 = _local_scalar_dense_141 = _local_scalar_dense_142 = _local_scalar_dense_143 = None + wait_tensor_840 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_113); all_to_all_single_113 = None + index_put_87 = torch.ops.aten.index_put.default(full_default_52, [div_42], wait_tensor_840, True); div_42 = wait_tensor_840 = None + add_2040 = torch.ops.aten.add.Tensor(add_2032, index_put_87); add_2032 = index_put_87 = None + mul_1881 = torch.ops.aten.mul.Tensor(view_2106, 1.0); view_2106 = None + scatter_add_17 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_899, mul_1881); getitem_899 = mul_1881 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(mm_75, torch.float32); mm_75 = None + sub_192 = torch.ops.aten.sub.Tensor(convert_element_type_497, amax_8); convert_element_type_497 = amax_8 = None + exp_25 = torch.ops.aten.exp.default(sub_192); sub_192 = None + div_41 = torch.ops.aten.div.Tensor(exp_25, sum_33); exp_25 = sum_33 = None + mul_1882 = torch.ops.aten.mul.Tensor(scatter_add_17, div_41); scatter_add_17 = None + sum_243 = torch.ops.aten.sum.dim_IntList(mul_1882, [1], True) + neg_106 = torch.ops.aten.neg.default(div_41); div_41 = None + fma_17 = torch.ops.prims.fma.default(neg_106, sum_243, mul_1882); neg_106 = sum_243 = mul_1882 = None + convert_element_type_2742 = torch.ops.prims.convert_element_type.default(fma_17, torch.bfloat16); fma_17 = None + permute_1282 = torch.ops.aten.permute.default(convert_element_type_2742, [1, 0]) + mm_496 = torch.ops.aten.mm.default(permute_1282, view_594); permute_1282 = view_594 = None + convert_element_type_494 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16); primals_159 = None + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_494, 128, '0'); convert_element_type_494 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + slice_57 = torch.ops.aten.slice.Tensor(wait_tensor_186, 0, 0, 64); wait_tensor_186 = None + permute_139 = torch.ops.aten.permute.default(slice_57, [1, 0]); slice_57 = None + permute_1284 = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None + mm_497 = torch.ops.aten.mm.default(convert_element_type_2742, permute_1284); convert_element_type_2742 = permute_1284 = None + add_2041 = torch.ops.aten.add.Tensor(add_2040, mm_497); add_2040 = mm_497 = None + convert_element_type_2747 = torch.ops.prims.convert_element_type.default(mm_496, torch.float32); mm_496 = None + split_1108 = torch.ops.aten.split.Tensor(convert_element_type_2747, 1); convert_element_type_2747 = None + getitem_20685 = split_1108[0] + getitem_20686 = split_1108[1] + getitem_20687 = split_1108[2] + getitem_20688 = split_1108[3] + getitem_20689 = split_1108[4] + getitem_20690 = split_1108[5] + getitem_20691 = split_1108[6] + getitem_20692 = split_1108[7] + getitem_20693 = split_1108[8] + getitem_20694 = split_1108[9] + getitem_20695 = split_1108[10] + getitem_20696 = split_1108[11] + getitem_20697 = split_1108[12] + getitem_20698 = split_1108[13] + getitem_20699 = split_1108[14] + getitem_20700 = split_1108[15] + getitem_20701 = split_1108[16] + getitem_20702 = split_1108[17] + getitem_20703 = split_1108[18] + getitem_20704 = split_1108[19] + getitem_20705 = split_1108[20] + getitem_20706 = split_1108[21] + getitem_20707 = split_1108[22] + getitem_20708 = split_1108[23] + getitem_20709 = split_1108[24] + getitem_20710 = split_1108[25] + getitem_20711 = split_1108[26] + getitem_20712 = split_1108[27] + getitem_20713 = split_1108[28] + getitem_20714 = split_1108[29] + getitem_20715 = split_1108[30] + getitem_20716 = split_1108[31] + getitem_20717 = split_1108[32] + getitem_20718 = split_1108[33] + getitem_20719 = split_1108[34] + getitem_20720 = split_1108[35] + getitem_20721 = split_1108[36] + getitem_20722 = split_1108[37] + getitem_20723 = split_1108[38] + getitem_20724 = split_1108[39] + getitem_20725 = split_1108[40] + getitem_20726 = split_1108[41] + getitem_20727 = split_1108[42] + getitem_20728 = split_1108[43] + getitem_20729 = split_1108[44] + getitem_20730 = split_1108[45] + getitem_20731 = split_1108[46] + getitem_20732 = split_1108[47] + getitem_20733 = split_1108[48] + getitem_20734 = split_1108[49] + getitem_20735 = split_1108[50] + getitem_20736 = split_1108[51] + getitem_20737 = split_1108[52] + getitem_20738 = split_1108[53] + getitem_20739 = split_1108[54] + getitem_20740 = split_1108[55] + getitem_20741 = split_1108[56] + getitem_20742 = split_1108[57] + getitem_20743 = split_1108[58] + getitem_20744 = split_1108[59] + getitem_20745 = split_1108[60] + getitem_20746 = split_1108[61] + getitem_20747 = split_1108[62] + getitem_20748 = split_1108[63]; split_1108 = None + cat_375 = torch.ops.aten.cat.default([getitem_20685, getitem_20686, getitem_20687, getitem_20688, getitem_20689, getitem_20690, getitem_20691, getitem_20692, getitem_20693, getitem_20694, getitem_20695, getitem_20696, getitem_20697, getitem_20698, getitem_20699, getitem_20700, getitem_20701, getitem_20702, getitem_20703, getitem_20704, getitem_20705, getitem_20706, getitem_20707, getitem_20708, getitem_20709, getitem_20710, getitem_20711, getitem_20712, getitem_20713, getitem_20714, getitem_20715, getitem_20716, getitem_20717, getitem_20718, getitem_20719, getitem_20720, getitem_20721, getitem_20722, getitem_20723, getitem_20724, getitem_20725, getitem_20726, getitem_20727, getitem_20728, getitem_20729, getitem_20730, getitem_20731, getitem_20732, getitem_20733, getitem_20734, getitem_20735, getitem_20736, getitem_20737, getitem_20738, getitem_20739, getitem_20740, getitem_20741, getitem_20742, getitem_20743, getitem_20744, getitem_20745, getitem_20746, getitem_20747, getitem_20748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_20685 = getitem_20686 = getitem_20687 = getitem_20688 = getitem_20689 = getitem_20690 = getitem_20691 = getitem_20692 = getitem_20693 = getitem_20694 = getitem_20695 = getitem_20696 = getitem_20697 = getitem_20698 = getitem_20699 = getitem_20700 = getitem_20701 = getitem_20702 = getitem_20703 = getitem_20704 = getitem_20705 = getitem_20706 = getitem_20707 = getitem_20708 = getitem_20709 = getitem_20710 = getitem_20711 = getitem_20712 = getitem_20713 = getitem_20714 = getitem_20715 = getitem_20716 = getitem_20717 = getitem_20718 = getitem_20719 = getitem_20720 = getitem_20721 = getitem_20722 = getitem_20723 = getitem_20724 = getitem_20725 = getitem_20726 = getitem_20727 = getitem_20728 = getitem_20729 = getitem_20730 = getitem_20731 = getitem_20732 = getitem_20733 = getitem_20734 = getitem_20735 = getitem_20736 = getitem_20737 = getitem_20738 = getitem_20739 = getitem_20740 = getitem_20741 = getitem_20742 = getitem_20743 = getitem_20744 = getitem_20745 = getitem_20746 = getitem_20747 = getitem_20748 = None + reduce_scatter_tensor_246 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_375, 'avg', 128, '0'); cat_375 = None + wait_tensor_841 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_246); reduce_scatter_tensor_246 = None + view_2108 = torch.ops.aten.view.default(add_2041, [2, 4096, 2048]); add_2041 = None + convert_element_type_2748 = torch.ops.prims.convert_element_type.default(view_2108, torch.float32); view_2108 = None + convert_element_type_491 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16); primals_157 = None + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_491, 128, '0'); convert_element_type_491 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + convert_element_type_2750 = torch.ops.prims.convert_element_type.default(wait_tensor_185, torch.float32); wait_tensor_185 = None + mul_1883 = torch.ops.aten.mul.Tensor(convert_element_type_2748, convert_element_type_2750); convert_element_type_2750 = None + convert_element_type_492 = torch.ops.prims.convert_element_type.default(add_552, torch.float32); add_552 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_492, rsqrt_29); convert_element_type_492 = None + mul_1885 = torch.ops.aten.mul.Tensor(mul_407, mul_1883) + sum_244 = torch.ops.aten.sum.dim_IntList(mul_1885, [2], True); mul_1885 = None + div_237 = torch.ops.aten.div.Tensor(mul_407, 2048) + mul_1886 = torch.ops.aten.mul.Tensor(div_237, sum_244); div_237 = sum_244 = None + sub_730 = torch.ops.aten.sub.Tensor(mul_1883, mul_1886); mul_1883 = mul_1886 = None + mul_1887 = torch.ops.aten.mul.Tensor(sub_730, rsqrt_29); sub_730 = rsqrt_29 = None + mul_1888 = torch.ops.aten.mul.Tensor(convert_element_type_2748, mul_407); convert_element_type_2748 = mul_407 = None + sum_245 = torch.ops.aten.sum.dim_IntList(mul_1888, [0, 1]); mul_1888 = None + convert_element_type_2751 = torch.ops.prims.convert_element_type.default(mul_1887, torch.bfloat16); mul_1887 = None + add_2042 = torch.ops.aten.add.Tensor(add_2029, convert_element_type_2751); add_2029 = convert_element_type_2751 = None + convert_element_type_default_30 = torch.ops.prims.convert_element_type.default(sum_245, torch.float32); sum_245 = None + reduce_scatter_tensor_247 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_30, 'avg', 128, '0'); convert_element_type_default_30 = None + wait_tensor_842 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_247); reduce_scatter_tensor_247 = None + view_2109 = torch.ops.aten.view.default(add_2042, [8192, 2048]) + permute_1286 = torch.ops.aten.permute.default(view_2109, [1, 0]) + permute_137 = torch.ops.aten.permute.default(getitem_895, [0, 2, 1, 3]) + view_589 = torch.ops.aten.view.default(permute_137, [2, 4096, -1]); permute_137 = None + view_591 = torch.ops.aten.view.default(view_589, [8192, 2048]); view_589 = None + mm_498 = torch.ops.aten.mm.default(permute_1286, view_591); permute_1286 = view_591 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16); primals_156 = None + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_488, 128, '0'); convert_element_type_488 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_138 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + permute_1288 = torch.ops.aten.permute.default(permute_138, [1, 0]); permute_138 = None + mm_499 = torch.ops.aten.mm.default(view_2109, permute_1288); view_2109 = permute_1288 = None + view_2110 = torch.ops.aten.view.default(mm_499, [2, 4096, 2048]); mm_499 = None + convert_element_type_2758 = torch.ops.prims.convert_element_type.default(mm_498, torch.float32); mm_498 = None + reduce_scatter_tensor_248 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2758, 'avg', 128, '0'); convert_element_type_2758 = None + wait_tensor_843 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_248); reduce_scatter_tensor_248 = None + view_2111 = torch.ops.aten.view.default(view_2110, [2, 4096, 16, 128]); view_2110 = None + permute_1290 = torch.ops.aten.permute.default(view_2111, [0, 2, 1, 3]); view_2111 = None + fw_graph17 = self.fw_graph17 + joint_graph17 = self.joint_graph17 + mask_graph17 = self.mask_graph17 + flex_attention_backward_17 = torch.ops.higher_order.flex_attention_backward(permute_134, permute_135, permute_136, getitem_895, getitem_896, permute_1290, None, fw_graph17, joint_graph17, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph17), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_134 = permute_135 = permute_136 = getitem_895 = getitem_896 = permute_1290 = fw_graph17 = joint_graph17 = mask_graph17 = None + getitem_20749 = flex_attention_backward_17[0] + getitem_20750 = flex_attention_backward_17[1] + getitem_20751 = flex_attention_backward_17[2]; flex_attention_backward_17 = None + permute_1291 = torch.ops.aten.permute.default(getitem_20751, [0, 2, 1, 3]); getitem_20751 = None + permute_1292 = torch.ops.aten.permute.default(getitem_20750, [0, 2, 1, 3]); getitem_20750 = None + permute_1293 = torch.ops.aten.permute.default(getitem_20749, [0, 2, 1, 3]); getitem_20749 = None + slice_266 = torch.ops.aten.slice.Tensor(permute_1292, 3, 0, 128) + slice_267 = torch.ops.aten.slice.Tensor(permute_1292, 3, 128, 192); permute_1292 = None + sum_246 = torch.ops.aten.sum.dim_IntList(slice_267, [2], True); slice_267 = None + cat_376 = torch.ops.aten.cat.default([slice_266, permute_1291], 3); slice_266 = permute_1291 = None + view_2112 = torch.ops.aten.view.default(cat_376, [2, 4096, 4096]); cat_376 = None + view_2113 = torch.ops.aten.view.default(view_2112, [8192, 4096]); view_2112 = None + permute_1294 = torch.ops.aten.permute.default(view_2113, [1, 0]) + mm_500 = torch.ops.aten.mm.default(permute_1294, view_586); permute_1294 = view_586 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16); primals_155 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 128, '0'); convert_element_type_485 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + permute_1296 = torch.ops.aten.permute.default(permute_133, [1, 0]); permute_133 = None + mm_501 = torch.ops.aten.mm.default(view_2113, permute_1296); view_2113 = permute_1296 = None + view_2114 = torch.ops.aten.view.default(mm_501, [2, 4096, 512]); mm_501 = None + convert_element_type_2763 = torch.ops.prims.convert_element_type.default(mm_500, torch.float32); mm_500 = None + reduce_scatter_tensor_249 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2763, 'avg', 128, '0'); convert_element_type_2763 = None + wait_tensor_844 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_249); reduce_scatter_tensor_249 = None + convert_element_type_2764 = torch.ops.prims.convert_element_type.default(view_2114, torch.float32); view_2114 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16); primals_154 = None + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 128, '0'); convert_element_type_482 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_2766 = torch.ops.prims.convert_element_type.default(wait_tensor_182, torch.float32); wait_tensor_182 = None + mul_1889 = torch.ops.aten.mul.Tensor(convert_element_type_2764, convert_element_type_2766); convert_element_type_2766 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(getitem_891, torch.float32); getitem_891 = None + mul_405 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_28); convert_element_type_483 = None + mul_1891 = torch.ops.aten.mul.Tensor(mul_405, mul_1889) + sum_247 = torch.ops.aten.sum.dim_IntList(mul_1891, [2], True); mul_1891 = None + div_238 = torch.ops.aten.div.Tensor(mul_405, 512) + mul_1892 = torch.ops.aten.mul.Tensor(div_238, sum_247); div_238 = sum_247 = None + sub_731 = torch.ops.aten.sub.Tensor(mul_1889, mul_1892); mul_1889 = mul_1892 = None + mul_1893 = torch.ops.aten.mul.Tensor(sub_731, rsqrt_28); sub_731 = rsqrt_28 = None + mul_1894 = torch.ops.aten.mul.Tensor(convert_element_type_2764, mul_405); convert_element_type_2764 = mul_405 = None + sum_248 = torch.ops.aten.sum.dim_IntList(mul_1894, [0, 1]); mul_1894 = None + convert_element_type_2767 = torch.ops.prims.convert_element_type.default(mul_1893, torch.bfloat16); mul_1893 = None + convert_element_type_default_29 = torch.ops.prims.convert_element_type.default(sum_248, torch.float32); sum_248 = None + reduce_scatter_tensor_250 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_29, 'avg', 128, '0'); convert_element_type_default_29 = None + wait_tensor_845 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_250); reduce_scatter_tensor_250 = None + convert_element_type_2770 = torch.ops.prims.convert_element_type.default(sum_246, torch.float32); sum_246 = None + view_2115 = torch.ops.aten.view.default(convert_element_type_2770, [2, 4096, 1, 32, 2]); convert_element_type_2770 = None + view_as_complex_88 = torch.ops.aten.view_as_complex.default(view_2115); view_2115 = None + mul_1895 = torch.ops.aten.mul.Tensor(view_as_complex_88, clone_9); view_as_complex_88 = None + view_as_real_88 = torch.ops.aten.view_as_real.default(mul_1895); mul_1895 = None + view_2116 = torch.ops.aten.view.default(view_as_real_88, [2, 4096, 1, 64]); view_as_real_88 = None + convert_element_type_2771 = torch.ops.prims.convert_element_type.default(view_2116, torch.bfloat16); view_2116 = None + squeeze_43 = torch.ops.aten.squeeze.dim(convert_element_type_2771, 2); convert_element_type_2771 = None + cat_377 = torch.ops.aten.cat.default([convert_element_type_2767, squeeze_43], 2); convert_element_type_2767 = squeeze_43 = None + view_2117 = torch.ops.aten.view.default(cat_377, [8192, 576]); cat_377 = None + permute_1298 = torch.ops.aten.permute.default(view_2117, [1, 0]) + mm_502 = torch.ops.aten.mm.default(permute_1298, view_572); permute_1298 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16); primals_153 = None + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_477, 128, '0'); convert_element_type_477 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + slice_55 = torch.ops.aten.slice.Tensor(wait_tensor_181, 0, 0, 576); wait_tensor_181 = None + permute_132 = torch.ops.aten.permute.default(slice_55, [1, 0]); slice_55 = None + permute_1300 = torch.ops.aten.permute.default(permute_132, [1, 0]); permute_132 = None + mm_503 = torch.ops.aten.mm.default(view_2117, permute_1300); view_2117 = permute_1300 = None + view_2118 = torch.ops.aten.view.default(mm_503, [2, 4096, 2048]); mm_503 = None + convert_element_type_2776 = torch.ops.prims.convert_element_type.default(mm_502, torch.float32); mm_502 = None + split_1109 = torch.ops.aten.split.Tensor(convert_element_type_2776, 5); convert_element_type_2776 = None + getitem_20753 = split_1109[0] + getitem_20754 = split_1109[1] + getitem_20755 = split_1109[2] + getitem_20756 = split_1109[3] + getitem_20757 = split_1109[4] + getitem_20758 = split_1109[5] + getitem_20759 = split_1109[6] + getitem_20760 = split_1109[7] + getitem_20761 = split_1109[8] + getitem_20762 = split_1109[9] + getitem_20763 = split_1109[10] + getitem_20764 = split_1109[11] + getitem_20765 = split_1109[12] + getitem_20766 = split_1109[13] + getitem_20767 = split_1109[14] + getitem_20768 = split_1109[15] + getitem_20769 = split_1109[16] + getitem_20770 = split_1109[17] + getitem_20771 = split_1109[18] + getitem_20772 = split_1109[19] + getitem_20773 = split_1109[20] + getitem_20774 = split_1109[21] + getitem_20775 = split_1109[22] + getitem_20776 = split_1109[23] + getitem_20777 = split_1109[24] + getitem_20778 = split_1109[25] + getitem_20779 = split_1109[26] + getitem_20780 = split_1109[27] + getitem_20781 = split_1109[28] + getitem_20782 = split_1109[29] + getitem_20783 = split_1109[30] + getitem_20784 = split_1109[31] + getitem_20785 = split_1109[32] + getitem_20786 = split_1109[33] + getitem_20787 = split_1109[34] + getitem_20788 = split_1109[35] + getitem_20789 = split_1109[36] + getitem_20790 = split_1109[37] + getitem_20791 = split_1109[38] + getitem_20792 = split_1109[39] + getitem_20793 = split_1109[40] + getitem_20794 = split_1109[41] + getitem_20795 = split_1109[42] + getitem_20796 = split_1109[43] + getitem_20797 = split_1109[44] + getitem_20798 = split_1109[45] + getitem_20799 = split_1109[46] + getitem_20800 = split_1109[47] + getitem_20801 = split_1109[48] + getitem_20802 = split_1109[49] + getitem_20803 = split_1109[50] + getitem_20804 = split_1109[51] + getitem_20805 = split_1109[52] + getitem_20806 = split_1109[53] + getitem_20807 = split_1109[54] + getitem_20808 = split_1109[55] + getitem_20809 = split_1109[56] + getitem_20810 = split_1109[57] + getitem_20811 = split_1109[58] + getitem_20812 = split_1109[59] + getitem_20813 = split_1109[60] + getitem_20814 = split_1109[61] + getitem_20815 = split_1109[62] + getitem_20816 = split_1109[63] + getitem_20817 = split_1109[64] + getitem_20818 = split_1109[65] + getitem_20819 = split_1109[66] + getitem_20820 = split_1109[67] + getitem_20821 = split_1109[68] + getitem_20822 = split_1109[69] + getitem_20823 = split_1109[70] + getitem_20824 = split_1109[71] + getitem_20825 = split_1109[72] + getitem_20826 = split_1109[73] + getitem_20827 = split_1109[74] + getitem_20828 = split_1109[75] + getitem_20829 = split_1109[76] + getitem_20830 = split_1109[77] + getitem_20831 = split_1109[78] + getitem_20832 = split_1109[79] + getitem_20833 = split_1109[80] + getitem_20834 = split_1109[81] + getitem_20835 = split_1109[82] + getitem_20836 = split_1109[83] + getitem_20837 = split_1109[84] + getitem_20838 = split_1109[85] + getitem_20839 = split_1109[86] + getitem_20840 = split_1109[87] + getitem_20841 = split_1109[88] + getitem_20842 = split_1109[89] + getitem_20843 = split_1109[90] + getitem_20844 = split_1109[91] + getitem_20845 = split_1109[92] + getitem_20846 = split_1109[93] + getitem_20847 = split_1109[94] + getitem_20848 = split_1109[95] + getitem_20849 = split_1109[96] + getitem_20850 = split_1109[97] + getitem_20851 = split_1109[98] + getitem_20852 = split_1109[99] + getitem_20853 = split_1109[100] + getitem_20854 = split_1109[101] + getitem_20855 = split_1109[102] + getitem_20856 = split_1109[103] + getitem_20857 = split_1109[104] + getitem_20858 = split_1109[105] + getitem_20859 = split_1109[106] + getitem_20860 = split_1109[107] + getitem_20861 = split_1109[108] + getitem_20862 = split_1109[109] + getitem_20863 = split_1109[110] + getitem_20864 = split_1109[111] + getitem_20865 = split_1109[112] + getitem_20866 = split_1109[113] + getitem_20867 = split_1109[114] + getitem_20868 = split_1109[115]; split_1109 = None + constant_pad_nd_1373 = torch.ops.aten.constant_pad_nd.default(getitem_20868, [0, 0, 0, 4], 0.0); getitem_20868 = None + cat_378 = torch.ops.aten.cat.default([getitem_20753, getitem_20754, getitem_20755, getitem_20756, getitem_20757, getitem_20758, getitem_20759, getitem_20760, getitem_20761, getitem_20762, getitem_20763, getitem_20764, getitem_20765, getitem_20766, getitem_20767, getitem_20768, getitem_20769, getitem_20770, getitem_20771, getitem_20772, getitem_20773, getitem_20774, getitem_20775, getitem_20776, getitem_20777, getitem_20778, getitem_20779, getitem_20780, getitem_20781, getitem_20782, getitem_20783, getitem_20784, getitem_20785, getitem_20786, getitem_20787, getitem_20788, getitem_20789, getitem_20790, getitem_20791, getitem_20792, getitem_20793, getitem_20794, getitem_20795, getitem_20796, getitem_20797, getitem_20798, getitem_20799, getitem_20800, getitem_20801, getitem_20802, getitem_20803, getitem_20804, getitem_20805, getitem_20806, getitem_20807, getitem_20808, getitem_20809, getitem_20810, getitem_20811, getitem_20812, getitem_20813, getitem_20814, getitem_20815, getitem_20816, getitem_20817, getitem_20818, getitem_20819, getitem_20820, getitem_20821, getitem_20822, getitem_20823, getitem_20824, getitem_20825, getitem_20826, getitem_20827, getitem_20828, getitem_20829, getitem_20830, getitem_20831, getitem_20832, getitem_20833, getitem_20834, getitem_20835, getitem_20836, getitem_20837, getitem_20838, getitem_20839, getitem_20840, getitem_20841, getitem_20842, getitem_20843, getitem_20844, getitem_20845, getitem_20846, getitem_20847, getitem_20848, getitem_20849, getitem_20850, getitem_20851, getitem_20852, getitem_20853, getitem_20854, getitem_20855, getitem_20856, getitem_20857, getitem_20858, getitem_20859, getitem_20860, getitem_20861, getitem_20862, getitem_20863, getitem_20864, getitem_20865, getitem_20866, getitem_20867, constant_pad_nd_1373, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_20753 = getitem_20754 = getitem_20755 = getitem_20756 = getitem_20757 = getitem_20758 = getitem_20759 = getitem_20760 = getitem_20761 = getitem_20762 = getitem_20763 = getitem_20764 = getitem_20765 = getitem_20766 = getitem_20767 = getitem_20768 = getitem_20769 = getitem_20770 = getitem_20771 = getitem_20772 = getitem_20773 = getitem_20774 = getitem_20775 = getitem_20776 = getitem_20777 = getitem_20778 = getitem_20779 = getitem_20780 = getitem_20781 = getitem_20782 = getitem_20783 = getitem_20784 = getitem_20785 = getitem_20786 = getitem_20787 = getitem_20788 = getitem_20789 = getitem_20790 = getitem_20791 = getitem_20792 = getitem_20793 = getitem_20794 = getitem_20795 = getitem_20796 = getitem_20797 = getitem_20798 = getitem_20799 = getitem_20800 = getitem_20801 = getitem_20802 = getitem_20803 = getitem_20804 = getitem_20805 = getitem_20806 = getitem_20807 = getitem_20808 = getitem_20809 = getitem_20810 = getitem_20811 = getitem_20812 = getitem_20813 = getitem_20814 = getitem_20815 = getitem_20816 = getitem_20817 = getitem_20818 = getitem_20819 = getitem_20820 = getitem_20821 = getitem_20822 = getitem_20823 = getitem_20824 = getitem_20825 = getitem_20826 = getitem_20827 = getitem_20828 = getitem_20829 = getitem_20830 = getitem_20831 = getitem_20832 = getitem_20833 = getitem_20834 = getitem_20835 = getitem_20836 = getitem_20837 = getitem_20838 = getitem_20839 = getitem_20840 = getitem_20841 = getitem_20842 = getitem_20843 = getitem_20844 = getitem_20845 = getitem_20846 = getitem_20847 = getitem_20848 = getitem_20849 = getitem_20850 = getitem_20851 = getitem_20852 = getitem_20853 = getitem_20854 = getitem_20855 = getitem_20856 = getitem_20857 = getitem_20858 = getitem_20859 = getitem_20860 = getitem_20861 = getitem_20862 = getitem_20863 = getitem_20864 = getitem_20865 = getitem_20866 = getitem_20867 = constant_pad_nd_1373 = None + reduce_scatter_tensor_251 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_378, 'avg', 128, '0'); cat_378 = None + wait_tensor_846 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_251); reduce_scatter_tensor_251 = None + slice_268 = torch.ops.aten.slice.Tensor(permute_1293, 3, 0, 128) + slice_269 = torch.ops.aten.slice.Tensor(permute_1293, 3, 128, 192); permute_1293 = None + convert_element_type_2777 = torch.ops.prims.convert_element_type.default(slice_269, torch.float32); slice_269 = None + view_2119 = torch.ops.aten.view.default(convert_element_type_2777, [2, 4096, 16, 32, 2]); convert_element_type_2777 = None + view_as_complex_89 = torch.ops.aten.view_as_complex.default(view_2119); view_2119 = None + mul_1896 = torch.ops.aten.mul.Tensor(view_as_complex_89, clone_9); view_as_complex_89 = None + view_as_real_89 = torch.ops.aten.view_as_real.default(mul_1896); mul_1896 = None + view_2120 = torch.ops.aten.view.default(view_as_real_89, [2, 4096, 16, 64]); view_as_real_89 = None + convert_element_type_2778 = torch.ops.prims.convert_element_type.default(view_2120, torch.bfloat16); view_2120 = None + cat_379 = torch.ops.aten.cat.default([slice_268, convert_element_type_2778], 3); slice_268 = convert_element_type_2778 = None + view_2121 = torch.ops.aten.view.default(cat_379, [2, 4096, 3072]); cat_379 = None + view_2122 = torch.ops.aten.view.default(view_2121, [8192, 3072]); view_2121 = None + permute_1302 = torch.ops.aten.permute.default(view_2122, [1, 0]) + mm_504 = torch.ops.aten.mm.default(permute_1302, view_572); permute_1302 = view_572 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16); primals_152 = None + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 128, '0'); convert_element_type_472 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + permute_1304 = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None + mm_505 = torch.ops.aten.mm.default(view_2122, permute_1304); view_2122 = permute_1304 = None + view_2123 = torch.ops.aten.view.default(mm_505, [2, 4096, 2048]); mm_505 = None + add_2043 = torch.ops.aten.add.Tensor(view_2118, view_2123); view_2118 = view_2123 = None + convert_element_type_2783 = torch.ops.prims.convert_element_type.default(mm_504, torch.float32); mm_504 = None + reduce_scatter_tensor_252 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2783, 'avg', 128, '0'); convert_element_type_2783 = None + wait_tensor_847 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_252); reduce_scatter_tensor_252 = None + convert_element_type_2784 = torch.ops.prims.convert_element_type.default(add_2043, torch.float32); add_2043 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16); primals_151 = None + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 128, '0'); convert_element_type_469 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + convert_element_type_2786 = torch.ops.prims.convert_element_type.default(wait_tensor_179, torch.float32); wait_tensor_179 = None + mul_1897 = torch.ops.aten.mul.Tensor(convert_element_type_2784, convert_element_type_2786); convert_element_type_2786 = None + convert_element_type_470 = torch.ops.prims.convert_element_type.default(add_549, torch.float32); add_549 = None + mul_401 = torch.ops.aten.mul.Tensor(convert_element_type_470, rsqrt_27); convert_element_type_470 = None + mul_1899 = torch.ops.aten.mul.Tensor(mul_401, mul_1897) + sum_249 = torch.ops.aten.sum.dim_IntList(mul_1899, [2], True); mul_1899 = None + div_239 = torch.ops.aten.div.Tensor(mul_401, 2048) + mul_1900 = torch.ops.aten.mul.Tensor(div_239, sum_249); div_239 = sum_249 = None + sub_732 = torch.ops.aten.sub.Tensor(mul_1897, mul_1900); mul_1897 = mul_1900 = None + mul_1901 = torch.ops.aten.mul.Tensor(sub_732, rsqrt_27); sub_732 = rsqrt_27 = None + mul_1902 = torch.ops.aten.mul.Tensor(convert_element_type_2784, mul_401); convert_element_type_2784 = mul_401 = None + sum_250 = torch.ops.aten.sum.dim_IntList(mul_1902, [0, 1]); mul_1902 = None + convert_element_type_2787 = torch.ops.prims.convert_element_type.default(mul_1901, torch.bfloat16); mul_1901 = None + add_2044 = torch.ops.aten.add.Tensor(add_2042, convert_element_type_2787); add_2042 = convert_element_type_2787 = None + convert_element_type_default_28 = torch.ops.prims.convert_element_type.default(sum_250, torch.float32); sum_250 = None + reduce_scatter_tensor_253 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_28, 'avg', 128, '0'); convert_element_type_default_28 = None + wait_tensor_848 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_253); reduce_scatter_tensor_253 = None + view_2124 = torch.ops.aten.view.default(add_2044, [8192, 2048]) + unsqueeze_71 = torch.ops.aten.unsqueeze.default(view_2124, 1) + convert_element_type_2790 = torch.ops.prims.convert_element_type.default(unsqueeze_71, torch.float32); unsqueeze_71 = None + bmm_62 = torch.ops.aten.bmm.default(permute_1306, convert_element_type_2790); permute_1306 = None + bmm_63 = torch.ops.aten.bmm.default(convert_element_type_2790, permute_1307); convert_element_type_2790 = permute_1307 = None + convert_element_type_2791 = torch.ops.prims.convert_element_type.default(bmm_62, torch.bfloat16); bmm_62 = None + view_2125 = torch.ops.aten.view.default(bmm_63, [8192, 6]); bmm_63 = None + view_2126 = torch.ops.aten.view.default(convert_element_type_2791, [49152, 2048]); convert_element_type_2791 = None + index_88 = torch.ops.aten.index.Tensor(view_2126, [getitem_791]); view_2126 = getitem_791 = None + permute_1308 = torch.ops.aten.permute.default(view_2124, [1, 0]) + mm_506 = torch.ops.aten.mm.default(permute_1308, mul_398); permute_1308 = mul_398 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16); primals_150 = None + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_464, 128, '0'); convert_element_type_464 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + permute_1310 = torch.ops.aten.permute.default(permute_130, [1, 0]); permute_130 = None + mm_507 = torch.ops.aten.mm.default(view_2124, permute_1310); view_2124 = permute_1310 = None + convert_element_type_2796 = torch.ops.prims.convert_element_type.default(mm_506, torch.float32); mm_506 = None + reduce_scatter_tensor_254 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2796, 'avg', 128, '0'); convert_element_type_2796 = None + wait_tensor_849 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_254); reduce_scatter_tensor_254 = None + convert_element_type_459 = torch.ops.prims.convert_element_type.default(mm_68, torch.float32); mm_68 = None + neg_16 = torch.ops.aten.neg.default(convert_element_type_459) + exp_24 = torch.ops.aten.exp.default(neg_16); neg_16 = None + add_544 = torch.ops.aten.add.Tensor(exp_24, 1); exp_24 = None + div_40 = torch.ops.aten.div.Tensor(convert_element_type_459, add_544) + convert_element_type_460 = torch.ops.prims.convert_element_type.default(div_40, torch.bfloat16); div_40 = None + mul_1903 = torch.ops.aten.mul.Tensor(mm_507, convert_element_type_460); convert_element_type_460 = None + mul_1904 = torch.ops.aten.mul.Tensor(mm_507, mm_69); mm_507 = mm_69 = None + permute_1312 = torch.ops.aten.permute.default(mul_1903, [1, 0]) + mm_508 = torch.ops.aten.mm.default(permute_1312, view_527); permute_1312 = None + convert_element_type_461 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16); primals_149 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_461, 128, '0'); convert_element_type_461 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_177, [1, 0]); wait_tensor_177 = None + permute_1314 = torch.ops.aten.permute.default(permute_129, [1, 0]); permute_129 = None + mm_509 = torch.ops.aten.mm.default(mul_1903, permute_1314); mul_1903 = permute_1314 = None + convert_element_type_2801 = torch.ops.prims.convert_element_type.default(mm_508, torch.float32); mm_508 = None + reduce_scatter_tensor_255 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2801, 'avg', 128, '0'); convert_element_type_2801 = None + wait_tensor_850 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_255); reduce_scatter_tensor_255 = None + convert_element_type_2802 = torch.ops.prims.convert_element_type.default(mul_1904, torch.float32); mul_1904 = None + reciprocal_36 = torch.ops.aten.reciprocal.default(add_544); add_544 = None + mul_1905 = torch.ops.aten.mul.Tensor(reciprocal_36, 1); reciprocal_36 = None + mul_1906 = torch.ops.aten.mul.Tensor(convert_element_type_2802, mul_1905); convert_element_type_2802 = None + sub_733 = torch.ops.aten.sub.Tensor(1, mul_1905); mul_1905 = None + mul_1907 = torch.ops.aten.mul.Tensor(convert_element_type_459, sub_733); convert_element_type_459 = sub_733 = None + add_2046 = torch.ops.aten.add.Tensor(mul_1907, 1); mul_1907 = None + mul_1908 = torch.ops.aten.mul.Tensor(mul_1906, add_2046); mul_1906 = add_2046 = None + convert_element_type_2804 = torch.ops.prims.convert_element_type.default(mul_1908, torch.bfloat16); mul_1908 = None + permute_1316 = torch.ops.aten.permute.default(convert_element_type_2804, [1, 0]) + mm_510 = torch.ops.aten.mm.default(permute_1316, view_527); permute_1316 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16); primals_148 = None + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_456, 128, '0'); convert_element_type_456 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + permute_1318 = torch.ops.aten.permute.default(permute_128, [1, 0]); permute_128 = None + mm_511 = torch.ops.aten.mm.default(convert_element_type_2804, permute_1318); convert_element_type_2804 = permute_1318 = None + add_2047 = torch.ops.aten.add.Tensor(mm_509, mm_511); mm_509 = mm_511 = None + convert_element_type_2809 = torch.ops.prims.convert_element_type.default(mm_510, torch.float32); mm_510 = None + reduce_scatter_tensor_256 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2809, 'avg', 128, '0'); convert_element_type_2809 = None + wait_tensor_851 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_256); reduce_scatter_tensor_256 = None + all_to_all_single_114 = torch.ops._c10d_functional.all_to_all_single.default(index_88, [_local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127], [_local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119], '1033'); index_88 = None + wait_tensor_852 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_114); all_to_all_single_114 = None + full_456 = torch.ops.aten.full.default([sym_size_int_29, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_29 = None + slice_scatter_18 = torch.ops.aten.slice_scatter.default(full_456, wait_tensor_852, 0, 0, -1); wait_tensor_852 = None + index_89 = torch.ops.aten.index.Tensor(slice_scatter_18, [getitem_792]); slice_scatter_18 = None + permute_1320 = torch.ops.aten.permute.default(index_89, [1, 0]) + _grouped_mm_186 = torch.ops.aten._grouped_mm.default(permute_1320, mul_378, cumsum_23); permute_1320 = mul_378 = None + _grouped_mm_187 = torch.ops.aten._grouped_mm.default(index_89, permute_1322, cumsum_23); index_89 = permute_1322 = None + convert_element_type_454 = torch.ops.prims.convert_element_type.default(_grouped_mm_21, torch.float32); _grouped_mm_21 = None + neg_15 = torch.ops.aten.neg.default(convert_element_type_454) + exp_23 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_508 = torch.ops.aten.add.Tensor(exp_23, 1); exp_23 = None + div_39 = torch.ops.aten.div.Tensor(convert_element_type_454, add_508) + convert_element_type_455 = torch.ops.prims.convert_element_type.default(div_39, torch.bfloat16); div_39 = None + mul_1909 = torch.ops.aten.mul.Tensor(_grouped_mm_187, convert_element_type_455); convert_element_type_455 = None + mul_1910 = torch.ops.aten.mul.Tensor(_grouped_mm_187, _grouped_mm_22); _grouped_mm_187 = _grouped_mm_22 = None + permute_1324 = torch.ops.aten.permute.default(mul_1909, [1, 0]) + _grouped_mm_188 = torch.ops.aten._grouped_mm.default(permute_1324, index_15, cumsum_23); permute_1324 = None + _grouped_mm_189 = torch.ops.aten._grouped_mm.default(mul_1909, permute_1326, cumsum_23); mul_1909 = permute_1326 = None + convert_element_type_2810 = torch.ops.prims.convert_element_type.default(mul_1910, torch.float32); mul_1910 = None + reciprocal_37 = torch.ops.aten.reciprocal.default(add_508); add_508 = None + mul_1911 = torch.ops.aten.mul.Tensor(reciprocal_37, 1); reciprocal_37 = None + mul_1912 = torch.ops.aten.mul.Tensor(convert_element_type_2810, mul_1911); convert_element_type_2810 = None + sub_734 = torch.ops.aten.sub.Tensor(1, mul_1911); mul_1911 = None + mul_1913 = torch.ops.aten.mul.Tensor(convert_element_type_454, sub_734); convert_element_type_454 = sub_734 = None + add_2049 = torch.ops.aten.add.Tensor(mul_1913, 1); mul_1913 = None + mul_1914 = torch.ops.aten.mul.Tensor(mul_1912, add_2049); mul_1912 = add_2049 = None + convert_element_type_2812 = torch.ops.prims.convert_element_type.default(mul_1914, torch.bfloat16); mul_1914 = None + permute_1328 = torch.ops.aten.permute.default(convert_element_type_2812, [1, 0]) + _grouped_mm_190 = torch.ops.aten._grouped_mm.default(permute_1328, index_15, cumsum_23); permute_1328 = index_15 = None + _grouped_mm_191 = torch.ops.aten._grouped_mm.default(convert_element_type_2812, permute_1330, cumsum_23); convert_element_type_2812 = permute_1330 = cumsum_23 = None + add_2050 = torch.ops.aten.add.Tensor(_grouped_mm_189, _grouped_mm_191); _grouped_mm_189 = _grouped_mm_191 = None + convert_element_type_2813 = torch.ops.prims.convert_element_type.default(_grouped_mm_188, torch.float32); _grouped_mm_188 = None + div_240 = torch.ops.aten.div.Tensor(convert_element_type_2813, 128); convert_element_type_2813 = None + split_1111 = torch.ops.aten.split.Tensor(div_240, 88, 1); div_240 = None + getitem_20885 = split_1111[0] + getitem_20902 = split_1111[1] + getitem_20919 = split_1111[2] + getitem_20936 = split_1111[3] + getitem_20953 = split_1111[4] + getitem_20970 = split_1111[5] + getitem_20987 = split_1111[6] + getitem_21004 = split_1111[7] + getitem_21021 = split_1111[8] + getitem_21038 = split_1111[9] + getitem_21055 = split_1111[10] + getitem_21072 = split_1111[11] + getitem_21089 = split_1111[12] + getitem_21106 = split_1111[13] + getitem_21123 = split_1111[14] + getitem_21140 = split_1111[15]; split_1111 = None + cat_380 = torch.ops.aten.cat.default([getitem_20885, getitem_20902, getitem_20919, getitem_20936, getitem_20953, getitem_20970, getitem_20987, getitem_21004, getitem_21021, getitem_21038, getitem_21055, getitem_21072, getitem_21089, getitem_21106, getitem_21123, getitem_21140]); getitem_20885 = getitem_20902 = getitem_20919 = getitem_20936 = getitem_20953 = getitem_20970 = getitem_20987 = getitem_21004 = getitem_21021 = getitem_21038 = getitem_21055 = getitem_21072 = getitem_21089 = getitem_21106 = getitem_21123 = getitem_21140 = None + reduce_scatter_tensor_257 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_380, 'sum', 16, '1025'); cat_380 = None + wait_tensor_853 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_257); reduce_scatter_tensor_257 = None + convert_element_type_2814 = torch.ops.prims.convert_element_type.default(_grouped_mm_186, torch.float32); _grouped_mm_186 = None + div_241 = torch.ops.aten.div.Tensor(convert_element_type_2814, 128); convert_element_type_2814 = None + split_1128 = torch.ops.aten.split.Tensor(div_241, 128, 1); div_241 = None + getitem_21157 = split_1128[0] + getitem_21174 = split_1128[1] + getitem_21191 = split_1128[2] + getitem_21208 = split_1128[3] + getitem_21225 = split_1128[4] + getitem_21242 = split_1128[5] + getitem_21259 = split_1128[6] + getitem_21276 = split_1128[7] + getitem_21293 = split_1128[8] + getitem_21310 = split_1128[9] + getitem_21327 = split_1128[10] + getitem_21344 = split_1128[11] + getitem_21361 = split_1128[12] + getitem_21378 = split_1128[13] + getitem_21395 = split_1128[14] + getitem_21412 = split_1128[15]; split_1128 = None + cat_381 = torch.ops.aten.cat.default([getitem_21157, getitem_21174, getitem_21191, getitem_21208, getitem_21225, getitem_21242, getitem_21259, getitem_21276, getitem_21293, getitem_21310, getitem_21327, getitem_21344, getitem_21361, getitem_21378, getitem_21395, getitem_21412]); getitem_21157 = getitem_21174 = getitem_21191 = getitem_21208 = getitem_21225 = getitem_21242 = getitem_21259 = getitem_21276 = getitem_21293 = getitem_21310 = getitem_21327 = getitem_21344 = getitem_21361 = getitem_21378 = getitem_21395 = getitem_21412 = None + reduce_scatter_tensor_258 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_381, 'sum', 16, '1025'); cat_381 = None + wait_tensor_854 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_258); reduce_scatter_tensor_258 = None + convert_element_type_2815 = torch.ops.prims.convert_element_type.default(_grouped_mm_190, torch.float32); _grouped_mm_190 = None + div_242 = torch.ops.aten.div.Tensor(convert_element_type_2815, 128); convert_element_type_2815 = None + split_1145 = torch.ops.aten.split.Tensor(div_242, 88, 1); div_242 = None + getitem_21429 = split_1145[0] + getitem_21446 = split_1145[1] + getitem_21463 = split_1145[2] + getitem_21480 = split_1145[3] + getitem_21497 = split_1145[4] + getitem_21514 = split_1145[5] + getitem_21531 = split_1145[6] + getitem_21548 = split_1145[7] + getitem_21565 = split_1145[8] + getitem_21582 = split_1145[9] + getitem_21599 = split_1145[10] + getitem_21616 = split_1145[11] + getitem_21633 = split_1145[12] + getitem_21650 = split_1145[13] + getitem_21667 = split_1145[14] + getitem_21684 = split_1145[15]; split_1145 = None + cat_382 = torch.ops.aten.cat.default([getitem_21429, getitem_21446, getitem_21463, getitem_21480, getitem_21497, getitem_21514, getitem_21531, getitem_21548, getitem_21565, getitem_21582, getitem_21599, getitem_21616, getitem_21633, getitem_21650, getitem_21667, getitem_21684]); getitem_21429 = getitem_21446 = getitem_21463 = getitem_21480 = getitem_21497 = getitem_21514 = getitem_21531 = getitem_21548 = getitem_21565 = getitem_21582 = getitem_21599 = getitem_21616 = getitem_21633 = getitem_21650 = getitem_21667 = getitem_21684 = None + reduce_scatter_tensor_259 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_382, 'sum', 16, '1025'); cat_382 = None + wait_tensor_855 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_259); reduce_scatter_tensor_259 = None + index_put_88 = torch.ops.aten.index_put.default(full_456, [getitem_792], add_2050, True); full_456 = getitem_792 = add_2050 = None + slice_270 = torch.ops.aten.slice.Tensor(index_put_88, 0, 0, add_2051); index_put_88 = add_2051 = None + all_to_all_single_115 = torch.ops._c10d_functional.all_to_all_single.default(slice_270, [_local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119], [_local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127], '1033'); slice_270 = _local_scalar_dense_112 = _local_scalar_dense_113 = _local_scalar_dense_114 = _local_scalar_dense_115 = _local_scalar_dense_116 = _local_scalar_dense_117 = _local_scalar_dense_118 = _local_scalar_dense_119 = _local_scalar_dense_120 = _local_scalar_dense_121 = _local_scalar_dense_122 = _local_scalar_dense_123 = _local_scalar_dense_124 = _local_scalar_dense_125 = _local_scalar_dense_126 = _local_scalar_dense_127 = None + wait_tensor_856 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_115); all_to_all_single_115 = None + index_put_89 = torch.ops.aten.index_put.default(full_default_52, [div_37], wait_tensor_856, True); div_37 = wait_tensor_856 = None + add_2055 = torch.ops.aten.add.Tensor(add_2047, index_put_89); add_2047 = index_put_89 = None + mul_1915 = torch.ops.aten.mul.Tensor(view_2125, 1.0); view_2125 = None + scatter_add_18 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_789, mul_1915); getitem_789 = mul_1915 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(mm_67, torch.float32); mm_67 = None + sub_168 = torch.ops.aten.sub.Tensor(convert_element_type_443, amax_7); convert_element_type_443 = amax_7 = None + exp_22 = torch.ops.aten.exp.default(sub_168); sub_168 = None + div_36 = torch.ops.aten.div.Tensor(exp_22, sum_29); exp_22 = sum_29 = None + mul_1916 = torch.ops.aten.mul.Tensor(scatter_add_18, div_36); scatter_add_18 = None + sum_251 = torch.ops.aten.sum.dim_IntList(mul_1916, [1], True) + neg_109 = torch.ops.aten.neg.default(div_36); div_36 = None + fma_18 = torch.ops.prims.fma.default(neg_109, sum_251, mul_1916); neg_109 = sum_251 = mul_1916 = None + convert_element_type_2816 = torch.ops.prims.convert_element_type.default(fma_18, torch.bfloat16); fma_18 = None + permute_1332 = torch.ops.aten.permute.default(convert_element_type_2816, [1, 0]) + mm_512 = torch.ops.aten.mm.default(permute_1332, view_527); permute_1332 = view_527 = None + convert_element_type_440 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_440, 128, '0'); convert_element_type_440 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + slice_51 = torch.ops.aten.slice.Tensor(wait_tensor_165, 0, 0, 64); wait_tensor_165 = None + permute_124 = torch.ops.aten.permute.default(slice_51, [1, 0]); slice_51 = None + permute_1334 = torch.ops.aten.permute.default(permute_124, [1, 0]); permute_124 = None + mm_513 = torch.ops.aten.mm.default(convert_element_type_2816, permute_1334); convert_element_type_2816 = permute_1334 = None + add_2056 = torch.ops.aten.add.Tensor(add_2055, mm_513); add_2055 = mm_513 = None + convert_element_type_2821 = torch.ops.prims.convert_element_type.default(mm_512, torch.float32); mm_512 = None + split_1161 = torch.ops.aten.split.Tensor(convert_element_type_2821, 1); convert_element_type_2821 = None + getitem_21685 = split_1161[0] + getitem_21686 = split_1161[1] + getitem_21687 = split_1161[2] + getitem_21688 = split_1161[3] + getitem_21689 = split_1161[4] + getitem_21690 = split_1161[5] + getitem_21691 = split_1161[6] + getitem_21692 = split_1161[7] + getitem_21693 = split_1161[8] + getitem_21694 = split_1161[9] + getitem_21695 = split_1161[10] + getitem_21696 = split_1161[11] + getitem_21697 = split_1161[12] + getitem_21698 = split_1161[13] + getitem_21699 = split_1161[14] + getitem_21700 = split_1161[15] + getitem_21701 = split_1161[16] + getitem_21702 = split_1161[17] + getitem_21703 = split_1161[18] + getitem_21704 = split_1161[19] + getitem_21705 = split_1161[20] + getitem_21706 = split_1161[21] + getitem_21707 = split_1161[22] + getitem_21708 = split_1161[23] + getitem_21709 = split_1161[24] + getitem_21710 = split_1161[25] + getitem_21711 = split_1161[26] + getitem_21712 = split_1161[27] + getitem_21713 = split_1161[28] + getitem_21714 = split_1161[29] + getitem_21715 = split_1161[30] + getitem_21716 = split_1161[31] + getitem_21717 = split_1161[32] + getitem_21718 = split_1161[33] + getitem_21719 = split_1161[34] + getitem_21720 = split_1161[35] + getitem_21721 = split_1161[36] + getitem_21722 = split_1161[37] + getitem_21723 = split_1161[38] + getitem_21724 = split_1161[39] + getitem_21725 = split_1161[40] + getitem_21726 = split_1161[41] + getitem_21727 = split_1161[42] + getitem_21728 = split_1161[43] + getitem_21729 = split_1161[44] + getitem_21730 = split_1161[45] + getitem_21731 = split_1161[46] + getitem_21732 = split_1161[47] + getitem_21733 = split_1161[48] + getitem_21734 = split_1161[49] + getitem_21735 = split_1161[50] + getitem_21736 = split_1161[51] + getitem_21737 = split_1161[52] + getitem_21738 = split_1161[53] + getitem_21739 = split_1161[54] + getitem_21740 = split_1161[55] + getitem_21741 = split_1161[56] + getitem_21742 = split_1161[57] + getitem_21743 = split_1161[58] + getitem_21744 = split_1161[59] + getitem_21745 = split_1161[60] + getitem_21746 = split_1161[61] + getitem_21747 = split_1161[62] + getitem_21748 = split_1161[63]; split_1161 = None + cat_383 = torch.ops.aten.cat.default([getitem_21685, getitem_21686, getitem_21687, getitem_21688, getitem_21689, getitem_21690, getitem_21691, getitem_21692, getitem_21693, getitem_21694, getitem_21695, getitem_21696, getitem_21697, getitem_21698, getitem_21699, getitem_21700, getitem_21701, getitem_21702, getitem_21703, getitem_21704, getitem_21705, getitem_21706, getitem_21707, getitem_21708, getitem_21709, getitem_21710, getitem_21711, getitem_21712, getitem_21713, getitem_21714, getitem_21715, getitem_21716, getitem_21717, getitem_21718, getitem_21719, getitem_21720, getitem_21721, getitem_21722, getitem_21723, getitem_21724, getitem_21725, getitem_21726, getitem_21727, getitem_21728, getitem_21729, getitem_21730, getitem_21731, getitem_21732, getitem_21733, getitem_21734, getitem_21735, getitem_21736, getitem_21737, getitem_21738, getitem_21739, getitem_21740, getitem_21741, getitem_21742, getitem_21743, getitem_21744, getitem_21745, getitem_21746, getitem_21747, getitem_21748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_21685 = getitem_21686 = getitem_21687 = getitem_21688 = getitem_21689 = getitem_21690 = getitem_21691 = getitem_21692 = getitem_21693 = getitem_21694 = getitem_21695 = getitem_21696 = getitem_21697 = getitem_21698 = getitem_21699 = getitem_21700 = getitem_21701 = getitem_21702 = getitem_21703 = getitem_21704 = getitem_21705 = getitem_21706 = getitem_21707 = getitem_21708 = getitem_21709 = getitem_21710 = getitem_21711 = getitem_21712 = getitem_21713 = getitem_21714 = getitem_21715 = getitem_21716 = getitem_21717 = getitem_21718 = getitem_21719 = getitem_21720 = getitem_21721 = getitem_21722 = getitem_21723 = getitem_21724 = getitem_21725 = getitem_21726 = getitem_21727 = getitem_21728 = getitem_21729 = getitem_21730 = getitem_21731 = getitem_21732 = getitem_21733 = getitem_21734 = getitem_21735 = getitem_21736 = getitem_21737 = getitem_21738 = getitem_21739 = getitem_21740 = getitem_21741 = getitem_21742 = getitem_21743 = getitem_21744 = getitem_21745 = getitem_21746 = getitem_21747 = getitem_21748 = None + reduce_scatter_tensor_260 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_383, 'avg', 128, '0'); cat_383 = None + wait_tensor_857 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_260); reduce_scatter_tensor_260 = None + view_2127 = torch.ops.aten.view.default(add_2056, [2, 4096, 2048]); add_2056 = None + convert_element_type_2822 = torch.ops.prims.convert_element_type.default(view_2127, torch.float32); view_2127 = None + convert_element_type_437 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_437, 128, '0'); convert_element_type_437 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_2824 = torch.ops.prims.convert_element_type.default(wait_tensor_164, torch.float32); wait_tensor_164 = None + mul_1917 = torch.ops.aten.mul.Tensor(convert_element_type_2822, convert_element_type_2824); convert_element_type_2824 = None + convert_element_type_438 = torch.ops.prims.convert_element_type.default(add_484, torch.float32); add_484 = None + mul_358 = torch.ops.aten.mul.Tensor(convert_element_type_438, rsqrt_26); convert_element_type_438 = None + mul_1919 = torch.ops.aten.mul.Tensor(mul_358, mul_1917) + sum_252 = torch.ops.aten.sum.dim_IntList(mul_1919, [2], True); mul_1919 = None + div_243 = torch.ops.aten.div.Tensor(mul_358, 2048) + mul_1920 = torch.ops.aten.mul.Tensor(div_243, sum_252); div_243 = sum_252 = None + sub_736 = torch.ops.aten.sub.Tensor(mul_1917, mul_1920); mul_1917 = mul_1920 = None + mul_1921 = torch.ops.aten.mul.Tensor(sub_736, rsqrt_26); sub_736 = rsqrt_26 = None + mul_1922 = torch.ops.aten.mul.Tensor(convert_element_type_2822, mul_358); convert_element_type_2822 = mul_358 = None + sum_253 = torch.ops.aten.sum.dim_IntList(mul_1922, [0, 1]); mul_1922 = None + convert_element_type_2825 = torch.ops.prims.convert_element_type.default(mul_1921, torch.bfloat16); mul_1921 = None + add_2057 = torch.ops.aten.add.Tensor(add_2044, convert_element_type_2825); add_2044 = convert_element_type_2825 = None + convert_element_type_default_27 = torch.ops.prims.convert_element_type.default(sum_253, torch.float32); sum_253 = None + reduce_scatter_tensor_261 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_27, 'avg', 128, '0'); convert_element_type_default_27 = None + wait_tensor_858 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_261); reduce_scatter_tensor_261 = None + view_2128 = torch.ops.aten.view.default(add_2057, [8192, 2048]) + permute_1336 = torch.ops.aten.permute.default(view_2128, [1, 0]) + permute_122 = torch.ops.aten.permute.default(getitem_785, [0, 2, 1, 3]) + view_522 = torch.ops.aten.view.default(permute_122, [2, 4096, -1]); permute_122 = None + view_524 = torch.ops.aten.view.default(view_522, [8192, 2048]); view_522 = None + mm_514 = torch.ops.aten.mm.default(permute_1336, view_524); permute_1336 = view_524 = None + convert_element_type_434 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_434, 128, '0'); convert_element_type_434 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + permute_1338 = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None + mm_515 = torch.ops.aten.mm.default(view_2128, permute_1338); view_2128 = permute_1338 = None + view_2129 = torch.ops.aten.view.default(mm_515, [2, 4096, 2048]); mm_515 = None + convert_element_type_2832 = torch.ops.prims.convert_element_type.default(mm_514, torch.float32); mm_514 = None + reduce_scatter_tensor_262 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2832, 'avg', 128, '0'); convert_element_type_2832 = None + wait_tensor_859 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_262); reduce_scatter_tensor_262 = None + view_2130 = torch.ops.aten.view.default(view_2129, [2, 4096, 16, 128]); view_2129 = None + permute_1340 = torch.ops.aten.permute.default(view_2130, [0, 2, 1, 3]); view_2130 = None + fw_graph18 = self.fw_graph18 + joint_graph18 = self.joint_graph18 + mask_graph18 = self.mask_graph18 + flex_attention_backward_18 = torch.ops.higher_order.flex_attention_backward(permute_119, permute_120, permute_121, getitem_785, getitem_786, permute_1340, None, fw_graph18, joint_graph18, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph18), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_119 = permute_120 = permute_121 = getitem_785 = getitem_786 = permute_1340 = fw_graph18 = joint_graph18 = mask_graph18 = None + getitem_21749 = flex_attention_backward_18[0] + getitem_21750 = flex_attention_backward_18[1] + getitem_21751 = flex_attention_backward_18[2]; flex_attention_backward_18 = None + permute_1341 = torch.ops.aten.permute.default(getitem_21751, [0, 2, 1, 3]); getitem_21751 = None + permute_1342 = torch.ops.aten.permute.default(getitem_21750, [0, 2, 1, 3]); getitem_21750 = None + permute_1343 = torch.ops.aten.permute.default(getitem_21749, [0, 2, 1, 3]); getitem_21749 = None + slice_272 = torch.ops.aten.slice.Tensor(permute_1342, 3, 0, 128) + slice_273 = torch.ops.aten.slice.Tensor(permute_1342, 3, 128, 192); permute_1342 = None + sum_254 = torch.ops.aten.sum.dim_IntList(slice_273, [2], True); slice_273 = None + cat_384 = torch.ops.aten.cat.default([slice_272, permute_1341], 3); slice_272 = permute_1341 = None + view_2131 = torch.ops.aten.view.default(cat_384, [2, 4096, 4096]); cat_384 = None + view_2132 = torch.ops.aten.view.default(view_2131, [8192, 4096]); view_2131 = None + permute_1344 = torch.ops.aten.permute.default(view_2132, [1, 0]) + mm_516 = torch.ops.aten.mm.default(permute_1344, view_519); permute_1344 = view_519 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_431, 128, '0'); convert_element_type_431 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + permute_1346 = torch.ops.aten.permute.default(permute_118, [1, 0]); permute_118 = None + mm_517 = torch.ops.aten.mm.default(view_2132, permute_1346); view_2132 = permute_1346 = None + view_2133 = torch.ops.aten.view.default(mm_517, [2, 4096, 512]); mm_517 = None + convert_element_type_2837 = torch.ops.prims.convert_element_type.default(mm_516, torch.float32); mm_516 = None + reduce_scatter_tensor_263 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2837, 'avg', 128, '0'); convert_element_type_2837 = None + wait_tensor_860 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_263); reduce_scatter_tensor_263 = None + convert_element_type_2838 = torch.ops.prims.convert_element_type.default(view_2133, torch.float32); view_2133 = None + convert_element_type_428 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_428, 128, '0'); convert_element_type_428 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_2840 = torch.ops.prims.convert_element_type.default(wait_tensor_161, torch.float32); wait_tensor_161 = None + mul_1923 = torch.ops.aten.mul.Tensor(convert_element_type_2838, convert_element_type_2840); convert_element_type_2840 = None + convert_element_type_429 = torch.ops.prims.convert_element_type.default(getitem_781, torch.float32); getitem_781 = None + mul_356 = torch.ops.aten.mul.Tensor(convert_element_type_429, rsqrt_25); convert_element_type_429 = None + mul_1925 = torch.ops.aten.mul.Tensor(mul_356, mul_1923) + sum_255 = torch.ops.aten.sum.dim_IntList(mul_1925, [2], True); mul_1925 = None + div_244 = torch.ops.aten.div.Tensor(mul_356, 512) + mul_1926 = torch.ops.aten.mul.Tensor(div_244, sum_255); div_244 = sum_255 = None + sub_737 = torch.ops.aten.sub.Tensor(mul_1923, mul_1926); mul_1923 = mul_1926 = None + mul_1927 = torch.ops.aten.mul.Tensor(sub_737, rsqrt_25); sub_737 = rsqrt_25 = None + mul_1928 = torch.ops.aten.mul.Tensor(convert_element_type_2838, mul_356); convert_element_type_2838 = mul_356 = None + sum_256 = torch.ops.aten.sum.dim_IntList(mul_1928, [0, 1]); mul_1928 = None + convert_element_type_2841 = torch.ops.prims.convert_element_type.default(mul_1927, torch.bfloat16); mul_1927 = None + convert_element_type_default_26 = torch.ops.prims.convert_element_type.default(sum_256, torch.float32); sum_256 = None + reduce_scatter_tensor_264 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_26, 'avg', 128, '0'); convert_element_type_default_26 = None + wait_tensor_861 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_264); reduce_scatter_tensor_264 = None + convert_element_type_2844 = torch.ops.prims.convert_element_type.default(sum_254, torch.float32); sum_254 = None + view_2134 = torch.ops.aten.view.default(convert_element_type_2844, [2, 4096, 1, 32, 2]); convert_element_type_2844 = None + view_as_complex_90 = torch.ops.aten.view_as_complex.default(view_2134); view_2134 = None + mul_1929 = torch.ops.aten.mul.Tensor(view_as_complex_90, clone_9); view_as_complex_90 = None + view_as_real_90 = torch.ops.aten.view_as_real.default(mul_1929); mul_1929 = None + view_2135 = torch.ops.aten.view.default(view_as_real_90, [2, 4096, 1, 64]); view_as_real_90 = None + convert_element_type_2845 = torch.ops.prims.convert_element_type.default(view_2135, torch.bfloat16); view_2135 = None + squeeze_44 = torch.ops.aten.squeeze.dim(convert_element_type_2845, 2); convert_element_type_2845 = None + cat_385 = torch.ops.aten.cat.default([convert_element_type_2841, squeeze_44], 2); convert_element_type_2841 = squeeze_44 = None + view_2136 = torch.ops.aten.view.default(cat_385, [8192, 576]); cat_385 = None + permute_1348 = torch.ops.aten.permute.default(view_2136, [1, 0]) + mm_518 = torch.ops.aten.mm.default(permute_1348, view_505); permute_1348 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_423, 128, '0'); convert_element_type_423 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + slice_49 = torch.ops.aten.slice.Tensor(wait_tensor_160, 0, 0, 576); wait_tensor_160 = None + permute_117 = torch.ops.aten.permute.default(slice_49, [1, 0]); slice_49 = None + permute_1350 = torch.ops.aten.permute.default(permute_117, [1, 0]); permute_117 = None + mm_519 = torch.ops.aten.mm.default(view_2136, permute_1350); view_2136 = permute_1350 = None + view_2137 = torch.ops.aten.view.default(mm_519, [2, 4096, 2048]); mm_519 = None + convert_element_type_2850 = torch.ops.prims.convert_element_type.default(mm_518, torch.float32); mm_518 = None + split_1162 = torch.ops.aten.split.Tensor(convert_element_type_2850, 5); convert_element_type_2850 = None + getitem_21753 = split_1162[0] + getitem_21754 = split_1162[1] + getitem_21755 = split_1162[2] + getitem_21756 = split_1162[3] + getitem_21757 = split_1162[4] + getitem_21758 = split_1162[5] + getitem_21759 = split_1162[6] + getitem_21760 = split_1162[7] + getitem_21761 = split_1162[8] + getitem_21762 = split_1162[9] + getitem_21763 = split_1162[10] + getitem_21764 = split_1162[11] + getitem_21765 = split_1162[12] + getitem_21766 = split_1162[13] + getitem_21767 = split_1162[14] + getitem_21768 = split_1162[15] + getitem_21769 = split_1162[16] + getitem_21770 = split_1162[17] + getitem_21771 = split_1162[18] + getitem_21772 = split_1162[19] + getitem_21773 = split_1162[20] + getitem_21774 = split_1162[21] + getitem_21775 = split_1162[22] + getitem_21776 = split_1162[23] + getitem_21777 = split_1162[24] + getitem_21778 = split_1162[25] + getitem_21779 = split_1162[26] + getitem_21780 = split_1162[27] + getitem_21781 = split_1162[28] + getitem_21782 = split_1162[29] + getitem_21783 = split_1162[30] + getitem_21784 = split_1162[31] + getitem_21785 = split_1162[32] + getitem_21786 = split_1162[33] + getitem_21787 = split_1162[34] + getitem_21788 = split_1162[35] + getitem_21789 = split_1162[36] + getitem_21790 = split_1162[37] + getitem_21791 = split_1162[38] + getitem_21792 = split_1162[39] + getitem_21793 = split_1162[40] + getitem_21794 = split_1162[41] + getitem_21795 = split_1162[42] + getitem_21796 = split_1162[43] + getitem_21797 = split_1162[44] + getitem_21798 = split_1162[45] + getitem_21799 = split_1162[46] + getitem_21800 = split_1162[47] + getitem_21801 = split_1162[48] + getitem_21802 = split_1162[49] + getitem_21803 = split_1162[50] + getitem_21804 = split_1162[51] + getitem_21805 = split_1162[52] + getitem_21806 = split_1162[53] + getitem_21807 = split_1162[54] + getitem_21808 = split_1162[55] + getitem_21809 = split_1162[56] + getitem_21810 = split_1162[57] + getitem_21811 = split_1162[58] + getitem_21812 = split_1162[59] + getitem_21813 = split_1162[60] + getitem_21814 = split_1162[61] + getitem_21815 = split_1162[62] + getitem_21816 = split_1162[63] + getitem_21817 = split_1162[64] + getitem_21818 = split_1162[65] + getitem_21819 = split_1162[66] + getitem_21820 = split_1162[67] + getitem_21821 = split_1162[68] + getitem_21822 = split_1162[69] + getitem_21823 = split_1162[70] + getitem_21824 = split_1162[71] + getitem_21825 = split_1162[72] + getitem_21826 = split_1162[73] + getitem_21827 = split_1162[74] + getitem_21828 = split_1162[75] + getitem_21829 = split_1162[76] + getitem_21830 = split_1162[77] + getitem_21831 = split_1162[78] + getitem_21832 = split_1162[79] + getitem_21833 = split_1162[80] + getitem_21834 = split_1162[81] + getitem_21835 = split_1162[82] + getitem_21836 = split_1162[83] + getitem_21837 = split_1162[84] + getitem_21838 = split_1162[85] + getitem_21839 = split_1162[86] + getitem_21840 = split_1162[87] + getitem_21841 = split_1162[88] + getitem_21842 = split_1162[89] + getitem_21843 = split_1162[90] + getitem_21844 = split_1162[91] + getitem_21845 = split_1162[92] + getitem_21846 = split_1162[93] + getitem_21847 = split_1162[94] + getitem_21848 = split_1162[95] + getitem_21849 = split_1162[96] + getitem_21850 = split_1162[97] + getitem_21851 = split_1162[98] + getitem_21852 = split_1162[99] + getitem_21853 = split_1162[100] + getitem_21854 = split_1162[101] + getitem_21855 = split_1162[102] + getitem_21856 = split_1162[103] + getitem_21857 = split_1162[104] + getitem_21858 = split_1162[105] + getitem_21859 = split_1162[106] + getitem_21860 = split_1162[107] + getitem_21861 = split_1162[108] + getitem_21862 = split_1162[109] + getitem_21863 = split_1162[110] + getitem_21864 = split_1162[111] + getitem_21865 = split_1162[112] + getitem_21866 = split_1162[113] + getitem_21867 = split_1162[114] + getitem_21868 = split_1162[115]; split_1162 = None + constant_pad_nd_1450 = torch.ops.aten.constant_pad_nd.default(getitem_21868, [0, 0, 0, 4], 0.0); getitem_21868 = None + cat_386 = torch.ops.aten.cat.default([getitem_21753, getitem_21754, getitem_21755, getitem_21756, getitem_21757, getitem_21758, getitem_21759, getitem_21760, getitem_21761, getitem_21762, getitem_21763, getitem_21764, getitem_21765, getitem_21766, getitem_21767, getitem_21768, getitem_21769, getitem_21770, getitem_21771, getitem_21772, getitem_21773, getitem_21774, getitem_21775, getitem_21776, getitem_21777, getitem_21778, getitem_21779, getitem_21780, getitem_21781, getitem_21782, getitem_21783, getitem_21784, getitem_21785, getitem_21786, getitem_21787, getitem_21788, getitem_21789, getitem_21790, getitem_21791, getitem_21792, getitem_21793, getitem_21794, getitem_21795, getitem_21796, getitem_21797, getitem_21798, getitem_21799, getitem_21800, getitem_21801, getitem_21802, getitem_21803, getitem_21804, getitem_21805, getitem_21806, getitem_21807, getitem_21808, getitem_21809, getitem_21810, getitem_21811, getitem_21812, getitem_21813, getitem_21814, getitem_21815, getitem_21816, getitem_21817, getitem_21818, getitem_21819, getitem_21820, getitem_21821, getitem_21822, getitem_21823, getitem_21824, getitem_21825, getitem_21826, getitem_21827, getitem_21828, getitem_21829, getitem_21830, getitem_21831, getitem_21832, getitem_21833, getitem_21834, getitem_21835, getitem_21836, getitem_21837, getitem_21838, getitem_21839, getitem_21840, getitem_21841, getitem_21842, getitem_21843, getitem_21844, getitem_21845, getitem_21846, getitem_21847, getitem_21848, getitem_21849, getitem_21850, getitem_21851, getitem_21852, getitem_21853, getitem_21854, getitem_21855, getitem_21856, getitem_21857, getitem_21858, getitem_21859, getitem_21860, getitem_21861, getitem_21862, getitem_21863, getitem_21864, getitem_21865, getitem_21866, getitem_21867, constant_pad_nd_1450, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_21753 = getitem_21754 = getitem_21755 = getitem_21756 = getitem_21757 = getitem_21758 = getitem_21759 = getitem_21760 = getitem_21761 = getitem_21762 = getitem_21763 = getitem_21764 = getitem_21765 = getitem_21766 = getitem_21767 = getitem_21768 = getitem_21769 = getitem_21770 = getitem_21771 = getitem_21772 = getitem_21773 = getitem_21774 = getitem_21775 = getitem_21776 = getitem_21777 = getitem_21778 = getitem_21779 = getitem_21780 = getitem_21781 = getitem_21782 = getitem_21783 = getitem_21784 = getitem_21785 = getitem_21786 = getitem_21787 = getitem_21788 = getitem_21789 = getitem_21790 = getitem_21791 = getitem_21792 = getitem_21793 = getitem_21794 = getitem_21795 = getitem_21796 = getitem_21797 = getitem_21798 = getitem_21799 = getitem_21800 = getitem_21801 = getitem_21802 = getitem_21803 = getitem_21804 = getitem_21805 = getitem_21806 = getitem_21807 = getitem_21808 = getitem_21809 = getitem_21810 = getitem_21811 = getitem_21812 = getitem_21813 = getitem_21814 = getitem_21815 = getitem_21816 = getitem_21817 = getitem_21818 = getitem_21819 = getitem_21820 = getitem_21821 = getitem_21822 = getitem_21823 = getitem_21824 = getitem_21825 = getitem_21826 = getitem_21827 = getitem_21828 = getitem_21829 = getitem_21830 = getitem_21831 = getitem_21832 = getitem_21833 = getitem_21834 = getitem_21835 = getitem_21836 = getitem_21837 = getitem_21838 = getitem_21839 = getitem_21840 = getitem_21841 = getitem_21842 = getitem_21843 = getitem_21844 = getitem_21845 = getitem_21846 = getitem_21847 = getitem_21848 = getitem_21849 = getitem_21850 = getitem_21851 = getitem_21852 = getitem_21853 = getitem_21854 = getitem_21855 = getitem_21856 = getitem_21857 = getitem_21858 = getitem_21859 = getitem_21860 = getitem_21861 = getitem_21862 = getitem_21863 = getitem_21864 = getitem_21865 = getitem_21866 = getitem_21867 = constant_pad_nd_1450 = None + reduce_scatter_tensor_265 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_386, 'avg', 128, '0'); cat_386 = None + wait_tensor_862 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_265); reduce_scatter_tensor_265 = None + slice_274 = torch.ops.aten.slice.Tensor(permute_1343, 3, 0, 128) + slice_275 = torch.ops.aten.slice.Tensor(permute_1343, 3, 128, 192); permute_1343 = None + convert_element_type_2851 = torch.ops.prims.convert_element_type.default(slice_275, torch.float32); slice_275 = None + view_2138 = torch.ops.aten.view.default(convert_element_type_2851, [2, 4096, 16, 32, 2]); convert_element_type_2851 = None + view_as_complex_91 = torch.ops.aten.view_as_complex.default(view_2138); view_2138 = None + mul_1930 = torch.ops.aten.mul.Tensor(view_as_complex_91, clone_9); view_as_complex_91 = None + view_as_real_91 = torch.ops.aten.view_as_real.default(mul_1930); mul_1930 = None + view_2139 = torch.ops.aten.view.default(view_as_real_91, [2, 4096, 16, 64]); view_as_real_91 = None + convert_element_type_2852 = torch.ops.prims.convert_element_type.default(view_2139, torch.bfloat16); view_2139 = None + cat_387 = torch.ops.aten.cat.default([slice_274, convert_element_type_2852], 3); slice_274 = convert_element_type_2852 = None + view_2140 = torch.ops.aten.view.default(cat_387, [2, 4096, 3072]); cat_387 = None + view_2141 = torch.ops.aten.view.default(view_2140, [8192, 3072]); view_2140 = None + permute_1352 = torch.ops.aten.permute.default(view_2141, [1, 0]) + mm_520 = torch.ops.aten.mm.default(permute_1352, view_505); permute_1352 = view_505 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 128, '0'); convert_element_type_418 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_116 = torch.ops.aten.permute.default(wait_tensor_159, [1, 0]); wait_tensor_159 = None + permute_1354 = torch.ops.aten.permute.default(permute_116, [1, 0]); permute_116 = None + mm_521 = torch.ops.aten.mm.default(view_2141, permute_1354); view_2141 = permute_1354 = None + view_2142 = torch.ops.aten.view.default(mm_521, [2, 4096, 2048]); mm_521 = None + add_2058 = torch.ops.aten.add.Tensor(view_2137, view_2142); view_2137 = view_2142 = None + convert_element_type_2857 = torch.ops.prims.convert_element_type.default(mm_520, torch.float32); mm_520 = None + reduce_scatter_tensor_266 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2857, 'avg', 128, '0'); convert_element_type_2857 = None + wait_tensor_863 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_266); reduce_scatter_tensor_266 = None + convert_element_type_2858 = torch.ops.prims.convert_element_type.default(add_2058, torch.float32); add_2058 = None + convert_element_type_415 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16); primals_135 = None + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_415, 128, '0'); convert_element_type_415 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + convert_element_type_2860 = torch.ops.prims.convert_element_type.default(wait_tensor_158, torch.float32); wait_tensor_158 = None + mul_1931 = torch.ops.aten.mul.Tensor(convert_element_type_2858, convert_element_type_2860); convert_element_type_2860 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(add_481, torch.float32); add_481 = None + mul_352 = torch.ops.aten.mul.Tensor(convert_element_type_416, rsqrt_24); convert_element_type_416 = None + mul_1933 = torch.ops.aten.mul.Tensor(mul_352, mul_1931) + sum_257 = torch.ops.aten.sum.dim_IntList(mul_1933, [2], True); mul_1933 = None + div_245 = torch.ops.aten.div.Tensor(mul_352, 2048) + mul_1934 = torch.ops.aten.mul.Tensor(div_245, sum_257); div_245 = sum_257 = None + sub_738 = torch.ops.aten.sub.Tensor(mul_1931, mul_1934); mul_1931 = mul_1934 = None + mul_1935 = torch.ops.aten.mul.Tensor(sub_738, rsqrt_24); sub_738 = rsqrt_24 = None + mul_1936 = torch.ops.aten.mul.Tensor(convert_element_type_2858, mul_352); convert_element_type_2858 = mul_352 = None + sum_258 = torch.ops.aten.sum.dim_IntList(mul_1936, [0, 1]); mul_1936 = None + convert_element_type_2861 = torch.ops.prims.convert_element_type.default(mul_1935, torch.bfloat16); mul_1935 = None + add_2059 = torch.ops.aten.add.Tensor(add_2057, convert_element_type_2861); add_2057 = convert_element_type_2861 = None + convert_element_type_default_25 = torch.ops.prims.convert_element_type.default(sum_258, torch.float32); sum_258 = None + reduce_scatter_tensor_267 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_25, 'avg', 128, '0'); convert_element_type_default_25 = None + wait_tensor_864 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_267); reduce_scatter_tensor_267 = None + view_2143 = torch.ops.aten.view.default(add_2059, [8192, 2048]) + unsqueeze_72 = torch.ops.aten.unsqueeze.default(view_2143, 1) + convert_element_type_2864 = torch.ops.prims.convert_element_type.default(unsqueeze_72, torch.float32); unsqueeze_72 = None + bmm_64 = torch.ops.aten.bmm.default(permute_1356, convert_element_type_2864); permute_1356 = None + bmm_65 = torch.ops.aten.bmm.default(convert_element_type_2864, permute_1357); convert_element_type_2864 = permute_1357 = None + convert_element_type_2865 = torch.ops.prims.convert_element_type.default(bmm_64, torch.bfloat16); bmm_64 = None + view_2144 = torch.ops.aten.view.default(bmm_65, [8192, 6]); bmm_65 = None + view_2145 = torch.ops.aten.view.default(convert_element_type_2865, [49152, 2048]); convert_element_type_2865 = None + index_90 = torch.ops.aten.index.Tensor(view_2145, [getitem_681]); view_2145 = getitem_681 = None + permute_1358 = torch.ops.aten.permute.default(view_2143, [1, 0]) + mm_522 = torch.ops.aten.mm.default(permute_1358, mul_349); permute_1358 = mul_349 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16); primals_134 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_410, 128, '0'); convert_element_type_410 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_115 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + permute_1360 = torch.ops.aten.permute.default(permute_115, [1, 0]); permute_115 = None + mm_523 = torch.ops.aten.mm.default(view_2143, permute_1360); view_2143 = permute_1360 = None + convert_element_type_2870 = torch.ops.prims.convert_element_type.default(mm_522, torch.float32); mm_522 = None + reduce_scatter_tensor_268 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2870, 'avg', 128, '0'); convert_element_type_2870 = None + wait_tensor_865 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_268); reduce_scatter_tensor_268 = None + convert_element_type_405 = torch.ops.prims.convert_element_type.default(mm_60, torch.float32); mm_60 = None + neg_14 = torch.ops.aten.neg.default(convert_element_type_405) + exp_21 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_476 = torch.ops.aten.add.Tensor(exp_21, 1); exp_21 = None + div_35 = torch.ops.aten.div.Tensor(convert_element_type_405, add_476) + convert_element_type_406 = torch.ops.prims.convert_element_type.default(div_35, torch.bfloat16); div_35 = None + mul_1937 = torch.ops.aten.mul.Tensor(mm_523, convert_element_type_406); convert_element_type_406 = None + mul_1938 = torch.ops.aten.mul.Tensor(mm_523, mm_61); mm_523 = mm_61 = None + permute_1362 = torch.ops.aten.permute.default(mul_1937, [1, 0]) + mm_524 = torch.ops.aten.mm.default(permute_1362, view_460); permute_1362 = None + convert_element_type_407 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16); primals_133 = None + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_407, 128, '0'); convert_element_type_407 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_114 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + permute_1364 = torch.ops.aten.permute.default(permute_114, [1, 0]); permute_114 = None + mm_525 = torch.ops.aten.mm.default(mul_1937, permute_1364); mul_1937 = permute_1364 = None + convert_element_type_2875 = torch.ops.prims.convert_element_type.default(mm_524, torch.float32); mm_524 = None + reduce_scatter_tensor_269 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2875, 'avg', 128, '0'); convert_element_type_2875 = None + wait_tensor_866 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_269); reduce_scatter_tensor_269 = None + convert_element_type_2876 = torch.ops.prims.convert_element_type.default(mul_1938, torch.float32); mul_1938 = None + reciprocal_38 = torch.ops.aten.reciprocal.default(add_476); add_476 = None + mul_1939 = torch.ops.aten.mul.Tensor(reciprocal_38, 1); reciprocal_38 = None + mul_1940 = torch.ops.aten.mul.Tensor(convert_element_type_2876, mul_1939); convert_element_type_2876 = None + sub_739 = torch.ops.aten.sub.Tensor(1, mul_1939); mul_1939 = None + mul_1941 = torch.ops.aten.mul.Tensor(convert_element_type_405, sub_739); convert_element_type_405 = sub_739 = None + add_2061 = torch.ops.aten.add.Tensor(mul_1941, 1); mul_1941 = None + mul_1942 = torch.ops.aten.mul.Tensor(mul_1940, add_2061); mul_1940 = add_2061 = None + convert_element_type_2878 = torch.ops.prims.convert_element_type.default(mul_1942, torch.bfloat16); mul_1942 = None + permute_1366 = torch.ops.aten.permute.default(convert_element_type_2878, [1, 0]) + mm_526 = torch.ops.aten.mm.default(permute_1366, view_460); permute_1366 = None + convert_element_type_402 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16); primals_132 = None + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_402, 128, '0'); convert_element_type_402 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_113 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + permute_1368 = torch.ops.aten.permute.default(permute_113, [1, 0]); permute_113 = None + mm_527 = torch.ops.aten.mm.default(convert_element_type_2878, permute_1368); convert_element_type_2878 = permute_1368 = None + add_2062 = torch.ops.aten.add.Tensor(mm_525, mm_527); mm_525 = mm_527 = None + convert_element_type_2883 = torch.ops.prims.convert_element_type.default(mm_526, torch.float32); mm_526 = None + reduce_scatter_tensor_270 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2883, 'avg', 128, '0'); convert_element_type_2883 = None + wait_tensor_867 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_270); reduce_scatter_tensor_270 = None + all_to_all_single_116 = torch.ops._c10d_functional.all_to_all_single.default(index_90, [_local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111], [_local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103], '1033'); index_90 = None + wait_tensor_868 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_116); all_to_all_single_116 = None + full_462 = torch.ops.aten.full.default([sym_size_int_25, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_25 = None + slice_scatter_19 = torch.ops.aten.slice_scatter.default(full_462, wait_tensor_868, 0, 0, -1); wait_tensor_868 = None + index_91 = torch.ops.aten.index.Tensor(slice_scatter_19, [getitem_682]); slice_scatter_19 = None + permute_1370 = torch.ops.aten.permute.default(index_91, [1, 0]) + _grouped_mm_192 = torch.ops.aten._grouped_mm.default(permute_1370, mul_329, cumsum_20); permute_1370 = mul_329 = None + _grouped_mm_193 = torch.ops.aten._grouped_mm.default(index_91, permute_1372, cumsum_20); index_91 = permute_1372 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(_grouped_mm_18, torch.float32); _grouped_mm_18 = None + neg_13 = torch.ops.aten.neg.default(convert_element_type_400) + exp_20 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_440 = torch.ops.aten.add.Tensor(exp_20, 1); exp_20 = None + div_34 = torch.ops.aten.div.Tensor(convert_element_type_400, add_440) + convert_element_type_401 = torch.ops.prims.convert_element_type.default(div_34, torch.bfloat16); div_34 = None + mul_1943 = torch.ops.aten.mul.Tensor(_grouped_mm_193, convert_element_type_401); convert_element_type_401 = None + mul_1944 = torch.ops.aten.mul.Tensor(_grouped_mm_193, _grouped_mm_19); _grouped_mm_193 = _grouped_mm_19 = None + permute_1374 = torch.ops.aten.permute.default(mul_1943, [1, 0]) + _grouped_mm_194 = torch.ops.aten._grouped_mm.default(permute_1374, index_13, cumsum_20); permute_1374 = None + _grouped_mm_195 = torch.ops.aten._grouped_mm.default(mul_1943, permute_1376, cumsum_20); mul_1943 = permute_1376 = None + convert_element_type_2884 = torch.ops.prims.convert_element_type.default(mul_1944, torch.float32); mul_1944 = None + reciprocal_39 = torch.ops.aten.reciprocal.default(add_440); add_440 = None + mul_1945 = torch.ops.aten.mul.Tensor(reciprocal_39, 1); reciprocal_39 = None + mul_1946 = torch.ops.aten.mul.Tensor(convert_element_type_2884, mul_1945); convert_element_type_2884 = None + sub_740 = torch.ops.aten.sub.Tensor(1, mul_1945); mul_1945 = None + mul_1947 = torch.ops.aten.mul.Tensor(convert_element_type_400, sub_740); convert_element_type_400 = sub_740 = None + add_2064 = torch.ops.aten.add.Tensor(mul_1947, 1); mul_1947 = None + mul_1948 = torch.ops.aten.mul.Tensor(mul_1946, add_2064); mul_1946 = add_2064 = None + convert_element_type_2886 = torch.ops.prims.convert_element_type.default(mul_1948, torch.bfloat16); mul_1948 = None + permute_1378 = torch.ops.aten.permute.default(convert_element_type_2886, [1, 0]) + _grouped_mm_196 = torch.ops.aten._grouped_mm.default(permute_1378, index_13, cumsum_20); permute_1378 = index_13 = None + _grouped_mm_197 = torch.ops.aten._grouped_mm.default(convert_element_type_2886, permute_1380, cumsum_20); convert_element_type_2886 = permute_1380 = cumsum_20 = None + add_2065 = torch.ops.aten.add.Tensor(_grouped_mm_195, _grouped_mm_197); _grouped_mm_195 = _grouped_mm_197 = None + convert_element_type_2887 = torch.ops.prims.convert_element_type.default(_grouped_mm_194, torch.float32); _grouped_mm_194 = None + div_246 = torch.ops.aten.div.Tensor(convert_element_type_2887, 128); convert_element_type_2887 = None + split_1164 = torch.ops.aten.split.Tensor(div_246, 88, 1); div_246 = None + getitem_21885 = split_1164[0] + getitem_21902 = split_1164[1] + getitem_21919 = split_1164[2] + getitem_21936 = split_1164[3] + getitem_21953 = split_1164[4] + getitem_21970 = split_1164[5] + getitem_21987 = split_1164[6] + getitem_22004 = split_1164[7] + getitem_22021 = split_1164[8] + getitem_22038 = split_1164[9] + getitem_22055 = split_1164[10] + getitem_22072 = split_1164[11] + getitem_22089 = split_1164[12] + getitem_22106 = split_1164[13] + getitem_22123 = split_1164[14] + getitem_22140 = split_1164[15]; split_1164 = None + cat_388 = torch.ops.aten.cat.default([getitem_21885, getitem_21902, getitem_21919, getitem_21936, getitem_21953, getitem_21970, getitem_21987, getitem_22004, getitem_22021, getitem_22038, getitem_22055, getitem_22072, getitem_22089, getitem_22106, getitem_22123, getitem_22140]); getitem_21885 = getitem_21902 = getitem_21919 = getitem_21936 = getitem_21953 = getitem_21970 = getitem_21987 = getitem_22004 = getitem_22021 = getitem_22038 = getitem_22055 = getitem_22072 = getitem_22089 = getitem_22106 = getitem_22123 = getitem_22140 = None + reduce_scatter_tensor_271 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_388, 'sum', 16, '1025'); cat_388 = None + wait_tensor_869 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_271); reduce_scatter_tensor_271 = None + convert_element_type_2888 = torch.ops.prims.convert_element_type.default(_grouped_mm_192, torch.float32); _grouped_mm_192 = None + div_247 = torch.ops.aten.div.Tensor(convert_element_type_2888, 128); convert_element_type_2888 = None + split_1181 = torch.ops.aten.split.Tensor(div_247, 128, 1); div_247 = None + getitem_22157 = split_1181[0] + getitem_22174 = split_1181[1] + getitem_22191 = split_1181[2] + getitem_22208 = split_1181[3] + getitem_22225 = split_1181[4] + getitem_22242 = split_1181[5] + getitem_22259 = split_1181[6] + getitem_22276 = split_1181[7] + getitem_22293 = split_1181[8] + getitem_22310 = split_1181[9] + getitem_22327 = split_1181[10] + getitem_22344 = split_1181[11] + getitem_22361 = split_1181[12] + getitem_22378 = split_1181[13] + getitem_22395 = split_1181[14] + getitem_22412 = split_1181[15]; split_1181 = None + cat_389 = torch.ops.aten.cat.default([getitem_22157, getitem_22174, getitem_22191, getitem_22208, getitem_22225, getitem_22242, getitem_22259, getitem_22276, getitem_22293, getitem_22310, getitem_22327, getitem_22344, getitem_22361, getitem_22378, getitem_22395, getitem_22412]); getitem_22157 = getitem_22174 = getitem_22191 = getitem_22208 = getitem_22225 = getitem_22242 = getitem_22259 = getitem_22276 = getitem_22293 = getitem_22310 = getitem_22327 = getitem_22344 = getitem_22361 = getitem_22378 = getitem_22395 = getitem_22412 = None + reduce_scatter_tensor_272 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_389, 'sum', 16, '1025'); cat_389 = None + wait_tensor_870 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_272); reduce_scatter_tensor_272 = None + convert_element_type_2889 = torch.ops.prims.convert_element_type.default(_grouped_mm_196, torch.float32); _grouped_mm_196 = None + div_248 = torch.ops.aten.div.Tensor(convert_element_type_2889, 128); convert_element_type_2889 = None + split_1198 = torch.ops.aten.split.Tensor(div_248, 88, 1); div_248 = None + getitem_22429 = split_1198[0] + getitem_22446 = split_1198[1] + getitem_22463 = split_1198[2] + getitem_22480 = split_1198[3] + getitem_22497 = split_1198[4] + getitem_22514 = split_1198[5] + getitem_22531 = split_1198[6] + getitem_22548 = split_1198[7] + getitem_22565 = split_1198[8] + getitem_22582 = split_1198[9] + getitem_22599 = split_1198[10] + getitem_22616 = split_1198[11] + getitem_22633 = split_1198[12] + getitem_22650 = split_1198[13] + getitem_22667 = split_1198[14] + getitem_22684 = split_1198[15]; split_1198 = None + cat_390 = torch.ops.aten.cat.default([getitem_22429, getitem_22446, getitem_22463, getitem_22480, getitem_22497, getitem_22514, getitem_22531, getitem_22548, getitem_22565, getitem_22582, getitem_22599, getitem_22616, getitem_22633, getitem_22650, getitem_22667, getitem_22684]); getitem_22429 = getitem_22446 = getitem_22463 = getitem_22480 = getitem_22497 = getitem_22514 = getitem_22531 = getitem_22548 = getitem_22565 = getitem_22582 = getitem_22599 = getitem_22616 = getitem_22633 = getitem_22650 = getitem_22667 = getitem_22684 = None + reduce_scatter_tensor_273 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_390, 'sum', 16, '1025'); cat_390 = None + wait_tensor_871 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_273); reduce_scatter_tensor_273 = None + index_put_90 = torch.ops.aten.index_put.default(full_462, [getitem_682], add_2065, True); full_462 = getitem_682 = add_2065 = None + slice_276 = torch.ops.aten.slice.Tensor(index_put_90, 0, 0, add_2066); index_put_90 = add_2066 = None + all_to_all_single_117 = torch.ops._c10d_functional.all_to_all_single.default(slice_276, [_local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103], [_local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111], '1033'); slice_276 = _local_scalar_dense_96 = _local_scalar_dense_97 = _local_scalar_dense_98 = _local_scalar_dense_99 = _local_scalar_dense_100 = _local_scalar_dense_101 = _local_scalar_dense_102 = _local_scalar_dense_103 = _local_scalar_dense_104 = _local_scalar_dense_105 = _local_scalar_dense_106 = _local_scalar_dense_107 = _local_scalar_dense_108 = _local_scalar_dense_109 = _local_scalar_dense_110 = _local_scalar_dense_111 = None + wait_tensor_872 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_117); all_to_all_single_117 = None + index_put_91 = torch.ops.aten.index_put.default(full_default_52, [div_32], wait_tensor_872, True); div_32 = wait_tensor_872 = None + add_2070 = torch.ops.aten.add.Tensor(add_2062, index_put_91); add_2062 = index_put_91 = None + mul_1949 = torch.ops.aten.mul.Tensor(view_2144, 1.0); view_2144 = None + scatter_add_19 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_679, mul_1949); getitem_679 = mul_1949 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(mm_59, torch.float32); mm_59 = None + sub_144 = torch.ops.aten.sub.Tensor(convert_element_type_389, amax_6); convert_element_type_389 = amax_6 = None + exp_19 = torch.ops.aten.exp.default(sub_144); sub_144 = None + div_31 = torch.ops.aten.div.Tensor(exp_19, sum_25); exp_19 = sum_25 = None + mul_1950 = torch.ops.aten.mul.Tensor(scatter_add_19, div_31); scatter_add_19 = None + sum_259 = torch.ops.aten.sum.dim_IntList(mul_1950, [1], True) + neg_112 = torch.ops.aten.neg.default(div_31); div_31 = None + fma_19 = torch.ops.prims.fma.default(neg_112, sum_259, mul_1950); neg_112 = sum_259 = mul_1950 = None + convert_element_type_2890 = torch.ops.prims.convert_element_type.default(fma_19, torch.bfloat16); fma_19 = None + permute_1382 = torch.ops.aten.permute.default(convert_element_type_2890, [1, 0]) + mm_528 = torch.ops.aten.mm.default(permute_1382, view_460); permute_1382 = view_460 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16); primals_127 = None + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 128, '0'); convert_element_type_386 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + slice_45 = torch.ops.aten.slice.Tensor(wait_tensor_144, 0, 0, 64); wait_tensor_144 = None + permute_109 = torch.ops.aten.permute.default(slice_45, [1, 0]); slice_45 = None + permute_1384 = torch.ops.aten.permute.default(permute_109, [1, 0]); permute_109 = None + mm_529 = torch.ops.aten.mm.default(convert_element_type_2890, permute_1384); convert_element_type_2890 = permute_1384 = None + add_2071 = torch.ops.aten.add.Tensor(add_2070, mm_529); add_2070 = mm_529 = None + convert_element_type_2895 = torch.ops.prims.convert_element_type.default(mm_528, torch.float32); mm_528 = None + split_1214 = torch.ops.aten.split.Tensor(convert_element_type_2895, 1); convert_element_type_2895 = None + getitem_22685 = split_1214[0] + getitem_22686 = split_1214[1] + getitem_22687 = split_1214[2] + getitem_22688 = split_1214[3] + getitem_22689 = split_1214[4] + getitem_22690 = split_1214[5] + getitem_22691 = split_1214[6] + getitem_22692 = split_1214[7] + getitem_22693 = split_1214[8] + getitem_22694 = split_1214[9] + getitem_22695 = split_1214[10] + getitem_22696 = split_1214[11] + getitem_22697 = split_1214[12] + getitem_22698 = split_1214[13] + getitem_22699 = split_1214[14] + getitem_22700 = split_1214[15] + getitem_22701 = split_1214[16] + getitem_22702 = split_1214[17] + getitem_22703 = split_1214[18] + getitem_22704 = split_1214[19] + getitem_22705 = split_1214[20] + getitem_22706 = split_1214[21] + getitem_22707 = split_1214[22] + getitem_22708 = split_1214[23] + getitem_22709 = split_1214[24] + getitem_22710 = split_1214[25] + getitem_22711 = split_1214[26] + getitem_22712 = split_1214[27] + getitem_22713 = split_1214[28] + getitem_22714 = split_1214[29] + getitem_22715 = split_1214[30] + getitem_22716 = split_1214[31] + getitem_22717 = split_1214[32] + getitem_22718 = split_1214[33] + getitem_22719 = split_1214[34] + getitem_22720 = split_1214[35] + getitem_22721 = split_1214[36] + getitem_22722 = split_1214[37] + getitem_22723 = split_1214[38] + getitem_22724 = split_1214[39] + getitem_22725 = split_1214[40] + getitem_22726 = split_1214[41] + getitem_22727 = split_1214[42] + getitem_22728 = split_1214[43] + getitem_22729 = split_1214[44] + getitem_22730 = split_1214[45] + getitem_22731 = split_1214[46] + getitem_22732 = split_1214[47] + getitem_22733 = split_1214[48] + getitem_22734 = split_1214[49] + getitem_22735 = split_1214[50] + getitem_22736 = split_1214[51] + getitem_22737 = split_1214[52] + getitem_22738 = split_1214[53] + getitem_22739 = split_1214[54] + getitem_22740 = split_1214[55] + getitem_22741 = split_1214[56] + getitem_22742 = split_1214[57] + getitem_22743 = split_1214[58] + getitem_22744 = split_1214[59] + getitem_22745 = split_1214[60] + getitem_22746 = split_1214[61] + getitem_22747 = split_1214[62] + getitem_22748 = split_1214[63]; split_1214 = None + cat_391 = torch.ops.aten.cat.default([getitem_22685, getitem_22686, getitem_22687, getitem_22688, getitem_22689, getitem_22690, getitem_22691, getitem_22692, getitem_22693, getitem_22694, getitem_22695, getitem_22696, getitem_22697, getitem_22698, getitem_22699, getitem_22700, getitem_22701, getitem_22702, getitem_22703, getitem_22704, getitem_22705, getitem_22706, getitem_22707, getitem_22708, getitem_22709, getitem_22710, getitem_22711, getitem_22712, getitem_22713, getitem_22714, getitem_22715, getitem_22716, getitem_22717, getitem_22718, getitem_22719, getitem_22720, getitem_22721, getitem_22722, getitem_22723, getitem_22724, getitem_22725, getitem_22726, getitem_22727, getitem_22728, getitem_22729, getitem_22730, getitem_22731, getitem_22732, getitem_22733, getitem_22734, getitem_22735, getitem_22736, getitem_22737, getitem_22738, getitem_22739, getitem_22740, getitem_22741, getitem_22742, getitem_22743, getitem_22744, getitem_22745, getitem_22746, getitem_22747, getitem_22748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_22685 = getitem_22686 = getitem_22687 = getitem_22688 = getitem_22689 = getitem_22690 = getitem_22691 = getitem_22692 = getitem_22693 = getitem_22694 = getitem_22695 = getitem_22696 = getitem_22697 = getitem_22698 = getitem_22699 = getitem_22700 = getitem_22701 = getitem_22702 = getitem_22703 = getitem_22704 = getitem_22705 = getitem_22706 = getitem_22707 = getitem_22708 = getitem_22709 = getitem_22710 = getitem_22711 = getitem_22712 = getitem_22713 = getitem_22714 = getitem_22715 = getitem_22716 = getitem_22717 = getitem_22718 = getitem_22719 = getitem_22720 = getitem_22721 = getitem_22722 = getitem_22723 = getitem_22724 = getitem_22725 = getitem_22726 = getitem_22727 = getitem_22728 = getitem_22729 = getitem_22730 = getitem_22731 = getitem_22732 = getitem_22733 = getitem_22734 = getitem_22735 = getitem_22736 = getitem_22737 = getitem_22738 = getitem_22739 = getitem_22740 = getitem_22741 = getitem_22742 = getitem_22743 = getitem_22744 = getitem_22745 = getitem_22746 = getitem_22747 = getitem_22748 = None + reduce_scatter_tensor_274 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_391, 'avg', 128, '0'); cat_391 = None + wait_tensor_873 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_274); reduce_scatter_tensor_274 = None + view_2146 = torch.ops.aten.view.default(add_2071, [2, 4096, 2048]); add_2071 = None + convert_element_type_2896 = torch.ops.prims.convert_element_type.default(view_2146, torch.float32); view_2146 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 128, '0'); convert_element_type_383 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + convert_element_type_2898 = torch.ops.prims.convert_element_type.default(wait_tensor_143, torch.float32); wait_tensor_143 = None + mul_1951 = torch.ops.aten.mul.Tensor(convert_element_type_2896, convert_element_type_2898); convert_element_type_2898 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_416, torch.float32); add_416 = None + mul_309 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_1953 = torch.ops.aten.mul.Tensor(mul_309, mul_1951) + sum_260 = torch.ops.aten.sum.dim_IntList(mul_1953, [2], True); mul_1953 = None + div_249 = torch.ops.aten.div.Tensor(mul_309, 2048) + mul_1954 = torch.ops.aten.mul.Tensor(div_249, sum_260); div_249 = sum_260 = None + sub_742 = torch.ops.aten.sub.Tensor(mul_1951, mul_1954); mul_1951 = mul_1954 = None + mul_1955 = torch.ops.aten.mul.Tensor(sub_742, rsqrt_23); sub_742 = rsqrt_23 = None + mul_1956 = torch.ops.aten.mul.Tensor(convert_element_type_2896, mul_309); convert_element_type_2896 = mul_309 = None + sum_261 = torch.ops.aten.sum.dim_IntList(mul_1956, [0, 1]); mul_1956 = None + convert_element_type_2899 = torch.ops.prims.convert_element_type.default(mul_1955, torch.bfloat16); mul_1955 = None + add_2072 = torch.ops.aten.add.Tensor(add_2059, convert_element_type_2899); add_2059 = convert_element_type_2899 = None + convert_element_type_default_24 = torch.ops.prims.convert_element_type.default(sum_261, torch.float32); sum_261 = None + reduce_scatter_tensor_275 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_24, 'avg', 128, '0'); convert_element_type_default_24 = None + wait_tensor_874 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_275); reduce_scatter_tensor_275 = None + view_2147 = torch.ops.aten.view.default(add_2072, [8192, 2048]) + permute_1386 = torch.ops.aten.permute.default(view_2147, [1, 0]) + permute_107 = torch.ops.aten.permute.default(getitem_675, [0, 2, 1, 3]) + view_455 = torch.ops.aten.view.default(permute_107, [2, 4096, -1]); permute_107 = None + view_457 = torch.ops.aten.view.default(view_455, [8192, 2048]); view_455 = None + mm_530 = torch.ops.aten.mm.default(permute_1386, view_457); permute_1386 = view_457 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 128, '0'); convert_element_type_380 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + permute_1388 = torch.ops.aten.permute.default(permute_108, [1, 0]); permute_108 = None + mm_531 = torch.ops.aten.mm.default(view_2147, permute_1388); view_2147 = permute_1388 = None + view_2148 = torch.ops.aten.view.default(mm_531, [2, 4096, 2048]); mm_531 = None + convert_element_type_2906 = torch.ops.prims.convert_element_type.default(mm_530, torch.float32); mm_530 = None + reduce_scatter_tensor_276 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2906, 'avg', 128, '0'); convert_element_type_2906 = None + wait_tensor_875 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_276); reduce_scatter_tensor_276 = None + view_2149 = torch.ops.aten.view.default(view_2148, [2, 4096, 16, 128]); view_2148 = None + permute_1390 = torch.ops.aten.permute.default(view_2149, [0, 2, 1, 3]); view_2149 = None + fw_graph19 = self.fw_graph19 + joint_graph19 = self.joint_graph19 + mask_graph19 = self.mask_graph19 + flex_attention_backward_19 = torch.ops.higher_order.flex_attention_backward(permute_104, permute_105, permute_106, getitem_675, getitem_676, permute_1390, None, fw_graph19, joint_graph19, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph19), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_104 = permute_105 = permute_106 = getitem_675 = getitem_676 = permute_1390 = fw_graph19 = joint_graph19 = mask_graph19 = None + getitem_22749 = flex_attention_backward_19[0] + getitem_22750 = flex_attention_backward_19[1] + getitem_22751 = flex_attention_backward_19[2]; flex_attention_backward_19 = None + permute_1391 = torch.ops.aten.permute.default(getitem_22751, [0, 2, 1, 3]); getitem_22751 = None + permute_1392 = torch.ops.aten.permute.default(getitem_22750, [0, 2, 1, 3]); getitem_22750 = None + permute_1393 = torch.ops.aten.permute.default(getitem_22749, [0, 2, 1, 3]); getitem_22749 = None + slice_278 = torch.ops.aten.slice.Tensor(permute_1392, 3, 0, 128) + slice_279 = torch.ops.aten.slice.Tensor(permute_1392, 3, 128, 192); permute_1392 = None + sum_262 = torch.ops.aten.sum.dim_IntList(slice_279, [2], True); slice_279 = None + cat_392 = torch.ops.aten.cat.default([slice_278, permute_1391], 3); slice_278 = permute_1391 = None + view_2150 = torch.ops.aten.view.default(cat_392, [2, 4096, 4096]); cat_392 = None + view_2151 = torch.ops.aten.view.default(view_2150, [8192, 4096]); view_2150 = None + permute_1394 = torch.ops.aten.permute.default(view_2151, [1, 0]) + mm_532 = torch.ops.aten.mm.default(permute_1394, view_452); permute_1394 = view_452 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_377, 128, '0'); convert_element_type_377 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_103 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + permute_1396 = torch.ops.aten.permute.default(permute_103, [1, 0]); permute_103 = None + mm_533 = torch.ops.aten.mm.default(view_2151, permute_1396); view_2151 = permute_1396 = None + view_2152 = torch.ops.aten.view.default(mm_533, [2, 4096, 512]); mm_533 = None + convert_element_type_2911 = torch.ops.prims.convert_element_type.default(mm_532, torch.float32); mm_532 = None + reduce_scatter_tensor_277 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2911, 'avg', 128, '0'); convert_element_type_2911 = None + wait_tensor_876 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_277); reduce_scatter_tensor_277 = None + convert_element_type_2912 = torch.ops.prims.convert_element_type.default(view_2152, torch.float32); view_2152 = None + convert_element_type_374 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_374, 128, '0'); convert_element_type_374 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + convert_element_type_2914 = torch.ops.prims.convert_element_type.default(wait_tensor_140, torch.float32); wait_tensor_140 = None + mul_1957 = torch.ops.aten.mul.Tensor(convert_element_type_2912, convert_element_type_2914); convert_element_type_2914 = None + convert_element_type_375 = torch.ops.prims.convert_element_type.default(getitem_671, torch.float32); getitem_671 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_375, rsqrt_22); convert_element_type_375 = None + mul_1959 = torch.ops.aten.mul.Tensor(mul_307, mul_1957) + sum_263 = torch.ops.aten.sum.dim_IntList(mul_1959, [2], True); mul_1959 = None + div_250 = torch.ops.aten.div.Tensor(mul_307, 512) + mul_1960 = torch.ops.aten.mul.Tensor(div_250, sum_263); div_250 = sum_263 = None + sub_743 = torch.ops.aten.sub.Tensor(mul_1957, mul_1960); mul_1957 = mul_1960 = None + mul_1961 = torch.ops.aten.mul.Tensor(sub_743, rsqrt_22); sub_743 = rsqrt_22 = None + mul_1962 = torch.ops.aten.mul.Tensor(convert_element_type_2912, mul_307); convert_element_type_2912 = mul_307 = None + sum_264 = torch.ops.aten.sum.dim_IntList(mul_1962, [0, 1]); mul_1962 = None + convert_element_type_2915 = torch.ops.prims.convert_element_type.default(mul_1961, torch.bfloat16); mul_1961 = None + convert_element_type_default_23 = torch.ops.prims.convert_element_type.default(sum_264, torch.float32); sum_264 = None + reduce_scatter_tensor_278 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_23, 'avg', 128, '0'); convert_element_type_default_23 = None + wait_tensor_877 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_278); reduce_scatter_tensor_278 = None + convert_element_type_2918 = torch.ops.prims.convert_element_type.default(sum_262, torch.float32); sum_262 = None + view_2153 = torch.ops.aten.view.default(convert_element_type_2918, [2, 4096, 1, 32, 2]); convert_element_type_2918 = None + view_as_complex_92 = torch.ops.aten.view_as_complex.default(view_2153); view_2153 = None + mul_1963 = torch.ops.aten.mul.Tensor(view_as_complex_92, clone_9); view_as_complex_92 = None + view_as_real_92 = torch.ops.aten.view_as_real.default(mul_1963); mul_1963 = None + view_2154 = torch.ops.aten.view.default(view_as_real_92, [2, 4096, 1, 64]); view_as_real_92 = None + convert_element_type_2919 = torch.ops.prims.convert_element_type.default(view_2154, torch.bfloat16); view_2154 = None + squeeze_45 = torch.ops.aten.squeeze.dim(convert_element_type_2919, 2); convert_element_type_2919 = None + cat_393 = torch.ops.aten.cat.default([convert_element_type_2915, squeeze_45], 2); convert_element_type_2915 = squeeze_45 = None + view_2155 = torch.ops.aten.view.default(cat_393, [8192, 576]); cat_393 = None + permute_1398 = torch.ops.aten.permute.default(view_2155, [1, 0]) + mm_534 = torch.ops.aten.mm.default(permute_1398, view_438); permute_1398 = None + convert_element_type_369 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_369, 128, '0'); convert_element_type_369 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + slice_43 = torch.ops.aten.slice.Tensor(wait_tensor_139, 0, 0, 576); wait_tensor_139 = None + permute_102 = torch.ops.aten.permute.default(slice_43, [1, 0]); slice_43 = None + permute_1400 = torch.ops.aten.permute.default(permute_102, [1, 0]); permute_102 = None + mm_535 = torch.ops.aten.mm.default(view_2155, permute_1400); view_2155 = permute_1400 = None + view_2156 = torch.ops.aten.view.default(mm_535, [2, 4096, 2048]); mm_535 = None + convert_element_type_2924 = torch.ops.prims.convert_element_type.default(mm_534, torch.float32); mm_534 = None + split_1215 = torch.ops.aten.split.Tensor(convert_element_type_2924, 5); convert_element_type_2924 = None + getitem_22753 = split_1215[0] + getitem_22754 = split_1215[1] + getitem_22755 = split_1215[2] + getitem_22756 = split_1215[3] + getitem_22757 = split_1215[4] + getitem_22758 = split_1215[5] + getitem_22759 = split_1215[6] + getitem_22760 = split_1215[7] + getitem_22761 = split_1215[8] + getitem_22762 = split_1215[9] + getitem_22763 = split_1215[10] + getitem_22764 = split_1215[11] + getitem_22765 = split_1215[12] + getitem_22766 = split_1215[13] + getitem_22767 = split_1215[14] + getitem_22768 = split_1215[15] + getitem_22769 = split_1215[16] + getitem_22770 = split_1215[17] + getitem_22771 = split_1215[18] + getitem_22772 = split_1215[19] + getitem_22773 = split_1215[20] + getitem_22774 = split_1215[21] + getitem_22775 = split_1215[22] + getitem_22776 = split_1215[23] + getitem_22777 = split_1215[24] + getitem_22778 = split_1215[25] + getitem_22779 = split_1215[26] + getitem_22780 = split_1215[27] + getitem_22781 = split_1215[28] + getitem_22782 = split_1215[29] + getitem_22783 = split_1215[30] + getitem_22784 = split_1215[31] + getitem_22785 = split_1215[32] + getitem_22786 = split_1215[33] + getitem_22787 = split_1215[34] + getitem_22788 = split_1215[35] + getitem_22789 = split_1215[36] + getitem_22790 = split_1215[37] + getitem_22791 = split_1215[38] + getitem_22792 = split_1215[39] + getitem_22793 = split_1215[40] + getitem_22794 = split_1215[41] + getitem_22795 = split_1215[42] + getitem_22796 = split_1215[43] + getitem_22797 = split_1215[44] + getitem_22798 = split_1215[45] + getitem_22799 = split_1215[46] + getitem_22800 = split_1215[47] + getitem_22801 = split_1215[48] + getitem_22802 = split_1215[49] + getitem_22803 = split_1215[50] + getitem_22804 = split_1215[51] + getitem_22805 = split_1215[52] + getitem_22806 = split_1215[53] + getitem_22807 = split_1215[54] + getitem_22808 = split_1215[55] + getitem_22809 = split_1215[56] + getitem_22810 = split_1215[57] + getitem_22811 = split_1215[58] + getitem_22812 = split_1215[59] + getitem_22813 = split_1215[60] + getitem_22814 = split_1215[61] + getitem_22815 = split_1215[62] + getitem_22816 = split_1215[63] + getitem_22817 = split_1215[64] + getitem_22818 = split_1215[65] + getitem_22819 = split_1215[66] + getitem_22820 = split_1215[67] + getitem_22821 = split_1215[68] + getitem_22822 = split_1215[69] + getitem_22823 = split_1215[70] + getitem_22824 = split_1215[71] + getitem_22825 = split_1215[72] + getitem_22826 = split_1215[73] + getitem_22827 = split_1215[74] + getitem_22828 = split_1215[75] + getitem_22829 = split_1215[76] + getitem_22830 = split_1215[77] + getitem_22831 = split_1215[78] + getitem_22832 = split_1215[79] + getitem_22833 = split_1215[80] + getitem_22834 = split_1215[81] + getitem_22835 = split_1215[82] + getitem_22836 = split_1215[83] + getitem_22837 = split_1215[84] + getitem_22838 = split_1215[85] + getitem_22839 = split_1215[86] + getitem_22840 = split_1215[87] + getitem_22841 = split_1215[88] + getitem_22842 = split_1215[89] + getitem_22843 = split_1215[90] + getitem_22844 = split_1215[91] + getitem_22845 = split_1215[92] + getitem_22846 = split_1215[93] + getitem_22847 = split_1215[94] + getitem_22848 = split_1215[95] + getitem_22849 = split_1215[96] + getitem_22850 = split_1215[97] + getitem_22851 = split_1215[98] + getitem_22852 = split_1215[99] + getitem_22853 = split_1215[100] + getitem_22854 = split_1215[101] + getitem_22855 = split_1215[102] + getitem_22856 = split_1215[103] + getitem_22857 = split_1215[104] + getitem_22858 = split_1215[105] + getitem_22859 = split_1215[106] + getitem_22860 = split_1215[107] + getitem_22861 = split_1215[108] + getitem_22862 = split_1215[109] + getitem_22863 = split_1215[110] + getitem_22864 = split_1215[111] + getitem_22865 = split_1215[112] + getitem_22866 = split_1215[113] + getitem_22867 = split_1215[114] + getitem_22868 = split_1215[115]; split_1215 = None + constant_pad_nd_1527 = torch.ops.aten.constant_pad_nd.default(getitem_22868, [0, 0, 0, 4], 0.0); getitem_22868 = None + cat_394 = torch.ops.aten.cat.default([getitem_22753, getitem_22754, getitem_22755, getitem_22756, getitem_22757, getitem_22758, getitem_22759, getitem_22760, getitem_22761, getitem_22762, getitem_22763, getitem_22764, getitem_22765, getitem_22766, getitem_22767, getitem_22768, getitem_22769, getitem_22770, getitem_22771, getitem_22772, getitem_22773, getitem_22774, getitem_22775, getitem_22776, getitem_22777, getitem_22778, getitem_22779, getitem_22780, getitem_22781, getitem_22782, getitem_22783, getitem_22784, getitem_22785, getitem_22786, getitem_22787, getitem_22788, getitem_22789, getitem_22790, getitem_22791, getitem_22792, getitem_22793, getitem_22794, getitem_22795, getitem_22796, getitem_22797, getitem_22798, getitem_22799, getitem_22800, getitem_22801, getitem_22802, getitem_22803, getitem_22804, getitem_22805, getitem_22806, getitem_22807, getitem_22808, getitem_22809, getitem_22810, getitem_22811, getitem_22812, getitem_22813, getitem_22814, getitem_22815, getitem_22816, getitem_22817, getitem_22818, getitem_22819, getitem_22820, getitem_22821, getitem_22822, getitem_22823, getitem_22824, getitem_22825, getitem_22826, getitem_22827, getitem_22828, getitem_22829, getitem_22830, getitem_22831, getitem_22832, getitem_22833, getitem_22834, getitem_22835, getitem_22836, getitem_22837, getitem_22838, getitem_22839, getitem_22840, getitem_22841, getitem_22842, getitem_22843, getitem_22844, getitem_22845, getitem_22846, getitem_22847, getitem_22848, getitem_22849, getitem_22850, getitem_22851, getitem_22852, getitem_22853, getitem_22854, getitem_22855, getitem_22856, getitem_22857, getitem_22858, getitem_22859, getitem_22860, getitem_22861, getitem_22862, getitem_22863, getitem_22864, getitem_22865, getitem_22866, getitem_22867, constant_pad_nd_1527, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_22753 = getitem_22754 = getitem_22755 = getitem_22756 = getitem_22757 = getitem_22758 = getitem_22759 = getitem_22760 = getitem_22761 = getitem_22762 = getitem_22763 = getitem_22764 = getitem_22765 = getitem_22766 = getitem_22767 = getitem_22768 = getitem_22769 = getitem_22770 = getitem_22771 = getitem_22772 = getitem_22773 = getitem_22774 = getitem_22775 = getitem_22776 = getitem_22777 = getitem_22778 = getitem_22779 = getitem_22780 = getitem_22781 = getitem_22782 = getitem_22783 = getitem_22784 = getitem_22785 = getitem_22786 = getitem_22787 = getitem_22788 = getitem_22789 = getitem_22790 = getitem_22791 = getitem_22792 = getitem_22793 = getitem_22794 = getitem_22795 = getitem_22796 = getitem_22797 = getitem_22798 = getitem_22799 = getitem_22800 = getitem_22801 = getitem_22802 = getitem_22803 = getitem_22804 = getitem_22805 = getitem_22806 = getitem_22807 = getitem_22808 = getitem_22809 = getitem_22810 = getitem_22811 = getitem_22812 = getitem_22813 = getitem_22814 = getitem_22815 = getitem_22816 = getitem_22817 = getitem_22818 = getitem_22819 = getitem_22820 = getitem_22821 = getitem_22822 = getitem_22823 = getitem_22824 = getitem_22825 = getitem_22826 = getitem_22827 = getitem_22828 = getitem_22829 = getitem_22830 = getitem_22831 = getitem_22832 = getitem_22833 = getitem_22834 = getitem_22835 = getitem_22836 = getitem_22837 = getitem_22838 = getitem_22839 = getitem_22840 = getitem_22841 = getitem_22842 = getitem_22843 = getitem_22844 = getitem_22845 = getitem_22846 = getitem_22847 = getitem_22848 = getitem_22849 = getitem_22850 = getitem_22851 = getitem_22852 = getitem_22853 = getitem_22854 = getitem_22855 = getitem_22856 = getitem_22857 = getitem_22858 = getitem_22859 = getitem_22860 = getitem_22861 = getitem_22862 = getitem_22863 = getitem_22864 = getitem_22865 = getitem_22866 = getitem_22867 = constant_pad_nd_1527 = None + reduce_scatter_tensor_279 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_394, 'avg', 128, '0'); cat_394 = None + wait_tensor_878 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_279); reduce_scatter_tensor_279 = None + slice_280 = torch.ops.aten.slice.Tensor(permute_1393, 3, 0, 128) + slice_281 = torch.ops.aten.slice.Tensor(permute_1393, 3, 128, 192); permute_1393 = None + convert_element_type_2925 = torch.ops.prims.convert_element_type.default(slice_281, torch.float32); slice_281 = None + view_2157 = torch.ops.aten.view.default(convert_element_type_2925, [2, 4096, 16, 32, 2]); convert_element_type_2925 = None + view_as_complex_93 = torch.ops.aten.view_as_complex.default(view_2157); view_2157 = None + mul_1964 = torch.ops.aten.mul.Tensor(view_as_complex_93, clone_9); view_as_complex_93 = None + view_as_real_93 = torch.ops.aten.view_as_real.default(mul_1964); mul_1964 = None + view_2158 = torch.ops.aten.view.default(view_as_real_93, [2, 4096, 16, 64]); view_as_real_93 = None + convert_element_type_2926 = torch.ops.prims.convert_element_type.default(view_2158, torch.bfloat16); view_2158 = None + cat_395 = torch.ops.aten.cat.default([slice_280, convert_element_type_2926], 3); slice_280 = convert_element_type_2926 = None + view_2159 = torch.ops.aten.view.default(cat_395, [2, 4096, 3072]); cat_395 = None + view_2160 = torch.ops.aten.view.default(view_2159, [8192, 3072]); view_2159 = None + permute_1402 = torch.ops.aten.permute.default(view_2160, [1, 0]) + mm_536 = torch.ops.aten.mm.default(permute_1402, view_438); permute_1402 = view_438 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 128, '0'); convert_element_type_364 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + permute_1404 = torch.ops.aten.permute.default(permute_101, [1, 0]); permute_101 = None + mm_537 = torch.ops.aten.mm.default(view_2160, permute_1404); view_2160 = permute_1404 = None + view_2161 = torch.ops.aten.view.default(mm_537, [2, 4096, 2048]); mm_537 = None + add_2073 = torch.ops.aten.add.Tensor(view_2156, view_2161); view_2156 = view_2161 = None + convert_element_type_2931 = torch.ops.prims.convert_element_type.default(mm_536, torch.float32); mm_536 = None + reduce_scatter_tensor_280 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2931, 'avg', 128, '0'); convert_element_type_2931 = None + wait_tensor_879 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_280); reduce_scatter_tensor_280 = None + convert_element_type_2932 = torch.ops.prims.convert_element_type.default(add_2073, torch.float32); add_2073 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 128, '0'); convert_element_type_361 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + convert_element_type_2934 = torch.ops.prims.convert_element_type.default(wait_tensor_137, torch.float32); wait_tensor_137 = None + mul_1965 = torch.ops.aten.mul.Tensor(convert_element_type_2932, convert_element_type_2934); convert_element_type_2934 = None + convert_element_type_362 = torch.ops.prims.convert_element_type.default(add_413, torch.float32); add_413 = None + mul_303 = torch.ops.aten.mul.Tensor(convert_element_type_362, rsqrt_21); convert_element_type_362 = None + mul_1967 = torch.ops.aten.mul.Tensor(mul_303, mul_1965) + sum_265 = torch.ops.aten.sum.dim_IntList(mul_1967, [2], True); mul_1967 = None + div_251 = torch.ops.aten.div.Tensor(mul_303, 2048) + mul_1968 = torch.ops.aten.mul.Tensor(div_251, sum_265); div_251 = sum_265 = None + sub_744 = torch.ops.aten.sub.Tensor(mul_1965, mul_1968); mul_1965 = mul_1968 = None + mul_1969 = torch.ops.aten.mul.Tensor(sub_744, rsqrt_21); sub_744 = rsqrt_21 = None + mul_1970 = torch.ops.aten.mul.Tensor(convert_element_type_2932, mul_303); convert_element_type_2932 = mul_303 = None + sum_266 = torch.ops.aten.sum.dim_IntList(mul_1970, [0, 1]); mul_1970 = None + convert_element_type_2935 = torch.ops.prims.convert_element_type.default(mul_1969, torch.bfloat16); mul_1969 = None + add_2074 = torch.ops.aten.add.Tensor(add_2072, convert_element_type_2935); add_2072 = convert_element_type_2935 = None + convert_element_type_default_22 = torch.ops.prims.convert_element_type.default(sum_266, torch.float32); sum_266 = None + reduce_scatter_tensor_281 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_22, 'avg', 128, '0'); convert_element_type_default_22 = None + wait_tensor_880 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_281); reduce_scatter_tensor_281 = None + view_2162 = torch.ops.aten.view.default(add_2074, [8192, 2048]) + unsqueeze_73 = torch.ops.aten.unsqueeze.default(view_2162, 1) + convert_element_type_2938 = torch.ops.prims.convert_element_type.default(unsqueeze_73, torch.float32); unsqueeze_73 = None + bmm_66 = torch.ops.aten.bmm.default(permute_1406, convert_element_type_2938); permute_1406 = None + bmm_67 = torch.ops.aten.bmm.default(convert_element_type_2938, permute_1407); convert_element_type_2938 = permute_1407 = None + convert_element_type_2939 = torch.ops.prims.convert_element_type.default(bmm_66, torch.bfloat16); bmm_66 = None + view_2163 = torch.ops.aten.view.default(bmm_67, [8192, 6]); bmm_67 = None + view_2164 = torch.ops.aten.view.default(convert_element_type_2939, [49152, 2048]); convert_element_type_2939 = None + index_92 = torch.ops.aten.index.Tensor(view_2164, [getitem_571]); view_2164 = getitem_571 = None + permute_1408 = torch.ops.aten.permute.default(view_2162, [1, 0]) + mm_538 = torch.ops.aten.mm.default(permute_1408, mul_300); permute_1408 = mul_300 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_356, 128, '0'); convert_element_type_356 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + permute_1410 = torch.ops.aten.permute.default(permute_100, [1, 0]); permute_100 = None + mm_539 = torch.ops.aten.mm.default(view_2162, permute_1410); view_2162 = permute_1410 = None + convert_element_type_2944 = torch.ops.prims.convert_element_type.default(mm_538, torch.float32); mm_538 = None + reduce_scatter_tensor_282 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2944, 'avg', 128, '0'); convert_element_type_2944 = None + wait_tensor_881 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_282); reduce_scatter_tensor_282 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(mm_52, torch.float32); mm_52 = None + neg_12 = torch.ops.aten.neg.default(convert_element_type_351) + exp_18 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_408 = torch.ops.aten.add.Tensor(exp_18, 1); exp_18 = None + div_30 = torch.ops.aten.div.Tensor(convert_element_type_351, add_408) + convert_element_type_352 = torch.ops.prims.convert_element_type.default(div_30, torch.bfloat16); div_30 = None + mul_1971 = torch.ops.aten.mul.Tensor(mm_539, convert_element_type_352); convert_element_type_352 = None + mul_1972 = torch.ops.aten.mul.Tensor(mm_539, mm_53); mm_539 = mm_53 = None + permute_1412 = torch.ops.aten.permute.default(mul_1971, [1, 0]) + mm_540 = torch.ops.aten.mm.default(permute_1412, view_393); permute_1412 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16); primals_117 = None + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 128, '0'); convert_element_type_353 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + permute_1414 = torch.ops.aten.permute.default(permute_99, [1, 0]); permute_99 = None + mm_541 = torch.ops.aten.mm.default(mul_1971, permute_1414); mul_1971 = permute_1414 = None + convert_element_type_2949 = torch.ops.prims.convert_element_type.default(mm_540, torch.float32); mm_540 = None + reduce_scatter_tensor_283 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2949, 'avg', 128, '0'); convert_element_type_2949 = None + wait_tensor_882 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_283); reduce_scatter_tensor_283 = None + convert_element_type_2950 = torch.ops.prims.convert_element_type.default(mul_1972, torch.float32); mul_1972 = None + reciprocal_40 = torch.ops.aten.reciprocal.default(add_408); add_408 = None + mul_1973 = torch.ops.aten.mul.Tensor(reciprocal_40, 1); reciprocal_40 = None + mul_1974 = torch.ops.aten.mul.Tensor(convert_element_type_2950, mul_1973); convert_element_type_2950 = None + sub_745 = torch.ops.aten.sub.Tensor(1, mul_1973); mul_1973 = None + mul_1975 = torch.ops.aten.mul.Tensor(convert_element_type_351, sub_745); convert_element_type_351 = sub_745 = None + add_2076 = torch.ops.aten.add.Tensor(mul_1975, 1); mul_1975 = None + mul_1976 = torch.ops.aten.mul.Tensor(mul_1974, add_2076); mul_1974 = add_2076 = None + convert_element_type_2952 = torch.ops.prims.convert_element_type.default(mul_1976, torch.bfloat16); mul_1976 = None + permute_1416 = torch.ops.aten.permute.default(convert_element_type_2952, [1, 0]) + mm_542 = torch.ops.aten.mm.default(permute_1416, view_393); permute_1416 = None + convert_element_type_348 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16); primals_116 = None + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_348, 128, '0'); convert_element_type_348 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + permute_1418 = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None + mm_543 = torch.ops.aten.mm.default(convert_element_type_2952, permute_1418); convert_element_type_2952 = permute_1418 = None + add_2077 = torch.ops.aten.add.Tensor(mm_541, mm_543); mm_541 = mm_543 = None + convert_element_type_2957 = torch.ops.prims.convert_element_type.default(mm_542, torch.float32); mm_542 = None + reduce_scatter_tensor_284 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2957, 'avg', 128, '0'); convert_element_type_2957 = None + wait_tensor_883 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_284); reduce_scatter_tensor_284 = None + all_to_all_single_118 = torch.ops._c10d_functional.all_to_all_single.default(index_92, [_local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95], [_local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87], '1033'); index_92 = None + wait_tensor_884 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_118); all_to_all_single_118 = None + full_468 = torch.ops.aten.full.default([sym_size_int_21, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_21 = None + slice_scatter_20 = torch.ops.aten.slice_scatter.default(full_468, wait_tensor_884, 0, 0, -1); wait_tensor_884 = None + index_93 = torch.ops.aten.index.Tensor(slice_scatter_20, [getitem_572]); slice_scatter_20 = None + permute_1420 = torch.ops.aten.permute.default(index_93, [1, 0]) + _grouped_mm_198 = torch.ops.aten._grouped_mm.default(permute_1420, mul_280, cumsum_17); permute_1420 = mul_280 = None + _grouped_mm_199 = torch.ops.aten._grouped_mm.default(index_93, permute_1422, cumsum_17); index_93 = permute_1422 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(_grouped_mm_15, torch.float32); _grouped_mm_15 = None + neg_11 = torch.ops.aten.neg.default(convert_element_type_346) + exp_17 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_372 = torch.ops.aten.add.Tensor(exp_17, 1); exp_17 = None + div_29 = torch.ops.aten.div.Tensor(convert_element_type_346, add_372) + convert_element_type_347 = torch.ops.prims.convert_element_type.default(div_29, torch.bfloat16); div_29 = None + mul_1977 = torch.ops.aten.mul.Tensor(_grouped_mm_199, convert_element_type_347); convert_element_type_347 = None + mul_1978 = torch.ops.aten.mul.Tensor(_grouped_mm_199, _grouped_mm_16); _grouped_mm_199 = _grouped_mm_16 = None + permute_1424 = torch.ops.aten.permute.default(mul_1977, [1, 0]) + _grouped_mm_200 = torch.ops.aten._grouped_mm.default(permute_1424, index_11, cumsum_17); permute_1424 = None + _grouped_mm_201 = torch.ops.aten._grouped_mm.default(mul_1977, permute_1426, cumsum_17); mul_1977 = permute_1426 = None + convert_element_type_2958 = torch.ops.prims.convert_element_type.default(mul_1978, torch.float32); mul_1978 = None + reciprocal_41 = torch.ops.aten.reciprocal.default(add_372); add_372 = None + mul_1979 = torch.ops.aten.mul.Tensor(reciprocal_41, 1); reciprocal_41 = None + mul_1980 = torch.ops.aten.mul.Tensor(convert_element_type_2958, mul_1979); convert_element_type_2958 = None + sub_746 = torch.ops.aten.sub.Tensor(1, mul_1979); mul_1979 = None + mul_1981 = torch.ops.aten.mul.Tensor(convert_element_type_346, sub_746); convert_element_type_346 = sub_746 = None + add_2079 = torch.ops.aten.add.Tensor(mul_1981, 1); mul_1981 = None + mul_1982 = torch.ops.aten.mul.Tensor(mul_1980, add_2079); mul_1980 = add_2079 = None + convert_element_type_2960 = torch.ops.prims.convert_element_type.default(mul_1982, torch.bfloat16); mul_1982 = None + permute_1428 = torch.ops.aten.permute.default(convert_element_type_2960, [1, 0]) + _grouped_mm_202 = torch.ops.aten._grouped_mm.default(permute_1428, index_11, cumsum_17); permute_1428 = index_11 = None + _grouped_mm_203 = torch.ops.aten._grouped_mm.default(convert_element_type_2960, permute_1430, cumsum_17); convert_element_type_2960 = permute_1430 = cumsum_17 = None + add_2080 = torch.ops.aten.add.Tensor(_grouped_mm_201, _grouped_mm_203); _grouped_mm_201 = _grouped_mm_203 = None + convert_element_type_2961 = torch.ops.prims.convert_element_type.default(_grouped_mm_200, torch.float32); _grouped_mm_200 = None + div_252 = torch.ops.aten.div.Tensor(convert_element_type_2961, 128); convert_element_type_2961 = None + split_1217 = torch.ops.aten.split.Tensor(div_252, 88, 1); div_252 = None + getitem_22885 = split_1217[0] + getitem_22902 = split_1217[1] + getitem_22919 = split_1217[2] + getitem_22936 = split_1217[3] + getitem_22953 = split_1217[4] + getitem_22970 = split_1217[5] + getitem_22987 = split_1217[6] + getitem_23004 = split_1217[7] + getitem_23021 = split_1217[8] + getitem_23038 = split_1217[9] + getitem_23055 = split_1217[10] + getitem_23072 = split_1217[11] + getitem_23089 = split_1217[12] + getitem_23106 = split_1217[13] + getitem_23123 = split_1217[14] + getitem_23140 = split_1217[15]; split_1217 = None + cat_396 = torch.ops.aten.cat.default([getitem_22885, getitem_22902, getitem_22919, getitem_22936, getitem_22953, getitem_22970, getitem_22987, getitem_23004, getitem_23021, getitem_23038, getitem_23055, getitem_23072, getitem_23089, getitem_23106, getitem_23123, getitem_23140]); getitem_22885 = getitem_22902 = getitem_22919 = getitem_22936 = getitem_22953 = getitem_22970 = getitem_22987 = getitem_23004 = getitem_23021 = getitem_23038 = getitem_23055 = getitem_23072 = getitem_23089 = getitem_23106 = getitem_23123 = getitem_23140 = None + reduce_scatter_tensor_285 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_396, 'sum', 16, '1025'); cat_396 = None + wait_tensor_885 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_285); reduce_scatter_tensor_285 = None + convert_element_type_2962 = torch.ops.prims.convert_element_type.default(_grouped_mm_198, torch.float32); _grouped_mm_198 = None + div_253 = torch.ops.aten.div.Tensor(convert_element_type_2962, 128); convert_element_type_2962 = None + split_1234 = torch.ops.aten.split.Tensor(div_253, 128, 1); div_253 = None + getitem_23157 = split_1234[0] + getitem_23174 = split_1234[1] + getitem_23191 = split_1234[2] + getitem_23208 = split_1234[3] + getitem_23225 = split_1234[4] + getitem_23242 = split_1234[5] + getitem_23259 = split_1234[6] + getitem_23276 = split_1234[7] + getitem_23293 = split_1234[8] + getitem_23310 = split_1234[9] + getitem_23327 = split_1234[10] + getitem_23344 = split_1234[11] + getitem_23361 = split_1234[12] + getitem_23378 = split_1234[13] + getitem_23395 = split_1234[14] + getitem_23412 = split_1234[15]; split_1234 = None + cat_397 = torch.ops.aten.cat.default([getitem_23157, getitem_23174, getitem_23191, getitem_23208, getitem_23225, getitem_23242, getitem_23259, getitem_23276, getitem_23293, getitem_23310, getitem_23327, getitem_23344, getitem_23361, getitem_23378, getitem_23395, getitem_23412]); getitem_23157 = getitem_23174 = getitem_23191 = getitem_23208 = getitem_23225 = getitem_23242 = getitem_23259 = getitem_23276 = getitem_23293 = getitem_23310 = getitem_23327 = getitem_23344 = getitem_23361 = getitem_23378 = getitem_23395 = getitem_23412 = None + reduce_scatter_tensor_286 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_397, 'sum', 16, '1025'); cat_397 = None + wait_tensor_886 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_286); reduce_scatter_tensor_286 = None + convert_element_type_2963 = torch.ops.prims.convert_element_type.default(_grouped_mm_202, torch.float32); _grouped_mm_202 = None + div_254 = torch.ops.aten.div.Tensor(convert_element_type_2963, 128); convert_element_type_2963 = None + split_1251 = torch.ops.aten.split.Tensor(div_254, 88, 1); div_254 = None + getitem_23429 = split_1251[0] + getitem_23446 = split_1251[1] + getitem_23463 = split_1251[2] + getitem_23480 = split_1251[3] + getitem_23497 = split_1251[4] + getitem_23514 = split_1251[5] + getitem_23531 = split_1251[6] + getitem_23548 = split_1251[7] + getitem_23565 = split_1251[8] + getitem_23582 = split_1251[9] + getitem_23599 = split_1251[10] + getitem_23616 = split_1251[11] + getitem_23633 = split_1251[12] + getitem_23650 = split_1251[13] + getitem_23667 = split_1251[14] + getitem_23684 = split_1251[15]; split_1251 = None + cat_398 = torch.ops.aten.cat.default([getitem_23429, getitem_23446, getitem_23463, getitem_23480, getitem_23497, getitem_23514, getitem_23531, getitem_23548, getitem_23565, getitem_23582, getitem_23599, getitem_23616, getitem_23633, getitem_23650, getitem_23667, getitem_23684]); getitem_23429 = getitem_23446 = getitem_23463 = getitem_23480 = getitem_23497 = getitem_23514 = getitem_23531 = getitem_23548 = getitem_23565 = getitem_23582 = getitem_23599 = getitem_23616 = getitem_23633 = getitem_23650 = getitem_23667 = getitem_23684 = None + reduce_scatter_tensor_287 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_398, 'sum', 16, '1025'); cat_398 = None + wait_tensor_887 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_287); reduce_scatter_tensor_287 = None + index_put_92 = torch.ops.aten.index_put.default(full_468, [getitem_572], add_2080, True); full_468 = getitem_572 = add_2080 = None + slice_282 = torch.ops.aten.slice.Tensor(index_put_92, 0, 0, add_2081); index_put_92 = add_2081 = None + all_to_all_single_119 = torch.ops._c10d_functional.all_to_all_single.default(slice_282, [_local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87], [_local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95], '1033'); slice_282 = _local_scalar_dense_80 = _local_scalar_dense_81 = _local_scalar_dense_82 = _local_scalar_dense_83 = _local_scalar_dense_84 = _local_scalar_dense_85 = _local_scalar_dense_86 = _local_scalar_dense_87 = _local_scalar_dense_88 = _local_scalar_dense_89 = _local_scalar_dense_90 = _local_scalar_dense_91 = _local_scalar_dense_92 = _local_scalar_dense_93 = _local_scalar_dense_94 = _local_scalar_dense_95 = None + wait_tensor_888 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_119); all_to_all_single_119 = None + index_put_93 = torch.ops.aten.index_put.default(full_default_52, [div_27], wait_tensor_888, True); div_27 = wait_tensor_888 = None + add_2085 = torch.ops.aten.add.Tensor(add_2077, index_put_93); add_2077 = index_put_93 = None + mul_1983 = torch.ops.aten.mul.Tensor(view_2163, 1.0); view_2163 = None + scatter_add_20 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_569, mul_1983); getitem_569 = mul_1983 = None + convert_element_type_335 = torch.ops.prims.convert_element_type.default(mm_51, torch.float32); mm_51 = None + sub_120 = torch.ops.aten.sub.Tensor(convert_element_type_335, amax_5); convert_element_type_335 = amax_5 = None + exp_16 = torch.ops.aten.exp.default(sub_120); sub_120 = None + div_26 = torch.ops.aten.div.Tensor(exp_16, sum_21); exp_16 = sum_21 = None + mul_1984 = torch.ops.aten.mul.Tensor(scatter_add_20, div_26); scatter_add_20 = None + sum_267 = torch.ops.aten.sum.dim_IntList(mul_1984, [1], True) + neg_115 = torch.ops.aten.neg.default(div_26); div_26 = None + fma_20 = torch.ops.prims.fma.default(neg_115, sum_267, mul_1984); neg_115 = sum_267 = mul_1984 = None + convert_element_type_2964 = torch.ops.prims.convert_element_type.default(fma_20, torch.bfloat16); fma_20 = None + permute_1432 = torch.ops.aten.permute.default(convert_element_type_2964, [1, 0]) + mm_544 = torch.ops.aten.mm.default(permute_1432, view_393); permute_1432 = view_393 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16); primals_111 = None + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_332, 128, '0'); convert_element_type_332 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + slice_39 = torch.ops.aten.slice.Tensor(wait_tensor_123, 0, 0, 64); wait_tensor_123 = None + permute_94 = torch.ops.aten.permute.default(slice_39, [1, 0]); slice_39 = None + permute_1434 = torch.ops.aten.permute.default(permute_94, [1, 0]); permute_94 = None + mm_545 = torch.ops.aten.mm.default(convert_element_type_2964, permute_1434); convert_element_type_2964 = permute_1434 = None + add_2086 = torch.ops.aten.add.Tensor(add_2085, mm_545); add_2085 = mm_545 = None + convert_element_type_2969 = torch.ops.prims.convert_element_type.default(mm_544, torch.float32); mm_544 = None + split_1267 = torch.ops.aten.split.Tensor(convert_element_type_2969, 1); convert_element_type_2969 = None + getitem_23685 = split_1267[0] + getitem_23686 = split_1267[1] + getitem_23687 = split_1267[2] + getitem_23688 = split_1267[3] + getitem_23689 = split_1267[4] + getitem_23690 = split_1267[5] + getitem_23691 = split_1267[6] + getitem_23692 = split_1267[7] + getitem_23693 = split_1267[8] + getitem_23694 = split_1267[9] + getitem_23695 = split_1267[10] + getitem_23696 = split_1267[11] + getitem_23697 = split_1267[12] + getitem_23698 = split_1267[13] + getitem_23699 = split_1267[14] + getitem_23700 = split_1267[15] + getitem_23701 = split_1267[16] + getitem_23702 = split_1267[17] + getitem_23703 = split_1267[18] + getitem_23704 = split_1267[19] + getitem_23705 = split_1267[20] + getitem_23706 = split_1267[21] + getitem_23707 = split_1267[22] + getitem_23708 = split_1267[23] + getitem_23709 = split_1267[24] + getitem_23710 = split_1267[25] + getitem_23711 = split_1267[26] + getitem_23712 = split_1267[27] + getitem_23713 = split_1267[28] + getitem_23714 = split_1267[29] + getitem_23715 = split_1267[30] + getitem_23716 = split_1267[31] + getitem_23717 = split_1267[32] + getitem_23718 = split_1267[33] + getitem_23719 = split_1267[34] + getitem_23720 = split_1267[35] + getitem_23721 = split_1267[36] + getitem_23722 = split_1267[37] + getitem_23723 = split_1267[38] + getitem_23724 = split_1267[39] + getitem_23725 = split_1267[40] + getitem_23726 = split_1267[41] + getitem_23727 = split_1267[42] + getitem_23728 = split_1267[43] + getitem_23729 = split_1267[44] + getitem_23730 = split_1267[45] + getitem_23731 = split_1267[46] + getitem_23732 = split_1267[47] + getitem_23733 = split_1267[48] + getitem_23734 = split_1267[49] + getitem_23735 = split_1267[50] + getitem_23736 = split_1267[51] + getitem_23737 = split_1267[52] + getitem_23738 = split_1267[53] + getitem_23739 = split_1267[54] + getitem_23740 = split_1267[55] + getitem_23741 = split_1267[56] + getitem_23742 = split_1267[57] + getitem_23743 = split_1267[58] + getitem_23744 = split_1267[59] + getitem_23745 = split_1267[60] + getitem_23746 = split_1267[61] + getitem_23747 = split_1267[62] + getitem_23748 = split_1267[63]; split_1267 = None + cat_399 = torch.ops.aten.cat.default([getitem_23685, getitem_23686, getitem_23687, getitem_23688, getitem_23689, getitem_23690, getitem_23691, getitem_23692, getitem_23693, getitem_23694, getitem_23695, getitem_23696, getitem_23697, getitem_23698, getitem_23699, getitem_23700, getitem_23701, getitem_23702, getitem_23703, getitem_23704, getitem_23705, getitem_23706, getitem_23707, getitem_23708, getitem_23709, getitem_23710, getitem_23711, getitem_23712, getitem_23713, getitem_23714, getitem_23715, getitem_23716, getitem_23717, getitem_23718, getitem_23719, getitem_23720, getitem_23721, getitem_23722, getitem_23723, getitem_23724, getitem_23725, getitem_23726, getitem_23727, getitem_23728, getitem_23729, getitem_23730, getitem_23731, getitem_23732, getitem_23733, getitem_23734, getitem_23735, getitem_23736, getitem_23737, getitem_23738, getitem_23739, getitem_23740, getitem_23741, getitem_23742, getitem_23743, getitem_23744, getitem_23745, getitem_23746, getitem_23747, getitem_23748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_23685 = getitem_23686 = getitem_23687 = getitem_23688 = getitem_23689 = getitem_23690 = getitem_23691 = getitem_23692 = getitem_23693 = getitem_23694 = getitem_23695 = getitem_23696 = getitem_23697 = getitem_23698 = getitem_23699 = getitem_23700 = getitem_23701 = getitem_23702 = getitem_23703 = getitem_23704 = getitem_23705 = getitem_23706 = getitem_23707 = getitem_23708 = getitem_23709 = getitem_23710 = getitem_23711 = getitem_23712 = getitem_23713 = getitem_23714 = getitem_23715 = getitem_23716 = getitem_23717 = getitem_23718 = getitem_23719 = getitem_23720 = getitem_23721 = getitem_23722 = getitem_23723 = getitem_23724 = getitem_23725 = getitem_23726 = getitem_23727 = getitem_23728 = getitem_23729 = getitem_23730 = getitem_23731 = getitem_23732 = getitem_23733 = getitem_23734 = getitem_23735 = getitem_23736 = getitem_23737 = getitem_23738 = getitem_23739 = getitem_23740 = getitem_23741 = getitem_23742 = getitem_23743 = getitem_23744 = getitem_23745 = getitem_23746 = getitem_23747 = getitem_23748 = None + reduce_scatter_tensor_288 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_399, 'avg', 128, '0'); cat_399 = None + wait_tensor_889 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_288); reduce_scatter_tensor_288 = None + view_2165 = torch.ops.aten.view.default(add_2086, [2, 4096, 2048]); add_2086 = None + convert_element_type_2970 = torch.ops.prims.convert_element_type.default(view_2165, torch.float32); view_2165 = None + convert_element_type_329 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16); primals_109 = None + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_329, 128, '0'); convert_element_type_329 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + convert_element_type_2972 = torch.ops.prims.convert_element_type.default(wait_tensor_122, torch.float32); wait_tensor_122 = None + mul_1985 = torch.ops.aten.mul.Tensor(convert_element_type_2970, convert_element_type_2972); convert_element_type_2972 = None + convert_element_type_330 = torch.ops.prims.convert_element_type.default(add_348, torch.float32); add_348 = None + mul_260 = torch.ops.aten.mul.Tensor(convert_element_type_330, rsqrt_20); convert_element_type_330 = None + mul_1987 = torch.ops.aten.mul.Tensor(mul_260, mul_1985) + sum_268 = torch.ops.aten.sum.dim_IntList(mul_1987, [2], True); mul_1987 = None + div_255 = torch.ops.aten.div.Tensor(mul_260, 2048) + mul_1988 = torch.ops.aten.mul.Tensor(div_255, sum_268); div_255 = sum_268 = None + sub_748 = torch.ops.aten.sub.Tensor(mul_1985, mul_1988); mul_1985 = mul_1988 = None + mul_1989 = torch.ops.aten.mul.Tensor(sub_748, rsqrt_20); sub_748 = rsqrt_20 = None + mul_1990 = torch.ops.aten.mul.Tensor(convert_element_type_2970, mul_260); convert_element_type_2970 = mul_260 = None + sum_269 = torch.ops.aten.sum.dim_IntList(mul_1990, [0, 1]); mul_1990 = None + convert_element_type_2973 = torch.ops.prims.convert_element_type.default(mul_1989, torch.bfloat16); mul_1989 = None + add_2087 = torch.ops.aten.add.Tensor(add_2074, convert_element_type_2973); add_2074 = convert_element_type_2973 = None + convert_element_type_default_21 = torch.ops.prims.convert_element_type.default(sum_269, torch.float32); sum_269 = None + reduce_scatter_tensor_289 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_21, 'avg', 128, '0'); convert_element_type_default_21 = None + wait_tensor_890 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_289); reduce_scatter_tensor_289 = None + view_2166 = torch.ops.aten.view.default(add_2087, [8192, 2048]) + permute_1436 = torch.ops.aten.permute.default(view_2166, [1, 0]) + permute_92 = torch.ops.aten.permute.default(getitem_565, [0, 2, 1, 3]) + view_388 = torch.ops.aten.view.default(permute_92, [2, 4096, -1]); permute_92 = None + view_390 = torch.ops.aten.view.default(view_388, [8192, 2048]); view_388 = None + mm_546 = torch.ops.aten.mm.default(permute_1436, view_390); permute_1436 = view_390 = None + convert_element_type_326 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_326, 128, '0'); convert_element_type_326 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_93 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + permute_1438 = torch.ops.aten.permute.default(permute_93, [1, 0]); permute_93 = None + mm_547 = torch.ops.aten.mm.default(view_2166, permute_1438); view_2166 = permute_1438 = None + view_2167 = torch.ops.aten.view.default(mm_547, [2, 4096, 2048]); mm_547 = None + convert_element_type_2980 = torch.ops.prims.convert_element_type.default(mm_546, torch.float32); mm_546 = None + reduce_scatter_tensor_290 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2980, 'avg', 128, '0'); convert_element_type_2980 = None + wait_tensor_891 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_290); reduce_scatter_tensor_290 = None + view_2168 = torch.ops.aten.view.default(view_2167, [2, 4096, 16, 128]); view_2167 = None + permute_1440 = torch.ops.aten.permute.default(view_2168, [0, 2, 1, 3]); view_2168 = None + fw_graph20 = self.fw_graph20 + joint_graph20 = self.joint_graph20 + mask_graph20 = self.mask_graph20 + flex_attention_backward_20 = torch.ops.higher_order.flex_attention_backward(permute_89, permute_90, permute_91, getitem_565, getitem_566, permute_1440, None, fw_graph20, joint_graph20, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph20), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_89 = permute_90 = permute_91 = getitem_565 = getitem_566 = permute_1440 = fw_graph20 = joint_graph20 = mask_graph20 = None + getitem_23749 = flex_attention_backward_20[0] + getitem_23750 = flex_attention_backward_20[1] + getitem_23751 = flex_attention_backward_20[2]; flex_attention_backward_20 = None + permute_1441 = torch.ops.aten.permute.default(getitem_23751, [0, 2, 1, 3]); getitem_23751 = None + permute_1442 = torch.ops.aten.permute.default(getitem_23750, [0, 2, 1, 3]); getitem_23750 = None + permute_1443 = torch.ops.aten.permute.default(getitem_23749, [0, 2, 1, 3]); getitem_23749 = None + slice_284 = torch.ops.aten.slice.Tensor(permute_1442, 3, 0, 128) + slice_285 = torch.ops.aten.slice.Tensor(permute_1442, 3, 128, 192); permute_1442 = None + sum_270 = torch.ops.aten.sum.dim_IntList(slice_285, [2], True); slice_285 = None + cat_400 = torch.ops.aten.cat.default([slice_284, permute_1441], 3); slice_284 = permute_1441 = None + view_2169 = torch.ops.aten.view.default(cat_400, [2, 4096, 4096]); cat_400 = None + view_2170 = torch.ops.aten.view.default(view_2169, [8192, 4096]); view_2169 = None + permute_1444 = torch.ops.aten.permute.default(view_2170, [1, 0]) + mm_548 = torch.ops.aten.mm.default(permute_1444, view_385); permute_1444 = view_385 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_323, 128, '0'); convert_element_type_323 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + permute_1446 = torch.ops.aten.permute.default(permute_88, [1, 0]); permute_88 = None + mm_549 = torch.ops.aten.mm.default(view_2170, permute_1446); view_2170 = permute_1446 = None + view_2171 = torch.ops.aten.view.default(mm_549, [2, 4096, 512]); mm_549 = None + convert_element_type_2985 = torch.ops.prims.convert_element_type.default(mm_548, torch.float32); mm_548 = None + reduce_scatter_tensor_291 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2985, 'avg', 128, '0'); convert_element_type_2985 = None + wait_tensor_892 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_291); reduce_scatter_tensor_291 = None + convert_element_type_2986 = torch.ops.prims.convert_element_type.default(view_2171, torch.float32); view_2171 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 128, '0'); convert_element_type_320 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + convert_element_type_2988 = torch.ops.prims.convert_element_type.default(wait_tensor_119, torch.float32); wait_tensor_119 = None + mul_1991 = torch.ops.aten.mul.Tensor(convert_element_type_2986, convert_element_type_2988); convert_element_type_2988 = None + convert_element_type_321 = torch.ops.prims.convert_element_type.default(getitem_561, torch.float32); getitem_561 = None + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_321, rsqrt_19); convert_element_type_321 = None + mul_1993 = torch.ops.aten.mul.Tensor(mul_258, mul_1991) + sum_271 = torch.ops.aten.sum.dim_IntList(mul_1993, [2], True); mul_1993 = None + div_256 = torch.ops.aten.div.Tensor(mul_258, 512) + mul_1994 = torch.ops.aten.mul.Tensor(div_256, sum_271); div_256 = sum_271 = None + sub_749 = torch.ops.aten.sub.Tensor(mul_1991, mul_1994); mul_1991 = mul_1994 = None + mul_1995 = torch.ops.aten.mul.Tensor(sub_749, rsqrt_19); sub_749 = rsqrt_19 = None + mul_1996 = torch.ops.aten.mul.Tensor(convert_element_type_2986, mul_258); convert_element_type_2986 = mul_258 = None + sum_272 = torch.ops.aten.sum.dim_IntList(mul_1996, [0, 1]); mul_1996 = None + convert_element_type_2989 = torch.ops.prims.convert_element_type.default(mul_1995, torch.bfloat16); mul_1995 = None + convert_element_type_default_20 = torch.ops.prims.convert_element_type.default(sum_272, torch.float32); sum_272 = None + reduce_scatter_tensor_292 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_20, 'avg', 128, '0'); convert_element_type_default_20 = None + wait_tensor_893 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_292); reduce_scatter_tensor_292 = None + convert_element_type_2992 = torch.ops.prims.convert_element_type.default(sum_270, torch.float32); sum_270 = None + view_2172 = torch.ops.aten.view.default(convert_element_type_2992, [2, 4096, 1, 32, 2]); convert_element_type_2992 = None + view_as_complex_94 = torch.ops.aten.view_as_complex.default(view_2172); view_2172 = None + mul_1997 = torch.ops.aten.mul.Tensor(view_as_complex_94, clone_9); view_as_complex_94 = None + view_as_real_94 = torch.ops.aten.view_as_real.default(mul_1997); mul_1997 = None + view_2173 = torch.ops.aten.view.default(view_as_real_94, [2, 4096, 1, 64]); view_as_real_94 = None + convert_element_type_2993 = torch.ops.prims.convert_element_type.default(view_2173, torch.bfloat16); view_2173 = None + squeeze_46 = torch.ops.aten.squeeze.dim(convert_element_type_2993, 2); convert_element_type_2993 = None + cat_401 = torch.ops.aten.cat.default([convert_element_type_2989, squeeze_46], 2); convert_element_type_2989 = squeeze_46 = None + view_2174 = torch.ops.aten.view.default(cat_401, [8192, 576]); cat_401 = None + permute_1448 = torch.ops.aten.permute.default(view_2174, [1, 0]) + mm_550 = torch.ops.aten.mm.default(permute_1448, view_371); permute_1448 = None + convert_element_type_315 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_315, 128, '0'); convert_element_type_315 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + slice_37 = torch.ops.aten.slice.Tensor(wait_tensor_118, 0, 0, 576); wait_tensor_118 = None + permute_87 = torch.ops.aten.permute.default(slice_37, [1, 0]); slice_37 = None + permute_1450 = torch.ops.aten.permute.default(permute_87, [1, 0]); permute_87 = None + mm_551 = torch.ops.aten.mm.default(view_2174, permute_1450); view_2174 = permute_1450 = None + view_2175 = torch.ops.aten.view.default(mm_551, [2, 4096, 2048]); mm_551 = None + convert_element_type_2998 = torch.ops.prims.convert_element_type.default(mm_550, torch.float32); mm_550 = None + split_1268 = torch.ops.aten.split.Tensor(convert_element_type_2998, 5); convert_element_type_2998 = None + getitem_23753 = split_1268[0] + getitem_23754 = split_1268[1] + getitem_23755 = split_1268[2] + getitem_23756 = split_1268[3] + getitem_23757 = split_1268[4] + getitem_23758 = split_1268[5] + getitem_23759 = split_1268[6] + getitem_23760 = split_1268[7] + getitem_23761 = split_1268[8] + getitem_23762 = split_1268[9] + getitem_23763 = split_1268[10] + getitem_23764 = split_1268[11] + getitem_23765 = split_1268[12] + getitem_23766 = split_1268[13] + getitem_23767 = split_1268[14] + getitem_23768 = split_1268[15] + getitem_23769 = split_1268[16] + getitem_23770 = split_1268[17] + getitem_23771 = split_1268[18] + getitem_23772 = split_1268[19] + getitem_23773 = split_1268[20] + getitem_23774 = split_1268[21] + getitem_23775 = split_1268[22] + getitem_23776 = split_1268[23] + getitem_23777 = split_1268[24] + getitem_23778 = split_1268[25] + getitem_23779 = split_1268[26] + getitem_23780 = split_1268[27] + getitem_23781 = split_1268[28] + getitem_23782 = split_1268[29] + getitem_23783 = split_1268[30] + getitem_23784 = split_1268[31] + getitem_23785 = split_1268[32] + getitem_23786 = split_1268[33] + getitem_23787 = split_1268[34] + getitem_23788 = split_1268[35] + getitem_23789 = split_1268[36] + getitem_23790 = split_1268[37] + getitem_23791 = split_1268[38] + getitem_23792 = split_1268[39] + getitem_23793 = split_1268[40] + getitem_23794 = split_1268[41] + getitem_23795 = split_1268[42] + getitem_23796 = split_1268[43] + getitem_23797 = split_1268[44] + getitem_23798 = split_1268[45] + getitem_23799 = split_1268[46] + getitem_23800 = split_1268[47] + getitem_23801 = split_1268[48] + getitem_23802 = split_1268[49] + getitem_23803 = split_1268[50] + getitem_23804 = split_1268[51] + getitem_23805 = split_1268[52] + getitem_23806 = split_1268[53] + getitem_23807 = split_1268[54] + getitem_23808 = split_1268[55] + getitem_23809 = split_1268[56] + getitem_23810 = split_1268[57] + getitem_23811 = split_1268[58] + getitem_23812 = split_1268[59] + getitem_23813 = split_1268[60] + getitem_23814 = split_1268[61] + getitem_23815 = split_1268[62] + getitem_23816 = split_1268[63] + getitem_23817 = split_1268[64] + getitem_23818 = split_1268[65] + getitem_23819 = split_1268[66] + getitem_23820 = split_1268[67] + getitem_23821 = split_1268[68] + getitem_23822 = split_1268[69] + getitem_23823 = split_1268[70] + getitem_23824 = split_1268[71] + getitem_23825 = split_1268[72] + getitem_23826 = split_1268[73] + getitem_23827 = split_1268[74] + getitem_23828 = split_1268[75] + getitem_23829 = split_1268[76] + getitem_23830 = split_1268[77] + getitem_23831 = split_1268[78] + getitem_23832 = split_1268[79] + getitem_23833 = split_1268[80] + getitem_23834 = split_1268[81] + getitem_23835 = split_1268[82] + getitem_23836 = split_1268[83] + getitem_23837 = split_1268[84] + getitem_23838 = split_1268[85] + getitem_23839 = split_1268[86] + getitem_23840 = split_1268[87] + getitem_23841 = split_1268[88] + getitem_23842 = split_1268[89] + getitem_23843 = split_1268[90] + getitem_23844 = split_1268[91] + getitem_23845 = split_1268[92] + getitem_23846 = split_1268[93] + getitem_23847 = split_1268[94] + getitem_23848 = split_1268[95] + getitem_23849 = split_1268[96] + getitem_23850 = split_1268[97] + getitem_23851 = split_1268[98] + getitem_23852 = split_1268[99] + getitem_23853 = split_1268[100] + getitem_23854 = split_1268[101] + getitem_23855 = split_1268[102] + getitem_23856 = split_1268[103] + getitem_23857 = split_1268[104] + getitem_23858 = split_1268[105] + getitem_23859 = split_1268[106] + getitem_23860 = split_1268[107] + getitem_23861 = split_1268[108] + getitem_23862 = split_1268[109] + getitem_23863 = split_1268[110] + getitem_23864 = split_1268[111] + getitem_23865 = split_1268[112] + getitem_23866 = split_1268[113] + getitem_23867 = split_1268[114] + getitem_23868 = split_1268[115]; split_1268 = None + constant_pad_nd_1604 = torch.ops.aten.constant_pad_nd.default(getitem_23868, [0, 0, 0, 4], 0.0); getitem_23868 = None + cat_402 = torch.ops.aten.cat.default([getitem_23753, getitem_23754, getitem_23755, getitem_23756, getitem_23757, getitem_23758, getitem_23759, getitem_23760, getitem_23761, getitem_23762, getitem_23763, getitem_23764, getitem_23765, getitem_23766, getitem_23767, getitem_23768, getitem_23769, getitem_23770, getitem_23771, getitem_23772, getitem_23773, getitem_23774, getitem_23775, getitem_23776, getitem_23777, getitem_23778, getitem_23779, getitem_23780, getitem_23781, getitem_23782, getitem_23783, getitem_23784, getitem_23785, getitem_23786, getitem_23787, getitem_23788, getitem_23789, getitem_23790, getitem_23791, getitem_23792, getitem_23793, getitem_23794, getitem_23795, getitem_23796, getitem_23797, getitem_23798, getitem_23799, getitem_23800, getitem_23801, getitem_23802, getitem_23803, getitem_23804, getitem_23805, getitem_23806, getitem_23807, getitem_23808, getitem_23809, getitem_23810, getitem_23811, getitem_23812, getitem_23813, getitem_23814, getitem_23815, getitem_23816, getitem_23817, getitem_23818, getitem_23819, getitem_23820, getitem_23821, getitem_23822, getitem_23823, getitem_23824, getitem_23825, getitem_23826, getitem_23827, getitem_23828, getitem_23829, getitem_23830, getitem_23831, getitem_23832, getitem_23833, getitem_23834, getitem_23835, getitem_23836, getitem_23837, getitem_23838, getitem_23839, getitem_23840, getitem_23841, getitem_23842, getitem_23843, getitem_23844, getitem_23845, getitem_23846, getitem_23847, getitem_23848, getitem_23849, getitem_23850, getitem_23851, getitem_23852, getitem_23853, getitem_23854, getitem_23855, getitem_23856, getitem_23857, getitem_23858, getitem_23859, getitem_23860, getitem_23861, getitem_23862, getitem_23863, getitem_23864, getitem_23865, getitem_23866, getitem_23867, constant_pad_nd_1604, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_23753 = getitem_23754 = getitem_23755 = getitem_23756 = getitem_23757 = getitem_23758 = getitem_23759 = getitem_23760 = getitem_23761 = getitem_23762 = getitem_23763 = getitem_23764 = getitem_23765 = getitem_23766 = getitem_23767 = getitem_23768 = getitem_23769 = getitem_23770 = getitem_23771 = getitem_23772 = getitem_23773 = getitem_23774 = getitem_23775 = getitem_23776 = getitem_23777 = getitem_23778 = getitem_23779 = getitem_23780 = getitem_23781 = getitem_23782 = getitem_23783 = getitem_23784 = getitem_23785 = getitem_23786 = getitem_23787 = getitem_23788 = getitem_23789 = getitem_23790 = getitem_23791 = getitem_23792 = getitem_23793 = getitem_23794 = getitem_23795 = getitem_23796 = getitem_23797 = getitem_23798 = getitem_23799 = getitem_23800 = getitem_23801 = getitem_23802 = getitem_23803 = getitem_23804 = getitem_23805 = getitem_23806 = getitem_23807 = getitem_23808 = getitem_23809 = getitem_23810 = getitem_23811 = getitem_23812 = getitem_23813 = getitem_23814 = getitem_23815 = getitem_23816 = getitem_23817 = getitem_23818 = getitem_23819 = getitem_23820 = getitem_23821 = getitem_23822 = getitem_23823 = getitem_23824 = getitem_23825 = getitem_23826 = getitem_23827 = getitem_23828 = getitem_23829 = getitem_23830 = getitem_23831 = getitem_23832 = getitem_23833 = getitem_23834 = getitem_23835 = getitem_23836 = getitem_23837 = getitem_23838 = getitem_23839 = getitem_23840 = getitem_23841 = getitem_23842 = getitem_23843 = getitem_23844 = getitem_23845 = getitem_23846 = getitem_23847 = getitem_23848 = getitem_23849 = getitem_23850 = getitem_23851 = getitem_23852 = getitem_23853 = getitem_23854 = getitem_23855 = getitem_23856 = getitem_23857 = getitem_23858 = getitem_23859 = getitem_23860 = getitem_23861 = getitem_23862 = getitem_23863 = getitem_23864 = getitem_23865 = getitem_23866 = getitem_23867 = constant_pad_nd_1604 = None + reduce_scatter_tensor_293 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_402, 'avg', 128, '0'); cat_402 = None + wait_tensor_894 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_293); reduce_scatter_tensor_293 = None + slice_286 = torch.ops.aten.slice.Tensor(permute_1443, 3, 0, 128) + slice_287 = torch.ops.aten.slice.Tensor(permute_1443, 3, 128, 192); permute_1443 = None + convert_element_type_2999 = torch.ops.prims.convert_element_type.default(slice_287, torch.float32); slice_287 = None + view_2176 = torch.ops.aten.view.default(convert_element_type_2999, [2, 4096, 16, 32, 2]); convert_element_type_2999 = None + view_as_complex_95 = torch.ops.aten.view_as_complex.default(view_2176); view_2176 = None + mul_1998 = torch.ops.aten.mul.Tensor(view_as_complex_95, clone_9); view_as_complex_95 = None + view_as_real_95 = torch.ops.aten.view_as_real.default(mul_1998); mul_1998 = None + view_2177 = torch.ops.aten.view.default(view_as_real_95, [2, 4096, 16, 64]); view_as_real_95 = None + convert_element_type_3000 = torch.ops.prims.convert_element_type.default(view_2177, torch.bfloat16); view_2177 = None + cat_403 = torch.ops.aten.cat.default([slice_286, convert_element_type_3000], 3); slice_286 = convert_element_type_3000 = None + view_2178 = torch.ops.aten.view.default(cat_403, [2, 4096, 3072]); cat_403 = None + view_2179 = torch.ops.aten.view.default(view_2178, [8192, 3072]); view_2178 = None + permute_1452 = torch.ops.aten.permute.default(view_2179, [1, 0]) + mm_552 = torch.ops.aten.mm.default(permute_1452, view_371); permute_1452 = view_371 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_310, 128, '0'); convert_element_type_310 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + permute_1454 = torch.ops.aten.permute.default(permute_86, [1, 0]); permute_86 = None + mm_553 = torch.ops.aten.mm.default(view_2179, permute_1454); view_2179 = permute_1454 = None + view_2180 = torch.ops.aten.view.default(mm_553, [2, 4096, 2048]); mm_553 = None + add_2088 = torch.ops.aten.add.Tensor(view_2175, view_2180); view_2175 = view_2180 = None + convert_element_type_3005 = torch.ops.prims.convert_element_type.default(mm_552, torch.float32); mm_552 = None + reduce_scatter_tensor_294 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3005, 'avg', 128, '0'); convert_element_type_3005 = None + wait_tensor_895 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_294); reduce_scatter_tensor_294 = None + convert_element_type_3006 = torch.ops.prims.convert_element_type.default(add_2088, torch.float32); add_2088 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 128, '0'); convert_element_type_307 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_3008 = torch.ops.prims.convert_element_type.default(wait_tensor_116, torch.float32); wait_tensor_116 = None + mul_1999 = torch.ops.aten.mul.Tensor(convert_element_type_3006, convert_element_type_3008); convert_element_type_3008 = None + convert_element_type_308 = torch.ops.prims.convert_element_type.default(add_345, torch.float32); add_345 = None + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_308, rsqrt_18); convert_element_type_308 = None + mul_2001 = torch.ops.aten.mul.Tensor(mul_254, mul_1999) + sum_273 = torch.ops.aten.sum.dim_IntList(mul_2001, [2], True); mul_2001 = None + div_257 = torch.ops.aten.div.Tensor(mul_254, 2048) + mul_2002 = torch.ops.aten.mul.Tensor(div_257, sum_273); div_257 = sum_273 = None + sub_750 = torch.ops.aten.sub.Tensor(mul_1999, mul_2002); mul_1999 = mul_2002 = None + mul_2003 = torch.ops.aten.mul.Tensor(sub_750, rsqrt_18); sub_750 = rsqrt_18 = None + mul_2004 = torch.ops.aten.mul.Tensor(convert_element_type_3006, mul_254); convert_element_type_3006 = mul_254 = None + sum_274 = torch.ops.aten.sum.dim_IntList(mul_2004, [0, 1]); mul_2004 = None + convert_element_type_3009 = torch.ops.prims.convert_element_type.default(mul_2003, torch.bfloat16); mul_2003 = None + add_2089 = torch.ops.aten.add.Tensor(add_2087, convert_element_type_3009); add_2087 = convert_element_type_3009 = None + convert_element_type_default_19 = torch.ops.prims.convert_element_type.default(sum_274, torch.float32); sum_274 = None + reduce_scatter_tensor_295 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_19, 'avg', 128, '0'); convert_element_type_default_19 = None + wait_tensor_896 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_295); reduce_scatter_tensor_295 = None + view_2181 = torch.ops.aten.view.default(add_2089, [8192, 2048]) + unsqueeze_74 = torch.ops.aten.unsqueeze.default(view_2181, 1) + convert_element_type_3012 = torch.ops.prims.convert_element_type.default(unsqueeze_74, torch.float32); unsqueeze_74 = None + bmm_68 = torch.ops.aten.bmm.default(permute_1456, convert_element_type_3012); permute_1456 = None + bmm_69 = torch.ops.aten.bmm.default(convert_element_type_3012, permute_1457); convert_element_type_3012 = permute_1457 = None + convert_element_type_3013 = torch.ops.prims.convert_element_type.default(bmm_68, torch.bfloat16); bmm_68 = None + view_2182 = torch.ops.aten.view.default(bmm_69, [8192, 6]); bmm_69 = None + view_2183 = torch.ops.aten.view.default(convert_element_type_3013, [49152, 2048]); convert_element_type_3013 = None + index_94 = torch.ops.aten.index.Tensor(view_2183, [getitem_461]); view_2183 = getitem_461 = None + permute_1458 = torch.ops.aten.permute.default(view_2181, [1, 0]) + mm_554 = torch.ops.aten.mm.default(permute_1458, mul_251); permute_1458 = mul_251 = None + convert_element_type_302 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_302, 128, '0'); convert_element_type_302 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + permute_1460 = torch.ops.aten.permute.default(permute_85, [1, 0]); permute_85 = None + mm_555 = torch.ops.aten.mm.default(view_2181, permute_1460); view_2181 = permute_1460 = None + convert_element_type_3018 = torch.ops.prims.convert_element_type.default(mm_554, torch.float32); mm_554 = None + reduce_scatter_tensor_296 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3018, 'avg', 128, '0'); convert_element_type_3018 = None + wait_tensor_897 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_296); reduce_scatter_tensor_296 = None + convert_element_type_297 = torch.ops.prims.convert_element_type.default(mm_44, torch.float32); mm_44 = None + neg_10 = torch.ops.aten.neg.default(convert_element_type_297) + exp_15 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_340 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + div_25 = torch.ops.aten.div.Tensor(convert_element_type_297, add_340) + convert_element_type_298 = torch.ops.prims.convert_element_type.default(div_25, torch.bfloat16); div_25 = None + mul_2005 = torch.ops.aten.mul.Tensor(mm_555, convert_element_type_298); convert_element_type_298 = None + mul_2006 = torch.ops.aten.mul.Tensor(mm_555, mm_45); mm_555 = mm_45 = None + permute_1462 = torch.ops.aten.permute.default(mul_2005, [1, 0]) + mm_556 = torch.ops.aten.mm.default(permute_1462, view_326); permute_1462 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_299, 128, '0'); convert_element_type_299 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_114, [1, 0]); wait_tensor_114 = None + permute_1464 = torch.ops.aten.permute.default(permute_84, [1, 0]); permute_84 = None + mm_557 = torch.ops.aten.mm.default(mul_2005, permute_1464); mul_2005 = permute_1464 = None + convert_element_type_3023 = torch.ops.prims.convert_element_type.default(mm_556, torch.float32); mm_556 = None + reduce_scatter_tensor_297 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3023, 'avg', 128, '0'); convert_element_type_3023 = None + wait_tensor_898 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_297); reduce_scatter_tensor_297 = None + convert_element_type_3024 = torch.ops.prims.convert_element_type.default(mul_2006, torch.float32); mul_2006 = None + reciprocal_42 = torch.ops.aten.reciprocal.default(add_340); add_340 = None + mul_2007 = torch.ops.aten.mul.Tensor(reciprocal_42, 1); reciprocal_42 = None + mul_2008 = torch.ops.aten.mul.Tensor(convert_element_type_3024, mul_2007); convert_element_type_3024 = None + sub_751 = torch.ops.aten.sub.Tensor(1, mul_2007); mul_2007 = None + mul_2009 = torch.ops.aten.mul.Tensor(convert_element_type_297, sub_751); convert_element_type_297 = sub_751 = None + add_2091 = torch.ops.aten.add.Tensor(mul_2009, 1); mul_2009 = None + mul_2010 = torch.ops.aten.mul.Tensor(mul_2008, add_2091); mul_2008 = add_2091 = None + convert_element_type_3026 = torch.ops.prims.convert_element_type.default(mul_2010, torch.bfloat16); mul_2010 = None + permute_1466 = torch.ops.aten.permute.default(convert_element_type_3026, [1, 0]) + mm_558 = torch.ops.aten.mm.default(permute_1466, view_326); permute_1466 = None + convert_element_type_294 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_294, 128, '0'); convert_element_type_294 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_83 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + permute_1468 = torch.ops.aten.permute.default(permute_83, [1, 0]); permute_83 = None + mm_559 = torch.ops.aten.mm.default(convert_element_type_3026, permute_1468); convert_element_type_3026 = permute_1468 = None + add_2092 = torch.ops.aten.add.Tensor(mm_557, mm_559); mm_557 = mm_559 = None + convert_element_type_3031 = torch.ops.prims.convert_element_type.default(mm_558, torch.float32); mm_558 = None + reduce_scatter_tensor_298 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3031, 'avg', 128, '0'); convert_element_type_3031 = None + wait_tensor_899 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_298); reduce_scatter_tensor_298 = None + all_to_all_single_120 = torch.ops._c10d_functional.all_to_all_single.default(index_94, [_local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79], [_local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71], '1033'); index_94 = None + wait_tensor_900 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_120); all_to_all_single_120 = None + full_474 = torch.ops.aten.full.default([sym_size_int_17, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_17 = None + slice_scatter_21 = torch.ops.aten.slice_scatter.default(full_474, wait_tensor_900, 0, 0, -1); wait_tensor_900 = None + index_95 = torch.ops.aten.index.Tensor(slice_scatter_21, [getitem_462]); slice_scatter_21 = None + permute_1470 = torch.ops.aten.permute.default(index_95, [1, 0]) + _grouped_mm_204 = torch.ops.aten._grouped_mm.default(permute_1470, mul_231, cumsum_14); permute_1470 = mul_231 = None + _grouped_mm_205 = torch.ops.aten._grouped_mm.default(index_95, permute_1472, cumsum_14); index_95 = permute_1472 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(_grouped_mm_12, torch.float32); _grouped_mm_12 = None + neg_9 = torch.ops.aten.neg.default(convert_element_type_292) + exp_14 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_304 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + div_24 = torch.ops.aten.div.Tensor(convert_element_type_292, add_304) + convert_element_type_293 = torch.ops.prims.convert_element_type.default(div_24, torch.bfloat16); div_24 = None + mul_2011 = torch.ops.aten.mul.Tensor(_grouped_mm_205, convert_element_type_293); convert_element_type_293 = None + mul_2012 = torch.ops.aten.mul.Tensor(_grouped_mm_205, _grouped_mm_13); _grouped_mm_205 = _grouped_mm_13 = None + permute_1474 = torch.ops.aten.permute.default(mul_2011, [1, 0]) + _grouped_mm_206 = torch.ops.aten._grouped_mm.default(permute_1474, index_9, cumsum_14); permute_1474 = None + _grouped_mm_207 = torch.ops.aten._grouped_mm.default(mul_2011, permute_1476, cumsum_14); mul_2011 = permute_1476 = None + convert_element_type_3032 = torch.ops.prims.convert_element_type.default(mul_2012, torch.float32); mul_2012 = None + reciprocal_43 = torch.ops.aten.reciprocal.default(add_304); add_304 = None + mul_2013 = torch.ops.aten.mul.Tensor(reciprocal_43, 1); reciprocal_43 = None + mul_2014 = torch.ops.aten.mul.Tensor(convert_element_type_3032, mul_2013); convert_element_type_3032 = None + sub_752 = torch.ops.aten.sub.Tensor(1, mul_2013); mul_2013 = None + mul_2015 = torch.ops.aten.mul.Tensor(convert_element_type_292, sub_752); convert_element_type_292 = sub_752 = None + add_2094 = torch.ops.aten.add.Tensor(mul_2015, 1); mul_2015 = None + mul_2016 = torch.ops.aten.mul.Tensor(mul_2014, add_2094); mul_2014 = add_2094 = None + convert_element_type_3034 = torch.ops.prims.convert_element_type.default(mul_2016, torch.bfloat16); mul_2016 = None + permute_1478 = torch.ops.aten.permute.default(convert_element_type_3034, [1, 0]) + _grouped_mm_208 = torch.ops.aten._grouped_mm.default(permute_1478, index_9, cumsum_14); permute_1478 = index_9 = None + _grouped_mm_209 = torch.ops.aten._grouped_mm.default(convert_element_type_3034, permute_1480, cumsum_14); convert_element_type_3034 = permute_1480 = cumsum_14 = None + add_2095 = torch.ops.aten.add.Tensor(_grouped_mm_207, _grouped_mm_209); _grouped_mm_207 = _grouped_mm_209 = None + convert_element_type_3035 = torch.ops.prims.convert_element_type.default(_grouped_mm_206, torch.float32); _grouped_mm_206 = None + div_258 = torch.ops.aten.div.Tensor(convert_element_type_3035, 128); convert_element_type_3035 = None + split_1270 = torch.ops.aten.split.Tensor(div_258, 88, 1); div_258 = None + getitem_23885 = split_1270[0] + getitem_23902 = split_1270[1] + getitem_23919 = split_1270[2] + getitem_23936 = split_1270[3] + getitem_23953 = split_1270[4] + getitem_23970 = split_1270[5] + getitem_23987 = split_1270[6] + getitem_24004 = split_1270[7] + getitem_24021 = split_1270[8] + getitem_24038 = split_1270[9] + getitem_24055 = split_1270[10] + getitem_24072 = split_1270[11] + getitem_24089 = split_1270[12] + getitem_24106 = split_1270[13] + getitem_24123 = split_1270[14] + getitem_24140 = split_1270[15]; split_1270 = None + cat_404 = torch.ops.aten.cat.default([getitem_23885, getitem_23902, getitem_23919, getitem_23936, getitem_23953, getitem_23970, getitem_23987, getitem_24004, getitem_24021, getitem_24038, getitem_24055, getitem_24072, getitem_24089, getitem_24106, getitem_24123, getitem_24140]); getitem_23885 = getitem_23902 = getitem_23919 = getitem_23936 = getitem_23953 = getitem_23970 = getitem_23987 = getitem_24004 = getitem_24021 = getitem_24038 = getitem_24055 = getitem_24072 = getitem_24089 = getitem_24106 = getitem_24123 = getitem_24140 = None + reduce_scatter_tensor_299 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_404, 'sum', 16, '1025'); cat_404 = None + wait_tensor_901 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_299); reduce_scatter_tensor_299 = None + convert_element_type_3036 = torch.ops.prims.convert_element_type.default(_grouped_mm_204, torch.float32); _grouped_mm_204 = None + div_259 = torch.ops.aten.div.Tensor(convert_element_type_3036, 128); convert_element_type_3036 = None + split_1287 = torch.ops.aten.split.Tensor(div_259, 128, 1); div_259 = None + getitem_24157 = split_1287[0] + getitem_24174 = split_1287[1] + getitem_24191 = split_1287[2] + getitem_24208 = split_1287[3] + getitem_24225 = split_1287[4] + getitem_24242 = split_1287[5] + getitem_24259 = split_1287[6] + getitem_24276 = split_1287[7] + getitem_24293 = split_1287[8] + getitem_24310 = split_1287[9] + getitem_24327 = split_1287[10] + getitem_24344 = split_1287[11] + getitem_24361 = split_1287[12] + getitem_24378 = split_1287[13] + getitem_24395 = split_1287[14] + getitem_24412 = split_1287[15]; split_1287 = None + cat_405 = torch.ops.aten.cat.default([getitem_24157, getitem_24174, getitem_24191, getitem_24208, getitem_24225, getitem_24242, getitem_24259, getitem_24276, getitem_24293, getitem_24310, getitem_24327, getitem_24344, getitem_24361, getitem_24378, getitem_24395, getitem_24412]); getitem_24157 = getitem_24174 = getitem_24191 = getitem_24208 = getitem_24225 = getitem_24242 = getitem_24259 = getitem_24276 = getitem_24293 = getitem_24310 = getitem_24327 = getitem_24344 = getitem_24361 = getitem_24378 = getitem_24395 = getitem_24412 = None + reduce_scatter_tensor_300 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_405, 'sum', 16, '1025'); cat_405 = None + wait_tensor_902 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_300); reduce_scatter_tensor_300 = None + convert_element_type_3037 = torch.ops.prims.convert_element_type.default(_grouped_mm_208, torch.float32); _grouped_mm_208 = None + div_260 = torch.ops.aten.div.Tensor(convert_element_type_3037, 128); convert_element_type_3037 = None + split_1304 = torch.ops.aten.split.Tensor(div_260, 88, 1); div_260 = None + getitem_24429 = split_1304[0] + getitem_24446 = split_1304[1] + getitem_24463 = split_1304[2] + getitem_24480 = split_1304[3] + getitem_24497 = split_1304[4] + getitem_24514 = split_1304[5] + getitem_24531 = split_1304[6] + getitem_24548 = split_1304[7] + getitem_24565 = split_1304[8] + getitem_24582 = split_1304[9] + getitem_24599 = split_1304[10] + getitem_24616 = split_1304[11] + getitem_24633 = split_1304[12] + getitem_24650 = split_1304[13] + getitem_24667 = split_1304[14] + getitem_24684 = split_1304[15]; split_1304 = None + cat_406 = torch.ops.aten.cat.default([getitem_24429, getitem_24446, getitem_24463, getitem_24480, getitem_24497, getitem_24514, getitem_24531, getitem_24548, getitem_24565, getitem_24582, getitem_24599, getitem_24616, getitem_24633, getitem_24650, getitem_24667, getitem_24684]); getitem_24429 = getitem_24446 = getitem_24463 = getitem_24480 = getitem_24497 = getitem_24514 = getitem_24531 = getitem_24548 = getitem_24565 = getitem_24582 = getitem_24599 = getitem_24616 = getitem_24633 = getitem_24650 = getitem_24667 = getitem_24684 = None + reduce_scatter_tensor_301 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_406, 'sum', 16, '1025'); cat_406 = None + wait_tensor_903 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_301); reduce_scatter_tensor_301 = None + index_put_94 = torch.ops.aten.index_put.default(full_474, [getitem_462], add_2095, True); full_474 = getitem_462 = add_2095 = None + slice_288 = torch.ops.aten.slice.Tensor(index_put_94, 0, 0, add_2096); index_put_94 = add_2096 = None + all_to_all_single_121 = torch.ops._c10d_functional.all_to_all_single.default(slice_288, [_local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71], [_local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79], '1033'); slice_288 = _local_scalar_dense_64 = _local_scalar_dense_65 = _local_scalar_dense_66 = _local_scalar_dense_67 = _local_scalar_dense_68 = _local_scalar_dense_69 = _local_scalar_dense_70 = _local_scalar_dense_71 = _local_scalar_dense_72 = _local_scalar_dense_73 = _local_scalar_dense_74 = _local_scalar_dense_75 = _local_scalar_dense_76 = _local_scalar_dense_77 = _local_scalar_dense_78 = _local_scalar_dense_79 = None + wait_tensor_904 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_121); all_to_all_single_121 = None + index_put_95 = torch.ops.aten.index_put.default(full_default_52, [div_22], wait_tensor_904, True); div_22 = wait_tensor_904 = None + add_2100 = torch.ops.aten.add.Tensor(add_2092, index_put_95); add_2092 = index_put_95 = None + mul_2017 = torch.ops.aten.mul.Tensor(view_2182, 1.0); view_2182 = None + scatter_add_21 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_459, mul_2017); getitem_459 = mul_2017 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(mm_43, torch.float32); mm_43 = None + sub_96 = torch.ops.aten.sub.Tensor(convert_element_type_281, amax_4); convert_element_type_281 = amax_4 = None + exp_13 = torch.ops.aten.exp.default(sub_96); sub_96 = None + div_21 = torch.ops.aten.div.Tensor(exp_13, sum_17); exp_13 = sum_17 = None + mul_2018 = torch.ops.aten.mul.Tensor(scatter_add_21, div_21); scatter_add_21 = None + sum_275 = torch.ops.aten.sum.dim_IntList(mul_2018, [1], True) + neg_118 = torch.ops.aten.neg.default(div_21); div_21 = None + fma_21 = torch.ops.prims.fma.default(neg_118, sum_275, mul_2018); neg_118 = sum_275 = mul_2018 = None + convert_element_type_3038 = torch.ops.prims.convert_element_type.default(fma_21, torch.bfloat16); fma_21 = None + permute_1482 = torch.ops.aten.permute.default(convert_element_type_3038, [1, 0]) + mm_560 = torch.ops.aten.mm.default(permute_1482, view_326); permute_1482 = view_326 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16); primals_95 = None + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_278, 128, '0'); convert_element_type_278 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + slice_33 = torch.ops.aten.slice.Tensor(wait_tensor_102, 0, 0, 64); wait_tensor_102 = None + permute_79 = torch.ops.aten.permute.default(slice_33, [1, 0]); slice_33 = None + permute_1484 = torch.ops.aten.permute.default(permute_79, [1, 0]); permute_79 = None + mm_561 = torch.ops.aten.mm.default(convert_element_type_3038, permute_1484); convert_element_type_3038 = permute_1484 = None + add_2101 = torch.ops.aten.add.Tensor(add_2100, mm_561); add_2100 = mm_561 = None + convert_element_type_3043 = torch.ops.prims.convert_element_type.default(mm_560, torch.float32); mm_560 = None + split_1320 = torch.ops.aten.split.Tensor(convert_element_type_3043, 1); convert_element_type_3043 = None + getitem_24685 = split_1320[0] + getitem_24686 = split_1320[1] + getitem_24687 = split_1320[2] + getitem_24688 = split_1320[3] + getitem_24689 = split_1320[4] + getitem_24690 = split_1320[5] + getitem_24691 = split_1320[6] + getitem_24692 = split_1320[7] + getitem_24693 = split_1320[8] + getitem_24694 = split_1320[9] + getitem_24695 = split_1320[10] + getitem_24696 = split_1320[11] + getitem_24697 = split_1320[12] + getitem_24698 = split_1320[13] + getitem_24699 = split_1320[14] + getitem_24700 = split_1320[15] + getitem_24701 = split_1320[16] + getitem_24702 = split_1320[17] + getitem_24703 = split_1320[18] + getitem_24704 = split_1320[19] + getitem_24705 = split_1320[20] + getitem_24706 = split_1320[21] + getitem_24707 = split_1320[22] + getitem_24708 = split_1320[23] + getitem_24709 = split_1320[24] + getitem_24710 = split_1320[25] + getitem_24711 = split_1320[26] + getitem_24712 = split_1320[27] + getitem_24713 = split_1320[28] + getitem_24714 = split_1320[29] + getitem_24715 = split_1320[30] + getitem_24716 = split_1320[31] + getitem_24717 = split_1320[32] + getitem_24718 = split_1320[33] + getitem_24719 = split_1320[34] + getitem_24720 = split_1320[35] + getitem_24721 = split_1320[36] + getitem_24722 = split_1320[37] + getitem_24723 = split_1320[38] + getitem_24724 = split_1320[39] + getitem_24725 = split_1320[40] + getitem_24726 = split_1320[41] + getitem_24727 = split_1320[42] + getitem_24728 = split_1320[43] + getitem_24729 = split_1320[44] + getitem_24730 = split_1320[45] + getitem_24731 = split_1320[46] + getitem_24732 = split_1320[47] + getitem_24733 = split_1320[48] + getitem_24734 = split_1320[49] + getitem_24735 = split_1320[50] + getitem_24736 = split_1320[51] + getitem_24737 = split_1320[52] + getitem_24738 = split_1320[53] + getitem_24739 = split_1320[54] + getitem_24740 = split_1320[55] + getitem_24741 = split_1320[56] + getitem_24742 = split_1320[57] + getitem_24743 = split_1320[58] + getitem_24744 = split_1320[59] + getitem_24745 = split_1320[60] + getitem_24746 = split_1320[61] + getitem_24747 = split_1320[62] + getitem_24748 = split_1320[63]; split_1320 = None + cat_407 = torch.ops.aten.cat.default([getitem_24685, getitem_24686, getitem_24687, getitem_24688, getitem_24689, getitem_24690, getitem_24691, getitem_24692, getitem_24693, getitem_24694, getitem_24695, getitem_24696, getitem_24697, getitem_24698, getitem_24699, getitem_24700, getitem_24701, getitem_24702, getitem_24703, getitem_24704, getitem_24705, getitem_24706, getitem_24707, getitem_24708, getitem_24709, getitem_24710, getitem_24711, getitem_24712, getitem_24713, getitem_24714, getitem_24715, getitem_24716, getitem_24717, getitem_24718, getitem_24719, getitem_24720, getitem_24721, getitem_24722, getitem_24723, getitem_24724, getitem_24725, getitem_24726, getitem_24727, getitem_24728, getitem_24729, getitem_24730, getitem_24731, getitem_24732, getitem_24733, getitem_24734, getitem_24735, getitem_24736, getitem_24737, getitem_24738, getitem_24739, getitem_24740, getitem_24741, getitem_24742, getitem_24743, getitem_24744, getitem_24745, getitem_24746, getitem_24747, getitem_24748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_24685 = getitem_24686 = getitem_24687 = getitem_24688 = getitem_24689 = getitem_24690 = getitem_24691 = getitem_24692 = getitem_24693 = getitem_24694 = getitem_24695 = getitem_24696 = getitem_24697 = getitem_24698 = getitem_24699 = getitem_24700 = getitem_24701 = getitem_24702 = getitem_24703 = getitem_24704 = getitem_24705 = getitem_24706 = getitem_24707 = getitem_24708 = getitem_24709 = getitem_24710 = getitem_24711 = getitem_24712 = getitem_24713 = getitem_24714 = getitem_24715 = getitem_24716 = getitem_24717 = getitem_24718 = getitem_24719 = getitem_24720 = getitem_24721 = getitem_24722 = getitem_24723 = getitem_24724 = getitem_24725 = getitem_24726 = getitem_24727 = getitem_24728 = getitem_24729 = getitem_24730 = getitem_24731 = getitem_24732 = getitem_24733 = getitem_24734 = getitem_24735 = getitem_24736 = getitem_24737 = getitem_24738 = getitem_24739 = getitem_24740 = getitem_24741 = getitem_24742 = getitem_24743 = getitem_24744 = getitem_24745 = getitem_24746 = getitem_24747 = getitem_24748 = None + reduce_scatter_tensor_302 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_407, 'avg', 128, '0'); cat_407 = None + wait_tensor_905 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_302); reduce_scatter_tensor_302 = None + view_2184 = torch.ops.aten.view.default(add_2101, [2, 4096, 2048]); add_2101 = None + convert_element_type_3044 = torch.ops.prims.convert_element_type.default(view_2184, torch.float32); view_2184 = None + convert_element_type_275 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16); primals_93 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_275, 128, '0'); convert_element_type_275 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + convert_element_type_3046 = torch.ops.prims.convert_element_type.default(wait_tensor_101, torch.float32); wait_tensor_101 = None + mul_2019 = torch.ops.aten.mul.Tensor(convert_element_type_3044, convert_element_type_3046); convert_element_type_3046 = None + convert_element_type_276 = torch.ops.prims.convert_element_type.default(add_280, torch.float32); add_280 = None + mul_211 = torch.ops.aten.mul.Tensor(convert_element_type_276, rsqrt_17); convert_element_type_276 = None + mul_2021 = torch.ops.aten.mul.Tensor(mul_211, mul_2019) + sum_276 = torch.ops.aten.sum.dim_IntList(mul_2021, [2], True); mul_2021 = None + div_261 = torch.ops.aten.div.Tensor(mul_211, 2048) + mul_2022 = torch.ops.aten.mul.Tensor(div_261, sum_276); div_261 = sum_276 = None + sub_754 = torch.ops.aten.sub.Tensor(mul_2019, mul_2022); mul_2019 = mul_2022 = None + mul_2023 = torch.ops.aten.mul.Tensor(sub_754, rsqrt_17); sub_754 = rsqrt_17 = None + mul_2024 = torch.ops.aten.mul.Tensor(convert_element_type_3044, mul_211); convert_element_type_3044 = mul_211 = None + sum_277 = torch.ops.aten.sum.dim_IntList(mul_2024, [0, 1]); mul_2024 = None + convert_element_type_3047 = torch.ops.prims.convert_element_type.default(mul_2023, torch.bfloat16); mul_2023 = None + add_2102 = torch.ops.aten.add.Tensor(add_2089, convert_element_type_3047); add_2089 = convert_element_type_3047 = None + convert_element_type_default_18 = torch.ops.prims.convert_element_type.default(sum_277, torch.float32); sum_277 = None + reduce_scatter_tensor_303 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_18, 'avg', 128, '0'); convert_element_type_default_18 = None + wait_tensor_906 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_303); reduce_scatter_tensor_303 = None + view_2185 = torch.ops.aten.view.default(add_2102, [8192, 2048]) + permute_1486 = torch.ops.aten.permute.default(view_2185, [1, 0]) + permute_77 = torch.ops.aten.permute.default(getitem_455, [0, 2, 1, 3]) + view_321 = torch.ops.aten.view.default(permute_77, [2, 4096, -1]); permute_77 = None + view_323 = torch.ops.aten.view.default(view_321, [8192, 2048]); view_321 = None + mm_562 = torch.ops.aten.mm.default(permute_1486, view_323); permute_1486 = view_323 = None + convert_element_type_272 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16); primals_92 = None + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_272, 128, '0'); convert_element_type_272 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_100, [1, 0]); wait_tensor_100 = None + permute_1488 = torch.ops.aten.permute.default(permute_78, [1, 0]); permute_78 = None + mm_563 = torch.ops.aten.mm.default(view_2185, permute_1488); view_2185 = permute_1488 = None + view_2186 = torch.ops.aten.view.default(mm_563, [2, 4096, 2048]); mm_563 = None + convert_element_type_3054 = torch.ops.prims.convert_element_type.default(mm_562, torch.float32); mm_562 = None + reduce_scatter_tensor_304 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3054, 'avg', 128, '0'); convert_element_type_3054 = None + wait_tensor_907 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_304); reduce_scatter_tensor_304 = None + view_2187 = torch.ops.aten.view.default(view_2186, [2, 4096, 16, 128]); view_2186 = None + permute_1490 = torch.ops.aten.permute.default(view_2187, [0, 2, 1, 3]); view_2187 = None + fw_graph21 = self.fw_graph21 + joint_graph21 = self.joint_graph21 + mask_graph21 = self.mask_graph21 + flex_attention_backward_21 = torch.ops.higher_order.flex_attention_backward(permute_74, permute_75, permute_76, getitem_455, getitem_456, permute_1490, None, fw_graph21, joint_graph21, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph21), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_74 = permute_75 = permute_76 = getitem_455 = getitem_456 = permute_1490 = fw_graph21 = joint_graph21 = mask_graph21 = None + getitem_24749 = flex_attention_backward_21[0] + getitem_24750 = flex_attention_backward_21[1] + getitem_24751 = flex_attention_backward_21[2]; flex_attention_backward_21 = None + permute_1491 = torch.ops.aten.permute.default(getitem_24751, [0, 2, 1, 3]); getitem_24751 = None + permute_1492 = torch.ops.aten.permute.default(getitem_24750, [0, 2, 1, 3]); getitem_24750 = None + permute_1493 = torch.ops.aten.permute.default(getitem_24749, [0, 2, 1, 3]); getitem_24749 = None + slice_290 = torch.ops.aten.slice.Tensor(permute_1492, 3, 0, 128) + slice_291 = torch.ops.aten.slice.Tensor(permute_1492, 3, 128, 192); permute_1492 = None + sum_278 = torch.ops.aten.sum.dim_IntList(slice_291, [2], True); slice_291 = None + cat_408 = torch.ops.aten.cat.default([slice_290, permute_1491], 3); slice_290 = permute_1491 = None + view_2188 = torch.ops.aten.view.default(cat_408, [2, 4096, 4096]); cat_408 = None + view_2189 = torch.ops.aten.view.default(view_2188, [8192, 4096]); view_2188 = None + permute_1494 = torch.ops.aten.permute.default(view_2189, [1, 0]) + mm_564 = torch.ops.aten.mm.default(permute_1494, view_318); permute_1494 = view_318 = None + convert_element_type_269 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16); primals_91 = None + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_269, 128, '0'); convert_element_type_269 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + permute_1496 = torch.ops.aten.permute.default(permute_73, [1, 0]); permute_73 = None + mm_565 = torch.ops.aten.mm.default(view_2189, permute_1496); view_2189 = permute_1496 = None + view_2190 = torch.ops.aten.view.default(mm_565, [2, 4096, 512]); mm_565 = None + convert_element_type_3059 = torch.ops.prims.convert_element_type.default(mm_564, torch.float32); mm_564 = None + reduce_scatter_tensor_305 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3059, 'avg', 128, '0'); convert_element_type_3059 = None + wait_tensor_908 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_305); reduce_scatter_tensor_305 = None + convert_element_type_3060 = torch.ops.prims.convert_element_type.default(view_2190, torch.float32); view_2190 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_266, 128, '0'); convert_element_type_266 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_3062 = torch.ops.prims.convert_element_type.default(wait_tensor_98, torch.float32); wait_tensor_98 = None + mul_2025 = torch.ops.aten.mul.Tensor(convert_element_type_3060, convert_element_type_3062); convert_element_type_3062 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(getitem_451, torch.float32); getitem_451 = None + mul_209 = torch.ops.aten.mul.Tensor(convert_element_type_267, rsqrt_16); convert_element_type_267 = None + mul_2027 = torch.ops.aten.mul.Tensor(mul_209, mul_2025) + sum_279 = torch.ops.aten.sum.dim_IntList(mul_2027, [2], True); mul_2027 = None + div_262 = torch.ops.aten.div.Tensor(mul_209, 512) + mul_2028 = torch.ops.aten.mul.Tensor(div_262, sum_279); div_262 = sum_279 = None + sub_755 = torch.ops.aten.sub.Tensor(mul_2025, mul_2028); mul_2025 = mul_2028 = None + mul_2029 = torch.ops.aten.mul.Tensor(sub_755, rsqrt_16); sub_755 = rsqrt_16 = None + mul_2030 = torch.ops.aten.mul.Tensor(convert_element_type_3060, mul_209); convert_element_type_3060 = mul_209 = None + sum_280 = torch.ops.aten.sum.dim_IntList(mul_2030, [0, 1]); mul_2030 = None + convert_element_type_3063 = torch.ops.prims.convert_element_type.default(mul_2029, torch.bfloat16); mul_2029 = None + convert_element_type_default_17 = torch.ops.prims.convert_element_type.default(sum_280, torch.float32); sum_280 = None + reduce_scatter_tensor_306 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_17, 'avg', 128, '0'); convert_element_type_default_17 = None + wait_tensor_909 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_306); reduce_scatter_tensor_306 = None + convert_element_type_3066 = torch.ops.prims.convert_element_type.default(sum_278, torch.float32); sum_278 = None + view_2191 = torch.ops.aten.view.default(convert_element_type_3066, [2, 4096, 1, 32, 2]); convert_element_type_3066 = None + view_as_complex_96 = torch.ops.aten.view_as_complex.default(view_2191); view_2191 = None + mul_2031 = torch.ops.aten.mul.Tensor(view_as_complex_96, clone_9); view_as_complex_96 = None + view_as_real_96 = torch.ops.aten.view_as_real.default(mul_2031); mul_2031 = None + view_2192 = torch.ops.aten.view.default(view_as_real_96, [2, 4096, 1, 64]); view_as_real_96 = None + convert_element_type_3067 = torch.ops.prims.convert_element_type.default(view_2192, torch.bfloat16); view_2192 = None + squeeze_47 = torch.ops.aten.squeeze.dim(convert_element_type_3067, 2); convert_element_type_3067 = None + cat_409 = torch.ops.aten.cat.default([convert_element_type_3063, squeeze_47], 2); convert_element_type_3063 = squeeze_47 = None + view_2193 = torch.ops.aten.view.default(cat_409, [8192, 576]); cat_409 = None + permute_1498 = torch.ops.aten.permute.default(view_2193, [1, 0]) + mm_566 = torch.ops.aten.mm.default(permute_1498, view_304); permute_1498 = None + convert_element_type_261 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_261, 128, '0'); convert_element_type_261 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + slice_31 = torch.ops.aten.slice.Tensor(wait_tensor_97, 0, 0, 576); wait_tensor_97 = None + permute_72 = torch.ops.aten.permute.default(slice_31, [1, 0]); slice_31 = None + permute_1500 = torch.ops.aten.permute.default(permute_72, [1, 0]); permute_72 = None + mm_567 = torch.ops.aten.mm.default(view_2193, permute_1500); view_2193 = permute_1500 = None + view_2194 = torch.ops.aten.view.default(mm_567, [2, 4096, 2048]); mm_567 = None + convert_element_type_3072 = torch.ops.prims.convert_element_type.default(mm_566, torch.float32); mm_566 = None + split_1321 = torch.ops.aten.split.Tensor(convert_element_type_3072, 5); convert_element_type_3072 = None + getitem_24753 = split_1321[0] + getitem_24754 = split_1321[1] + getitem_24755 = split_1321[2] + getitem_24756 = split_1321[3] + getitem_24757 = split_1321[4] + getitem_24758 = split_1321[5] + getitem_24759 = split_1321[6] + getitem_24760 = split_1321[7] + getitem_24761 = split_1321[8] + getitem_24762 = split_1321[9] + getitem_24763 = split_1321[10] + getitem_24764 = split_1321[11] + getitem_24765 = split_1321[12] + getitem_24766 = split_1321[13] + getitem_24767 = split_1321[14] + getitem_24768 = split_1321[15] + getitem_24769 = split_1321[16] + getitem_24770 = split_1321[17] + getitem_24771 = split_1321[18] + getitem_24772 = split_1321[19] + getitem_24773 = split_1321[20] + getitem_24774 = split_1321[21] + getitem_24775 = split_1321[22] + getitem_24776 = split_1321[23] + getitem_24777 = split_1321[24] + getitem_24778 = split_1321[25] + getitem_24779 = split_1321[26] + getitem_24780 = split_1321[27] + getitem_24781 = split_1321[28] + getitem_24782 = split_1321[29] + getitem_24783 = split_1321[30] + getitem_24784 = split_1321[31] + getitem_24785 = split_1321[32] + getitem_24786 = split_1321[33] + getitem_24787 = split_1321[34] + getitem_24788 = split_1321[35] + getitem_24789 = split_1321[36] + getitem_24790 = split_1321[37] + getitem_24791 = split_1321[38] + getitem_24792 = split_1321[39] + getitem_24793 = split_1321[40] + getitem_24794 = split_1321[41] + getitem_24795 = split_1321[42] + getitem_24796 = split_1321[43] + getitem_24797 = split_1321[44] + getitem_24798 = split_1321[45] + getitem_24799 = split_1321[46] + getitem_24800 = split_1321[47] + getitem_24801 = split_1321[48] + getitem_24802 = split_1321[49] + getitem_24803 = split_1321[50] + getitem_24804 = split_1321[51] + getitem_24805 = split_1321[52] + getitem_24806 = split_1321[53] + getitem_24807 = split_1321[54] + getitem_24808 = split_1321[55] + getitem_24809 = split_1321[56] + getitem_24810 = split_1321[57] + getitem_24811 = split_1321[58] + getitem_24812 = split_1321[59] + getitem_24813 = split_1321[60] + getitem_24814 = split_1321[61] + getitem_24815 = split_1321[62] + getitem_24816 = split_1321[63] + getitem_24817 = split_1321[64] + getitem_24818 = split_1321[65] + getitem_24819 = split_1321[66] + getitem_24820 = split_1321[67] + getitem_24821 = split_1321[68] + getitem_24822 = split_1321[69] + getitem_24823 = split_1321[70] + getitem_24824 = split_1321[71] + getitem_24825 = split_1321[72] + getitem_24826 = split_1321[73] + getitem_24827 = split_1321[74] + getitem_24828 = split_1321[75] + getitem_24829 = split_1321[76] + getitem_24830 = split_1321[77] + getitem_24831 = split_1321[78] + getitem_24832 = split_1321[79] + getitem_24833 = split_1321[80] + getitem_24834 = split_1321[81] + getitem_24835 = split_1321[82] + getitem_24836 = split_1321[83] + getitem_24837 = split_1321[84] + getitem_24838 = split_1321[85] + getitem_24839 = split_1321[86] + getitem_24840 = split_1321[87] + getitem_24841 = split_1321[88] + getitem_24842 = split_1321[89] + getitem_24843 = split_1321[90] + getitem_24844 = split_1321[91] + getitem_24845 = split_1321[92] + getitem_24846 = split_1321[93] + getitem_24847 = split_1321[94] + getitem_24848 = split_1321[95] + getitem_24849 = split_1321[96] + getitem_24850 = split_1321[97] + getitem_24851 = split_1321[98] + getitem_24852 = split_1321[99] + getitem_24853 = split_1321[100] + getitem_24854 = split_1321[101] + getitem_24855 = split_1321[102] + getitem_24856 = split_1321[103] + getitem_24857 = split_1321[104] + getitem_24858 = split_1321[105] + getitem_24859 = split_1321[106] + getitem_24860 = split_1321[107] + getitem_24861 = split_1321[108] + getitem_24862 = split_1321[109] + getitem_24863 = split_1321[110] + getitem_24864 = split_1321[111] + getitem_24865 = split_1321[112] + getitem_24866 = split_1321[113] + getitem_24867 = split_1321[114] + getitem_24868 = split_1321[115]; split_1321 = None + constant_pad_nd_1681 = torch.ops.aten.constant_pad_nd.default(getitem_24868, [0, 0, 0, 4], 0.0); getitem_24868 = None + cat_410 = torch.ops.aten.cat.default([getitem_24753, getitem_24754, getitem_24755, getitem_24756, getitem_24757, getitem_24758, getitem_24759, getitem_24760, getitem_24761, getitem_24762, getitem_24763, getitem_24764, getitem_24765, getitem_24766, getitem_24767, getitem_24768, getitem_24769, getitem_24770, getitem_24771, getitem_24772, getitem_24773, getitem_24774, getitem_24775, getitem_24776, getitem_24777, getitem_24778, getitem_24779, getitem_24780, getitem_24781, getitem_24782, getitem_24783, getitem_24784, getitem_24785, getitem_24786, getitem_24787, getitem_24788, getitem_24789, getitem_24790, getitem_24791, getitem_24792, getitem_24793, getitem_24794, getitem_24795, getitem_24796, getitem_24797, getitem_24798, getitem_24799, getitem_24800, getitem_24801, getitem_24802, getitem_24803, getitem_24804, getitem_24805, getitem_24806, getitem_24807, getitem_24808, getitem_24809, getitem_24810, getitem_24811, getitem_24812, getitem_24813, getitem_24814, getitem_24815, getitem_24816, getitem_24817, getitem_24818, getitem_24819, getitem_24820, getitem_24821, getitem_24822, getitem_24823, getitem_24824, getitem_24825, getitem_24826, getitem_24827, getitem_24828, getitem_24829, getitem_24830, getitem_24831, getitem_24832, getitem_24833, getitem_24834, getitem_24835, getitem_24836, getitem_24837, getitem_24838, getitem_24839, getitem_24840, getitem_24841, getitem_24842, getitem_24843, getitem_24844, getitem_24845, getitem_24846, getitem_24847, getitem_24848, getitem_24849, getitem_24850, getitem_24851, getitem_24852, getitem_24853, getitem_24854, getitem_24855, getitem_24856, getitem_24857, getitem_24858, getitem_24859, getitem_24860, getitem_24861, getitem_24862, getitem_24863, getitem_24864, getitem_24865, getitem_24866, getitem_24867, constant_pad_nd_1681, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_24753 = getitem_24754 = getitem_24755 = getitem_24756 = getitem_24757 = getitem_24758 = getitem_24759 = getitem_24760 = getitem_24761 = getitem_24762 = getitem_24763 = getitem_24764 = getitem_24765 = getitem_24766 = getitem_24767 = getitem_24768 = getitem_24769 = getitem_24770 = getitem_24771 = getitem_24772 = getitem_24773 = getitem_24774 = getitem_24775 = getitem_24776 = getitem_24777 = getitem_24778 = getitem_24779 = getitem_24780 = getitem_24781 = getitem_24782 = getitem_24783 = getitem_24784 = getitem_24785 = getitem_24786 = getitem_24787 = getitem_24788 = getitem_24789 = getitem_24790 = getitem_24791 = getitem_24792 = getitem_24793 = getitem_24794 = getitem_24795 = getitem_24796 = getitem_24797 = getitem_24798 = getitem_24799 = getitem_24800 = getitem_24801 = getitem_24802 = getitem_24803 = getitem_24804 = getitem_24805 = getitem_24806 = getitem_24807 = getitem_24808 = getitem_24809 = getitem_24810 = getitem_24811 = getitem_24812 = getitem_24813 = getitem_24814 = getitem_24815 = getitem_24816 = getitem_24817 = getitem_24818 = getitem_24819 = getitem_24820 = getitem_24821 = getitem_24822 = getitem_24823 = getitem_24824 = getitem_24825 = getitem_24826 = getitem_24827 = getitem_24828 = getitem_24829 = getitem_24830 = getitem_24831 = getitem_24832 = getitem_24833 = getitem_24834 = getitem_24835 = getitem_24836 = getitem_24837 = getitem_24838 = getitem_24839 = getitem_24840 = getitem_24841 = getitem_24842 = getitem_24843 = getitem_24844 = getitem_24845 = getitem_24846 = getitem_24847 = getitem_24848 = getitem_24849 = getitem_24850 = getitem_24851 = getitem_24852 = getitem_24853 = getitem_24854 = getitem_24855 = getitem_24856 = getitem_24857 = getitem_24858 = getitem_24859 = getitem_24860 = getitem_24861 = getitem_24862 = getitem_24863 = getitem_24864 = getitem_24865 = getitem_24866 = getitem_24867 = constant_pad_nd_1681 = None + reduce_scatter_tensor_307 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_410, 'avg', 128, '0'); cat_410 = None + wait_tensor_910 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_307); reduce_scatter_tensor_307 = None + slice_292 = torch.ops.aten.slice.Tensor(permute_1493, 3, 0, 128) + slice_293 = torch.ops.aten.slice.Tensor(permute_1493, 3, 128, 192); permute_1493 = None + convert_element_type_3073 = torch.ops.prims.convert_element_type.default(slice_293, torch.float32); slice_293 = None + view_2195 = torch.ops.aten.view.default(convert_element_type_3073, [2, 4096, 16, 32, 2]); convert_element_type_3073 = None + view_as_complex_97 = torch.ops.aten.view_as_complex.default(view_2195); view_2195 = None + mul_2032 = torch.ops.aten.mul.Tensor(view_as_complex_97, clone_9); view_as_complex_97 = None + view_as_real_97 = torch.ops.aten.view_as_real.default(mul_2032); mul_2032 = None + view_2196 = torch.ops.aten.view.default(view_as_real_97, [2, 4096, 16, 64]); view_as_real_97 = None + convert_element_type_3074 = torch.ops.prims.convert_element_type.default(view_2196, torch.bfloat16); view_2196 = None + cat_411 = torch.ops.aten.cat.default([slice_292, convert_element_type_3074], 3); slice_292 = convert_element_type_3074 = None + view_2197 = torch.ops.aten.view.default(cat_411, [2, 4096, 3072]); cat_411 = None + view_2198 = torch.ops.aten.view.default(view_2197, [8192, 3072]); view_2197 = None + permute_1502 = torch.ops.aten.permute.default(view_2198, [1, 0]) + mm_568 = torch.ops.aten.mm.default(permute_1502, view_304); permute_1502 = view_304 = None + convert_element_type_256 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_256, 128, '0'); convert_element_type_256 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_71 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + permute_1504 = torch.ops.aten.permute.default(permute_71, [1, 0]); permute_71 = None + mm_569 = torch.ops.aten.mm.default(view_2198, permute_1504); view_2198 = permute_1504 = None + view_2199 = torch.ops.aten.view.default(mm_569, [2, 4096, 2048]); mm_569 = None + add_2103 = torch.ops.aten.add.Tensor(view_2194, view_2199); view_2194 = view_2199 = None + convert_element_type_3079 = torch.ops.prims.convert_element_type.default(mm_568, torch.float32); mm_568 = None + reduce_scatter_tensor_308 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3079, 'avg', 128, '0'); convert_element_type_3079 = None + wait_tensor_911 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_308); reduce_scatter_tensor_308 = None + convert_element_type_3080 = torch.ops.prims.convert_element_type.default(add_2103, torch.float32); add_2103 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 128, '0'); convert_element_type_253 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + convert_element_type_3082 = torch.ops.prims.convert_element_type.default(wait_tensor_95, torch.float32); wait_tensor_95 = None + mul_2033 = torch.ops.aten.mul.Tensor(convert_element_type_3080, convert_element_type_3082); convert_element_type_3082 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(add_277, torch.float32); add_277 = None + mul_205 = torch.ops.aten.mul.Tensor(convert_element_type_254, rsqrt_15); convert_element_type_254 = None + mul_2035 = torch.ops.aten.mul.Tensor(mul_205, mul_2033) + sum_281 = torch.ops.aten.sum.dim_IntList(mul_2035, [2], True); mul_2035 = None + div_263 = torch.ops.aten.div.Tensor(mul_205, 2048) + mul_2036 = torch.ops.aten.mul.Tensor(div_263, sum_281); div_263 = sum_281 = None + sub_756 = torch.ops.aten.sub.Tensor(mul_2033, mul_2036); mul_2033 = mul_2036 = None + mul_2037 = torch.ops.aten.mul.Tensor(sub_756, rsqrt_15); sub_756 = rsqrt_15 = None + mul_2038 = torch.ops.aten.mul.Tensor(convert_element_type_3080, mul_205); convert_element_type_3080 = mul_205 = None + sum_282 = torch.ops.aten.sum.dim_IntList(mul_2038, [0, 1]); mul_2038 = None + convert_element_type_3083 = torch.ops.prims.convert_element_type.default(mul_2037, torch.bfloat16); mul_2037 = None + add_2104 = torch.ops.aten.add.Tensor(add_2102, convert_element_type_3083); add_2102 = convert_element_type_3083 = None + convert_element_type_default_16 = torch.ops.prims.convert_element_type.default(sum_282, torch.float32); sum_282 = None + reduce_scatter_tensor_309 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_16, 'avg', 128, '0'); convert_element_type_default_16 = None + wait_tensor_912 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_309); reduce_scatter_tensor_309 = None + view_2200 = torch.ops.aten.view.default(add_2104, [8192, 2048]) + unsqueeze_75 = torch.ops.aten.unsqueeze.default(view_2200, 1) + convert_element_type_3086 = torch.ops.prims.convert_element_type.default(unsqueeze_75, torch.float32); unsqueeze_75 = None + bmm_70 = torch.ops.aten.bmm.default(permute_1506, convert_element_type_3086); permute_1506 = None + bmm_71 = torch.ops.aten.bmm.default(convert_element_type_3086, permute_1507); convert_element_type_3086 = permute_1507 = None + convert_element_type_3087 = torch.ops.prims.convert_element_type.default(bmm_70, torch.bfloat16); bmm_70 = None + view_2201 = torch.ops.aten.view.default(bmm_71, [8192, 6]); bmm_71 = None + view_2202 = torch.ops.aten.view.default(convert_element_type_3087, [49152, 2048]); convert_element_type_3087 = None + index_96 = torch.ops.aten.index.Tensor(view_2202, [getitem_351]); view_2202 = getitem_351 = None + permute_1508 = torch.ops.aten.permute.default(view_2200, [1, 0]) + mm_570 = torch.ops.aten.mm.default(permute_1508, mul_202); permute_1508 = mul_202 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 128, '0'); convert_element_type_248 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + permute_70 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + permute_1510 = torch.ops.aten.permute.default(permute_70, [1, 0]); permute_70 = None + mm_571 = torch.ops.aten.mm.default(view_2200, permute_1510); view_2200 = permute_1510 = None + convert_element_type_3092 = torch.ops.prims.convert_element_type.default(mm_570, torch.float32); mm_570 = None + reduce_scatter_tensor_310 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3092, 'avg', 128, '0'); convert_element_type_3092 = None + wait_tensor_913 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_310); reduce_scatter_tensor_310 = None + convert_element_type_243 = torch.ops.prims.convert_element_type.default(mm_36, torch.float32); mm_36 = None + neg_8 = torch.ops.aten.neg.default(convert_element_type_243) + exp_12 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_272 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + div_20 = torch.ops.aten.div.Tensor(convert_element_type_243, add_272) + convert_element_type_244 = torch.ops.prims.convert_element_type.default(div_20, torch.bfloat16); div_20 = None + mul_2039 = torch.ops.aten.mul.Tensor(mm_571, convert_element_type_244); convert_element_type_244 = None + mul_2040 = torch.ops.aten.mul.Tensor(mm_571, mm_37); mm_571 = mm_37 = None + permute_1512 = torch.ops.aten.permute.default(mul_2039, [1, 0]) + mm_572 = torch.ops.aten.mm.default(permute_1512, view_259); permute_1512 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_245, 128, '0'); convert_element_type_245 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_69 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + permute_1514 = torch.ops.aten.permute.default(permute_69, [1, 0]); permute_69 = None + mm_573 = torch.ops.aten.mm.default(mul_2039, permute_1514); mul_2039 = permute_1514 = None + convert_element_type_3097 = torch.ops.prims.convert_element_type.default(mm_572, torch.float32); mm_572 = None + reduce_scatter_tensor_311 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3097, 'avg', 128, '0'); convert_element_type_3097 = None + wait_tensor_914 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_311); reduce_scatter_tensor_311 = None + convert_element_type_3098 = torch.ops.prims.convert_element_type.default(mul_2040, torch.float32); mul_2040 = None + reciprocal_44 = torch.ops.aten.reciprocal.default(add_272); add_272 = None + mul_2041 = torch.ops.aten.mul.Tensor(reciprocal_44, 1); reciprocal_44 = None + mul_2042 = torch.ops.aten.mul.Tensor(convert_element_type_3098, mul_2041); convert_element_type_3098 = None + sub_757 = torch.ops.aten.sub.Tensor(1, mul_2041); mul_2041 = None + mul_2043 = torch.ops.aten.mul.Tensor(convert_element_type_243, sub_757); convert_element_type_243 = sub_757 = None + add_2106 = torch.ops.aten.add.Tensor(mul_2043, 1); mul_2043 = None + mul_2044 = torch.ops.aten.mul.Tensor(mul_2042, add_2106); mul_2042 = add_2106 = None + convert_element_type_3100 = torch.ops.prims.convert_element_type.default(mul_2044, torch.bfloat16); mul_2044 = None + permute_1516 = torch.ops.aten.permute.default(convert_element_type_3100, [1, 0]) + mm_574 = torch.ops.aten.mm.default(permute_1516, view_259); permute_1516 = None + convert_element_type_240 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_240, 128, '0'); convert_element_type_240 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + permute_1518 = torch.ops.aten.permute.default(permute_68, [1, 0]); permute_68 = None + mm_575 = torch.ops.aten.mm.default(convert_element_type_3100, permute_1518); convert_element_type_3100 = permute_1518 = None + add_2107 = torch.ops.aten.add.Tensor(mm_573, mm_575); mm_573 = mm_575 = None + convert_element_type_3105 = torch.ops.prims.convert_element_type.default(mm_574, torch.float32); mm_574 = None + reduce_scatter_tensor_312 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3105, 'avg', 128, '0'); convert_element_type_3105 = None + wait_tensor_915 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_312); reduce_scatter_tensor_312 = None + all_to_all_single_122 = torch.ops._c10d_functional.all_to_all_single.default(index_96, [_local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63], [_local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55], '1033'); index_96 = None + wait_tensor_916 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_122); all_to_all_single_122 = None + full_480 = torch.ops.aten.full.default([sym_size_int_13, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_13 = None + slice_scatter_22 = torch.ops.aten.slice_scatter.default(full_480, wait_tensor_916, 0, 0, -1); wait_tensor_916 = None + index_97 = torch.ops.aten.index.Tensor(slice_scatter_22, [getitem_352]); slice_scatter_22 = None + permute_1520 = torch.ops.aten.permute.default(index_97, [1, 0]) + _grouped_mm_210 = torch.ops.aten._grouped_mm.default(permute_1520, mul_182, cumsum_11); permute_1520 = mul_182 = None + _grouped_mm_211 = torch.ops.aten._grouped_mm.default(index_97, permute_1522, cumsum_11); index_97 = permute_1522 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(_grouped_mm_9, torch.float32); _grouped_mm_9 = None + neg_7 = torch.ops.aten.neg.default(convert_element_type_238) + exp_11 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_236 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + div_19 = torch.ops.aten.div.Tensor(convert_element_type_238, add_236) + convert_element_type_239 = torch.ops.prims.convert_element_type.default(div_19, torch.bfloat16); div_19 = None + mul_2045 = torch.ops.aten.mul.Tensor(_grouped_mm_211, convert_element_type_239); convert_element_type_239 = None + mul_2046 = torch.ops.aten.mul.Tensor(_grouped_mm_211, _grouped_mm_10); _grouped_mm_211 = _grouped_mm_10 = None + permute_1524 = torch.ops.aten.permute.default(mul_2045, [1, 0]) + _grouped_mm_212 = torch.ops.aten._grouped_mm.default(permute_1524, index_7, cumsum_11); permute_1524 = None + _grouped_mm_213 = torch.ops.aten._grouped_mm.default(mul_2045, permute_1526, cumsum_11); mul_2045 = permute_1526 = None + convert_element_type_3106 = torch.ops.prims.convert_element_type.default(mul_2046, torch.float32); mul_2046 = None + reciprocal_45 = torch.ops.aten.reciprocal.default(add_236); add_236 = None + mul_2047 = torch.ops.aten.mul.Tensor(reciprocal_45, 1); reciprocal_45 = None + mul_2048 = torch.ops.aten.mul.Tensor(convert_element_type_3106, mul_2047); convert_element_type_3106 = None + sub_758 = torch.ops.aten.sub.Tensor(1, mul_2047); mul_2047 = None + mul_2049 = torch.ops.aten.mul.Tensor(convert_element_type_238, sub_758); convert_element_type_238 = sub_758 = None + add_2109 = torch.ops.aten.add.Tensor(mul_2049, 1); mul_2049 = None + mul_2050 = torch.ops.aten.mul.Tensor(mul_2048, add_2109); mul_2048 = add_2109 = None + convert_element_type_3108 = torch.ops.prims.convert_element_type.default(mul_2050, torch.bfloat16); mul_2050 = None + permute_1528 = torch.ops.aten.permute.default(convert_element_type_3108, [1, 0]) + _grouped_mm_214 = torch.ops.aten._grouped_mm.default(permute_1528, index_7, cumsum_11); permute_1528 = index_7 = None + _grouped_mm_215 = torch.ops.aten._grouped_mm.default(convert_element_type_3108, permute_1530, cumsum_11); convert_element_type_3108 = permute_1530 = cumsum_11 = None + add_2110 = torch.ops.aten.add.Tensor(_grouped_mm_213, _grouped_mm_215); _grouped_mm_213 = _grouped_mm_215 = None + convert_element_type_3109 = torch.ops.prims.convert_element_type.default(_grouped_mm_212, torch.float32); _grouped_mm_212 = None + div_264 = torch.ops.aten.div.Tensor(convert_element_type_3109, 128); convert_element_type_3109 = None + split_1323 = torch.ops.aten.split.Tensor(div_264, 88, 1); div_264 = None + getitem_24885 = split_1323[0] + getitem_24902 = split_1323[1] + getitem_24919 = split_1323[2] + getitem_24936 = split_1323[3] + getitem_24953 = split_1323[4] + getitem_24970 = split_1323[5] + getitem_24987 = split_1323[6] + getitem_25004 = split_1323[7] + getitem_25021 = split_1323[8] + getitem_25038 = split_1323[9] + getitem_25055 = split_1323[10] + getitem_25072 = split_1323[11] + getitem_25089 = split_1323[12] + getitem_25106 = split_1323[13] + getitem_25123 = split_1323[14] + getitem_25140 = split_1323[15]; split_1323 = None + cat_412 = torch.ops.aten.cat.default([getitem_24885, getitem_24902, getitem_24919, getitem_24936, getitem_24953, getitem_24970, getitem_24987, getitem_25004, getitem_25021, getitem_25038, getitem_25055, getitem_25072, getitem_25089, getitem_25106, getitem_25123, getitem_25140]); getitem_24885 = getitem_24902 = getitem_24919 = getitem_24936 = getitem_24953 = getitem_24970 = getitem_24987 = getitem_25004 = getitem_25021 = getitem_25038 = getitem_25055 = getitem_25072 = getitem_25089 = getitem_25106 = getitem_25123 = getitem_25140 = None + reduce_scatter_tensor_313 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_412, 'sum', 16, '1025'); cat_412 = None + wait_tensor_917 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_313); reduce_scatter_tensor_313 = None + convert_element_type_3110 = torch.ops.prims.convert_element_type.default(_grouped_mm_210, torch.float32); _grouped_mm_210 = None + div_265 = torch.ops.aten.div.Tensor(convert_element_type_3110, 128); convert_element_type_3110 = None + split_1340 = torch.ops.aten.split.Tensor(div_265, 128, 1); div_265 = None + getitem_25157 = split_1340[0] + getitem_25174 = split_1340[1] + getitem_25191 = split_1340[2] + getitem_25208 = split_1340[3] + getitem_25225 = split_1340[4] + getitem_25242 = split_1340[5] + getitem_25259 = split_1340[6] + getitem_25276 = split_1340[7] + getitem_25293 = split_1340[8] + getitem_25310 = split_1340[9] + getitem_25327 = split_1340[10] + getitem_25344 = split_1340[11] + getitem_25361 = split_1340[12] + getitem_25378 = split_1340[13] + getitem_25395 = split_1340[14] + getitem_25412 = split_1340[15]; split_1340 = None + cat_413 = torch.ops.aten.cat.default([getitem_25157, getitem_25174, getitem_25191, getitem_25208, getitem_25225, getitem_25242, getitem_25259, getitem_25276, getitem_25293, getitem_25310, getitem_25327, getitem_25344, getitem_25361, getitem_25378, getitem_25395, getitem_25412]); getitem_25157 = getitem_25174 = getitem_25191 = getitem_25208 = getitem_25225 = getitem_25242 = getitem_25259 = getitem_25276 = getitem_25293 = getitem_25310 = getitem_25327 = getitem_25344 = getitem_25361 = getitem_25378 = getitem_25395 = getitem_25412 = None + reduce_scatter_tensor_314 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_413, 'sum', 16, '1025'); cat_413 = None + wait_tensor_918 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_314); reduce_scatter_tensor_314 = None + convert_element_type_3111 = torch.ops.prims.convert_element_type.default(_grouped_mm_214, torch.float32); _grouped_mm_214 = None + div_266 = torch.ops.aten.div.Tensor(convert_element_type_3111, 128); convert_element_type_3111 = None + split_1357 = torch.ops.aten.split.Tensor(div_266, 88, 1); div_266 = None + getitem_25429 = split_1357[0] + getitem_25446 = split_1357[1] + getitem_25463 = split_1357[2] + getitem_25480 = split_1357[3] + getitem_25497 = split_1357[4] + getitem_25514 = split_1357[5] + getitem_25531 = split_1357[6] + getitem_25548 = split_1357[7] + getitem_25565 = split_1357[8] + getitem_25582 = split_1357[9] + getitem_25599 = split_1357[10] + getitem_25616 = split_1357[11] + getitem_25633 = split_1357[12] + getitem_25650 = split_1357[13] + getitem_25667 = split_1357[14] + getitem_25684 = split_1357[15]; split_1357 = None + cat_414 = torch.ops.aten.cat.default([getitem_25429, getitem_25446, getitem_25463, getitem_25480, getitem_25497, getitem_25514, getitem_25531, getitem_25548, getitem_25565, getitem_25582, getitem_25599, getitem_25616, getitem_25633, getitem_25650, getitem_25667, getitem_25684]); getitem_25429 = getitem_25446 = getitem_25463 = getitem_25480 = getitem_25497 = getitem_25514 = getitem_25531 = getitem_25548 = getitem_25565 = getitem_25582 = getitem_25599 = getitem_25616 = getitem_25633 = getitem_25650 = getitem_25667 = getitem_25684 = None + reduce_scatter_tensor_315 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_414, 'sum', 16, '1025'); cat_414 = None + wait_tensor_919 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_315); reduce_scatter_tensor_315 = None + index_put_96 = torch.ops.aten.index_put.default(full_480, [getitem_352], add_2110, True); full_480 = getitem_352 = add_2110 = None + slice_294 = torch.ops.aten.slice.Tensor(index_put_96, 0, 0, add_2111); index_put_96 = add_2111 = None + all_to_all_single_123 = torch.ops._c10d_functional.all_to_all_single.default(slice_294, [_local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55], [_local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63], '1033'); slice_294 = _local_scalar_dense_48 = _local_scalar_dense_49 = _local_scalar_dense_50 = _local_scalar_dense_51 = _local_scalar_dense_52 = _local_scalar_dense_53 = _local_scalar_dense_54 = _local_scalar_dense_55 = _local_scalar_dense_56 = _local_scalar_dense_57 = _local_scalar_dense_58 = _local_scalar_dense_59 = _local_scalar_dense_60 = _local_scalar_dense_61 = _local_scalar_dense_62 = _local_scalar_dense_63 = None + wait_tensor_920 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_123); all_to_all_single_123 = None + index_put_97 = torch.ops.aten.index_put.default(full_default_52, [div_17], wait_tensor_920, True); div_17 = wait_tensor_920 = None + add_2115 = torch.ops.aten.add.Tensor(add_2107, index_put_97); add_2107 = index_put_97 = None + mul_2051 = torch.ops.aten.mul.Tensor(view_2201, 1.0); view_2201 = None + scatter_add_22 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_349, mul_2051); getitem_349 = mul_2051 = None + convert_element_type_227 = torch.ops.prims.convert_element_type.default(mm_35, torch.float32); mm_35 = None + sub_72 = torch.ops.aten.sub.Tensor(convert_element_type_227, amax_3); convert_element_type_227 = amax_3 = None + exp_10 = torch.ops.aten.exp.default(sub_72); sub_72 = None + div_16 = torch.ops.aten.div.Tensor(exp_10, sum_13); exp_10 = sum_13 = None + mul_2052 = torch.ops.aten.mul.Tensor(scatter_add_22, div_16); scatter_add_22 = None + sum_283 = torch.ops.aten.sum.dim_IntList(mul_2052, [1], True) + neg_121 = torch.ops.aten.neg.default(div_16); div_16 = None + fma_22 = torch.ops.prims.fma.default(neg_121, sum_283, mul_2052); neg_121 = sum_283 = mul_2052 = None + convert_element_type_3112 = torch.ops.prims.convert_element_type.default(fma_22, torch.bfloat16); fma_22 = None + permute_1532 = torch.ops.aten.permute.default(convert_element_type_3112, [1, 0]) + mm_576 = torch.ops.aten.mm.default(permute_1532, view_259); permute_1532 = view_259 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16); primals_79 = None + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_224, 128, '0'); convert_element_type_224 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + slice_27 = torch.ops.aten.slice.Tensor(wait_tensor_81, 0, 0, 64); wait_tensor_81 = None + permute_64 = torch.ops.aten.permute.default(slice_27, [1, 0]); slice_27 = None + permute_1534 = torch.ops.aten.permute.default(permute_64, [1, 0]); permute_64 = None + mm_577 = torch.ops.aten.mm.default(convert_element_type_3112, permute_1534); convert_element_type_3112 = permute_1534 = None + add_2116 = torch.ops.aten.add.Tensor(add_2115, mm_577); add_2115 = mm_577 = None + convert_element_type_3117 = torch.ops.prims.convert_element_type.default(mm_576, torch.float32); mm_576 = None + split_1373 = torch.ops.aten.split.Tensor(convert_element_type_3117, 1); convert_element_type_3117 = None + getitem_25685 = split_1373[0] + getitem_25686 = split_1373[1] + getitem_25687 = split_1373[2] + getitem_25688 = split_1373[3] + getitem_25689 = split_1373[4] + getitem_25690 = split_1373[5] + getitem_25691 = split_1373[6] + getitem_25692 = split_1373[7] + getitem_25693 = split_1373[8] + getitem_25694 = split_1373[9] + getitem_25695 = split_1373[10] + getitem_25696 = split_1373[11] + getitem_25697 = split_1373[12] + getitem_25698 = split_1373[13] + getitem_25699 = split_1373[14] + getitem_25700 = split_1373[15] + getitem_25701 = split_1373[16] + getitem_25702 = split_1373[17] + getitem_25703 = split_1373[18] + getitem_25704 = split_1373[19] + getitem_25705 = split_1373[20] + getitem_25706 = split_1373[21] + getitem_25707 = split_1373[22] + getitem_25708 = split_1373[23] + getitem_25709 = split_1373[24] + getitem_25710 = split_1373[25] + getitem_25711 = split_1373[26] + getitem_25712 = split_1373[27] + getitem_25713 = split_1373[28] + getitem_25714 = split_1373[29] + getitem_25715 = split_1373[30] + getitem_25716 = split_1373[31] + getitem_25717 = split_1373[32] + getitem_25718 = split_1373[33] + getitem_25719 = split_1373[34] + getitem_25720 = split_1373[35] + getitem_25721 = split_1373[36] + getitem_25722 = split_1373[37] + getitem_25723 = split_1373[38] + getitem_25724 = split_1373[39] + getitem_25725 = split_1373[40] + getitem_25726 = split_1373[41] + getitem_25727 = split_1373[42] + getitem_25728 = split_1373[43] + getitem_25729 = split_1373[44] + getitem_25730 = split_1373[45] + getitem_25731 = split_1373[46] + getitem_25732 = split_1373[47] + getitem_25733 = split_1373[48] + getitem_25734 = split_1373[49] + getitem_25735 = split_1373[50] + getitem_25736 = split_1373[51] + getitem_25737 = split_1373[52] + getitem_25738 = split_1373[53] + getitem_25739 = split_1373[54] + getitem_25740 = split_1373[55] + getitem_25741 = split_1373[56] + getitem_25742 = split_1373[57] + getitem_25743 = split_1373[58] + getitem_25744 = split_1373[59] + getitem_25745 = split_1373[60] + getitem_25746 = split_1373[61] + getitem_25747 = split_1373[62] + getitem_25748 = split_1373[63]; split_1373 = None + cat_415 = torch.ops.aten.cat.default([getitem_25685, getitem_25686, getitem_25687, getitem_25688, getitem_25689, getitem_25690, getitem_25691, getitem_25692, getitem_25693, getitem_25694, getitem_25695, getitem_25696, getitem_25697, getitem_25698, getitem_25699, getitem_25700, getitem_25701, getitem_25702, getitem_25703, getitem_25704, getitem_25705, getitem_25706, getitem_25707, getitem_25708, getitem_25709, getitem_25710, getitem_25711, getitem_25712, getitem_25713, getitem_25714, getitem_25715, getitem_25716, getitem_25717, getitem_25718, getitem_25719, getitem_25720, getitem_25721, getitem_25722, getitem_25723, getitem_25724, getitem_25725, getitem_25726, getitem_25727, getitem_25728, getitem_25729, getitem_25730, getitem_25731, getitem_25732, getitem_25733, getitem_25734, getitem_25735, getitem_25736, getitem_25737, getitem_25738, getitem_25739, getitem_25740, getitem_25741, getitem_25742, getitem_25743, getitem_25744, getitem_25745, getitem_25746, getitem_25747, getitem_25748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_25685 = getitem_25686 = getitem_25687 = getitem_25688 = getitem_25689 = getitem_25690 = getitem_25691 = getitem_25692 = getitem_25693 = getitem_25694 = getitem_25695 = getitem_25696 = getitem_25697 = getitem_25698 = getitem_25699 = getitem_25700 = getitem_25701 = getitem_25702 = getitem_25703 = getitem_25704 = getitem_25705 = getitem_25706 = getitem_25707 = getitem_25708 = getitem_25709 = getitem_25710 = getitem_25711 = getitem_25712 = getitem_25713 = getitem_25714 = getitem_25715 = getitem_25716 = getitem_25717 = getitem_25718 = getitem_25719 = getitem_25720 = getitem_25721 = getitem_25722 = getitem_25723 = getitem_25724 = getitem_25725 = getitem_25726 = getitem_25727 = getitem_25728 = getitem_25729 = getitem_25730 = getitem_25731 = getitem_25732 = getitem_25733 = getitem_25734 = getitem_25735 = getitem_25736 = getitem_25737 = getitem_25738 = getitem_25739 = getitem_25740 = getitem_25741 = getitem_25742 = getitem_25743 = getitem_25744 = getitem_25745 = getitem_25746 = getitem_25747 = getitem_25748 = None + reduce_scatter_tensor_316 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_415, 'avg', 128, '0'); cat_415 = None + wait_tensor_921 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_316); reduce_scatter_tensor_316 = None + view_2203 = torch.ops.aten.view.default(add_2116, [2, 4096, 2048]); add_2116 = None + convert_element_type_3118 = torch.ops.prims.convert_element_type.default(view_2203, torch.float32); view_2203 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16); primals_77 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 128, '0'); convert_element_type_221 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + convert_element_type_3120 = torch.ops.prims.convert_element_type.default(wait_tensor_80, torch.float32); wait_tensor_80 = None + mul_2053 = torch.ops.aten.mul.Tensor(convert_element_type_3118, convert_element_type_3120); convert_element_type_3120 = None + convert_element_type_222 = torch.ops.prims.convert_element_type.default(add_212, torch.float32); add_212 = None + mul_162 = torch.ops.aten.mul.Tensor(convert_element_type_222, rsqrt_14); convert_element_type_222 = None + mul_2055 = torch.ops.aten.mul.Tensor(mul_162, mul_2053) + sum_284 = torch.ops.aten.sum.dim_IntList(mul_2055, [2], True); mul_2055 = None + div_267 = torch.ops.aten.div.Tensor(mul_162, 2048) + mul_2056 = torch.ops.aten.mul.Tensor(div_267, sum_284); div_267 = sum_284 = None + sub_760 = torch.ops.aten.sub.Tensor(mul_2053, mul_2056); mul_2053 = mul_2056 = None + mul_2057 = torch.ops.aten.mul.Tensor(sub_760, rsqrt_14); sub_760 = rsqrt_14 = None + mul_2058 = torch.ops.aten.mul.Tensor(convert_element_type_3118, mul_162); convert_element_type_3118 = mul_162 = None + sum_285 = torch.ops.aten.sum.dim_IntList(mul_2058, [0, 1]); mul_2058 = None + convert_element_type_3121 = torch.ops.prims.convert_element_type.default(mul_2057, torch.bfloat16); mul_2057 = None + add_2117 = torch.ops.aten.add.Tensor(add_2104, convert_element_type_3121); add_2104 = convert_element_type_3121 = None + convert_element_type_default_15 = torch.ops.prims.convert_element_type.default(sum_285, torch.float32); sum_285 = None + reduce_scatter_tensor_317 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_15, 'avg', 128, '0'); convert_element_type_default_15 = None + wait_tensor_922 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_317); reduce_scatter_tensor_317 = None + view_2204 = torch.ops.aten.view.default(add_2117, [8192, 2048]) + permute_1536 = torch.ops.aten.permute.default(view_2204, [1, 0]) + permute_62 = torch.ops.aten.permute.default(getitem_345, [0, 2, 1, 3]) + view_254 = torch.ops.aten.view.default(permute_62, [2, 4096, -1]); permute_62 = None + view_256 = torch.ops.aten.view.default(view_254, [8192, 2048]); view_254 = None + mm_578 = torch.ops.aten.mm.default(permute_1536, view_256); permute_1536 = view_256 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16); primals_76 = None + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 128, '0'); convert_element_type_218 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + permute_1538 = torch.ops.aten.permute.default(permute_63, [1, 0]); permute_63 = None + mm_579 = torch.ops.aten.mm.default(view_2204, permute_1538); view_2204 = permute_1538 = None + view_2205 = torch.ops.aten.view.default(mm_579, [2, 4096, 2048]); mm_579 = None + convert_element_type_3128 = torch.ops.prims.convert_element_type.default(mm_578, torch.float32); mm_578 = None + reduce_scatter_tensor_318 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3128, 'avg', 128, '0'); convert_element_type_3128 = None + wait_tensor_923 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_318); reduce_scatter_tensor_318 = None + view_2206 = torch.ops.aten.view.default(view_2205, [2, 4096, 16, 128]); view_2205 = None + permute_1540 = torch.ops.aten.permute.default(view_2206, [0, 2, 1, 3]); view_2206 = None + fw_graph22 = self.fw_graph22 + joint_graph22 = self.joint_graph22 + mask_graph22 = self.mask_graph22 + flex_attention_backward_22 = torch.ops.higher_order.flex_attention_backward(permute_59, permute_60, permute_61, getitem_345, getitem_346, permute_1540, None, fw_graph22, joint_graph22, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph22), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_59 = permute_60 = permute_61 = getitem_345 = getitem_346 = permute_1540 = fw_graph22 = joint_graph22 = mask_graph22 = None + getitem_25749 = flex_attention_backward_22[0] + getitem_25750 = flex_attention_backward_22[1] + getitem_25751 = flex_attention_backward_22[2]; flex_attention_backward_22 = None + permute_1541 = torch.ops.aten.permute.default(getitem_25751, [0, 2, 1, 3]); getitem_25751 = None + permute_1542 = torch.ops.aten.permute.default(getitem_25750, [0, 2, 1, 3]); getitem_25750 = None + permute_1543 = torch.ops.aten.permute.default(getitem_25749, [0, 2, 1, 3]); getitem_25749 = None + slice_296 = torch.ops.aten.slice.Tensor(permute_1542, 3, 0, 128) + slice_297 = torch.ops.aten.slice.Tensor(permute_1542, 3, 128, 192); permute_1542 = None + sum_286 = torch.ops.aten.sum.dim_IntList(slice_297, [2], True); slice_297 = None + cat_416 = torch.ops.aten.cat.default([slice_296, permute_1541], 3); slice_296 = permute_1541 = None + view_2207 = torch.ops.aten.view.default(cat_416, [2, 4096, 4096]); cat_416 = None + view_2208 = torch.ops.aten.view.default(view_2207, [8192, 4096]); view_2207 = None + permute_1544 = torch.ops.aten.permute.default(view_2208, [1, 0]) + mm_580 = torch.ops.aten.mm.default(permute_1544, view_251); permute_1544 = view_251 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16); primals_75 = None + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 128, '0'); convert_element_type_215 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_58 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + permute_1546 = torch.ops.aten.permute.default(permute_58, [1, 0]); permute_58 = None + mm_581 = torch.ops.aten.mm.default(view_2208, permute_1546); view_2208 = permute_1546 = None + view_2209 = torch.ops.aten.view.default(mm_581, [2, 4096, 512]); mm_581 = None + convert_element_type_3133 = torch.ops.prims.convert_element_type.default(mm_580, torch.float32); mm_580 = None + reduce_scatter_tensor_319 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3133, 'avg', 128, '0'); convert_element_type_3133 = None + wait_tensor_924 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_319); reduce_scatter_tensor_319 = None + convert_element_type_3134 = torch.ops.prims.convert_element_type.default(view_2209, torch.float32); view_2209 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16); primals_74 = None + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_212, 128, '0'); convert_element_type_212 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + convert_element_type_3136 = torch.ops.prims.convert_element_type.default(wait_tensor_77, torch.float32); wait_tensor_77 = None + mul_2059 = torch.ops.aten.mul.Tensor(convert_element_type_3134, convert_element_type_3136); convert_element_type_3136 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(getitem_341, torch.float32); getitem_341 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_213, rsqrt_13); convert_element_type_213 = None + mul_2061 = torch.ops.aten.mul.Tensor(mul_160, mul_2059) + sum_287 = torch.ops.aten.sum.dim_IntList(mul_2061, [2], True); mul_2061 = None + div_268 = torch.ops.aten.div.Tensor(mul_160, 512) + mul_2062 = torch.ops.aten.mul.Tensor(div_268, sum_287); div_268 = sum_287 = None + sub_761 = torch.ops.aten.sub.Tensor(mul_2059, mul_2062); mul_2059 = mul_2062 = None + mul_2063 = torch.ops.aten.mul.Tensor(sub_761, rsqrt_13); sub_761 = rsqrt_13 = None + mul_2064 = torch.ops.aten.mul.Tensor(convert_element_type_3134, mul_160); convert_element_type_3134 = mul_160 = None + sum_288 = torch.ops.aten.sum.dim_IntList(mul_2064, [0, 1]); mul_2064 = None + convert_element_type_3137 = torch.ops.prims.convert_element_type.default(mul_2063, torch.bfloat16); mul_2063 = None + convert_element_type_default_14 = torch.ops.prims.convert_element_type.default(sum_288, torch.float32); sum_288 = None + reduce_scatter_tensor_320 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_14, 'avg', 128, '0'); convert_element_type_default_14 = None + wait_tensor_925 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_320); reduce_scatter_tensor_320 = None + convert_element_type_3140 = torch.ops.prims.convert_element_type.default(sum_286, torch.float32); sum_286 = None + view_2210 = torch.ops.aten.view.default(convert_element_type_3140, [2, 4096, 1, 32, 2]); convert_element_type_3140 = None + view_as_complex_98 = torch.ops.aten.view_as_complex.default(view_2210); view_2210 = None + mul_2065 = torch.ops.aten.mul.Tensor(view_as_complex_98, clone_9); view_as_complex_98 = None + view_as_real_98 = torch.ops.aten.view_as_real.default(mul_2065); mul_2065 = None + view_2211 = torch.ops.aten.view.default(view_as_real_98, [2, 4096, 1, 64]); view_as_real_98 = None + convert_element_type_3141 = torch.ops.prims.convert_element_type.default(view_2211, torch.bfloat16); view_2211 = None + squeeze_48 = torch.ops.aten.squeeze.dim(convert_element_type_3141, 2); convert_element_type_3141 = None + cat_417 = torch.ops.aten.cat.default([convert_element_type_3137, squeeze_48], 2); convert_element_type_3137 = squeeze_48 = None + view_2212 = torch.ops.aten.view.default(cat_417, [8192, 576]); cat_417 = None + permute_1548 = torch.ops.aten.permute.default(view_2212, [1, 0]) + mm_582 = torch.ops.aten.mm.default(permute_1548, view_237); permute_1548 = None + convert_element_type_207 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16); primals_73 = None + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_207, 128, '0'); convert_element_type_207 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + slice_25 = torch.ops.aten.slice.Tensor(wait_tensor_76, 0, 0, 576); wait_tensor_76 = None + permute_57 = torch.ops.aten.permute.default(slice_25, [1, 0]); slice_25 = None + permute_1550 = torch.ops.aten.permute.default(permute_57, [1, 0]); permute_57 = None + mm_583 = torch.ops.aten.mm.default(view_2212, permute_1550); view_2212 = permute_1550 = None + view_2213 = torch.ops.aten.view.default(mm_583, [2, 4096, 2048]); mm_583 = None + convert_element_type_3146 = torch.ops.prims.convert_element_type.default(mm_582, torch.float32); mm_582 = None + split_1374 = torch.ops.aten.split.Tensor(convert_element_type_3146, 5); convert_element_type_3146 = None + getitem_25753 = split_1374[0] + getitem_25754 = split_1374[1] + getitem_25755 = split_1374[2] + getitem_25756 = split_1374[3] + getitem_25757 = split_1374[4] + getitem_25758 = split_1374[5] + getitem_25759 = split_1374[6] + getitem_25760 = split_1374[7] + getitem_25761 = split_1374[8] + getitem_25762 = split_1374[9] + getitem_25763 = split_1374[10] + getitem_25764 = split_1374[11] + getitem_25765 = split_1374[12] + getitem_25766 = split_1374[13] + getitem_25767 = split_1374[14] + getitem_25768 = split_1374[15] + getitem_25769 = split_1374[16] + getitem_25770 = split_1374[17] + getitem_25771 = split_1374[18] + getitem_25772 = split_1374[19] + getitem_25773 = split_1374[20] + getitem_25774 = split_1374[21] + getitem_25775 = split_1374[22] + getitem_25776 = split_1374[23] + getitem_25777 = split_1374[24] + getitem_25778 = split_1374[25] + getitem_25779 = split_1374[26] + getitem_25780 = split_1374[27] + getitem_25781 = split_1374[28] + getitem_25782 = split_1374[29] + getitem_25783 = split_1374[30] + getitem_25784 = split_1374[31] + getitem_25785 = split_1374[32] + getitem_25786 = split_1374[33] + getitem_25787 = split_1374[34] + getitem_25788 = split_1374[35] + getitem_25789 = split_1374[36] + getitem_25790 = split_1374[37] + getitem_25791 = split_1374[38] + getitem_25792 = split_1374[39] + getitem_25793 = split_1374[40] + getitem_25794 = split_1374[41] + getitem_25795 = split_1374[42] + getitem_25796 = split_1374[43] + getitem_25797 = split_1374[44] + getitem_25798 = split_1374[45] + getitem_25799 = split_1374[46] + getitem_25800 = split_1374[47] + getitem_25801 = split_1374[48] + getitem_25802 = split_1374[49] + getitem_25803 = split_1374[50] + getitem_25804 = split_1374[51] + getitem_25805 = split_1374[52] + getitem_25806 = split_1374[53] + getitem_25807 = split_1374[54] + getitem_25808 = split_1374[55] + getitem_25809 = split_1374[56] + getitem_25810 = split_1374[57] + getitem_25811 = split_1374[58] + getitem_25812 = split_1374[59] + getitem_25813 = split_1374[60] + getitem_25814 = split_1374[61] + getitem_25815 = split_1374[62] + getitem_25816 = split_1374[63] + getitem_25817 = split_1374[64] + getitem_25818 = split_1374[65] + getitem_25819 = split_1374[66] + getitem_25820 = split_1374[67] + getitem_25821 = split_1374[68] + getitem_25822 = split_1374[69] + getitem_25823 = split_1374[70] + getitem_25824 = split_1374[71] + getitem_25825 = split_1374[72] + getitem_25826 = split_1374[73] + getitem_25827 = split_1374[74] + getitem_25828 = split_1374[75] + getitem_25829 = split_1374[76] + getitem_25830 = split_1374[77] + getitem_25831 = split_1374[78] + getitem_25832 = split_1374[79] + getitem_25833 = split_1374[80] + getitem_25834 = split_1374[81] + getitem_25835 = split_1374[82] + getitem_25836 = split_1374[83] + getitem_25837 = split_1374[84] + getitem_25838 = split_1374[85] + getitem_25839 = split_1374[86] + getitem_25840 = split_1374[87] + getitem_25841 = split_1374[88] + getitem_25842 = split_1374[89] + getitem_25843 = split_1374[90] + getitem_25844 = split_1374[91] + getitem_25845 = split_1374[92] + getitem_25846 = split_1374[93] + getitem_25847 = split_1374[94] + getitem_25848 = split_1374[95] + getitem_25849 = split_1374[96] + getitem_25850 = split_1374[97] + getitem_25851 = split_1374[98] + getitem_25852 = split_1374[99] + getitem_25853 = split_1374[100] + getitem_25854 = split_1374[101] + getitem_25855 = split_1374[102] + getitem_25856 = split_1374[103] + getitem_25857 = split_1374[104] + getitem_25858 = split_1374[105] + getitem_25859 = split_1374[106] + getitem_25860 = split_1374[107] + getitem_25861 = split_1374[108] + getitem_25862 = split_1374[109] + getitem_25863 = split_1374[110] + getitem_25864 = split_1374[111] + getitem_25865 = split_1374[112] + getitem_25866 = split_1374[113] + getitem_25867 = split_1374[114] + getitem_25868 = split_1374[115]; split_1374 = None + constant_pad_nd_1758 = torch.ops.aten.constant_pad_nd.default(getitem_25868, [0, 0, 0, 4], 0.0); getitem_25868 = None + cat_418 = torch.ops.aten.cat.default([getitem_25753, getitem_25754, getitem_25755, getitem_25756, getitem_25757, getitem_25758, getitem_25759, getitem_25760, getitem_25761, getitem_25762, getitem_25763, getitem_25764, getitem_25765, getitem_25766, getitem_25767, getitem_25768, getitem_25769, getitem_25770, getitem_25771, getitem_25772, getitem_25773, getitem_25774, getitem_25775, getitem_25776, getitem_25777, getitem_25778, getitem_25779, getitem_25780, getitem_25781, getitem_25782, getitem_25783, getitem_25784, getitem_25785, getitem_25786, getitem_25787, getitem_25788, getitem_25789, getitem_25790, getitem_25791, getitem_25792, getitem_25793, getitem_25794, getitem_25795, getitem_25796, getitem_25797, getitem_25798, getitem_25799, getitem_25800, getitem_25801, getitem_25802, getitem_25803, getitem_25804, getitem_25805, getitem_25806, getitem_25807, getitem_25808, getitem_25809, getitem_25810, getitem_25811, getitem_25812, getitem_25813, getitem_25814, getitem_25815, getitem_25816, getitem_25817, getitem_25818, getitem_25819, getitem_25820, getitem_25821, getitem_25822, getitem_25823, getitem_25824, getitem_25825, getitem_25826, getitem_25827, getitem_25828, getitem_25829, getitem_25830, getitem_25831, getitem_25832, getitem_25833, getitem_25834, getitem_25835, getitem_25836, getitem_25837, getitem_25838, getitem_25839, getitem_25840, getitem_25841, getitem_25842, getitem_25843, getitem_25844, getitem_25845, getitem_25846, getitem_25847, getitem_25848, getitem_25849, getitem_25850, getitem_25851, getitem_25852, getitem_25853, getitem_25854, getitem_25855, getitem_25856, getitem_25857, getitem_25858, getitem_25859, getitem_25860, getitem_25861, getitem_25862, getitem_25863, getitem_25864, getitem_25865, getitem_25866, getitem_25867, constant_pad_nd_1758, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_25753 = getitem_25754 = getitem_25755 = getitem_25756 = getitem_25757 = getitem_25758 = getitem_25759 = getitem_25760 = getitem_25761 = getitem_25762 = getitem_25763 = getitem_25764 = getitem_25765 = getitem_25766 = getitem_25767 = getitem_25768 = getitem_25769 = getitem_25770 = getitem_25771 = getitem_25772 = getitem_25773 = getitem_25774 = getitem_25775 = getitem_25776 = getitem_25777 = getitem_25778 = getitem_25779 = getitem_25780 = getitem_25781 = getitem_25782 = getitem_25783 = getitem_25784 = getitem_25785 = getitem_25786 = getitem_25787 = getitem_25788 = getitem_25789 = getitem_25790 = getitem_25791 = getitem_25792 = getitem_25793 = getitem_25794 = getitem_25795 = getitem_25796 = getitem_25797 = getitem_25798 = getitem_25799 = getitem_25800 = getitem_25801 = getitem_25802 = getitem_25803 = getitem_25804 = getitem_25805 = getitem_25806 = getitem_25807 = getitem_25808 = getitem_25809 = getitem_25810 = getitem_25811 = getitem_25812 = getitem_25813 = getitem_25814 = getitem_25815 = getitem_25816 = getitem_25817 = getitem_25818 = getitem_25819 = getitem_25820 = getitem_25821 = getitem_25822 = getitem_25823 = getitem_25824 = getitem_25825 = getitem_25826 = getitem_25827 = getitem_25828 = getitem_25829 = getitem_25830 = getitem_25831 = getitem_25832 = getitem_25833 = getitem_25834 = getitem_25835 = getitem_25836 = getitem_25837 = getitem_25838 = getitem_25839 = getitem_25840 = getitem_25841 = getitem_25842 = getitem_25843 = getitem_25844 = getitem_25845 = getitem_25846 = getitem_25847 = getitem_25848 = getitem_25849 = getitem_25850 = getitem_25851 = getitem_25852 = getitem_25853 = getitem_25854 = getitem_25855 = getitem_25856 = getitem_25857 = getitem_25858 = getitem_25859 = getitem_25860 = getitem_25861 = getitem_25862 = getitem_25863 = getitem_25864 = getitem_25865 = getitem_25866 = getitem_25867 = constant_pad_nd_1758 = None + reduce_scatter_tensor_321 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_418, 'avg', 128, '0'); cat_418 = None + wait_tensor_926 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_321); reduce_scatter_tensor_321 = None + slice_298 = torch.ops.aten.slice.Tensor(permute_1543, 3, 0, 128) + slice_299 = torch.ops.aten.slice.Tensor(permute_1543, 3, 128, 192); permute_1543 = None + convert_element_type_3147 = torch.ops.prims.convert_element_type.default(slice_299, torch.float32); slice_299 = None + view_2214 = torch.ops.aten.view.default(convert_element_type_3147, [2, 4096, 16, 32, 2]); convert_element_type_3147 = None + view_as_complex_99 = torch.ops.aten.view_as_complex.default(view_2214); view_2214 = None + mul_2066 = torch.ops.aten.mul.Tensor(view_as_complex_99, clone_9); view_as_complex_99 = None + view_as_real_99 = torch.ops.aten.view_as_real.default(mul_2066); mul_2066 = None + view_2215 = torch.ops.aten.view.default(view_as_real_99, [2, 4096, 16, 64]); view_as_real_99 = None + convert_element_type_3148 = torch.ops.prims.convert_element_type.default(view_2215, torch.bfloat16); view_2215 = None + cat_419 = torch.ops.aten.cat.default([slice_298, convert_element_type_3148], 3); slice_298 = convert_element_type_3148 = None + view_2216 = torch.ops.aten.view.default(cat_419, [2, 4096, 3072]); cat_419 = None + view_2217 = torch.ops.aten.view.default(view_2216, [8192, 3072]); view_2216 = None + permute_1552 = torch.ops.aten.permute.default(view_2217, [1, 0]) + mm_584 = torch.ops.aten.mm.default(permute_1552, view_237); permute_1552 = view_237 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 128, '0'); convert_element_type_202 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + permute_1554 = torch.ops.aten.permute.default(permute_56, [1, 0]); permute_56 = None + mm_585 = torch.ops.aten.mm.default(view_2217, permute_1554); view_2217 = permute_1554 = None + view_2218 = torch.ops.aten.view.default(mm_585, [2, 4096, 2048]); mm_585 = None + add_2118 = torch.ops.aten.add.Tensor(view_2213, view_2218); view_2213 = view_2218 = None + convert_element_type_3153 = torch.ops.prims.convert_element_type.default(mm_584, torch.float32); mm_584 = None + reduce_scatter_tensor_322 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3153, 'avg', 128, '0'); convert_element_type_3153 = None + wait_tensor_927 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_322); reduce_scatter_tensor_322 = None + convert_element_type_3154 = torch.ops.prims.convert_element_type.default(add_2118, torch.float32); add_2118 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 128, '0'); convert_element_type_199 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_3156 = torch.ops.prims.convert_element_type.default(wait_tensor_74, torch.float32); wait_tensor_74 = None + mul_2067 = torch.ops.aten.mul.Tensor(convert_element_type_3154, convert_element_type_3156); convert_element_type_3156 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_209, torch.float32); add_209 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_2069 = torch.ops.aten.mul.Tensor(mul_156, mul_2067) + sum_289 = torch.ops.aten.sum.dim_IntList(mul_2069, [2], True); mul_2069 = None + div_269 = torch.ops.aten.div.Tensor(mul_156, 2048) + mul_2070 = torch.ops.aten.mul.Tensor(div_269, sum_289); div_269 = sum_289 = None + sub_762 = torch.ops.aten.sub.Tensor(mul_2067, mul_2070); mul_2067 = mul_2070 = None + mul_2071 = torch.ops.aten.mul.Tensor(sub_762, rsqrt_12); sub_762 = rsqrt_12 = None + mul_2072 = torch.ops.aten.mul.Tensor(convert_element_type_3154, mul_156); convert_element_type_3154 = mul_156 = None + sum_290 = torch.ops.aten.sum.dim_IntList(mul_2072, [0, 1]); mul_2072 = None + convert_element_type_3157 = torch.ops.prims.convert_element_type.default(mul_2071, torch.bfloat16); mul_2071 = None + add_2119 = torch.ops.aten.add.Tensor(add_2117, convert_element_type_3157); add_2117 = convert_element_type_3157 = None + convert_element_type_default_13 = torch.ops.prims.convert_element_type.default(sum_290, torch.float32); sum_290 = None + reduce_scatter_tensor_323 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_13, 'avg', 128, '0'); convert_element_type_default_13 = None + wait_tensor_928 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_323); reduce_scatter_tensor_323 = None + view_2219 = torch.ops.aten.view.default(add_2119, [8192, 2048]) + unsqueeze_76 = torch.ops.aten.unsqueeze.default(view_2219, 1) + convert_element_type_3160 = torch.ops.prims.convert_element_type.default(unsqueeze_76, torch.float32); unsqueeze_76 = None + bmm_72 = torch.ops.aten.bmm.default(permute_1556, convert_element_type_3160); permute_1556 = None + bmm_73 = torch.ops.aten.bmm.default(convert_element_type_3160, permute_1557); convert_element_type_3160 = permute_1557 = None + convert_element_type_3161 = torch.ops.prims.convert_element_type.default(bmm_72, torch.bfloat16); bmm_72 = None + view_2220 = torch.ops.aten.view.default(bmm_73, [8192, 6]); bmm_73 = None + view_2221 = torch.ops.aten.view.default(convert_element_type_3161, [49152, 2048]); convert_element_type_3161 = None + index_98 = torch.ops.aten.index.Tensor(view_2221, [getitem_241]); view_2221 = getitem_241 = None + permute_1558 = torch.ops.aten.permute.default(view_2219, [1, 0]) + mm_586 = torch.ops.aten.mm.default(permute_1558, mul_153); permute_1558 = mul_153 = None + convert_element_type_194 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_194, 128, '0'); convert_element_type_194 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_73, [1, 0]); wait_tensor_73 = None + permute_1560 = torch.ops.aten.permute.default(permute_55, [1, 0]); permute_55 = None + mm_587 = torch.ops.aten.mm.default(view_2219, permute_1560); view_2219 = permute_1560 = None + convert_element_type_3166 = torch.ops.prims.convert_element_type.default(mm_586, torch.float32); mm_586 = None + reduce_scatter_tensor_324 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3166, 'avg', 128, '0'); convert_element_type_3166 = None + wait_tensor_929 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_324); reduce_scatter_tensor_324 = None + convert_element_type_189 = torch.ops.prims.convert_element_type.default(mm_28, torch.float32); mm_28 = None + neg_6 = torch.ops.aten.neg.default(convert_element_type_189) + exp_9 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_204 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + div_15 = torch.ops.aten.div.Tensor(convert_element_type_189, add_204) + convert_element_type_190 = torch.ops.prims.convert_element_type.default(div_15, torch.bfloat16); div_15 = None + mul_2073 = torch.ops.aten.mul.Tensor(mm_587, convert_element_type_190); convert_element_type_190 = None + mul_2074 = torch.ops.aten.mul.Tensor(mm_587, mm_29); mm_587 = mm_29 = None + permute_1562 = torch.ops.aten.permute.default(mul_2073, [1, 0]) + mm_588 = torch.ops.aten.mm.default(permute_1562, view_192); permute_1562 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_191, 128, '0'); convert_element_type_191 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + permute_1564 = torch.ops.aten.permute.default(permute_54, [1, 0]); permute_54 = None + mm_589 = torch.ops.aten.mm.default(mul_2073, permute_1564); mul_2073 = permute_1564 = None + convert_element_type_3171 = torch.ops.prims.convert_element_type.default(mm_588, torch.float32); mm_588 = None + reduce_scatter_tensor_325 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3171, 'avg', 128, '0'); convert_element_type_3171 = None + wait_tensor_930 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_325); reduce_scatter_tensor_325 = None + convert_element_type_3172 = torch.ops.prims.convert_element_type.default(mul_2074, torch.float32); mul_2074 = None + reciprocal_46 = torch.ops.aten.reciprocal.default(add_204); add_204 = None + mul_2075 = torch.ops.aten.mul.Tensor(reciprocal_46, 1); reciprocal_46 = None + mul_2076 = torch.ops.aten.mul.Tensor(convert_element_type_3172, mul_2075); convert_element_type_3172 = None + sub_763 = torch.ops.aten.sub.Tensor(1, mul_2075); mul_2075 = None + mul_2077 = torch.ops.aten.mul.Tensor(convert_element_type_189, sub_763); convert_element_type_189 = sub_763 = None + add_2121 = torch.ops.aten.add.Tensor(mul_2077, 1); mul_2077 = None + mul_2078 = torch.ops.aten.mul.Tensor(mul_2076, add_2121); mul_2076 = add_2121 = None + convert_element_type_3174 = torch.ops.prims.convert_element_type.default(mul_2078, torch.bfloat16); mul_2078 = None + permute_1566 = torch.ops.aten.permute.default(convert_element_type_3174, [1, 0]) + mm_590 = torch.ops.aten.mm.default(permute_1566, view_192); permute_1566 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_186, 128, '0'); convert_element_type_186 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + permute_1568 = torch.ops.aten.permute.default(permute_53, [1, 0]); permute_53 = None + mm_591 = torch.ops.aten.mm.default(convert_element_type_3174, permute_1568); convert_element_type_3174 = permute_1568 = None + add_2122 = torch.ops.aten.add.Tensor(mm_589, mm_591); mm_589 = mm_591 = None + convert_element_type_3179 = torch.ops.prims.convert_element_type.default(mm_590, torch.float32); mm_590 = None + reduce_scatter_tensor_326 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3179, 'avg', 128, '0'); convert_element_type_3179 = None + wait_tensor_931 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_326); reduce_scatter_tensor_326 = None + all_to_all_single_124 = torch.ops._c10d_functional.all_to_all_single.default(index_98, [_local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47], [_local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39], '1033'); index_98 = None + wait_tensor_932 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_124); all_to_all_single_124 = None + full_486 = torch.ops.aten.full.default([sym_size_int_9, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_9 = None + slice_scatter_23 = torch.ops.aten.slice_scatter.default(full_486, wait_tensor_932, 0, 0, -1); wait_tensor_932 = None + index_99 = torch.ops.aten.index.Tensor(slice_scatter_23, [getitem_242]); slice_scatter_23 = None + permute_1570 = torch.ops.aten.permute.default(index_99, [1, 0]) + _grouped_mm_216 = torch.ops.aten._grouped_mm.default(permute_1570, mul_133, cumsum_8); permute_1570 = mul_133 = None + _grouped_mm_217 = torch.ops.aten._grouped_mm.default(index_99, permute_1572, cumsum_8); index_99 = permute_1572 = None + convert_element_type_184 = torch.ops.prims.convert_element_type.default(_grouped_mm_6, torch.float32); _grouped_mm_6 = None + neg_5 = torch.ops.aten.neg.default(convert_element_type_184) + exp_8 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_168 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + div_14 = torch.ops.aten.div.Tensor(convert_element_type_184, add_168) + convert_element_type_185 = torch.ops.prims.convert_element_type.default(div_14, torch.bfloat16); div_14 = None + mul_2079 = torch.ops.aten.mul.Tensor(_grouped_mm_217, convert_element_type_185); convert_element_type_185 = None + mul_2080 = torch.ops.aten.mul.Tensor(_grouped_mm_217, _grouped_mm_7); _grouped_mm_217 = _grouped_mm_7 = None + permute_1574 = torch.ops.aten.permute.default(mul_2079, [1, 0]) + _grouped_mm_218 = torch.ops.aten._grouped_mm.default(permute_1574, index_5, cumsum_8); permute_1574 = None + _grouped_mm_219 = torch.ops.aten._grouped_mm.default(mul_2079, permute_1576, cumsum_8); mul_2079 = permute_1576 = None + convert_element_type_3180 = torch.ops.prims.convert_element_type.default(mul_2080, torch.float32); mul_2080 = None + reciprocal_47 = torch.ops.aten.reciprocal.default(add_168); add_168 = None + mul_2081 = torch.ops.aten.mul.Tensor(reciprocal_47, 1); reciprocal_47 = None + mul_2082 = torch.ops.aten.mul.Tensor(convert_element_type_3180, mul_2081); convert_element_type_3180 = None + sub_764 = torch.ops.aten.sub.Tensor(1, mul_2081); mul_2081 = None + mul_2083 = torch.ops.aten.mul.Tensor(convert_element_type_184, sub_764); convert_element_type_184 = sub_764 = None + add_2124 = torch.ops.aten.add.Tensor(mul_2083, 1); mul_2083 = None + mul_2084 = torch.ops.aten.mul.Tensor(mul_2082, add_2124); mul_2082 = add_2124 = None + convert_element_type_3182 = torch.ops.prims.convert_element_type.default(mul_2084, torch.bfloat16); mul_2084 = None + permute_1578 = torch.ops.aten.permute.default(convert_element_type_3182, [1, 0]) + _grouped_mm_220 = torch.ops.aten._grouped_mm.default(permute_1578, index_5, cumsum_8); permute_1578 = index_5 = None + _grouped_mm_221 = torch.ops.aten._grouped_mm.default(convert_element_type_3182, permute_1580, cumsum_8); convert_element_type_3182 = permute_1580 = cumsum_8 = None + add_2125 = torch.ops.aten.add.Tensor(_grouped_mm_219, _grouped_mm_221); _grouped_mm_219 = _grouped_mm_221 = None + convert_element_type_3183 = torch.ops.prims.convert_element_type.default(_grouped_mm_218, torch.float32); _grouped_mm_218 = None + div_270 = torch.ops.aten.div.Tensor(convert_element_type_3183, 128); convert_element_type_3183 = None + split_1376 = torch.ops.aten.split.Tensor(div_270, 88, 1); div_270 = None + getitem_25885 = split_1376[0] + getitem_25902 = split_1376[1] + getitem_25919 = split_1376[2] + getitem_25936 = split_1376[3] + getitem_25953 = split_1376[4] + getitem_25970 = split_1376[5] + getitem_25987 = split_1376[6] + getitem_26004 = split_1376[7] + getitem_26021 = split_1376[8] + getitem_26038 = split_1376[9] + getitem_26055 = split_1376[10] + getitem_26072 = split_1376[11] + getitem_26089 = split_1376[12] + getitem_26106 = split_1376[13] + getitem_26123 = split_1376[14] + getitem_26140 = split_1376[15]; split_1376 = None + cat_420 = torch.ops.aten.cat.default([getitem_25885, getitem_25902, getitem_25919, getitem_25936, getitem_25953, getitem_25970, getitem_25987, getitem_26004, getitem_26021, getitem_26038, getitem_26055, getitem_26072, getitem_26089, getitem_26106, getitem_26123, getitem_26140]); getitem_25885 = getitem_25902 = getitem_25919 = getitem_25936 = getitem_25953 = getitem_25970 = getitem_25987 = getitem_26004 = getitem_26021 = getitem_26038 = getitem_26055 = getitem_26072 = getitem_26089 = getitem_26106 = getitem_26123 = getitem_26140 = None + reduce_scatter_tensor_327 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_420, 'sum', 16, '1025'); cat_420 = None + wait_tensor_933 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_327); reduce_scatter_tensor_327 = None + convert_element_type_3184 = torch.ops.prims.convert_element_type.default(_grouped_mm_216, torch.float32); _grouped_mm_216 = None + div_271 = torch.ops.aten.div.Tensor(convert_element_type_3184, 128); convert_element_type_3184 = None + split_1393 = torch.ops.aten.split.Tensor(div_271, 128, 1); div_271 = None + getitem_26157 = split_1393[0] + getitem_26174 = split_1393[1] + getitem_26191 = split_1393[2] + getitem_26208 = split_1393[3] + getitem_26225 = split_1393[4] + getitem_26242 = split_1393[5] + getitem_26259 = split_1393[6] + getitem_26276 = split_1393[7] + getitem_26293 = split_1393[8] + getitem_26310 = split_1393[9] + getitem_26327 = split_1393[10] + getitem_26344 = split_1393[11] + getitem_26361 = split_1393[12] + getitem_26378 = split_1393[13] + getitem_26395 = split_1393[14] + getitem_26412 = split_1393[15]; split_1393 = None + cat_421 = torch.ops.aten.cat.default([getitem_26157, getitem_26174, getitem_26191, getitem_26208, getitem_26225, getitem_26242, getitem_26259, getitem_26276, getitem_26293, getitem_26310, getitem_26327, getitem_26344, getitem_26361, getitem_26378, getitem_26395, getitem_26412]); getitem_26157 = getitem_26174 = getitem_26191 = getitem_26208 = getitem_26225 = getitem_26242 = getitem_26259 = getitem_26276 = getitem_26293 = getitem_26310 = getitem_26327 = getitem_26344 = getitem_26361 = getitem_26378 = getitem_26395 = getitem_26412 = None + reduce_scatter_tensor_328 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_421, 'sum', 16, '1025'); cat_421 = None + wait_tensor_934 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_328); reduce_scatter_tensor_328 = None + convert_element_type_3185 = torch.ops.prims.convert_element_type.default(_grouped_mm_220, torch.float32); _grouped_mm_220 = None + div_272 = torch.ops.aten.div.Tensor(convert_element_type_3185, 128); convert_element_type_3185 = None + split_1410 = torch.ops.aten.split.Tensor(div_272, 88, 1); div_272 = None + getitem_26429 = split_1410[0] + getitem_26446 = split_1410[1] + getitem_26463 = split_1410[2] + getitem_26480 = split_1410[3] + getitem_26497 = split_1410[4] + getitem_26514 = split_1410[5] + getitem_26531 = split_1410[6] + getitem_26548 = split_1410[7] + getitem_26565 = split_1410[8] + getitem_26582 = split_1410[9] + getitem_26599 = split_1410[10] + getitem_26616 = split_1410[11] + getitem_26633 = split_1410[12] + getitem_26650 = split_1410[13] + getitem_26667 = split_1410[14] + getitem_26684 = split_1410[15]; split_1410 = None + cat_422 = torch.ops.aten.cat.default([getitem_26429, getitem_26446, getitem_26463, getitem_26480, getitem_26497, getitem_26514, getitem_26531, getitem_26548, getitem_26565, getitem_26582, getitem_26599, getitem_26616, getitem_26633, getitem_26650, getitem_26667, getitem_26684]); getitem_26429 = getitem_26446 = getitem_26463 = getitem_26480 = getitem_26497 = getitem_26514 = getitem_26531 = getitem_26548 = getitem_26565 = getitem_26582 = getitem_26599 = getitem_26616 = getitem_26633 = getitem_26650 = getitem_26667 = getitem_26684 = None + reduce_scatter_tensor_329 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_422, 'sum', 16, '1025'); cat_422 = None + wait_tensor_935 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_329); reduce_scatter_tensor_329 = None + index_put_98 = torch.ops.aten.index_put.default(full_486, [getitem_242], add_2125, True); full_486 = getitem_242 = add_2125 = None + slice_300 = torch.ops.aten.slice.Tensor(index_put_98, 0, 0, add_2126); index_put_98 = add_2126 = None + all_to_all_single_125 = torch.ops._c10d_functional.all_to_all_single.default(slice_300, [_local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39], [_local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47], '1033'); slice_300 = _local_scalar_dense_32 = _local_scalar_dense_33 = _local_scalar_dense_34 = _local_scalar_dense_35 = _local_scalar_dense_36 = _local_scalar_dense_37 = _local_scalar_dense_38 = _local_scalar_dense_39 = _local_scalar_dense_40 = _local_scalar_dense_41 = _local_scalar_dense_42 = _local_scalar_dense_43 = _local_scalar_dense_44 = _local_scalar_dense_45 = _local_scalar_dense_46 = _local_scalar_dense_47 = None + wait_tensor_936 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_125); all_to_all_single_125 = None + index_put_99 = torch.ops.aten.index_put.default(full_default_52, [div_12], wait_tensor_936, True); div_12 = wait_tensor_936 = None + add_2130 = torch.ops.aten.add.Tensor(add_2122, index_put_99); add_2122 = index_put_99 = None + mul_2085 = torch.ops.aten.mul.Tensor(view_2220, 1.0); view_2220 = None + scatter_add_23 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_239, mul_2085); getitem_239 = mul_2085 = None + convert_element_type_173 = torch.ops.prims.convert_element_type.default(mm_27, torch.float32); mm_27 = None + sub_48 = torch.ops.aten.sub.Tensor(convert_element_type_173, amax_2); convert_element_type_173 = amax_2 = None + exp_7 = torch.ops.aten.exp.default(sub_48); sub_48 = None + div_11 = torch.ops.aten.div.Tensor(exp_7, sum_9); exp_7 = sum_9 = None + mul_2086 = torch.ops.aten.mul.Tensor(scatter_add_23, div_11); scatter_add_23 = None + sum_291 = torch.ops.aten.sum.dim_IntList(mul_2086, [1], True) + neg_124 = torch.ops.aten.neg.default(div_11); div_11 = None + fma_23 = torch.ops.prims.fma.default(neg_124, sum_291, mul_2086); neg_124 = sum_291 = mul_2086 = None + convert_element_type_3186 = torch.ops.prims.convert_element_type.default(fma_23, torch.bfloat16); fma_23 = None + permute_1582 = torch.ops.aten.permute.default(convert_element_type_3186, [1, 0]) + mm_592 = torch.ops.aten.mm.default(permute_1582, view_192); permute_1582 = view_192 = None + convert_element_type_170 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16); primals_63 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_170, 128, '0'); convert_element_type_170 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + slice_21 = torch.ops.aten.slice.Tensor(wait_tensor_60, 0, 0, 64); wait_tensor_60 = None + permute_49 = torch.ops.aten.permute.default(slice_21, [1, 0]); slice_21 = None + permute_1584 = torch.ops.aten.permute.default(permute_49, [1, 0]); permute_49 = None + mm_593 = torch.ops.aten.mm.default(convert_element_type_3186, permute_1584); convert_element_type_3186 = permute_1584 = None + add_2131 = torch.ops.aten.add.Tensor(add_2130, mm_593); add_2130 = mm_593 = None + convert_element_type_3191 = torch.ops.prims.convert_element_type.default(mm_592, torch.float32); mm_592 = None + split_1426 = torch.ops.aten.split.Tensor(convert_element_type_3191, 1); convert_element_type_3191 = None + getitem_26685 = split_1426[0] + getitem_26686 = split_1426[1] + getitem_26687 = split_1426[2] + getitem_26688 = split_1426[3] + getitem_26689 = split_1426[4] + getitem_26690 = split_1426[5] + getitem_26691 = split_1426[6] + getitem_26692 = split_1426[7] + getitem_26693 = split_1426[8] + getitem_26694 = split_1426[9] + getitem_26695 = split_1426[10] + getitem_26696 = split_1426[11] + getitem_26697 = split_1426[12] + getitem_26698 = split_1426[13] + getitem_26699 = split_1426[14] + getitem_26700 = split_1426[15] + getitem_26701 = split_1426[16] + getitem_26702 = split_1426[17] + getitem_26703 = split_1426[18] + getitem_26704 = split_1426[19] + getitem_26705 = split_1426[20] + getitem_26706 = split_1426[21] + getitem_26707 = split_1426[22] + getitem_26708 = split_1426[23] + getitem_26709 = split_1426[24] + getitem_26710 = split_1426[25] + getitem_26711 = split_1426[26] + getitem_26712 = split_1426[27] + getitem_26713 = split_1426[28] + getitem_26714 = split_1426[29] + getitem_26715 = split_1426[30] + getitem_26716 = split_1426[31] + getitem_26717 = split_1426[32] + getitem_26718 = split_1426[33] + getitem_26719 = split_1426[34] + getitem_26720 = split_1426[35] + getitem_26721 = split_1426[36] + getitem_26722 = split_1426[37] + getitem_26723 = split_1426[38] + getitem_26724 = split_1426[39] + getitem_26725 = split_1426[40] + getitem_26726 = split_1426[41] + getitem_26727 = split_1426[42] + getitem_26728 = split_1426[43] + getitem_26729 = split_1426[44] + getitem_26730 = split_1426[45] + getitem_26731 = split_1426[46] + getitem_26732 = split_1426[47] + getitem_26733 = split_1426[48] + getitem_26734 = split_1426[49] + getitem_26735 = split_1426[50] + getitem_26736 = split_1426[51] + getitem_26737 = split_1426[52] + getitem_26738 = split_1426[53] + getitem_26739 = split_1426[54] + getitem_26740 = split_1426[55] + getitem_26741 = split_1426[56] + getitem_26742 = split_1426[57] + getitem_26743 = split_1426[58] + getitem_26744 = split_1426[59] + getitem_26745 = split_1426[60] + getitem_26746 = split_1426[61] + getitem_26747 = split_1426[62] + getitem_26748 = split_1426[63]; split_1426 = None + cat_423 = torch.ops.aten.cat.default([getitem_26685, getitem_26686, getitem_26687, getitem_26688, getitem_26689, getitem_26690, getitem_26691, getitem_26692, getitem_26693, getitem_26694, getitem_26695, getitem_26696, getitem_26697, getitem_26698, getitem_26699, getitem_26700, getitem_26701, getitem_26702, getitem_26703, getitem_26704, getitem_26705, getitem_26706, getitem_26707, getitem_26708, getitem_26709, getitem_26710, getitem_26711, getitem_26712, getitem_26713, getitem_26714, getitem_26715, getitem_26716, getitem_26717, getitem_26718, getitem_26719, getitem_26720, getitem_26721, getitem_26722, getitem_26723, getitem_26724, getitem_26725, getitem_26726, getitem_26727, getitem_26728, getitem_26729, getitem_26730, getitem_26731, getitem_26732, getitem_26733, getitem_26734, getitem_26735, getitem_26736, getitem_26737, getitem_26738, getitem_26739, getitem_26740, getitem_26741, getitem_26742, getitem_26743, getitem_26744, getitem_26745, getitem_26746, getitem_26747, getitem_26748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_26685 = getitem_26686 = getitem_26687 = getitem_26688 = getitem_26689 = getitem_26690 = getitem_26691 = getitem_26692 = getitem_26693 = getitem_26694 = getitem_26695 = getitem_26696 = getitem_26697 = getitem_26698 = getitem_26699 = getitem_26700 = getitem_26701 = getitem_26702 = getitem_26703 = getitem_26704 = getitem_26705 = getitem_26706 = getitem_26707 = getitem_26708 = getitem_26709 = getitem_26710 = getitem_26711 = getitem_26712 = getitem_26713 = getitem_26714 = getitem_26715 = getitem_26716 = getitem_26717 = getitem_26718 = getitem_26719 = getitem_26720 = getitem_26721 = getitem_26722 = getitem_26723 = getitem_26724 = getitem_26725 = getitem_26726 = getitem_26727 = getitem_26728 = getitem_26729 = getitem_26730 = getitem_26731 = getitem_26732 = getitem_26733 = getitem_26734 = getitem_26735 = getitem_26736 = getitem_26737 = getitem_26738 = getitem_26739 = getitem_26740 = getitem_26741 = getitem_26742 = getitem_26743 = getitem_26744 = getitem_26745 = getitem_26746 = getitem_26747 = getitem_26748 = None + reduce_scatter_tensor_330 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_423, 'avg', 128, '0'); cat_423 = None + wait_tensor_937 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_330); reduce_scatter_tensor_330 = None + view_2222 = torch.ops.aten.view.default(add_2131, [2, 4096, 2048]); add_2131 = None + convert_element_type_3192 = torch.ops.prims.convert_element_type.default(view_2222, torch.float32); view_2222 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16); primals_61 = None + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_167, 128, '0'); convert_element_type_167 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_3194 = torch.ops.prims.convert_element_type.default(wait_tensor_59, torch.float32); wait_tensor_59 = None + mul_2087 = torch.ops.aten.mul.Tensor(convert_element_type_3192, convert_element_type_3194); convert_element_type_3194 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(add_144, torch.float32); add_144 = None + mul_113 = torch.ops.aten.mul.Tensor(convert_element_type_168, rsqrt_11); convert_element_type_168 = None + mul_2089 = torch.ops.aten.mul.Tensor(mul_113, mul_2087) + sum_292 = torch.ops.aten.sum.dim_IntList(mul_2089, [2], True); mul_2089 = None + div_273 = torch.ops.aten.div.Tensor(mul_113, 2048) + mul_2090 = torch.ops.aten.mul.Tensor(div_273, sum_292); div_273 = sum_292 = None + sub_766 = torch.ops.aten.sub.Tensor(mul_2087, mul_2090); mul_2087 = mul_2090 = None + mul_2091 = torch.ops.aten.mul.Tensor(sub_766, rsqrt_11); sub_766 = rsqrt_11 = None + mul_2092 = torch.ops.aten.mul.Tensor(convert_element_type_3192, mul_113); convert_element_type_3192 = mul_113 = None + sum_293 = torch.ops.aten.sum.dim_IntList(mul_2092, [0, 1]); mul_2092 = None + convert_element_type_3195 = torch.ops.prims.convert_element_type.default(mul_2091, torch.bfloat16); mul_2091 = None + add_2132 = torch.ops.aten.add.Tensor(add_2119, convert_element_type_3195); add_2119 = convert_element_type_3195 = None + convert_element_type_default_12 = torch.ops.prims.convert_element_type.default(sum_293, torch.float32); sum_293 = None + reduce_scatter_tensor_331 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_12, 'avg', 128, '0'); convert_element_type_default_12 = None + wait_tensor_938 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_331); reduce_scatter_tensor_331 = None + view_2223 = torch.ops.aten.view.default(add_2132, [8192, 2048]) + permute_1586 = torch.ops.aten.permute.default(view_2223, [1, 0]) + permute_47 = torch.ops.aten.permute.default(getitem_235, [0, 2, 1, 3]) + view_187 = torch.ops.aten.view.default(permute_47, [2, 4096, -1]); permute_47 = None + view_189 = torch.ops.aten.view.default(view_187, [8192, 2048]); view_187 = None + mm_594 = torch.ops.aten.mm.default(permute_1586, view_189); permute_1586 = view_189 = None + convert_element_type_164 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16); primals_60 = None + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_164, 128, '0'); convert_element_type_164 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_48 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + permute_1588 = torch.ops.aten.permute.default(permute_48, [1, 0]); permute_48 = None + mm_595 = torch.ops.aten.mm.default(view_2223, permute_1588); view_2223 = permute_1588 = None + view_2224 = torch.ops.aten.view.default(mm_595, [2, 4096, 2048]); mm_595 = None + convert_element_type_3202 = torch.ops.prims.convert_element_type.default(mm_594, torch.float32); mm_594 = None + reduce_scatter_tensor_332 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3202, 'avg', 128, '0'); convert_element_type_3202 = None + wait_tensor_939 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_332); reduce_scatter_tensor_332 = None + view_2225 = torch.ops.aten.view.default(view_2224, [2, 4096, 16, 128]); view_2224 = None + permute_1590 = torch.ops.aten.permute.default(view_2225, [0, 2, 1, 3]); view_2225 = None + fw_graph23 = self.fw_graph23 + joint_graph23 = self.joint_graph23 + mask_graph23 = self.mask_graph23 + flex_attention_backward_23 = torch.ops.higher_order.flex_attention_backward(permute_44, permute_45, permute_46, getitem_235, getitem_236, permute_1590, None, fw_graph23, joint_graph23, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph23), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_44 = permute_45 = permute_46 = getitem_235 = getitem_236 = permute_1590 = fw_graph23 = joint_graph23 = mask_graph23 = None + getitem_26749 = flex_attention_backward_23[0] + getitem_26750 = flex_attention_backward_23[1] + getitem_26751 = flex_attention_backward_23[2]; flex_attention_backward_23 = None + permute_1591 = torch.ops.aten.permute.default(getitem_26751, [0, 2, 1, 3]); getitem_26751 = None + permute_1592 = torch.ops.aten.permute.default(getitem_26750, [0, 2, 1, 3]); getitem_26750 = None + permute_1593 = torch.ops.aten.permute.default(getitem_26749, [0, 2, 1, 3]); getitem_26749 = None + slice_302 = torch.ops.aten.slice.Tensor(permute_1592, 3, 0, 128) + slice_303 = torch.ops.aten.slice.Tensor(permute_1592, 3, 128, 192); permute_1592 = None + sum_294 = torch.ops.aten.sum.dim_IntList(slice_303, [2], True); slice_303 = None + cat_424 = torch.ops.aten.cat.default([slice_302, permute_1591], 3); slice_302 = permute_1591 = None + view_2226 = torch.ops.aten.view.default(cat_424, [2, 4096, 4096]); cat_424 = None + view_2227 = torch.ops.aten.view.default(view_2226, [8192, 4096]); view_2226 = None + permute_1594 = torch.ops.aten.permute.default(view_2227, [1, 0]) + mm_596 = torch.ops.aten.mm.default(permute_1594, view_184); permute_1594 = view_184 = None + convert_element_type_161 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16); primals_59 = None + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_161, 128, '0'); convert_element_type_161 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + permute_1596 = torch.ops.aten.permute.default(permute_43, [1, 0]); permute_43 = None + mm_597 = torch.ops.aten.mm.default(view_2227, permute_1596); view_2227 = permute_1596 = None + view_2228 = torch.ops.aten.view.default(mm_597, [2, 4096, 512]); mm_597 = None + convert_element_type_3207 = torch.ops.prims.convert_element_type.default(mm_596, torch.float32); mm_596 = None + reduce_scatter_tensor_333 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3207, 'avg', 128, '0'); convert_element_type_3207 = None + wait_tensor_940 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_333); reduce_scatter_tensor_333 = None + convert_element_type_3208 = torch.ops.prims.convert_element_type.default(view_2228, torch.float32); view_2228 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16); primals_58 = None + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_158, 128, '0'); convert_element_type_158 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + convert_element_type_3210 = torch.ops.prims.convert_element_type.default(wait_tensor_56, torch.float32); wait_tensor_56 = None + mul_2093 = torch.ops.aten.mul.Tensor(convert_element_type_3208, convert_element_type_3210); convert_element_type_3210 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(getitem_231, torch.float32); getitem_231 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_159, rsqrt_10); convert_element_type_159 = None + mul_2095 = torch.ops.aten.mul.Tensor(mul_111, mul_2093) + sum_295 = torch.ops.aten.sum.dim_IntList(mul_2095, [2], True); mul_2095 = None + div_274 = torch.ops.aten.div.Tensor(mul_111, 512) + mul_2096 = torch.ops.aten.mul.Tensor(div_274, sum_295); div_274 = sum_295 = None + sub_767 = torch.ops.aten.sub.Tensor(mul_2093, mul_2096); mul_2093 = mul_2096 = None + mul_2097 = torch.ops.aten.mul.Tensor(sub_767, rsqrt_10); sub_767 = rsqrt_10 = None + mul_2098 = torch.ops.aten.mul.Tensor(convert_element_type_3208, mul_111); convert_element_type_3208 = mul_111 = None + sum_296 = torch.ops.aten.sum.dim_IntList(mul_2098, [0, 1]); mul_2098 = None + convert_element_type_3211 = torch.ops.prims.convert_element_type.default(mul_2097, torch.bfloat16); mul_2097 = None + convert_element_type_default_11 = torch.ops.prims.convert_element_type.default(sum_296, torch.float32); sum_296 = None + reduce_scatter_tensor_334 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_11, 'avg', 128, '0'); convert_element_type_default_11 = None + wait_tensor_941 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_334); reduce_scatter_tensor_334 = None + convert_element_type_3214 = torch.ops.prims.convert_element_type.default(sum_294, torch.float32); sum_294 = None + view_2229 = torch.ops.aten.view.default(convert_element_type_3214, [2, 4096, 1, 32, 2]); convert_element_type_3214 = None + view_as_complex_100 = torch.ops.aten.view_as_complex.default(view_2229); view_2229 = None + mul_2099 = torch.ops.aten.mul.Tensor(view_as_complex_100, clone_9); view_as_complex_100 = None + view_as_real_100 = torch.ops.aten.view_as_real.default(mul_2099); mul_2099 = None + view_2230 = torch.ops.aten.view.default(view_as_real_100, [2, 4096, 1, 64]); view_as_real_100 = None + convert_element_type_3215 = torch.ops.prims.convert_element_type.default(view_2230, torch.bfloat16); view_2230 = None + squeeze_49 = torch.ops.aten.squeeze.dim(convert_element_type_3215, 2); convert_element_type_3215 = None + cat_425 = torch.ops.aten.cat.default([convert_element_type_3211, squeeze_49], 2); convert_element_type_3211 = squeeze_49 = None + view_2231 = torch.ops.aten.view.default(cat_425, [8192, 576]); cat_425 = None + permute_1598 = torch.ops.aten.permute.default(view_2231, [1, 0]) + mm_598 = torch.ops.aten.mm.default(permute_1598, view_170); permute_1598 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16); primals_57 = None + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_153, 128, '0'); convert_element_type_153 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + slice_19 = torch.ops.aten.slice.Tensor(wait_tensor_55, 0, 0, 576); wait_tensor_55 = None + permute_42 = torch.ops.aten.permute.default(slice_19, [1, 0]); slice_19 = None + permute_1600 = torch.ops.aten.permute.default(permute_42, [1, 0]); permute_42 = None + mm_599 = torch.ops.aten.mm.default(view_2231, permute_1600); view_2231 = permute_1600 = None + view_2232 = torch.ops.aten.view.default(mm_599, [2, 4096, 2048]); mm_599 = None + convert_element_type_3220 = torch.ops.prims.convert_element_type.default(mm_598, torch.float32); mm_598 = None + split_1427 = torch.ops.aten.split.Tensor(convert_element_type_3220, 5); convert_element_type_3220 = None + getitem_26753 = split_1427[0] + getitem_26754 = split_1427[1] + getitem_26755 = split_1427[2] + getitem_26756 = split_1427[3] + getitem_26757 = split_1427[4] + getitem_26758 = split_1427[5] + getitem_26759 = split_1427[6] + getitem_26760 = split_1427[7] + getitem_26761 = split_1427[8] + getitem_26762 = split_1427[9] + getitem_26763 = split_1427[10] + getitem_26764 = split_1427[11] + getitem_26765 = split_1427[12] + getitem_26766 = split_1427[13] + getitem_26767 = split_1427[14] + getitem_26768 = split_1427[15] + getitem_26769 = split_1427[16] + getitem_26770 = split_1427[17] + getitem_26771 = split_1427[18] + getitem_26772 = split_1427[19] + getitem_26773 = split_1427[20] + getitem_26774 = split_1427[21] + getitem_26775 = split_1427[22] + getitem_26776 = split_1427[23] + getitem_26777 = split_1427[24] + getitem_26778 = split_1427[25] + getitem_26779 = split_1427[26] + getitem_26780 = split_1427[27] + getitem_26781 = split_1427[28] + getitem_26782 = split_1427[29] + getitem_26783 = split_1427[30] + getitem_26784 = split_1427[31] + getitem_26785 = split_1427[32] + getitem_26786 = split_1427[33] + getitem_26787 = split_1427[34] + getitem_26788 = split_1427[35] + getitem_26789 = split_1427[36] + getitem_26790 = split_1427[37] + getitem_26791 = split_1427[38] + getitem_26792 = split_1427[39] + getitem_26793 = split_1427[40] + getitem_26794 = split_1427[41] + getitem_26795 = split_1427[42] + getitem_26796 = split_1427[43] + getitem_26797 = split_1427[44] + getitem_26798 = split_1427[45] + getitem_26799 = split_1427[46] + getitem_26800 = split_1427[47] + getitem_26801 = split_1427[48] + getitem_26802 = split_1427[49] + getitem_26803 = split_1427[50] + getitem_26804 = split_1427[51] + getitem_26805 = split_1427[52] + getitem_26806 = split_1427[53] + getitem_26807 = split_1427[54] + getitem_26808 = split_1427[55] + getitem_26809 = split_1427[56] + getitem_26810 = split_1427[57] + getitem_26811 = split_1427[58] + getitem_26812 = split_1427[59] + getitem_26813 = split_1427[60] + getitem_26814 = split_1427[61] + getitem_26815 = split_1427[62] + getitem_26816 = split_1427[63] + getitem_26817 = split_1427[64] + getitem_26818 = split_1427[65] + getitem_26819 = split_1427[66] + getitem_26820 = split_1427[67] + getitem_26821 = split_1427[68] + getitem_26822 = split_1427[69] + getitem_26823 = split_1427[70] + getitem_26824 = split_1427[71] + getitem_26825 = split_1427[72] + getitem_26826 = split_1427[73] + getitem_26827 = split_1427[74] + getitem_26828 = split_1427[75] + getitem_26829 = split_1427[76] + getitem_26830 = split_1427[77] + getitem_26831 = split_1427[78] + getitem_26832 = split_1427[79] + getitem_26833 = split_1427[80] + getitem_26834 = split_1427[81] + getitem_26835 = split_1427[82] + getitem_26836 = split_1427[83] + getitem_26837 = split_1427[84] + getitem_26838 = split_1427[85] + getitem_26839 = split_1427[86] + getitem_26840 = split_1427[87] + getitem_26841 = split_1427[88] + getitem_26842 = split_1427[89] + getitem_26843 = split_1427[90] + getitem_26844 = split_1427[91] + getitem_26845 = split_1427[92] + getitem_26846 = split_1427[93] + getitem_26847 = split_1427[94] + getitem_26848 = split_1427[95] + getitem_26849 = split_1427[96] + getitem_26850 = split_1427[97] + getitem_26851 = split_1427[98] + getitem_26852 = split_1427[99] + getitem_26853 = split_1427[100] + getitem_26854 = split_1427[101] + getitem_26855 = split_1427[102] + getitem_26856 = split_1427[103] + getitem_26857 = split_1427[104] + getitem_26858 = split_1427[105] + getitem_26859 = split_1427[106] + getitem_26860 = split_1427[107] + getitem_26861 = split_1427[108] + getitem_26862 = split_1427[109] + getitem_26863 = split_1427[110] + getitem_26864 = split_1427[111] + getitem_26865 = split_1427[112] + getitem_26866 = split_1427[113] + getitem_26867 = split_1427[114] + getitem_26868 = split_1427[115]; split_1427 = None + constant_pad_nd_1835 = torch.ops.aten.constant_pad_nd.default(getitem_26868, [0, 0, 0, 4], 0.0); getitem_26868 = None + cat_426 = torch.ops.aten.cat.default([getitem_26753, getitem_26754, getitem_26755, getitem_26756, getitem_26757, getitem_26758, getitem_26759, getitem_26760, getitem_26761, getitem_26762, getitem_26763, getitem_26764, getitem_26765, getitem_26766, getitem_26767, getitem_26768, getitem_26769, getitem_26770, getitem_26771, getitem_26772, getitem_26773, getitem_26774, getitem_26775, getitem_26776, getitem_26777, getitem_26778, getitem_26779, getitem_26780, getitem_26781, getitem_26782, getitem_26783, getitem_26784, getitem_26785, getitem_26786, getitem_26787, getitem_26788, getitem_26789, getitem_26790, getitem_26791, getitem_26792, getitem_26793, getitem_26794, getitem_26795, getitem_26796, getitem_26797, getitem_26798, getitem_26799, getitem_26800, getitem_26801, getitem_26802, getitem_26803, getitem_26804, getitem_26805, getitem_26806, getitem_26807, getitem_26808, getitem_26809, getitem_26810, getitem_26811, getitem_26812, getitem_26813, getitem_26814, getitem_26815, getitem_26816, getitem_26817, getitem_26818, getitem_26819, getitem_26820, getitem_26821, getitem_26822, getitem_26823, getitem_26824, getitem_26825, getitem_26826, getitem_26827, getitem_26828, getitem_26829, getitem_26830, getitem_26831, getitem_26832, getitem_26833, getitem_26834, getitem_26835, getitem_26836, getitem_26837, getitem_26838, getitem_26839, getitem_26840, getitem_26841, getitem_26842, getitem_26843, getitem_26844, getitem_26845, getitem_26846, getitem_26847, getitem_26848, getitem_26849, getitem_26850, getitem_26851, getitem_26852, getitem_26853, getitem_26854, getitem_26855, getitem_26856, getitem_26857, getitem_26858, getitem_26859, getitem_26860, getitem_26861, getitem_26862, getitem_26863, getitem_26864, getitem_26865, getitem_26866, getitem_26867, constant_pad_nd_1835, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_26753 = getitem_26754 = getitem_26755 = getitem_26756 = getitem_26757 = getitem_26758 = getitem_26759 = getitem_26760 = getitem_26761 = getitem_26762 = getitem_26763 = getitem_26764 = getitem_26765 = getitem_26766 = getitem_26767 = getitem_26768 = getitem_26769 = getitem_26770 = getitem_26771 = getitem_26772 = getitem_26773 = getitem_26774 = getitem_26775 = getitem_26776 = getitem_26777 = getitem_26778 = getitem_26779 = getitem_26780 = getitem_26781 = getitem_26782 = getitem_26783 = getitem_26784 = getitem_26785 = getitem_26786 = getitem_26787 = getitem_26788 = getitem_26789 = getitem_26790 = getitem_26791 = getitem_26792 = getitem_26793 = getitem_26794 = getitem_26795 = getitem_26796 = getitem_26797 = getitem_26798 = getitem_26799 = getitem_26800 = getitem_26801 = getitem_26802 = getitem_26803 = getitem_26804 = getitem_26805 = getitem_26806 = getitem_26807 = getitem_26808 = getitem_26809 = getitem_26810 = getitem_26811 = getitem_26812 = getitem_26813 = getitem_26814 = getitem_26815 = getitem_26816 = getitem_26817 = getitem_26818 = getitem_26819 = getitem_26820 = getitem_26821 = getitem_26822 = getitem_26823 = getitem_26824 = getitem_26825 = getitem_26826 = getitem_26827 = getitem_26828 = getitem_26829 = getitem_26830 = getitem_26831 = getitem_26832 = getitem_26833 = getitem_26834 = getitem_26835 = getitem_26836 = getitem_26837 = getitem_26838 = getitem_26839 = getitem_26840 = getitem_26841 = getitem_26842 = getitem_26843 = getitem_26844 = getitem_26845 = getitem_26846 = getitem_26847 = getitem_26848 = getitem_26849 = getitem_26850 = getitem_26851 = getitem_26852 = getitem_26853 = getitem_26854 = getitem_26855 = getitem_26856 = getitem_26857 = getitem_26858 = getitem_26859 = getitem_26860 = getitem_26861 = getitem_26862 = getitem_26863 = getitem_26864 = getitem_26865 = getitem_26866 = getitem_26867 = constant_pad_nd_1835 = None + reduce_scatter_tensor_335 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_426, 'avg', 128, '0'); cat_426 = None + wait_tensor_942 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_335); reduce_scatter_tensor_335 = None + slice_304 = torch.ops.aten.slice.Tensor(permute_1593, 3, 0, 128) + slice_305 = torch.ops.aten.slice.Tensor(permute_1593, 3, 128, 192); permute_1593 = None + convert_element_type_3221 = torch.ops.prims.convert_element_type.default(slice_305, torch.float32); slice_305 = None + view_2233 = torch.ops.aten.view.default(convert_element_type_3221, [2, 4096, 16, 32, 2]); convert_element_type_3221 = None + view_as_complex_101 = torch.ops.aten.view_as_complex.default(view_2233); view_2233 = None + mul_2100 = torch.ops.aten.mul.Tensor(view_as_complex_101, clone_9); view_as_complex_101 = None + view_as_real_101 = torch.ops.aten.view_as_real.default(mul_2100); mul_2100 = None + view_2234 = torch.ops.aten.view.default(view_as_real_101, [2, 4096, 16, 64]); view_as_real_101 = None + convert_element_type_3222 = torch.ops.prims.convert_element_type.default(view_2234, torch.bfloat16); view_2234 = None + cat_427 = torch.ops.aten.cat.default([slice_304, convert_element_type_3222], 3); slice_304 = convert_element_type_3222 = None + view_2235 = torch.ops.aten.view.default(cat_427, [2, 4096, 3072]); cat_427 = None + view_2236 = torch.ops.aten.view.default(view_2235, [8192, 3072]); view_2235 = None + permute_1602 = torch.ops.aten.permute.default(view_2236, [1, 0]) + mm_600 = torch.ops.aten.mm.default(permute_1602, view_170); permute_1602 = view_170 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16); primals_56 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_148, 128, '0'); convert_element_type_148 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + permute_1604 = torch.ops.aten.permute.default(permute_41, [1, 0]); permute_41 = None + mm_601 = torch.ops.aten.mm.default(view_2236, permute_1604); view_2236 = permute_1604 = None + view_2237 = torch.ops.aten.view.default(mm_601, [2, 4096, 2048]); mm_601 = None + add_2133 = torch.ops.aten.add.Tensor(view_2232, view_2237); view_2232 = view_2237 = None + convert_element_type_3227 = torch.ops.prims.convert_element_type.default(mm_600, torch.float32); mm_600 = None + reduce_scatter_tensor_336 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3227, 'avg', 128, '0'); convert_element_type_3227 = None + wait_tensor_943 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_336); reduce_scatter_tensor_336 = None + convert_element_type_3228 = torch.ops.prims.convert_element_type.default(add_2133, torch.float32); add_2133 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16); primals_55 = None + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_145, 128, '0'); convert_element_type_145 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_3230 = torch.ops.prims.convert_element_type.default(wait_tensor_53, torch.float32); wait_tensor_53 = None + mul_2101 = torch.ops.aten.mul.Tensor(convert_element_type_3228, convert_element_type_3230); convert_element_type_3230 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(add_141, torch.float32); add_141 = None + mul_107 = torch.ops.aten.mul.Tensor(convert_element_type_146, rsqrt_9); convert_element_type_146 = None + mul_2103 = torch.ops.aten.mul.Tensor(mul_107, mul_2101) + sum_297 = torch.ops.aten.sum.dim_IntList(mul_2103, [2], True); mul_2103 = None + div_275 = torch.ops.aten.div.Tensor(mul_107, 2048) + mul_2104 = torch.ops.aten.mul.Tensor(div_275, sum_297); div_275 = sum_297 = None + sub_768 = torch.ops.aten.sub.Tensor(mul_2101, mul_2104); mul_2101 = mul_2104 = None + mul_2105 = torch.ops.aten.mul.Tensor(sub_768, rsqrt_9); sub_768 = rsqrt_9 = None + mul_2106 = torch.ops.aten.mul.Tensor(convert_element_type_3228, mul_107); convert_element_type_3228 = mul_107 = None + sum_298 = torch.ops.aten.sum.dim_IntList(mul_2106, [0, 1]); mul_2106 = None + convert_element_type_3231 = torch.ops.prims.convert_element_type.default(mul_2105, torch.bfloat16); mul_2105 = None + add_2134 = torch.ops.aten.add.Tensor(add_2132, convert_element_type_3231); add_2132 = convert_element_type_3231 = None + convert_element_type_default_10 = torch.ops.prims.convert_element_type.default(sum_298, torch.float32); sum_298 = None + reduce_scatter_tensor_337 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_10, 'avg', 128, '0'); convert_element_type_default_10 = None + wait_tensor_944 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_337); reduce_scatter_tensor_337 = None + view_2238 = torch.ops.aten.view.default(add_2134, [8192, 2048]) + unsqueeze_77 = torch.ops.aten.unsqueeze.default(view_2238, 1) + convert_element_type_3234 = torch.ops.prims.convert_element_type.default(unsqueeze_77, torch.float32); unsqueeze_77 = None + bmm_74 = torch.ops.aten.bmm.default(permute_1606, convert_element_type_3234); permute_1606 = None + bmm_75 = torch.ops.aten.bmm.default(convert_element_type_3234, permute_1607); convert_element_type_3234 = permute_1607 = None + convert_element_type_3235 = torch.ops.prims.convert_element_type.default(bmm_74, torch.bfloat16); bmm_74 = None + view_2239 = torch.ops.aten.view.default(bmm_75, [8192, 6]); bmm_75 = None + view_2240 = torch.ops.aten.view.default(convert_element_type_3235, [49152, 2048]); convert_element_type_3235 = None + index_100 = torch.ops.aten.index.Tensor(view_2240, [getitem_131]); view_2240 = getitem_131 = None + permute_1608 = torch.ops.aten.permute.default(view_2238, [1, 0]) + mm_602 = torch.ops.aten.mm.default(permute_1608, mul_104); permute_1608 = mul_104 = None + convert_element_type_140 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_140, 128, '0'); convert_element_type_140 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + permute_1610 = torch.ops.aten.permute.default(permute_40, [1, 0]); permute_40 = None + mm_603 = torch.ops.aten.mm.default(view_2238, permute_1610); view_2238 = permute_1610 = None + convert_element_type_3240 = torch.ops.prims.convert_element_type.default(mm_602, torch.float32); mm_602 = None + reduce_scatter_tensor_338 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3240, 'avg', 128, '0'); convert_element_type_3240 = None + wait_tensor_945 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_338); reduce_scatter_tensor_338 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mm_20, torch.float32); mm_20 = None + neg_4 = torch.ops.aten.neg.default(convert_element_type_135) + exp_6 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_136 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + div_10 = torch.ops.aten.div.Tensor(convert_element_type_135, add_136) + convert_element_type_136 = torch.ops.prims.convert_element_type.default(div_10, torch.bfloat16); div_10 = None + mul_2107 = torch.ops.aten.mul.Tensor(mm_603, convert_element_type_136); convert_element_type_136 = None + mul_2108 = torch.ops.aten.mul.Tensor(mm_603, mm_21); mm_603 = mm_21 = None + permute_1612 = torch.ops.aten.permute.default(mul_2107, [1, 0]) + mm_604 = torch.ops.aten.mm.default(permute_1612, view_125); permute_1612 = None + convert_element_type_137 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_137, 128, '0'); convert_element_type_137 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_39 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + permute_1614 = torch.ops.aten.permute.default(permute_39, [1, 0]); permute_39 = None + mm_605 = torch.ops.aten.mm.default(mul_2107, permute_1614); mul_2107 = permute_1614 = None + convert_element_type_3245 = torch.ops.prims.convert_element_type.default(mm_604, torch.float32); mm_604 = None + reduce_scatter_tensor_339 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3245, 'avg', 128, '0'); convert_element_type_3245 = None + wait_tensor_946 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_339); reduce_scatter_tensor_339 = None + convert_element_type_3246 = torch.ops.prims.convert_element_type.default(mul_2108, torch.float32); mul_2108 = None + reciprocal_48 = torch.ops.aten.reciprocal.default(add_136); add_136 = None + mul_2109 = torch.ops.aten.mul.Tensor(reciprocal_48, 1); reciprocal_48 = None + mul_2110 = torch.ops.aten.mul.Tensor(convert_element_type_3246, mul_2109); convert_element_type_3246 = None + sub_769 = torch.ops.aten.sub.Tensor(1, mul_2109); mul_2109 = None + mul_2111 = torch.ops.aten.mul.Tensor(convert_element_type_135, sub_769); convert_element_type_135 = sub_769 = None + add_2136 = torch.ops.aten.add.Tensor(mul_2111, 1); mul_2111 = None + mul_2112 = torch.ops.aten.mul.Tensor(mul_2110, add_2136); mul_2110 = add_2136 = None + convert_element_type_3248 = torch.ops.prims.convert_element_type.default(mul_2112, torch.bfloat16); mul_2112 = None + permute_1616 = torch.ops.aten.permute.default(convert_element_type_3248, [1, 0]) + mm_606 = torch.ops.aten.mm.default(permute_1616, view_125); permute_1616 = None + convert_element_type_132 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_132, 128, '0'); convert_element_type_132 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_38 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + permute_1618 = torch.ops.aten.permute.default(permute_38, [1, 0]); permute_38 = None + mm_607 = torch.ops.aten.mm.default(convert_element_type_3248, permute_1618); convert_element_type_3248 = permute_1618 = None + add_2137 = torch.ops.aten.add.Tensor(mm_605, mm_607); mm_605 = mm_607 = None + convert_element_type_3253 = torch.ops.prims.convert_element_type.default(mm_606, torch.float32); mm_606 = None + reduce_scatter_tensor_340 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3253, 'avg', 128, '0'); convert_element_type_3253 = None + wait_tensor_947 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_340); reduce_scatter_tensor_340 = None + all_to_all_single_126 = torch.ops._c10d_functional.all_to_all_single.default(index_100, [_local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31], [_local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23], '1033'); index_100 = None + wait_tensor_948 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_126); all_to_all_single_126 = None + full_492 = torch.ops.aten.full.default([sym_size_int_5, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_5 = None + slice_scatter_24 = torch.ops.aten.slice_scatter.default(full_492, wait_tensor_948, 0, 0, -1); wait_tensor_948 = None + index_101 = torch.ops.aten.index.Tensor(slice_scatter_24, [getitem_132]); slice_scatter_24 = None + permute_1620 = torch.ops.aten.permute.default(index_101, [1, 0]) + _grouped_mm_222 = torch.ops.aten._grouped_mm.default(permute_1620, mul_84, cumsum_5); permute_1620 = mul_84 = None + _grouped_mm_223 = torch.ops.aten._grouped_mm.default(index_101, permute_1622, cumsum_5); index_101 = permute_1622 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(_grouped_mm_3, torch.float32); _grouped_mm_3 = None + neg_3 = torch.ops.aten.neg.default(convert_element_type_130) + exp_5 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_100 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + div_9 = torch.ops.aten.div.Tensor(convert_element_type_130, add_100) + convert_element_type_131 = torch.ops.prims.convert_element_type.default(div_9, torch.bfloat16); div_9 = None + mul_2113 = torch.ops.aten.mul.Tensor(_grouped_mm_223, convert_element_type_131); convert_element_type_131 = None + mul_2114 = torch.ops.aten.mul.Tensor(_grouped_mm_223, _grouped_mm_4); _grouped_mm_223 = _grouped_mm_4 = None + permute_1624 = torch.ops.aten.permute.default(mul_2113, [1, 0]) + _grouped_mm_224 = torch.ops.aten._grouped_mm.default(permute_1624, index_3, cumsum_5); permute_1624 = None + _grouped_mm_225 = torch.ops.aten._grouped_mm.default(mul_2113, permute_1626, cumsum_5); mul_2113 = permute_1626 = None + convert_element_type_3254 = torch.ops.prims.convert_element_type.default(mul_2114, torch.float32); mul_2114 = None + reciprocal_49 = torch.ops.aten.reciprocal.default(add_100); add_100 = None + mul_2115 = torch.ops.aten.mul.Tensor(reciprocal_49, 1); reciprocal_49 = None + mul_2116 = torch.ops.aten.mul.Tensor(convert_element_type_3254, mul_2115); convert_element_type_3254 = None + sub_770 = torch.ops.aten.sub.Tensor(1, mul_2115); mul_2115 = None + mul_2117 = torch.ops.aten.mul.Tensor(convert_element_type_130, sub_770); convert_element_type_130 = sub_770 = None + add_2139 = torch.ops.aten.add.Tensor(mul_2117, 1); mul_2117 = None + mul_2118 = torch.ops.aten.mul.Tensor(mul_2116, add_2139); mul_2116 = add_2139 = None + convert_element_type_3256 = torch.ops.prims.convert_element_type.default(mul_2118, torch.bfloat16); mul_2118 = None + permute_1628 = torch.ops.aten.permute.default(convert_element_type_3256, [1, 0]) + _grouped_mm_226 = torch.ops.aten._grouped_mm.default(permute_1628, index_3, cumsum_5); permute_1628 = index_3 = None + _grouped_mm_227 = torch.ops.aten._grouped_mm.default(convert_element_type_3256, permute_1630, cumsum_5); convert_element_type_3256 = permute_1630 = cumsum_5 = None + add_2140 = torch.ops.aten.add.Tensor(_grouped_mm_225, _grouped_mm_227); _grouped_mm_225 = _grouped_mm_227 = None + convert_element_type_3257 = torch.ops.prims.convert_element_type.default(_grouped_mm_224, torch.float32); _grouped_mm_224 = None + div_276 = torch.ops.aten.div.Tensor(convert_element_type_3257, 128); convert_element_type_3257 = None + split_1429 = torch.ops.aten.split.Tensor(div_276, 88, 1); div_276 = None + getitem_26885 = split_1429[0] + getitem_26902 = split_1429[1] + getitem_26919 = split_1429[2] + getitem_26936 = split_1429[3] + getitem_26953 = split_1429[4] + getitem_26970 = split_1429[5] + getitem_26987 = split_1429[6] + getitem_27004 = split_1429[7] + getitem_27021 = split_1429[8] + getitem_27038 = split_1429[9] + getitem_27055 = split_1429[10] + getitem_27072 = split_1429[11] + getitem_27089 = split_1429[12] + getitem_27106 = split_1429[13] + getitem_27123 = split_1429[14] + getitem_27140 = split_1429[15]; split_1429 = None + cat_428 = torch.ops.aten.cat.default([getitem_26885, getitem_26902, getitem_26919, getitem_26936, getitem_26953, getitem_26970, getitem_26987, getitem_27004, getitem_27021, getitem_27038, getitem_27055, getitem_27072, getitem_27089, getitem_27106, getitem_27123, getitem_27140]); getitem_26885 = getitem_26902 = getitem_26919 = getitem_26936 = getitem_26953 = getitem_26970 = getitem_26987 = getitem_27004 = getitem_27021 = getitem_27038 = getitem_27055 = getitem_27072 = getitem_27089 = getitem_27106 = getitem_27123 = getitem_27140 = None + reduce_scatter_tensor_341 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_428, 'sum', 16, '1025'); cat_428 = None + wait_tensor_949 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_341); reduce_scatter_tensor_341 = None + convert_element_type_3258 = torch.ops.prims.convert_element_type.default(_grouped_mm_222, torch.float32); _grouped_mm_222 = None + div_277 = torch.ops.aten.div.Tensor(convert_element_type_3258, 128); convert_element_type_3258 = None + split_1446 = torch.ops.aten.split.Tensor(div_277, 128, 1); div_277 = None + getitem_27157 = split_1446[0] + getitem_27174 = split_1446[1] + getitem_27191 = split_1446[2] + getitem_27208 = split_1446[3] + getitem_27225 = split_1446[4] + getitem_27242 = split_1446[5] + getitem_27259 = split_1446[6] + getitem_27276 = split_1446[7] + getitem_27293 = split_1446[8] + getitem_27310 = split_1446[9] + getitem_27327 = split_1446[10] + getitem_27344 = split_1446[11] + getitem_27361 = split_1446[12] + getitem_27378 = split_1446[13] + getitem_27395 = split_1446[14] + getitem_27412 = split_1446[15]; split_1446 = None + cat_429 = torch.ops.aten.cat.default([getitem_27157, getitem_27174, getitem_27191, getitem_27208, getitem_27225, getitem_27242, getitem_27259, getitem_27276, getitem_27293, getitem_27310, getitem_27327, getitem_27344, getitem_27361, getitem_27378, getitem_27395, getitem_27412]); getitem_27157 = getitem_27174 = getitem_27191 = getitem_27208 = getitem_27225 = getitem_27242 = getitem_27259 = getitem_27276 = getitem_27293 = getitem_27310 = getitem_27327 = getitem_27344 = getitem_27361 = getitem_27378 = getitem_27395 = getitem_27412 = None + reduce_scatter_tensor_342 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_429, 'sum', 16, '1025'); cat_429 = None + wait_tensor_950 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_342); reduce_scatter_tensor_342 = None + convert_element_type_3259 = torch.ops.prims.convert_element_type.default(_grouped_mm_226, torch.float32); _grouped_mm_226 = None + div_278 = torch.ops.aten.div.Tensor(convert_element_type_3259, 128); convert_element_type_3259 = None + split_1463 = torch.ops.aten.split.Tensor(div_278, 88, 1); div_278 = None + getitem_27429 = split_1463[0] + getitem_27446 = split_1463[1] + getitem_27463 = split_1463[2] + getitem_27480 = split_1463[3] + getitem_27497 = split_1463[4] + getitem_27514 = split_1463[5] + getitem_27531 = split_1463[6] + getitem_27548 = split_1463[7] + getitem_27565 = split_1463[8] + getitem_27582 = split_1463[9] + getitem_27599 = split_1463[10] + getitem_27616 = split_1463[11] + getitem_27633 = split_1463[12] + getitem_27650 = split_1463[13] + getitem_27667 = split_1463[14] + getitem_27684 = split_1463[15]; split_1463 = None + cat_430 = torch.ops.aten.cat.default([getitem_27429, getitem_27446, getitem_27463, getitem_27480, getitem_27497, getitem_27514, getitem_27531, getitem_27548, getitem_27565, getitem_27582, getitem_27599, getitem_27616, getitem_27633, getitem_27650, getitem_27667, getitem_27684]); getitem_27429 = getitem_27446 = getitem_27463 = getitem_27480 = getitem_27497 = getitem_27514 = getitem_27531 = getitem_27548 = getitem_27565 = getitem_27582 = getitem_27599 = getitem_27616 = getitem_27633 = getitem_27650 = getitem_27667 = getitem_27684 = None + reduce_scatter_tensor_343 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_430, 'sum', 16, '1025'); cat_430 = None + wait_tensor_951 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_343); reduce_scatter_tensor_343 = None + index_put_100 = torch.ops.aten.index_put.default(full_492, [getitem_132], add_2140, True); full_492 = getitem_132 = add_2140 = None + slice_306 = torch.ops.aten.slice.Tensor(index_put_100, 0, 0, add_2141); index_put_100 = add_2141 = None + all_to_all_single_127 = torch.ops._c10d_functional.all_to_all_single.default(slice_306, [_local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23], [_local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31], '1033'); slice_306 = _local_scalar_dense_16 = _local_scalar_dense_17 = _local_scalar_dense_18 = _local_scalar_dense_19 = _local_scalar_dense_20 = _local_scalar_dense_21 = _local_scalar_dense_22 = _local_scalar_dense_23 = _local_scalar_dense_24 = _local_scalar_dense_25 = _local_scalar_dense_26 = _local_scalar_dense_27 = _local_scalar_dense_28 = _local_scalar_dense_29 = _local_scalar_dense_30 = _local_scalar_dense_31 = None + wait_tensor_952 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_127); all_to_all_single_127 = None + index_put_101 = torch.ops.aten.index_put.default(full_default_52, [div_7], wait_tensor_952, True); div_7 = wait_tensor_952 = None + add_2145 = torch.ops.aten.add.Tensor(add_2137, index_put_101); add_2137 = index_put_101 = None + mul_2119 = torch.ops.aten.mul.Tensor(view_2239, 1.0); view_2239 = None + scatter_add_24 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_129, mul_2119); getitem_129 = mul_2119 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(mm_19, torch.float32); mm_19 = None + sub_24 = torch.ops.aten.sub.Tensor(convert_element_type_119, amax_1); convert_element_type_119 = amax_1 = None + exp_4 = torch.ops.aten.exp.default(sub_24); sub_24 = None + div_6 = torch.ops.aten.div.Tensor(exp_4, sum_5); exp_4 = sum_5 = None + mul_2120 = torch.ops.aten.mul.Tensor(scatter_add_24, div_6); scatter_add_24 = None + sum_299 = torch.ops.aten.sum.dim_IntList(mul_2120, [1], True) + neg_127 = torch.ops.aten.neg.default(div_6); div_6 = None + fma_24 = torch.ops.prims.fma.default(neg_127, sum_299, mul_2120); neg_127 = sum_299 = mul_2120 = None + convert_element_type_3260 = torch.ops.prims.convert_element_type.default(fma_24, torch.bfloat16); fma_24 = None + permute_1632 = torch.ops.aten.permute.default(convert_element_type_3260, [1, 0]) + mm_608 = torch.ops.aten.mm.default(permute_1632, view_125); permute_1632 = view_125 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 128, '0'); convert_element_type_116 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + slice_15 = torch.ops.aten.slice.Tensor(wait_tensor_39, 0, 0, 64); wait_tensor_39 = None + permute_34 = torch.ops.aten.permute.default(slice_15, [1, 0]); slice_15 = None + permute_1634 = torch.ops.aten.permute.default(permute_34, [1, 0]); permute_34 = None + mm_609 = torch.ops.aten.mm.default(convert_element_type_3260, permute_1634); convert_element_type_3260 = permute_1634 = None + add_2146 = torch.ops.aten.add.Tensor(add_2145, mm_609); add_2145 = mm_609 = None + convert_element_type_3265 = torch.ops.prims.convert_element_type.default(mm_608, torch.float32); mm_608 = None + split_1479 = torch.ops.aten.split.Tensor(convert_element_type_3265, 1); convert_element_type_3265 = None + getitem_27685 = split_1479[0] + getitem_27686 = split_1479[1] + getitem_27687 = split_1479[2] + getitem_27688 = split_1479[3] + getitem_27689 = split_1479[4] + getitem_27690 = split_1479[5] + getitem_27691 = split_1479[6] + getitem_27692 = split_1479[7] + getitem_27693 = split_1479[8] + getitem_27694 = split_1479[9] + getitem_27695 = split_1479[10] + getitem_27696 = split_1479[11] + getitem_27697 = split_1479[12] + getitem_27698 = split_1479[13] + getitem_27699 = split_1479[14] + getitem_27700 = split_1479[15] + getitem_27701 = split_1479[16] + getitem_27702 = split_1479[17] + getitem_27703 = split_1479[18] + getitem_27704 = split_1479[19] + getitem_27705 = split_1479[20] + getitem_27706 = split_1479[21] + getitem_27707 = split_1479[22] + getitem_27708 = split_1479[23] + getitem_27709 = split_1479[24] + getitem_27710 = split_1479[25] + getitem_27711 = split_1479[26] + getitem_27712 = split_1479[27] + getitem_27713 = split_1479[28] + getitem_27714 = split_1479[29] + getitem_27715 = split_1479[30] + getitem_27716 = split_1479[31] + getitem_27717 = split_1479[32] + getitem_27718 = split_1479[33] + getitem_27719 = split_1479[34] + getitem_27720 = split_1479[35] + getitem_27721 = split_1479[36] + getitem_27722 = split_1479[37] + getitem_27723 = split_1479[38] + getitem_27724 = split_1479[39] + getitem_27725 = split_1479[40] + getitem_27726 = split_1479[41] + getitem_27727 = split_1479[42] + getitem_27728 = split_1479[43] + getitem_27729 = split_1479[44] + getitem_27730 = split_1479[45] + getitem_27731 = split_1479[46] + getitem_27732 = split_1479[47] + getitem_27733 = split_1479[48] + getitem_27734 = split_1479[49] + getitem_27735 = split_1479[50] + getitem_27736 = split_1479[51] + getitem_27737 = split_1479[52] + getitem_27738 = split_1479[53] + getitem_27739 = split_1479[54] + getitem_27740 = split_1479[55] + getitem_27741 = split_1479[56] + getitem_27742 = split_1479[57] + getitem_27743 = split_1479[58] + getitem_27744 = split_1479[59] + getitem_27745 = split_1479[60] + getitem_27746 = split_1479[61] + getitem_27747 = split_1479[62] + getitem_27748 = split_1479[63]; split_1479 = None + cat_431 = torch.ops.aten.cat.default([getitem_27685, getitem_27686, getitem_27687, getitem_27688, getitem_27689, getitem_27690, getitem_27691, getitem_27692, getitem_27693, getitem_27694, getitem_27695, getitem_27696, getitem_27697, getitem_27698, getitem_27699, getitem_27700, getitem_27701, getitem_27702, getitem_27703, getitem_27704, getitem_27705, getitem_27706, getitem_27707, getitem_27708, getitem_27709, getitem_27710, getitem_27711, getitem_27712, getitem_27713, getitem_27714, getitem_27715, getitem_27716, getitem_27717, getitem_27718, getitem_27719, getitem_27720, getitem_27721, getitem_27722, getitem_27723, getitem_27724, getitem_27725, getitem_27726, getitem_27727, getitem_27728, getitem_27729, getitem_27730, getitem_27731, getitem_27732, getitem_27733, getitem_27734, getitem_27735, getitem_27736, getitem_27737, getitem_27738, getitem_27739, getitem_27740, getitem_27741, getitem_27742, getitem_27743, getitem_27744, getitem_27745, getitem_27746, getitem_27747, getitem_27748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_27685 = getitem_27686 = getitem_27687 = getitem_27688 = getitem_27689 = getitem_27690 = getitem_27691 = getitem_27692 = getitem_27693 = getitem_27694 = getitem_27695 = getitem_27696 = getitem_27697 = getitem_27698 = getitem_27699 = getitem_27700 = getitem_27701 = getitem_27702 = getitem_27703 = getitem_27704 = getitem_27705 = getitem_27706 = getitem_27707 = getitem_27708 = getitem_27709 = getitem_27710 = getitem_27711 = getitem_27712 = getitem_27713 = getitem_27714 = getitem_27715 = getitem_27716 = getitem_27717 = getitem_27718 = getitem_27719 = getitem_27720 = getitem_27721 = getitem_27722 = getitem_27723 = getitem_27724 = getitem_27725 = getitem_27726 = getitem_27727 = getitem_27728 = getitem_27729 = getitem_27730 = getitem_27731 = getitem_27732 = getitem_27733 = getitem_27734 = getitem_27735 = getitem_27736 = getitem_27737 = getitem_27738 = getitem_27739 = getitem_27740 = getitem_27741 = getitem_27742 = getitem_27743 = getitem_27744 = getitem_27745 = getitem_27746 = getitem_27747 = getitem_27748 = None + reduce_scatter_tensor_344 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_431, 'avg', 128, '0'); cat_431 = None + wait_tensor_953 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_344); reduce_scatter_tensor_344 = None + view_2241 = torch.ops.aten.view.default(add_2146, [2, 4096, 2048]); add_2146 = None + convert_element_type_3266 = torch.ops.prims.convert_element_type.default(view_2241, torch.float32); view_2241 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16); primals_45 = None + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_113, 128, '0'); convert_element_type_113 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_3268 = torch.ops.prims.convert_element_type.default(wait_tensor_38, torch.float32); wait_tensor_38 = None + mul_2121 = torch.ops.aten.mul.Tensor(convert_element_type_3266, convert_element_type_3268); convert_element_type_3268 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(add_76, torch.float32); add_76 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_114, rsqrt_8); convert_element_type_114 = None + mul_2123 = torch.ops.aten.mul.Tensor(mul_64, mul_2121) + sum_300 = torch.ops.aten.sum.dim_IntList(mul_2123, [2], True); mul_2123 = None + div_279 = torch.ops.aten.div.Tensor(mul_64, 2048) + mul_2124 = torch.ops.aten.mul.Tensor(div_279, sum_300); div_279 = sum_300 = None + sub_772 = torch.ops.aten.sub.Tensor(mul_2121, mul_2124); mul_2121 = mul_2124 = None + mul_2125 = torch.ops.aten.mul.Tensor(sub_772, rsqrt_8); sub_772 = rsqrt_8 = None + mul_2126 = torch.ops.aten.mul.Tensor(convert_element_type_3266, mul_64); convert_element_type_3266 = mul_64 = None + sum_301 = torch.ops.aten.sum.dim_IntList(mul_2126, [0, 1]); mul_2126 = None + convert_element_type_3269 = torch.ops.prims.convert_element_type.default(mul_2125, torch.bfloat16); mul_2125 = None + add_2147 = torch.ops.aten.add.Tensor(add_2134, convert_element_type_3269); add_2134 = convert_element_type_3269 = None + convert_element_type_default_9 = torch.ops.prims.convert_element_type.default(sum_301, torch.float32); sum_301 = None + reduce_scatter_tensor_345 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_9, 'avg', 128, '0'); convert_element_type_default_9 = None + wait_tensor_954 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_345); reduce_scatter_tensor_345 = None + view_2242 = torch.ops.aten.view.default(add_2147, [8192, 2048]) + permute_1636 = torch.ops.aten.permute.default(view_2242, [1, 0]) + permute_32 = torch.ops.aten.permute.default(getitem_125, [0, 2, 1, 3]) + view_120 = torch.ops.aten.view.default(permute_32, [2, 4096, -1]); permute_32 = None + view_122 = torch.ops.aten.view.default(view_120, [8192, 2048]); view_120 = None + mm_610 = torch.ops.aten.mm.default(permute_1636, view_122); permute_1636 = view_122 = None + convert_element_type_110 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16); primals_44 = None + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_110, 128, '0'); convert_element_type_110 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + permute_1638 = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None + mm_611 = torch.ops.aten.mm.default(view_2242, permute_1638); view_2242 = permute_1638 = None + view_2243 = torch.ops.aten.view.default(mm_611, [2, 4096, 2048]); mm_611 = None + convert_element_type_3276 = torch.ops.prims.convert_element_type.default(mm_610, torch.float32); mm_610 = None + reduce_scatter_tensor_346 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3276, 'avg', 128, '0'); convert_element_type_3276 = None + wait_tensor_955 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_346); reduce_scatter_tensor_346 = None + view_2244 = torch.ops.aten.view.default(view_2243, [2, 4096, 16, 128]); view_2243 = None + permute_1640 = torch.ops.aten.permute.default(view_2244, [0, 2, 1, 3]); view_2244 = None + fw_graph24 = self.fw_graph24 + joint_graph24 = self.joint_graph24 + mask_graph24 = self.mask_graph24 + flex_attention_backward_24 = torch.ops.higher_order.flex_attention_backward(permute_29, permute_30, permute_31, getitem_125, getitem_126, permute_1640, None, fw_graph24, joint_graph24, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph24), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_29 = permute_30 = permute_31 = getitem_125 = getitem_126 = permute_1640 = fw_graph24 = joint_graph24 = mask_graph24 = None + getitem_27749 = flex_attention_backward_24[0] + getitem_27750 = flex_attention_backward_24[1] + getitem_27751 = flex_attention_backward_24[2]; flex_attention_backward_24 = None + permute_1641 = torch.ops.aten.permute.default(getitem_27751, [0, 2, 1, 3]); getitem_27751 = None + permute_1642 = torch.ops.aten.permute.default(getitem_27750, [0, 2, 1, 3]); getitem_27750 = None + permute_1643 = torch.ops.aten.permute.default(getitem_27749, [0, 2, 1, 3]); getitem_27749 = None + slice_308 = torch.ops.aten.slice.Tensor(permute_1642, 3, 0, 128) + slice_309 = torch.ops.aten.slice.Tensor(permute_1642, 3, 128, 192); permute_1642 = None + sum_302 = torch.ops.aten.sum.dim_IntList(slice_309, [2], True); slice_309 = None + cat_432 = torch.ops.aten.cat.default([slice_308, permute_1641], 3); slice_308 = permute_1641 = None + view_2245 = torch.ops.aten.view.default(cat_432, [2, 4096, 4096]); cat_432 = None + view_2246 = torch.ops.aten.view.default(view_2245, [8192, 4096]); view_2245 = None + permute_1644 = torch.ops.aten.permute.default(view_2246, [1, 0]) + mm_612 = torch.ops.aten.mm.default(permute_1644, view_117); permute_1644 = view_117 = None + convert_element_type_107 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16); primals_43 = None + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_107, 128, '0'); convert_element_type_107 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_28 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + permute_1646 = torch.ops.aten.permute.default(permute_28, [1, 0]); permute_28 = None + mm_613 = torch.ops.aten.mm.default(view_2246, permute_1646); view_2246 = permute_1646 = None + view_2247 = torch.ops.aten.view.default(mm_613, [2, 4096, 512]); mm_613 = None + convert_element_type_3281 = torch.ops.prims.convert_element_type.default(mm_612, torch.float32); mm_612 = None + reduce_scatter_tensor_347 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3281, 'avg', 128, '0'); convert_element_type_3281 = None + wait_tensor_956 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_347); reduce_scatter_tensor_347 = None + convert_element_type_3282 = torch.ops.prims.convert_element_type.default(view_2247, torch.float32); view_2247 = None + convert_element_type_104 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16); primals_42 = None + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_104, 128, '0'); convert_element_type_104 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + convert_element_type_3284 = torch.ops.prims.convert_element_type.default(wait_tensor_35, torch.float32); wait_tensor_35 = None + mul_2127 = torch.ops.aten.mul.Tensor(convert_element_type_3282, convert_element_type_3284); convert_element_type_3284 = None + convert_element_type_105 = torch.ops.prims.convert_element_type.default(getitem_121, torch.float32); getitem_121 = None + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_105, rsqrt_7); convert_element_type_105 = None + mul_2129 = torch.ops.aten.mul.Tensor(mul_62, mul_2127) + sum_303 = torch.ops.aten.sum.dim_IntList(mul_2129, [2], True); mul_2129 = None + div_280 = torch.ops.aten.div.Tensor(mul_62, 512) + mul_2130 = torch.ops.aten.mul.Tensor(div_280, sum_303); div_280 = sum_303 = None + sub_773 = torch.ops.aten.sub.Tensor(mul_2127, mul_2130); mul_2127 = mul_2130 = None + mul_2131 = torch.ops.aten.mul.Tensor(sub_773, rsqrt_7); sub_773 = rsqrt_7 = None + mul_2132 = torch.ops.aten.mul.Tensor(convert_element_type_3282, mul_62); convert_element_type_3282 = mul_62 = None + sum_304 = torch.ops.aten.sum.dim_IntList(mul_2132, [0, 1]); mul_2132 = None + convert_element_type_3285 = torch.ops.prims.convert_element_type.default(mul_2131, torch.bfloat16); mul_2131 = None + convert_element_type_default_8 = torch.ops.prims.convert_element_type.default(sum_304, torch.float32); sum_304 = None + reduce_scatter_tensor_348 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_8, 'avg', 128, '0'); convert_element_type_default_8 = None + wait_tensor_957 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_348); reduce_scatter_tensor_348 = None + convert_element_type_3288 = torch.ops.prims.convert_element_type.default(sum_302, torch.float32); sum_302 = None + view_2248 = torch.ops.aten.view.default(convert_element_type_3288, [2, 4096, 1, 32, 2]); convert_element_type_3288 = None + view_as_complex_102 = torch.ops.aten.view_as_complex.default(view_2248); view_2248 = None + mul_2133 = torch.ops.aten.mul.Tensor(view_as_complex_102, clone_9); view_as_complex_102 = None + view_as_real_102 = torch.ops.aten.view_as_real.default(mul_2133); mul_2133 = None + view_2249 = torch.ops.aten.view.default(view_as_real_102, [2, 4096, 1, 64]); view_as_real_102 = None + convert_element_type_3289 = torch.ops.prims.convert_element_type.default(view_2249, torch.bfloat16); view_2249 = None + squeeze_50 = torch.ops.aten.squeeze.dim(convert_element_type_3289, 2); convert_element_type_3289 = None + cat_433 = torch.ops.aten.cat.default([convert_element_type_3285, squeeze_50], 2); convert_element_type_3285 = squeeze_50 = None + view_2250 = torch.ops.aten.view.default(cat_433, [8192, 576]); cat_433 = None + permute_1648 = torch.ops.aten.permute.default(view_2250, [1, 0]) + mm_614 = torch.ops.aten.mm.default(permute_1648, view_103); permute_1648 = None + convert_element_type_99 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16); primals_41 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_99, 128, '0'); convert_element_type_99 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + slice_13 = torch.ops.aten.slice.Tensor(wait_tensor_34, 0, 0, 576); wait_tensor_34 = None + permute_27 = torch.ops.aten.permute.default(slice_13, [1, 0]); slice_13 = None + permute_1650 = torch.ops.aten.permute.default(permute_27, [1, 0]); permute_27 = None + mm_615 = torch.ops.aten.mm.default(view_2250, permute_1650); view_2250 = permute_1650 = None + view_2251 = torch.ops.aten.view.default(mm_615, [2, 4096, 2048]); mm_615 = None + convert_element_type_3294 = torch.ops.prims.convert_element_type.default(mm_614, torch.float32); mm_614 = None + split_1480 = torch.ops.aten.split.Tensor(convert_element_type_3294, 5); convert_element_type_3294 = None + getitem_27753 = split_1480[0] + getitem_27754 = split_1480[1] + getitem_27755 = split_1480[2] + getitem_27756 = split_1480[3] + getitem_27757 = split_1480[4] + getitem_27758 = split_1480[5] + getitem_27759 = split_1480[6] + getitem_27760 = split_1480[7] + getitem_27761 = split_1480[8] + getitem_27762 = split_1480[9] + getitem_27763 = split_1480[10] + getitem_27764 = split_1480[11] + getitem_27765 = split_1480[12] + getitem_27766 = split_1480[13] + getitem_27767 = split_1480[14] + getitem_27768 = split_1480[15] + getitem_27769 = split_1480[16] + getitem_27770 = split_1480[17] + getitem_27771 = split_1480[18] + getitem_27772 = split_1480[19] + getitem_27773 = split_1480[20] + getitem_27774 = split_1480[21] + getitem_27775 = split_1480[22] + getitem_27776 = split_1480[23] + getitem_27777 = split_1480[24] + getitem_27778 = split_1480[25] + getitem_27779 = split_1480[26] + getitem_27780 = split_1480[27] + getitem_27781 = split_1480[28] + getitem_27782 = split_1480[29] + getitem_27783 = split_1480[30] + getitem_27784 = split_1480[31] + getitem_27785 = split_1480[32] + getitem_27786 = split_1480[33] + getitem_27787 = split_1480[34] + getitem_27788 = split_1480[35] + getitem_27789 = split_1480[36] + getitem_27790 = split_1480[37] + getitem_27791 = split_1480[38] + getitem_27792 = split_1480[39] + getitem_27793 = split_1480[40] + getitem_27794 = split_1480[41] + getitem_27795 = split_1480[42] + getitem_27796 = split_1480[43] + getitem_27797 = split_1480[44] + getitem_27798 = split_1480[45] + getitem_27799 = split_1480[46] + getitem_27800 = split_1480[47] + getitem_27801 = split_1480[48] + getitem_27802 = split_1480[49] + getitem_27803 = split_1480[50] + getitem_27804 = split_1480[51] + getitem_27805 = split_1480[52] + getitem_27806 = split_1480[53] + getitem_27807 = split_1480[54] + getitem_27808 = split_1480[55] + getitem_27809 = split_1480[56] + getitem_27810 = split_1480[57] + getitem_27811 = split_1480[58] + getitem_27812 = split_1480[59] + getitem_27813 = split_1480[60] + getitem_27814 = split_1480[61] + getitem_27815 = split_1480[62] + getitem_27816 = split_1480[63] + getitem_27817 = split_1480[64] + getitem_27818 = split_1480[65] + getitem_27819 = split_1480[66] + getitem_27820 = split_1480[67] + getitem_27821 = split_1480[68] + getitem_27822 = split_1480[69] + getitem_27823 = split_1480[70] + getitem_27824 = split_1480[71] + getitem_27825 = split_1480[72] + getitem_27826 = split_1480[73] + getitem_27827 = split_1480[74] + getitem_27828 = split_1480[75] + getitem_27829 = split_1480[76] + getitem_27830 = split_1480[77] + getitem_27831 = split_1480[78] + getitem_27832 = split_1480[79] + getitem_27833 = split_1480[80] + getitem_27834 = split_1480[81] + getitem_27835 = split_1480[82] + getitem_27836 = split_1480[83] + getitem_27837 = split_1480[84] + getitem_27838 = split_1480[85] + getitem_27839 = split_1480[86] + getitem_27840 = split_1480[87] + getitem_27841 = split_1480[88] + getitem_27842 = split_1480[89] + getitem_27843 = split_1480[90] + getitem_27844 = split_1480[91] + getitem_27845 = split_1480[92] + getitem_27846 = split_1480[93] + getitem_27847 = split_1480[94] + getitem_27848 = split_1480[95] + getitem_27849 = split_1480[96] + getitem_27850 = split_1480[97] + getitem_27851 = split_1480[98] + getitem_27852 = split_1480[99] + getitem_27853 = split_1480[100] + getitem_27854 = split_1480[101] + getitem_27855 = split_1480[102] + getitem_27856 = split_1480[103] + getitem_27857 = split_1480[104] + getitem_27858 = split_1480[105] + getitem_27859 = split_1480[106] + getitem_27860 = split_1480[107] + getitem_27861 = split_1480[108] + getitem_27862 = split_1480[109] + getitem_27863 = split_1480[110] + getitem_27864 = split_1480[111] + getitem_27865 = split_1480[112] + getitem_27866 = split_1480[113] + getitem_27867 = split_1480[114] + getitem_27868 = split_1480[115]; split_1480 = None + constant_pad_nd_1912 = torch.ops.aten.constant_pad_nd.default(getitem_27868, [0, 0, 0, 4], 0.0); getitem_27868 = None + cat_434 = torch.ops.aten.cat.default([getitem_27753, getitem_27754, getitem_27755, getitem_27756, getitem_27757, getitem_27758, getitem_27759, getitem_27760, getitem_27761, getitem_27762, getitem_27763, getitem_27764, getitem_27765, getitem_27766, getitem_27767, getitem_27768, getitem_27769, getitem_27770, getitem_27771, getitem_27772, getitem_27773, getitem_27774, getitem_27775, getitem_27776, getitem_27777, getitem_27778, getitem_27779, getitem_27780, getitem_27781, getitem_27782, getitem_27783, getitem_27784, getitem_27785, getitem_27786, getitem_27787, getitem_27788, getitem_27789, getitem_27790, getitem_27791, getitem_27792, getitem_27793, getitem_27794, getitem_27795, getitem_27796, getitem_27797, getitem_27798, getitem_27799, getitem_27800, getitem_27801, getitem_27802, getitem_27803, getitem_27804, getitem_27805, getitem_27806, getitem_27807, getitem_27808, getitem_27809, getitem_27810, getitem_27811, getitem_27812, getitem_27813, getitem_27814, getitem_27815, getitem_27816, getitem_27817, getitem_27818, getitem_27819, getitem_27820, getitem_27821, getitem_27822, getitem_27823, getitem_27824, getitem_27825, getitem_27826, getitem_27827, getitem_27828, getitem_27829, getitem_27830, getitem_27831, getitem_27832, getitem_27833, getitem_27834, getitem_27835, getitem_27836, getitem_27837, getitem_27838, getitem_27839, getitem_27840, getitem_27841, getitem_27842, getitem_27843, getitem_27844, getitem_27845, getitem_27846, getitem_27847, getitem_27848, getitem_27849, getitem_27850, getitem_27851, getitem_27852, getitem_27853, getitem_27854, getitem_27855, getitem_27856, getitem_27857, getitem_27858, getitem_27859, getitem_27860, getitem_27861, getitem_27862, getitem_27863, getitem_27864, getitem_27865, getitem_27866, getitem_27867, constant_pad_nd_1912, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_27753 = getitem_27754 = getitem_27755 = getitem_27756 = getitem_27757 = getitem_27758 = getitem_27759 = getitem_27760 = getitem_27761 = getitem_27762 = getitem_27763 = getitem_27764 = getitem_27765 = getitem_27766 = getitem_27767 = getitem_27768 = getitem_27769 = getitem_27770 = getitem_27771 = getitem_27772 = getitem_27773 = getitem_27774 = getitem_27775 = getitem_27776 = getitem_27777 = getitem_27778 = getitem_27779 = getitem_27780 = getitem_27781 = getitem_27782 = getitem_27783 = getitem_27784 = getitem_27785 = getitem_27786 = getitem_27787 = getitem_27788 = getitem_27789 = getitem_27790 = getitem_27791 = getitem_27792 = getitem_27793 = getitem_27794 = getitem_27795 = getitem_27796 = getitem_27797 = getitem_27798 = getitem_27799 = getitem_27800 = getitem_27801 = getitem_27802 = getitem_27803 = getitem_27804 = getitem_27805 = getitem_27806 = getitem_27807 = getitem_27808 = getitem_27809 = getitem_27810 = getitem_27811 = getitem_27812 = getitem_27813 = getitem_27814 = getitem_27815 = getitem_27816 = getitem_27817 = getitem_27818 = getitem_27819 = getitem_27820 = getitem_27821 = getitem_27822 = getitem_27823 = getitem_27824 = getitem_27825 = getitem_27826 = getitem_27827 = getitem_27828 = getitem_27829 = getitem_27830 = getitem_27831 = getitem_27832 = getitem_27833 = getitem_27834 = getitem_27835 = getitem_27836 = getitem_27837 = getitem_27838 = getitem_27839 = getitem_27840 = getitem_27841 = getitem_27842 = getitem_27843 = getitem_27844 = getitem_27845 = getitem_27846 = getitem_27847 = getitem_27848 = getitem_27849 = getitem_27850 = getitem_27851 = getitem_27852 = getitem_27853 = getitem_27854 = getitem_27855 = getitem_27856 = getitem_27857 = getitem_27858 = getitem_27859 = getitem_27860 = getitem_27861 = getitem_27862 = getitem_27863 = getitem_27864 = getitem_27865 = getitem_27866 = getitem_27867 = constant_pad_nd_1912 = None + reduce_scatter_tensor_349 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_434, 'avg', 128, '0'); cat_434 = None + wait_tensor_958 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_349); reduce_scatter_tensor_349 = None + slice_310 = torch.ops.aten.slice.Tensor(permute_1643, 3, 0, 128) + slice_311 = torch.ops.aten.slice.Tensor(permute_1643, 3, 128, 192); permute_1643 = None + convert_element_type_3295 = torch.ops.prims.convert_element_type.default(slice_311, torch.float32); slice_311 = None + view_2252 = torch.ops.aten.view.default(convert_element_type_3295, [2, 4096, 16, 32, 2]); convert_element_type_3295 = None + view_as_complex_103 = torch.ops.aten.view_as_complex.default(view_2252); view_2252 = None + mul_2134 = torch.ops.aten.mul.Tensor(view_as_complex_103, clone_9); view_as_complex_103 = None + view_as_real_103 = torch.ops.aten.view_as_real.default(mul_2134); mul_2134 = None + view_2253 = torch.ops.aten.view.default(view_as_real_103, [2, 4096, 16, 64]); view_as_real_103 = None + convert_element_type_3296 = torch.ops.prims.convert_element_type.default(view_2253, torch.bfloat16); view_2253 = None + cat_435 = torch.ops.aten.cat.default([slice_310, convert_element_type_3296], 3); slice_310 = convert_element_type_3296 = None + view_2254 = torch.ops.aten.view.default(cat_435, [2, 4096, 3072]); cat_435 = None + view_2255 = torch.ops.aten.view.default(view_2254, [8192, 3072]); view_2254 = None + permute_1652 = torch.ops.aten.permute.default(view_2255, [1, 0]) + mm_616 = torch.ops.aten.mm.default(permute_1652, view_103); permute_1652 = view_103 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16); primals_40 = None + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 128, '0'); convert_element_type_94 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_26 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + permute_1654 = torch.ops.aten.permute.default(permute_26, [1, 0]); permute_26 = None + mm_617 = torch.ops.aten.mm.default(view_2255, permute_1654); view_2255 = permute_1654 = None + view_2256 = torch.ops.aten.view.default(mm_617, [2, 4096, 2048]); mm_617 = None + add_2148 = torch.ops.aten.add.Tensor(view_2251, view_2256); view_2251 = view_2256 = None + convert_element_type_3301 = torch.ops.prims.convert_element_type.default(mm_616, torch.float32); mm_616 = None + reduce_scatter_tensor_350 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3301, 'avg', 128, '0'); convert_element_type_3301 = None + wait_tensor_959 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_350); reduce_scatter_tensor_350 = None + convert_element_type_3302 = torch.ops.prims.convert_element_type.default(add_2148, torch.float32); add_2148 = None + convert_element_type_91 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16); primals_39 = None + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_91, 128, '0'); convert_element_type_91 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_3304 = torch.ops.prims.convert_element_type.default(wait_tensor_32, torch.float32); wait_tensor_32 = None + mul_2135 = torch.ops.aten.mul.Tensor(convert_element_type_3302, convert_element_type_3304); convert_element_type_3304 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(add_73, torch.float32); add_73 = None + mul_58 = torch.ops.aten.mul.Tensor(convert_element_type_92, rsqrt_6); convert_element_type_92 = None + mul_2137 = torch.ops.aten.mul.Tensor(mul_58, mul_2135) + sum_305 = torch.ops.aten.sum.dim_IntList(mul_2137, [2], True); mul_2137 = None + div_281 = torch.ops.aten.div.Tensor(mul_58, 2048) + mul_2138 = torch.ops.aten.mul.Tensor(div_281, sum_305); div_281 = sum_305 = None + sub_774 = torch.ops.aten.sub.Tensor(mul_2135, mul_2138); mul_2135 = mul_2138 = None + mul_2139 = torch.ops.aten.mul.Tensor(sub_774, rsqrt_6); sub_774 = rsqrt_6 = None + mul_2140 = torch.ops.aten.mul.Tensor(convert_element_type_3302, mul_58); convert_element_type_3302 = mul_58 = None + sum_306 = torch.ops.aten.sum.dim_IntList(mul_2140, [0, 1]); mul_2140 = None + convert_element_type_3305 = torch.ops.prims.convert_element_type.default(mul_2139, torch.bfloat16); mul_2139 = None + add_2149 = torch.ops.aten.add.Tensor(add_2147, convert_element_type_3305); add_2147 = convert_element_type_3305 = None + convert_element_type_default_7 = torch.ops.prims.convert_element_type.default(sum_306, torch.float32); sum_306 = None + reduce_scatter_tensor_351 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_7, 'avg', 128, '0'); convert_element_type_default_7 = None + wait_tensor_960 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_351); reduce_scatter_tensor_351 = None + view_2257 = torch.ops.aten.view.default(add_2149, [8192, 2048]) + unsqueeze_78 = torch.ops.aten.unsqueeze.default(view_2257, 1) + convert_element_type_3308 = torch.ops.prims.convert_element_type.default(unsqueeze_78, torch.float32); unsqueeze_78 = None + bmm_76 = torch.ops.aten.bmm.default(permute_1656, convert_element_type_3308); permute_1656 = None + bmm_77 = torch.ops.aten.bmm.default(convert_element_type_3308, permute_1657); convert_element_type_3308 = permute_1657 = None + convert_element_type_3309 = torch.ops.prims.convert_element_type.default(bmm_76, torch.bfloat16); bmm_76 = None + view_2258 = torch.ops.aten.view.default(bmm_77, [8192, 6]); bmm_77 = None + view_2259 = torch.ops.aten.view.default(convert_element_type_3309, [49152, 2048]); convert_element_type_3309 = None + index_102 = torch.ops.aten.index.Tensor(view_2259, [getitem_21]); view_2259 = getitem_21 = None + permute_1658 = torch.ops.aten.permute.default(view_2257, [1, 0]) + mm_618 = torch.ops.aten.mm.default(permute_1658, mul_55); permute_1658 = mul_55 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16); primals_38 = None + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 128, '0'); convert_element_type_86 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_25 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + permute_1660 = torch.ops.aten.permute.default(permute_25, [1, 0]); permute_25 = None + mm_619 = torch.ops.aten.mm.default(view_2257, permute_1660); view_2257 = permute_1660 = None + convert_element_type_3314 = torch.ops.prims.convert_element_type.default(mm_618, torch.float32); mm_618 = None + reduce_scatter_tensor_352 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3314, 'avg', 128, '0'); convert_element_type_3314 = None + wait_tensor_961 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_352); reduce_scatter_tensor_352 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(mm_12, torch.float32); mm_12 = None + neg_2 = torch.ops.aten.neg.default(convert_element_type_81) + exp_3 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_68 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + div_5 = torch.ops.aten.div.Tensor(convert_element_type_81, add_68) + convert_element_type_82 = torch.ops.prims.convert_element_type.default(div_5, torch.bfloat16); div_5 = None + mul_2141 = torch.ops.aten.mul.Tensor(mm_619, convert_element_type_82); convert_element_type_82 = None + mul_2142 = torch.ops.aten.mul.Tensor(mm_619, mm_13); mm_619 = mm_13 = None + permute_1662 = torch.ops.aten.permute.default(mul_2141, [1, 0]) + mm_620 = torch.ops.aten.mm.default(permute_1662, view_58); permute_1662 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16); primals_37 = None + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 128, '0'); convert_element_type_83 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + permute_1664 = torch.ops.aten.permute.default(permute_24, [1, 0]); permute_24 = None + mm_621 = torch.ops.aten.mm.default(mul_2141, permute_1664); mul_2141 = permute_1664 = None + convert_element_type_3319 = torch.ops.prims.convert_element_type.default(mm_620, torch.float32); mm_620 = None + reduce_scatter_tensor_353 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3319, 'avg', 128, '0'); convert_element_type_3319 = None + wait_tensor_962 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_353); reduce_scatter_tensor_353 = None + convert_element_type_3320 = torch.ops.prims.convert_element_type.default(mul_2142, torch.float32); mul_2142 = None + reciprocal_50 = torch.ops.aten.reciprocal.default(add_68); add_68 = None + mul_2143 = torch.ops.aten.mul.Tensor(reciprocal_50, 1); reciprocal_50 = None + mul_2144 = torch.ops.aten.mul.Tensor(convert_element_type_3320, mul_2143); convert_element_type_3320 = None + sub_775 = torch.ops.aten.sub.Tensor(1, mul_2143); mul_2143 = None + mul_2145 = torch.ops.aten.mul.Tensor(convert_element_type_81, sub_775); convert_element_type_81 = sub_775 = None + add_2151 = torch.ops.aten.add.Tensor(mul_2145, 1); mul_2145 = None + mul_2146 = torch.ops.aten.mul.Tensor(mul_2144, add_2151); mul_2144 = add_2151 = None + convert_element_type_3322 = torch.ops.prims.convert_element_type.default(mul_2146, torch.bfloat16); mul_2146 = None + permute_1666 = torch.ops.aten.permute.default(convert_element_type_3322, [1, 0]) + mm_622 = torch.ops.aten.mm.default(permute_1666, view_58); permute_1666 = None + convert_element_type_78 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_78, 128, '0'); convert_element_type_78 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + permute_1668 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None + mm_623 = torch.ops.aten.mm.default(convert_element_type_3322, permute_1668); convert_element_type_3322 = permute_1668 = None + add_2152 = torch.ops.aten.add.Tensor(mm_621, mm_623); mm_621 = mm_623 = None + convert_element_type_3327 = torch.ops.prims.convert_element_type.default(mm_622, torch.float32); mm_622 = None + reduce_scatter_tensor_354 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3327, 'avg', 128, '0'); convert_element_type_3327 = None + wait_tensor_963 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_354); reduce_scatter_tensor_354 = None + all_to_all_single_128 = torch.ops._c10d_functional.all_to_all_single.default(index_102, [_local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15], [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7], '1033'); index_102 = None + wait_tensor_964 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_128); all_to_all_single_128 = None + full_498 = torch.ops.aten.full.default([sym_size_int_1, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_1 = None + slice_scatter_25 = torch.ops.aten.slice_scatter.default(full_498, wait_tensor_964, 0, 0, -1); wait_tensor_964 = None + index_103 = torch.ops.aten.index.Tensor(slice_scatter_25, [getitem_22]); slice_scatter_25 = None + permute_1670 = torch.ops.aten.permute.default(index_103, [1, 0]) + _grouped_mm_228 = torch.ops.aten._grouped_mm.default(permute_1670, mul_35, cumsum_2); permute_1670 = mul_35 = None + _grouped_mm_229 = torch.ops.aten._grouped_mm.default(index_103, permute_1672, cumsum_2); index_103 = permute_1672 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(_grouped_mm, torch.float32); _grouped_mm = None + neg_1 = torch.ops.aten.neg.default(convert_element_type_76) + exp_2 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_32 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + div_4 = torch.ops.aten.div.Tensor(convert_element_type_76, add_32) + convert_element_type_77 = torch.ops.prims.convert_element_type.default(div_4, torch.bfloat16); div_4 = None + mul_2147 = torch.ops.aten.mul.Tensor(_grouped_mm_229, convert_element_type_77); convert_element_type_77 = None + mul_2148 = torch.ops.aten.mul.Tensor(_grouped_mm_229, _grouped_mm_1); _grouped_mm_229 = _grouped_mm_1 = None + permute_1674 = torch.ops.aten.permute.default(mul_2147, [1, 0]) + _grouped_mm_230 = torch.ops.aten._grouped_mm.default(permute_1674, index_1, cumsum_2); permute_1674 = None + _grouped_mm_231 = torch.ops.aten._grouped_mm.default(mul_2147, permute_1676, cumsum_2); mul_2147 = permute_1676 = None + convert_element_type_3328 = torch.ops.prims.convert_element_type.default(mul_2148, torch.float32); mul_2148 = None + reciprocal_51 = torch.ops.aten.reciprocal.default(add_32); add_32 = None + mul_2149 = torch.ops.aten.mul.Tensor(reciprocal_51, 1); reciprocal_51 = None + mul_2150 = torch.ops.aten.mul.Tensor(convert_element_type_3328, mul_2149); convert_element_type_3328 = None + sub_776 = torch.ops.aten.sub.Tensor(1, mul_2149); mul_2149 = None + mul_2151 = torch.ops.aten.mul.Tensor(convert_element_type_76, sub_776); convert_element_type_76 = sub_776 = None + add_2154 = torch.ops.aten.add.Tensor(mul_2151, 1); mul_2151 = None + mul_2152 = torch.ops.aten.mul.Tensor(mul_2150, add_2154); mul_2150 = add_2154 = None + convert_element_type_3330 = torch.ops.prims.convert_element_type.default(mul_2152, torch.bfloat16); mul_2152 = None + permute_1678 = torch.ops.aten.permute.default(convert_element_type_3330, [1, 0]) + _grouped_mm_232 = torch.ops.aten._grouped_mm.default(permute_1678, index_1, cumsum_2); permute_1678 = index_1 = None + _grouped_mm_233 = torch.ops.aten._grouped_mm.default(convert_element_type_3330, permute_1680, cumsum_2); convert_element_type_3330 = permute_1680 = cumsum_2 = None + add_2155 = torch.ops.aten.add.Tensor(_grouped_mm_231, _grouped_mm_233); _grouped_mm_231 = _grouped_mm_233 = None + convert_element_type_3331 = torch.ops.prims.convert_element_type.default(_grouped_mm_230, torch.float32); _grouped_mm_230 = None + div_282 = torch.ops.aten.div.Tensor(convert_element_type_3331, 128); convert_element_type_3331 = None + split_1482 = torch.ops.aten.split.Tensor(div_282, 88, 1); div_282 = None + getitem_27885 = split_1482[0] + getitem_27902 = split_1482[1] + getitem_27919 = split_1482[2] + getitem_27936 = split_1482[3] + getitem_27953 = split_1482[4] + getitem_27970 = split_1482[5] + getitem_27987 = split_1482[6] + getitem_28004 = split_1482[7] + getitem_28021 = split_1482[8] + getitem_28038 = split_1482[9] + getitem_28055 = split_1482[10] + getitem_28072 = split_1482[11] + getitem_28089 = split_1482[12] + getitem_28106 = split_1482[13] + getitem_28123 = split_1482[14] + getitem_28140 = split_1482[15]; split_1482 = None + cat_436 = torch.ops.aten.cat.default([getitem_27885, getitem_27902, getitem_27919, getitem_27936, getitem_27953, getitem_27970, getitem_27987, getitem_28004, getitem_28021, getitem_28038, getitem_28055, getitem_28072, getitem_28089, getitem_28106, getitem_28123, getitem_28140]); getitem_27885 = getitem_27902 = getitem_27919 = getitem_27936 = getitem_27953 = getitem_27970 = getitem_27987 = getitem_28004 = getitem_28021 = getitem_28038 = getitem_28055 = getitem_28072 = getitem_28089 = getitem_28106 = getitem_28123 = getitem_28140 = None + reduce_scatter_tensor_355 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_436, 'sum', 16, '1025'); cat_436 = None + wait_tensor_965 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_355); reduce_scatter_tensor_355 = None + convert_element_type_3332 = torch.ops.prims.convert_element_type.default(_grouped_mm_228, torch.float32); _grouped_mm_228 = None + div_283 = torch.ops.aten.div.Tensor(convert_element_type_3332, 128); convert_element_type_3332 = None + split_1499 = torch.ops.aten.split.Tensor(div_283, 128, 1); div_283 = None + getitem_28157 = split_1499[0] + getitem_28174 = split_1499[1] + getitem_28191 = split_1499[2] + getitem_28208 = split_1499[3] + getitem_28225 = split_1499[4] + getitem_28242 = split_1499[5] + getitem_28259 = split_1499[6] + getitem_28276 = split_1499[7] + getitem_28293 = split_1499[8] + getitem_28310 = split_1499[9] + getitem_28327 = split_1499[10] + getitem_28344 = split_1499[11] + getitem_28361 = split_1499[12] + getitem_28378 = split_1499[13] + getitem_28395 = split_1499[14] + getitem_28412 = split_1499[15]; split_1499 = None + cat_437 = torch.ops.aten.cat.default([getitem_28157, getitem_28174, getitem_28191, getitem_28208, getitem_28225, getitem_28242, getitem_28259, getitem_28276, getitem_28293, getitem_28310, getitem_28327, getitem_28344, getitem_28361, getitem_28378, getitem_28395, getitem_28412]); getitem_28157 = getitem_28174 = getitem_28191 = getitem_28208 = getitem_28225 = getitem_28242 = getitem_28259 = getitem_28276 = getitem_28293 = getitem_28310 = getitem_28327 = getitem_28344 = getitem_28361 = getitem_28378 = getitem_28395 = getitem_28412 = None + reduce_scatter_tensor_356 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_437, 'sum', 16, '1025'); cat_437 = None + wait_tensor_966 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_356); reduce_scatter_tensor_356 = None + convert_element_type_3333 = torch.ops.prims.convert_element_type.default(_grouped_mm_232, torch.float32); _grouped_mm_232 = None + div_284 = torch.ops.aten.div.Tensor(convert_element_type_3333, 128); convert_element_type_3333 = None + split_1516 = torch.ops.aten.split.Tensor(div_284, 88, 1); div_284 = None + getitem_28429 = split_1516[0] + getitem_28446 = split_1516[1] + getitem_28463 = split_1516[2] + getitem_28480 = split_1516[3] + getitem_28497 = split_1516[4] + getitem_28514 = split_1516[5] + getitem_28531 = split_1516[6] + getitem_28548 = split_1516[7] + getitem_28565 = split_1516[8] + getitem_28582 = split_1516[9] + getitem_28599 = split_1516[10] + getitem_28616 = split_1516[11] + getitem_28633 = split_1516[12] + getitem_28650 = split_1516[13] + getitem_28667 = split_1516[14] + getitem_28684 = split_1516[15]; split_1516 = None + cat_438 = torch.ops.aten.cat.default([getitem_28429, getitem_28446, getitem_28463, getitem_28480, getitem_28497, getitem_28514, getitem_28531, getitem_28548, getitem_28565, getitem_28582, getitem_28599, getitem_28616, getitem_28633, getitem_28650, getitem_28667, getitem_28684]); getitem_28429 = getitem_28446 = getitem_28463 = getitem_28480 = getitem_28497 = getitem_28514 = getitem_28531 = getitem_28548 = getitem_28565 = getitem_28582 = getitem_28599 = getitem_28616 = getitem_28633 = getitem_28650 = getitem_28667 = getitem_28684 = None + reduce_scatter_tensor_357 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_438, 'sum', 16, '1025'); cat_438 = None + wait_tensor_967 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_357); reduce_scatter_tensor_357 = None + index_put_102 = torch.ops.aten.index_put.default(full_498, [getitem_22], add_2155, True); full_498 = getitem_22 = add_2155 = None + slice_312 = torch.ops.aten.slice.Tensor(index_put_102, 0, 0, add_2156); index_put_102 = add_2156 = None + all_to_all_single_129 = torch.ops._c10d_functional.all_to_all_single.default(slice_312, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7], [_local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15], '1033'); slice_312 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = _local_scalar_dense_3 = _local_scalar_dense_4 = _local_scalar_dense_5 = _local_scalar_dense_6 = _local_scalar_dense_7 = _local_scalar_dense_8 = _local_scalar_dense_9 = _local_scalar_dense_10 = _local_scalar_dense_11 = _local_scalar_dense_12 = _local_scalar_dense_13 = _local_scalar_dense_14 = _local_scalar_dense_15 = None + wait_tensor_968 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_129); all_to_all_single_129 = None + index_put_103 = torch.ops.aten.index_put.default(full_default_52, [div_2], wait_tensor_968, True); full_default_52 = div_2 = wait_tensor_968 = None + add_2160 = torch.ops.aten.add.Tensor(add_2152, index_put_103); add_2152 = index_put_103 = None + mul_2153 = torch.ops.aten.mul.Tensor(view_2258, 1.0); view_2258 = None + scatter_add_25 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_19, mul_2153); full_default_53 = getitem_19 = mul_2153 = None + convert_element_type_65 = torch.ops.prims.convert_element_type.default(mm_11, torch.float32); mm_11 = None + sub = torch.ops.aten.sub.Tensor(convert_element_type_65, amax); convert_element_type_65 = amax = None + exp_1 = torch.ops.aten.exp.default(sub); sub = None + div_1 = torch.ops.aten.div.Tensor(exp_1, sum_1); exp_1 = sum_1 = None + mul_2154 = torch.ops.aten.mul.Tensor(scatter_add_25, div_1); scatter_add_25 = None + sum_307 = torch.ops.aten.sum.dim_IntList(mul_2154, [1], True) + neg_130 = torch.ops.aten.neg.default(div_1); div_1 = None + fma_25 = torch.ops.prims.fma.default(neg_130, sum_307, mul_2154); neg_130 = sum_307 = mul_2154 = None + convert_element_type_3334 = torch.ops.prims.convert_element_type.default(fma_25, torch.bfloat16); fma_25 = None + permute_1682 = torch.ops.aten.permute.default(convert_element_type_3334, [1, 0]) + mm_624 = torch.ops.aten.mm.default(permute_1682, view_58); permute_1682 = view_58 = None + convert_element_type_62 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_62, 128, '0'); convert_element_type_62 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + slice_9 = torch.ops.aten.slice.Tensor(wait_tensor_18, 0, 0, 64); wait_tensor_18 = None + permute_19 = torch.ops.aten.permute.default(slice_9, [1, 0]); slice_9 = None + permute_1684 = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None + mm_625 = torch.ops.aten.mm.default(convert_element_type_3334, permute_1684); convert_element_type_3334 = permute_1684 = None + add_2161 = torch.ops.aten.add.Tensor(add_2160, mm_625); add_2160 = mm_625 = None + convert_element_type_3339 = torch.ops.prims.convert_element_type.default(mm_624, torch.float32); mm_624 = None + split_1532 = torch.ops.aten.split.Tensor(convert_element_type_3339, 1); convert_element_type_3339 = None + getitem_28685 = split_1532[0] + getitem_28686 = split_1532[1] + getitem_28687 = split_1532[2] + getitem_28688 = split_1532[3] + getitem_28689 = split_1532[4] + getitem_28690 = split_1532[5] + getitem_28691 = split_1532[6] + getitem_28692 = split_1532[7] + getitem_28693 = split_1532[8] + getitem_28694 = split_1532[9] + getitem_28695 = split_1532[10] + getitem_28696 = split_1532[11] + getitem_28697 = split_1532[12] + getitem_28698 = split_1532[13] + getitem_28699 = split_1532[14] + getitem_28700 = split_1532[15] + getitem_28701 = split_1532[16] + getitem_28702 = split_1532[17] + getitem_28703 = split_1532[18] + getitem_28704 = split_1532[19] + getitem_28705 = split_1532[20] + getitem_28706 = split_1532[21] + getitem_28707 = split_1532[22] + getitem_28708 = split_1532[23] + getitem_28709 = split_1532[24] + getitem_28710 = split_1532[25] + getitem_28711 = split_1532[26] + getitem_28712 = split_1532[27] + getitem_28713 = split_1532[28] + getitem_28714 = split_1532[29] + getitem_28715 = split_1532[30] + getitem_28716 = split_1532[31] + getitem_28717 = split_1532[32] + getitem_28718 = split_1532[33] + getitem_28719 = split_1532[34] + getitem_28720 = split_1532[35] + getitem_28721 = split_1532[36] + getitem_28722 = split_1532[37] + getitem_28723 = split_1532[38] + getitem_28724 = split_1532[39] + getitem_28725 = split_1532[40] + getitem_28726 = split_1532[41] + getitem_28727 = split_1532[42] + getitem_28728 = split_1532[43] + getitem_28729 = split_1532[44] + getitem_28730 = split_1532[45] + getitem_28731 = split_1532[46] + getitem_28732 = split_1532[47] + getitem_28733 = split_1532[48] + getitem_28734 = split_1532[49] + getitem_28735 = split_1532[50] + getitem_28736 = split_1532[51] + getitem_28737 = split_1532[52] + getitem_28738 = split_1532[53] + getitem_28739 = split_1532[54] + getitem_28740 = split_1532[55] + getitem_28741 = split_1532[56] + getitem_28742 = split_1532[57] + getitem_28743 = split_1532[58] + getitem_28744 = split_1532[59] + getitem_28745 = split_1532[60] + getitem_28746 = split_1532[61] + getitem_28747 = split_1532[62] + getitem_28748 = split_1532[63]; split_1532 = None + cat_439 = torch.ops.aten.cat.default([getitem_28685, getitem_28686, getitem_28687, getitem_28688, getitem_28689, getitem_28690, getitem_28691, getitem_28692, getitem_28693, getitem_28694, getitem_28695, getitem_28696, getitem_28697, getitem_28698, getitem_28699, getitem_28700, getitem_28701, getitem_28702, getitem_28703, getitem_28704, getitem_28705, getitem_28706, getitem_28707, getitem_28708, getitem_28709, getitem_28710, getitem_28711, getitem_28712, getitem_28713, getitem_28714, getitem_28715, getitem_28716, getitem_28717, getitem_28718, getitem_28719, getitem_28720, getitem_28721, getitem_28722, getitem_28723, getitem_28724, getitem_28725, getitem_28726, getitem_28727, getitem_28728, getitem_28729, getitem_28730, getitem_28731, getitem_28732, getitem_28733, getitem_28734, getitem_28735, getitem_28736, getitem_28737, getitem_28738, getitem_28739, getitem_28740, getitem_28741, getitem_28742, getitem_28743, getitem_28744, getitem_28745, getitem_28746, getitem_28747, getitem_28748, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd, constant_pad_nd]); getitem_28685 = getitem_28686 = getitem_28687 = getitem_28688 = getitem_28689 = getitem_28690 = getitem_28691 = getitem_28692 = getitem_28693 = getitem_28694 = getitem_28695 = getitem_28696 = getitem_28697 = getitem_28698 = getitem_28699 = getitem_28700 = getitem_28701 = getitem_28702 = getitem_28703 = getitem_28704 = getitem_28705 = getitem_28706 = getitem_28707 = getitem_28708 = getitem_28709 = getitem_28710 = getitem_28711 = getitem_28712 = getitem_28713 = getitem_28714 = getitem_28715 = getitem_28716 = getitem_28717 = getitem_28718 = getitem_28719 = getitem_28720 = getitem_28721 = getitem_28722 = getitem_28723 = getitem_28724 = getitem_28725 = getitem_28726 = getitem_28727 = getitem_28728 = getitem_28729 = getitem_28730 = getitem_28731 = getitem_28732 = getitem_28733 = getitem_28734 = getitem_28735 = getitem_28736 = getitem_28737 = getitem_28738 = getitem_28739 = getitem_28740 = getitem_28741 = getitem_28742 = getitem_28743 = getitem_28744 = getitem_28745 = getitem_28746 = getitem_28747 = getitem_28748 = constant_pad_nd = None + reduce_scatter_tensor_358 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_439, 'avg', 128, '0'); cat_439 = None + wait_tensor_969 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_358); reduce_scatter_tensor_358 = None + view_2260 = torch.ops.aten.view.default(add_2161, [2, 4096, 2048]); add_2161 = None + convert_element_type_3340 = torch.ops.prims.convert_element_type.default(view_2260, torch.float32); view_2260 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_59, 128, '0'); convert_element_type_59 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + convert_element_type_3342 = torch.ops.prims.convert_element_type.default(wait_tensor_17, torch.float32); wait_tensor_17 = None + mul_2155 = torch.ops.aten.mul.Tensor(convert_element_type_3340, convert_element_type_3342); convert_element_type_3342 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(add_8, torch.float32); add_8 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, rsqrt_5); convert_element_type_60 = None + mul_2157 = torch.ops.aten.mul.Tensor(mul_15, mul_2155) + sum_308 = torch.ops.aten.sum.dim_IntList(mul_2157, [2], True); mul_2157 = None + div_285 = torch.ops.aten.div.Tensor(mul_15, 2048) + mul_2158 = torch.ops.aten.mul.Tensor(div_285, sum_308); div_285 = sum_308 = None + sub_778 = torch.ops.aten.sub.Tensor(mul_2155, mul_2158); mul_2155 = mul_2158 = None + mul_2159 = torch.ops.aten.mul.Tensor(sub_778, rsqrt_5); sub_778 = rsqrt_5 = None + mul_2160 = torch.ops.aten.mul.Tensor(convert_element_type_3340, mul_15); convert_element_type_3340 = mul_15 = None + sum_309 = torch.ops.aten.sum.dim_IntList(mul_2160, [0, 1]); mul_2160 = None + convert_element_type_3343 = torch.ops.prims.convert_element_type.default(mul_2159, torch.bfloat16); mul_2159 = None + add_2162 = torch.ops.aten.add.Tensor(add_2149, convert_element_type_3343); add_2149 = convert_element_type_3343 = None + convert_element_type_default_6 = torch.ops.prims.convert_element_type.default(sum_309, torch.float32); sum_309 = None + reduce_scatter_tensor_359 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_6, 'avg', 128, '0'); convert_element_type_default_6 = None + wait_tensor_970 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_359); reduce_scatter_tensor_359 = None + view_2261 = torch.ops.aten.view.default(add_2162, [8192, 2048]) + permute_1686 = torch.ops.aten.permute.default(view_2261, [1, 0]) + permute_17 = torch.ops.aten.permute.default(getitem_15, [0, 2, 1, 3]) + view_53 = torch.ops.aten.view.default(permute_17, [2, 4096, -1]); permute_17 = None + view_55 = torch.ops.aten.view.default(view_53, [8192, 2048]); view_53 = None + mm_626 = torch.ops.aten.mm.default(permute_1686, view_55); permute_1686 = view_55 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 128, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + permute_1688 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None + mm_627 = torch.ops.aten.mm.default(view_2261, permute_1688); view_2261 = permute_1688 = None + view_2262 = torch.ops.aten.view.default(mm_627, [2, 4096, 2048]); mm_627 = None + convert_element_type_3350 = torch.ops.prims.convert_element_type.default(mm_626, torch.float32); mm_626 = None + reduce_scatter_tensor_360 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3350, 'avg', 128, '0'); convert_element_type_3350 = None + wait_tensor_971 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_360); reduce_scatter_tensor_360 = None + view_2263 = torch.ops.aten.view.default(view_2262, [2, 4096, 16, 128]); view_2262 = None + permute_1690 = torch.ops.aten.permute.default(view_2263, [0, 2, 1, 3]); view_2263 = None + fw_graph25 = self.fw_graph25 + joint_graph25 = self.joint_graph25 + mask_graph25 = self.mask_graph25 + flex_attention_backward_25 = torch.ops.higher_order.flex_attention_backward(permute_14, permute_15, permute_16, getitem_15, getitem_16, permute_1690, None, fw_graph25, joint_graph25, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph25), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_14 = permute_15 = permute_16 = getitem_15 = getitem_16 = permute_1690 = fw_graph25 = joint_graph25 = mask_graph25 = None + getitem_28749 = flex_attention_backward_25[0] + getitem_28750 = flex_attention_backward_25[1] + getitem_28751 = flex_attention_backward_25[2]; flex_attention_backward_25 = None + permute_1691 = torch.ops.aten.permute.default(getitem_28751, [0, 2, 1, 3]); getitem_28751 = None + permute_1692 = torch.ops.aten.permute.default(getitem_28750, [0, 2, 1, 3]); getitem_28750 = None + permute_1693 = torch.ops.aten.permute.default(getitem_28749, [0, 2, 1, 3]); getitem_28749 = None + slice_314 = torch.ops.aten.slice.Tensor(permute_1692, 3, 0, 128) + slice_315 = torch.ops.aten.slice.Tensor(permute_1692, 3, 128, 192); permute_1692 = None + sum_310 = torch.ops.aten.sum.dim_IntList(slice_315, [2], True); slice_315 = None + cat_440 = torch.ops.aten.cat.default([slice_314, permute_1691], 3); slice_314 = permute_1691 = None + view_2264 = torch.ops.aten.view.default(cat_440, [2, 4096, 4096]); cat_440 = None + view_2265 = torch.ops.aten.view.default(view_2264, [8192, 4096]); view_2264 = None + permute_1694 = torch.ops.aten.permute.default(view_2265, [1, 0]) + mm_628 = torch.ops.aten.mm.default(permute_1694, view_50); permute_1694 = view_50 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16); primals_27 = None + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 128, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_15, [1, 0]); wait_tensor_15 = None + permute_1696 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None + mm_629 = torch.ops.aten.mm.default(view_2265, permute_1696); view_2265 = permute_1696 = None + view_2266 = torch.ops.aten.view.default(mm_629, [2, 4096, 512]); mm_629 = None + convert_element_type_3355 = torch.ops.prims.convert_element_type.default(mm_628, torch.float32); mm_628 = None + reduce_scatter_tensor_361 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3355, 'avg', 128, '0'); convert_element_type_3355 = None + wait_tensor_972 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_361); reduce_scatter_tensor_361 = None + convert_element_type_3356 = torch.ops.prims.convert_element_type.default(view_2266, torch.float32); view_2266 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16); primals_26 = None + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 128, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + convert_element_type_3358 = torch.ops.prims.convert_element_type.default(wait_tensor_14, torch.float32); wait_tensor_14 = None + mul_2161 = torch.ops.aten.mul.Tensor(convert_element_type_3356, convert_element_type_3358); convert_element_type_3358 = None + convert_element_type_51 = torch.ops.prims.convert_element_type.default(getitem_11, torch.float32); getitem_11 = None + mul_13 = torch.ops.aten.mul.Tensor(convert_element_type_51, rsqrt_4); convert_element_type_51 = None + mul_2163 = torch.ops.aten.mul.Tensor(mul_13, mul_2161) + sum_311 = torch.ops.aten.sum.dim_IntList(mul_2163, [2], True); mul_2163 = None + div_286 = torch.ops.aten.div.Tensor(mul_13, 512) + mul_2164 = torch.ops.aten.mul.Tensor(div_286, sum_311); div_286 = sum_311 = None + sub_779 = torch.ops.aten.sub.Tensor(mul_2161, mul_2164); mul_2161 = mul_2164 = None + mul_2165 = torch.ops.aten.mul.Tensor(sub_779, rsqrt_4); sub_779 = rsqrt_4 = None + mul_2166 = torch.ops.aten.mul.Tensor(convert_element_type_3356, mul_13); convert_element_type_3356 = mul_13 = None + sum_312 = torch.ops.aten.sum.dim_IntList(mul_2166, [0, 1]); mul_2166 = None + convert_element_type_3359 = torch.ops.prims.convert_element_type.default(mul_2165, torch.bfloat16); mul_2165 = None + convert_element_type_default_5 = torch.ops.prims.convert_element_type.default(sum_312, torch.float32); sum_312 = None + reduce_scatter_tensor_362 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_5, 'avg', 128, '0'); convert_element_type_default_5 = None + wait_tensor_973 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_362); reduce_scatter_tensor_362 = None + convert_element_type_3362 = torch.ops.prims.convert_element_type.default(sum_310, torch.float32); sum_310 = None + view_2267 = torch.ops.aten.view.default(convert_element_type_3362, [2, 4096, 1, 32, 2]); convert_element_type_3362 = None + view_as_complex_104 = torch.ops.aten.view_as_complex.default(view_2267); view_2267 = None + mul_2167 = torch.ops.aten.mul.Tensor(view_as_complex_104, clone_9); view_as_complex_104 = None + view_as_real_104 = torch.ops.aten.view_as_real.default(mul_2167); mul_2167 = None + view_2268 = torch.ops.aten.view.default(view_as_real_104, [2, 4096, 1, 64]); view_as_real_104 = None + convert_element_type_3363 = torch.ops.prims.convert_element_type.default(view_2268, torch.bfloat16); view_2268 = None + squeeze_51 = torch.ops.aten.squeeze.dim(convert_element_type_3363, 2); convert_element_type_3363 = None + cat_441 = torch.ops.aten.cat.default([convert_element_type_3359, squeeze_51], 2); convert_element_type_3359 = squeeze_51 = None + view_2269 = torch.ops.aten.view.default(cat_441, [8192, 576]); cat_441 = None + permute_1698 = torch.ops.aten.permute.default(view_2269, [1, 0]) + mm_630 = torch.ops.aten.mm.default(permute_1698, view_36); permute_1698 = None + convert_element_type_45 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16); primals_25 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_45, 128, '0'); convert_element_type_45 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + slice_7 = torch.ops.aten.slice.Tensor(wait_tensor_13, 0, 0, 576); wait_tensor_13 = None + permute_12 = torch.ops.aten.permute.default(slice_7, [1, 0]); slice_7 = None + permute_1700 = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None + mm_631 = torch.ops.aten.mm.default(view_2269, permute_1700); view_2269 = permute_1700 = None + view_2270 = torch.ops.aten.view.default(mm_631, [2, 4096, 2048]); mm_631 = None + convert_element_type_3368 = torch.ops.prims.convert_element_type.default(mm_630, torch.float32); mm_630 = None + split_1533 = torch.ops.aten.split.Tensor(convert_element_type_3368, 5); convert_element_type_3368 = None + getitem_28753 = split_1533[0] + getitem_28754 = split_1533[1] + getitem_28755 = split_1533[2] + getitem_28756 = split_1533[3] + getitem_28757 = split_1533[4] + getitem_28758 = split_1533[5] + getitem_28759 = split_1533[6] + getitem_28760 = split_1533[7] + getitem_28761 = split_1533[8] + getitem_28762 = split_1533[9] + getitem_28763 = split_1533[10] + getitem_28764 = split_1533[11] + getitem_28765 = split_1533[12] + getitem_28766 = split_1533[13] + getitem_28767 = split_1533[14] + getitem_28768 = split_1533[15] + getitem_28769 = split_1533[16] + getitem_28770 = split_1533[17] + getitem_28771 = split_1533[18] + getitem_28772 = split_1533[19] + getitem_28773 = split_1533[20] + getitem_28774 = split_1533[21] + getitem_28775 = split_1533[22] + getitem_28776 = split_1533[23] + getitem_28777 = split_1533[24] + getitem_28778 = split_1533[25] + getitem_28779 = split_1533[26] + getitem_28780 = split_1533[27] + getitem_28781 = split_1533[28] + getitem_28782 = split_1533[29] + getitem_28783 = split_1533[30] + getitem_28784 = split_1533[31] + getitem_28785 = split_1533[32] + getitem_28786 = split_1533[33] + getitem_28787 = split_1533[34] + getitem_28788 = split_1533[35] + getitem_28789 = split_1533[36] + getitem_28790 = split_1533[37] + getitem_28791 = split_1533[38] + getitem_28792 = split_1533[39] + getitem_28793 = split_1533[40] + getitem_28794 = split_1533[41] + getitem_28795 = split_1533[42] + getitem_28796 = split_1533[43] + getitem_28797 = split_1533[44] + getitem_28798 = split_1533[45] + getitem_28799 = split_1533[46] + getitem_28800 = split_1533[47] + getitem_28801 = split_1533[48] + getitem_28802 = split_1533[49] + getitem_28803 = split_1533[50] + getitem_28804 = split_1533[51] + getitem_28805 = split_1533[52] + getitem_28806 = split_1533[53] + getitem_28807 = split_1533[54] + getitem_28808 = split_1533[55] + getitem_28809 = split_1533[56] + getitem_28810 = split_1533[57] + getitem_28811 = split_1533[58] + getitem_28812 = split_1533[59] + getitem_28813 = split_1533[60] + getitem_28814 = split_1533[61] + getitem_28815 = split_1533[62] + getitem_28816 = split_1533[63] + getitem_28817 = split_1533[64] + getitem_28818 = split_1533[65] + getitem_28819 = split_1533[66] + getitem_28820 = split_1533[67] + getitem_28821 = split_1533[68] + getitem_28822 = split_1533[69] + getitem_28823 = split_1533[70] + getitem_28824 = split_1533[71] + getitem_28825 = split_1533[72] + getitem_28826 = split_1533[73] + getitem_28827 = split_1533[74] + getitem_28828 = split_1533[75] + getitem_28829 = split_1533[76] + getitem_28830 = split_1533[77] + getitem_28831 = split_1533[78] + getitem_28832 = split_1533[79] + getitem_28833 = split_1533[80] + getitem_28834 = split_1533[81] + getitem_28835 = split_1533[82] + getitem_28836 = split_1533[83] + getitem_28837 = split_1533[84] + getitem_28838 = split_1533[85] + getitem_28839 = split_1533[86] + getitem_28840 = split_1533[87] + getitem_28841 = split_1533[88] + getitem_28842 = split_1533[89] + getitem_28843 = split_1533[90] + getitem_28844 = split_1533[91] + getitem_28845 = split_1533[92] + getitem_28846 = split_1533[93] + getitem_28847 = split_1533[94] + getitem_28848 = split_1533[95] + getitem_28849 = split_1533[96] + getitem_28850 = split_1533[97] + getitem_28851 = split_1533[98] + getitem_28852 = split_1533[99] + getitem_28853 = split_1533[100] + getitem_28854 = split_1533[101] + getitem_28855 = split_1533[102] + getitem_28856 = split_1533[103] + getitem_28857 = split_1533[104] + getitem_28858 = split_1533[105] + getitem_28859 = split_1533[106] + getitem_28860 = split_1533[107] + getitem_28861 = split_1533[108] + getitem_28862 = split_1533[109] + getitem_28863 = split_1533[110] + getitem_28864 = split_1533[111] + getitem_28865 = split_1533[112] + getitem_28866 = split_1533[113] + getitem_28867 = split_1533[114] + getitem_28868 = split_1533[115]; split_1533 = None + constant_pad_nd_1989 = torch.ops.aten.constant_pad_nd.default(getitem_28868, [0, 0, 0, 4], 0.0); getitem_28868 = None + cat_442 = torch.ops.aten.cat.default([getitem_28753, getitem_28754, getitem_28755, getitem_28756, getitem_28757, getitem_28758, getitem_28759, getitem_28760, getitem_28761, getitem_28762, getitem_28763, getitem_28764, getitem_28765, getitem_28766, getitem_28767, getitem_28768, getitem_28769, getitem_28770, getitem_28771, getitem_28772, getitem_28773, getitem_28774, getitem_28775, getitem_28776, getitem_28777, getitem_28778, getitem_28779, getitem_28780, getitem_28781, getitem_28782, getitem_28783, getitem_28784, getitem_28785, getitem_28786, getitem_28787, getitem_28788, getitem_28789, getitem_28790, getitem_28791, getitem_28792, getitem_28793, getitem_28794, getitem_28795, getitem_28796, getitem_28797, getitem_28798, getitem_28799, getitem_28800, getitem_28801, getitem_28802, getitem_28803, getitem_28804, getitem_28805, getitem_28806, getitem_28807, getitem_28808, getitem_28809, getitem_28810, getitem_28811, getitem_28812, getitem_28813, getitem_28814, getitem_28815, getitem_28816, getitem_28817, getitem_28818, getitem_28819, getitem_28820, getitem_28821, getitem_28822, getitem_28823, getitem_28824, getitem_28825, getitem_28826, getitem_28827, getitem_28828, getitem_28829, getitem_28830, getitem_28831, getitem_28832, getitem_28833, getitem_28834, getitem_28835, getitem_28836, getitem_28837, getitem_28838, getitem_28839, getitem_28840, getitem_28841, getitem_28842, getitem_28843, getitem_28844, getitem_28845, getitem_28846, getitem_28847, getitem_28848, getitem_28849, getitem_28850, getitem_28851, getitem_28852, getitem_28853, getitem_28854, getitem_28855, getitem_28856, getitem_28857, getitem_28858, getitem_28859, getitem_28860, getitem_28861, getitem_28862, getitem_28863, getitem_28864, getitem_28865, getitem_28866, getitem_28867, constant_pad_nd_1989, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_28753 = getitem_28754 = getitem_28755 = getitem_28756 = getitem_28757 = getitem_28758 = getitem_28759 = getitem_28760 = getitem_28761 = getitem_28762 = getitem_28763 = getitem_28764 = getitem_28765 = getitem_28766 = getitem_28767 = getitem_28768 = getitem_28769 = getitem_28770 = getitem_28771 = getitem_28772 = getitem_28773 = getitem_28774 = getitem_28775 = getitem_28776 = getitem_28777 = getitem_28778 = getitem_28779 = getitem_28780 = getitem_28781 = getitem_28782 = getitem_28783 = getitem_28784 = getitem_28785 = getitem_28786 = getitem_28787 = getitem_28788 = getitem_28789 = getitem_28790 = getitem_28791 = getitem_28792 = getitem_28793 = getitem_28794 = getitem_28795 = getitem_28796 = getitem_28797 = getitem_28798 = getitem_28799 = getitem_28800 = getitem_28801 = getitem_28802 = getitem_28803 = getitem_28804 = getitem_28805 = getitem_28806 = getitem_28807 = getitem_28808 = getitem_28809 = getitem_28810 = getitem_28811 = getitem_28812 = getitem_28813 = getitem_28814 = getitem_28815 = getitem_28816 = getitem_28817 = getitem_28818 = getitem_28819 = getitem_28820 = getitem_28821 = getitem_28822 = getitem_28823 = getitem_28824 = getitem_28825 = getitem_28826 = getitem_28827 = getitem_28828 = getitem_28829 = getitem_28830 = getitem_28831 = getitem_28832 = getitem_28833 = getitem_28834 = getitem_28835 = getitem_28836 = getitem_28837 = getitem_28838 = getitem_28839 = getitem_28840 = getitem_28841 = getitem_28842 = getitem_28843 = getitem_28844 = getitem_28845 = getitem_28846 = getitem_28847 = getitem_28848 = getitem_28849 = getitem_28850 = getitem_28851 = getitem_28852 = getitem_28853 = getitem_28854 = getitem_28855 = getitem_28856 = getitem_28857 = getitem_28858 = getitem_28859 = getitem_28860 = getitem_28861 = getitem_28862 = getitem_28863 = getitem_28864 = getitem_28865 = getitem_28866 = getitem_28867 = constant_pad_nd_1989 = None + reduce_scatter_tensor_363 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_442, 'avg', 128, '0'); cat_442 = None + wait_tensor_974 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_363); reduce_scatter_tensor_363 = None + slice_316 = torch.ops.aten.slice.Tensor(permute_1693, 3, 0, 128) + slice_317 = torch.ops.aten.slice.Tensor(permute_1693, 3, 128, 192); permute_1693 = None + convert_element_type_3369 = torch.ops.prims.convert_element_type.default(slice_317, torch.float32); slice_317 = None + view_2271 = torch.ops.aten.view.default(convert_element_type_3369, [2, 4096, 16, 32, 2]); convert_element_type_3369 = None + view_as_complex_105 = torch.ops.aten.view_as_complex.default(view_2271); view_2271 = None + mul_2168 = torch.ops.aten.mul.Tensor(view_as_complex_105, clone_9); view_as_complex_105 = None + view_as_real_105 = torch.ops.aten.view_as_real.default(mul_2168); mul_2168 = None + view_2272 = torch.ops.aten.view.default(view_as_real_105, [2, 4096, 16, 64]); view_as_real_105 = None + convert_element_type_3370 = torch.ops.prims.convert_element_type.default(view_2272, torch.bfloat16); view_2272 = None + cat_443 = torch.ops.aten.cat.default([slice_316, convert_element_type_3370], 3); slice_316 = convert_element_type_3370 = None + view_2273 = torch.ops.aten.view.default(cat_443, [2, 4096, 3072]); cat_443 = None + view_2274 = torch.ops.aten.view.default(view_2273, [8192, 3072]); view_2273 = None + permute_1702 = torch.ops.aten.permute.default(view_2274, [1, 0]) + mm_632 = torch.ops.aten.mm.default(permute_1702, view_36); permute_1702 = view_36 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16); primals_24 = None + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 128, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + permute_1704 = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None + mm_633 = torch.ops.aten.mm.default(view_2274, permute_1704); view_2274 = permute_1704 = None + view_2275 = torch.ops.aten.view.default(mm_633, [2, 4096, 2048]); mm_633 = None + add_2163 = torch.ops.aten.add.Tensor(view_2270, view_2275); view_2270 = view_2275 = None + convert_element_type_3375 = torch.ops.prims.convert_element_type.default(mm_632, torch.float32); mm_632 = None + reduce_scatter_tensor_364 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3375, 'avg', 128, '0'); convert_element_type_3375 = None + wait_tensor_975 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_364); reduce_scatter_tensor_364 = None + convert_element_type_3376 = torch.ops.prims.convert_element_type.default(add_2163, torch.float32); add_2163 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16); primals_23 = None + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 128, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + convert_element_type_3378 = torch.ops.prims.convert_element_type.default(wait_tensor_11, torch.float32); wait_tensor_11 = None + mul_2169 = torch.ops.aten.mul.Tensor(convert_element_type_3376, convert_element_type_3378); convert_element_type_3378 = None + convert_element_type_38 = torch.ops.prims.convert_element_type.default(add_5, torch.float32); add_5 = None + mul_9 = torch.ops.aten.mul.Tensor(convert_element_type_38, rsqrt_3); convert_element_type_38 = None + mul_2171 = torch.ops.aten.mul.Tensor(mul_9, mul_2169) + sum_313 = torch.ops.aten.sum.dim_IntList(mul_2171, [2], True); mul_2171 = None + div_287 = torch.ops.aten.div.Tensor(mul_9, 2048) + mul_2172 = torch.ops.aten.mul.Tensor(div_287, sum_313); div_287 = sum_313 = None + sub_780 = torch.ops.aten.sub.Tensor(mul_2169, mul_2172); mul_2169 = mul_2172 = None + mul_2173 = torch.ops.aten.mul.Tensor(sub_780, rsqrt_3); sub_780 = rsqrt_3 = None + mul_2174 = torch.ops.aten.mul.Tensor(convert_element_type_3376, mul_9); convert_element_type_3376 = mul_9 = None + sum_314 = torch.ops.aten.sum.dim_IntList(mul_2174, [0, 1]); mul_2174 = None + convert_element_type_3379 = torch.ops.prims.convert_element_type.default(mul_2173, torch.bfloat16); mul_2173 = None + add_2164 = torch.ops.aten.add.Tensor(add_2162, convert_element_type_3379); add_2162 = convert_element_type_3379 = None + convert_element_type_default_4 = torch.ops.prims.convert_element_type.default(sum_314, torch.float32); sum_314 = None + reduce_scatter_tensor_365 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_4, 'avg', 128, '0'); convert_element_type_default_4 = None + wait_tensor_976 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_365); reduce_scatter_tensor_365 = None + view_2276 = torch.ops.aten.view.default(add_2164, [8192, 2048]) + permute_1706 = torch.ops.aten.permute.default(view_2276, [1, 0]) + mm_634 = torch.ops.aten.mm.default(permute_1706, view_32); permute_1706 = view_32 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16); primals_22 = None + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 128, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_10, [1, 0]); wait_tensor_10 = None + permute_1708 = torch.ops.aten.permute.default(permute_10, [1, 0]); permute_10 = None + mm_635 = torch.ops.aten.mm.default(view_2276, permute_1708); view_2276 = permute_1708 = None + view_2277 = torch.ops.aten.view.default(mm_635, [2, 4096, 10944]); mm_635 = None + convert_element_type_3386 = torch.ops.prims.convert_element_type.default(mm_634, torch.float32); mm_634 = None + reduce_scatter_tensor_366 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3386, 'avg', 128, '0'); convert_element_type_3386 = None + wait_tensor_977 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_366); reduce_scatter_tensor_366 = None + view_27 = torch.ops.aten.view.default(mm_4, [2, 4096, 10944]); mm_4 = None + convert_element_type_29 = torch.ops.prims.convert_element_type.default(view_27, torch.float32); view_27 = None + neg = torch.ops.aten.neg.default(convert_element_type_29) + exp = torch.ops.aten.exp.default(neg); neg = None + add_4 = torch.ops.aten.add.Tensor(exp, 1); exp = None + div = torch.ops.aten.div.Tensor(convert_element_type_29, add_4) + convert_element_type_30 = torch.ops.prims.convert_element_type.default(div, torch.bfloat16); div = None + mul_2175 = torch.ops.aten.mul.Tensor(view_2277, convert_element_type_30); convert_element_type_30 = None + view_30 = torch.ops.aten.view.default(mm_5, [2, 4096, 10944]); mm_5 = None + mul_2176 = torch.ops.aten.mul.Tensor(view_2277, view_30); view_2277 = view_30 = None + view_2278 = torch.ops.aten.view.default(mul_2175, [8192, 10944]); mul_2175 = None + permute_1710 = torch.ops.aten.permute.default(view_2278, [1, 0]) + mm_636 = torch.ops.aten.mm.default(permute_1710, view_26); permute_1710 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16); primals_21 = None + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 128, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + slice_5 = torch.ops.aten.slice.Tensor(wait_tensor_9, 0, 0, 10944); wait_tensor_9 = None + permute_9 = torch.ops.aten.permute.default(slice_5, [1, 0]); slice_5 = None + permute_1712 = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None + mm_637 = torch.ops.aten.mm.default(view_2278, permute_1712); view_2278 = permute_1712 = None + view_2279 = torch.ops.aten.view.default(mm_637, [2, 4096, 2048]); mm_637 = None + convert_element_type_3391 = torch.ops.prims.convert_element_type.default(mm_636, torch.float32); mm_636 = None + split_1534 = torch.ops.aten.split.Tensor(convert_element_type_3391, 86); convert_element_type_3391 = None + getitem_28869 = split_1534[0] + getitem_28870 = split_1534[1] + getitem_28871 = split_1534[2] + getitem_28872 = split_1534[3] + getitem_28873 = split_1534[4] + getitem_28874 = split_1534[5] + getitem_28875 = split_1534[6] + getitem_28876 = split_1534[7] + getitem_28877 = split_1534[8] + getitem_28878 = split_1534[9] + getitem_28879 = split_1534[10] + getitem_28880 = split_1534[11] + getitem_28881 = split_1534[12] + getitem_28882 = split_1534[13] + getitem_28883 = split_1534[14] + getitem_28884 = split_1534[15] + getitem_28885 = split_1534[16] + getitem_28886 = split_1534[17] + getitem_28887 = split_1534[18] + getitem_28888 = split_1534[19] + getitem_28889 = split_1534[20] + getitem_28890 = split_1534[21] + getitem_28891 = split_1534[22] + getitem_28892 = split_1534[23] + getitem_28893 = split_1534[24] + getitem_28894 = split_1534[25] + getitem_28895 = split_1534[26] + getitem_28896 = split_1534[27] + getitem_28897 = split_1534[28] + getitem_28898 = split_1534[29] + getitem_28899 = split_1534[30] + getitem_28900 = split_1534[31] + getitem_28901 = split_1534[32] + getitem_28902 = split_1534[33] + getitem_28903 = split_1534[34] + getitem_28904 = split_1534[35] + getitem_28905 = split_1534[36] + getitem_28906 = split_1534[37] + getitem_28907 = split_1534[38] + getitem_28908 = split_1534[39] + getitem_28909 = split_1534[40] + getitem_28910 = split_1534[41] + getitem_28911 = split_1534[42] + getitem_28912 = split_1534[43] + getitem_28913 = split_1534[44] + getitem_28914 = split_1534[45] + getitem_28915 = split_1534[46] + getitem_28916 = split_1534[47] + getitem_28917 = split_1534[48] + getitem_28918 = split_1534[49] + getitem_28919 = split_1534[50] + getitem_28920 = split_1534[51] + getitem_28921 = split_1534[52] + getitem_28922 = split_1534[53] + getitem_28923 = split_1534[54] + getitem_28924 = split_1534[55] + getitem_28925 = split_1534[56] + getitem_28926 = split_1534[57] + getitem_28927 = split_1534[58] + getitem_28928 = split_1534[59] + getitem_28929 = split_1534[60] + getitem_28930 = split_1534[61] + getitem_28931 = split_1534[62] + getitem_28932 = split_1534[63] + getitem_28933 = split_1534[64] + getitem_28934 = split_1534[65] + getitem_28935 = split_1534[66] + getitem_28936 = split_1534[67] + getitem_28937 = split_1534[68] + getitem_28938 = split_1534[69] + getitem_28939 = split_1534[70] + getitem_28940 = split_1534[71] + getitem_28941 = split_1534[72] + getitem_28942 = split_1534[73] + getitem_28943 = split_1534[74] + getitem_28944 = split_1534[75] + getitem_28945 = split_1534[76] + getitem_28946 = split_1534[77] + getitem_28947 = split_1534[78] + getitem_28948 = split_1534[79] + getitem_28949 = split_1534[80] + getitem_28950 = split_1534[81] + getitem_28951 = split_1534[82] + getitem_28952 = split_1534[83] + getitem_28953 = split_1534[84] + getitem_28954 = split_1534[85] + getitem_28955 = split_1534[86] + getitem_28956 = split_1534[87] + getitem_28957 = split_1534[88] + getitem_28958 = split_1534[89] + getitem_28959 = split_1534[90] + getitem_28960 = split_1534[91] + getitem_28961 = split_1534[92] + getitem_28962 = split_1534[93] + getitem_28963 = split_1534[94] + getitem_28964 = split_1534[95] + getitem_28965 = split_1534[96] + getitem_28966 = split_1534[97] + getitem_28967 = split_1534[98] + getitem_28968 = split_1534[99] + getitem_28969 = split_1534[100] + getitem_28970 = split_1534[101] + getitem_28971 = split_1534[102] + getitem_28972 = split_1534[103] + getitem_28973 = split_1534[104] + getitem_28974 = split_1534[105] + getitem_28975 = split_1534[106] + getitem_28976 = split_1534[107] + getitem_28977 = split_1534[108] + getitem_28978 = split_1534[109] + getitem_28979 = split_1534[110] + getitem_28980 = split_1534[111] + getitem_28981 = split_1534[112] + getitem_28982 = split_1534[113] + getitem_28983 = split_1534[114] + getitem_28984 = split_1534[115] + getitem_28985 = split_1534[116] + getitem_28986 = split_1534[117] + getitem_28987 = split_1534[118] + getitem_28988 = split_1534[119] + getitem_28989 = split_1534[120] + getitem_28990 = split_1534[121] + getitem_28991 = split_1534[122] + getitem_28992 = split_1534[123] + getitem_28993 = split_1534[124] + getitem_28994 = split_1534[125] + getitem_28995 = split_1534[126] + getitem_28996 = split_1534[127]; split_1534 = None + constant_pad_nd_2002 = torch.ops.aten.constant_pad_nd.default(getitem_28996, [0, 0, 0, 64], 0.0); getitem_28996 = None + cat_444 = torch.ops.aten.cat.default([getitem_28869, getitem_28870, getitem_28871, getitem_28872, getitem_28873, getitem_28874, getitem_28875, getitem_28876, getitem_28877, getitem_28878, getitem_28879, getitem_28880, getitem_28881, getitem_28882, getitem_28883, getitem_28884, getitem_28885, getitem_28886, getitem_28887, getitem_28888, getitem_28889, getitem_28890, getitem_28891, getitem_28892, getitem_28893, getitem_28894, getitem_28895, getitem_28896, getitem_28897, getitem_28898, getitem_28899, getitem_28900, getitem_28901, getitem_28902, getitem_28903, getitem_28904, getitem_28905, getitem_28906, getitem_28907, getitem_28908, getitem_28909, getitem_28910, getitem_28911, getitem_28912, getitem_28913, getitem_28914, getitem_28915, getitem_28916, getitem_28917, getitem_28918, getitem_28919, getitem_28920, getitem_28921, getitem_28922, getitem_28923, getitem_28924, getitem_28925, getitem_28926, getitem_28927, getitem_28928, getitem_28929, getitem_28930, getitem_28931, getitem_28932, getitem_28933, getitem_28934, getitem_28935, getitem_28936, getitem_28937, getitem_28938, getitem_28939, getitem_28940, getitem_28941, getitem_28942, getitem_28943, getitem_28944, getitem_28945, getitem_28946, getitem_28947, getitem_28948, getitem_28949, getitem_28950, getitem_28951, getitem_28952, getitem_28953, getitem_28954, getitem_28955, getitem_28956, getitem_28957, getitem_28958, getitem_28959, getitem_28960, getitem_28961, getitem_28962, getitem_28963, getitem_28964, getitem_28965, getitem_28966, getitem_28967, getitem_28968, getitem_28969, getitem_28970, getitem_28971, getitem_28972, getitem_28973, getitem_28974, getitem_28975, getitem_28976, getitem_28977, getitem_28978, getitem_28979, getitem_28980, getitem_28981, getitem_28982, getitem_28983, getitem_28984, getitem_28985, getitem_28986, getitem_28987, getitem_28988, getitem_28989, getitem_28990, getitem_28991, getitem_28992, getitem_28993, getitem_28994, getitem_28995, constant_pad_nd_2002]); getitem_28869 = getitem_28870 = getitem_28871 = getitem_28872 = getitem_28873 = getitem_28874 = getitem_28875 = getitem_28876 = getitem_28877 = getitem_28878 = getitem_28879 = getitem_28880 = getitem_28881 = getitem_28882 = getitem_28883 = getitem_28884 = getitem_28885 = getitem_28886 = getitem_28887 = getitem_28888 = getitem_28889 = getitem_28890 = getitem_28891 = getitem_28892 = getitem_28893 = getitem_28894 = getitem_28895 = getitem_28896 = getitem_28897 = getitem_28898 = getitem_28899 = getitem_28900 = getitem_28901 = getitem_28902 = getitem_28903 = getitem_28904 = getitem_28905 = getitem_28906 = getitem_28907 = getitem_28908 = getitem_28909 = getitem_28910 = getitem_28911 = getitem_28912 = getitem_28913 = getitem_28914 = getitem_28915 = getitem_28916 = getitem_28917 = getitem_28918 = getitem_28919 = getitem_28920 = getitem_28921 = getitem_28922 = getitem_28923 = getitem_28924 = getitem_28925 = getitem_28926 = getitem_28927 = getitem_28928 = getitem_28929 = getitem_28930 = getitem_28931 = getitem_28932 = getitem_28933 = getitem_28934 = getitem_28935 = getitem_28936 = getitem_28937 = getitem_28938 = getitem_28939 = getitem_28940 = getitem_28941 = getitem_28942 = getitem_28943 = getitem_28944 = getitem_28945 = getitem_28946 = getitem_28947 = getitem_28948 = getitem_28949 = getitem_28950 = getitem_28951 = getitem_28952 = getitem_28953 = getitem_28954 = getitem_28955 = getitem_28956 = getitem_28957 = getitem_28958 = getitem_28959 = getitem_28960 = getitem_28961 = getitem_28962 = getitem_28963 = getitem_28964 = getitem_28965 = getitem_28966 = getitem_28967 = getitem_28968 = getitem_28969 = getitem_28970 = getitem_28971 = getitem_28972 = getitem_28973 = getitem_28974 = getitem_28975 = getitem_28976 = getitem_28977 = getitem_28978 = getitem_28979 = getitem_28980 = getitem_28981 = getitem_28982 = getitem_28983 = getitem_28984 = getitem_28985 = getitem_28986 = getitem_28987 = getitem_28988 = getitem_28989 = getitem_28990 = getitem_28991 = getitem_28992 = getitem_28993 = getitem_28994 = getitem_28995 = constant_pad_nd_2002 = None + reduce_scatter_tensor_367 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_444, 'avg', 128, '0'); cat_444 = None + wait_tensor_978 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_367); reduce_scatter_tensor_367 = None + convert_element_type_3392 = torch.ops.prims.convert_element_type.default(mul_2176, torch.float32); mul_2176 = None + reciprocal_52 = torch.ops.aten.reciprocal.default(add_4); add_4 = None + mul_2177 = torch.ops.aten.mul.Tensor(reciprocal_52, 1); reciprocal_52 = None + mul_2178 = torch.ops.aten.mul.Tensor(convert_element_type_3392, mul_2177); convert_element_type_3392 = None + sub_781 = torch.ops.aten.sub.Tensor(1, mul_2177); mul_2177 = None + mul_2179 = torch.ops.aten.mul.Tensor(convert_element_type_29, sub_781); convert_element_type_29 = sub_781 = None + add_2166 = torch.ops.aten.add.Tensor(mul_2179, 1); mul_2179 = None + mul_2180 = torch.ops.aten.mul.Tensor(mul_2178, add_2166); mul_2178 = add_2166 = None + convert_element_type_3394 = torch.ops.prims.convert_element_type.default(mul_2180, torch.bfloat16); mul_2180 = None + view_2280 = torch.ops.aten.view.default(convert_element_type_3394, [8192, 10944]); convert_element_type_3394 = None + permute_1714 = torch.ops.aten.permute.default(view_2280, [1, 0]) + mm_638 = torch.ops.aten.mm.default(permute_1714, view_26); permute_1714 = view_26 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16); primals_20 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_26, 128, '0'); convert_element_type_26 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + slice_4 = torch.ops.aten.slice.Tensor(wait_tensor_8, 0, 0, 10944); wait_tensor_8 = None + permute_8 = torch.ops.aten.permute.default(slice_4, [1, 0]); slice_4 = None + permute_1716 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None + mm_639 = torch.ops.aten.mm.default(view_2280, permute_1716); view_2280 = permute_1716 = None + view_2281 = torch.ops.aten.view.default(mm_639, [2, 4096, 2048]); mm_639 = None + add_2167 = torch.ops.aten.add.Tensor(view_2279, view_2281); view_2279 = view_2281 = None + convert_element_type_3399 = torch.ops.prims.convert_element_type.default(mm_638, torch.float32); mm_638 = None + split_1535 = torch.ops.aten.split.Tensor(convert_element_type_3399, 86); convert_element_type_3399 = None + getitem_28997 = split_1535[0] + getitem_28998 = split_1535[1] + getitem_28999 = split_1535[2] + getitem_29000 = split_1535[3] + getitem_29001 = split_1535[4] + getitem_29002 = split_1535[5] + getitem_29003 = split_1535[6] + getitem_29004 = split_1535[7] + getitem_29005 = split_1535[8] + getitem_29006 = split_1535[9] + getitem_29007 = split_1535[10] + getitem_29008 = split_1535[11] + getitem_29009 = split_1535[12] + getitem_29010 = split_1535[13] + getitem_29011 = split_1535[14] + getitem_29012 = split_1535[15] + getitem_29013 = split_1535[16] + getitem_29014 = split_1535[17] + getitem_29015 = split_1535[18] + getitem_29016 = split_1535[19] + getitem_29017 = split_1535[20] + getitem_29018 = split_1535[21] + getitem_29019 = split_1535[22] + getitem_29020 = split_1535[23] + getitem_29021 = split_1535[24] + getitem_29022 = split_1535[25] + getitem_29023 = split_1535[26] + getitem_29024 = split_1535[27] + getitem_29025 = split_1535[28] + getitem_29026 = split_1535[29] + getitem_29027 = split_1535[30] + getitem_29028 = split_1535[31] + getitem_29029 = split_1535[32] + getitem_29030 = split_1535[33] + getitem_29031 = split_1535[34] + getitem_29032 = split_1535[35] + getitem_29033 = split_1535[36] + getitem_29034 = split_1535[37] + getitem_29035 = split_1535[38] + getitem_29036 = split_1535[39] + getitem_29037 = split_1535[40] + getitem_29038 = split_1535[41] + getitem_29039 = split_1535[42] + getitem_29040 = split_1535[43] + getitem_29041 = split_1535[44] + getitem_29042 = split_1535[45] + getitem_29043 = split_1535[46] + getitem_29044 = split_1535[47] + getitem_29045 = split_1535[48] + getitem_29046 = split_1535[49] + getitem_29047 = split_1535[50] + getitem_29048 = split_1535[51] + getitem_29049 = split_1535[52] + getitem_29050 = split_1535[53] + getitem_29051 = split_1535[54] + getitem_29052 = split_1535[55] + getitem_29053 = split_1535[56] + getitem_29054 = split_1535[57] + getitem_29055 = split_1535[58] + getitem_29056 = split_1535[59] + getitem_29057 = split_1535[60] + getitem_29058 = split_1535[61] + getitem_29059 = split_1535[62] + getitem_29060 = split_1535[63] + getitem_29061 = split_1535[64] + getitem_29062 = split_1535[65] + getitem_29063 = split_1535[66] + getitem_29064 = split_1535[67] + getitem_29065 = split_1535[68] + getitem_29066 = split_1535[69] + getitem_29067 = split_1535[70] + getitem_29068 = split_1535[71] + getitem_29069 = split_1535[72] + getitem_29070 = split_1535[73] + getitem_29071 = split_1535[74] + getitem_29072 = split_1535[75] + getitem_29073 = split_1535[76] + getitem_29074 = split_1535[77] + getitem_29075 = split_1535[78] + getitem_29076 = split_1535[79] + getitem_29077 = split_1535[80] + getitem_29078 = split_1535[81] + getitem_29079 = split_1535[82] + getitem_29080 = split_1535[83] + getitem_29081 = split_1535[84] + getitem_29082 = split_1535[85] + getitem_29083 = split_1535[86] + getitem_29084 = split_1535[87] + getitem_29085 = split_1535[88] + getitem_29086 = split_1535[89] + getitem_29087 = split_1535[90] + getitem_29088 = split_1535[91] + getitem_29089 = split_1535[92] + getitem_29090 = split_1535[93] + getitem_29091 = split_1535[94] + getitem_29092 = split_1535[95] + getitem_29093 = split_1535[96] + getitem_29094 = split_1535[97] + getitem_29095 = split_1535[98] + getitem_29096 = split_1535[99] + getitem_29097 = split_1535[100] + getitem_29098 = split_1535[101] + getitem_29099 = split_1535[102] + getitem_29100 = split_1535[103] + getitem_29101 = split_1535[104] + getitem_29102 = split_1535[105] + getitem_29103 = split_1535[106] + getitem_29104 = split_1535[107] + getitem_29105 = split_1535[108] + getitem_29106 = split_1535[109] + getitem_29107 = split_1535[110] + getitem_29108 = split_1535[111] + getitem_29109 = split_1535[112] + getitem_29110 = split_1535[113] + getitem_29111 = split_1535[114] + getitem_29112 = split_1535[115] + getitem_29113 = split_1535[116] + getitem_29114 = split_1535[117] + getitem_29115 = split_1535[118] + getitem_29116 = split_1535[119] + getitem_29117 = split_1535[120] + getitem_29118 = split_1535[121] + getitem_29119 = split_1535[122] + getitem_29120 = split_1535[123] + getitem_29121 = split_1535[124] + getitem_29122 = split_1535[125] + getitem_29123 = split_1535[126] + getitem_29124 = split_1535[127]; split_1535 = None + constant_pad_nd_2003 = torch.ops.aten.constant_pad_nd.default(getitem_29124, [0, 0, 0, 64], 0.0); getitem_29124 = None + cat_445 = torch.ops.aten.cat.default([getitem_28997, getitem_28998, getitem_28999, getitem_29000, getitem_29001, getitem_29002, getitem_29003, getitem_29004, getitem_29005, getitem_29006, getitem_29007, getitem_29008, getitem_29009, getitem_29010, getitem_29011, getitem_29012, getitem_29013, getitem_29014, getitem_29015, getitem_29016, getitem_29017, getitem_29018, getitem_29019, getitem_29020, getitem_29021, getitem_29022, getitem_29023, getitem_29024, getitem_29025, getitem_29026, getitem_29027, getitem_29028, getitem_29029, getitem_29030, getitem_29031, getitem_29032, getitem_29033, getitem_29034, getitem_29035, getitem_29036, getitem_29037, getitem_29038, getitem_29039, getitem_29040, getitem_29041, getitem_29042, getitem_29043, getitem_29044, getitem_29045, getitem_29046, getitem_29047, getitem_29048, getitem_29049, getitem_29050, getitem_29051, getitem_29052, getitem_29053, getitem_29054, getitem_29055, getitem_29056, getitem_29057, getitem_29058, getitem_29059, getitem_29060, getitem_29061, getitem_29062, getitem_29063, getitem_29064, getitem_29065, getitem_29066, getitem_29067, getitem_29068, getitem_29069, getitem_29070, getitem_29071, getitem_29072, getitem_29073, getitem_29074, getitem_29075, getitem_29076, getitem_29077, getitem_29078, getitem_29079, getitem_29080, getitem_29081, getitem_29082, getitem_29083, getitem_29084, getitem_29085, getitem_29086, getitem_29087, getitem_29088, getitem_29089, getitem_29090, getitem_29091, getitem_29092, getitem_29093, getitem_29094, getitem_29095, getitem_29096, getitem_29097, getitem_29098, getitem_29099, getitem_29100, getitem_29101, getitem_29102, getitem_29103, getitem_29104, getitem_29105, getitem_29106, getitem_29107, getitem_29108, getitem_29109, getitem_29110, getitem_29111, getitem_29112, getitem_29113, getitem_29114, getitem_29115, getitem_29116, getitem_29117, getitem_29118, getitem_29119, getitem_29120, getitem_29121, getitem_29122, getitem_29123, constant_pad_nd_2003]); getitem_28997 = getitem_28998 = getitem_28999 = getitem_29000 = getitem_29001 = getitem_29002 = getitem_29003 = getitem_29004 = getitem_29005 = getitem_29006 = getitem_29007 = getitem_29008 = getitem_29009 = getitem_29010 = getitem_29011 = getitem_29012 = getitem_29013 = getitem_29014 = getitem_29015 = getitem_29016 = getitem_29017 = getitem_29018 = getitem_29019 = getitem_29020 = getitem_29021 = getitem_29022 = getitem_29023 = getitem_29024 = getitem_29025 = getitem_29026 = getitem_29027 = getitem_29028 = getitem_29029 = getitem_29030 = getitem_29031 = getitem_29032 = getitem_29033 = getitem_29034 = getitem_29035 = getitem_29036 = getitem_29037 = getitem_29038 = getitem_29039 = getitem_29040 = getitem_29041 = getitem_29042 = getitem_29043 = getitem_29044 = getitem_29045 = getitem_29046 = getitem_29047 = getitem_29048 = getitem_29049 = getitem_29050 = getitem_29051 = getitem_29052 = getitem_29053 = getitem_29054 = getitem_29055 = getitem_29056 = getitem_29057 = getitem_29058 = getitem_29059 = getitem_29060 = getitem_29061 = getitem_29062 = getitem_29063 = getitem_29064 = getitem_29065 = getitem_29066 = getitem_29067 = getitem_29068 = getitem_29069 = getitem_29070 = getitem_29071 = getitem_29072 = getitem_29073 = getitem_29074 = getitem_29075 = getitem_29076 = getitem_29077 = getitem_29078 = getitem_29079 = getitem_29080 = getitem_29081 = getitem_29082 = getitem_29083 = getitem_29084 = getitem_29085 = getitem_29086 = getitem_29087 = getitem_29088 = getitem_29089 = getitem_29090 = getitem_29091 = getitem_29092 = getitem_29093 = getitem_29094 = getitem_29095 = getitem_29096 = getitem_29097 = getitem_29098 = getitem_29099 = getitem_29100 = getitem_29101 = getitem_29102 = getitem_29103 = getitem_29104 = getitem_29105 = getitem_29106 = getitem_29107 = getitem_29108 = getitem_29109 = getitem_29110 = getitem_29111 = getitem_29112 = getitem_29113 = getitem_29114 = getitem_29115 = getitem_29116 = getitem_29117 = getitem_29118 = getitem_29119 = getitem_29120 = getitem_29121 = getitem_29122 = getitem_29123 = constant_pad_nd_2003 = None + reduce_scatter_tensor_368 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_445, 'avg', 128, '0'); cat_445 = None + wait_tensor_979 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_368); reduce_scatter_tensor_368 = None + convert_element_type_3400 = torch.ops.prims.convert_element_type.default(add_2167, torch.float32); add_2167 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16); primals_19 = None + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 128, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_3402 = torch.ops.prims.convert_element_type.default(wait_tensor_7, torch.float32); wait_tensor_7 = None + mul_2181 = torch.ops.aten.mul.Tensor(convert_element_type_3400, convert_element_type_3402); convert_element_type_3402 = None + view_23 = torch.ops.aten.view.default(mm_3, [2, 4096, 2048]); mm_3 = None + add_2 = torch.ops.aten.add.Tensor(embedding, view_23); view_23 = None + convert_element_type_24 = torch.ops.prims.convert_element_type.default(add_2, torch.float32); add_2 = None + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_24, rsqrt_2); convert_element_type_24 = None + mul_2183 = torch.ops.aten.mul.Tensor(mul_6, mul_2181) + sum_315 = torch.ops.aten.sum.dim_IntList(mul_2183, [2], True); mul_2183 = None + div_288 = torch.ops.aten.div.Tensor(mul_6, 2048) + mul_2184 = torch.ops.aten.mul.Tensor(div_288, sum_315); div_288 = sum_315 = None + sub_782 = torch.ops.aten.sub.Tensor(mul_2181, mul_2184); mul_2181 = mul_2184 = None + mul_2185 = torch.ops.aten.mul.Tensor(sub_782, rsqrt_2); sub_782 = rsqrt_2 = None + mul_2186 = torch.ops.aten.mul.Tensor(convert_element_type_3400, mul_6); convert_element_type_3400 = mul_6 = None + sum_316 = torch.ops.aten.sum.dim_IntList(mul_2186, [0, 1]); mul_2186 = None + convert_element_type_3403 = torch.ops.prims.convert_element_type.default(mul_2185, torch.bfloat16); mul_2185 = None + add_2168 = torch.ops.aten.add.Tensor(add_2164, convert_element_type_3403); add_2164 = convert_element_type_3403 = None + convert_element_type_default_3 = torch.ops.prims.convert_element_type.default(sum_316, torch.float32); sum_316 = None + reduce_scatter_tensor_369 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_3, 'avg', 128, '0'); convert_element_type_default_3 = None + wait_tensor_980 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_369); reduce_scatter_tensor_369 = None + view_2282 = torch.ops.aten.view.default(add_2168, [8192, 2048]) + permute_1718 = torch.ops.aten.permute.default(view_2282, [1, 0]) + permute_6 = torch.ops.aten.permute.default(getitem_6, [0, 2, 1, 3]) + view_20 = torch.ops.aten.view.default(permute_6, [2, 4096, -1]); permute_6 = None + view_22 = torch.ops.aten.view.default(view_20, [8192, 2048]); view_20 = None + mm_640 = torch.ops.aten.mm.default(permute_1718, view_22); permute_1718 = view_22 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16); primals_18 = None + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 128, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + permute_1720 = torch.ops.aten.permute.default(permute_7, [1, 0]); permute_7 = None + mm_641 = torch.ops.aten.mm.default(view_2282, permute_1720); view_2282 = permute_1720 = None + view_2283 = torch.ops.aten.view.default(mm_641, [2, 4096, 2048]); mm_641 = None + convert_element_type_3410 = torch.ops.prims.convert_element_type.default(mm_640, torch.float32); mm_640 = None + reduce_scatter_tensor_370 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3410, 'avg', 128, '0'); convert_element_type_3410 = None + wait_tensor_981 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_370); reduce_scatter_tensor_370 = None + view_2284 = torch.ops.aten.view.default(view_2283, [2, 4096, 16, 128]); view_2283 = None + permute_1722 = torch.ops.aten.permute.default(view_2284, [0, 2, 1, 3]); view_2284 = None + fw_graph26 = self.fw_graph26 + joint_graph26 = self.joint_graph26 + mask_graph26 = self.mask_graph26 + flex_attention_backward_26 = torch.ops.higher_order.flex_attention_backward(permute_3, permute_4, permute_5, getitem_6, getitem_7, permute_1722, None, fw_graph26, joint_graph26, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph26), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_3 = permute_4 = permute_5 = getitem_6 = getitem_7 = permute_1722 = fw_graph26 = joint_graph26 = primals_10 = primals_9 = primals_12 = primals_13 = primals_14 = primals_15 = primals_16 = primals_17 = mask_graph26 = primals_11 = None + getitem_29125 = flex_attention_backward_26[0] + getitem_29126 = flex_attention_backward_26[1] + getitem_29127 = flex_attention_backward_26[2]; flex_attention_backward_26 = None + permute_1723 = torch.ops.aten.permute.default(getitem_29127, [0, 2, 1, 3]); getitem_29127 = None + permute_1724 = torch.ops.aten.permute.default(getitem_29126, [0, 2, 1, 3]); getitem_29126 = None + permute_1725 = torch.ops.aten.permute.default(getitem_29125, [0, 2, 1, 3]); getitem_29125 = None + slice_318 = torch.ops.aten.slice.Tensor(permute_1724, 3, 0, 128) + slice_319 = torch.ops.aten.slice.Tensor(permute_1724, 3, 128, 192); permute_1724 = None + sum_317 = torch.ops.aten.sum.dim_IntList(slice_319, [2], True); slice_319 = None + cat_446 = torch.ops.aten.cat.default([slice_318, permute_1723], 3); slice_318 = permute_1723 = None + view_2285 = torch.ops.aten.view.default(cat_446, [2, 4096, 4096]); cat_446 = None + view_2286 = torch.ops.aten.view.default(view_2285, [8192, 4096]); view_2285 = None + permute_1726 = torch.ops.aten.permute.default(view_2286, [1, 0]) + mm_642 = torch.ops.aten.mm.default(permute_1726, view_17); permute_1726 = view_17 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16); primals_8 = None + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 128, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + permute_1728 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + mm_643 = torch.ops.aten.mm.default(view_2286, permute_1728); view_2286 = permute_1728 = None + view_2287 = torch.ops.aten.view.default(mm_643, [2, 4096, 512]); mm_643 = None + convert_element_type_3415 = torch.ops.prims.convert_element_type.default(mm_642, torch.float32); mm_642 = None + reduce_scatter_tensor_371 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3415, 'avg', 128, '0'); convert_element_type_3415 = None + wait_tensor_982 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_371); reduce_scatter_tensor_371 = None + convert_element_type_3416 = torch.ops.prims.convert_element_type.default(view_2287, torch.float32); view_2287 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16); primals_7 = None + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_14, 128, '0'); convert_element_type_14 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + convert_element_type_3418 = torch.ops.prims.convert_element_type.default(wait_tensor_4, torch.float32); wait_tensor_4 = None + mul_2187 = torch.ops.aten.mul.Tensor(convert_element_type_3416, convert_element_type_3418); convert_element_type_3418 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(getitem_2, torch.float32); getitem_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_15, rsqrt_1); convert_element_type_15 = None + mul_2189 = torch.ops.aten.mul.Tensor(mul_4, mul_2187) + sum_318 = torch.ops.aten.sum.dim_IntList(mul_2189, [2], True); mul_2189 = None + div_289 = torch.ops.aten.div.Tensor(mul_4, 512) + mul_2190 = torch.ops.aten.mul.Tensor(div_289, sum_318); div_289 = sum_318 = None + sub_783 = torch.ops.aten.sub.Tensor(mul_2187, mul_2190); mul_2187 = mul_2190 = None + mul_2191 = torch.ops.aten.mul.Tensor(sub_783, rsqrt_1); sub_783 = rsqrt_1 = None + mul_2192 = torch.ops.aten.mul.Tensor(convert_element_type_3416, mul_4); convert_element_type_3416 = mul_4 = None + sum_319 = torch.ops.aten.sum.dim_IntList(mul_2192, [0, 1]); mul_2192 = None + convert_element_type_3419 = torch.ops.prims.convert_element_type.default(mul_2191, torch.bfloat16); mul_2191 = None + convert_element_type_default_2 = torch.ops.prims.convert_element_type.default(sum_319, torch.float32); sum_319 = None + reduce_scatter_tensor_372 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_2, 'avg', 128, '0'); convert_element_type_default_2 = None + wait_tensor_983 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_372); reduce_scatter_tensor_372 = None + convert_element_type_3422 = torch.ops.prims.convert_element_type.default(sum_317, torch.float32); sum_317 = None + view_2288 = torch.ops.aten.view.default(convert_element_type_3422, [2, 4096, 1, 32, 2]); convert_element_type_3422 = None + view_as_complex_106 = torch.ops.aten.view_as_complex.default(view_2288); view_2288 = None + mul_2193 = torch.ops.aten.mul.Tensor(view_as_complex_106, clone_9); view_as_complex_106 = None + view_as_real_106 = torch.ops.aten.view_as_real.default(mul_2193); mul_2193 = None + view_2289 = torch.ops.aten.view.default(view_as_real_106, [2, 4096, 1, 64]); view_as_real_106 = None + convert_element_type_3423 = torch.ops.prims.convert_element_type.default(view_2289, torch.bfloat16); view_2289 = None + squeeze_52 = torch.ops.aten.squeeze.dim(convert_element_type_3423, 2); convert_element_type_3423 = None + cat_447 = torch.ops.aten.cat.default([convert_element_type_3419, squeeze_52], 2); convert_element_type_3419 = squeeze_52 = None + view_2290 = torch.ops.aten.view.default(cat_447, [8192, 576]); cat_447 = None + permute_1730 = torch.ops.aten.permute.default(view_2290, [1, 0]) + mm_644 = torch.ops.aten.mm.default(permute_1730, view_3); permute_1730 = None + convert_element_type_9 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16); primals_6 = None + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_9, 128, '0'); convert_element_type_9 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + slice_2 = torch.ops.aten.slice.Tensor(wait_tensor_3, 0, 0, 576); wait_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(slice_2, [1, 0]); slice_2 = None + permute_1732 = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None + mm_645 = torch.ops.aten.mm.default(view_2290, permute_1732); view_2290 = permute_1732 = None + view_2291 = torch.ops.aten.view.default(mm_645, [2, 4096, 2048]); mm_645 = None + convert_element_type_3428 = torch.ops.prims.convert_element_type.default(mm_644, torch.float32); mm_644 = None + split_1536 = torch.ops.aten.split.Tensor(convert_element_type_3428, 5); convert_element_type_3428 = None + getitem_29129 = split_1536[0] + getitem_29130 = split_1536[1] + getitem_29131 = split_1536[2] + getitem_29132 = split_1536[3] + getitem_29133 = split_1536[4] + getitem_29134 = split_1536[5] + getitem_29135 = split_1536[6] + getitem_29136 = split_1536[7] + getitem_29137 = split_1536[8] + getitem_29138 = split_1536[9] + getitem_29139 = split_1536[10] + getitem_29140 = split_1536[11] + getitem_29141 = split_1536[12] + getitem_29142 = split_1536[13] + getitem_29143 = split_1536[14] + getitem_29144 = split_1536[15] + getitem_29145 = split_1536[16] + getitem_29146 = split_1536[17] + getitem_29147 = split_1536[18] + getitem_29148 = split_1536[19] + getitem_29149 = split_1536[20] + getitem_29150 = split_1536[21] + getitem_29151 = split_1536[22] + getitem_29152 = split_1536[23] + getitem_29153 = split_1536[24] + getitem_29154 = split_1536[25] + getitem_29155 = split_1536[26] + getitem_29156 = split_1536[27] + getitem_29157 = split_1536[28] + getitem_29158 = split_1536[29] + getitem_29159 = split_1536[30] + getitem_29160 = split_1536[31] + getitem_29161 = split_1536[32] + getitem_29162 = split_1536[33] + getitem_29163 = split_1536[34] + getitem_29164 = split_1536[35] + getitem_29165 = split_1536[36] + getitem_29166 = split_1536[37] + getitem_29167 = split_1536[38] + getitem_29168 = split_1536[39] + getitem_29169 = split_1536[40] + getitem_29170 = split_1536[41] + getitem_29171 = split_1536[42] + getitem_29172 = split_1536[43] + getitem_29173 = split_1536[44] + getitem_29174 = split_1536[45] + getitem_29175 = split_1536[46] + getitem_29176 = split_1536[47] + getitem_29177 = split_1536[48] + getitem_29178 = split_1536[49] + getitem_29179 = split_1536[50] + getitem_29180 = split_1536[51] + getitem_29181 = split_1536[52] + getitem_29182 = split_1536[53] + getitem_29183 = split_1536[54] + getitem_29184 = split_1536[55] + getitem_29185 = split_1536[56] + getitem_29186 = split_1536[57] + getitem_29187 = split_1536[58] + getitem_29188 = split_1536[59] + getitem_29189 = split_1536[60] + getitem_29190 = split_1536[61] + getitem_29191 = split_1536[62] + getitem_29192 = split_1536[63] + getitem_29193 = split_1536[64] + getitem_29194 = split_1536[65] + getitem_29195 = split_1536[66] + getitem_29196 = split_1536[67] + getitem_29197 = split_1536[68] + getitem_29198 = split_1536[69] + getitem_29199 = split_1536[70] + getitem_29200 = split_1536[71] + getitem_29201 = split_1536[72] + getitem_29202 = split_1536[73] + getitem_29203 = split_1536[74] + getitem_29204 = split_1536[75] + getitem_29205 = split_1536[76] + getitem_29206 = split_1536[77] + getitem_29207 = split_1536[78] + getitem_29208 = split_1536[79] + getitem_29209 = split_1536[80] + getitem_29210 = split_1536[81] + getitem_29211 = split_1536[82] + getitem_29212 = split_1536[83] + getitem_29213 = split_1536[84] + getitem_29214 = split_1536[85] + getitem_29215 = split_1536[86] + getitem_29216 = split_1536[87] + getitem_29217 = split_1536[88] + getitem_29218 = split_1536[89] + getitem_29219 = split_1536[90] + getitem_29220 = split_1536[91] + getitem_29221 = split_1536[92] + getitem_29222 = split_1536[93] + getitem_29223 = split_1536[94] + getitem_29224 = split_1536[95] + getitem_29225 = split_1536[96] + getitem_29226 = split_1536[97] + getitem_29227 = split_1536[98] + getitem_29228 = split_1536[99] + getitem_29229 = split_1536[100] + getitem_29230 = split_1536[101] + getitem_29231 = split_1536[102] + getitem_29232 = split_1536[103] + getitem_29233 = split_1536[104] + getitem_29234 = split_1536[105] + getitem_29235 = split_1536[106] + getitem_29236 = split_1536[107] + getitem_29237 = split_1536[108] + getitem_29238 = split_1536[109] + getitem_29239 = split_1536[110] + getitem_29240 = split_1536[111] + getitem_29241 = split_1536[112] + getitem_29242 = split_1536[113] + getitem_29243 = split_1536[114] + getitem_29244 = split_1536[115]; split_1536 = None + constant_pad_nd_2004 = torch.ops.aten.constant_pad_nd.default(getitem_29244, [0, 0, 0, 4], 0.0); getitem_29244 = None + cat_448 = torch.ops.aten.cat.default([getitem_29129, getitem_29130, getitem_29131, getitem_29132, getitem_29133, getitem_29134, getitem_29135, getitem_29136, getitem_29137, getitem_29138, getitem_29139, getitem_29140, getitem_29141, getitem_29142, getitem_29143, getitem_29144, getitem_29145, getitem_29146, getitem_29147, getitem_29148, getitem_29149, getitem_29150, getitem_29151, getitem_29152, getitem_29153, getitem_29154, getitem_29155, getitem_29156, getitem_29157, getitem_29158, getitem_29159, getitem_29160, getitem_29161, getitem_29162, getitem_29163, getitem_29164, getitem_29165, getitem_29166, getitem_29167, getitem_29168, getitem_29169, getitem_29170, getitem_29171, getitem_29172, getitem_29173, getitem_29174, getitem_29175, getitem_29176, getitem_29177, getitem_29178, getitem_29179, getitem_29180, getitem_29181, getitem_29182, getitem_29183, getitem_29184, getitem_29185, getitem_29186, getitem_29187, getitem_29188, getitem_29189, getitem_29190, getitem_29191, getitem_29192, getitem_29193, getitem_29194, getitem_29195, getitem_29196, getitem_29197, getitem_29198, getitem_29199, getitem_29200, getitem_29201, getitem_29202, getitem_29203, getitem_29204, getitem_29205, getitem_29206, getitem_29207, getitem_29208, getitem_29209, getitem_29210, getitem_29211, getitem_29212, getitem_29213, getitem_29214, getitem_29215, getitem_29216, getitem_29217, getitem_29218, getitem_29219, getitem_29220, getitem_29221, getitem_29222, getitem_29223, getitem_29224, getitem_29225, getitem_29226, getitem_29227, getitem_29228, getitem_29229, getitem_29230, getitem_29231, getitem_29232, getitem_29233, getitem_29234, getitem_29235, getitem_29236, getitem_29237, getitem_29238, getitem_29239, getitem_29240, getitem_29241, getitem_29242, getitem_29243, constant_pad_nd_2004, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65, constant_pad_nd_65]); getitem_29129 = getitem_29130 = getitem_29131 = getitem_29132 = getitem_29133 = getitem_29134 = getitem_29135 = getitem_29136 = getitem_29137 = getitem_29138 = getitem_29139 = getitem_29140 = getitem_29141 = getitem_29142 = getitem_29143 = getitem_29144 = getitem_29145 = getitem_29146 = getitem_29147 = getitem_29148 = getitem_29149 = getitem_29150 = getitem_29151 = getitem_29152 = getitem_29153 = getitem_29154 = getitem_29155 = getitem_29156 = getitem_29157 = getitem_29158 = getitem_29159 = getitem_29160 = getitem_29161 = getitem_29162 = getitem_29163 = getitem_29164 = getitem_29165 = getitem_29166 = getitem_29167 = getitem_29168 = getitem_29169 = getitem_29170 = getitem_29171 = getitem_29172 = getitem_29173 = getitem_29174 = getitem_29175 = getitem_29176 = getitem_29177 = getitem_29178 = getitem_29179 = getitem_29180 = getitem_29181 = getitem_29182 = getitem_29183 = getitem_29184 = getitem_29185 = getitem_29186 = getitem_29187 = getitem_29188 = getitem_29189 = getitem_29190 = getitem_29191 = getitem_29192 = getitem_29193 = getitem_29194 = getitem_29195 = getitem_29196 = getitem_29197 = getitem_29198 = getitem_29199 = getitem_29200 = getitem_29201 = getitem_29202 = getitem_29203 = getitem_29204 = getitem_29205 = getitem_29206 = getitem_29207 = getitem_29208 = getitem_29209 = getitem_29210 = getitem_29211 = getitem_29212 = getitem_29213 = getitem_29214 = getitem_29215 = getitem_29216 = getitem_29217 = getitem_29218 = getitem_29219 = getitem_29220 = getitem_29221 = getitem_29222 = getitem_29223 = getitem_29224 = getitem_29225 = getitem_29226 = getitem_29227 = getitem_29228 = getitem_29229 = getitem_29230 = getitem_29231 = getitem_29232 = getitem_29233 = getitem_29234 = getitem_29235 = getitem_29236 = getitem_29237 = getitem_29238 = getitem_29239 = getitem_29240 = getitem_29241 = getitem_29242 = getitem_29243 = constant_pad_nd_2004 = constant_pad_nd_65 = None + reduce_scatter_tensor_373 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_448, 'avg', 128, '0'); cat_448 = None + wait_tensor_984 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_373); reduce_scatter_tensor_373 = None + slice_320 = torch.ops.aten.slice.Tensor(permute_1725, 3, 0, 128) + slice_321 = torch.ops.aten.slice.Tensor(permute_1725, 3, 128, 192); permute_1725 = None + convert_element_type_3429 = torch.ops.prims.convert_element_type.default(slice_321, torch.float32); slice_321 = None + view_2292 = torch.ops.aten.view.default(convert_element_type_3429, [2, 4096, 16, 32, 2]); convert_element_type_3429 = None + view_as_complex_107 = torch.ops.aten.view_as_complex.default(view_2292); view_2292 = None + mul_2194 = torch.ops.aten.mul.Tensor(view_as_complex_107, clone_9); view_as_complex_107 = clone_9 = None + view_as_real_107 = torch.ops.aten.view_as_real.default(mul_2194); mul_2194 = None + view_2293 = torch.ops.aten.view.default(view_as_real_107, [2, 4096, 16, 64]); view_as_real_107 = None + convert_element_type_3430 = torch.ops.prims.convert_element_type.default(view_2293, torch.bfloat16); view_2293 = None + cat_449 = torch.ops.aten.cat.default([slice_320, convert_element_type_3430], 3); slice_320 = convert_element_type_3430 = None + view_2294 = torch.ops.aten.view.default(cat_449, [2, 4096, 3072]); cat_449 = None + view_2295 = torch.ops.aten.view.default(view_2294, [8192, 3072]); view_2294 = None + permute_1734 = torch.ops.aten.permute.default(view_2295, [1, 0]) + mm_646 = torch.ops.aten.mm.default(permute_1734, view_3); permute_1734 = view_3 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16); primals_5 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 128, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + permute_1736 = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + mm_647 = torch.ops.aten.mm.default(view_2295, permute_1736); view_2295 = permute_1736 = None + view_2296 = torch.ops.aten.view.default(mm_647, [2, 4096, 2048]); mm_647 = None + add_2169 = torch.ops.aten.add.Tensor(view_2291, view_2296); view_2291 = view_2296 = None + convert_element_type_3435 = torch.ops.prims.convert_element_type.default(mm_646, torch.float32); mm_646 = None + reduce_scatter_tensor_374 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3435, 'avg', 128, '0'); convert_element_type_3435 = None + wait_tensor_985 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_374); reduce_scatter_tensor_374 = None + convert_element_type_3436 = torch.ops.prims.convert_element_type.default(add_2169, torch.float32); add_2169 = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 128, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_3438 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32); wait_tensor_1 = None + mul_2195 = torch.ops.aten.mul.Tensor(convert_element_type_3436, convert_element_type_3438); convert_element_type_3438 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32); embedding = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_2197 = torch.ops.aten.mul.Tensor(mul, mul_2195) + sum_320 = torch.ops.aten.sum.dim_IntList(mul_2197, [2], True); mul_2197 = None + div_290 = torch.ops.aten.div.Tensor(mul, 2048) + mul_2198 = torch.ops.aten.mul.Tensor(div_290, sum_320); div_290 = sum_320 = None + sub_784 = torch.ops.aten.sub.Tensor(mul_2195, mul_2198); mul_2195 = mul_2198 = None + mul_2199 = torch.ops.aten.mul.Tensor(sub_784, rsqrt); sub_784 = rsqrt = None + mul_2200 = torch.ops.aten.mul.Tensor(convert_element_type_3436, mul); convert_element_type_3436 = mul = None + sum_321 = torch.ops.aten.sum.dim_IntList(mul_2200, [0, 1]); mul_2200 = None + convert_element_type_3439 = torch.ops.prims.convert_element_type.default(mul_2199, torch.bfloat16); mul_2199 = None + add_2170 = torch.ops.aten.add.Tensor(add_2168, convert_element_type_3439); add_2168 = convert_element_type_3439 = None + convert_element_type_default_1 = torch.ops.prims.convert_element_type.default(sum_321, torch.float32); sum_321 = None + reduce_scatter_tensor_375 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_1, 'avg', 128, '0'); convert_element_type_default_1 = None + wait_tensor_986 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_375); reduce_scatter_tensor_375 = None + convert_element_type_3442 = torch.ops.prims.convert_element_type.default(add_2170, torch.float32); add_2170 = None + eq_572 = torch.ops.aten.eq.Scalar(primals_2, -1) + unsqueeze_79 = torch.ops.aten.unsqueeze.default(eq_572, -1); eq_572 = None + full_default_157 = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + where = torch.ops.aten.where.self(unsqueeze_79, full_default_157, convert_element_type_3442); unsqueeze_79 = full_default_157 = convert_element_type_3442 = None + full_default_158 = torch.ops.aten.full.default([102400, 2048], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_104 = torch.ops.aten.index_put.default(full_default_158, [primals_2], where, True); full_default_158 = primals_2 = where = None + convert_element_type_default = torch.ops.prims.convert_element_type.default(index_put_104, torch.float32); index_put_104 = None + reduce_scatter_tensor_376 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default, 'avg', 128, '0'); convert_element_type_default = None + wait_tensor_987 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_376); reduce_scatter_tensor_376 = None + return (wait_tensor_987, None, None, wait_tensor_986, wait_tensor_985, wait_tensor_984, wait_tensor_983, wait_tensor_982, None, None, None, None, None, None, None, None, None, wait_tensor_981, wait_tensor_980, wait_tensor_979, wait_tensor_978, wait_tensor_977, wait_tensor_976, wait_tensor_975, wait_tensor_974, wait_tensor_973, wait_tensor_972, wait_tensor_971, wait_tensor_970, None, wait_tensor_969, None, wait_tensor_967, wait_tensor_966, wait_tensor_965, wait_tensor_963, wait_tensor_962, wait_tensor_961, wait_tensor_960, wait_tensor_959, wait_tensor_958, wait_tensor_957, wait_tensor_956, wait_tensor_955, wait_tensor_954, None, wait_tensor_953, None, wait_tensor_951, wait_tensor_950, wait_tensor_949, wait_tensor_947, wait_tensor_946, wait_tensor_945, wait_tensor_944, wait_tensor_943, wait_tensor_942, wait_tensor_941, wait_tensor_940, wait_tensor_939, wait_tensor_938, None, wait_tensor_937, None, wait_tensor_935, wait_tensor_934, wait_tensor_933, wait_tensor_931, wait_tensor_930, wait_tensor_929, wait_tensor_928, wait_tensor_927, wait_tensor_926, wait_tensor_925, wait_tensor_924, wait_tensor_923, wait_tensor_922, None, wait_tensor_921, None, wait_tensor_919, wait_tensor_918, wait_tensor_917, wait_tensor_915, wait_tensor_914, wait_tensor_913, wait_tensor_912, wait_tensor_911, wait_tensor_910, wait_tensor_909, wait_tensor_908, wait_tensor_907, wait_tensor_906, None, wait_tensor_905, None, wait_tensor_903, wait_tensor_902, wait_tensor_901, wait_tensor_899, wait_tensor_898, wait_tensor_897, wait_tensor_896, wait_tensor_895, wait_tensor_894, wait_tensor_893, wait_tensor_892, wait_tensor_891, wait_tensor_890, None, wait_tensor_889, None, wait_tensor_887, wait_tensor_886, wait_tensor_885, wait_tensor_883, wait_tensor_882, wait_tensor_881, wait_tensor_880, wait_tensor_879, wait_tensor_878, wait_tensor_877, wait_tensor_876, wait_tensor_875, wait_tensor_874, None, wait_tensor_873, None, wait_tensor_871, wait_tensor_870, wait_tensor_869, wait_tensor_867, wait_tensor_866, wait_tensor_865, wait_tensor_864, wait_tensor_863, wait_tensor_862, wait_tensor_861, wait_tensor_860, wait_tensor_859, wait_tensor_858, None, wait_tensor_857, None, wait_tensor_855, wait_tensor_854, wait_tensor_853, wait_tensor_851, wait_tensor_850, wait_tensor_849, wait_tensor_848, wait_tensor_847, wait_tensor_846, wait_tensor_845, wait_tensor_844, wait_tensor_843, wait_tensor_842, None, wait_tensor_841, None, wait_tensor_839, wait_tensor_838, wait_tensor_837, wait_tensor_835, wait_tensor_834, wait_tensor_833, wait_tensor_832, wait_tensor_831, wait_tensor_830, wait_tensor_829, wait_tensor_828, wait_tensor_827, wait_tensor_826, None, wait_tensor_825, None, wait_tensor_823, wait_tensor_822, wait_tensor_821, wait_tensor_819, wait_tensor_818, wait_tensor_817, wait_tensor_816, wait_tensor_815, wait_tensor_814, wait_tensor_813, wait_tensor_812, wait_tensor_811, wait_tensor_810, None, wait_tensor_809, None, wait_tensor_807, wait_tensor_806, wait_tensor_805, wait_tensor_803, wait_tensor_802, wait_tensor_801, wait_tensor_800, wait_tensor_799, wait_tensor_798, wait_tensor_797, wait_tensor_796, wait_tensor_795, wait_tensor_794, None, wait_tensor_793, None, wait_tensor_791, wait_tensor_790, wait_tensor_789, wait_tensor_787, wait_tensor_786, wait_tensor_785, wait_tensor_784, wait_tensor_783, wait_tensor_782, wait_tensor_781, wait_tensor_780, wait_tensor_779, wait_tensor_778, None, wait_tensor_777, None, wait_tensor_775, wait_tensor_774, wait_tensor_773, wait_tensor_771, wait_tensor_770, wait_tensor_769, wait_tensor_768, wait_tensor_767, wait_tensor_766, wait_tensor_765, wait_tensor_764, wait_tensor_763, wait_tensor_762, None, wait_tensor_761, None, wait_tensor_759, wait_tensor_758, wait_tensor_757, wait_tensor_755, wait_tensor_754, wait_tensor_753, wait_tensor_752, wait_tensor_751, wait_tensor_750, wait_tensor_749, wait_tensor_748, wait_tensor_747, wait_tensor_746, None, wait_tensor_745, None, wait_tensor_743, wait_tensor_742, wait_tensor_741, wait_tensor_739, wait_tensor_738, wait_tensor_737, wait_tensor_736, wait_tensor_735, wait_tensor_734, wait_tensor_733, wait_tensor_732, wait_tensor_731, wait_tensor_730, None, wait_tensor_729, None, wait_tensor_727, wait_tensor_726, wait_tensor_725, wait_tensor_723, wait_tensor_722, wait_tensor_721, wait_tensor_720, wait_tensor_719, wait_tensor_718, wait_tensor_717, wait_tensor_716, wait_tensor_715, wait_tensor_714, None, wait_tensor_713, None, wait_tensor_711, wait_tensor_710, wait_tensor_709, wait_tensor_707, wait_tensor_706, wait_tensor_705, wait_tensor_704, wait_tensor_703, wait_tensor_702, wait_tensor_701, wait_tensor_700, wait_tensor_699, wait_tensor_698, None, wait_tensor_697, None, wait_tensor_695, wait_tensor_694, wait_tensor_693, wait_tensor_691, wait_tensor_690, wait_tensor_689, wait_tensor_688, wait_tensor_687, wait_tensor_686, wait_tensor_685, wait_tensor_684, wait_tensor_683, wait_tensor_682, None, wait_tensor_681, None, wait_tensor_679, wait_tensor_678, wait_tensor_677, wait_tensor_675, wait_tensor_674, wait_tensor_673, wait_tensor_672, wait_tensor_671, wait_tensor_670, wait_tensor_669, wait_tensor_668, wait_tensor_667, wait_tensor_666, None, wait_tensor_665, None, wait_tensor_663, wait_tensor_662, wait_tensor_661, wait_tensor_659, wait_tensor_658, wait_tensor_657, wait_tensor_656, wait_tensor_655, wait_tensor_654, wait_tensor_653, wait_tensor_652, wait_tensor_651, wait_tensor_650, None, wait_tensor_649, None, wait_tensor_647, wait_tensor_646, wait_tensor_645, wait_tensor_643, wait_tensor_642, wait_tensor_641, wait_tensor_640, wait_tensor_639, wait_tensor_638, wait_tensor_637, wait_tensor_636, wait_tensor_635, wait_tensor_634, None, wait_tensor_633, None, wait_tensor_631, wait_tensor_630, wait_tensor_629, wait_tensor_627, wait_tensor_626, wait_tensor_625, wait_tensor_624, wait_tensor_623, wait_tensor_622, wait_tensor_621, wait_tensor_620, wait_tensor_619, wait_tensor_618, None, wait_tensor_617, None, wait_tensor_615, wait_tensor_614, wait_tensor_613, wait_tensor_611, wait_tensor_610, wait_tensor_609, wait_tensor_608, wait_tensor_607, wait_tensor_606, wait_tensor_605, wait_tensor_604, wait_tensor_603, wait_tensor_602, None, wait_tensor_601, None, wait_tensor_599, wait_tensor_598, wait_tensor_597, wait_tensor_595, wait_tensor_594, wait_tensor_593, wait_tensor_592, wait_tensor_591, wait_tensor_590, wait_tensor_589, wait_tensor_588, wait_tensor_587, wait_tensor_586, None, wait_tensor_585, None, wait_tensor_583, wait_tensor_582, wait_tensor_581, wait_tensor_579, wait_tensor_578, wait_tensor_577, wait_tensor_576, wait_tensor_575, wait_tensor_574, wait_tensor_573, wait_tensor_572, wait_tensor_571, wait_tensor_570, None, wait_tensor_569, None, wait_tensor_567, wait_tensor_566, wait_tensor_565, wait_tensor_563, wait_tensor_562, wait_tensor_561, wait_tensor_560, wait_tensor_559) + +def load_args(reader): + # MoE expert token counts (approximate uniform distribution) + u8 = u9 = u10 = u11 = u12 = u13 = u14 = u15 = u24 = u25 = u26 = u27 = u28 = u29 = u30 = u31 = u40 = u41 = u42 = u43 = u44 = u45 = u46 = u47 = u56 = u57 = u58 = u59 = u60 = u61 = u62 = u63 = u72 = u73 = u74 = u75 = u76 = u77 = u78 = u79 = u88 = u89 = u90 = u91 = u92 = u93 = u94 = u95 = u104 = u105 = u106 = u107 = u108 = u109 = u110 = u111 = u120 = u121 = u122 = u123 = u124 = u125 = u126 = u127 = u136 = u137 = u138 = u139 = u140 = u141 = u142 = u143 = u152 = u153 = u154 = u155 = u156 = u157 = u158 = u159 = u168 = u169 = u170 = u171 = u172 = u173 = u174 = u175 = u184 = u185 = u186 = u187 = u188 = u189 = u190 = u191 = u200 = u201 = u202 = u203 = u204 = u205 = u206 = u207 = u216 = u217 = u218 = u219 = u220 = u221 = u222 = u223 = u232 = u233 = u234 = u235 = u236 = u237 = u238 = u239 = u248 = u249 = u250 = u251 = u252 = u253 = u254 = u255 = u264 = u265 = u266 = u267 = u268 = u269 = u270 = u271 = u280 = u281 = u282 = u283 = u284 = u285 = u286 = u287 = u296 = u297 = u298 = u299 = u300 = u301 = u302 = u303 = u312 = u313 = u314 = u315 = u316 = u317 = u318 = u319 = u328 = u329 = u330 = u331 = u332 = u333 = u334 = u335 = u344 = u345 = u346 = u347 = u348 = u349 = u350 = u351 = u360 = u361 = u362 = u363 = u364 = u365 = u366 = u367 = u376 = u377 = u378 = u379 = u380 = u381 = u382 = u383 = u392 = u393 = u394 = u395 = u396 = u397 = u398 = u399 = u408 = u409 = u410 = u411 = u412 = u413 = u414 = u415 = 512 + reader.symint(512) # _local_scalar_dense + reader.symint(512) # _local_scalar_dense_1 + reader.symint(512) # _local_scalar_dense_2 + reader.symint(512) # _local_scalar_dense_3 + reader.symint(512) # _local_scalar_dense_4 + reader.symint(512) # _local_scalar_dense_5 + reader.symint(512) # _local_scalar_dense_6 + reader.symint(512) # _local_scalar_dense_7 + reader.symint(512) # _local_scalar_dense_8 + reader.symint(512) # _local_scalar_dense_9 + reader.symint(512) # _local_scalar_dense_10 + reader.symint(512) # _local_scalar_dense_11 + reader.symint(512) # _local_scalar_dense_12 + reader.symint(512) # _local_scalar_dense_13 + reader.symint(512) # _local_scalar_dense_14 + reader.symint(512) # _local_scalar_dense_15 + reader.symint(512) # _local_scalar_dense_16 + reader.symint(512) # _local_scalar_dense_17 + reader.symint(512) # _local_scalar_dense_18 + reader.symint(512) # _local_scalar_dense_19 + reader.symint(512) # _local_scalar_dense_20 + reader.symint(512) # _local_scalar_dense_21 + reader.symint(512) # _local_scalar_dense_22 + reader.symint(512) # _local_scalar_dense_23 + reader.symint(512) # _local_scalar_dense_24 + reader.symint(512) # _local_scalar_dense_25 + reader.symint(512) # _local_scalar_dense_26 + reader.symint(512) # _local_scalar_dense_27 + reader.symint(512) # _local_scalar_dense_28 + reader.symint(512) # _local_scalar_dense_29 + reader.symint(512) # _local_scalar_dense_30 + reader.symint(512) # _local_scalar_dense_31 + reader.symint(512) # _local_scalar_dense_32 + reader.symint(512) # _local_scalar_dense_33 + reader.symint(512) # _local_scalar_dense_34 + reader.symint(512) # _local_scalar_dense_35 + reader.symint(512) # _local_scalar_dense_36 + reader.symint(512) # _local_scalar_dense_37 + reader.symint(512) # _local_scalar_dense_38 + reader.symint(512) # _local_scalar_dense_39 + reader.symint(512) # _local_scalar_dense_40 + reader.symint(512) # _local_scalar_dense_41 + reader.symint(512) # _local_scalar_dense_42 + reader.symint(512) # _local_scalar_dense_43 + reader.symint(512) # _local_scalar_dense_44 + reader.symint(512) # _local_scalar_dense_45 + reader.symint(512) # _local_scalar_dense_46 + reader.symint(512) # _local_scalar_dense_47 + reader.symint(512) # _local_scalar_dense_48 + reader.symint(512) # _local_scalar_dense_49 + reader.symint(512) # _local_scalar_dense_50 + reader.symint(512) # _local_scalar_dense_51 + reader.symint(512) # _local_scalar_dense_52 + reader.symint(512) # _local_scalar_dense_53 + reader.symint(512) # _local_scalar_dense_54 + reader.symint(512) # _local_scalar_dense_55 + reader.symint(512) # _local_scalar_dense_56 + reader.symint(512) # _local_scalar_dense_57 + reader.symint(512) # _local_scalar_dense_58 + reader.symint(512) # _local_scalar_dense_59 + reader.symint(512) # _local_scalar_dense_60 + reader.symint(512) # _local_scalar_dense_61 + reader.symint(512) # _local_scalar_dense_62 + reader.symint(512) # _local_scalar_dense_63 + reader.symint(512) # _local_scalar_dense_64 + reader.symint(512) # _local_scalar_dense_65 + reader.symint(512) # _local_scalar_dense_66 + reader.symint(512) # _local_scalar_dense_67 + reader.symint(512) # _local_scalar_dense_68 + reader.symint(512) # _local_scalar_dense_69 + reader.symint(512) # _local_scalar_dense_70 + reader.symint(512) # _local_scalar_dense_71 + reader.symint(512) # _local_scalar_dense_72 + reader.symint(512) # _local_scalar_dense_73 + reader.symint(512) # _local_scalar_dense_74 + reader.symint(512) # _local_scalar_dense_75 + reader.symint(512) # _local_scalar_dense_76 + reader.symint(512) # _local_scalar_dense_77 + reader.symint(512) # _local_scalar_dense_78 + reader.symint(512) # _local_scalar_dense_79 + reader.symint(512) # _local_scalar_dense_80 + reader.symint(512) # _local_scalar_dense_81 + reader.symint(512) # _local_scalar_dense_82 + reader.symint(512) # _local_scalar_dense_83 + reader.symint(512) # _local_scalar_dense_84 + reader.symint(512) # _local_scalar_dense_85 + reader.symint(512) # _local_scalar_dense_86 + reader.symint(512) # _local_scalar_dense_87 + reader.symint(512) # _local_scalar_dense_88 + reader.symint(512) # _local_scalar_dense_89 + reader.symint(512) # _local_scalar_dense_90 + reader.symint(512) # _local_scalar_dense_91 + reader.symint(512) # _local_scalar_dense_92 + reader.symint(512) # _local_scalar_dense_93 + reader.symint(512) # _local_scalar_dense_94 + reader.symint(512) # _local_scalar_dense_95 + reader.symint(512) # _local_scalar_dense_96 + reader.symint(512) # _local_scalar_dense_97 + reader.symint(512) # _local_scalar_dense_98 + reader.symint(512) # _local_scalar_dense_99 + reader.symint(512) # _local_scalar_dense_100 + reader.symint(512) # _local_scalar_dense_101 + reader.symint(512) # _local_scalar_dense_102 + reader.symint(512) # _local_scalar_dense_103 + reader.symint(512) # _local_scalar_dense_104 + reader.symint(512) # _local_scalar_dense_105 + reader.symint(512) # _local_scalar_dense_106 + reader.symint(512) # _local_scalar_dense_107 + reader.symint(512) # _local_scalar_dense_108 + reader.symint(512) # _local_scalar_dense_109 + reader.symint(512) # _local_scalar_dense_110 + reader.symint(512) # _local_scalar_dense_111 + reader.symint(512) # _local_scalar_dense_112 + reader.symint(512) # _local_scalar_dense_113 + reader.symint(512) # _local_scalar_dense_114 + reader.symint(512) # _local_scalar_dense_115 + reader.symint(512) # _local_scalar_dense_116 + reader.symint(512) # _local_scalar_dense_117 + reader.symint(512) # _local_scalar_dense_118 + reader.symint(512) # _local_scalar_dense_119 + reader.symint(512) # _local_scalar_dense_120 + reader.symint(512) # _local_scalar_dense_121 + reader.symint(512) # _local_scalar_dense_122 + reader.symint(512) # _local_scalar_dense_123 + reader.symint(512) # _local_scalar_dense_124 + reader.symint(512) # _local_scalar_dense_125 + reader.symint(512) # _local_scalar_dense_126 + reader.symint(512) # _local_scalar_dense_127 + reader.symint(512) # _local_scalar_dense_128 + reader.symint(512) # _local_scalar_dense_129 + reader.symint(512) # _local_scalar_dense_130 + reader.symint(512) # _local_scalar_dense_131 + reader.symint(512) # _local_scalar_dense_132 + reader.symint(512) # _local_scalar_dense_133 + reader.symint(512) # _local_scalar_dense_134 + reader.symint(512) # _local_scalar_dense_135 + reader.symint(512) # _local_scalar_dense_136 + reader.symint(512) # _local_scalar_dense_137 + reader.symint(512) # _local_scalar_dense_138 + reader.symint(512) # _local_scalar_dense_139 + reader.symint(512) # _local_scalar_dense_140 + reader.symint(512) # _local_scalar_dense_141 + reader.symint(512) # _local_scalar_dense_142 + reader.symint(512) # _local_scalar_dense_143 + reader.symint(512) # _local_scalar_dense_144 + reader.symint(512) # _local_scalar_dense_145 + reader.symint(512) # _local_scalar_dense_146 + reader.symint(512) # _local_scalar_dense_147 + reader.symint(512) # _local_scalar_dense_148 + reader.symint(512) # _local_scalar_dense_149 + reader.symint(512) # _local_scalar_dense_150 + reader.symint(512) # _local_scalar_dense_151 + reader.symint(512) # _local_scalar_dense_152 + reader.symint(512) # _local_scalar_dense_153 + reader.symint(512) # _local_scalar_dense_154 + reader.symint(512) # _local_scalar_dense_155 + reader.symint(512) # _local_scalar_dense_156 + reader.symint(512) # _local_scalar_dense_157 + reader.symint(512) # _local_scalar_dense_158 + reader.symint(512) # _local_scalar_dense_159 + reader.symint(512) # _local_scalar_dense_160 + reader.symint(512) # _local_scalar_dense_161 + reader.symint(512) # _local_scalar_dense_162 + reader.symint(512) # _local_scalar_dense_163 + reader.symint(512) # _local_scalar_dense_164 + reader.symint(512) # _local_scalar_dense_165 + reader.symint(512) # _local_scalar_dense_166 + reader.symint(512) # _local_scalar_dense_167 + reader.symint(512) # _local_scalar_dense_168 + reader.symint(512) # _local_scalar_dense_169 + reader.symint(512) # _local_scalar_dense_170 + reader.symint(512) # _local_scalar_dense_171 + reader.symint(512) # _local_scalar_dense_172 + reader.symint(512) # _local_scalar_dense_173 + reader.symint(512) # _local_scalar_dense_174 + reader.symint(512) # _local_scalar_dense_175 + reader.symint(512) # _local_scalar_dense_176 + reader.symint(512) # _local_scalar_dense_177 + reader.symint(512) # _local_scalar_dense_178 + reader.symint(512) # _local_scalar_dense_179 + reader.symint(512) # _local_scalar_dense_180 + reader.symint(512) # _local_scalar_dense_181 + reader.symint(512) # _local_scalar_dense_182 + reader.symint(512) # _local_scalar_dense_183 + reader.symint(512) # _local_scalar_dense_184 + reader.symint(512) # _local_scalar_dense_185 + reader.symint(512) # _local_scalar_dense_186 + reader.symint(512) # _local_scalar_dense_187 + reader.symint(512) # _local_scalar_dense_188 + reader.symint(512) # _local_scalar_dense_189 + reader.symint(512) # _local_scalar_dense_190 + reader.symint(512) # _local_scalar_dense_191 + reader.symint(512) # _local_scalar_dense_192 + reader.symint(512) # _local_scalar_dense_193 + reader.symint(512) # _local_scalar_dense_194 + reader.symint(512) # _local_scalar_dense_195 + reader.symint(512) # _local_scalar_dense_196 + reader.symint(512) # _local_scalar_dense_197 + reader.symint(512) # _local_scalar_dense_198 + reader.symint(512) # _local_scalar_dense_199 + reader.symint(512) # _local_scalar_dense_200 + reader.symint(512) # _local_scalar_dense_201 + reader.symint(512) # _local_scalar_dense_202 + reader.symint(512) # _local_scalar_dense_203 + reader.symint(512) # _local_scalar_dense_204 + reader.symint(512) # _local_scalar_dense_205 + reader.symint(512) # _local_scalar_dense_206 + reader.symint(512) # _local_scalar_dense_207 + reader.symint(512) # _local_scalar_dense_208 + reader.symint(512) # _local_scalar_dense_209 + reader.symint(512) # _local_scalar_dense_210 + reader.symint(512) # _local_scalar_dense_211 + reader.symint(512) # _local_scalar_dense_212 + reader.symint(512) # _local_scalar_dense_213 + reader.symint(512) # _local_scalar_dense_214 + reader.symint(512) # _local_scalar_dense_215 + reader.symint(512) # _local_scalar_dense_216 + reader.symint(512) # _local_scalar_dense_217 + reader.symint(512) # _local_scalar_dense_218 + reader.symint(512) # _local_scalar_dense_219 + reader.symint(512) # _local_scalar_dense_220 + reader.symint(512) # _local_scalar_dense_221 + reader.symint(512) # _local_scalar_dense_222 + reader.symint(512) # _local_scalar_dense_223 + reader.symint(512) # _local_scalar_dense_224 + reader.symint(512) # _local_scalar_dense_225 + reader.symint(512) # _local_scalar_dense_226 + reader.symint(512) # _local_scalar_dense_227 + reader.symint(512) # _local_scalar_dense_228 + reader.symint(512) # _local_scalar_dense_229 + reader.symint(512) # _local_scalar_dense_230 + reader.symint(512) # _local_scalar_dense_231 + reader.symint(512) # _local_scalar_dense_232 + reader.symint(512) # _local_scalar_dense_233 + reader.symint(512) # _local_scalar_dense_234 + reader.symint(512) # _local_scalar_dense_235 + reader.symint(512) # _local_scalar_dense_236 + reader.symint(512) # _local_scalar_dense_237 + reader.symint(512) # _local_scalar_dense_238 + reader.symint(512) # _local_scalar_dense_239 + reader.symint(512) # _local_scalar_dense_240 + reader.symint(512) # _local_scalar_dense_241 + reader.symint(512) # _local_scalar_dense_242 + reader.symint(512) # _local_scalar_dense_243 + reader.symint(512) # _local_scalar_dense_244 + reader.symint(512) # _local_scalar_dense_245 + reader.symint(512) # _local_scalar_dense_246 + reader.symint(512) # _local_scalar_dense_247 + reader.symint(512) # _local_scalar_dense_248 + reader.symint(512) # _local_scalar_dense_249 + reader.symint(512) # _local_scalar_dense_250 + reader.symint(512) # _local_scalar_dense_251 + reader.symint(512) # _local_scalar_dense_252 + reader.symint(512) # _local_scalar_dense_253 + reader.symint(512) # _local_scalar_dense_254 + reader.symint(512) # _local_scalar_dense_255 + reader.symint(512) # _local_scalar_dense_256 + reader.symint(512) # _local_scalar_dense_257 + reader.symint(512) # _local_scalar_dense_258 + reader.symint(512) # _local_scalar_dense_259 + reader.symint(512) # _local_scalar_dense_260 + reader.symint(512) # _local_scalar_dense_261 + reader.symint(512) # _local_scalar_dense_262 + reader.symint(512) # _local_scalar_dense_263 + reader.symint(512) # _local_scalar_dense_264 + reader.symint(512) # _local_scalar_dense_265 + reader.symint(512) # _local_scalar_dense_266 + reader.symint(512) # _local_scalar_dense_267 + reader.symint(512) # _local_scalar_dense_268 + reader.symint(512) # _local_scalar_dense_269 + reader.symint(512) # _local_scalar_dense_270 + reader.symint(512) # _local_scalar_dense_271 + reader.symint(512) # _local_scalar_dense_272 + reader.symint(512) # _local_scalar_dense_273 + reader.symint(512) # _local_scalar_dense_274 + reader.symint(512) # _local_scalar_dense_275 + reader.symint(512) # _local_scalar_dense_276 + reader.symint(512) # _local_scalar_dense_277 + reader.symint(512) # _local_scalar_dense_278 + reader.symint(512) # _local_scalar_dense_279 + reader.symint(512) # _local_scalar_dense_280 + reader.symint(512) # _local_scalar_dense_281 + reader.symint(512) # _local_scalar_dense_282 + reader.symint(512) # _local_scalar_dense_283 + reader.symint(512) # _local_scalar_dense_284 + reader.symint(512) # _local_scalar_dense_285 + reader.symint(512) # _local_scalar_dense_286 + reader.symint(512) # _local_scalar_dense_287 + reader.symint(512) # _local_scalar_dense_288 + reader.symint(512) # _local_scalar_dense_289 + reader.symint(512) # _local_scalar_dense_290 + reader.symint(512) # _local_scalar_dense_291 + reader.symint(512) # _local_scalar_dense_292 + reader.symint(512) # _local_scalar_dense_293 + reader.symint(512) # _local_scalar_dense_294 + reader.symint(512) # _local_scalar_dense_295 + reader.symint(512) # _local_scalar_dense_296 + reader.symint(512) # _local_scalar_dense_297 + reader.symint(512) # _local_scalar_dense_298 + reader.symint(512) # _local_scalar_dense_299 + reader.symint(512) # _local_scalar_dense_300 + reader.symint(512) # _local_scalar_dense_301 + reader.symint(512) # _local_scalar_dense_302 + reader.symint(512) # _local_scalar_dense_303 + reader.symint(512) # _local_scalar_dense_304 + reader.symint(512) # _local_scalar_dense_305 + reader.symint(512) # _local_scalar_dense_306 + reader.symint(512) # _local_scalar_dense_307 + reader.symint(512) # _local_scalar_dense_308 + reader.symint(512) # _local_scalar_dense_309 + reader.symint(512) # _local_scalar_dense_310 + reader.symint(512) # _local_scalar_dense_311 + reader.symint(512) # _local_scalar_dense_312 + reader.symint(512) # _local_scalar_dense_313 + reader.symint(512) # _local_scalar_dense_314 + reader.symint(512) # _local_scalar_dense_315 + reader.symint(512) # _local_scalar_dense_316 + reader.symint(512) # _local_scalar_dense_317 + reader.symint(512) # _local_scalar_dense_318 + reader.symint(512) # _local_scalar_dense_319 + reader.symint(512) # _local_scalar_dense_320 + reader.symint(512) # _local_scalar_dense_321 + reader.symint(512) # _local_scalar_dense_322 + reader.symint(512) # _local_scalar_dense_323 + reader.symint(512) # _local_scalar_dense_324 + reader.symint(512) # _local_scalar_dense_325 + reader.symint(512) # _local_scalar_dense_326 + reader.symint(512) # _local_scalar_dense_327 + reader.symint(512) # _local_scalar_dense_328 + reader.symint(512) # _local_scalar_dense_329 + reader.symint(512) # _local_scalar_dense_330 + reader.symint(512) # _local_scalar_dense_331 + reader.symint(512) # _local_scalar_dense_332 + reader.symint(512) # _local_scalar_dense_333 + reader.symint(512) # _local_scalar_dense_334 + reader.symint(512) # _local_scalar_dense_335 + reader.symint(512) # _local_scalar_dense_336 + reader.symint(512) # _local_scalar_dense_337 + reader.symint(512) # _local_scalar_dense_338 + reader.symint(512) # _local_scalar_dense_339 + reader.symint(512) # _local_scalar_dense_340 + reader.symint(512) # _local_scalar_dense_341 + reader.symint(512) # _local_scalar_dense_342 + reader.symint(512) # _local_scalar_dense_343 + reader.symint(512) # _local_scalar_dense_344 + reader.symint(512) # _local_scalar_dense_345 + reader.symint(512) # _local_scalar_dense_346 + reader.symint(512) # _local_scalar_dense_347 + reader.symint(512) # _local_scalar_dense_348 + reader.symint(512) # _local_scalar_dense_349 + reader.symint(512) # _local_scalar_dense_350 + reader.symint(512) # _local_scalar_dense_351 + reader.symint(512) # _local_scalar_dense_352 + reader.symint(512) # _local_scalar_dense_353 + reader.symint(512) # _local_scalar_dense_354 + reader.symint(512) # _local_scalar_dense_355 + reader.symint(512) # _local_scalar_dense_356 + reader.symint(512) # _local_scalar_dense_357 + reader.symint(512) # _local_scalar_dense_358 + reader.symint(512) # _local_scalar_dense_359 + reader.symint(512) # _local_scalar_dense_360 + reader.symint(512) # _local_scalar_dense_361 + reader.symint(512) # _local_scalar_dense_362 + reader.symint(512) # _local_scalar_dense_363 + reader.symint(512) # _local_scalar_dense_364 + reader.symint(512) # _local_scalar_dense_365 + reader.symint(512) # _local_scalar_dense_366 + reader.symint(512) # _local_scalar_dense_367 + reader.symint(512) # _local_scalar_dense_368 + reader.symint(512) # _local_scalar_dense_369 + reader.symint(512) # _local_scalar_dense_370 + reader.symint(512) # _local_scalar_dense_371 + reader.symint(512) # _local_scalar_dense_372 + reader.symint(512) # _local_scalar_dense_373 + reader.symint(512) # _local_scalar_dense_374 + reader.symint(512) # _local_scalar_dense_375 + reader.symint(512) # _local_scalar_dense_376 + reader.symint(512) # _local_scalar_dense_377 + reader.symint(512) # _local_scalar_dense_378 + reader.symint(512) # _local_scalar_dense_379 + reader.symint(512) # _local_scalar_dense_380 + reader.symint(512) # _local_scalar_dense_381 + reader.symint(512) # _local_scalar_dense_382 + reader.symint(512) # _local_scalar_dense_383 + reader.symint(512) # _local_scalar_dense_384 + reader.symint(512) # _local_scalar_dense_385 + reader.symint(512) # _local_scalar_dense_386 + reader.symint(512) # _local_scalar_dense_387 + reader.symint(512) # _local_scalar_dense_388 + reader.symint(512) # _local_scalar_dense_389 + reader.symint(512) # _local_scalar_dense_390 + reader.symint(512) # _local_scalar_dense_391 + reader.symint(512) # _local_scalar_dense_392 + reader.symint(512) # _local_scalar_dense_393 + reader.symint(512) # _local_scalar_dense_394 + reader.symint(512) # _local_scalar_dense_395 + reader.symint(512) # _local_scalar_dense_396 + reader.symint(512) # _local_scalar_dense_397 + reader.symint(512) # _local_scalar_dense_398 + reader.symint(512) # _local_scalar_dense_399 + reader.symint(512) # _local_scalar_dense_400 + reader.symint(512) # _local_scalar_dense_401 + reader.symint(512) # _local_scalar_dense_402 + reader.symint(512) # _local_scalar_dense_403 + reader.symint(512) # _local_scalar_dense_404 + reader.symint(512) # _local_scalar_dense_405 + reader.symint(512) # _local_scalar_dense_406 + reader.symint(512) # _local_scalar_dense_407 + reader.symint(512) # _local_scalar_dense_408 + reader.symint(512) # _local_scalar_dense_409 + reader.symint(512) # _local_scalar_dense_410 + reader.symint(512) # _local_scalar_dense_411 + reader.symint(512) # _local_scalar_dense_412 + reader.symint(512) # _local_scalar_dense_413 + reader.symint(512) # _local_scalar_dense_414 + reader.symint(512) # _local_scalar_dense_415 + reader.symint(512) # sym_size_int_1 + reader.symint(512) # sym_size_int_5 + reader.symint(512) # sym_size_int_9 + reader.symint(512) # sym_size_int_13 + reader.symint(512) # sym_size_int_17 + reader.symint(512) # sym_size_int_21 + reader.symint(512) # sym_size_int_25 + reader.symint(512) # sym_size_int_29 + reader.symint(512) # sym_size_int_33 + reader.symint(512) # sym_size_int_37 + reader.symint(512) # sym_size_int_41 + reader.symint(512) # sym_size_int_45 + reader.symint(512) # sym_size_int_49 + reader.symint(512) # sym_size_int_53 + reader.symint(512) # sym_size_int_57 + reader.symint(512) # sym_size_int_61 + reader.symint(512) # sym_size_int_65 + reader.symint(512) # sym_size_int_69 + reader.symint(512) # sym_size_int_73 + reader.symint(512) # sym_size_int_77 + reader.symint(512) # sym_size_int_81 + reader.symint(512) # sym_size_int_85 + reader.symint(512) # sym_size_int_89 + reader.symint(512) # sym_size_int_93 + reader.symint(512) # sym_size_int_97 + reader.symint(512) # sym_size_int_101 + reader.symint(512) # add_1781 + reader.symint(512) # add_1796 + reader.symint(512) # add_1811 + reader.symint(512) # add_1826 + reader.symint(512) # add_1841 + reader.symint(512) # add_1856 + reader.symint(512) # add_1871 + reader.symint(512) # add_1886 + reader.symint(512) # add_1901 + reader.symint(512) # add_1916 + reader.symint(512) # add_1931 + reader.symint(512) # add_1946 + reader.symint(512) # add_1961 + reader.symint(512) # add_1976 + reader.symint(512) # add_1991 + reader.symint(512) # add_2006 + reader.symint(512) # add_2021 + reader.symint(512) # add_2036 + reader.symint(512) # add_2051 + reader.symint(512) # add_2066 + reader.symint(512) # add_2081 + reader.symint(512) # add_2096 + reader.symint(512) # add_2111 + reader.symint(512) # add_2126 + reader.symint(512) # add_2141 + reader.symint(512) # add_2156 + buf0 = reader.storage(None, 6553600, device=device(type='cuda', index=0)) + reader.tensor(buf0, (800, 2048), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 65536, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 4096), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (4096, 32), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf3, (16,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf4, (24, 2048), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf5, (5, 2048), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf6, (4,), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf7, (32, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf8, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_9 + buf9 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf9, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_10 + buf10 = reader.storage(None, 32768, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf10, (2, 4096), dtype=torch.int32, is_leaf=True) # primals_11 + buf11 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf11, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_12 + buf12 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf12, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_13 + buf13 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf13, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_14 + buf14 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf14, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_15 + buf15 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf15, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_16 + buf16 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf16, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_17 + buf17 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf17, (16, 2048), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf18, (16,), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 704512, device=device(type='cuda', index=0)) + reader.tensor(buf19, (86, 2048), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 704512, device=device(type='cuda', index=0)) + reader.tensor(buf20, (86, 2048), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 700416, device=device(type='cuda', index=0)) + reader.tensor(buf21, (16, 10944), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf22, (16,), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf23, (24, 2048), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf24, (5, 2048), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf25, (4,), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf26, (32, 512), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf27, (16, 2048), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf28, (16,), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf29, (1, 2048), is_leaf=True) # primals_31 + buf30 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf30, (8, 88, 2048), is_leaf=True) # primals_33 + buf31 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf31, (8, 128, 1408), is_leaf=True) # primals_34 + buf32 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf32, (8, 88, 2048), is_leaf=True) # primals_35 + buf33 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf33, (22, 2048), is_leaf=True) # primals_36 + buf34 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf34, (22, 2048), is_leaf=True) # primals_37 + buf35 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf35, (16, 2816), is_leaf=True) # primals_38 + buf36 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf36, (16,), is_leaf=True) # primals_39 + buf37 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf37, (24, 2048), is_leaf=True) # primals_40 + buf38 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf38, (5, 2048), is_leaf=True) # primals_41 + buf39 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf39, (4,), is_leaf=True) # primals_42 + buf40 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf40, (32, 512), is_leaf=True) # primals_43 + buf41 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf41, (16, 2048), is_leaf=True) # primals_44 + buf42 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf42, (16,), is_leaf=True) # primals_45 + buf43 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf43, (1, 2048), is_leaf=True) # primals_47 + buf44 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf44, (8, 88, 2048), is_leaf=True) # primals_49 + buf45 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf45, (8, 128, 1408), is_leaf=True) # primals_50 + buf46 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf46, (8, 88, 2048), is_leaf=True) # primals_51 + buf47 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf47, (22, 2048), is_leaf=True) # primals_52 + buf48 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf48, (22, 2048), is_leaf=True) # primals_53 + buf49 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf49, (16, 2816), is_leaf=True) # primals_54 + buf50 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf50, (16,), is_leaf=True) # primals_55 + buf51 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf51, (24, 2048), is_leaf=True) # primals_56 + buf52 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf52, (5, 2048), is_leaf=True) # primals_57 + buf53 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf53, (4,), is_leaf=True) # primals_58 + buf54 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf54, (32, 512), is_leaf=True) # primals_59 + buf55 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf55, (16, 2048), is_leaf=True) # primals_60 + buf56 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf56, (16,), is_leaf=True) # primals_61 + buf57 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf57, (1, 2048), is_leaf=True) # primals_63 + buf58 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf58, (8, 88, 2048), is_leaf=True) # primals_65 + buf59 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf59, (8, 128, 1408), is_leaf=True) # primals_66 + buf60 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf60, (8, 88, 2048), is_leaf=True) # primals_67 + buf61 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf61, (22, 2048), is_leaf=True) # primals_68 + buf62 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf62, (22, 2048), is_leaf=True) # primals_69 + buf63 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf63, (16, 2816), is_leaf=True) # primals_70 + buf64 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf64, (16,), is_leaf=True) # primals_71 + buf65 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf65, (24, 2048), is_leaf=True) # primals_72 + buf66 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf66, (5, 2048), is_leaf=True) # primals_73 + buf67 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf67, (4,), is_leaf=True) # primals_74 + buf68 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf68, (32, 512), is_leaf=True) # primals_75 + buf69 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf69, (16, 2048), is_leaf=True) # primals_76 + buf70 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf70, (16,), is_leaf=True) # primals_77 + buf71 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf71, (1, 2048), is_leaf=True) # primals_79 + buf72 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf72, (8, 88, 2048), is_leaf=True) # primals_81 + buf73 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf73, (8, 128, 1408), is_leaf=True) # primals_82 + buf74 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf74, (8, 88, 2048), is_leaf=True) # primals_83 + buf75 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf75, (22, 2048), is_leaf=True) # primals_84 + buf76 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf76, (22, 2048), is_leaf=True) # primals_85 + buf77 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf77, (16, 2816), is_leaf=True) # primals_86 + buf78 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf78, (16,), is_leaf=True) # primals_87 + buf79 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf79, (24, 2048), is_leaf=True) # primals_88 + buf80 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf80, (5, 2048), is_leaf=True) # primals_89 + buf81 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf81, (4,), is_leaf=True) # primals_90 + buf82 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf82, (32, 512), is_leaf=True) # primals_91 + buf83 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf83, (16, 2048), is_leaf=True) # primals_92 + buf84 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf84, (16,), is_leaf=True) # primals_93 + buf85 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf85, (1, 2048), is_leaf=True) # primals_95 + buf86 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf86, (8, 88, 2048), is_leaf=True) # primals_97 + buf87 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf87, (8, 128, 1408), is_leaf=True) # primals_98 + buf88 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf88, (8, 88, 2048), is_leaf=True) # primals_99 + buf89 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf89, (22, 2048), is_leaf=True) # primals_100 + buf90 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf90, (22, 2048), is_leaf=True) # primals_101 + buf91 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf91, (16, 2816), is_leaf=True) # primals_102 + buf92 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf92, (16,), is_leaf=True) # primals_103 + buf93 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf93, (24, 2048), is_leaf=True) # primals_104 + buf94 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf94, (5, 2048), is_leaf=True) # primals_105 + buf95 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf95, (4,), is_leaf=True) # primals_106 + buf96 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf96, (32, 512), is_leaf=True) # primals_107 + buf97 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf97, (16, 2048), is_leaf=True) # primals_108 + buf98 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf98, (16,), is_leaf=True) # primals_109 + buf99 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf99, (1, 2048), is_leaf=True) # primals_111 + buf100 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf100, (8, 88, 2048), is_leaf=True) # primals_113 + buf101 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf101, (8, 128, 1408), is_leaf=True) # primals_114 + buf102 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf102, (8, 88, 2048), is_leaf=True) # primals_115 + buf103 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf103, (22, 2048), is_leaf=True) # primals_116 + buf104 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf104, (22, 2048), is_leaf=True) # primals_117 + buf105 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf105, (16, 2816), is_leaf=True) # primals_118 + buf106 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf106, (16,), is_leaf=True) # primals_119 + buf107 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf107, (24, 2048), is_leaf=True) # primals_120 + buf108 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf108, (5, 2048), is_leaf=True) # primals_121 + buf109 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf109, (4,), is_leaf=True) # primals_122 + buf110 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf110, (32, 512), is_leaf=True) # primals_123 + buf111 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf111, (16, 2048), is_leaf=True) # primals_124 + buf112 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf112, (16,), is_leaf=True) # primals_125 + buf113 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf113, (1, 2048), is_leaf=True) # primals_127 + buf114 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf114, (8, 88, 2048), is_leaf=True) # primals_129 + buf115 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf115, (8, 128, 1408), is_leaf=True) # primals_130 + buf116 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf116, (8, 88, 2048), is_leaf=True) # primals_131 + buf117 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf117, (22, 2048), is_leaf=True) # primals_132 + buf118 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf118, (22, 2048), is_leaf=True) # primals_133 + buf119 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf119, (16, 2816), is_leaf=True) # primals_134 + buf120 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf120, (16,), is_leaf=True) # primals_135 + buf121 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf121, (24, 2048), is_leaf=True) # primals_136 + buf122 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf122, (5, 2048), is_leaf=True) # primals_137 + buf123 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf123, (4,), is_leaf=True) # primals_138 + buf124 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf124, (32, 512), is_leaf=True) # primals_139 + buf125 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf125, (16, 2048), is_leaf=True) # primals_140 + buf126 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf126, (16,), is_leaf=True) # primals_141 + buf127 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf127, (1, 2048), is_leaf=True) # primals_143 + buf128 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf128, (8, 88, 2048), is_leaf=True) # primals_145 + buf129 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf129, (8, 128, 1408), is_leaf=True) # primals_146 + buf130 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf130, (8, 88, 2048), is_leaf=True) # primals_147 + buf131 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf131, (22, 2048), is_leaf=True) # primals_148 + buf132 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf132, (22, 2048), is_leaf=True) # primals_149 + buf133 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf133, (16, 2816), is_leaf=True) # primals_150 + buf134 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf134, (16,), is_leaf=True) # primals_151 + buf135 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf135, (24, 2048), is_leaf=True) # primals_152 + buf136 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf136, (5, 2048), is_leaf=True) # primals_153 + buf137 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf137, (4,), is_leaf=True) # primals_154 + buf138 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf138, (32, 512), is_leaf=True) # primals_155 + buf139 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf139, (16, 2048), is_leaf=True) # primals_156 + buf140 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf140, (16,), is_leaf=True) # primals_157 + buf141 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf141, (1, 2048), is_leaf=True) # primals_159 + buf142 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf142, (8, 88, 2048), is_leaf=True) # primals_161 + buf143 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf143, (8, 128, 1408), is_leaf=True) # primals_162 + buf144 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf144, (8, 88, 2048), is_leaf=True) # primals_163 + buf145 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf145, (22, 2048), is_leaf=True) # primals_164 + buf146 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf146, (22, 2048), is_leaf=True) # primals_165 + buf147 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf147, (16, 2816), is_leaf=True) # primals_166 + buf148 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf148, (16,), is_leaf=True) # primals_167 + buf149 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf149, (24, 2048), is_leaf=True) # primals_168 + buf150 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf150, (5, 2048), is_leaf=True) # primals_169 + buf151 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf151, (4,), is_leaf=True) # primals_170 + buf152 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf152, (32, 512), is_leaf=True) # primals_171 + buf153 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf153, (16, 2048), is_leaf=True) # primals_172 + buf154 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf154, (16,), is_leaf=True) # primals_173 + buf155 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf155, (1, 2048), is_leaf=True) # primals_175 + buf156 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf156, (8, 88, 2048), is_leaf=True) # primals_177 + buf157 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf157, (8, 128, 1408), is_leaf=True) # primals_178 + buf158 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf158, (8, 88, 2048), is_leaf=True) # primals_179 + buf159 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf159, (22, 2048), is_leaf=True) # primals_180 + buf160 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf160, (22, 2048), is_leaf=True) # primals_181 + buf161 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf161, (16, 2816), is_leaf=True) # primals_182 + buf162 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf162, (16,), is_leaf=True) # primals_183 + buf163 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf163, (24, 2048), is_leaf=True) # primals_184 + buf164 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf164, (5, 2048), is_leaf=True) # primals_185 + buf165 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf165, (4,), is_leaf=True) # primals_186 + buf166 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf166, (32, 512), is_leaf=True) # primals_187 + buf167 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf167, (16, 2048), is_leaf=True) # primals_188 + buf168 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf168, (16,), is_leaf=True) # primals_189 + buf169 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf169, (1, 2048), is_leaf=True) # primals_191 + buf170 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf170, (8, 88, 2048), is_leaf=True) # primals_193 + buf171 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf171, (8, 128, 1408), is_leaf=True) # primals_194 + buf172 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf172, (8, 88, 2048), is_leaf=True) # primals_195 + buf173 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf173, (22, 2048), is_leaf=True) # primals_196 + buf174 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf174, (22, 2048), is_leaf=True) # primals_197 + buf175 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf175, (16, 2816), is_leaf=True) # primals_198 + buf176 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf176, (16,), is_leaf=True) # primals_199 + buf177 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf177, (24, 2048), is_leaf=True) # primals_200 + buf178 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf178, (5, 2048), is_leaf=True) # primals_201 + buf179 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf179, (4,), is_leaf=True) # primals_202 + buf180 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf180, (32, 512), is_leaf=True) # primals_203 + buf181 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf181, (16, 2048), is_leaf=True) # primals_204 + buf182 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf182, (16,), is_leaf=True) # primals_205 + buf183 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf183, (1, 2048), is_leaf=True) # primals_207 + buf184 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf184, (8, 88, 2048), is_leaf=True) # primals_209 + buf185 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf185, (8, 128, 1408), is_leaf=True) # primals_210 + buf186 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf186, (8, 88, 2048), is_leaf=True) # primals_211 + buf187 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf187, (22, 2048), is_leaf=True) # primals_212 + buf188 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf188, (22, 2048), is_leaf=True) # primals_213 + buf189 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf189, (16, 2816), is_leaf=True) # primals_214 + buf190 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf190, (16,), is_leaf=True) # primals_215 + buf191 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf191, (24, 2048), is_leaf=True) # primals_216 + buf192 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf192, (5, 2048), is_leaf=True) # primals_217 + buf193 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf193, (4,), is_leaf=True) # primals_218 + buf194 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf194, (32, 512), is_leaf=True) # primals_219 + buf195 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf195, (16, 2048), is_leaf=True) # primals_220 + buf196 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf196, (16,), is_leaf=True) # primals_221 + buf197 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf197, (1, 2048), is_leaf=True) # primals_223 + buf198 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf198, (8, 88, 2048), is_leaf=True) # primals_225 + buf199 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf199, (8, 128, 1408), is_leaf=True) # primals_226 + buf200 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf200, (8, 88, 2048), is_leaf=True) # primals_227 + buf201 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf201, (22, 2048), is_leaf=True) # primals_228 + buf202 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf202, (22, 2048), is_leaf=True) # primals_229 + buf203 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf203, (16, 2816), is_leaf=True) # primals_230 + buf204 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf204, (16,), is_leaf=True) # primals_231 + buf205 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf205, (24, 2048), is_leaf=True) # primals_232 + buf206 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf206, (5, 2048), is_leaf=True) # primals_233 + buf207 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf207, (4,), is_leaf=True) # primals_234 + buf208 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf208, (32, 512), is_leaf=True) # primals_235 + buf209 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf209, (16, 2048), is_leaf=True) # primals_236 + buf210 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf210, (16,), is_leaf=True) # primals_237 + buf211 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf211, (1, 2048), is_leaf=True) # primals_239 + buf212 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf212, (8, 88, 2048), is_leaf=True) # primals_241 + buf213 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf213, (8, 128, 1408), is_leaf=True) # primals_242 + buf214 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf214, (8, 88, 2048), is_leaf=True) # primals_243 + buf215 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf215, (22, 2048), is_leaf=True) # primals_244 + buf216 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf216, (22, 2048), is_leaf=True) # primals_245 + buf217 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf217, (16, 2816), is_leaf=True) # primals_246 + buf218 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf218, (16,), is_leaf=True) # primals_247 + buf219 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf219, (24, 2048), is_leaf=True) # primals_248 + buf220 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf220, (5, 2048), is_leaf=True) # primals_249 + buf221 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf221, (4,), is_leaf=True) # primals_250 + buf222 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf222, (32, 512), is_leaf=True) # primals_251 + buf223 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf223, (16, 2048), is_leaf=True) # primals_252 + buf224 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf224, (16,), is_leaf=True) # primals_253 + buf225 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf225, (1, 2048), is_leaf=True) # primals_255 + buf226 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf226, (8, 88, 2048), is_leaf=True) # primals_257 + buf227 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf227, (8, 128, 1408), is_leaf=True) # primals_258 + buf228 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf228, (8, 88, 2048), is_leaf=True) # primals_259 + buf229 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf229, (22, 2048), is_leaf=True) # primals_260 + buf230 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf230, (22, 2048), is_leaf=True) # primals_261 + buf231 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf231, (16, 2816), is_leaf=True) # primals_262 + buf232 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf232, (16,), is_leaf=True) # primals_263 + buf233 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf233, (24, 2048), is_leaf=True) # primals_264 + buf234 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf234, (5, 2048), is_leaf=True) # primals_265 + buf235 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf235, (4,), is_leaf=True) # primals_266 + buf236 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf236, (32, 512), is_leaf=True) # primals_267 + buf237 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf237, (16, 2048), is_leaf=True) # primals_268 + buf238 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf238, (16,), is_leaf=True) # primals_269 + buf239 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf239, (1, 2048), is_leaf=True) # primals_271 + buf240 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf240, (8, 88, 2048), is_leaf=True) # primals_273 + buf241 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf241, (8, 128, 1408), is_leaf=True) # primals_274 + buf242 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf242, (8, 88, 2048), is_leaf=True) # primals_275 + buf243 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf243, (22, 2048), is_leaf=True) # primals_276 + buf244 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf244, (22, 2048), is_leaf=True) # primals_277 + buf245 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf245, (16, 2816), is_leaf=True) # primals_278 + buf246 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf246, (16,), is_leaf=True) # primals_279 + buf247 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf247, (24, 2048), is_leaf=True) # primals_280 + buf248 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf248, (5, 2048), is_leaf=True) # primals_281 + buf249 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf249, (4,), is_leaf=True) # primals_282 + buf250 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf250, (32, 512), is_leaf=True) # primals_283 + buf251 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf251, (16, 2048), is_leaf=True) # primals_284 + buf252 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf252, (16,), is_leaf=True) # primals_285 + buf253 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf253, (1, 2048), is_leaf=True) # primals_287 + buf254 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf254, (8, 88, 2048), is_leaf=True) # primals_289 + buf255 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf255, (8, 128, 1408), is_leaf=True) # primals_290 + buf256 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf256, (8, 88, 2048), is_leaf=True) # primals_291 + buf257 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf257, (22, 2048), is_leaf=True) # primals_292 + buf258 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf258, (22, 2048), is_leaf=True) # primals_293 + buf259 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf259, (16, 2816), is_leaf=True) # primals_294 + buf260 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf260, (16,), is_leaf=True) # primals_295 + buf261 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf261, (24, 2048), is_leaf=True) # primals_296 + buf262 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf262, (5, 2048), is_leaf=True) # primals_297 + buf263 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf263, (4,), is_leaf=True) # primals_298 + buf264 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf264, (32, 512), is_leaf=True) # primals_299 + buf265 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf265, (16, 2048), is_leaf=True) # primals_300 + buf266 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf266, (16,), is_leaf=True) # primals_301 + buf267 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf267, (1, 2048), is_leaf=True) # primals_303 + buf268 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf268, (8, 88, 2048), is_leaf=True) # primals_305 + buf269 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf269, (8, 128, 1408), is_leaf=True) # primals_306 + buf270 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf270, (8, 88, 2048), is_leaf=True) # primals_307 + buf271 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf271, (22, 2048), is_leaf=True) # primals_308 + buf272 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf272, (22, 2048), is_leaf=True) # primals_309 + buf273 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf273, (16, 2816), is_leaf=True) # primals_310 + buf274 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf274, (16,), is_leaf=True) # primals_311 + buf275 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf275, (24, 2048), is_leaf=True) # primals_312 + buf276 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf276, (5, 2048), is_leaf=True) # primals_313 + buf277 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf277, (4,), is_leaf=True) # primals_314 + buf278 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf278, (32, 512), is_leaf=True) # primals_315 + buf279 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf279, (16, 2048), is_leaf=True) # primals_316 + buf280 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf280, (16,), is_leaf=True) # primals_317 + buf281 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf281, (1, 2048), is_leaf=True) # primals_319 + buf282 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf282, (8, 88, 2048), is_leaf=True) # primals_321 + buf283 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf283, (8, 128, 1408), is_leaf=True) # primals_322 + buf284 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf284, (8, 88, 2048), is_leaf=True) # primals_323 + buf285 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf285, (22, 2048), is_leaf=True) # primals_324 + buf286 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf286, (22, 2048), is_leaf=True) # primals_325 + buf287 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf287, (16, 2816), is_leaf=True) # primals_326 + buf288 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf288, (16,), is_leaf=True) # primals_327 + buf289 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf289, (24, 2048), is_leaf=True) # primals_328 + buf290 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf290, (5, 2048), is_leaf=True) # primals_329 + buf291 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf291, (4,), is_leaf=True) # primals_330 + buf292 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf292, (32, 512), is_leaf=True) # primals_331 + buf293 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf293, (16, 2048), is_leaf=True) # primals_332 + buf294 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf294, (16,), is_leaf=True) # primals_333 + buf295 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf295, (1, 2048), is_leaf=True) # primals_335 + buf296 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf296, (8, 88, 2048), is_leaf=True) # primals_337 + buf297 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf297, (8, 128, 1408), is_leaf=True) # primals_338 + buf298 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf298, (8, 88, 2048), is_leaf=True) # primals_339 + buf299 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf299, (22, 2048), is_leaf=True) # primals_340 + buf300 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf300, (22, 2048), is_leaf=True) # primals_341 + buf301 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf301, (16, 2816), is_leaf=True) # primals_342 + buf302 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf302, (16,), is_leaf=True) # primals_343 + buf303 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf303, (24, 2048), is_leaf=True) # primals_344 + buf304 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf304, (5, 2048), is_leaf=True) # primals_345 + buf305 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf305, (4,), is_leaf=True) # primals_346 + buf306 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf306, (32, 512), is_leaf=True) # primals_347 + buf307 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf307, (16, 2048), is_leaf=True) # primals_348 + buf308 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf308, (16,), is_leaf=True) # primals_349 + buf309 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf309, (1, 2048), is_leaf=True) # primals_351 + buf310 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf310, (8, 88, 2048), is_leaf=True) # primals_353 + buf311 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf311, (8, 128, 1408), is_leaf=True) # primals_354 + buf312 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf312, (8, 88, 2048), is_leaf=True) # primals_355 + buf313 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf313, (22, 2048), is_leaf=True) # primals_356 + buf314 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf314, (22, 2048), is_leaf=True) # primals_357 + buf315 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf315, (16, 2816), is_leaf=True) # primals_358 + buf316 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf316, (16,), is_leaf=True) # primals_359 + buf317 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf317, (24, 2048), is_leaf=True) # primals_360 + buf318 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf318, (5, 2048), is_leaf=True) # primals_361 + buf319 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf319, (4,), is_leaf=True) # primals_362 + buf320 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf320, (32, 512), is_leaf=True) # primals_363 + buf321 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf321, (16, 2048), is_leaf=True) # primals_364 + buf322 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf322, (16,), is_leaf=True) # primals_365 + buf323 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf323, (1, 2048), is_leaf=True) # primals_367 + buf324 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf324, (8, 88, 2048), is_leaf=True) # primals_369 + buf325 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf325, (8, 128, 1408), is_leaf=True) # primals_370 + buf326 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf326, (8, 88, 2048), is_leaf=True) # primals_371 + buf327 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf327, (22, 2048), is_leaf=True) # primals_372 + buf328 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf328, (22, 2048), is_leaf=True) # primals_373 + buf329 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf329, (16, 2816), is_leaf=True) # primals_374 + buf330 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf330, (16,), is_leaf=True) # primals_375 + buf331 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf331, (24, 2048), is_leaf=True) # primals_376 + buf332 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf332, (5, 2048), is_leaf=True) # primals_377 + buf333 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf333, (4,), is_leaf=True) # primals_378 + buf334 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf334, (32, 512), is_leaf=True) # primals_379 + buf335 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf335, (16, 2048), is_leaf=True) # primals_380 + buf336 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf336, (16,), is_leaf=True) # primals_381 + buf337 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf337, (1, 2048), is_leaf=True) # primals_383 + buf338 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf338, (8, 88, 2048), is_leaf=True) # primals_385 + buf339 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf339, (8, 128, 1408), is_leaf=True) # primals_386 + buf340 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf340, (8, 88, 2048), is_leaf=True) # primals_387 + buf341 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf341, (22, 2048), is_leaf=True) # primals_388 + buf342 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf342, (22, 2048), is_leaf=True) # primals_389 + buf343 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf343, (16, 2816), is_leaf=True) # primals_390 + buf344 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf344, (16,), is_leaf=True) # primals_391 + buf345 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf345, (24, 2048), is_leaf=True) # primals_392 + buf346 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf346, (5, 2048), is_leaf=True) # primals_393 + buf347 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf347, (4,), is_leaf=True) # primals_394 + buf348 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf348, (32, 512), is_leaf=True) # primals_395 + buf349 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf349, (16, 2048), is_leaf=True) # primals_396 + buf350 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf350, (16,), is_leaf=True) # primals_397 + buf351 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf351, (1, 2048), is_leaf=True) # primals_399 + buf352 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf352, (8, 88, 2048), is_leaf=True) # primals_401 + buf353 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf353, (8, 128, 1408), is_leaf=True) # primals_402 + buf354 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf354, (8, 88, 2048), is_leaf=True) # primals_403 + buf355 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf355, (22, 2048), is_leaf=True) # primals_404 + buf356 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf356, (22, 2048), is_leaf=True) # primals_405 + buf357 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf357, (16, 2816), is_leaf=True) # primals_406 + buf358 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf358, (16,), is_leaf=True) # primals_407 + buf359 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf359, (24, 2048), is_leaf=True) # primals_408 + buf360 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf360, (5, 2048), is_leaf=True) # primals_409 + buf361 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf361, (4,), is_leaf=True) # primals_410 + buf362 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf362, (32, 512), is_leaf=True) # primals_411 + buf363 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf363, (16, 2048), is_leaf=True) # primals_412 + buf364 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf364, (16,), is_leaf=True) # primals_413 + buf365 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf365, (1, 2048), is_leaf=True) # primals_415 + buf366 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf366, (8, 88, 2048), is_leaf=True) # primals_417 + buf367 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf367, (8, 128, 1408), is_leaf=True) # primals_418 + buf368 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf368, (8, 88, 2048), is_leaf=True) # primals_419 + buf369 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf369, (22, 2048), is_leaf=True) # primals_420 + buf370 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf370, (22, 2048), is_leaf=True) # primals_421 + buf371 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf371, (16, 2816), is_leaf=True) # primals_422 + buf372 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf372, (16,), is_leaf=True) # primals_423 + buf373 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf373, (24, 2048), is_leaf=True) # primals_424 + buf374 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf374, (5, 2048), is_leaf=True) # primals_425 + buf375 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf375, (4,), is_leaf=True) # primals_426 + buf376 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf376, (32, 512), is_leaf=True) # primals_427 + buf377 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf377, (16, 2048), is_leaf=True) # primals_428 + buf378 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf378, (16,), is_leaf=True) # primals_429 + buf379 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf379, (1, 2048), is_leaf=True) # primals_431 + buf380 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf380, (8, 88, 2048), is_leaf=True) # primals_433 + buf381 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf381, (8, 128, 1408), is_leaf=True) # primals_434 + buf382 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf382, (8, 88, 2048), is_leaf=True) # primals_435 + buf383 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf383, (22, 2048), is_leaf=True) # primals_436 + buf384 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf384, (22, 2048), is_leaf=True) # primals_437 + buf385 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf385, (16, 2816), is_leaf=True) # primals_438 + buf386 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf386, (16,), is_leaf=True) # primals_439 + buf387 = reader.storage(None, 6553600, device=device(type='cuda', index=0)) + reader.tensor(buf387, (800, 2048), is_leaf=True) # primals_440 + buf388 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf388, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # embedding + buf389 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf389, (2, 4096, 1), is_leaf=True) # rsqrt + buf390 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf390, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_3 + buf391 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf391, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2 + buf392 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf392, (2, 4096, 1), is_leaf=True) # rsqrt_1 + buf393 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf393, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_17 + buf394 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf394, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_3 + buf395 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf395, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_4 + buf396 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf396, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_5 + buf397 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf397, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_6 + buf398 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf398, (2, 16, 4096), is_leaf=True) # getitem_7 + buf399 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf399, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # mm_3 + buf400 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf400, (2, 4096, 1), is_leaf=True) # rsqrt_2 + buf401 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf401, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_26 + buf402 = reader.storage(None, 179306496, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf402, (8192, 10944), dtype=torch.bfloat16, is_leaf=True) # mm_4 + buf403 = reader.storage(None, 179306496, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf403, (8192, 10944), dtype=torch.bfloat16, is_leaf=True) # mm_5 + buf404 = reader.storage(None, 179306496, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf404, (8192, 10944), dtype=torch.bfloat16, is_leaf=True) # view_32 + buf405 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf405, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_5 + buf406 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf406, (2, 4096, 1), is_leaf=True) # rsqrt_3 + buf407 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf407, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_36 + buf408 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf408, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_11 + buf409 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf409, (2, 4096, 1), is_leaf=True) # rsqrt_4 + buf410 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf410, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_50 + buf411 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf411, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_14 + buf412 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf412, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_15 + buf413 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf413, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_16 + buf414 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf414, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_15 + buf415 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf415, (2, 16, 4096), is_leaf=True) # getitem_16 + buf416 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf416, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_8 + buf417 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf417, (2, 4096, 1), is_leaf=True) # rsqrt_5 + buf418 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf418, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_58 + buf419 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf419, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_11 + buf420 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf420, (8192, 1), is_leaf=True) # amax + buf421 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf421, (8192, 1), is_leaf=True) # sum_1 + buf422 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf422, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_19 + buf423 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf423, (49152,), dtype=torch.int64, is_leaf=True) # getitem_21 + buf424 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf424, (49152,), dtype=torch.int64, is_leaf=True) # div_2 + buf425 = reader.storage(None, 32*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf425, (8*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_22 + buf426 = reader.storage(None, 32768*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf426, (8*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_1 + buf427 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf427, (8,), dtype=torch.int32, is_leaf=True) # cumsum_2 + buf428 = reader.storage(None, 22528*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf428, (8*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm + buf429 = reader.storage(None, 22528*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf429, (8*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_1 + buf430 = reader.storage(None, 22528*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf430, (8*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_35 + buf431 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf431, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_12 + buf432 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf432, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_13 + buf433 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf433, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_55 + buf434 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf434, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_73 + buf435 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf435, (2, 4096, 1), is_leaf=True) # rsqrt_6 + buf436 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf436, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_103 + buf437 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf437, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_121 + buf438 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf438, (2, 4096, 1), is_leaf=True) # rsqrt_7 + buf439 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf439, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_117 + buf440 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf440, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_29 + buf441 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf441, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_30 + buf442 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf442, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_31 + buf443 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf443, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_125 + buf444 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf444, (2, 16, 4096), is_leaf=True) # getitem_126 + buf445 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf445, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_76 + buf446 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf446, (2, 4096, 1), is_leaf=True) # rsqrt_8 + buf447 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf447, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_125 + buf448 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf448, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_19 + buf449 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf449, (8192, 1), is_leaf=True) # amax_1 + buf450 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf450, (8192, 1), is_leaf=True) # sum_5 + buf451 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf451, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_129 + buf452 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf452, (49152,), dtype=torch.int64, is_leaf=True) # getitem_131 + buf453 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf453, (49152,), dtype=torch.int64, is_leaf=True) # div_7 + buf454 = reader.storage(None, 32*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf454, (8*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_132 + buf455 = reader.storage(None, 32768*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf455, (8*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_3 + buf456 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf456, (8,), dtype=torch.int32, is_leaf=True) # cumsum_5 + buf457 = reader.storage(None, 22528*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf457, (8*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_3 + buf458 = reader.storage(None, 22528*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf458, (8*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_4 + buf459 = reader.storage(None, 22528*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf459, (8*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_84 + buf460 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf460, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_20 + buf461 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf461, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_21 + buf462 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf462, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_104 + buf463 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf463, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_141 + buf464 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf464, (2, 4096, 1), is_leaf=True) # rsqrt_9 + buf465 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf465, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_170 + buf466 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf466, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_231 + buf467 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf467, (2, 4096, 1), is_leaf=True) # rsqrt_10 + buf468 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf468, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_184 + buf469 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf469, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_44 + buf470 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf470, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_45 + buf471 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf471, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_46 + buf472 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf472, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_235 + buf473 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf473, (2, 16, 4096), is_leaf=True) # getitem_236 + buf474 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf474, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_144 + buf475 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf475, (2, 4096, 1), is_leaf=True) # rsqrt_11 + buf476 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf476, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_192 + buf477 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf477, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_27 + buf478 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf478, (8192, 1), is_leaf=True) # amax_2 + buf479 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf479, (8192, 1), is_leaf=True) # sum_9 + buf480 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf480, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_239 + buf481 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf481, (49152,), dtype=torch.int64, is_leaf=True) # getitem_241 + buf482 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf482, (49152,), dtype=torch.int64, is_leaf=True) # div_12 + buf483 = reader.storage(None, 32*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf483, (8*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_242 + buf484 = reader.storage(None, 32768*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf484, (8*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_5 + buf485 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf485, (8,), dtype=torch.int32, is_leaf=True) # cumsum_8 + buf486 = reader.storage(None, 22528*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf486, (8*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_6 + buf487 = reader.storage(None, 22528*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf487, (8*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_7 + buf488 = reader.storage(None, 22528*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf488, (8*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_133 + buf489 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf489, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_28 + buf490 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf490, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_29 + buf491 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf491, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_153 + buf492 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf492, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_209 + buf493 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf493, (2, 4096, 1), is_leaf=True) # rsqrt_12 + buf494 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf494, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_237 + buf495 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf495, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_341 + buf496 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf496, (2, 4096, 1), is_leaf=True) # rsqrt_13 + buf497 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf497, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_251 + buf498 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf498, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_59 + buf499 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf499, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_60 + buf500 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf500, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_61 + buf501 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf501, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_345 + buf502 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf502, (2, 16, 4096), is_leaf=True) # getitem_346 + buf503 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf503, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_212 + buf504 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf504, (2, 4096, 1), is_leaf=True) # rsqrt_14 + buf505 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf505, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_259 + buf506 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf506, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_35 + buf507 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf507, (8192, 1), is_leaf=True) # amax_3 + buf508 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf508, (8192, 1), is_leaf=True) # sum_13 + buf509 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf509, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_349 + buf510 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf510, (49152,), dtype=torch.int64, is_leaf=True) # getitem_351 + buf511 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf511, (49152,), dtype=torch.int64, is_leaf=True) # div_17 + buf512 = reader.storage(None, 32*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf512, (8*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_352 + buf513 = reader.storage(None, 32768*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf513, (8*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_7 + buf514 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf514, (8,), dtype=torch.int32, is_leaf=True) # cumsum_11 + buf515 = reader.storage(None, 22528*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf515, (8*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_9 + buf516 = reader.storage(None, 22528*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf516, (8*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_10 + buf517 = reader.storage(None, 22528*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf517, (8*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_182 + buf518 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf518, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_36 + buf519 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf519, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_37 + buf520 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf520, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_202 + buf521 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf521, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_277 + buf522 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf522, (2, 4096, 1), is_leaf=True) # rsqrt_15 + buf523 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf523, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_304 + buf524 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf524, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_451 + buf525 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf525, (2, 4096, 1), is_leaf=True) # rsqrt_16 + buf526 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf526, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_318 + buf527 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf527, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_74 + buf528 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf528, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_75 + buf529 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf529, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_76 + buf530 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf530, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_455 + buf531 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf531, (2, 16, 4096), is_leaf=True) # getitem_456 + buf532 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf532, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_280 + buf533 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf533, (2, 4096, 1), is_leaf=True) # rsqrt_17 + buf534 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf534, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_326 + buf535 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf535, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_43 + buf536 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf536, (8192, 1), is_leaf=True) # amax_4 + buf537 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf537, (8192, 1), is_leaf=True) # sum_17 + buf538 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf538, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_459 + buf539 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf539, (49152,), dtype=torch.int64, is_leaf=True) # getitem_461 + buf540 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf540, (49152,), dtype=torch.int64, is_leaf=True) # div_22 + buf541 = reader.storage(None, 32*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf541, (8*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_462 + buf542 = reader.storage(None, 32768*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf542, (8*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_9 + buf543 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf543, (8,), dtype=torch.int32, is_leaf=True) # cumsum_14 + buf544 = reader.storage(None, 22528*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf544, (8*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_12 + buf545 = reader.storage(None, 22528*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf545, (8*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_13 + buf546 = reader.storage(None, 22528*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf546, (8*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_231 + buf547 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf547, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_44 + buf548 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf548, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_45 + buf549 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf549, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_251 + buf550 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf550, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_345 + buf551 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf551, (2, 4096, 1), is_leaf=True) # rsqrt_18 + buf552 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf552, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_371 + buf553 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf553, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_561 + buf554 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf554, (2, 4096, 1), is_leaf=True) # rsqrt_19 + buf555 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf555, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_385 + buf556 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf556, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_89 + buf557 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf557, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_90 + buf558 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf558, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_91 + buf559 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf559, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_565 + buf560 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf560, (2, 16, 4096), is_leaf=True) # getitem_566 + buf561 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf561, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_348 + buf562 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf562, (2, 4096, 1), is_leaf=True) # rsqrt_20 + buf563 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf563, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_393 + buf564 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf564, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_51 + buf565 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf565, (8192, 1), is_leaf=True) # amax_5 + buf566 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf566, (8192, 1), is_leaf=True) # sum_21 + buf567 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf567, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_569 + buf568 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf568, (49152,), dtype=torch.int64, is_leaf=True) # getitem_571 + buf569 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf569, (49152,), dtype=torch.int64, is_leaf=True) # div_27 + buf570 = reader.storage(None, 32*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf570, (8*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_572 + buf571 = reader.storage(None, 32768*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf571, (8*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_11 + buf572 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf572, (8,), dtype=torch.int32, is_leaf=True) # cumsum_17 + buf573 = reader.storage(None, 22528*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf573, (8*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_15 + buf574 = reader.storage(None, 22528*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf574, (8*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_16 + buf575 = reader.storage(None, 22528*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf575, (8*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_280 + buf576 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf576, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_52 + buf577 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf577, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_53 + buf578 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf578, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_300 + buf579 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf579, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_413 + buf580 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf580, (2, 4096, 1), is_leaf=True) # rsqrt_21 + buf581 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf581, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_438 + buf582 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf582, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_671 + buf583 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf583, (2, 4096, 1), is_leaf=True) # rsqrt_22 + buf584 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf584, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_452 + buf585 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf585, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_104 + buf586 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf586, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_105 + buf587 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf587, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_106 + buf588 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf588, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_675 + buf589 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf589, (2, 16, 4096), is_leaf=True) # getitem_676 + buf590 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf590, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_416 + buf591 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf591, (2, 4096, 1), is_leaf=True) # rsqrt_23 + buf592 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf592, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_460 + buf593 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf593, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_59 + buf594 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf594, (8192, 1), is_leaf=True) # amax_6 + buf595 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf595, (8192, 1), is_leaf=True) # sum_25 + buf596 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf596, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_679 + buf597 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf597, (49152,), dtype=torch.int64, is_leaf=True) # getitem_681 + buf598 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf598, (49152,), dtype=torch.int64, is_leaf=True) # div_32 + buf599 = reader.storage(None, 32*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf599, (8*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_682 + buf600 = reader.storage(None, 32768*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf600, (8*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_13 + buf601 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf601, (8,), dtype=torch.int32, is_leaf=True) # cumsum_20 + buf602 = reader.storage(None, 22528*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf602, (8*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_18 + buf603 = reader.storage(None, 22528*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf603, (8*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_19 + buf604 = reader.storage(None, 22528*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf604, (8*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_329 + buf605 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf605, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_60 + buf606 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf606, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_61 + buf607 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf607, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_349 + buf608 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf608, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_481 + buf609 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf609, (2, 4096, 1), is_leaf=True) # rsqrt_24 + buf610 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf610, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_505 + buf611 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf611, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_781 + buf612 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf612, (2, 4096, 1), is_leaf=True) # rsqrt_25 + buf613 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf613, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_519 + buf614 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf614, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_119 + buf615 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf615, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_120 + buf616 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf616, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_121 + buf617 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf617, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_785 + buf618 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf618, (2, 16, 4096), is_leaf=True) # getitem_786 + buf619 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf619, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_484 + buf620 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf620, (2, 4096, 1), is_leaf=True) # rsqrt_26 + buf621 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf621, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_527 + buf622 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf622, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_67 + buf623 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf623, (8192, 1), is_leaf=True) # amax_7 + buf624 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf624, (8192, 1), is_leaf=True) # sum_29 + buf625 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf625, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_789 + buf626 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf626, (49152,), dtype=torch.int64, is_leaf=True) # getitem_791 + buf627 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf627, (49152,), dtype=torch.int64, is_leaf=True) # div_37 + buf628 = reader.storage(None, 32*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf628, (8*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_792 + buf629 = reader.storage(None, 32768*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf629, (8*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_15 + buf630 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf630, (8,), dtype=torch.int32, is_leaf=True) # cumsum_23 + buf631 = reader.storage(None, 22528*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf631, (8*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_21 + buf632 = reader.storage(None, 22528*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf632, (8*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_22 + buf633 = reader.storage(None, 22528*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf633, (8*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_378 + buf634 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf634, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_68 + buf635 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf635, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_69 + buf636 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf636, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_398 + buf637 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf637, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_549 + buf638 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf638, (2, 4096, 1), is_leaf=True) # rsqrt_27 + buf639 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf639, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_572 + buf640 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf640, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_891 + buf641 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf641, (2, 4096, 1), is_leaf=True) # rsqrt_28 + buf642 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf642, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_586 + buf643 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf643, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_134 + buf644 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf644, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_135 + buf645 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf645, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_136 + buf646 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf646, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_895 + buf647 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf647, (2, 16, 4096), is_leaf=True) # getitem_896 + buf648 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf648, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_552 + buf649 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf649, (2, 4096, 1), is_leaf=True) # rsqrt_29 + buf650 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf650, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_594 + buf651 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf651, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_75 + buf652 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf652, (8192, 1), is_leaf=True) # amax_8 + buf653 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf653, (8192, 1), is_leaf=True) # sum_33 + buf654 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf654, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_899 + buf655 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf655, (49152,), dtype=torch.int64, is_leaf=True) # getitem_901 + buf656 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf656, (49152,), dtype=torch.int64, is_leaf=True) # div_42 + buf657 = reader.storage(None, 32*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf657, (8*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_902 + buf658 = reader.storage(None, 32768*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf658, (8*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_17 + buf659 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf659, (8,), dtype=torch.int32, is_leaf=True) # cumsum_26 + buf660 = reader.storage(None, 22528*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf660, (8*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_24 + buf661 = reader.storage(None, 22528*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf661, (8*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_25 + buf662 = reader.storage(None, 22528*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf662, (8*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_427 + buf663 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf663, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_76 + buf664 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf664, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_77 + buf665 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf665, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_447 + buf666 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf666, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_617 + buf667 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf667, (2, 4096, 1), is_leaf=True) # rsqrt_30 + buf668 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf668, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_639 + buf669 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf669, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1001 + buf670 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf670, (2, 4096, 1), is_leaf=True) # rsqrt_31 + buf671 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf671, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_653 + buf672 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf672, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_149 + buf673 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf673, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_150 + buf674 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf674, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_151 + buf675 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf675, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1005 + buf676 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf676, (2, 16, 4096), is_leaf=True) # getitem_1006 + buf677 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf677, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_620 + buf678 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf678, (2, 4096, 1), is_leaf=True) # rsqrt_32 + buf679 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf679, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_661 + buf680 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf680, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_83 + buf681 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf681, (8192, 1), is_leaf=True) # amax_9 + buf682 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf682, (8192, 1), is_leaf=True) # sum_37 + buf683 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf683, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_1009 + buf684 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf684, (49152,), dtype=torch.int64, is_leaf=True) # getitem_1011 + buf685 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf685, (49152,), dtype=torch.int64, is_leaf=True) # div_47 + buf686 = reader.storage(None, 32*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf686, (8*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_1012 + buf687 = reader.storage(None, 32768*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf687, (8*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_19 + buf688 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf688, (8,), dtype=torch.int32, is_leaf=True) # cumsum_29 + buf689 = reader.storage(None, 22528*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf689, (8*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_27 + buf690 = reader.storage(None, 22528*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf690, (8*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_28 + buf691 = reader.storage(None, 22528*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf691, (8*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_476 + buf692 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf692, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_84 + buf693 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf693, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_85 + buf694 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf694, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_496 + buf695 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf695, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_685 + buf696 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf696, (2, 4096, 1), is_leaf=True) # rsqrt_33 + buf697 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf697, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_706 + buf698 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf698, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1111 + buf699 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf699, (2, 4096, 1), is_leaf=True) # rsqrt_34 + buf700 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf700, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_720 + buf701 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf701, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_164 + buf702 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf702, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_165 + buf703 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf703, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_166 + buf704 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf704, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1115 + buf705 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf705, (2, 16, 4096), is_leaf=True) # getitem_1116 + buf706 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf706, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_688 + buf707 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf707, (2, 4096, 1), is_leaf=True) # rsqrt_35 + buf708 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf708, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_728 + buf709 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf709, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_91 + buf710 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf710, (8192, 1), is_leaf=True) # amax_10 + buf711 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf711, (8192, 1), is_leaf=True) # sum_41 + buf712 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf712, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_1119 + buf713 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf713, (49152,), dtype=torch.int64, is_leaf=True) # getitem_1121 + buf714 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf714, (49152,), dtype=torch.int64, is_leaf=True) # div_52 + buf715 = reader.storage(None, 32*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf715, (8*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_1122 + buf716 = reader.storage(None, 32768*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf716, (8*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_21 + buf717 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf717, (8,), dtype=torch.int32, is_leaf=True) # cumsum_32 + buf718 = reader.storage(None, 22528*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf718, (8*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_30 + buf719 = reader.storage(None, 22528*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf719, (8*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_31 + buf720 = reader.storage(None, 22528*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf720, (8*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_525 + buf721 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf721, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_92 + buf722 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf722, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_93 + buf723 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf723, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_545 + buf724 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf724, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_753 + buf725 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf725, (2, 4096, 1), is_leaf=True) # rsqrt_36 + buf726 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf726, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_773 + buf727 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf727, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1221 + buf728 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf728, (2, 4096, 1), is_leaf=True) # rsqrt_37 + buf729 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf729, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_787 + buf730 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf730, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_179 + buf731 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf731, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_180 + buf732 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf732, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_181 + buf733 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf733, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1225 + buf734 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf734, (2, 16, 4096), is_leaf=True) # getitem_1226 + buf735 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf735, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_756 + buf736 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf736, (2, 4096, 1), is_leaf=True) # rsqrt_38 + buf737 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf737, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_795 + buf738 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf738, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_99 + buf739 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf739, (8192, 1), is_leaf=True) # amax_11 + buf740 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf740, (8192, 1), is_leaf=True) # sum_45 + buf741 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf741, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_1229 + buf742 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf742, (49152,), dtype=torch.int64, is_leaf=True) # getitem_1231 + buf743 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf743, (49152,), dtype=torch.int64, is_leaf=True) # div_57 + buf744 = reader.storage(None, 32*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf744, (8*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_1232 + buf745 = reader.storage(None, 32768*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf745, (8*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_23 + buf746 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf746, (8,), dtype=torch.int32, is_leaf=True) # cumsum_35 + buf747 = reader.storage(None, 22528*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf747, (8*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_33 + buf748 = reader.storage(None, 22528*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf748, (8*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_34 + buf749 = reader.storage(None, 22528*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf749, (8*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_574 + buf750 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf750, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_100 + buf751 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf751, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_101 + buf752 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf752, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_594 + buf753 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf753, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_821 + buf754 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf754, (2, 4096, 1), is_leaf=True) # rsqrt_39 + buf755 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf755, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_840 + buf756 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf756, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1331 + buf757 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf757, (2, 4096, 1), is_leaf=True) # rsqrt_40 + buf758 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf758, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_854 + buf759 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf759, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_194 + buf760 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf760, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_195 + buf761 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf761, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_196 + buf762 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf762, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1335 + buf763 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf763, (2, 16, 4096), is_leaf=True) # getitem_1336 + buf764 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf764, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_824 + buf765 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf765, (2, 4096, 1), is_leaf=True) # rsqrt_41 + buf766 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf766, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_862 + buf767 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf767, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_107 + buf768 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf768, (8192, 1), is_leaf=True) # amax_12 + buf769 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf769, (8192, 1), is_leaf=True) # sum_49 + buf770 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf770, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_1339 + buf771 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf771, (49152,), dtype=torch.int64, is_leaf=True) # getitem_1341 + buf772 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf772, (49152,), dtype=torch.int64, is_leaf=True) # div_62 + buf773 = reader.storage(None, 32*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf773, (8*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_1342 + buf774 = reader.storage(None, 32768*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf774, (8*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_25 + buf775 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf775, (8,), dtype=torch.int32, is_leaf=True) # cumsum_38 + buf776 = reader.storage(None, 22528*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf776, (8*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_36 + buf777 = reader.storage(None, 22528*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf777, (8*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_37 + buf778 = reader.storage(None, 22528*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf778, (8*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_623 + buf779 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf779, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_108 + buf780 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf780, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_109 + buf781 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf781, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_643 + buf782 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf782, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_889 + buf783 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf783, (2, 4096, 1), is_leaf=True) # rsqrt_42 + buf784 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf784, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_907 + buf785 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf785, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1441 + buf786 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf786, (2, 4096, 1), is_leaf=True) # rsqrt_43 + buf787 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf787, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_921 + buf788 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf788, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_209 + buf789 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf789, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_210 + buf790 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf790, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_211 + buf791 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf791, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1445 + buf792 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf792, (2, 16, 4096), is_leaf=True) # getitem_1446 + buf793 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf793, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_892 + buf794 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf794, (2, 4096, 1), is_leaf=True) # rsqrt_44 + buf795 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf795, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_929 + buf796 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf796, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_115 + buf797 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf797, (8192, 1), is_leaf=True) # amax_13 + buf798 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf798, (8192, 1), is_leaf=True) # sum_53 + buf799 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf799, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_1449 + buf800 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf800, (49152,), dtype=torch.int64, is_leaf=True) # getitem_1451 + buf801 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf801, (49152,), dtype=torch.int64, is_leaf=True) # div_67 + buf802 = reader.storage(None, 32*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf802, (8*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_1452 + buf803 = reader.storage(None, 32768*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf803, (8*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_27 + buf804 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf804, (8,), dtype=torch.int32, is_leaf=True) # cumsum_41 + buf805 = reader.storage(None, 22528*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf805, (8*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_39 + buf806 = reader.storage(None, 22528*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf806, (8*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_40 + buf807 = reader.storage(None, 22528*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf807, (8*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_672 + buf808 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf808, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_116 + buf809 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf809, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_117 + buf810 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf810, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_692 + buf811 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf811, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_957 + buf812 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf812, (2, 4096, 1), is_leaf=True) # rsqrt_45 + buf813 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf813, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_974 + buf814 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf814, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1551 + buf815 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf815, (2, 4096, 1), is_leaf=True) # rsqrt_46 + buf816 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf816, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_988 + buf817 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf817, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_224 + buf818 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf818, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_225 + buf819 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf819, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_226 + buf820 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf820, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1555 + buf821 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf821, (2, 16, 4096), is_leaf=True) # getitem_1556 + buf822 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf822, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_960 + buf823 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf823, (2, 4096, 1), is_leaf=True) # rsqrt_47 + buf824 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf824, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_996 + buf825 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf825, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_123 + buf826 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf826, (8192, 1), is_leaf=True) # amax_14 + buf827 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf827, (8192, 1), is_leaf=True) # sum_57 + buf828 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf828, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_1559 + buf829 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf829, (49152,), dtype=torch.int64, is_leaf=True) # getitem_1561 + buf830 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf830, (49152,), dtype=torch.int64, is_leaf=True) # div_72 + buf831 = reader.storage(None, 32*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf831, (8*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_1562 + buf832 = reader.storage(None, 32768*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf832, (8*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_29 + buf833 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf833, (8,), dtype=torch.int32, is_leaf=True) # cumsum_44 + buf834 = reader.storage(None, 22528*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf834, (8*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_42 + buf835 = reader.storage(None, 22528*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf835, (8*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_43 + buf836 = reader.storage(None, 22528*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf836, (8*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_721 + buf837 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf837, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_124 + buf838 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf838, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_125 + buf839 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf839, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_741 + buf840 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf840, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1025 + buf841 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf841, (2, 4096, 1), is_leaf=True) # rsqrt_48 + buf842 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf842, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1041 + buf843 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf843, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1661 + buf844 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf844, (2, 4096, 1), is_leaf=True) # rsqrt_49 + buf845 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf845, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1055 + buf846 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf846, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_239 + buf847 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf847, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_240 + buf848 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf848, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_241 + buf849 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf849, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1665 + buf850 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf850, (2, 16, 4096), is_leaf=True) # getitem_1666 + buf851 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf851, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1028 + buf852 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf852, (2, 4096, 1), is_leaf=True) # rsqrt_50 + buf853 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf853, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1063 + buf854 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf854, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_131 + buf855 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf855, (8192, 1), is_leaf=True) # amax_15 + buf856 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf856, (8192, 1), is_leaf=True) # sum_61 + buf857 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf857, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_1669 + buf858 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf858, (49152,), dtype=torch.int64, is_leaf=True) # getitem_1671 + buf859 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf859, (49152,), dtype=torch.int64, is_leaf=True) # div_77 + buf860 = reader.storage(None, 32*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf860, (8*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_1672 + buf861 = reader.storage(None, 32768*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf861, (8*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_31 + buf862 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf862, (8,), dtype=torch.int32, is_leaf=True) # cumsum_47 + buf863 = reader.storage(None, 22528*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf863, (8*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_45 + buf864 = reader.storage(None, 22528*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf864, (8*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_46 + buf865 = reader.storage(None, 22528*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf865, (8*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_770 + buf866 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf866, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_132 + buf867 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf867, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_133 + buf868 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf868, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_790 + buf869 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf869, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1093 + buf870 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf870, (2, 4096, 1), is_leaf=True) # rsqrt_51 + buf871 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf871, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1108 + buf872 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf872, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1771 + buf873 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf873, (2, 4096, 1), is_leaf=True) # rsqrt_52 + buf874 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf874, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1122 + buf875 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf875, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_254 + buf876 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf876, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_255 + buf877 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf877, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_256 + buf878 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf878, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1775 + buf879 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf879, (2, 16, 4096), is_leaf=True) # getitem_1776 + buf880 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf880, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1096 + buf881 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf881, (2, 4096, 1), is_leaf=True) # rsqrt_53 + buf882 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf882, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1130 + buf883 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf883, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_139 + buf884 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf884, (8192, 1), is_leaf=True) # amax_16 + buf885 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf885, (8192, 1), is_leaf=True) # sum_65 + buf886 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf886, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_1779 + buf887 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf887, (49152,), dtype=torch.int64, is_leaf=True) # getitem_1781 + buf888 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf888, (49152,), dtype=torch.int64, is_leaf=True) # div_82 + buf889 = reader.storage(None, 32*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf889, (8*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_1782 + buf890 = reader.storage(None, 32768*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf890, (8*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_33 + buf891 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf891, (8,), dtype=torch.int32, is_leaf=True) # cumsum_50 + buf892 = reader.storage(None, 22528*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf892, (8*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_48 + buf893 = reader.storage(None, 22528*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf893, (8*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_49 + buf894 = reader.storage(None, 22528*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf894, (8*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_819 + buf895 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf895, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_140 + buf896 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf896, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_141 + buf897 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf897, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_839 + buf898 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf898, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1161 + buf899 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf899, (2, 4096, 1), is_leaf=True) # rsqrt_54 + buf900 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf900, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1175 + buf901 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf901, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1881 + buf902 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf902, (2, 4096, 1), is_leaf=True) # rsqrt_55 + buf903 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf903, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1189 + buf904 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf904, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_269 + buf905 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf905, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_270 + buf906 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf906, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_271 + buf907 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf907, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1885 + buf908 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf908, (2, 16, 4096), is_leaf=True) # getitem_1886 + buf909 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf909, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1164 + buf910 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf910, (2, 4096, 1), is_leaf=True) # rsqrt_56 + buf911 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf911, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1197 + buf912 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf912, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_147 + buf913 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf913, (8192, 1), is_leaf=True) # amax_17 + buf914 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf914, (8192, 1), is_leaf=True) # sum_69 + buf915 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf915, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_1889 + buf916 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf916, (49152,), dtype=torch.int64, is_leaf=True) # getitem_1891 + buf917 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf917, (49152,), dtype=torch.int64, is_leaf=True) # div_87 + buf918 = reader.storage(None, 32*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf918, (8*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_1892 + buf919 = reader.storage(None, 32768*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf919, (8*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_35 + buf920 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf920, (8,), dtype=torch.int32, is_leaf=True) # cumsum_53 + buf921 = reader.storage(None, 22528*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf921, (8*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_51 + buf922 = reader.storage(None, 22528*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf922, (8*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_52 + buf923 = reader.storage(None, 22528*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf923, (8*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_868 + buf924 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf924, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_148 + buf925 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf925, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_149 + buf926 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf926, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_888 + buf927 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf927, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1229 + buf928 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf928, (2, 4096, 1), is_leaf=True) # rsqrt_57 + buf929 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf929, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1242 + buf930 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf930, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1991 + buf931 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf931, (2, 4096, 1), is_leaf=True) # rsqrt_58 + buf932 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf932, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1256 + buf933 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf933, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_284 + buf934 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf934, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_285 + buf935 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf935, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_286 + buf936 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf936, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1995 + buf937 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf937, (2, 16, 4096), is_leaf=True) # getitem_1996 + buf938 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf938, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1232 + buf939 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf939, (2, 4096, 1), is_leaf=True) # rsqrt_59 + buf940 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf940, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1264 + buf941 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf941, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_155 + buf942 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf942, (8192, 1), is_leaf=True) # amax_18 + buf943 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf943, (8192, 1), is_leaf=True) # sum_73 + buf944 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf944, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_1999 + buf945 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf945, (49152,), dtype=torch.int64, is_leaf=True) # getitem_2001 + buf946 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf946, (49152,), dtype=torch.int64, is_leaf=True) # div_92 + buf947 = reader.storage(None, 32*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf947, (8*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_2002 + buf948 = reader.storage(None, 32768*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf948, (8*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_37 + buf949 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf949, (8,), dtype=torch.int32, is_leaf=True) # cumsum_56 + buf950 = reader.storage(None, 22528*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf950, (8*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_54 + buf951 = reader.storage(None, 22528*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf951, (8*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_55 + buf952 = reader.storage(None, 22528*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf952, (8*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_917 + buf953 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf953, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_156 + buf954 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf954, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_157 + buf955 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf955, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_937 + buf956 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf956, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1297 + buf957 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf957, (2, 4096, 1), is_leaf=True) # rsqrt_60 + buf958 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf958, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1309 + buf959 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf959, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2101 + buf960 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf960, (2, 4096, 1), is_leaf=True) # rsqrt_61 + buf961 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf961, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1323 + buf962 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf962, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_299 + buf963 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf963, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_300 + buf964 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf964, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_301 + buf965 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf965, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2105 + buf966 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf966, (2, 16, 4096), is_leaf=True) # getitem_2106 + buf967 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf967, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1300 + buf968 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf968, (2, 4096, 1), is_leaf=True) # rsqrt_62 + buf969 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf969, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1331 + buf970 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf970, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_163 + buf971 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf971, (8192, 1), is_leaf=True) # amax_19 + buf972 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf972, (8192, 1), is_leaf=True) # sum_77 + buf973 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf973, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_2109 + buf974 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf974, (49152,), dtype=torch.int64, is_leaf=True) # getitem_2111 + buf975 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf975, (49152,), dtype=torch.int64, is_leaf=True) # div_97 + buf976 = reader.storage(None, 32*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf976, (8*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_2112 + buf977 = reader.storage(None, 32768*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf977, (8*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_39 + buf978 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf978, (8,), dtype=torch.int32, is_leaf=True) # cumsum_59 + buf979 = reader.storage(None, 22528*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf979, (8*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_57 + buf980 = reader.storage(None, 22528*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf980, (8*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_58 + buf981 = reader.storage(None, 22528*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf981, (8*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_966 + buf982 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf982, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_164 + buf983 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf983, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_165 + buf984 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf984, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_986 + buf985 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf985, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1365 + buf986 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf986, (2, 4096, 1), is_leaf=True) # rsqrt_63 + buf987 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf987, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1376 + buf988 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf988, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2211 + buf989 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf989, (2, 4096, 1), is_leaf=True) # rsqrt_64 + buf990 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf990, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1390 + buf991 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf991, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_314 + buf992 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf992, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_315 + buf993 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf993, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_316 + buf994 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf994, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2215 + buf995 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf995, (2, 16, 4096), is_leaf=True) # getitem_2216 + buf996 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf996, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1368 + buf997 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf997, (2, 4096, 1), is_leaf=True) # rsqrt_65 + buf998 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf998, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1398 + buf999 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf999, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_171 + buf1000 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1000, (8192, 1), is_leaf=True) # amax_20 + buf1001 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1001, (8192, 1), is_leaf=True) # sum_81 + buf1002 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1002, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_2219 + buf1003 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1003, (49152,), dtype=torch.int64, is_leaf=True) # getitem_2221 + buf1004 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1004, (49152,), dtype=torch.int64, is_leaf=True) # div_102 + buf1005 = reader.storage(None, 32*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1005, (8*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_2222 + buf1006 = reader.storage(None, 32768*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1006, (8*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_41 + buf1007 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1007, (8,), dtype=torch.int32, is_leaf=True) # cumsum_62 + buf1008 = reader.storage(None, 22528*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1008, (8*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_60 + buf1009 = reader.storage(None, 22528*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1009, (8*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_61 + buf1010 = reader.storage(None, 22528*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1010, (8*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1015 + buf1011 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1011, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_172 + buf1012 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1012, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_173 + buf1013 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1013, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1035 + buf1014 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1014, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1433 + buf1015 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1015, (2, 4096, 1), is_leaf=True) # rsqrt_66 + buf1016 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1016, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1443 + buf1017 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1017, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2321 + buf1018 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1018, (2, 4096, 1), is_leaf=True) # rsqrt_67 + buf1019 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1019, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1457 + buf1020 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1020, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_329 + buf1021 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1021, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_330 + buf1022 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1022, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_331 + buf1023 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1023, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2325 + buf1024 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf1024, (2, 16, 4096), is_leaf=True) # getitem_2326 + buf1025 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1025, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1436 + buf1026 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1026, (2, 4096, 1), is_leaf=True) # rsqrt_68 + buf1027 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1027, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1465 + buf1028 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1028, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_179 + buf1029 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1029, (8192, 1), is_leaf=True) # amax_21 + buf1030 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1030, (8192, 1), is_leaf=True) # sum_85 + buf1031 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1031, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_2329 + buf1032 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1032, (49152,), dtype=torch.int64, is_leaf=True) # getitem_2331 + buf1033 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1033, (49152,), dtype=torch.int64, is_leaf=True) # div_107 + buf1034 = reader.storage(None, 32*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1034, (8*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_2332 + buf1035 = reader.storage(None, 32768*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1035, (8*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_43 + buf1036 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1036, (8,), dtype=torch.int32, is_leaf=True) # cumsum_65 + buf1037 = reader.storage(None, 22528*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1037, (8*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_63 + buf1038 = reader.storage(None, 22528*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1038, (8*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_64 + buf1039 = reader.storage(None, 22528*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1039, (8*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1064 + buf1040 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1040, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_180 + buf1041 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1041, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_181 + buf1042 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1042, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1084 + buf1043 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1043, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1501 + buf1044 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1044, (2, 4096, 1), is_leaf=True) # rsqrt_69 + buf1045 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1045, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1510 + buf1046 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1046, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2431 + buf1047 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1047, (2, 4096, 1), is_leaf=True) # rsqrt_70 + buf1048 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1048, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1524 + buf1049 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1049, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_344 + buf1050 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1050, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_345 + buf1051 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1051, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_346 + buf1052 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1052, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2435 + buf1053 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf1053, (2, 16, 4096), is_leaf=True) # getitem_2436 + buf1054 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1054, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1504 + buf1055 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1055, (2, 4096, 1), is_leaf=True) # rsqrt_71 + buf1056 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1056, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1532 + buf1057 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1057, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_187 + buf1058 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1058, (8192, 1), is_leaf=True) # amax_22 + buf1059 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1059, (8192, 1), is_leaf=True) # sum_89 + buf1060 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1060, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_2439 + buf1061 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1061, (49152,), dtype=torch.int64, is_leaf=True) # getitem_2441 + buf1062 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1062, (49152,), dtype=torch.int64, is_leaf=True) # div_112 + buf1063 = reader.storage(None, 32*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1063, (8*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_2442 + buf1064 = reader.storage(None, 32768*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1064, (8*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_45 + buf1065 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1065, (8,), dtype=torch.int32, is_leaf=True) # cumsum_68 + buf1066 = reader.storage(None, 22528*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1066, (8*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_66 + buf1067 = reader.storage(None, 22528*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1067, (8*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_67 + buf1068 = reader.storage(None, 22528*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1068, (8*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1113 + buf1069 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1069, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_188 + buf1070 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1070, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_189 + buf1071 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1071, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1133 + buf1072 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1072, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1569 + buf1073 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1073, (2, 4096, 1), is_leaf=True) # rsqrt_72 + buf1074 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1074, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1577 + buf1075 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1075, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2541 + buf1076 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1076, (2, 4096, 1), is_leaf=True) # rsqrt_73 + buf1077 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1077, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1591 + buf1078 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1078, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_359 + buf1079 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1079, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_360 + buf1080 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1080, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_361 + buf1081 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1081, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2545 + buf1082 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf1082, (2, 16, 4096), is_leaf=True) # getitem_2546 + buf1083 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1083, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1572 + buf1084 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1084, (2, 4096, 1), is_leaf=True) # rsqrt_74 + buf1085 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1085, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1599 + buf1086 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1086, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_195 + buf1087 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1087, (8192, 1), is_leaf=True) # amax_23 + buf1088 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1088, (8192, 1), is_leaf=True) # sum_93 + buf1089 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1089, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_2549 + buf1090 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1090, (49152,), dtype=torch.int64, is_leaf=True) # getitem_2551 + buf1091 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1091, (49152,), dtype=torch.int64, is_leaf=True) # div_117 + buf1092 = reader.storage(None, 32*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1092, (8*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_2552 + buf1093 = reader.storage(None, 32768*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1093, (8*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_47 + buf1094 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1094, (8,), dtype=torch.int32, is_leaf=True) # cumsum_71 + buf1095 = reader.storage(None, 22528*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1095, (8*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_69 + buf1096 = reader.storage(None, 22528*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1096, (8*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_70 + buf1097 = reader.storage(None, 22528*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1097, (8*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1162 + buf1098 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1098, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_196 + buf1099 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1099, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_197 + buf1100 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1100, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1182 + buf1101 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1101, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1637 + buf1102 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1102, (2, 4096, 1), is_leaf=True) # rsqrt_75 + buf1103 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1103, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1644 + buf1104 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1104, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2651 + buf1105 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1105, (2, 4096, 1), is_leaf=True) # rsqrt_76 + buf1106 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1106, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1658 + buf1107 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1107, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_374 + buf1108 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1108, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_375 + buf1109 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1109, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_376 + buf1110 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1110, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2655 + buf1111 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf1111, (2, 16, 4096), is_leaf=True) # getitem_2656 + buf1112 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1112, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1640 + buf1113 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1113, (2, 4096, 1), is_leaf=True) # rsqrt_77 + buf1114 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1114, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1666 + buf1115 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1115, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_203 + buf1116 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1116, (8192, 1), is_leaf=True) # amax_24 + buf1117 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1117, (8192, 1), is_leaf=True) # sum_97 + buf1118 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1118, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_2659 + buf1119 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1119, (49152,), dtype=torch.int64, is_leaf=True) # getitem_2661 + buf1120 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1120, (49152,), dtype=torch.int64, is_leaf=True) # div_122 + buf1121 = reader.storage(None, 32*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1121, (8*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_2662 + buf1122 = reader.storage(None, 32768*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1122, (8*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_49 + buf1123 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1123, (8,), dtype=torch.int32, is_leaf=True) # cumsum_74 + buf1124 = reader.storage(None, 22528*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1124, (8*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_72 + buf1125 = reader.storage(None, 22528*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1125, (8*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_73 + buf1126 = reader.storage(None, 22528*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1126, (8*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1211 + buf1127 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1127, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_204 + buf1128 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1128, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_205 + buf1129 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1129, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1231 + buf1130 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1130, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1705 + buf1131 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1131, (2, 4096, 1), is_leaf=True) # rsqrt_78 + buf1132 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1132, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1711 + buf1133 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1133, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2761 + buf1134 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1134, (2, 4096, 1), is_leaf=True) # rsqrt_79 + buf1135 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1135, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1725 + buf1136 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1136, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_389 + buf1137 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1137, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_390 + buf1138 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1138, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_391 + buf1139 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1139, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2765 + buf1140 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf1140, (2, 16, 4096), is_leaf=True) # getitem_2766 + buf1141 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1141, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1708 + buf1142 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1142, (2, 4096, 1), is_leaf=True) # rsqrt_80 + buf1143 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1143, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1733 + buf1144 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1144, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_211 + buf1145 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1145, (8192, 1), is_leaf=True) # amax_25 + buf1146 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1146, (8192, 1), is_leaf=True) # sum_101 + buf1147 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1147, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_2769 + buf1148 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1148, (49152,), dtype=torch.int64, is_leaf=True) # getitem_2771 + buf1149 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1149, (49152,), dtype=torch.int64, is_leaf=True) # div_127 + buf1150 = reader.storage(None, 32*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1150, (8*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_2772 + buf1151 = reader.storage(None, 32768*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1151, (8*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_51 + buf1152 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1152, (8,), dtype=torch.int32, is_leaf=True) # cumsum_77 + buf1153 = reader.storage(None, 22528*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1153, (8*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_75 + buf1154 = reader.storage(None, 22528*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1154, (8*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_76 + buf1155 = reader.storage(None, 22528*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1155, (8*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1260 + buf1156 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1156, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_212 + buf1157 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1157, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_213 + buf1158 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1158, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1280 + buf1159 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1159, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1773 + buf1160 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1160, (2, 4096, 1), is_leaf=True) # rsqrt_81 + buf1161 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1161, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1778 + buf1162 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1162, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_406 + buf1163 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1163, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_407 + buf1164 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1164, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_422 + buf1165 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1165, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_426 + buf1166 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1166, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_430 + buf1167 = reader.storage(None, 0, device=device(type='cuda', index=0)) + reader.tensor(buf1167, (0, 2048), is_leaf=True) # full_default_54 + buf1168 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1168, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_456 + buf1169 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1169, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_457 + buf1170 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1170, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_472 + buf1171 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1171, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_476 + buf1172 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1172, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_480 + buf1173 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1173, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_506 + buf1174 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1174, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_507 + buf1175 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1175, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_522 + buf1176 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1176, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_526 + buf1177 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1177, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_530 + buf1178 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1178, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_556 + buf1179 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1179, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_557 + buf1180 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1180, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_572 + buf1181 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1181, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_576 + buf1182 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1182, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_580 + buf1183 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1183, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_606 + buf1184 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1184, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_607 + buf1185 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1185, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_622 + buf1186 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1186, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_626 + buf1187 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1187, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_630 + buf1188 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1188, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_656 + buf1189 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1189, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_657 + buf1190 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1190, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_672 + buf1191 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1191, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_676 + buf1192 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1192, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_680 + buf1193 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1193, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_706 + buf1194 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1194, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_707 + buf1195 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1195, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_722 + buf1196 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1196, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_726 + buf1197 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1197, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_730 + buf1198 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1198, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_756 + buf1199 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1199, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_757 + buf1200 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1200, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_772 + buf1201 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1201, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_776 + buf1202 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1202, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_780 + buf1203 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1203, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_806 + buf1204 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1204, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_807 + buf1205 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1205, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_822 + buf1206 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1206, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_826 + buf1207 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1207, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_830 + buf1208 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1208, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_856 + buf1209 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1209, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_857 + buf1210 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1210, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_872 + buf1211 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1211, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_876 + buf1212 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1212, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_880 + buf1213 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1213, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_906 + buf1214 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1214, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_907 + buf1215 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1215, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_922 + buf1216 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1216, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_926 + buf1217 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1217, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_930 + buf1218 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1218, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_956 + buf1219 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1219, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_957 + buf1220 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1220, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_972 + buf1221 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1221, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_976 + buf1222 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1222, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_980 + buf1223 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1223, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1006 + buf1224 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1224, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1007 + buf1225 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1225, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1022 + buf1226 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1226, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1026 + buf1227 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1227, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1030 + buf1228 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1228, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1056 + buf1229 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1229, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1057 + buf1230 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1230, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1072 + buf1231 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1231, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1076 + buf1232 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1232, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1080 + buf1233 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1233, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1106 + buf1234 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1234, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1107 + buf1235 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1235, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1122 + buf1236 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1236, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1126 + buf1237 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1237, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1130 + buf1238 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1238, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1156 + buf1239 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1239, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1157 + buf1240 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1240, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1172 + buf1241 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1241, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1176 + buf1242 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1242, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1180 + buf1243 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1243, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1206 + buf1244 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1244, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1207 + buf1245 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1245, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1222 + buf1246 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1246, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1226 + buf1247 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1247, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1230 + buf1248 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1248, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1256 + buf1249 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1249, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1257 + buf1250 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1250, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1272 + buf1251 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1251, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1276 + buf1252 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1252, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1280 + buf1253 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1253, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1306 + buf1254 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1254, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1307 + buf1255 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1255, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1322 + buf1256 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1256, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1326 + buf1257 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1257, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1330 + buf1258 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1258, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1356 + buf1259 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1259, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1357 + buf1260 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1260, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1372 + buf1261 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1261, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1376 + buf1262 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1262, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1380 + buf1263 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1263, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1406 + buf1264 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1264, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1407 + buf1265 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1265, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1422 + buf1266 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1266, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1426 + buf1267 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1267, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1430 + buf1268 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1268, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1456 + buf1269 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1269, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1457 + buf1270 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1270, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1472 + buf1271 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1271, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1476 + buf1272 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1272, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1480 + buf1273 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1273, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1506 + buf1274 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1274, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1507 + buf1275 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1275, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1522 + buf1276 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1276, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1526 + buf1277 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1277, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1530 + buf1278 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1278, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1556 + buf1279 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1279, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1557 + buf1280 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1280, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1572 + buf1281 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1281, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1576 + buf1282 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1282, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1580 + buf1283 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1283, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1606 + buf1284 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1284, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1607 + buf1285 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1285, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1622 + buf1286 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1286, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1626 + buf1287 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1287, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1630 + buf1288 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1288, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1656 + buf1289 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1289, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1657 + buf1290 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1290, (8, 2048, 1408), dtype=torch.bfloat16, is_leaf=True) # permute_1672 + buf1291 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1291, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1676 + buf1292 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1292, (8, 1408, 2048), dtype=torch.bfloat16, is_leaf=True) # permute_1680 + buf1293 = reader.storage(None, 1677721600, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1293, (2, 4096, 102400), dtype=torch.bfloat16, is_leaf=True) # tangents_1 +load_args._version = 0 +mod = Repro() +if __name__ == '__main__': + from torch._dynamo.repro.after_aot import run_repro + from torch._dynamo.repro.after_aot import setup_fake_process_groups + setup_fake_process_groups({'0': {'size': 128, 'rank': 0}, '1033': {'size': 8, 'rank': 0}, '1025': {'size': 16, 'rank': 0}}) + with torch.no_grad(): + run_repro(mod, load_args, accuracy=False, command='run', save_dir=None, tracing_mode='symbolic', check_str=None) + # To run it separately, do + # mod, args = run_repro(mod, load_args, accuracy=False, command='get_args', save_dir=None, tracing_mode='symbolic', check_str=None) + # mod(*args) + dist.destroy_process_group() + +# Helper functions for overlap simulator +def get_pg_config(): + """DSv3 128 GPUs: FSDP=128, TP=1, EP=8.""" + return {'0': {'size': 128, 'rank': 0}, '1025': {'size': 16, 'rank': 0}, '1033': {'size': 8, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls16_8.table" + +def get_colls_group_mapping(): + # FSDP "0" → internode (table group "0"), all other groups → intranode (table group "1") + return {'0': '0', '1025': '1', '1033': '1'} diff --git a/autoparallel/tools/overlap_simulator/repro_dsv3_bw_64.py b/autoparallel/tools/overlap_simulator/repro_dsv3_bw_64.py new file mode 100644 index 00000000..03fe2da5 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_dsv3_bw_64.py @@ -0,0 +1,11332 @@ +# fmt: off +# flake8: noqa +# isort: skip_file + +import os +os.environ['PYTORCH_KERNEL_CACHE_PATH'] = '/mnt/mffuse/.cache/torch/kernels' +os.environ['TORCH_DISABLE_ADDR2LINE'] = '1' +os.environ['TORCH_TRACE'] = '/mnt/mffuse/outputs/sfsdp-dsv3-16b--tp1-bs2-inductor-64-ivankobzarev-gw1hcmpr/torch_trace/' +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' +os.environ['TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE'] = '[${role_name}${rank}|${local_rank}]:' +os.environ['TORCHELASTIC_MAX_RESTARTS'] = '0' +os.environ['TORCHX_INTERNAL_SESSION_ID'] = 'a7cb45e8-8435-4d98-8768-5273c1f06ab2' +os.environ['TORCHX_RUN_PYTHONPATH'] = '' +os.environ['TORCHELASTIC_ERROR_FILE'] = '/tmp/torchelastic_rm4e8tdn/sfsdp-dsv3-16b--tp1-bs2-inductor-64-ivankobzarev-gw1hcmpr_eyl0q0pn/attempt_0/0/error.json' +os.environ['TORCH_ADDR2LINE_BINARY'] = '/packages/folly.symbolizer/folly-addr2line' +os.environ['TORCHX_JOB_ID'] = 'mast_conda://torchx/sfsdp-dsv3-16b--tp1-bs2-inductor-64-ivankobzarev-gw1hcmpr' +os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '3' +os.environ['TORCHELASTIC_SIGNALS_TO_HANDLE'] = 'SIGTERM,SIGINT,SIGHUP,SIGQUIT' +os.environ['TORCHELASTIC_RUN_ID'] = 'sfsdp-dsv3-16b--tp1-bs2-inductor-64-ivankobzarev-gw1hcmpr' +os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1' +os.environ['TORCHELASTIC_RESTART_COUNT'] = '0' +os.environ['TORCHELASTIC_USE_AGENT_STORE'] = 'False' +os.environ['PYTORCH_DDP_USE_SIDE_STREAM'] = '0' +os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/tmp/torchinductor_root' +os.environ['TORCH_FR_BUFFER_SIZE'] = '20000' +os.environ['TORCH_NCCL_DUMP_ON_TIMEOUT'] = '1' +os.environ['TORCH_FR_DUMP_TEMP_FILE'] = '/mnt/mffuse_nccl_trace/nccl_trace/sfsdp-dsv3-16b--tp1-bs2-inductor-64-ivankobzarev-gw1hcmpr/v_0/attempt_0/nccl_trace_rank_' +os.environ['TRITON_CACHE_DIR'] = '/tmp/torchinductor_root/triton/0' + +import torch +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +import torch._inductor.inductor_prims +import torch.distributed as dist +from torch.testing._internal.distributed.fake_pg import FakeStore +import triton +import triton.language as tl + +import torch._dynamo.config +import torch._inductor.config +import torch._functorch.config +import torch.fx.experimental._config +torch._dynamo.config.capture_scalar_outputs = True +torch._inductor.config.allow_buffer_reuse = False +torch._inductor.config.reorder_for_compute_comm_overlap = False +torch._inductor.config.reorder_for_peak_memory = False +torch._inductor.config.max_autotune = False +torch._inductor.config.coordinate_descent_tuning = False +torch._inductor.config.deterministic = False +torch._inductor.config.aten_distributed_optimizations.collective_bucketing = True +torch._inductor.config.aten_distributed_optimizations.insert_overlap_deps = True +torch._inductor.config.wrap_inductor_compiled_regions = False +torch._inductor.config.triton.cudagraphs = False +torch._inductor.config.triton.store_cubin = False +torch._inductor.config.test_configs.runtime_triton_dtype_assert = False +torch._functorch.config.functionalize_rng_ops = False +torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = True +torch._functorch.config.unlift_effect_tokens = True +torch._functorch.config.selective_decompose = False + + + +isolate_fails_code_str = None + + + + + +if "__compile_source__" in globals(): + import inspect as __after_aot_inspect + import linecache as __after_aot_linecache + __after_aot_filename = __after_aot_inspect.currentframe().f_code.co_filename + __after_aot_linecache.cache[__after_aot_filename] = ( + len(__compile_source__), + None, + __compile_source__.splitlines(True), + __after_aot_filename, + ) +# torch version: 2.11.0a0+git5ac4d4b +# torch cuda version: 12.4 +# torch git version: 5ac4d4bf3f85e15fdd6676f46b090568ea91e47e + + +# CUDA Info: +# nvcc not found +# GPU Hardware Info: +# NVIDIA H100 80GB HBM3 : 8 + +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.reset_table() + +@triton.jit +def _fill_indices_kernel_0( + tokens_per_expert_group_ptr, + start_index_values_ptr, + write_offsets_ptr, + output_ptr, + experts_per_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # Number of threads per block +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # map programs (blocks) to the experts and loop (grid stride) if needed + for expert_id in range(pid, experts_per_rank, num_programs): + # read this experts write offset + write_offset = tl.load(write_offsets_ptr + expert_id) + + for r in range(num_ranks): + # index into tokens_per_expert_group array + i = r * experts_per_rank + expert_id + + # load start index and number of tokens for this expert-rank pair + start_index = tl.load(start_index_values_ptr + i) + length = tl.load(tokens_per_expert_group_ptr + i) + + # each thread in block processes tokens in parallel + offsets = tl.arange(0, BLOCK_SIZE) + + # tokens are processed in chunks of BLOCK_SIZE + for chunk_start in range(0, length, BLOCK_SIZE): + chunk_offsets = chunk_start + offsets + + # mask valid indices + mask = chunk_offsets < length + + values = start_index + chunk_offsets + + # destination + dest_indices = write_offset + chunk_offsets + + # store + tl.store(output_ptr + dest_indices, values, mask=mask) + + # update write offset for next rank + write_offset += length + +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(_fill_indices_kernel_0) +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.constant_args={0: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 1: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 2: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 3: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 4: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 5: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 6: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 7: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 8: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 9: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 10: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 11: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 12: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 13: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 14: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 15: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 16: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 17: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 18: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 19: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 20: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 21: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 22: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 23: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 24: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 25: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}} + +from torch.nn import * +# Stub for submodules referenced in backward graph +class GraphModule(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, *args, **kwargs): + pass +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fw_graph0 = GraphModule() + self.joint_graph0 = GraphModule() + self.mask_graph0 = GraphModule() + self.fw_graph1 = GraphModule() + self.joint_graph1 = GraphModule() + self.mask_graph1 = GraphModule() + self.fw_graph2 = GraphModule() + self.joint_graph2 = GraphModule() + self.mask_graph2 = GraphModule() + self.fw_graph3 = GraphModule() + self.joint_graph3 = GraphModule() + self.mask_graph3 = GraphModule() + self.fw_graph4 = GraphModule() + self.joint_graph4 = GraphModule() + self.mask_graph4 = GraphModule() + self.fw_graph5 = GraphModule() + self.joint_graph5 = GraphModule() + self.mask_graph5 = GraphModule() + self.fw_graph6 = GraphModule() + self.joint_graph6 = GraphModule() + self.mask_graph6 = GraphModule() + self.fw_graph7 = GraphModule() + self.joint_graph7 = GraphModule() + self.mask_graph7 = GraphModule() + self.fw_graph8 = GraphModule() + self.joint_graph8 = GraphModule() + self.mask_graph8 = GraphModule() + self.fw_graph9 = GraphModule() + self.joint_graph9 = GraphModule() + self.mask_graph9 = GraphModule() + self.fw_graph10 = GraphModule() + self.joint_graph10 = GraphModule() + self.mask_graph10 = GraphModule() + self.fw_graph11 = GraphModule() + self.joint_graph11 = GraphModule() + self.mask_graph11 = GraphModule() + self.fw_graph12 = GraphModule() + self.joint_graph12 = GraphModule() + self.mask_graph12 = GraphModule() + self.fw_graph13 = GraphModule() + self.joint_graph13 = GraphModule() + self.mask_graph13 = GraphModule() + self.fw_graph14 = GraphModule() + self.joint_graph14 = GraphModule() + self.mask_graph14 = GraphModule() + self.fw_graph15 = GraphModule() + self.joint_graph15 = GraphModule() + self.mask_graph15 = GraphModule() + self.fw_graph16 = GraphModule() + self.joint_graph16 = GraphModule() + self.mask_graph16 = GraphModule() + self.fw_graph17 = GraphModule() + self.joint_graph17 = GraphModule() + self.mask_graph17 = GraphModule() + self.fw_graph18 = GraphModule() + self.joint_graph18 = GraphModule() + self.mask_graph18 = GraphModule() + self.fw_graph19 = GraphModule() + self.joint_graph19 = GraphModule() + self.mask_graph19 = GraphModule() + self.fw_graph20 = GraphModule() + self.joint_graph20 = GraphModule() + self.mask_graph20 = GraphModule() + self.fw_graph21 = GraphModule() + self.joint_graph21 = GraphModule() + self.mask_graph21 = GraphModule() + self.fw_graph22 = GraphModule() + self.joint_graph22 = GraphModule() + self.mask_graph22 = GraphModule() + self.fw_graph23 = GraphModule() + self.joint_graph23 = GraphModule() + self.mask_graph23 = GraphModule() + self.fw_graph24 = GraphModule() + self.joint_graph24 = GraphModule() + self.mask_graph24 = GraphModule() + self.fw_graph25 = GraphModule() + self.joint_graph25 = GraphModule() + self.mask_graph25 = GraphModule() + self.fw_graph26 = GraphModule() + self.joint_graph26 = GraphModule() + self.mask_graph26 = GraphModule() + + + + def forward(self, _local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7, _local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15, _local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23, _local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31, _local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39, _local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47, _local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55, _local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63, _local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71, _local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79, _local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87, _local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95, _local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103, _local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111, _local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119, _local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127, _local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135, _local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143, _local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151, _local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159, _local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167, _local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175, _local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183, _local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191, _local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199, _local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207, _local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215, _local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223, _local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231, _local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239, _local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247, _local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255, _local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263, _local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271, _local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279, _local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287, _local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295, _local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303, _local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311, _local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319, _local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327, _local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335, _local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343, _local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351, _local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359, _local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367, _local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375, _local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383, _local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391, _local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399, _local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407, _local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415, sym_size_int_1, sym_size_int_5, sym_size_int_9, sym_size_int_13, sym_size_int_17, sym_size_int_21, sym_size_int_25, sym_size_int_29, sym_size_int_33, sym_size_int_37, sym_size_int_41, sym_size_int_45, sym_size_int_49, sym_size_int_53, sym_size_int_57, sym_size_int_61, sym_size_int_65, sym_size_int_69, sym_size_int_73, sym_size_int_77, sym_size_int_81, sym_size_int_85, sym_size_int_89, sym_size_int_93, sym_size_int_97, sym_size_int_101, add_1781, add_1796, add_1811, add_1826, add_1841, add_1856, add_1871, add_1886, add_1901, add_1916, add_1931, add_1946, add_1961, add_1976, add_1991, add_2006, add_2021, add_2036, add_2051, add_2066, add_2081, add_2096, add_2111, add_2126, add_2141, add_2156, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_31, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_47, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_63, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_79, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_95, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_111, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_127, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_143, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_159, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_175, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_191, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_207, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_223, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_239, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_255, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_271, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_287, primals_289, primals_290, primals_291, primals_292, primals_293, primals_294, primals_295, primals_296, primals_297, primals_298, primals_299, primals_300, primals_301, primals_303, primals_305, primals_306, primals_307, primals_308, primals_309, primals_310, primals_311, primals_312, primals_313, primals_314, primals_315, primals_316, primals_317, primals_319, primals_321, primals_322, primals_323, primals_324, primals_325, primals_326, primals_327, primals_328, primals_329, primals_330, primals_331, primals_332, primals_333, primals_335, primals_337, primals_338, primals_339, primals_340, primals_341, primals_342, primals_343, primals_344, primals_345, primals_346, primals_347, primals_348, primals_349, primals_351, primals_353, primals_354, primals_355, primals_356, primals_357, primals_358, primals_359, primals_360, primals_361, primals_362, primals_363, primals_364, primals_365, primals_367, primals_369, primals_370, primals_371, primals_372, primals_373, primals_374, primals_375, primals_376, primals_377, primals_378, primals_379, primals_380, primals_381, primals_383, primals_385, primals_386, primals_387, primals_388, primals_389, primals_390, primals_391, primals_392, primals_393, primals_394, primals_395, primals_396, primals_397, primals_399, primals_401, primals_402, primals_403, primals_404, primals_405, primals_406, primals_407, primals_408, primals_409, primals_410, primals_411, primals_412, primals_413, primals_415, primals_417, primals_418, primals_419, primals_420, primals_421, primals_422, primals_423, primals_424, primals_425, primals_426, primals_427, primals_428, primals_429, primals_431, primals_433, primals_434, primals_435, primals_436, primals_437, primals_438, primals_439, primals_440, embedding, rsqrt, view_3, getitem_2, rsqrt_1, view_17, permute_3, permute_4, permute_5, getitem_6, getitem_7, mm_3, rsqrt_2, view_26, mm_4, mm_5, view_32, add_5, rsqrt_3, view_36, getitem_11, rsqrt_4, view_50, permute_14, permute_15, permute_16, getitem_15, getitem_16, add_8, rsqrt_5, view_58, mm_11, amax, sum_1, getitem_19, getitem_21, div_2, getitem_22, index_1, cumsum_2, _grouped_mm, _grouped_mm_1, mul_35, mm_12, mm_13, mul_55, add_73, rsqrt_6, view_103, getitem_25, rsqrt_7, view_117, permute_29, permute_30, permute_31, getitem_29, getitem_30, add_76, rsqrt_8, view_125, mm_19, amax_1, sum_5, getitem_33, getitem_35, div_7, getitem_36, index_3, cumsum_5, _grouped_mm_3, _grouped_mm_4, mul_84, mm_20, mm_21, mul_104, add_141, rsqrt_9, view_170, getitem_39, rsqrt_10, view_184, permute_44, permute_45, permute_46, getitem_43, getitem_44, add_144, rsqrt_11, view_192, mm_27, amax_2, sum_9, getitem_47, getitem_49, div_12, getitem_50, index_5, cumsum_8, _grouped_mm_6, _grouped_mm_7, mul_133, mm_28, mm_29, mul_153, add_209, rsqrt_12, view_237, getitem_53, rsqrt_13, view_251, permute_59, permute_60, permute_61, getitem_57, getitem_58, add_212, rsqrt_14, view_259, mm_35, amax_3, sum_13, getitem_61, getitem_63, div_17, getitem_64, index_7, cumsum_11, _grouped_mm_9, _grouped_mm_10, mul_182, mm_36, mm_37, mul_202, add_277, rsqrt_15, view_304, getitem_67, rsqrt_16, view_318, permute_74, permute_75, permute_76, getitem_71, getitem_72, add_280, rsqrt_17, view_326, mm_43, amax_4, sum_17, getitem_75, getitem_77, div_22, getitem_78, index_9, cumsum_14, _grouped_mm_12, _grouped_mm_13, mul_231, mm_44, mm_45, mul_251, add_345, rsqrt_18, view_371, getitem_81, rsqrt_19, view_385, permute_89, permute_90, permute_91, getitem_85, getitem_86, add_348, rsqrt_20, view_393, mm_51, amax_5, sum_21, getitem_89, getitem_91, div_27, getitem_92, index_11, cumsum_17, _grouped_mm_15, _grouped_mm_16, mul_280, mm_52, mm_53, mul_300, add_413, rsqrt_21, view_438, getitem_95, rsqrt_22, view_452, permute_104, permute_105, permute_106, getitem_99, getitem_100, add_416, rsqrt_23, view_460, mm_59, amax_6, sum_25, getitem_103, getitem_105, div_32, getitem_106, index_13, cumsum_20, _grouped_mm_18, _grouped_mm_19, mul_329, mm_60, mm_61, mul_349, add_481, rsqrt_24, view_505, getitem_109, rsqrt_25, view_519, permute_119, permute_120, permute_121, getitem_113, getitem_114, add_484, rsqrt_26, view_527, mm_67, amax_7, sum_29, getitem_117, getitem_119, div_37, getitem_120, index_15, cumsum_23, _grouped_mm_21, _grouped_mm_22, mul_378, mm_68, mm_69, mul_398, add_549, rsqrt_27, view_572, getitem_123, rsqrt_28, view_586, permute_134, permute_135, permute_136, getitem_127, getitem_128, add_552, rsqrt_29, view_594, mm_75, amax_8, sum_33, getitem_131, getitem_133, div_42, getitem_134, index_17, cumsum_26, _grouped_mm_24, _grouped_mm_25, mul_427, mm_76, mm_77, mul_447, add_617, rsqrt_30, view_639, getitem_137, rsqrt_31, view_653, permute_149, permute_150, permute_151, getitem_141, getitem_142, add_620, rsqrt_32, view_661, mm_83, amax_9, sum_37, getitem_145, getitem_147, div_47, getitem_148, index_19, cumsum_29, _grouped_mm_27, _grouped_mm_28, mul_476, mm_84, mm_85, mul_496, add_685, rsqrt_33, view_706, getitem_151, rsqrt_34, view_720, permute_164, permute_165, permute_166, getitem_155, getitem_156, add_688, rsqrt_35, view_728, mm_91, amax_10, sum_41, getitem_159, getitem_161, div_52, getitem_162, index_21, cumsum_32, _grouped_mm_30, _grouped_mm_31, mul_525, mm_92, mm_93, mul_545, add_753, rsqrt_36, view_773, getitem_165, rsqrt_37, view_787, permute_179, permute_180, permute_181, getitem_169, getitem_170, add_756, rsqrt_38, view_795, mm_99, amax_11, sum_45, getitem_173, getitem_175, div_57, getitem_176, index_23, cumsum_35, _grouped_mm_33, _grouped_mm_34, mul_574, mm_100, mm_101, mul_594, add_821, rsqrt_39, view_840, getitem_179, rsqrt_40, view_854, permute_194, permute_195, permute_196, getitem_183, getitem_184, add_824, rsqrt_41, view_862, mm_107, amax_12, sum_49, getitem_187, getitem_189, div_62, getitem_190, index_25, cumsum_38, _grouped_mm_36, _grouped_mm_37, mul_623, mm_108, mm_109, mul_643, add_889, rsqrt_42, view_907, getitem_193, rsqrt_43, view_921, permute_209, permute_210, permute_211, getitem_197, getitem_198, add_892, rsqrt_44, view_929, mm_115, amax_13, sum_53, getitem_201, getitem_203, div_67, getitem_204, index_27, cumsum_41, _grouped_mm_39, _grouped_mm_40, mul_672, mm_116, mm_117, mul_692, add_957, rsqrt_45, view_974, getitem_207, rsqrt_46, view_988, permute_224, permute_225, permute_226, getitem_211, getitem_212, add_960, rsqrt_47, view_996, mm_123, amax_14, sum_57, getitem_215, getitem_217, div_72, getitem_218, index_29, cumsum_44, _grouped_mm_42, _grouped_mm_43, mul_721, mm_124, mm_125, mul_741, add_1025, rsqrt_48, view_1041, getitem_221, rsqrt_49, view_1055, permute_239, permute_240, permute_241, getitem_225, getitem_226, add_1028, rsqrt_50, view_1063, mm_131, amax_15, sum_61, getitem_229, getitem_231, div_77, getitem_232, index_31, cumsum_47, _grouped_mm_45, _grouped_mm_46, mul_770, mm_132, mm_133, mul_790, add_1093, rsqrt_51, view_1108, getitem_235, rsqrt_52, view_1122, permute_254, permute_255, permute_256, getitem_239, getitem_240, add_1096, rsqrt_53, view_1130, mm_139, amax_16, sum_65, getitem_243, getitem_245, div_82, getitem_246, index_33, cumsum_50, _grouped_mm_48, _grouped_mm_49, mul_819, mm_140, mm_141, mul_839, add_1161, rsqrt_54, view_1175, getitem_249, rsqrt_55, view_1189, permute_269, permute_270, permute_271, getitem_253, getitem_254, add_1164, rsqrt_56, view_1197, mm_147, amax_17, sum_69, getitem_257, getitem_259, div_87, getitem_260, index_35, cumsum_53, _grouped_mm_51, _grouped_mm_52, mul_868, mm_148, mm_149, mul_888, add_1229, rsqrt_57, view_1242, getitem_263, rsqrt_58, view_1256, permute_284, permute_285, permute_286, getitem_267, getitem_268, add_1232, rsqrt_59, view_1264, mm_155, amax_18, sum_73, getitem_271, getitem_273, div_92, getitem_274, index_37, cumsum_56, _grouped_mm_54, _grouped_mm_55, mul_917, mm_156, mm_157, mul_937, add_1297, rsqrt_60, view_1309, getitem_277, rsqrt_61, view_1323, permute_299, permute_300, permute_301, getitem_281, getitem_282, add_1300, rsqrt_62, view_1331, mm_163, amax_19, sum_77, getitem_285, getitem_287, div_97, getitem_288, index_39, cumsum_59, _grouped_mm_57, _grouped_mm_58, mul_966, mm_164, mm_165, mul_986, add_1365, rsqrt_63, view_1376, getitem_291, rsqrt_64, view_1390, permute_314, permute_315, permute_316, getitem_295, getitem_296, add_1368, rsqrt_65, view_1398, mm_171, amax_20, sum_81, getitem_299, getitem_301, div_102, getitem_302, index_41, cumsum_62, _grouped_mm_60, _grouped_mm_61, mul_1015, mm_172, mm_173, mul_1035, add_1433, rsqrt_66, view_1443, getitem_305, rsqrt_67, view_1457, permute_329, permute_330, permute_331, getitem_309, getitem_310, add_1436, rsqrt_68, view_1465, mm_179, amax_21, sum_85, getitem_313, getitem_315, div_107, getitem_316, index_43, cumsum_65, _grouped_mm_63, _grouped_mm_64, mul_1064, mm_180, mm_181, mul_1084, add_1501, rsqrt_69, view_1510, getitem_319, rsqrt_70, view_1524, permute_344, permute_345, permute_346, getitem_323, getitem_324, add_1504, rsqrt_71, view_1532, mm_187, amax_22, sum_89, getitem_327, getitem_329, div_112, getitem_330, index_45, cumsum_68, _grouped_mm_66, _grouped_mm_67, mul_1113, mm_188, mm_189, mul_1133, add_1569, rsqrt_72, view_1577, getitem_333, rsqrt_73, view_1591, permute_359, permute_360, permute_361, getitem_337, getitem_338, add_1572, rsqrt_74, view_1599, mm_195, amax_23, sum_93, getitem_341, getitem_343, div_117, getitem_344, index_47, cumsum_71, _grouped_mm_69, _grouped_mm_70, mul_1162, mm_196, mm_197, mul_1182, add_1637, rsqrt_75, view_1644, getitem_347, rsqrt_76, view_1658, permute_374, permute_375, permute_376, getitem_351, getitem_352, add_1640, rsqrt_77, view_1666, mm_203, amax_24, sum_97, getitem_355, getitem_357, div_122, getitem_358, index_49, cumsum_74, _grouped_mm_72, _grouped_mm_73, mul_1211, mm_204, mm_205, mul_1231, add_1705, rsqrt_78, view_1711, getitem_361, rsqrt_79, view_1725, permute_389, permute_390, permute_391, getitem_365, getitem_366, add_1708, rsqrt_80, view_1733, mm_211, amax_25, sum_101, getitem_369, getitem_371, div_127, getitem_372, index_51, cumsum_77, _grouped_mm_75, _grouped_mm_76, mul_1260, mm_212, mm_213, mul_1280, add_1773, rsqrt_81, view_1778, permute_406, permute_407, permute_456, permute_457, permute_506, permute_507, permute_556, permute_557, permute_606, permute_607, permute_656, permute_657, permute_706, permute_707, permute_756, permute_757, permute_806, permute_807, permute_856, permute_857, permute_906, permute_907, permute_956, permute_957, permute_1006, permute_1007, permute_1056, permute_1057, permute_1106, permute_1107, permute_1156, permute_1157, permute_1206, permute_1207, permute_1256, permute_1257, permute_1306, permute_1307, permute_1356, permute_1357, permute_1406, permute_1407, permute_1456, permute_1457, permute_1506, permute_1507, permute_1556, permute_1557, permute_1606, permute_1607, permute_1656, permute_1657, tangents_1): + view_1780 = torch.ops.aten.view.default(tangents_1, [8192, 102400]); tangents_1 = None + permute_402 = torch.ops.aten.permute.default(view_1780, [1, 0]) + mm_216 = torch.ops.aten.mm.default(permute_402, view_1778); permute_402 = view_1778 = None + convert_element_type_1444 = torch.ops.prims.convert_element_type.default(primals_440, torch.bfloat16); primals_440 = None + all_gather_into_tensor_454 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1444, 64, '0'); convert_element_type_1444 = None + wait_tensor_558 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_454); all_gather_into_tensor_454 = None + permute_401 = torch.ops.aten.permute.default(wait_tensor_558, [1, 0]); wait_tensor_558 = None + permute_404 = torch.ops.aten.permute.default(permute_401, [1, 0]); permute_401 = None + mm_217 = torch.ops.aten.mm.default(view_1780, permute_404); view_1780 = permute_404 = None + view_1781 = torch.ops.aten.view.default(mm_217, [2, 4096, 2048]); mm_217 = None + convert_element_type_1451 = torch.ops.prims.convert_element_type.default(mm_216, torch.float32); mm_216 = None + reduce_scatter_tensor = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1451, 'avg', 64, '0'); convert_element_type_1451 = None + wait_tensor_559 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None + convert_element_type_1452 = torch.ops.prims.convert_element_type.default(view_1781, torch.float32); view_1781 = None + convert_element_type_1441 = torch.ops.prims.convert_element_type.default(primals_439, torch.bfloat16); primals_439 = None + all_gather_into_tensor_453 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1441, 64, '0'); convert_element_type_1441 = None + wait_tensor_557 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_453); all_gather_into_tensor_453 = None + convert_element_type_1454 = torch.ops.prims.convert_element_type.default(wait_tensor_557, torch.float32); wait_tensor_557 = None + mul_1285 = torch.ops.aten.mul.Tensor(convert_element_type_1452, convert_element_type_1454); convert_element_type_1454 = None + convert_element_type_1442 = torch.ops.prims.convert_element_type.default(add_1773, torch.float32); add_1773 = None + mul_1283 = torch.ops.aten.mul.Tensor(convert_element_type_1442, rsqrt_81); convert_element_type_1442 = None + mul_1287 = torch.ops.aten.mul.Tensor(mul_1283, mul_1285) + sum_105 = torch.ops.aten.sum.dim_IntList(mul_1287, [2], True); mul_1287 = None + div_131 = torch.ops.aten.div.Tensor(mul_1283, 2048) + mul_1288 = torch.ops.aten.mul.Tensor(div_131, sum_105); div_131 = sum_105 = None + sub_624 = torch.ops.aten.sub.Tensor(mul_1285, mul_1288); mul_1285 = mul_1288 = None + mul_1289 = torch.ops.aten.mul.Tensor(sub_624, rsqrt_81); sub_624 = rsqrt_81 = None + mul_1290 = torch.ops.aten.mul.Tensor(convert_element_type_1452, mul_1283); convert_element_type_1452 = mul_1283 = None + sum_106 = torch.ops.aten.sum.dim_IntList(mul_1290, [0, 1]); mul_1290 = None + convert_element_type_1455 = torch.ops.prims.convert_element_type.default(mul_1289, torch.bfloat16); mul_1289 = None + convert_element_type_default_82 = torch.ops.prims.convert_element_type.default(sum_106, torch.float32); sum_106 = None + reduce_scatter_tensor_1 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_82, 'avg', 64, '0'); convert_element_type_default_82 = None + wait_tensor_560 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None + view_1782 = torch.ops.aten.view.default(convert_element_type_1455, [8192, 2048]) + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_1782, 1) + convert_element_type_1458 = torch.ops.prims.convert_element_type.default(unsqueeze_53, torch.float32); unsqueeze_53 = None + bmm_26 = torch.ops.aten.bmm.default(permute_406, convert_element_type_1458); permute_406 = None + bmm_27 = torch.ops.aten.bmm.default(convert_element_type_1458, permute_407); convert_element_type_1458 = permute_407 = None + convert_element_type_1459 = torch.ops.prims.convert_element_type.default(bmm_26, torch.bfloat16); bmm_26 = None + view_1783 = torch.ops.aten.view.default(bmm_27, [8192, 6]); bmm_27 = None + view_1784 = torch.ops.aten.view.default(convert_element_type_1459, [49152, 2048]); convert_element_type_1459 = None + index_52 = torch.ops.aten.index.Tensor(view_1784, [getitem_371]); view_1784 = getitem_371 = None + permute_408 = torch.ops.aten.permute.default(view_1782, [1, 0]) + mm_218 = torch.ops.aten.mm.default(permute_408, mul_1280); permute_408 = mul_1280 = None + convert_element_type_1436 = torch.ops.prims.convert_element_type.default(primals_438, torch.bfloat16); primals_438 = None + all_gather_into_tensor_452 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1436, 64, '0'); convert_element_type_1436 = None + wait_tensor_556 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_452); all_gather_into_tensor_452 = None + permute_400 = torch.ops.aten.permute.default(wait_tensor_556, [1, 0]); wait_tensor_556 = None + permute_410 = torch.ops.aten.permute.default(permute_400, [1, 0]); permute_400 = None + mm_219 = torch.ops.aten.mm.default(view_1782, permute_410); view_1782 = permute_410 = None + convert_element_type_1464 = torch.ops.prims.convert_element_type.default(mm_218, torch.float32); mm_218 = None + reduce_scatter_tensor_2 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1464, 'avg', 64, '0'); convert_element_type_1464 = None + wait_tensor_561 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None + convert_element_type_1431 = torch.ops.prims.convert_element_type.default(mm_212, torch.float32); mm_212 = None + neg_52 = torch.ops.aten.neg.default(convert_element_type_1431) + exp_78 = torch.ops.aten.exp.default(neg_52); neg_52 = None + add_1768 = torch.ops.aten.add.Tensor(exp_78, 1); exp_78 = None + div_130 = torch.ops.aten.div.Tensor(convert_element_type_1431, add_1768) + convert_element_type_1432 = torch.ops.prims.convert_element_type.default(div_130, torch.bfloat16); div_130 = None + mul_1291 = torch.ops.aten.mul.Tensor(mm_219, convert_element_type_1432); convert_element_type_1432 = None + mul_1292 = torch.ops.aten.mul.Tensor(mm_219, mm_213); mm_219 = mm_213 = None + permute_412 = torch.ops.aten.permute.default(mul_1291, [1, 0]) + mm_220 = torch.ops.aten.mm.default(permute_412, view_1733); permute_412 = None + convert_element_type_1433 = torch.ops.prims.convert_element_type.default(primals_437, torch.bfloat16); primals_437 = None + all_gather_into_tensor_451 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1433, 64, '0'); convert_element_type_1433 = None + wait_tensor_555 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_451); all_gather_into_tensor_451 = None + permute_399 = torch.ops.aten.permute.default(wait_tensor_555, [1, 0]); wait_tensor_555 = None + permute_414 = torch.ops.aten.permute.default(permute_399, [1, 0]); permute_399 = None + mm_221 = torch.ops.aten.mm.default(mul_1291, permute_414); mul_1291 = permute_414 = None + convert_element_type_1469 = torch.ops.prims.convert_element_type.default(mm_220, torch.float32); mm_220 = None + reduce_scatter_tensor_3 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1469, 'avg', 64, '0'); convert_element_type_1469 = None + wait_tensor_562 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3); reduce_scatter_tensor_3 = None + convert_element_type_1470 = torch.ops.prims.convert_element_type.default(mul_1292, torch.float32); mul_1292 = None + reciprocal = torch.ops.aten.reciprocal.default(add_1768); add_1768 = None + mul_1293 = torch.ops.aten.mul.Tensor(reciprocal, 1); reciprocal = None + mul_1294 = torch.ops.aten.mul.Tensor(convert_element_type_1470, mul_1293); convert_element_type_1470 = None + sub_625 = torch.ops.aten.sub.Tensor(1, mul_1293); mul_1293 = None + mul_1295 = torch.ops.aten.mul.Tensor(convert_element_type_1431, sub_625); convert_element_type_1431 = sub_625 = None + add_1776 = torch.ops.aten.add.Tensor(mul_1295, 1); mul_1295 = None + mul_1296 = torch.ops.aten.mul.Tensor(mul_1294, add_1776); mul_1294 = add_1776 = None + convert_element_type_1472 = torch.ops.prims.convert_element_type.default(mul_1296, torch.bfloat16); mul_1296 = None + permute_416 = torch.ops.aten.permute.default(convert_element_type_1472, [1, 0]) + mm_222 = torch.ops.aten.mm.default(permute_416, view_1733); permute_416 = None + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(primals_436, torch.bfloat16); primals_436 = None + all_gather_into_tensor_450 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1428, 64, '0'); convert_element_type_1428 = None + wait_tensor_554 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_450); all_gather_into_tensor_450 = None + permute_398 = torch.ops.aten.permute.default(wait_tensor_554, [1, 0]); wait_tensor_554 = None + permute_418 = torch.ops.aten.permute.default(permute_398, [1, 0]); permute_398 = None + mm_223 = torch.ops.aten.mm.default(convert_element_type_1472, permute_418); convert_element_type_1472 = permute_418 = None + add_1777 = torch.ops.aten.add.Tensor(mm_221, mm_223); mm_221 = mm_223 = None + convert_element_type_1477 = torch.ops.prims.convert_element_type.default(mm_222, torch.float32); mm_222 = None + reduce_scatter_tensor_4 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1477, 'avg', 64, '0'); convert_element_type_1477 = None + wait_tensor_563 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_4); reduce_scatter_tensor_4 = None + all_to_all_single_78 = torch.ops._c10d_functional.all_to_all_single.default(index_52, [_local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415], [_local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407], '521'); index_52 = None + wait_tensor_564 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_78); all_to_all_single_78 = None + full_348 = torch.ops.aten.full.default([sym_size_int_101, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_101 = None + slice_scatter = torch.ops.aten.slice_scatter.default(full_348, wait_tensor_564, 0, 0, -1); wait_tensor_564 = None + index_53 = torch.ops.aten.index.Tensor(slice_scatter, [getitem_372]); slice_scatter = None + permute_420 = torch.ops.aten.permute.default(index_53, [1, 0]) + _grouped_mm_78 = torch.ops.aten._grouped_mm.default(permute_420, mul_1260, cumsum_77); permute_420 = mul_1260 = None + convert_element_type_1422 = torch.ops.prims.convert_element_type.default(primals_434, torch.bfloat16); primals_434 = None + all_gather_into_tensor_446 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1422, 8, '513'); convert_element_type_1422 = None + wait_tensor_549 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_446); all_gather_into_tensor_446 = None + permute_397 = torch.ops.aten.permute.default(wait_tensor_549, [0, 2, 1]); wait_tensor_549 = None + permute_422 = torch.ops.aten.permute.default(permute_397, [0, 2, 1]); permute_397 = None + _grouped_mm_79 = torch.ops.aten._grouped_mm.default(index_53, permute_422, cumsum_77); index_53 = permute_422 = None + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(_grouped_mm_75, torch.float32); _grouped_mm_75 = None + neg_51 = torch.ops.aten.neg.default(convert_element_type_1426) + exp_77 = torch.ops.aten.exp.default(neg_51); neg_51 = None + add_1732 = torch.ops.aten.add.Tensor(exp_77, 1); exp_77 = None + div_129 = torch.ops.aten.div.Tensor(convert_element_type_1426, add_1732) + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(div_129, torch.bfloat16); div_129 = None + mul_1297 = torch.ops.aten.mul.Tensor(_grouped_mm_79, convert_element_type_1427); convert_element_type_1427 = None + mul_1298 = torch.ops.aten.mul.Tensor(_grouped_mm_79, _grouped_mm_76); _grouped_mm_79 = _grouped_mm_76 = None + permute_424 = torch.ops.aten.permute.default(mul_1297, [1, 0]) + _grouped_mm_80 = torch.ops.aten._grouped_mm.default(permute_424, index_51, cumsum_77); permute_424 = None + convert_element_type_1423 = torch.ops.prims.convert_element_type.default(primals_435, torch.bfloat16); primals_435 = None + all_gather_into_tensor_447 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1423, 8, '513'); convert_element_type_1423 = None + wait_tensor_550 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_447); all_gather_into_tensor_447 = None + permute_396 = torch.ops.aten.permute.default(wait_tensor_550, [0, 2, 1]); wait_tensor_550 = None + permute_426 = torch.ops.aten.permute.default(permute_396, [0, 2, 1]); permute_396 = None + _grouped_mm_81 = torch.ops.aten._grouped_mm.default(mul_1297, permute_426, cumsum_77); mul_1297 = permute_426 = None + convert_element_type_1478 = torch.ops.prims.convert_element_type.default(mul_1298, torch.float32); mul_1298 = None + reciprocal_1 = torch.ops.aten.reciprocal.default(add_1732); add_1732 = None + mul_1299 = torch.ops.aten.mul.Tensor(reciprocal_1, 1); reciprocal_1 = None + mul_1300 = torch.ops.aten.mul.Tensor(convert_element_type_1478, mul_1299); convert_element_type_1478 = None + sub_626 = torch.ops.aten.sub.Tensor(1, mul_1299); mul_1299 = None + mul_1301 = torch.ops.aten.mul.Tensor(convert_element_type_1426, sub_626); convert_element_type_1426 = sub_626 = None + add_1779 = torch.ops.aten.add.Tensor(mul_1301, 1); mul_1301 = None + mul_1302 = torch.ops.aten.mul.Tensor(mul_1300, add_1779); mul_1300 = add_1779 = None + convert_element_type_1480 = torch.ops.prims.convert_element_type.default(mul_1302, torch.bfloat16); mul_1302 = None + permute_428 = torch.ops.aten.permute.default(convert_element_type_1480, [1, 0]) + _grouped_mm_82 = torch.ops.aten._grouped_mm.default(permute_428, index_51, cumsum_77); permute_428 = index_51 = None + convert_element_type_1420 = torch.ops.prims.convert_element_type.default(primals_433, torch.bfloat16); primals_433 = None + all_gather_into_tensor_444 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1420, 8, '513'); convert_element_type_1420 = None + wait_tensor_547 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_444); all_gather_into_tensor_444 = None + permute_395 = torch.ops.aten.permute.default(wait_tensor_547, [0, 2, 1]); wait_tensor_547 = None + permute_430 = torch.ops.aten.permute.default(permute_395, [0, 2, 1]); permute_395 = None + _grouped_mm_83 = torch.ops.aten._grouped_mm.default(convert_element_type_1480, permute_430, cumsum_77); convert_element_type_1480 = permute_430 = cumsum_77 = None + add_1780 = torch.ops.aten.add.Tensor(_grouped_mm_81, _grouped_mm_83); _grouped_mm_81 = _grouped_mm_83 = None + convert_element_type_1481 = torch.ops.prims.convert_element_type.default(_grouped_mm_80, torch.float32); _grouped_mm_80 = None + div_132 = torch.ops.aten.div.Tensor(convert_element_type_1481, 64); convert_element_type_1481 = None + reduce_scatter_tensor_5 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_132, 'sum', 8, '513'); div_132 = None + wait_tensor_565 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5); reduce_scatter_tensor_5 = None + convert_element_type_1482 = torch.ops.prims.convert_element_type.default(_grouped_mm_78, torch.float32); _grouped_mm_78 = None + div_133 = torch.ops.aten.div.Tensor(convert_element_type_1482, 64); convert_element_type_1482 = None + reduce_scatter_tensor_6 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_133, 'sum', 8, '513'); div_133 = None + wait_tensor_566 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_6); reduce_scatter_tensor_6 = None + convert_element_type_1483 = torch.ops.prims.convert_element_type.default(_grouped_mm_82, torch.float32); _grouped_mm_82 = None + div_134 = torch.ops.aten.div.Tensor(convert_element_type_1483, 64); convert_element_type_1483 = None + reduce_scatter_tensor_7 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_134, 'sum', 8, '513'); div_134 = None + wait_tensor_567 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7); reduce_scatter_tensor_7 = None + index_put_52 = torch.ops.aten.index_put.default(full_348, [getitem_372], add_1780, True); full_348 = getitem_372 = add_1780 = None + slice_107 = torch.ops.aten.slice.Tensor(index_put_52, 0, 0, add_1781); index_put_52 = add_1781 = None + all_to_all_single_79 = torch.ops._c10d_functional.all_to_all_single.default(slice_107, [_local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407], [_local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415], '521'); slice_107 = _local_scalar_dense_400 = _local_scalar_dense_401 = _local_scalar_dense_402 = _local_scalar_dense_403 = _local_scalar_dense_404 = _local_scalar_dense_405 = _local_scalar_dense_406 = _local_scalar_dense_407 = _local_scalar_dense_408 = _local_scalar_dense_409 = _local_scalar_dense_410 = _local_scalar_dense_411 = _local_scalar_dense_412 = _local_scalar_dense_413 = _local_scalar_dense_414 = _local_scalar_dense_415 = None + wait_tensor_568 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_79); all_to_all_single_79 = None + full_default_52 = torch.ops.aten.full.default([8192, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_53 = torch.ops.aten.index_put.default(full_default_52, [div_127], wait_tensor_568, True); div_127 = wait_tensor_568 = None + add_1785 = torch.ops.aten.add.Tensor(add_1777, index_put_53); add_1777 = index_put_53 = None + mul_1303 = torch.ops.aten.mul.Tensor(view_1783, 1.0); view_1783 = None + full_default_53 = torch.ops.aten.full.default([8192, 64], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + scatter_add = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_369, mul_1303); getitem_369 = mul_1303 = None + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(mm_211, torch.float32); mm_211 = None + sub_600 = torch.ops.aten.sub.Tensor(convert_element_type_1415, amax_25); convert_element_type_1415 = amax_25 = None + exp_76 = torch.ops.aten.exp.default(sub_600); sub_600 = None + div_126 = torch.ops.aten.div.Tensor(exp_76, sum_101); exp_76 = sum_101 = None + mul_1304 = torch.ops.aten.mul.Tensor(scatter_add, div_126); scatter_add = None + sum_107 = torch.ops.aten.sum.dim_IntList(mul_1304, [1], True) + neg_55 = torch.ops.aten.neg.default(div_126); div_126 = None + fma = torch.ops.prims.fma.default(neg_55, sum_107, mul_1304); neg_55 = sum_107 = mul_1304 = None + convert_element_type_1484 = torch.ops.prims.convert_element_type.default(fma, torch.bfloat16); fma = None + permute_432 = torch.ops.aten.permute.default(convert_element_type_1484, [1, 0]) + mm_224 = torch.ops.aten.mm.default(permute_432, view_1733); permute_432 = view_1733 = None + convert_element_type_1412 = torch.ops.prims.convert_element_type.default(primals_431, torch.bfloat16); primals_431 = None + all_gather_into_tensor_443 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1412, 64, '0'); convert_element_type_1412 = None + wait_tensor_543 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_443); all_gather_into_tensor_443 = None + permute_394 = torch.ops.aten.permute.default(wait_tensor_543, [1, 0]); wait_tensor_543 = None + permute_434 = torch.ops.aten.permute.default(permute_394, [1, 0]); permute_394 = None + mm_225 = torch.ops.aten.mm.default(convert_element_type_1484, permute_434); convert_element_type_1484 = permute_434 = None + add_1786 = torch.ops.aten.add.Tensor(add_1785, mm_225); add_1785 = mm_225 = None + convert_element_type_1489 = torch.ops.prims.convert_element_type.default(mm_224, torch.float32); mm_224 = None + reduce_scatter_tensor_8 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1489, 'avg', 64, '0'); convert_element_type_1489 = None + wait_tensor_569 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_8); reduce_scatter_tensor_8 = None + view_1785 = torch.ops.aten.view.default(add_1786, [2, 4096, 2048]); add_1786 = None + convert_element_type_1490 = torch.ops.prims.convert_element_type.default(view_1785, torch.float32); view_1785 = None + convert_element_type_1409 = torch.ops.prims.convert_element_type.default(primals_429, torch.bfloat16); primals_429 = None + all_gather_into_tensor_442 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1409, 64, '0'); convert_element_type_1409 = None + wait_tensor_542 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_442); all_gather_into_tensor_442 = None + convert_element_type_1492 = torch.ops.prims.convert_element_type.default(wait_tensor_542, torch.float32); wait_tensor_542 = None + mul_1305 = torch.ops.aten.mul.Tensor(convert_element_type_1490, convert_element_type_1492); convert_element_type_1492 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(add_1708, torch.float32); add_1708 = None + mul_1240 = torch.ops.aten.mul.Tensor(convert_element_type_1410, rsqrt_80); convert_element_type_1410 = None + mul_1307 = torch.ops.aten.mul.Tensor(mul_1240, mul_1305) + sum_108 = torch.ops.aten.sum.dim_IntList(mul_1307, [2], True); mul_1307 = None + div_135 = torch.ops.aten.div.Tensor(mul_1240, 2048) + mul_1308 = torch.ops.aten.mul.Tensor(div_135, sum_108); div_135 = sum_108 = None + sub_628 = torch.ops.aten.sub.Tensor(mul_1305, mul_1308); mul_1305 = mul_1308 = None + mul_1309 = torch.ops.aten.mul.Tensor(sub_628, rsqrt_80); sub_628 = rsqrt_80 = None + mul_1310 = torch.ops.aten.mul.Tensor(convert_element_type_1490, mul_1240); convert_element_type_1490 = mul_1240 = None + sum_109 = torch.ops.aten.sum.dim_IntList(mul_1310, [0, 1]); mul_1310 = None + convert_element_type_1493 = torch.ops.prims.convert_element_type.default(mul_1309, torch.bfloat16); mul_1309 = None + add_1787 = torch.ops.aten.add.Tensor(convert_element_type_1455, convert_element_type_1493); convert_element_type_1455 = convert_element_type_1493 = None + convert_element_type_default_81 = torch.ops.prims.convert_element_type.default(sum_109, torch.float32); sum_109 = None + reduce_scatter_tensor_9 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_81, 'avg', 64, '0'); convert_element_type_default_81 = None + wait_tensor_570 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9); reduce_scatter_tensor_9 = None + view_1786 = torch.ops.aten.view.default(add_1787, [8192, 2048]) + permute_436 = torch.ops.aten.permute.default(view_1786, [1, 0]) + permute_392 = torch.ops.aten.permute.default(getitem_365, [0, 2, 1, 3]) + view_1728 = torch.ops.aten.view.default(permute_392, [2, 4096, -1]); permute_392 = None + view_1730 = torch.ops.aten.view.default(view_1728, [8192, 2048]); view_1728 = None + mm_226 = torch.ops.aten.mm.default(permute_436, view_1730); permute_436 = view_1730 = None + convert_element_type_1406 = torch.ops.prims.convert_element_type.default(primals_428, torch.bfloat16); primals_428 = None + all_gather_into_tensor_441 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1406, 64, '0'); convert_element_type_1406 = None + wait_tensor_541 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_441); all_gather_into_tensor_441 = None + permute_393 = torch.ops.aten.permute.default(wait_tensor_541, [1, 0]); wait_tensor_541 = None + permute_438 = torch.ops.aten.permute.default(permute_393, [1, 0]); permute_393 = None + mm_227 = torch.ops.aten.mm.default(view_1786, permute_438); view_1786 = permute_438 = None + view_1787 = torch.ops.aten.view.default(mm_227, [2, 4096, 2048]); mm_227 = None + convert_element_type_1500 = torch.ops.prims.convert_element_type.default(mm_226, torch.float32); mm_226 = None + reduce_scatter_tensor_10 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1500, 'avg', 64, '0'); convert_element_type_1500 = None + wait_tensor_571 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_10); reduce_scatter_tensor_10 = None + view_1788 = torch.ops.aten.view.default(view_1787, [2, 4096, 16, 128]); view_1787 = None + permute_440 = torch.ops.aten.permute.default(view_1788, [0, 2, 1, 3]); view_1788 = None + fw_graph0 = self.fw_graph0 + joint_graph0 = self.joint_graph0 + mask_graph0 = self.mask_graph0 + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(permute_389, permute_390, permute_391, getitem_365, getitem_366, permute_440, None, fw_graph0, joint_graph0, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph0), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_389 = permute_390 = permute_391 = getitem_365 = getitem_366 = permute_440 = fw_graph0 = joint_graph0 = mask_graph0 = None + getitem_373 = flex_attention_backward[0] + getitem_374 = flex_attention_backward[1] + getitem_375 = flex_attention_backward[2]; flex_attention_backward = None + permute_441 = torch.ops.aten.permute.default(getitem_375, [0, 2, 1, 3]); getitem_375 = None + permute_442 = torch.ops.aten.permute.default(getitem_374, [0, 2, 1, 3]); getitem_374 = None + permute_443 = torch.ops.aten.permute.default(getitem_373, [0, 2, 1, 3]); getitem_373 = None + slice_109 = torch.ops.aten.slice.Tensor(permute_442, 3, 0, 128) + slice_110 = torch.ops.aten.slice.Tensor(permute_442, 3, 128, 192); permute_442 = None + sum_110 = torch.ops.aten.sum.dim_IntList(slice_110, [2], True); slice_110 = None + cat_80 = torch.ops.aten.cat.default([slice_109, permute_441], 3); slice_109 = permute_441 = None + view_1789 = torch.ops.aten.view.default(cat_80, [2, 4096, 4096]); cat_80 = None + view_1790 = torch.ops.aten.view.default(view_1789, [8192, 4096]); view_1789 = None + permute_444 = torch.ops.aten.permute.default(view_1790, [1, 0]) + mm_228 = torch.ops.aten.mm.default(permute_444, view_1725); permute_444 = view_1725 = None + convert_element_type_1403 = torch.ops.prims.convert_element_type.default(primals_427, torch.bfloat16); primals_427 = None + all_gather_into_tensor_440 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1403, 64, '0'); convert_element_type_1403 = None + wait_tensor_540 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_440); all_gather_into_tensor_440 = None + permute_388 = torch.ops.aten.permute.default(wait_tensor_540, [1, 0]); wait_tensor_540 = None + permute_446 = torch.ops.aten.permute.default(permute_388, [1, 0]); permute_388 = None + mm_229 = torch.ops.aten.mm.default(view_1790, permute_446); view_1790 = permute_446 = None + view_1791 = torch.ops.aten.view.default(mm_229, [2, 4096, 512]); mm_229 = None + convert_element_type_1505 = torch.ops.prims.convert_element_type.default(mm_228, torch.float32); mm_228 = None + reduce_scatter_tensor_11 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1505, 'avg', 64, '0'); convert_element_type_1505 = None + wait_tensor_572 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11); reduce_scatter_tensor_11 = None + convert_element_type_1506 = torch.ops.prims.convert_element_type.default(view_1791, torch.float32); view_1791 = None + convert_element_type_1400 = torch.ops.prims.convert_element_type.default(primals_426, torch.bfloat16); primals_426 = None + all_gather_into_tensor_439 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1400, 64, '0'); convert_element_type_1400 = None + wait_tensor_539 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_439); all_gather_into_tensor_439 = None + convert_element_type_1508 = torch.ops.prims.convert_element_type.default(wait_tensor_539, torch.float32); wait_tensor_539 = None + mul_1311 = torch.ops.aten.mul.Tensor(convert_element_type_1506, convert_element_type_1508); convert_element_type_1508 = None + convert_element_type_1401 = torch.ops.prims.convert_element_type.default(getitem_361, torch.float32); getitem_361 = None + mul_1238 = torch.ops.aten.mul.Tensor(convert_element_type_1401, rsqrt_79); convert_element_type_1401 = None + mul_1313 = torch.ops.aten.mul.Tensor(mul_1238, mul_1311) + sum_111 = torch.ops.aten.sum.dim_IntList(mul_1313, [2], True); mul_1313 = None + div_136 = torch.ops.aten.div.Tensor(mul_1238, 512) + mul_1314 = torch.ops.aten.mul.Tensor(div_136, sum_111); div_136 = sum_111 = None + sub_629 = torch.ops.aten.sub.Tensor(mul_1311, mul_1314); mul_1311 = mul_1314 = None + mul_1315 = torch.ops.aten.mul.Tensor(sub_629, rsqrt_79); sub_629 = rsqrt_79 = None + mul_1316 = torch.ops.aten.mul.Tensor(convert_element_type_1506, mul_1238); convert_element_type_1506 = mul_1238 = None + sum_112 = torch.ops.aten.sum.dim_IntList(mul_1316, [0, 1]); mul_1316 = None + convert_element_type_1509 = torch.ops.prims.convert_element_type.default(mul_1315, torch.bfloat16); mul_1315 = None + convert_element_type_default_80 = torch.ops.prims.convert_element_type.default(sum_112, torch.float32); sum_112 = None + reduce_scatter_tensor_12 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_80, 'avg', 64, '0'); convert_element_type_default_80 = None + wait_tensor_573 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_12); reduce_scatter_tensor_12 = None + convert_element_type_1512 = torch.ops.prims.convert_element_type.default(sum_110, torch.float32); sum_110 = None + view_1792 = torch.ops.aten.view.default(convert_element_type_1512, [2, 4096, 1, 32, 2]); convert_element_type_1512 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_1792); view_1792 = None + view_7 = torch.ops.aten.view.default(primals_3, [1, 4096, 1, 32]); primals_3 = None + _conj = torch.ops.aten._conj.default(view_7); view_7 = None + clone_9 = torch.ops.aten.clone.default(_conj); _conj = None + mul_1317 = torch.ops.aten.mul.Tensor(view_as_complex_54, clone_9); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_1317); mul_1317 = None + view_1793 = torch.ops.aten.view.default(view_as_real_54, [2, 4096, 1, 64]); view_as_real_54 = None + convert_element_type_1513 = torch.ops.prims.convert_element_type.default(view_1793, torch.bfloat16); view_1793 = None + squeeze_26 = torch.ops.aten.squeeze.dim(convert_element_type_1513, 2); convert_element_type_1513 = None + cat_81 = torch.ops.aten.cat.default([convert_element_type_1509, squeeze_26], 2); convert_element_type_1509 = squeeze_26 = None + view_1794 = torch.ops.aten.view.default(cat_81, [8192, 576]); cat_81 = None + permute_448 = torch.ops.aten.permute.default(view_1794, [1, 0]) + mm_230 = torch.ops.aten.mm.default(permute_448, view_1711); permute_448 = None + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(primals_425, torch.bfloat16); primals_425 = None + all_gather_into_tensor_438 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1395, 64, '0'); convert_element_type_1395 = None + wait_tensor_538 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_438); all_gather_into_tensor_438 = None + permute_387 = torch.ops.aten.permute.default(wait_tensor_538, [1, 0]); wait_tensor_538 = None + permute_450 = torch.ops.aten.permute.default(permute_387, [1, 0]); permute_387 = None + mm_231 = torch.ops.aten.mm.default(view_1794, permute_450); view_1794 = permute_450 = None + view_1795 = torch.ops.aten.view.default(mm_231, [2, 4096, 2048]); mm_231 = None + convert_element_type_1518 = torch.ops.prims.convert_element_type.default(mm_230, torch.float32); mm_230 = None + reduce_scatter_tensor_13 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1518, 'avg', 64, '0'); convert_element_type_1518 = None + wait_tensor_574 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13); reduce_scatter_tensor_13 = None + slice_111 = torch.ops.aten.slice.Tensor(permute_443, 3, 0, 128) + slice_112 = torch.ops.aten.slice.Tensor(permute_443, 3, 128, 192); permute_443 = None + convert_element_type_1519 = torch.ops.prims.convert_element_type.default(slice_112, torch.float32); slice_112 = None + view_1796 = torch.ops.aten.view.default(convert_element_type_1519, [2, 4096, 16, 32, 2]); convert_element_type_1519 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_1796); view_1796 = None + mul_1318 = torch.ops.aten.mul.Tensor(view_as_complex_55, clone_9); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_1318); mul_1318 = None + view_1797 = torch.ops.aten.view.default(view_as_real_55, [2, 4096, 16, 64]); view_as_real_55 = None + convert_element_type_1520 = torch.ops.prims.convert_element_type.default(view_1797, torch.bfloat16); view_1797 = None + cat_82 = torch.ops.aten.cat.default([slice_111, convert_element_type_1520], 3); slice_111 = convert_element_type_1520 = None + view_1798 = torch.ops.aten.view.default(cat_82, [2, 4096, 3072]); cat_82 = None + view_1799 = torch.ops.aten.view.default(view_1798, [8192, 3072]); view_1798 = None + permute_452 = torch.ops.aten.permute.default(view_1799, [1, 0]) + mm_232 = torch.ops.aten.mm.default(permute_452, view_1711); permute_452 = view_1711 = None + convert_element_type_1390 = torch.ops.prims.convert_element_type.default(primals_424, torch.bfloat16); primals_424 = None + all_gather_into_tensor_437 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1390, 64, '0'); convert_element_type_1390 = None + wait_tensor_537 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_437); all_gather_into_tensor_437 = None + permute_386 = torch.ops.aten.permute.default(wait_tensor_537, [1, 0]); wait_tensor_537 = None + permute_454 = torch.ops.aten.permute.default(permute_386, [1, 0]); permute_386 = None + mm_233 = torch.ops.aten.mm.default(view_1799, permute_454); view_1799 = permute_454 = None + view_1800 = torch.ops.aten.view.default(mm_233, [2, 4096, 2048]); mm_233 = None + add_1788 = torch.ops.aten.add.Tensor(view_1795, view_1800); view_1795 = view_1800 = None + convert_element_type_1525 = torch.ops.prims.convert_element_type.default(mm_232, torch.float32); mm_232 = None + reduce_scatter_tensor_14 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1525, 'avg', 64, '0'); convert_element_type_1525 = None + wait_tensor_575 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_14); reduce_scatter_tensor_14 = None + convert_element_type_1526 = torch.ops.prims.convert_element_type.default(add_1788, torch.float32); add_1788 = None + convert_element_type_1387 = torch.ops.prims.convert_element_type.default(primals_423, torch.bfloat16); primals_423 = None + all_gather_into_tensor_436 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1387, 64, '0'); convert_element_type_1387 = None + wait_tensor_536 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_436); all_gather_into_tensor_436 = None + convert_element_type_1528 = torch.ops.prims.convert_element_type.default(wait_tensor_536, torch.float32); wait_tensor_536 = None + mul_1319 = torch.ops.aten.mul.Tensor(convert_element_type_1526, convert_element_type_1528); convert_element_type_1528 = None + convert_element_type_1388 = torch.ops.prims.convert_element_type.default(add_1705, torch.float32); add_1705 = None + mul_1234 = torch.ops.aten.mul.Tensor(convert_element_type_1388, rsqrt_78); convert_element_type_1388 = None + mul_1321 = torch.ops.aten.mul.Tensor(mul_1234, mul_1319) + sum_113 = torch.ops.aten.sum.dim_IntList(mul_1321, [2], True); mul_1321 = None + div_137 = torch.ops.aten.div.Tensor(mul_1234, 2048) + mul_1322 = torch.ops.aten.mul.Tensor(div_137, sum_113); div_137 = sum_113 = None + sub_630 = torch.ops.aten.sub.Tensor(mul_1319, mul_1322); mul_1319 = mul_1322 = None + mul_1323 = torch.ops.aten.mul.Tensor(sub_630, rsqrt_78); sub_630 = rsqrt_78 = None + mul_1324 = torch.ops.aten.mul.Tensor(convert_element_type_1526, mul_1234); convert_element_type_1526 = mul_1234 = None + sum_114 = torch.ops.aten.sum.dim_IntList(mul_1324, [0, 1]); mul_1324 = None + convert_element_type_1529 = torch.ops.prims.convert_element_type.default(mul_1323, torch.bfloat16); mul_1323 = None + add_1789 = torch.ops.aten.add.Tensor(add_1787, convert_element_type_1529); add_1787 = convert_element_type_1529 = None + convert_element_type_default_79 = torch.ops.prims.convert_element_type.default(sum_114, torch.float32); sum_114 = None + reduce_scatter_tensor_15 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_79, 'avg', 64, '0'); convert_element_type_default_79 = None + wait_tensor_576 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15); reduce_scatter_tensor_15 = None + view_1801 = torch.ops.aten.view.default(add_1789, [8192, 2048]) + unsqueeze_54 = torch.ops.aten.unsqueeze.default(view_1801, 1) + convert_element_type_1532 = torch.ops.prims.convert_element_type.default(unsqueeze_54, torch.float32); unsqueeze_54 = None + bmm_28 = torch.ops.aten.bmm.default(permute_456, convert_element_type_1532); permute_456 = None + bmm_29 = torch.ops.aten.bmm.default(convert_element_type_1532, permute_457); convert_element_type_1532 = permute_457 = None + convert_element_type_1533 = torch.ops.prims.convert_element_type.default(bmm_28, torch.bfloat16); bmm_28 = None + view_1802 = torch.ops.aten.view.default(bmm_29, [8192, 6]); bmm_29 = None + view_1803 = torch.ops.aten.view.default(convert_element_type_1533, [49152, 2048]); convert_element_type_1533 = None + index_54 = torch.ops.aten.index.Tensor(view_1803, [getitem_357]); view_1803 = getitem_357 = None + permute_458 = torch.ops.aten.permute.default(view_1801, [1, 0]) + mm_234 = torch.ops.aten.mm.default(permute_458, mul_1231); permute_458 = mul_1231 = None + convert_element_type_1382 = torch.ops.prims.convert_element_type.default(primals_422, torch.bfloat16); primals_422 = None + all_gather_into_tensor_435 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1382, 64, '0'); convert_element_type_1382 = None + wait_tensor_535 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_435); all_gather_into_tensor_435 = None + permute_385 = torch.ops.aten.permute.default(wait_tensor_535, [1, 0]); wait_tensor_535 = None + permute_460 = torch.ops.aten.permute.default(permute_385, [1, 0]); permute_385 = None + mm_235 = torch.ops.aten.mm.default(view_1801, permute_460); view_1801 = permute_460 = None + convert_element_type_1538 = torch.ops.prims.convert_element_type.default(mm_234, torch.float32); mm_234 = None + reduce_scatter_tensor_16 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1538, 'avg', 64, '0'); convert_element_type_1538 = None + wait_tensor_577 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_16); reduce_scatter_tensor_16 = None + convert_element_type_1377 = torch.ops.prims.convert_element_type.default(mm_204, torch.float32); mm_204 = None + neg_50 = torch.ops.aten.neg.default(convert_element_type_1377) + exp_75 = torch.ops.aten.exp.default(neg_50); neg_50 = None + add_1700 = torch.ops.aten.add.Tensor(exp_75, 1); exp_75 = None + div_125 = torch.ops.aten.div.Tensor(convert_element_type_1377, add_1700) + convert_element_type_1378 = torch.ops.prims.convert_element_type.default(div_125, torch.bfloat16); div_125 = None + mul_1325 = torch.ops.aten.mul.Tensor(mm_235, convert_element_type_1378); convert_element_type_1378 = None + mul_1326 = torch.ops.aten.mul.Tensor(mm_235, mm_205); mm_235 = mm_205 = None + permute_462 = torch.ops.aten.permute.default(mul_1325, [1, 0]) + mm_236 = torch.ops.aten.mm.default(permute_462, view_1666); permute_462 = None + convert_element_type_1379 = torch.ops.prims.convert_element_type.default(primals_421, torch.bfloat16); primals_421 = None + all_gather_into_tensor_434 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1379, 64, '0'); convert_element_type_1379 = None + wait_tensor_534 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_434); all_gather_into_tensor_434 = None + permute_384 = torch.ops.aten.permute.default(wait_tensor_534, [1, 0]); wait_tensor_534 = None + permute_464 = torch.ops.aten.permute.default(permute_384, [1, 0]); permute_384 = None + mm_237 = torch.ops.aten.mm.default(mul_1325, permute_464); mul_1325 = permute_464 = None + convert_element_type_1543 = torch.ops.prims.convert_element_type.default(mm_236, torch.float32); mm_236 = None + reduce_scatter_tensor_17 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1543, 'avg', 64, '0'); convert_element_type_1543 = None + wait_tensor_578 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17); reduce_scatter_tensor_17 = None + convert_element_type_1544 = torch.ops.prims.convert_element_type.default(mul_1326, torch.float32); mul_1326 = None + reciprocal_2 = torch.ops.aten.reciprocal.default(add_1700); add_1700 = None + mul_1327 = torch.ops.aten.mul.Tensor(reciprocal_2, 1); reciprocal_2 = None + mul_1328 = torch.ops.aten.mul.Tensor(convert_element_type_1544, mul_1327); convert_element_type_1544 = None + sub_631 = torch.ops.aten.sub.Tensor(1, mul_1327); mul_1327 = None + mul_1329 = torch.ops.aten.mul.Tensor(convert_element_type_1377, sub_631); convert_element_type_1377 = sub_631 = None + add_1791 = torch.ops.aten.add.Tensor(mul_1329, 1); mul_1329 = None + mul_1330 = torch.ops.aten.mul.Tensor(mul_1328, add_1791); mul_1328 = add_1791 = None + convert_element_type_1546 = torch.ops.prims.convert_element_type.default(mul_1330, torch.bfloat16); mul_1330 = None + permute_466 = torch.ops.aten.permute.default(convert_element_type_1546, [1, 0]) + mm_238 = torch.ops.aten.mm.default(permute_466, view_1666); permute_466 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(primals_420, torch.bfloat16); primals_420 = None + all_gather_into_tensor_433 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1374, 64, '0'); convert_element_type_1374 = None + wait_tensor_533 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_433); all_gather_into_tensor_433 = None + permute_383 = torch.ops.aten.permute.default(wait_tensor_533, [1, 0]); wait_tensor_533 = None + permute_468 = torch.ops.aten.permute.default(permute_383, [1, 0]); permute_383 = None + mm_239 = torch.ops.aten.mm.default(convert_element_type_1546, permute_468); convert_element_type_1546 = permute_468 = None + add_1792 = torch.ops.aten.add.Tensor(mm_237, mm_239); mm_237 = mm_239 = None + convert_element_type_1551 = torch.ops.prims.convert_element_type.default(mm_238, torch.float32); mm_238 = None + reduce_scatter_tensor_18 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1551, 'avg', 64, '0'); convert_element_type_1551 = None + wait_tensor_579 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_18); reduce_scatter_tensor_18 = None + all_to_all_single_80 = torch.ops._c10d_functional.all_to_all_single.default(index_54, [_local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399], [_local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391], '521'); index_54 = None + wait_tensor_580 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_80); all_to_all_single_80 = None + full_352 = torch.ops.aten.full.default([sym_size_int_97, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_97 = None + slice_scatter_1 = torch.ops.aten.slice_scatter.default(full_352, wait_tensor_580, 0, 0, -1); wait_tensor_580 = None + index_55 = torch.ops.aten.index.Tensor(slice_scatter_1, [getitem_358]); slice_scatter_1 = None + permute_470 = torch.ops.aten.permute.default(index_55, [1, 0]) + _grouped_mm_84 = torch.ops.aten._grouped_mm.default(permute_470, mul_1211, cumsum_74); permute_470 = mul_1211 = None + convert_element_type_1368 = torch.ops.prims.convert_element_type.default(primals_418, torch.bfloat16); primals_418 = None + all_gather_into_tensor_429 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1368, 8, '513'); convert_element_type_1368 = None + wait_tensor_528 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_429); all_gather_into_tensor_429 = None + permute_382 = torch.ops.aten.permute.default(wait_tensor_528, [0, 2, 1]); wait_tensor_528 = None + permute_472 = torch.ops.aten.permute.default(permute_382, [0, 2, 1]); permute_382 = None + _grouped_mm_85 = torch.ops.aten._grouped_mm.default(index_55, permute_472, cumsum_74); index_55 = permute_472 = None + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(_grouped_mm_72, torch.float32); _grouped_mm_72 = None + neg_49 = torch.ops.aten.neg.default(convert_element_type_1372) + exp_74 = torch.ops.aten.exp.default(neg_49); neg_49 = None + add_1664 = torch.ops.aten.add.Tensor(exp_74, 1); exp_74 = None + div_124 = torch.ops.aten.div.Tensor(convert_element_type_1372, add_1664) + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(div_124, torch.bfloat16); div_124 = None + mul_1331 = torch.ops.aten.mul.Tensor(_grouped_mm_85, convert_element_type_1373); convert_element_type_1373 = None + mul_1332 = torch.ops.aten.mul.Tensor(_grouped_mm_85, _grouped_mm_73); _grouped_mm_85 = _grouped_mm_73 = None + permute_474 = torch.ops.aten.permute.default(mul_1331, [1, 0]) + _grouped_mm_86 = torch.ops.aten._grouped_mm.default(permute_474, index_49, cumsum_74); permute_474 = None + convert_element_type_1369 = torch.ops.prims.convert_element_type.default(primals_419, torch.bfloat16); primals_419 = None + all_gather_into_tensor_430 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1369, 8, '513'); convert_element_type_1369 = None + wait_tensor_529 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_430); all_gather_into_tensor_430 = None + permute_381 = torch.ops.aten.permute.default(wait_tensor_529, [0, 2, 1]); wait_tensor_529 = None + permute_476 = torch.ops.aten.permute.default(permute_381, [0, 2, 1]); permute_381 = None + _grouped_mm_87 = torch.ops.aten._grouped_mm.default(mul_1331, permute_476, cumsum_74); mul_1331 = permute_476 = None + convert_element_type_1552 = torch.ops.prims.convert_element_type.default(mul_1332, torch.float32); mul_1332 = None + reciprocal_3 = torch.ops.aten.reciprocal.default(add_1664); add_1664 = None + mul_1333 = torch.ops.aten.mul.Tensor(reciprocal_3, 1); reciprocal_3 = None + mul_1334 = torch.ops.aten.mul.Tensor(convert_element_type_1552, mul_1333); convert_element_type_1552 = None + sub_632 = torch.ops.aten.sub.Tensor(1, mul_1333); mul_1333 = None + mul_1335 = torch.ops.aten.mul.Tensor(convert_element_type_1372, sub_632); convert_element_type_1372 = sub_632 = None + add_1794 = torch.ops.aten.add.Tensor(mul_1335, 1); mul_1335 = None + mul_1336 = torch.ops.aten.mul.Tensor(mul_1334, add_1794); mul_1334 = add_1794 = None + convert_element_type_1554 = torch.ops.prims.convert_element_type.default(mul_1336, torch.bfloat16); mul_1336 = None + permute_478 = torch.ops.aten.permute.default(convert_element_type_1554, [1, 0]) + _grouped_mm_88 = torch.ops.aten._grouped_mm.default(permute_478, index_49, cumsum_74); permute_478 = index_49 = None + convert_element_type_1366 = torch.ops.prims.convert_element_type.default(primals_417, torch.bfloat16); primals_417 = None + all_gather_into_tensor_427 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1366, 8, '513'); convert_element_type_1366 = None + wait_tensor_526 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_427); all_gather_into_tensor_427 = None + permute_380 = torch.ops.aten.permute.default(wait_tensor_526, [0, 2, 1]); wait_tensor_526 = None + permute_480 = torch.ops.aten.permute.default(permute_380, [0, 2, 1]); permute_380 = None + _grouped_mm_89 = torch.ops.aten._grouped_mm.default(convert_element_type_1554, permute_480, cumsum_74); convert_element_type_1554 = permute_480 = cumsum_74 = None + add_1795 = torch.ops.aten.add.Tensor(_grouped_mm_87, _grouped_mm_89); _grouped_mm_87 = _grouped_mm_89 = None + convert_element_type_1555 = torch.ops.prims.convert_element_type.default(_grouped_mm_86, torch.float32); _grouped_mm_86 = None + div_138 = torch.ops.aten.div.Tensor(convert_element_type_1555, 64); convert_element_type_1555 = None + reduce_scatter_tensor_19 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_138, 'sum', 8, '513'); div_138 = None + wait_tensor_581 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19); reduce_scatter_tensor_19 = None + convert_element_type_1556 = torch.ops.prims.convert_element_type.default(_grouped_mm_84, torch.float32); _grouped_mm_84 = None + div_139 = torch.ops.aten.div.Tensor(convert_element_type_1556, 64); convert_element_type_1556 = None + reduce_scatter_tensor_20 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_139, 'sum', 8, '513'); div_139 = None + wait_tensor_582 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_20); reduce_scatter_tensor_20 = None + convert_element_type_1557 = torch.ops.prims.convert_element_type.default(_grouped_mm_88, torch.float32); _grouped_mm_88 = None + div_140 = torch.ops.aten.div.Tensor(convert_element_type_1557, 64); convert_element_type_1557 = None + reduce_scatter_tensor_21 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_140, 'sum', 8, '513'); div_140 = None + wait_tensor_583 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21); reduce_scatter_tensor_21 = None + index_put_54 = torch.ops.aten.index_put.default(full_352, [getitem_358], add_1795, True); full_352 = getitem_358 = add_1795 = None + slice_113 = torch.ops.aten.slice.Tensor(index_put_54, 0, 0, add_1796); index_put_54 = add_1796 = None + all_to_all_single_81 = torch.ops._c10d_functional.all_to_all_single.default(slice_113, [_local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391], [_local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399], '521'); slice_113 = _local_scalar_dense_384 = _local_scalar_dense_385 = _local_scalar_dense_386 = _local_scalar_dense_387 = _local_scalar_dense_388 = _local_scalar_dense_389 = _local_scalar_dense_390 = _local_scalar_dense_391 = _local_scalar_dense_392 = _local_scalar_dense_393 = _local_scalar_dense_394 = _local_scalar_dense_395 = _local_scalar_dense_396 = _local_scalar_dense_397 = _local_scalar_dense_398 = _local_scalar_dense_399 = None + wait_tensor_584 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_81); all_to_all_single_81 = None + index_put_55 = torch.ops.aten.index_put.default(full_default_52, [div_122], wait_tensor_584, True); div_122 = wait_tensor_584 = None + add_1800 = torch.ops.aten.add.Tensor(add_1792, index_put_55); add_1792 = index_put_55 = None + mul_1337 = torch.ops.aten.mul.Tensor(view_1802, 1.0); view_1802 = None + scatter_add_1 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_355, mul_1337); getitem_355 = mul_1337 = None + convert_element_type_1361 = torch.ops.prims.convert_element_type.default(mm_203, torch.float32); mm_203 = None + sub_576 = torch.ops.aten.sub.Tensor(convert_element_type_1361, amax_24); convert_element_type_1361 = amax_24 = None + exp_73 = torch.ops.aten.exp.default(sub_576); sub_576 = None + div_121 = torch.ops.aten.div.Tensor(exp_73, sum_97); exp_73 = sum_97 = None + mul_1338 = torch.ops.aten.mul.Tensor(scatter_add_1, div_121); scatter_add_1 = None + sum_115 = torch.ops.aten.sum.dim_IntList(mul_1338, [1], True) + neg_58 = torch.ops.aten.neg.default(div_121); div_121 = None + fma_1 = torch.ops.prims.fma.default(neg_58, sum_115, mul_1338); neg_58 = sum_115 = mul_1338 = None + convert_element_type_1558 = torch.ops.prims.convert_element_type.default(fma_1, torch.bfloat16); fma_1 = None + permute_482 = torch.ops.aten.permute.default(convert_element_type_1558, [1, 0]) + mm_240 = torch.ops.aten.mm.default(permute_482, view_1666); permute_482 = view_1666 = None + convert_element_type_1358 = torch.ops.prims.convert_element_type.default(primals_415, torch.bfloat16); primals_415 = None + all_gather_into_tensor_426 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1358, 64, '0'); convert_element_type_1358 = None + wait_tensor_522 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_426); all_gather_into_tensor_426 = None + permute_379 = torch.ops.aten.permute.default(wait_tensor_522, [1, 0]); wait_tensor_522 = None + permute_484 = torch.ops.aten.permute.default(permute_379, [1, 0]); permute_379 = None + mm_241 = torch.ops.aten.mm.default(convert_element_type_1558, permute_484); convert_element_type_1558 = permute_484 = None + add_1801 = torch.ops.aten.add.Tensor(add_1800, mm_241); add_1800 = mm_241 = None + convert_element_type_1563 = torch.ops.prims.convert_element_type.default(mm_240, torch.float32); mm_240 = None + reduce_scatter_tensor_22 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1563, 'avg', 64, '0'); convert_element_type_1563 = None + wait_tensor_585 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_22); reduce_scatter_tensor_22 = None + view_1804 = torch.ops.aten.view.default(add_1801, [2, 4096, 2048]); add_1801 = None + convert_element_type_1564 = torch.ops.prims.convert_element_type.default(view_1804, torch.float32); view_1804 = None + convert_element_type_1355 = torch.ops.prims.convert_element_type.default(primals_413, torch.bfloat16); primals_413 = None + all_gather_into_tensor_425 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1355, 64, '0'); convert_element_type_1355 = None + wait_tensor_521 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_425); all_gather_into_tensor_425 = None + convert_element_type_1566 = torch.ops.prims.convert_element_type.default(wait_tensor_521, torch.float32); wait_tensor_521 = None + mul_1339 = torch.ops.aten.mul.Tensor(convert_element_type_1564, convert_element_type_1566); convert_element_type_1566 = None + convert_element_type_1356 = torch.ops.prims.convert_element_type.default(add_1640, torch.float32); add_1640 = None + mul_1191 = torch.ops.aten.mul.Tensor(convert_element_type_1356, rsqrt_77); convert_element_type_1356 = None + mul_1341 = torch.ops.aten.mul.Tensor(mul_1191, mul_1339) + sum_116 = torch.ops.aten.sum.dim_IntList(mul_1341, [2], True); mul_1341 = None + div_141 = torch.ops.aten.div.Tensor(mul_1191, 2048) + mul_1342 = torch.ops.aten.mul.Tensor(div_141, sum_116); div_141 = sum_116 = None + sub_634 = torch.ops.aten.sub.Tensor(mul_1339, mul_1342); mul_1339 = mul_1342 = None + mul_1343 = torch.ops.aten.mul.Tensor(sub_634, rsqrt_77); sub_634 = rsqrt_77 = None + mul_1344 = torch.ops.aten.mul.Tensor(convert_element_type_1564, mul_1191); convert_element_type_1564 = mul_1191 = None + sum_117 = torch.ops.aten.sum.dim_IntList(mul_1344, [0, 1]); mul_1344 = None + convert_element_type_1567 = torch.ops.prims.convert_element_type.default(mul_1343, torch.bfloat16); mul_1343 = None + add_1802 = torch.ops.aten.add.Tensor(add_1789, convert_element_type_1567); add_1789 = convert_element_type_1567 = None + convert_element_type_default_78 = torch.ops.prims.convert_element_type.default(sum_117, torch.float32); sum_117 = None + reduce_scatter_tensor_23 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_78, 'avg', 64, '0'); convert_element_type_default_78 = None + wait_tensor_586 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23); reduce_scatter_tensor_23 = None + view_1805 = torch.ops.aten.view.default(add_1802, [8192, 2048]) + permute_486 = torch.ops.aten.permute.default(view_1805, [1, 0]) + permute_377 = torch.ops.aten.permute.default(getitem_351, [0, 2, 1, 3]) + view_1661 = torch.ops.aten.view.default(permute_377, [2, 4096, -1]); permute_377 = None + view_1663 = torch.ops.aten.view.default(view_1661, [8192, 2048]); view_1661 = None + mm_242 = torch.ops.aten.mm.default(permute_486, view_1663); permute_486 = view_1663 = None + convert_element_type_1352 = torch.ops.prims.convert_element_type.default(primals_412, torch.bfloat16); primals_412 = None + all_gather_into_tensor_424 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1352, 64, '0'); convert_element_type_1352 = None + wait_tensor_520 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_424); all_gather_into_tensor_424 = None + permute_378 = torch.ops.aten.permute.default(wait_tensor_520, [1, 0]); wait_tensor_520 = None + permute_488 = torch.ops.aten.permute.default(permute_378, [1, 0]); permute_378 = None + mm_243 = torch.ops.aten.mm.default(view_1805, permute_488); view_1805 = permute_488 = None + view_1806 = torch.ops.aten.view.default(mm_243, [2, 4096, 2048]); mm_243 = None + convert_element_type_1574 = torch.ops.prims.convert_element_type.default(mm_242, torch.float32); mm_242 = None + reduce_scatter_tensor_24 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1574, 'avg', 64, '0'); convert_element_type_1574 = None + wait_tensor_587 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_24); reduce_scatter_tensor_24 = None + view_1807 = torch.ops.aten.view.default(view_1806, [2, 4096, 16, 128]); view_1806 = None + permute_490 = torch.ops.aten.permute.default(view_1807, [0, 2, 1, 3]); view_1807 = None + fw_graph1 = self.fw_graph1 + joint_graph1 = self.joint_graph1 + mask_graph1 = self.mask_graph1 + flex_attention_backward_1 = torch.ops.higher_order.flex_attention_backward(permute_374, permute_375, permute_376, getitem_351, getitem_352, permute_490, None, fw_graph1, joint_graph1, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph1), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_374 = permute_375 = permute_376 = getitem_351 = getitem_352 = permute_490 = fw_graph1 = joint_graph1 = mask_graph1 = None + getitem_377 = flex_attention_backward_1[0] + getitem_378 = flex_attention_backward_1[1] + getitem_379 = flex_attention_backward_1[2]; flex_attention_backward_1 = None + permute_491 = torch.ops.aten.permute.default(getitem_379, [0, 2, 1, 3]); getitem_379 = None + permute_492 = torch.ops.aten.permute.default(getitem_378, [0, 2, 1, 3]); getitem_378 = None + permute_493 = torch.ops.aten.permute.default(getitem_377, [0, 2, 1, 3]); getitem_377 = None + slice_115 = torch.ops.aten.slice.Tensor(permute_492, 3, 0, 128) + slice_116 = torch.ops.aten.slice.Tensor(permute_492, 3, 128, 192); permute_492 = None + sum_118 = torch.ops.aten.sum.dim_IntList(slice_116, [2], True); slice_116 = None + cat_83 = torch.ops.aten.cat.default([slice_115, permute_491], 3); slice_115 = permute_491 = None + view_1808 = torch.ops.aten.view.default(cat_83, [2, 4096, 4096]); cat_83 = None + view_1809 = torch.ops.aten.view.default(view_1808, [8192, 4096]); view_1808 = None + permute_494 = torch.ops.aten.permute.default(view_1809, [1, 0]) + mm_244 = torch.ops.aten.mm.default(permute_494, view_1658); permute_494 = view_1658 = None + convert_element_type_1349 = torch.ops.prims.convert_element_type.default(primals_411, torch.bfloat16); primals_411 = None + all_gather_into_tensor_423 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1349, 64, '0'); convert_element_type_1349 = None + wait_tensor_519 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_423); all_gather_into_tensor_423 = None + permute_373 = torch.ops.aten.permute.default(wait_tensor_519, [1, 0]); wait_tensor_519 = None + permute_496 = torch.ops.aten.permute.default(permute_373, [1, 0]); permute_373 = None + mm_245 = torch.ops.aten.mm.default(view_1809, permute_496); view_1809 = permute_496 = None + view_1810 = torch.ops.aten.view.default(mm_245, [2, 4096, 512]); mm_245 = None + convert_element_type_1579 = torch.ops.prims.convert_element_type.default(mm_244, torch.float32); mm_244 = None + reduce_scatter_tensor_25 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1579, 'avg', 64, '0'); convert_element_type_1579 = None + wait_tensor_588 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25); reduce_scatter_tensor_25 = None + convert_element_type_1580 = torch.ops.prims.convert_element_type.default(view_1810, torch.float32); view_1810 = None + convert_element_type_1346 = torch.ops.prims.convert_element_type.default(primals_410, torch.bfloat16); primals_410 = None + all_gather_into_tensor_422 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1346, 64, '0'); convert_element_type_1346 = None + wait_tensor_518 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_422); all_gather_into_tensor_422 = None + convert_element_type_1582 = torch.ops.prims.convert_element_type.default(wait_tensor_518, torch.float32); wait_tensor_518 = None + mul_1345 = torch.ops.aten.mul.Tensor(convert_element_type_1580, convert_element_type_1582); convert_element_type_1582 = None + convert_element_type_1347 = torch.ops.prims.convert_element_type.default(getitem_347, torch.float32); getitem_347 = None + mul_1189 = torch.ops.aten.mul.Tensor(convert_element_type_1347, rsqrt_76); convert_element_type_1347 = None + mul_1347 = torch.ops.aten.mul.Tensor(mul_1189, mul_1345) + sum_119 = torch.ops.aten.sum.dim_IntList(mul_1347, [2], True); mul_1347 = None + div_142 = torch.ops.aten.div.Tensor(mul_1189, 512) + mul_1348 = torch.ops.aten.mul.Tensor(div_142, sum_119); div_142 = sum_119 = None + sub_635 = torch.ops.aten.sub.Tensor(mul_1345, mul_1348); mul_1345 = mul_1348 = None + mul_1349 = torch.ops.aten.mul.Tensor(sub_635, rsqrt_76); sub_635 = rsqrt_76 = None + mul_1350 = torch.ops.aten.mul.Tensor(convert_element_type_1580, mul_1189); convert_element_type_1580 = mul_1189 = None + sum_120 = torch.ops.aten.sum.dim_IntList(mul_1350, [0, 1]); mul_1350 = None + convert_element_type_1583 = torch.ops.prims.convert_element_type.default(mul_1349, torch.bfloat16); mul_1349 = None + convert_element_type_default_77 = torch.ops.prims.convert_element_type.default(sum_120, torch.float32); sum_120 = None + reduce_scatter_tensor_26 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_77, 'avg', 64, '0'); convert_element_type_default_77 = None + wait_tensor_589 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_26); reduce_scatter_tensor_26 = None + convert_element_type_1586 = torch.ops.prims.convert_element_type.default(sum_118, torch.float32); sum_118 = None + view_1811 = torch.ops.aten.view.default(convert_element_type_1586, [2, 4096, 1, 32, 2]); convert_element_type_1586 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_1811); view_1811 = None + mul_1351 = torch.ops.aten.mul.Tensor(view_as_complex_56, clone_9); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_1351); mul_1351 = None + view_1812 = torch.ops.aten.view.default(view_as_real_56, [2, 4096, 1, 64]); view_as_real_56 = None + convert_element_type_1587 = torch.ops.prims.convert_element_type.default(view_1812, torch.bfloat16); view_1812 = None + squeeze_27 = torch.ops.aten.squeeze.dim(convert_element_type_1587, 2); convert_element_type_1587 = None + cat_84 = torch.ops.aten.cat.default([convert_element_type_1583, squeeze_27], 2); convert_element_type_1583 = squeeze_27 = None + view_1813 = torch.ops.aten.view.default(cat_84, [8192, 576]); cat_84 = None + permute_498 = torch.ops.aten.permute.default(view_1813, [1, 0]) + mm_246 = torch.ops.aten.mm.default(permute_498, view_1644); permute_498 = None + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(primals_409, torch.bfloat16); primals_409 = None + all_gather_into_tensor_421 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1341, 64, '0'); convert_element_type_1341 = None + wait_tensor_517 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_421); all_gather_into_tensor_421 = None + permute_372 = torch.ops.aten.permute.default(wait_tensor_517, [1, 0]); wait_tensor_517 = None + permute_500 = torch.ops.aten.permute.default(permute_372, [1, 0]); permute_372 = None + mm_247 = torch.ops.aten.mm.default(view_1813, permute_500); view_1813 = permute_500 = None + view_1814 = torch.ops.aten.view.default(mm_247, [2, 4096, 2048]); mm_247 = None + convert_element_type_1592 = torch.ops.prims.convert_element_type.default(mm_246, torch.float32); mm_246 = None + reduce_scatter_tensor_27 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1592, 'avg', 64, '0'); convert_element_type_1592 = None + wait_tensor_590 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27); reduce_scatter_tensor_27 = None + slice_117 = torch.ops.aten.slice.Tensor(permute_493, 3, 0, 128) + slice_118 = torch.ops.aten.slice.Tensor(permute_493, 3, 128, 192); permute_493 = None + convert_element_type_1593 = torch.ops.prims.convert_element_type.default(slice_118, torch.float32); slice_118 = None + view_1815 = torch.ops.aten.view.default(convert_element_type_1593, [2, 4096, 16, 32, 2]); convert_element_type_1593 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_1815); view_1815 = None + mul_1352 = torch.ops.aten.mul.Tensor(view_as_complex_57, clone_9); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_1352); mul_1352 = None + view_1816 = torch.ops.aten.view.default(view_as_real_57, [2, 4096, 16, 64]); view_as_real_57 = None + convert_element_type_1594 = torch.ops.prims.convert_element_type.default(view_1816, torch.bfloat16); view_1816 = None + cat_85 = torch.ops.aten.cat.default([slice_117, convert_element_type_1594], 3); slice_117 = convert_element_type_1594 = None + view_1817 = torch.ops.aten.view.default(cat_85, [2, 4096, 3072]); cat_85 = None + view_1818 = torch.ops.aten.view.default(view_1817, [8192, 3072]); view_1817 = None + permute_502 = torch.ops.aten.permute.default(view_1818, [1, 0]) + mm_248 = torch.ops.aten.mm.default(permute_502, view_1644); permute_502 = view_1644 = None + convert_element_type_1336 = torch.ops.prims.convert_element_type.default(primals_408, torch.bfloat16); primals_408 = None + all_gather_into_tensor_420 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1336, 64, '0'); convert_element_type_1336 = None + wait_tensor_516 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_420); all_gather_into_tensor_420 = None + permute_371 = torch.ops.aten.permute.default(wait_tensor_516, [1, 0]); wait_tensor_516 = None + permute_504 = torch.ops.aten.permute.default(permute_371, [1, 0]); permute_371 = None + mm_249 = torch.ops.aten.mm.default(view_1818, permute_504); view_1818 = permute_504 = None + view_1819 = torch.ops.aten.view.default(mm_249, [2, 4096, 2048]); mm_249 = None + add_1803 = torch.ops.aten.add.Tensor(view_1814, view_1819); view_1814 = view_1819 = None + convert_element_type_1599 = torch.ops.prims.convert_element_type.default(mm_248, torch.float32); mm_248 = None + reduce_scatter_tensor_28 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1599, 'avg', 64, '0'); convert_element_type_1599 = None + wait_tensor_591 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_28); reduce_scatter_tensor_28 = None + convert_element_type_1600 = torch.ops.prims.convert_element_type.default(add_1803, torch.float32); add_1803 = None + convert_element_type_1333 = torch.ops.prims.convert_element_type.default(primals_407, torch.bfloat16); primals_407 = None + all_gather_into_tensor_419 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1333, 64, '0'); convert_element_type_1333 = None + wait_tensor_515 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_419); all_gather_into_tensor_419 = None + convert_element_type_1602 = torch.ops.prims.convert_element_type.default(wait_tensor_515, torch.float32); wait_tensor_515 = None + mul_1353 = torch.ops.aten.mul.Tensor(convert_element_type_1600, convert_element_type_1602); convert_element_type_1602 = None + convert_element_type_1334 = torch.ops.prims.convert_element_type.default(add_1637, torch.float32); add_1637 = None + mul_1185 = torch.ops.aten.mul.Tensor(convert_element_type_1334, rsqrt_75); convert_element_type_1334 = None + mul_1355 = torch.ops.aten.mul.Tensor(mul_1185, mul_1353) + sum_121 = torch.ops.aten.sum.dim_IntList(mul_1355, [2], True); mul_1355 = None + div_143 = torch.ops.aten.div.Tensor(mul_1185, 2048) + mul_1356 = torch.ops.aten.mul.Tensor(div_143, sum_121); div_143 = sum_121 = None + sub_636 = torch.ops.aten.sub.Tensor(mul_1353, mul_1356); mul_1353 = mul_1356 = None + mul_1357 = torch.ops.aten.mul.Tensor(sub_636, rsqrt_75); sub_636 = rsqrt_75 = None + mul_1358 = torch.ops.aten.mul.Tensor(convert_element_type_1600, mul_1185); convert_element_type_1600 = mul_1185 = None + sum_122 = torch.ops.aten.sum.dim_IntList(mul_1358, [0, 1]); mul_1358 = None + convert_element_type_1603 = torch.ops.prims.convert_element_type.default(mul_1357, torch.bfloat16); mul_1357 = None + add_1804 = torch.ops.aten.add.Tensor(add_1802, convert_element_type_1603); add_1802 = convert_element_type_1603 = None + convert_element_type_default_76 = torch.ops.prims.convert_element_type.default(sum_122, torch.float32); sum_122 = None + reduce_scatter_tensor_29 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_76, 'avg', 64, '0'); convert_element_type_default_76 = None + wait_tensor_592 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29); reduce_scatter_tensor_29 = None + view_1820 = torch.ops.aten.view.default(add_1804, [8192, 2048]) + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_1820, 1) + convert_element_type_1606 = torch.ops.prims.convert_element_type.default(unsqueeze_55, torch.float32); unsqueeze_55 = None + bmm_30 = torch.ops.aten.bmm.default(permute_506, convert_element_type_1606); permute_506 = None + bmm_31 = torch.ops.aten.bmm.default(convert_element_type_1606, permute_507); convert_element_type_1606 = permute_507 = None + convert_element_type_1607 = torch.ops.prims.convert_element_type.default(bmm_30, torch.bfloat16); bmm_30 = None + view_1821 = torch.ops.aten.view.default(bmm_31, [8192, 6]); bmm_31 = None + view_1822 = torch.ops.aten.view.default(convert_element_type_1607, [49152, 2048]); convert_element_type_1607 = None + index_56 = torch.ops.aten.index.Tensor(view_1822, [getitem_343]); view_1822 = getitem_343 = None + permute_508 = torch.ops.aten.permute.default(view_1820, [1, 0]) + mm_250 = torch.ops.aten.mm.default(permute_508, mul_1182); permute_508 = mul_1182 = None + convert_element_type_1328 = torch.ops.prims.convert_element_type.default(primals_406, torch.bfloat16); primals_406 = None + all_gather_into_tensor_418 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1328, 64, '0'); convert_element_type_1328 = None + wait_tensor_514 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_418); all_gather_into_tensor_418 = None + permute_370 = torch.ops.aten.permute.default(wait_tensor_514, [1, 0]); wait_tensor_514 = None + permute_510 = torch.ops.aten.permute.default(permute_370, [1, 0]); permute_370 = None + mm_251 = torch.ops.aten.mm.default(view_1820, permute_510); view_1820 = permute_510 = None + convert_element_type_1612 = torch.ops.prims.convert_element_type.default(mm_250, torch.float32); mm_250 = None + reduce_scatter_tensor_30 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1612, 'avg', 64, '0'); convert_element_type_1612 = None + wait_tensor_593 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_30); reduce_scatter_tensor_30 = None + convert_element_type_1323 = torch.ops.prims.convert_element_type.default(mm_196, torch.float32); mm_196 = None + neg_48 = torch.ops.aten.neg.default(convert_element_type_1323) + exp_72 = torch.ops.aten.exp.default(neg_48); neg_48 = None + add_1632 = torch.ops.aten.add.Tensor(exp_72, 1); exp_72 = None + div_120 = torch.ops.aten.div.Tensor(convert_element_type_1323, add_1632) + convert_element_type_1324 = torch.ops.prims.convert_element_type.default(div_120, torch.bfloat16); div_120 = None + mul_1359 = torch.ops.aten.mul.Tensor(mm_251, convert_element_type_1324); convert_element_type_1324 = None + mul_1360 = torch.ops.aten.mul.Tensor(mm_251, mm_197); mm_251 = mm_197 = None + permute_512 = torch.ops.aten.permute.default(mul_1359, [1, 0]) + mm_252 = torch.ops.aten.mm.default(permute_512, view_1599); permute_512 = None + convert_element_type_1325 = torch.ops.prims.convert_element_type.default(primals_405, torch.bfloat16); primals_405 = None + all_gather_into_tensor_417 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1325, 64, '0'); convert_element_type_1325 = None + wait_tensor_513 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_417); all_gather_into_tensor_417 = None + permute_369 = torch.ops.aten.permute.default(wait_tensor_513, [1, 0]); wait_tensor_513 = None + permute_514 = torch.ops.aten.permute.default(permute_369, [1, 0]); permute_369 = None + mm_253 = torch.ops.aten.mm.default(mul_1359, permute_514); mul_1359 = permute_514 = None + convert_element_type_1617 = torch.ops.prims.convert_element_type.default(mm_252, torch.float32); mm_252 = None + reduce_scatter_tensor_31 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1617, 'avg', 64, '0'); convert_element_type_1617 = None + wait_tensor_594 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31); reduce_scatter_tensor_31 = None + convert_element_type_1618 = torch.ops.prims.convert_element_type.default(mul_1360, torch.float32); mul_1360 = None + reciprocal_4 = torch.ops.aten.reciprocal.default(add_1632); add_1632 = None + mul_1361 = torch.ops.aten.mul.Tensor(reciprocal_4, 1); reciprocal_4 = None + mul_1362 = torch.ops.aten.mul.Tensor(convert_element_type_1618, mul_1361); convert_element_type_1618 = None + sub_637 = torch.ops.aten.sub.Tensor(1, mul_1361); mul_1361 = None + mul_1363 = torch.ops.aten.mul.Tensor(convert_element_type_1323, sub_637); convert_element_type_1323 = sub_637 = None + add_1806 = torch.ops.aten.add.Tensor(mul_1363, 1); mul_1363 = None + mul_1364 = torch.ops.aten.mul.Tensor(mul_1362, add_1806); mul_1362 = add_1806 = None + convert_element_type_1620 = torch.ops.prims.convert_element_type.default(mul_1364, torch.bfloat16); mul_1364 = None + permute_516 = torch.ops.aten.permute.default(convert_element_type_1620, [1, 0]) + mm_254 = torch.ops.aten.mm.default(permute_516, view_1599); permute_516 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(primals_404, torch.bfloat16); primals_404 = None + all_gather_into_tensor_416 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1320, 64, '0'); convert_element_type_1320 = None + wait_tensor_512 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_416); all_gather_into_tensor_416 = None + permute_368 = torch.ops.aten.permute.default(wait_tensor_512, [1, 0]); wait_tensor_512 = None + permute_518 = torch.ops.aten.permute.default(permute_368, [1, 0]); permute_368 = None + mm_255 = torch.ops.aten.mm.default(convert_element_type_1620, permute_518); convert_element_type_1620 = permute_518 = None + add_1807 = torch.ops.aten.add.Tensor(mm_253, mm_255); mm_253 = mm_255 = None + convert_element_type_1625 = torch.ops.prims.convert_element_type.default(mm_254, torch.float32); mm_254 = None + reduce_scatter_tensor_32 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1625, 'avg', 64, '0'); convert_element_type_1625 = None + wait_tensor_595 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + all_to_all_single_82 = torch.ops._c10d_functional.all_to_all_single.default(index_56, [_local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383], [_local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375], '521'); index_56 = None + wait_tensor_596 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_82); all_to_all_single_82 = None + full_356 = torch.ops.aten.full.default([sym_size_int_93, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_93 = None + slice_scatter_2 = torch.ops.aten.slice_scatter.default(full_356, wait_tensor_596, 0, 0, -1); wait_tensor_596 = None + index_57 = torch.ops.aten.index.Tensor(slice_scatter_2, [getitem_344]); slice_scatter_2 = None + permute_520 = torch.ops.aten.permute.default(index_57, [1, 0]) + _grouped_mm_90 = torch.ops.aten._grouped_mm.default(permute_520, mul_1162, cumsum_71); permute_520 = mul_1162 = None + convert_element_type_1314 = torch.ops.prims.convert_element_type.default(primals_402, torch.bfloat16); primals_402 = None + all_gather_into_tensor_412 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1314, 8, '513'); convert_element_type_1314 = None + wait_tensor_507 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_412); all_gather_into_tensor_412 = None + permute_367 = torch.ops.aten.permute.default(wait_tensor_507, [0, 2, 1]); wait_tensor_507 = None + permute_522 = torch.ops.aten.permute.default(permute_367, [0, 2, 1]); permute_367 = None + _grouped_mm_91 = torch.ops.aten._grouped_mm.default(index_57, permute_522, cumsum_71); index_57 = permute_522 = None + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(_grouped_mm_69, torch.float32); _grouped_mm_69 = None + neg_47 = torch.ops.aten.neg.default(convert_element_type_1318) + exp_71 = torch.ops.aten.exp.default(neg_47); neg_47 = None + add_1596 = torch.ops.aten.add.Tensor(exp_71, 1); exp_71 = None + div_119 = torch.ops.aten.div.Tensor(convert_element_type_1318, add_1596) + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(div_119, torch.bfloat16); div_119 = None + mul_1365 = torch.ops.aten.mul.Tensor(_grouped_mm_91, convert_element_type_1319); convert_element_type_1319 = None + mul_1366 = torch.ops.aten.mul.Tensor(_grouped_mm_91, _grouped_mm_70); _grouped_mm_91 = _grouped_mm_70 = None + permute_524 = torch.ops.aten.permute.default(mul_1365, [1, 0]) + _grouped_mm_92 = torch.ops.aten._grouped_mm.default(permute_524, index_47, cumsum_71); permute_524 = None + convert_element_type_1315 = torch.ops.prims.convert_element_type.default(primals_403, torch.bfloat16); primals_403 = None + all_gather_into_tensor_413 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1315, 8, '513'); convert_element_type_1315 = None + wait_tensor_508 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_413); all_gather_into_tensor_413 = None + permute_366 = torch.ops.aten.permute.default(wait_tensor_508, [0, 2, 1]); wait_tensor_508 = None + permute_526 = torch.ops.aten.permute.default(permute_366, [0, 2, 1]); permute_366 = None + _grouped_mm_93 = torch.ops.aten._grouped_mm.default(mul_1365, permute_526, cumsum_71); mul_1365 = permute_526 = None + convert_element_type_1626 = torch.ops.prims.convert_element_type.default(mul_1366, torch.float32); mul_1366 = None + reciprocal_5 = torch.ops.aten.reciprocal.default(add_1596); add_1596 = None + mul_1367 = torch.ops.aten.mul.Tensor(reciprocal_5, 1); reciprocal_5 = None + mul_1368 = torch.ops.aten.mul.Tensor(convert_element_type_1626, mul_1367); convert_element_type_1626 = None + sub_638 = torch.ops.aten.sub.Tensor(1, mul_1367); mul_1367 = None + mul_1369 = torch.ops.aten.mul.Tensor(convert_element_type_1318, sub_638); convert_element_type_1318 = sub_638 = None + add_1809 = torch.ops.aten.add.Tensor(mul_1369, 1); mul_1369 = None + mul_1370 = torch.ops.aten.mul.Tensor(mul_1368, add_1809); mul_1368 = add_1809 = None + convert_element_type_1628 = torch.ops.prims.convert_element_type.default(mul_1370, torch.bfloat16); mul_1370 = None + permute_528 = torch.ops.aten.permute.default(convert_element_type_1628, [1, 0]) + _grouped_mm_94 = torch.ops.aten._grouped_mm.default(permute_528, index_47, cumsum_71); permute_528 = index_47 = None + convert_element_type_1312 = torch.ops.prims.convert_element_type.default(primals_401, torch.bfloat16); primals_401 = None + all_gather_into_tensor_410 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1312, 8, '513'); convert_element_type_1312 = None + wait_tensor_505 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_410); all_gather_into_tensor_410 = None + permute_365 = torch.ops.aten.permute.default(wait_tensor_505, [0, 2, 1]); wait_tensor_505 = None + permute_530 = torch.ops.aten.permute.default(permute_365, [0, 2, 1]); permute_365 = None + _grouped_mm_95 = torch.ops.aten._grouped_mm.default(convert_element_type_1628, permute_530, cumsum_71); convert_element_type_1628 = permute_530 = cumsum_71 = None + add_1810 = torch.ops.aten.add.Tensor(_grouped_mm_93, _grouped_mm_95); _grouped_mm_93 = _grouped_mm_95 = None + convert_element_type_1629 = torch.ops.prims.convert_element_type.default(_grouped_mm_92, torch.float32); _grouped_mm_92 = None + div_144 = torch.ops.aten.div.Tensor(convert_element_type_1629, 64); convert_element_type_1629 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_144, 'sum', 8, '513'); div_144 = None + wait_tensor_597 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33); reduce_scatter_tensor_33 = None + convert_element_type_1630 = torch.ops.prims.convert_element_type.default(_grouped_mm_90, torch.float32); _grouped_mm_90 = None + div_145 = torch.ops.aten.div.Tensor(convert_element_type_1630, 64); convert_element_type_1630 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_145, 'sum', 8, '513'); div_145 = None + wait_tensor_598 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + convert_element_type_1631 = torch.ops.prims.convert_element_type.default(_grouped_mm_94, torch.float32); _grouped_mm_94 = None + div_146 = torch.ops.aten.div.Tensor(convert_element_type_1631, 64); convert_element_type_1631 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_146, 'sum', 8, '513'); div_146 = None + wait_tensor_599 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35); reduce_scatter_tensor_35 = None + index_put_56 = torch.ops.aten.index_put.default(full_356, [getitem_344], add_1810, True); full_356 = getitem_344 = add_1810 = None + slice_119 = torch.ops.aten.slice.Tensor(index_put_56, 0, 0, add_1811); index_put_56 = add_1811 = None + all_to_all_single_83 = torch.ops._c10d_functional.all_to_all_single.default(slice_119, [_local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375], [_local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383], '521'); slice_119 = _local_scalar_dense_368 = _local_scalar_dense_369 = _local_scalar_dense_370 = _local_scalar_dense_371 = _local_scalar_dense_372 = _local_scalar_dense_373 = _local_scalar_dense_374 = _local_scalar_dense_375 = _local_scalar_dense_376 = _local_scalar_dense_377 = _local_scalar_dense_378 = _local_scalar_dense_379 = _local_scalar_dense_380 = _local_scalar_dense_381 = _local_scalar_dense_382 = _local_scalar_dense_383 = None + wait_tensor_600 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_83); all_to_all_single_83 = None + index_put_57 = torch.ops.aten.index_put.default(full_default_52, [div_117], wait_tensor_600, True); div_117 = wait_tensor_600 = None + add_1815 = torch.ops.aten.add.Tensor(add_1807, index_put_57); add_1807 = index_put_57 = None + mul_1371 = torch.ops.aten.mul.Tensor(view_1821, 1.0); view_1821 = None + scatter_add_2 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_341, mul_1371); getitem_341 = mul_1371 = None + convert_element_type_1307 = torch.ops.prims.convert_element_type.default(mm_195, torch.float32); mm_195 = None + sub_552 = torch.ops.aten.sub.Tensor(convert_element_type_1307, amax_23); convert_element_type_1307 = amax_23 = None + exp_70 = torch.ops.aten.exp.default(sub_552); sub_552 = None + div_116 = torch.ops.aten.div.Tensor(exp_70, sum_93); exp_70 = sum_93 = None + mul_1372 = torch.ops.aten.mul.Tensor(scatter_add_2, div_116); scatter_add_2 = None + sum_123 = torch.ops.aten.sum.dim_IntList(mul_1372, [1], True) + neg_61 = torch.ops.aten.neg.default(div_116); div_116 = None + fma_2 = torch.ops.prims.fma.default(neg_61, sum_123, mul_1372); neg_61 = sum_123 = mul_1372 = None + convert_element_type_1632 = torch.ops.prims.convert_element_type.default(fma_2, torch.bfloat16); fma_2 = None + permute_532 = torch.ops.aten.permute.default(convert_element_type_1632, [1, 0]) + mm_256 = torch.ops.aten.mm.default(permute_532, view_1599); permute_532 = view_1599 = None + convert_element_type_1304 = torch.ops.prims.convert_element_type.default(primals_399, torch.bfloat16); primals_399 = None + all_gather_into_tensor_409 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1304, 64, '0'); convert_element_type_1304 = None + wait_tensor_501 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_409); all_gather_into_tensor_409 = None + permute_364 = torch.ops.aten.permute.default(wait_tensor_501, [1, 0]); wait_tensor_501 = None + permute_534 = torch.ops.aten.permute.default(permute_364, [1, 0]); permute_364 = None + mm_257 = torch.ops.aten.mm.default(convert_element_type_1632, permute_534); convert_element_type_1632 = permute_534 = None + add_1816 = torch.ops.aten.add.Tensor(add_1815, mm_257); add_1815 = mm_257 = None + convert_element_type_1637 = torch.ops.prims.convert_element_type.default(mm_256, torch.float32); mm_256 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1637, 'avg', 64, '0'); convert_element_type_1637 = None + wait_tensor_601 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + view_1823 = torch.ops.aten.view.default(add_1816, [2, 4096, 2048]); add_1816 = None + convert_element_type_1638 = torch.ops.prims.convert_element_type.default(view_1823, torch.float32); view_1823 = None + convert_element_type_1301 = torch.ops.prims.convert_element_type.default(primals_397, torch.bfloat16); primals_397 = None + all_gather_into_tensor_408 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1301, 64, '0'); convert_element_type_1301 = None + wait_tensor_500 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_408); all_gather_into_tensor_408 = None + convert_element_type_1640 = torch.ops.prims.convert_element_type.default(wait_tensor_500, torch.float32); wait_tensor_500 = None + mul_1373 = torch.ops.aten.mul.Tensor(convert_element_type_1638, convert_element_type_1640); convert_element_type_1640 = None + convert_element_type_1302 = torch.ops.prims.convert_element_type.default(add_1572, torch.float32); add_1572 = None + mul_1142 = torch.ops.aten.mul.Tensor(convert_element_type_1302, rsqrt_74); convert_element_type_1302 = None + mul_1375 = torch.ops.aten.mul.Tensor(mul_1142, mul_1373) + sum_124 = torch.ops.aten.sum.dim_IntList(mul_1375, [2], True); mul_1375 = None + div_147 = torch.ops.aten.div.Tensor(mul_1142, 2048) + mul_1376 = torch.ops.aten.mul.Tensor(div_147, sum_124); div_147 = sum_124 = None + sub_640 = torch.ops.aten.sub.Tensor(mul_1373, mul_1376); mul_1373 = mul_1376 = None + mul_1377 = torch.ops.aten.mul.Tensor(sub_640, rsqrt_74); sub_640 = rsqrt_74 = None + mul_1378 = torch.ops.aten.mul.Tensor(convert_element_type_1638, mul_1142); convert_element_type_1638 = mul_1142 = None + sum_125 = torch.ops.aten.sum.dim_IntList(mul_1378, [0, 1]); mul_1378 = None + convert_element_type_1641 = torch.ops.prims.convert_element_type.default(mul_1377, torch.bfloat16); mul_1377 = None + add_1817 = torch.ops.aten.add.Tensor(add_1804, convert_element_type_1641); add_1804 = convert_element_type_1641 = None + convert_element_type_default_75 = torch.ops.prims.convert_element_type.default(sum_125, torch.float32); sum_125 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_75, 'avg', 64, '0'); convert_element_type_default_75 = None + wait_tensor_602 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37); reduce_scatter_tensor_37 = None + view_1824 = torch.ops.aten.view.default(add_1817, [8192, 2048]) + permute_536 = torch.ops.aten.permute.default(view_1824, [1, 0]) + permute_362 = torch.ops.aten.permute.default(getitem_337, [0, 2, 1, 3]) + view_1594 = torch.ops.aten.view.default(permute_362, [2, 4096, -1]); permute_362 = None + view_1596 = torch.ops.aten.view.default(view_1594, [8192, 2048]); view_1594 = None + mm_258 = torch.ops.aten.mm.default(permute_536, view_1596); permute_536 = view_1596 = None + convert_element_type_1298 = torch.ops.prims.convert_element_type.default(primals_396, torch.bfloat16); primals_396 = None + all_gather_into_tensor_407 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1298, 64, '0'); convert_element_type_1298 = None + wait_tensor_499 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_407); all_gather_into_tensor_407 = None + permute_363 = torch.ops.aten.permute.default(wait_tensor_499, [1, 0]); wait_tensor_499 = None + permute_538 = torch.ops.aten.permute.default(permute_363, [1, 0]); permute_363 = None + mm_259 = torch.ops.aten.mm.default(view_1824, permute_538); view_1824 = permute_538 = None + view_1825 = torch.ops.aten.view.default(mm_259, [2, 4096, 2048]); mm_259 = None + convert_element_type_1648 = torch.ops.prims.convert_element_type.default(mm_258, torch.float32); mm_258 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1648, 'avg', 64, '0'); convert_element_type_1648 = None + wait_tensor_603 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + view_1826 = torch.ops.aten.view.default(view_1825, [2, 4096, 16, 128]); view_1825 = None + permute_540 = torch.ops.aten.permute.default(view_1826, [0, 2, 1, 3]); view_1826 = None + fw_graph2 = self.fw_graph2 + joint_graph2 = self.joint_graph2 + mask_graph2 = self.mask_graph2 + flex_attention_backward_2 = torch.ops.higher_order.flex_attention_backward(permute_359, permute_360, permute_361, getitem_337, getitem_338, permute_540, None, fw_graph2, joint_graph2, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph2), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_359 = permute_360 = permute_361 = getitem_337 = getitem_338 = permute_540 = fw_graph2 = joint_graph2 = mask_graph2 = None + getitem_381 = flex_attention_backward_2[0] + getitem_382 = flex_attention_backward_2[1] + getitem_383 = flex_attention_backward_2[2]; flex_attention_backward_2 = None + permute_541 = torch.ops.aten.permute.default(getitem_383, [0, 2, 1, 3]); getitem_383 = None + permute_542 = torch.ops.aten.permute.default(getitem_382, [0, 2, 1, 3]); getitem_382 = None + permute_543 = torch.ops.aten.permute.default(getitem_381, [0, 2, 1, 3]); getitem_381 = None + slice_121 = torch.ops.aten.slice.Tensor(permute_542, 3, 0, 128) + slice_122 = torch.ops.aten.slice.Tensor(permute_542, 3, 128, 192); permute_542 = None + sum_126 = torch.ops.aten.sum.dim_IntList(slice_122, [2], True); slice_122 = None + cat_86 = torch.ops.aten.cat.default([slice_121, permute_541], 3); slice_121 = permute_541 = None + view_1827 = torch.ops.aten.view.default(cat_86, [2, 4096, 4096]); cat_86 = None + view_1828 = torch.ops.aten.view.default(view_1827, [8192, 4096]); view_1827 = None + permute_544 = torch.ops.aten.permute.default(view_1828, [1, 0]) + mm_260 = torch.ops.aten.mm.default(permute_544, view_1591); permute_544 = view_1591 = None + convert_element_type_1295 = torch.ops.prims.convert_element_type.default(primals_395, torch.bfloat16); primals_395 = None + all_gather_into_tensor_406 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1295, 64, '0'); convert_element_type_1295 = None + wait_tensor_498 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_406); all_gather_into_tensor_406 = None + permute_358 = torch.ops.aten.permute.default(wait_tensor_498, [1, 0]); wait_tensor_498 = None + permute_546 = torch.ops.aten.permute.default(permute_358, [1, 0]); permute_358 = None + mm_261 = torch.ops.aten.mm.default(view_1828, permute_546); view_1828 = permute_546 = None + view_1829 = torch.ops.aten.view.default(mm_261, [2, 4096, 512]); mm_261 = None + convert_element_type_1653 = torch.ops.prims.convert_element_type.default(mm_260, torch.float32); mm_260 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1653, 'avg', 64, '0'); convert_element_type_1653 = None + wait_tensor_604 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39); reduce_scatter_tensor_39 = None + convert_element_type_1654 = torch.ops.prims.convert_element_type.default(view_1829, torch.float32); view_1829 = None + convert_element_type_1292 = torch.ops.prims.convert_element_type.default(primals_394, torch.bfloat16); primals_394 = None + all_gather_into_tensor_405 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1292, 64, '0'); convert_element_type_1292 = None + wait_tensor_497 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_405); all_gather_into_tensor_405 = None + convert_element_type_1656 = torch.ops.prims.convert_element_type.default(wait_tensor_497, torch.float32); wait_tensor_497 = None + mul_1379 = torch.ops.aten.mul.Tensor(convert_element_type_1654, convert_element_type_1656); convert_element_type_1656 = None + convert_element_type_1293 = torch.ops.prims.convert_element_type.default(getitem_333, torch.float32); getitem_333 = None + mul_1140 = torch.ops.aten.mul.Tensor(convert_element_type_1293, rsqrt_73); convert_element_type_1293 = None + mul_1381 = torch.ops.aten.mul.Tensor(mul_1140, mul_1379) + sum_127 = torch.ops.aten.sum.dim_IntList(mul_1381, [2], True); mul_1381 = None + div_148 = torch.ops.aten.div.Tensor(mul_1140, 512) + mul_1382 = torch.ops.aten.mul.Tensor(div_148, sum_127); div_148 = sum_127 = None + sub_641 = torch.ops.aten.sub.Tensor(mul_1379, mul_1382); mul_1379 = mul_1382 = None + mul_1383 = torch.ops.aten.mul.Tensor(sub_641, rsqrt_73); sub_641 = rsqrt_73 = None + mul_1384 = torch.ops.aten.mul.Tensor(convert_element_type_1654, mul_1140); convert_element_type_1654 = mul_1140 = None + sum_128 = torch.ops.aten.sum.dim_IntList(mul_1384, [0, 1]); mul_1384 = None + convert_element_type_1657 = torch.ops.prims.convert_element_type.default(mul_1383, torch.bfloat16); mul_1383 = None + convert_element_type_default_74 = torch.ops.prims.convert_element_type.default(sum_128, torch.float32); sum_128 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_74, 'avg', 64, '0'); convert_element_type_default_74 = None + wait_tensor_605 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + convert_element_type_1660 = torch.ops.prims.convert_element_type.default(sum_126, torch.float32); sum_126 = None + view_1830 = torch.ops.aten.view.default(convert_element_type_1660, [2, 4096, 1, 32, 2]); convert_element_type_1660 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1830); view_1830 = None + mul_1385 = torch.ops.aten.mul.Tensor(view_as_complex_58, clone_9); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_1385); mul_1385 = None + view_1831 = torch.ops.aten.view.default(view_as_real_58, [2, 4096, 1, 64]); view_as_real_58 = None + convert_element_type_1661 = torch.ops.prims.convert_element_type.default(view_1831, torch.bfloat16); view_1831 = None + squeeze_28 = torch.ops.aten.squeeze.dim(convert_element_type_1661, 2); convert_element_type_1661 = None + cat_87 = torch.ops.aten.cat.default([convert_element_type_1657, squeeze_28], 2); convert_element_type_1657 = squeeze_28 = None + view_1832 = torch.ops.aten.view.default(cat_87, [8192, 576]); cat_87 = None + permute_548 = torch.ops.aten.permute.default(view_1832, [1, 0]) + mm_262 = torch.ops.aten.mm.default(permute_548, view_1577); permute_548 = None + convert_element_type_1287 = torch.ops.prims.convert_element_type.default(primals_393, torch.bfloat16); primals_393 = None + all_gather_into_tensor_404 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1287, 64, '0'); convert_element_type_1287 = None + wait_tensor_496 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_404); all_gather_into_tensor_404 = None + permute_357 = torch.ops.aten.permute.default(wait_tensor_496, [1, 0]); wait_tensor_496 = None + permute_550 = torch.ops.aten.permute.default(permute_357, [1, 0]); permute_357 = None + mm_263 = torch.ops.aten.mm.default(view_1832, permute_550); view_1832 = permute_550 = None + view_1833 = torch.ops.aten.view.default(mm_263, [2, 4096, 2048]); mm_263 = None + convert_element_type_1666 = torch.ops.prims.convert_element_type.default(mm_262, torch.float32); mm_262 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1666, 'avg', 64, '0'); convert_element_type_1666 = None + wait_tensor_606 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41); reduce_scatter_tensor_41 = None + slice_123 = torch.ops.aten.slice.Tensor(permute_543, 3, 0, 128) + slice_124 = torch.ops.aten.slice.Tensor(permute_543, 3, 128, 192); permute_543 = None + convert_element_type_1667 = torch.ops.prims.convert_element_type.default(slice_124, torch.float32); slice_124 = None + view_1834 = torch.ops.aten.view.default(convert_element_type_1667, [2, 4096, 16, 32, 2]); convert_element_type_1667 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1834); view_1834 = None + mul_1386 = torch.ops.aten.mul.Tensor(view_as_complex_59, clone_9); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_1386); mul_1386 = None + view_1835 = torch.ops.aten.view.default(view_as_real_59, [2, 4096, 16, 64]); view_as_real_59 = None + convert_element_type_1668 = torch.ops.prims.convert_element_type.default(view_1835, torch.bfloat16); view_1835 = None + cat_88 = torch.ops.aten.cat.default([slice_123, convert_element_type_1668], 3); slice_123 = convert_element_type_1668 = None + view_1836 = torch.ops.aten.view.default(cat_88, [2, 4096, 3072]); cat_88 = None + view_1837 = torch.ops.aten.view.default(view_1836, [8192, 3072]); view_1836 = None + permute_552 = torch.ops.aten.permute.default(view_1837, [1, 0]) + mm_264 = torch.ops.aten.mm.default(permute_552, view_1577); permute_552 = view_1577 = None + convert_element_type_1282 = torch.ops.prims.convert_element_type.default(primals_392, torch.bfloat16); primals_392 = None + all_gather_into_tensor_403 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1282, 64, '0'); convert_element_type_1282 = None + wait_tensor_495 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_403); all_gather_into_tensor_403 = None + permute_356 = torch.ops.aten.permute.default(wait_tensor_495, [1, 0]); wait_tensor_495 = None + permute_554 = torch.ops.aten.permute.default(permute_356, [1, 0]); permute_356 = None + mm_265 = torch.ops.aten.mm.default(view_1837, permute_554); view_1837 = permute_554 = None + view_1838 = torch.ops.aten.view.default(mm_265, [2, 4096, 2048]); mm_265 = None + add_1818 = torch.ops.aten.add.Tensor(view_1833, view_1838); view_1833 = view_1838 = None + convert_element_type_1673 = torch.ops.prims.convert_element_type.default(mm_264, torch.float32); mm_264 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1673, 'avg', 64, '0'); convert_element_type_1673 = None + wait_tensor_607 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + convert_element_type_1674 = torch.ops.prims.convert_element_type.default(add_1818, torch.float32); add_1818 = None + convert_element_type_1279 = torch.ops.prims.convert_element_type.default(primals_391, torch.bfloat16); primals_391 = None + all_gather_into_tensor_402 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1279, 64, '0'); convert_element_type_1279 = None + wait_tensor_494 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_402); all_gather_into_tensor_402 = None + convert_element_type_1676 = torch.ops.prims.convert_element_type.default(wait_tensor_494, torch.float32); wait_tensor_494 = None + mul_1387 = torch.ops.aten.mul.Tensor(convert_element_type_1674, convert_element_type_1676); convert_element_type_1676 = None + convert_element_type_1280 = torch.ops.prims.convert_element_type.default(add_1569, torch.float32); add_1569 = None + mul_1136 = torch.ops.aten.mul.Tensor(convert_element_type_1280, rsqrt_72); convert_element_type_1280 = None + mul_1389 = torch.ops.aten.mul.Tensor(mul_1136, mul_1387) + sum_129 = torch.ops.aten.sum.dim_IntList(mul_1389, [2], True); mul_1389 = None + div_149 = torch.ops.aten.div.Tensor(mul_1136, 2048) + mul_1390 = torch.ops.aten.mul.Tensor(div_149, sum_129); div_149 = sum_129 = None + sub_642 = torch.ops.aten.sub.Tensor(mul_1387, mul_1390); mul_1387 = mul_1390 = None + mul_1391 = torch.ops.aten.mul.Tensor(sub_642, rsqrt_72); sub_642 = rsqrt_72 = None + mul_1392 = torch.ops.aten.mul.Tensor(convert_element_type_1674, mul_1136); convert_element_type_1674 = mul_1136 = None + sum_130 = torch.ops.aten.sum.dim_IntList(mul_1392, [0, 1]); mul_1392 = None + convert_element_type_1677 = torch.ops.prims.convert_element_type.default(mul_1391, torch.bfloat16); mul_1391 = None + add_1819 = torch.ops.aten.add.Tensor(add_1817, convert_element_type_1677); add_1817 = convert_element_type_1677 = None + convert_element_type_default_73 = torch.ops.prims.convert_element_type.default(sum_130, torch.float32); sum_130 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_73, 'avg', 64, '0'); convert_element_type_default_73 = None + wait_tensor_608 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43); reduce_scatter_tensor_43 = None + view_1839 = torch.ops.aten.view.default(add_1819, [8192, 2048]) + unsqueeze_56 = torch.ops.aten.unsqueeze.default(view_1839, 1) + convert_element_type_1680 = torch.ops.prims.convert_element_type.default(unsqueeze_56, torch.float32); unsqueeze_56 = None + bmm_32 = torch.ops.aten.bmm.default(permute_556, convert_element_type_1680); permute_556 = None + bmm_33 = torch.ops.aten.bmm.default(convert_element_type_1680, permute_557); convert_element_type_1680 = permute_557 = None + convert_element_type_1681 = torch.ops.prims.convert_element_type.default(bmm_32, torch.bfloat16); bmm_32 = None + view_1840 = torch.ops.aten.view.default(bmm_33, [8192, 6]); bmm_33 = None + view_1841 = torch.ops.aten.view.default(convert_element_type_1681, [49152, 2048]); convert_element_type_1681 = None + index_58 = torch.ops.aten.index.Tensor(view_1841, [getitem_329]); view_1841 = getitem_329 = None + permute_558 = torch.ops.aten.permute.default(view_1839, [1, 0]) + mm_266 = torch.ops.aten.mm.default(permute_558, mul_1133); permute_558 = mul_1133 = None + convert_element_type_1274 = torch.ops.prims.convert_element_type.default(primals_390, torch.bfloat16); primals_390 = None + all_gather_into_tensor_401 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1274, 64, '0'); convert_element_type_1274 = None + wait_tensor_493 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_401); all_gather_into_tensor_401 = None + permute_355 = torch.ops.aten.permute.default(wait_tensor_493, [1, 0]); wait_tensor_493 = None + permute_560 = torch.ops.aten.permute.default(permute_355, [1, 0]); permute_355 = None + mm_267 = torch.ops.aten.mm.default(view_1839, permute_560); view_1839 = permute_560 = None + convert_element_type_1686 = torch.ops.prims.convert_element_type.default(mm_266, torch.float32); mm_266 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1686, 'avg', 64, '0'); convert_element_type_1686 = None + wait_tensor_609 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + convert_element_type_1269 = torch.ops.prims.convert_element_type.default(mm_188, torch.float32); mm_188 = None + neg_46 = torch.ops.aten.neg.default(convert_element_type_1269) + exp_69 = torch.ops.aten.exp.default(neg_46); neg_46 = None + add_1564 = torch.ops.aten.add.Tensor(exp_69, 1); exp_69 = None + div_115 = torch.ops.aten.div.Tensor(convert_element_type_1269, add_1564) + convert_element_type_1270 = torch.ops.prims.convert_element_type.default(div_115, torch.bfloat16); div_115 = None + mul_1393 = torch.ops.aten.mul.Tensor(mm_267, convert_element_type_1270); convert_element_type_1270 = None + mul_1394 = torch.ops.aten.mul.Tensor(mm_267, mm_189); mm_267 = mm_189 = None + permute_562 = torch.ops.aten.permute.default(mul_1393, [1, 0]) + mm_268 = torch.ops.aten.mm.default(permute_562, view_1532); permute_562 = None + convert_element_type_1271 = torch.ops.prims.convert_element_type.default(primals_389, torch.bfloat16); primals_389 = None + all_gather_into_tensor_400 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1271, 64, '0'); convert_element_type_1271 = None + wait_tensor_492 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_400); all_gather_into_tensor_400 = None + permute_354 = torch.ops.aten.permute.default(wait_tensor_492, [1, 0]); wait_tensor_492 = None + permute_564 = torch.ops.aten.permute.default(permute_354, [1, 0]); permute_354 = None + mm_269 = torch.ops.aten.mm.default(mul_1393, permute_564); mul_1393 = permute_564 = None + convert_element_type_1691 = torch.ops.prims.convert_element_type.default(mm_268, torch.float32); mm_268 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1691, 'avg', 64, '0'); convert_element_type_1691 = None + wait_tensor_610 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45); reduce_scatter_tensor_45 = None + convert_element_type_1692 = torch.ops.prims.convert_element_type.default(mul_1394, torch.float32); mul_1394 = None + reciprocal_6 = torch.ops.aten.reciprocal.default(add_1564); add_1564 = None + mul_1395 = torch.ops.aten.mul.Tensor(reciprocal_6, 1); reciprocal_6 = None + mul_1396 = torch.ops.aten.mul.Tensor(convert_element_type_1692, mul_1395); convert_element_type_1692 = None + sub_643 = torch.ops.aten.sub.Tensor(1, mul_1395); mul_1395 = None + mul_1397 = torch.ops.aten.mul.Tensor(convert_element_type_1269, sub_643); convert_element_type_1269 = sub_643 = None + add_1821 = torch.ops.aten.add.Tensor(mul_1397, 1); mul_1397 = None + mul_1398 = torch.ops.aten.mul.Tensor(mul_1396, add_1821); mul_1396 = add_1821 = None + convert_element_type_1694 = torch.ops.prims.convert_element_type.default(mul_1398, torch.bfloat16); mul_1398 = None + permute_566 = torch.ops.aten.permute.default(convert_element_type_1694, [1, 0]) + mm_270 = torch.ops.aten.mm.default(permute_566, view_1532); permute_566 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(primals_388, torch.bfloat16); primals_388 = None + all_gather_into_tensor_399 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1266, 64, '0'); convert_element_type_1266 = None + wait_tensor_491 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_399); all_gather_into_tensor_399 = None + permute_353 = torch.ops.aten.permute.default(wait_tensor_491, [1, 0]); wait_tensor_491 = None + permute_568 = torch.ops.aten.permute.default(permute_353, [1, 0]); permute_353 = None + mm_271 = torch.ops.aten.mm.default(convert_element_type_1694, permute_568); convert_element_type_1694 = permute_568 = None + add_1822 = torch.ops.aten.add.Tensor(mm_269, mm_271); mm_269 = mm_271 = None + convert_element_type_1699 = torch.ops.prims.convert_element_type.default(mm_270, torch.float32); mm_270 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1699, 'avg', 64, '0'); convert_element_type_1699 = None + wait_tensor_611 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + all_to_all_single_84 = torch.ops._c10d_functional.all_to_all_single.default(index_58, [_local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367], [_local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359], '521'); index_58 = None + wait_tensor_612 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_84); all_to_all_single_84 = None + full_360 = torch.ops.aten.full.default([sym_size_int_89, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_89 = None + slice_scatter_3 = torch.ops.aten.slice_scatter.default(full_360, wait_tensor_612, 0, 0, -1); wait_tensor_612 = None + index_59 = torch.ops.aten.index.Tensor(slice_scatter_3, [getitem_330]); slice_scatter_3 = None + permute_570 = torch.ops.aten.permute.default(index_59, [1, 0]) + _grouped_mm_96 = torch.ops.aten._grouped_mm.default(permute_570, mul_1113, cumsum_68); permute_570 = mul_1113 = None + convert_element_type_1260 = torch.ops.prims.convert_element_type.default(primals_386, torch.bfloat16); primals_386 = None + all_gather_into_tensor_395 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1260, 8, '513'); convert_element_type_1260 = None + wait_tensor_486 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_395); all_gather_into_tensor_395 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_486, [0, 2, 1]); wait_tensor_486 = None + permute_572 = torch.ops.aten.permute.default(permute_352, [0, 2, 1]); permute_352 = None + _grouped_mm_97 = torch.ops.aten._grouped_mm.default(index_59, permute_572, cumsum_68); index_59 = permute_572 = None + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(_grouped_mm_66, torch.float32); _grouped_mm_66 = None + neg_45 = torch.ops.aten.neg.default(convert_element_type_1264) + exp_68 = torch.ops.aten.exp.default(neg_45); neg_45 = None + add_1528 = torch.ops.aten.add.Tensor(exp_68, 1); exp_68 = None + div_114 = torch.ops.aten.div.Tensor(convert_element_type_1264, add_1528) + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(div_114, torch.bfloat16); div_114 = None + mul_1399 = torch.ops.aten.mul.Tensor(_grouped_mm_97, convert_element_type_1265); convert_element_type_1265 = None + mul_1400 = torch.ops.aten.mul.Tensor(_grouped_mm_97, _grouped_mm_67); _grouped_mm_97 = _grouped_mm_67 = None + permute_574 = torch.ops.aten.permute.default(mul_1399, [1, 0]) + _grouped_mm_98 = torch.ops.aten._grouped_mm.default(permute_574, index_45, cumsum_68); permute_574 = None + convert_element_type_1261 = torch.ops.prims.convert_element_type.default(primals_387, torch.bfloat16); primals_387 = None + all_gather_into_tensor_396 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1261, 8, '513'); convert_element_type_1261 = None + wait_tensor_487 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_396); all_gather_into_tensor_396 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_487, [0, 2, 1]); wait_tensor_487 = None + permute_576 = torch.ops.aten.permute.default(permute_351, [0, 2, 1]); permute_351 = None + _grouped_mm_99 = torch.ops.aten._grouped_mm.default(mul_1399, permute_576, cumsum_68); mul_1399 = permute_576 = None + convert_element_type_1700 = torch.ops.prims.convert_element_type.default(mul_1400, torch.float32); mul_1400 = None + reciprocal_7 = torch.ops.aten.reciprocal.default(add_1528); add_1528 = None + mul_1401 = torch.ops.aten.mul.Tensor(reciprocal_7, 1); reciprocal_7 = None + mul_1402 = torch.ops.aten.mul.Tensor(convert_element_type_1700, mul_1401); convert_element_type_1700 = None + sub_644 = torch.ops.aten.sub.Tensor(1, mul_1401); mul_1401 = None + mul_1403 = torch.ops.aten.mul.Tensor(convert_element_type_1264, sub_644); convert_element_type_1264 = sub_644 = None + add_1824 = torch.ops.aten.add.Tensor(mul_1403, 1); mul_1403 = None + mul_1404 = torch.ops.aten.mul.Tensor(mul_1402, add_1824); mul_1402 = add_1824 = None + convert_element_type_1702 = torch.ops.prims.convert_element_type.default(mul_1404, torch.bfloat16); mul_1404 = None + permute_578 = torch.ops.aten.permute.default(convert_element_type_1702, [1, 0]) + _grouped_mm_100 = torch.ops.aten._grouped_mm.default(permute_578, index_45, cumsum_68); permute_578 = index_45 = None + convert_element_type_1258 = torch.ops.prims.convert_element_type.default(primals_385, torch.bfloat16); primals_385 = None + all_gather_into_tensor_393 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1258, 8, '513'); convert_element_type_1258 = None + wait_tensor_484 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_393); all_gather_into_tensor_393 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_484, [0, 2, 1]); wait_tensor_484 = None + permute_580 = torch.ops.aten.permute.default(permute_350, [0, 2, 1]); permute_350 = None + _grouped_mm_101 = torch.ops.aten._grouped_mm.default(convert_element_type_1702, permute_580, cumsum_68); convert_element_type_1702 = permute_580 = cumsum_68 = None + add_1825 = torch.ops.aten.add.Tensor(_grouped_mm_99, _grouped_mm_101); _grouped_mm_99 = _grouped_mm_101 = None + convert_element_type_1703 = torch.ops.prims.convert_element_type.default(_grouped_mm_98, torch.float32); _grouped_mm_98 = None + div_150 = torch.ops.aten.div.Tensor(convert_element_type_1703, 64); convert_element_type_1703 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_150, 'sum', 8, '513'); div_150 = None + wait_tensor_613 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47); reduce_scatter_tensor_47 = None + convert_element_type_1704 = torch.ops.prims.convert_element_type.default(_grouped_mm_96, torch.float32); _grouped_mm_96 = None + div_151 = torch.ops.aten.div.Tensor(convert_element_type_1704, 64); convert_element_type_1704 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_151, 'sum', 8, '513'); div_151 = None + wait_tensor_614 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + convert_element_type_1705 = torch.ops.prims.convert_element_type.default(_grouped_mm_100, torch.float32); _grouped_mm_100 = None + div_152 = torch.ops.aten.div.Tensor(convert_element_type_1705, 64); convert_element_type_1705 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_152, 'sum', 8, '513'); div_152 = None + wait_tensor_615 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49); reduce_scatter_tensor_49 = None + index_put_58 = torch.ops.aten.index_put.default(full_360, [getitem_330], add_1825, True); full_360 = getitem_330 = add_1825 = None + slice_125 = torch.ops.aten.slice.Tensor(index_put_58, 0, 0, add_1826); index_put_58 = add_1826 = None + all_to_all_single_85 = torch.ops._c10d_functional.all_to_all_single.default(slice_125, [_local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359], [_local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367], '521'); slice_125 = _local_scalar_dense_352 = _local_scalar_dense_353 = _local_scalar_dense_354 = _local_scalar_dense_355 = _local_scalar_dense_356 = _local_scalar_dense_357 = _local_scalar_dense_358 = _local_scalar_dense_359 = _local_scalar_dense_360 = _local_scalar_dense_361 = _local_scalar_dense_362 = _local_scalar_dense_363 = _local_scalar_dense_364 = _local_scalar_dense_365 = _local_scalar_dense_366 = _local_scalar_dense_367 = None + wait_tensor_616 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_85); all_to_all_single_85 = None + index_put_59 = torch.ops.aten.index_put.default(full_default_52, [div_112], wait_tensor_616, True); div_112 = wait_tensor_616 = None + add_1830 = torch.ops.aten.add.Tensor(add_1822, index_put_59); add_1822 = index_put_59 = None + mul_1405 = torch.ops.aten.mul.Tensor(view_1840, 1.0); view_1840 = None + scatter_add_3 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_327, mul_1405); getitem_327 = mul_1405 = None + convert_element_type_1253 = torch.ops.prims.convert_element_type.default(mm_187, torch.float32); mm_187 = None + sub_528 = torch.ops.aten.sub.Tensor(convert_element_type_1253, amax_22); convert_element_type_1253 = amax_22 = None + exp_67 = torch.ops.aten.exp.default(sub_528); sub_528 = None + div_111 = torch.ops.aten.div.Tensor(exp_67, sum_89); exp_67 = sum_89 = None + mul_1406 = torch.ops.aten.mul.Tensor(scatter_add_3, div_111); scatter_add_3 = None + sum_131 = torch.ops.aten.sum.dim_IntList(mul_1406, [1], True) + neg_64 = torch.ops.aten.neg.default(div_111); div_111 = None + fma_3 = torch.ops.prims.fma.default(neg_64, sum_131, mul_1406); neg_64 = sum_131 = mul_1406 = None + convert_element_type_1706 = torch.ops.prims.convert_element_type.default(fma_3, torch.bfloat16); fma_3 = None + permute_582 = torch.ops.aten.permute.default(convert_element_type_1706, [1, 0]) + mm_272 = torch.ops.aten.mm.default(permute_582, view_1532); permute_582 = view_1532 = None + convert_element_type_1250 = torch.ops.prims.convert_element_type.default(primals_383, torch.bfloat16); primals_383 = None + all_gather_into_tensor_392 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1250, 64, '0'); convert_element_type_1250 = None + wait_tensor_480 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_392); all_gather_into_tensor_392 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_480, [1, 0]); wait_tensor_480 = None + permute_584 = torch.ops.aten.permute.default(permute_349, [1, 0]); permute_349 = None + mm_273 = torch.ops.aten.mm.default(convert_element_type_1706, permute_584); convert_element_type_1706 = permute_584 = None + add_1831 = torch.ops.aten.add.Tensor(add_1830, mm_273); add_1830 = mm_273 = None + convert_element_type_1711 = torch.ops.prims.convert_element_type.default(mm_272, torch.float32); mm_272 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1711, 'avg', 64, '0'); convert_element_type_1711 = None + wait_tensor_617 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + view_1842 = torch.ops.aten.view.default(add_1831, [2, 4096, 2048]); add_1831 = None + convert_element_type_1712 = torch.ops.prims.convert_element_type.default(view_1842, torch.float32); view_1842 = None + convert_element_type_1247 = torch.ops.prims.convert_element_type.default(primals_381, torch.bfloat16); primals_381 = None + all_gather_into_tensor_391 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1247, 64, '0'); convert_element_type_1247 = None + wait_tensor_479 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_391); all_gather_into_tensor_391 = None + convert_element_type_1714 = torch.ops.prims.convert_element_type.default(wait_tensor_479, torch.float32); wait_tensor_479 = None + mul_1407 = torch.ops.aten.mul.Tensor(convert_element_type_1712, convert_element_type_1714); convert_element_type_1714 = None + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(add_1504, torch.float32); add_1504 = None + mul_1093 = torch.ops.aten.mul.Tensor(convert_element_type_1248, rsqrt_71); convert_element_type_1248 = None + mul_1409 = torch.ops.aten.mul.Tensor(mul_1093, mul_1407) + sum_132 = torch.ops.aten.sum.dim_IntList(mul_1409, [2], True); mul_1409 = None + div_153 = torch.ops.aten.div.Tensor(mul_1093, 2048) + mul_1410 = torch.ops.aten.mul.Tensor(div_153, sum_132); div_153 = sum_132 = None + sub_646 = torch.ops.aten.sub.Tensor(mul_1407, mul_1410); mul_1407 = mul_1410 = None + mul_1411 = torch.ops.aten.mul.Tensor(sub_646, rsqrt_71); sub_646 = rsqrt_71 = None + mul_1412 = torch.ops.aten.mul.Tensor(convert_element_type_1712, mul_1093); convert_element_type_1712 = mul_1093 = None + sum_133 = torch.ops.aten.sum.dim_IntList(mul_1412, [0, 1]); mul_1412 = None + convert_element_type_1715 = torch.ops.prims.convert_element_type.default(mul_1411, torch.bfloat16); mul_1411 = None + add_1832 = torch.ops.aten.add.Tensor(add_1819, convert_element_type_1715); add_1819 = convert_element_type_1715 = None + convert_element_type_default_72 = torch.ops.prims.convert_element_type.default(sum_133, torch.float32); sum_133 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_72, 'avg', 64, '0'); convert_element_type_default_72 = None + wait_tensor_618 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51); reduce_scatter_tensor_51 = None + view_1843 = torch.ops.aten.view.default(add_1832, [8192, 2048]) + permute_586 = torch.ops.aten.permute.default(view_1843, [1, 0]) + permute_347 = torch.ops.aten.permute.default(getitem_323, [0, 2, 1, 3]) + view_1527 = torch.ops.aten.view.default(permute_347, [2, 4096, -1]); permute_347 = None + view_1529 = torch.ops.aten.view.default(view_1527, [8192, 2048]); view_1527 = None + mm_274 = torch.ops.aten.mm.default(permute_586, view_1529); permute_586 = view_1529 = None + convert_element_type_1244 = torch.ops.prims.convert_element_type.default(primals_380, torch.bfloat16); primals_380 = None + all_gather_into_tensor_390 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1244, 64, '0'); convert_element_type_1244 = None + wait_tensor_478 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_390); all_gather_into_tensor_390 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_478, [1, 0]); wait_tensor_478 = None + permute_588 = torch.ops.aten.permute.default(permute_348, [1, 0]); permute_348 = None + mm_275 = torch.ops.aten.mm.default(view_1843, permute_588); view_1843 = permute_588 = None + view_1844 = torch.ops.aten.view.default(mm_275, [2, 4096, 2048]); mm_275 = None + convert_element_type_1722 = torch.ops.prims.convert_element_type.default(mm_274, torch.float32); mm_274 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1722, 'avg', 64, '0'); convert_element_type_1722 = None + wait_tensor_619 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + view_1845 = torch.ops.aten.view.default(view_1844, [2, 4096, 16, 128]); view_1844 = None + permute_590 = torch.ops.aten.permute.default(view_1845, [0, 2, 1, 3]); view_1845 = None + fw_graph3 = self.fw_graph3 + joint_graph3 = self.joint_graph3 + mask_graph3 = self.mask_graph3 + flex_attention_backward_3 = torch.ops.higher_order.flex_attention_backward(permute_344, permute_345, permute_346, getitem_323, getitem_324, permute_590, None, fw_graph3, joint_graph3, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph3), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_344 = permute_345 = permute_346 = getitem_323 = getitem_324 = permute_590 = fw_graph3 = joint_graph3 = mask_graph3 = None + getitem_385 = flex_attention_backward_3[0] + getitem_386 = flex_attention_backward_3[1] + getitem_387 = flex_attention_backward_3[2]; flex_attention_backward_3 = None + permute_591 = torch.ops.aten.permute.default(getitem_387, [0, 2, 1, 3]); getitem_387 = None + permute_592 = torch.ops.aten.permute.default(getitem_386, [0, 2, 1, 3]); getitem_386 = None + permute_593 = torch.ops.aten.permute.default(getitem_385, [0, 2, 1, 3]); getitem_385 = None + slice_127 = torch.ops.aten.slice.Tensor(permute_592, 3, 0, 128) + slice_128 = torch.ops.aten.slice.Tensor(permute_592, 3, 128, 192); permute_592 = None + sum_134 = torch.ops.aten.sum.dim_IntList(slice_128, [2], True); slice_128 = None + cat_89 = torch.ops.aten.cat.default([slice_127, permute_591], 3); slice_127 = permute_591 = None + view_1846 = torch.ops.aten.view.default(cat_89, [2, 4096, 4096]); cat_89 = None + view_1847 = torch.ops.aten.view.default(view_1846, [8192, 4096]); view_1846 = None + permute_594 = torch.ops.aten.permute.default(view_1847, [1, 0]) + mm_276 = torch.ops.aten.mm.default(permute_594, view_1524); permute_594 = view_1524 = None + convert_element_type_1241 = torch.ops.prims.convert_element_type.default(primals_379, torch.bfloat16); primals_379 = None + all_gather_into_tensor_389 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1241, 64, '0'); convert_element_type_1241 = None + wait_tensor_477 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_389); all_gather_into_tensor_389 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_477, [1, 0]); wait_tensor_477 = None + permute_596 = torch.ops.aten.permute.default(permute_343, [1, 0]); permute_343 = None + mm_277 = torch.ops.aten.mm.default(view_1847, permute_596); view_1847 = permute_596 = None + view_1848 = torch.ops.aten.view.default(mm_277, [2, 4096, 512]); mm_277 = None + convert_element_type_1727 = torch.ops.prims.convert_element_type.default(mm_276, torch.float32); mm_276 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1727, 'avg', 64, '0'); convert_element_type_1727 = None + wait_tensor_620 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53); reduce_scatter_tensor_53 = None + convert_element_type_1728 = torch.ops.prims.convert_element_type.default(view_1848, torch.float32); view_1848 = None + convert_element_type_1238 = torch.ops.prims.convert_element_type.default(primals_378, torch.bfloat16); primals_378 = None + all_gather_into_tensor_388 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1238, 64, '0'); convert_element_type_1238 = None + wait_tensor_476 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_388); all_gather_into_tensor_388 = None + convert_element_type_1730 = torch.ops.prims.convert_element_type.default(wait_tensor_476, torch.float32); wait_tensor_476 = None + mul_1413 = torch.ops.aten.mul.Tensor(convert_element_type_1728, convert_element_type_1730); convert_element_type_1730 = None + convert_element_type_1239 = torch.ops.prims.convert_element_type.default(getitem_319, torch.float32); getitem_319 = None + mul_1091 = torch.ops.aten.mul.Tensor(convert_element_type_1239, rsqrt_70); convert_element_type_1239 = None + mul_1415 = torch.ops.aten.mul.Tensor(mul_1091, mul_1413) + sum_135 = torch.ops.aten.sum.dim_IntList(mul_1415, [2], True); mul_1415 = None + div_154 = torch.ops.aten.div.Tensor(mul_1091, 512) + mul_1416 = torch.ops.aten.mul.Tensor(div_154, sum_135); div_154 = sum_135 = None + sub_647 = torch.ops.aten.sub.Tensor(mul_1413, mul_1416); mul_1413 = mul_1416 = None + mul_1417 = torch.ops.aten.mul.Tensor(sub_647, rsqrt_70); sub_647 = rsqrt_70 = None + mul_1418 = torch.ops.aten.mul.Tensor(convert_element_type_1728, mul_1091); convert_element_type_1728 = mul_1091 = None + sum_136 = torch.ops.aten.sum.dim_IntList(mul_1418, [0, 1]); mul_1418 = None + convert_element_type_1731 = torch.ops.prims.convert_element_type.default(mul_1417, torch.bfloat16); mul_1417 = None + convert_element_type_default_71 = torch.ops.prims.convert_element_type.default(sum_136, torch.float32); sum_136 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_71, 'avg', 64, '0'); convert_element_type_default_71 = None + wait_tensor_621 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + convert_element_type_1734 = torch.ops.prims.convert_element_type.default(sum_134, torch.float32); sum_134 = None + view_1849 = torch.ops.aten.view.default(convert_element_type_1734, [2, 4096, 1, 32, 2]); convert_element_type_1734 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1849); view_1849 = None + mul_1419 = torch.ops.aten.mul.Tensor(view_as_complex_60, clone_9); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_1419); mul_1419 = None + view_1850 = torch.ops.aten.view.default(view_as_real_60, [2, 4096, 1, 64]); view_as_real_60 = None + convert_element_type_1735 = torch.ops.prims.convert_element_type.default(view_1850, torch.bfloat16); view_1850 = None + squeeze_29 = torch.ops.aten.squeeze.dim(convert_element_type_1735, 2); convert_element_type_1735 = None + cat_90 = torch.ops.aten.cat.default([convert_element_type_1731, squeeze_29], 2); convert_element_type_1731 = squeeze_29 = None + view_1851 = torch.ops.aten.view.default(cat_90, [8192, 576]); cat_90 = None + permute_598 = torch.ops.aten.permute.default(view_1851, [1, 0]) + mm_278 = torch.ops.aten.mm.default(permute_598, view_1510); permute_598 = None + convert_element_type_1233 = torch.ops.prims.convert_element_type.default(primals_377, torch.bfloat16); primals_377 = None + all_gather_into_tensor_387 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1233, 64, '0'); convert_element_type_1233 = None + wait_tensor_475 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_387); all_gather_into_tensor_387 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_475, [1, 0]); wait_tensor_475 = None + permute_600 = torch.ops.aten.permute.default(permute_342, [1, 0]); permute_342 = None + mm_279 = torch.ops.aten.mm.default(view_1851, permute_600); view_1851 = permute_600 = None + view_1852 = torch.ops.aten.view.default(mm_279, [2, 4096, 2048]); mm_279 = None + convert_element_type_1740 = torch.ops.prims.convert_element_type.default(mm_278, torch.float32); mm_278 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1740, 'avg', 64, '0'); convert_element_type_1740 = None + wait_tensor_622 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55); reduce_scatter_tensor_55 = None + slice_129 = torch.ops.aten.slice.Tensor(permute_593, 3, 0, 128) + slice_130 = torch.ops.aten.slice.Tensor(permute_593, 3, 128, 192); permute_593 = None + convert_element_type_1741 = torch.ops.prims.convert_element_type.default(slice_130, torch.float32); slice_130 = None + view_1853 = torch.ops.aten.view.default(convert_element_type_1741, [2, 4096, 16, 32, 2]); convert_element_type_1741 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1853); view_1853 = None + mul_1420 = torch.ops.aten.mul.Tensor(view_as_complex_61, clone_9); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_1420); mul_1420 = None + view_1854 = torch.ops.aten.view.default(view_as_real_61, [2, 4096, 16, 64]); view_as_real_61 = None + convert_element_type_1742 = torch.ops.prims.convert_element_type.default(view_1854, torch.bfloat16); view_1854 = None + cat_91 = torch.ops.aten.cat.default([slice_129, convert_element_type_1742], 3); slice_129 = convert_element_type_1742 = None + view_1855 = torch.ops.aten.view.default(cat_91, [2, 4096, 3072]); cat_91 = None + view_1856 = torch.ops.aten.view.default(view_1855, [8192, 3072]); view_1855 = None + permute_602 = torch.ops.aten.permute.default(view_1856, [1, 0]) + mm_280 = torch.ops.aten.mm.default(permute_602, view_1510); permute_602 = view_1510 = None + convert_element_type_1228 = torch.ops.prims.convert_element_type.default(primals_376, torch.bfloat16); primals_376 = None + all_gather_into_tensor_386 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1228, 64, '0'); convert_element_type_1228 = None + wait_tensor_474 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_386); all_gather_into_tensor_386 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_474, [1, 0]); wait_tensor_474 = None + permute_604 = torch.ops.aten.permute.default(permute_341, [1, 0]); permute_341 = None + mm_281 = torch.ops.aten.mm.default(view_1856, permute_604); view_1856 = permute_604 = None + view_1857 = torch.ops.aten.view.default(mm_281, [2, 4096, 2048]); mm_281 = None + add_1833 = torch.ops.aten.add.Tensor(view_1852, view_1857); view_1852 = view_1857 = None + convert_element_type_1747 = torch.ops.prims.convert_element_type.default(mm_280, torch.float32); mm_280 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1747, 'avg', 64, '0'); convert_element_type_1747 = None + wait_tensor_623 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + convert_element_type_1748 = torch.ops.prims.convert_element_type.default(add_1833, torch.float32); add_1833 = None + convert_element_type_1225 = torch.ops.prims.convert_element_type.default(primals_375, torch.bfloat16); primals_375 = None + all_gather_into_tensor_385 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1225, 64, '0'); convert_element_type_1225 = None + wait_tensor_473 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_385); all_gather_into_tensor_385 = None + convert_element_type_1750 = torch.ops.prims.convert_element_type.default(wait_tensor_473, torch.float32); wait_tensor_473 = None + mul_1421 = torch.ops.aten.mul.Tensor(convert_element_type_1748, convert_element_type_1750); convert_element_type_1750 = None + convert_element_type_1226 = torch.ops.prims.convert_element_type.default(add_1501, torch.float32); add_1501 = None + mul_1087 = torch.ops.aten.mul.Tensor(convert_element_type_1226, rsqrt_69); convert_element_type_1226 = None + mul_1423 = torch.ops.aten.mul.Tensor(mul_1087, mul_1421) + sum_137 = torch.ops.aten.sum.dim_IntList(mul_1423, [2], True); mul_1423 = None + div_155 = torch.ops.aten.div.Tensor(mul_1087, 2048) + mul_1424 = torch.ops.aten.mul.Tensor(div_155, sum_137); div_155 = sum_137 = None + sub_648 = torch.ops.aten.sub.Tensor(mul_1421, mul_1424); mul_1421 = mul_1424 = None + mul_1425 = torch.ops.aten.mul.Tensor(sub_648, rsqrt_69); sub_648 = rsqrt_69 = None + mul_1426 = torch.ops.aten.mul.Tensor(convert_element_type_1748, mul_1087); convert_element_type_1748 = mul_1087 = None + sum_138 = torch.ops.aten.sum.dim_IntList(mul_1426, [0, 1]); mul_1426 = None + convert_element_type_1751 = torch.ops.prims.convert_element_type.default(mul_1425, torch.bfloat16); mul_1425 = None + add_1834 = torch.ops.aten.add.Tensor(add_1832, convert_element_type_1751); add_1832 = convert_element_type_1751 = None + convert_element_type_default_70 = torch.ops.prims.convert_element_type.default(sum_138, torch.float32); sum_138 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_70, 'avg', 64, '0'); convert_element_type_default_70 = None + wait_tensor_624 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57); reduce_scatter_tensor_57 = None + view_1858 = torch.ops.aten.view.default(add_1834, [8192, 2048]) + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_1858, 1) + convert_element_type_1754 = torch.ops.prims.convert_element_type.default(unsqueeze_57, torch.float32); unsqueeze_57 = None + bmm_34 = torch.ops.aten.bmm.default(permute_606, convert_element_type_1754); permute_606 = None + bmm_35 = torch.ops.aten.bmm.default(convert_element_type_1754, permute_607); convert_element_type_1754 = permute_607 = None + convert_element_type_1755 = torch.ops.prims.convert_element_type.default(bmm_34, torch.bfloat16); bmm_34 = None + view_1859 = torch.ops.aten.view.default(bmm_35, [8192, 6]); bmm_35 = None + view_1860 = torch.ops.aten.view.default(convert_element_type_1755, [49152, 2048]); convert_element_type_1755 = None + index_60 = torch.ops.aten.index.Tensor(view_1860, [getitem_315]); view_1860 = getitem_315 = None + permute_608 = torch.ops.aten.permute.default(view_1858, [1, 0]) + mm_282 = torch.ops.aten.mm.default(permute_608, mul_1084); permute_608 = mul_1084 = None + convert_element_type_1220 = torch.ops.prims.convert_element_type.default(primals_374, torch.bfloat16); primals_374 = None + all_gather_into_tensor_384 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1220, 64, '0'); convert_element_type_1220 = None + wait_tensor_472 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_384); all_gather_into_tensor_384 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_472, [1, 0]); wait_tensor_472 = None + permute_610 = torch.ops.aten.permute.default(permute_340, [1, 0]); permute_340 = None + mm_283 = torch.ops.aten.mm.default(view_1858, permute_610); view_1858 = permute_610 = None + convert_element_type_1760 = torch.ops.prims.convert_element_type.default(mm_282, torch.float32); mm_282 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1760, 'avg', 64, '0'); convert_element_type_1760 = None + wait_tensor_625 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + convert_element_type_1215 = torch.ops.prims.convert_element_type.default(mm_180, torch.float32); mm_180 = None + neg_44 = torch.ops.aten.neg.default(convert_element_type_1215) + exp_66 = torch.ops.aten.exp.default(neg_44); neg_44 = None + add_1496 = torch.ops.aten.add.Tensor(exp_66, 1); exp_66 = None + div_110 = torch.ops.aten.div.Tensor(convert_element_type_1215, add_1496) + convert_element_type_1216 = torch.ops.prims.convert_element_type.default(div_110, torch.bfloat16); div_110 = None + mul_1427 = torch.ops.aten.mul.Tensor(mm_283, convert_element_type_1216); convert_element_type_1216 = None + mul_1428 = torch.ops.aten.mul.Tensor(mm_283, mm_181); mm_283 = mm_181 = None + permute_612 = torch.ops.aten.permute.default(mul_1427, [1, 0]) + mm_284 = torch.ops.aten.mm.default(permute_612, view_1465); permute_612 = None + convert_element_type_1217 = torch.ops.prims.convert_element_type.default(primals_373, torch.bfloat16); primals_373 = None + all_gather_into_tensor_383 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1217, 64, '0'); convert_element_type_1217 = None + wait_tensor_471 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_383); all_gather_into_tensor_383 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_471, [1, 0]); wait_tensor_471 = None + permute_614 = torch.ops.aten.permute.default(permute_339, [1, 0]); permute_339 = None + mm_285 = torch.ops.aten.mm.default(mul_1427, permute_614); mul_1427 = permute_614 = None + convert_element_type_1765 = torch.ops.prims.convert_element_type.default(mm_284, torch.float32); mm_284 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1765, 'avg', 64, '0'); convert_element_type_1765 = None + wait_tensor_626 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59); reduce_scatter_tensor_59 = None + convert_element_type_1766 = torch.ops.prims.convert_element_type.default(mul_1428, torch.float32); mul_1428 = None + reciprocal_8 = torch.ops.aten.reciprocal.default(add_1496); add_1496 = None + mul_1429 = torch.ops.aten.mul.Tensor(reciprocal_8, 1); reciprocal_8 = None + mul_1430 = torch.ops.aten.mul.Tensor(convert_element_type_1766, mul_1429); convert_element_type_1766 = None + sub_649 = torch.ops.aten.sub.Tensor(1, mul_1429); mul_1429 = None + mul_1431 = torch.ops.aten.mul.Tensor(convert_element_type_1215, sub_649); convert_element_type_1215 = sub_649 = None + add_1836 = torch.ops.aten.add.Tensor(mul_1431, 1); mul_1431 = None + mul_1432 = torch.ops.aten.mul.Tensor(mul_1430, add_1836); mul_1430 = add_1836 = None + convert_element_type_1768 = torch.ops.prims.convert_element_type.default(mul_1432, torch.bfloat16); mul_1432 = None + permute_616 = torch.ops.aten.permute.default(convert_element_type_1768, [1, 0]) + mm_286 = torch.ops.aten.mm.default(permute_616, view_1465); permute_616 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(primals_372, torch.bfloat16); primals_372 = None + all_gather_into_tensor_382 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1212, 64, '0'); convert_element_type_1212 = None + wait_tensor_470 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_382); all_gather_into_tensor_382 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_470, [1, 0]); wait_tensor_470 = None + permute_618 = torch.ops.aten.permute.default(permute_338, [1, 0]); permute_338 = None + mm_287 = torch.ops.aten.mm.default(convert_element_type_1768, permute_618); convert_element_type_1768 = permute_618 = None + add_1837 = torch.ops.aten.add.Tensor(mm_285, mm_287); mm_285 = mm_287 = None + convert_element_type_1773 = torch.ops.prims.convert_element_type.default(mm_286, torch.float32); mm_286 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1773, 'avg', 64, '0'); convert_element_type_1773 = None + wait_tensor_627 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + all_to_all_single_86 = torch.ops._c10d_functional.all_to_all_single.default(index_60, [_local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351], [_local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343], '521'); index_60 = None + wait_tensor_628 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_86); all_to_all_single_86 = None + full_364 = torch.ops.aten.full.default([sym_size_int_85, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_85 = None + slice_scatter_4 = torch.ops.aten.slice_scatter.default(full_364, wait_tensor_628, 0, 0, -1); wait_tensor_628 = None + index_61 = torch.ops.aten.index.Tensor(slice_scatter_4, [getitem_316]); slice_scatter_4 = None + permute_620 = torch.ops.aten.permute.default(index_61, [1, 0]) + _grouped_mm_102 = torch.ops.aten._grouped_mm.default(permute_620, mul_1064, cumsum_65); permute_620 = mul_1064 = None + convert_element_type_1206 = torch.ops.prims.convert_element_type.default(primals_370, torch.bfloat16); primals_370 = None + all_gather_into_tensor_378 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1206, 8, '513'); convert_element_type_1206 = None + wait_tensor_465 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_378); all_gather_into_tensor_378 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_465, [0, 2, 1]); wait_tensor_465 = None + permute_622 = torch.ops.aten.permute.default(permute_337, [0, 2, 1]); permute_337 = None + _grouped_mm_103 = torch.ops.aten._grouped_mm.default(index_61, permute_622, cumsum_65); index_61 = permute_622 = None + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(_grouped_mm_63, torch.float32); _grouped_mm_63 = None + neg_43 = torch.ops.aten.neg.default(convert_element_type_1210) + exp_65 = torch.ops.aten.exp.default(neg_43); neg_43 = None + add_1460 = torch.ops.aten.add.Tensor(exp_65, 1); exp_65 = None + div_109 = torch.ops.aten.div.Tensor(convert_element_type_1210, add_1460) + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(div_109, torch.bfloat16); div_109 = None + mul_1433 = torch.ops.aten.mul.Tensor(_grouped_mm_103, convert_element_type_1211); convert_element_type_1211 = None + mul_1434 = torch.ops.aten.mul.Tensor(_grouped_mm_103, _grouped_mm_64); _grouped_mm_103 = _grouped_mm_64 = None + permute_624 = torch.ops.aten.permute.default(mul_1433, [1, 0]) + _grouped_mm_104 = torch.ops.aten._grouped_mm.default(permute_624, index_43, cumsum_65); permute_624 = None + convert_element_type_1207 = torch.ops.prims.convert_element_type.default(primals_371, torch.bfloat16); primals_371 = None + all_gather_into_tensor_379 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1207, 8, '513'); convert_element_type_1207 = None + wait_tensor_466 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_379); all_gather_into_tensor_379 = None + permute_336 = torch.ops.aten.permute.default(wait_tensor_466, [0, 2, 1]); wait_tensor_466 = None + permute_626 = torch.ops.aten.permute.default(permute_336, [0, 2, 1]); permute_336 = None + _grouped_mm_105 = torch.ops.aten._grouped_mm.default(mul_1433, permute_626, cumsum_65); mul_1433 = permute_626 = None + convert_element_type_1774 = torch.ops.prims.convert_element_type.default(mul_1434, torch.float32); mul_1434 = None + reciprocal_9 = torch.ops.aten.reciprocal.default(add_1460); add_1460 = None + mul_1435 = torch.ops.aten.mul.Tensor(reciprocal_9, 1); reciprocal_9 = None + mul_1436 = torch.ops.aten.mul.Tensor(convert_element_type_1774, mul_1435); convert_element_type_1774 = None + sub_650 = torch.ops.aten.sub.Tensor(1, mul_1435); mul_1435 = None + mul_1437 = torch.ops.aten.mul.Tensor(convert_element_type_1210, sub_650); convert_element_type_1210 = sub_650 = None + add_1839 = torch.ops.aten.add.Tensor(mul_1437, 1); mul_1437 = None + mul_1438 = torch.ops.aten.mul.Tensor(mul_1436, add_1839); mul_1436 = add_1839 = None + convert_element_type_1776 = torch.ops.prims.convert_element_type.default(mul_1438, torch.bfloat16); mul_1438 = None + permute_628 = torch.ops.aten.permute.default(convert_element_type_1776, [1, 0]) + _grouped_mm_106 = torch.ops.aten._grouped_mm.default(permute_628, index_43, cumsum_65); permute_628 = index_43 = None + convert_element_type_1204 = torch.ops.prims.convert_element_type.default(primals_369, torch.bfloat16); primals_369 = None + all_gather_into_tensor_376 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1204, 8, '513'); convert_element_type_1204 = None + wait_tensor_463 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_376); all_gather_into_tensor_376 = None + permute_335 = torch.ops.aten.permute.default(wait_tensor_463, [0, 2, 1]); wait_tensor_463 = None + permute_630 = torch.ops.aten.permute.default(permute_335, [0, 2, 1]); permute_335 = None + _grouped_mm_107 = torch.ops.aten._grouped_mm.default(convert_element_type_1776, permute_630, cumsum_65); convert_element_type_1776 = permute_630 = cumsum_65 = None + add_1840 = torch.ops.aten.add.Tensor(_grouped_mm_105, _grouped_mm_107); _grouped_mm_105 = _grouped_mm_107 = None + convert_element_type_1777 = torch.ops.prims.convert_element_type.default(_grouped_mm_104, torch.float32); _grouped_mm_104 = None + div_156 = torch.ops.aten.div.Tensor(convert_element_type_1777, 64); convert_element_type_1777 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_156, 'sum', 8, '513'); div_156 = None + wait_tensor_629 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61); reduce_scatter_tensor_61 = None + convert_element_type_1778 = torch.ops.prims.convert_element_type.default(_grouped_mm_102, torch.float32); _grouped_mm_102 = None + div_157 = torch.ops.aten.div.Tensor(convert_element_type_1778, 64); convert_element_type_1778 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_157, 'sum', 8, '513'); div_157 = None + wait_tensor_630 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + convert_element_type_1779 = torch.ops.prims.convert_element_type.default(_grouped_mm_106, torch.float32); _grouped_mm_106 = None + div_158 = torch.ops.aten.div.Tensor(convert_element_type_1779, 64); convert_element_type_1779 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_158, 'sum', 8, '513'); div_158 = None + wait_tensor_631 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63); reduce_scatter_tensor_63 = None + index_put_60 = torch.ops.aten.index_put.default(full_364, [getitem_316], add_1840, True); full_364 = getitem_316 = add_1840 = None + slice_131 = torch.ops.aten.slice.Tensor(index_put_60, 0, 0, add_1841); index_put_60 = add_1841 = None + all_to_all_single_87 = torch.ops._c10d_functional.all_to_all_single.default(slice_131, [_local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343], [_local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351], '521'); slice_131 = _local_scalar_dense_336 = _local_scalar_dense_337 = _local_scalar_dense_338 = _local_scalar_dense_339 = _local_scalar_dense_340 = _local_scalar_dense_341 = _local_scalar_dense_342 = _local_scalar_dense_343 = _local_scalar_dense_344 = _local_scalar_dense_345 = _local_scalar_dense_346 = _local_scalar_dense_347 = _local_scalar_dense_348 = _local_scalar_dense_349 = _local_scalar_dense_350 = _local_scalar_dense_351 = None + wait_tensor_632 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_87); all_to_all_single_87 = None + index_put_61 = torch.ops.aten.index_put.default(full_default_52, [div_107], wait_tensor_632, True); div_107 = wait_tensor_632 = None + add_1845 = torch.ops.aten.add.Tensor(add_1837, index_put_61); add_1837 = index_put_61 = None + mul_1439 = torch.ops.aten.mul.Tensor(view_1859, 1.0); view_1859 = None + scatter_add_4 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_313, mul_1439); getitem_313 = mul_1439 = None + convert_element_type_1199 = torch.ops.prims.convert_element_type.default(mm_179, torch.float32); mm_179 = None + sub_504 = torch.ops.aten.sub.Tensor(convert_element_type_1199, amax_21); convert_element_type_1199 = amax_21 = None + exp_64 = torch.ops.aten.exp.default(sub_504); sub_504 = None + div_106 = torch.ops.aten.div.Tensor(exp_64, sum_85); exp_64 = sum_85 = None + mul_1440 = torch.ops.aten.mul.Tensor(scatter_add_4, div_106); scatter_add_4 = None + sum_139 = torch.ops.aten.sum.dim_IntList(mul_1440, [1], True) + neg_67 = torch.ops.aten.neg.default(div_106); div_106 = None + fma_4 = torch.ops.prims.fma.default(neg_67, sum_139, mul_1440); neg_67 = sum_139 = mul_1440 = None + convert_element_type_1780 = torch.ops.prims.convert_element_type.default(fma_4, torch.bfloat16); fma_4 = None + permute_632 = torch.ops.aten.permute.default(convert_element_type_1780, [1, 0]) + mm_288 = torch.ops.aten.mm.default(permute_632, view_1465); permute_632 = view_1465 = None + convert_element_type_1196 = torch.ops.prims.convert_element_type.default(primals_367, torch.bfloat16); primals_367 = None + all_gather_into_tensor_375 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1196, 64, '0'); convert_element_type_1196 = None + wait_tensor_459 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_375); all_gather_into_tensor_375 = None + permute_334 = torch.ops.aten.permute.default(wait_tensor_459, [1, 0]); wait_tensor_459 = None + permute_634 = torch.ops.aten.permute.default(permute_334, [1, 0]); permute_334 = None + mm_289 = torch.ops.aten.mm.default(convert_element_type_1780, permute_634); convert_element_type_1780 = permute_634 = None + add_1846 = torch.ops.aten.add.Tensor(add_1845, mm_289); add_1845 = mm_289 = None + convert_element_type_1785 = torch.ops.prims.convert_element_type.default(mm_288, torch.float32); mm_288 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1785, 'avg', 64, '0'); convert_element_type_1785 = None + wait_tensor_633 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64); reduce_scatter_tensor_64 = None + view_1861 = torch.ops.aten.view.default(add_1846, [2, 4096, 2048]); add_1846 = None + convert_element_type_1786 = torch.ops.prims.convert_element_type.default(view_1861, torch.float32); view_1861 = None + convert_element_type_1193 = torch.ops.prims.convert_element_type.default(primals_365, torch.bfloat16); primals_365 = None + all_gather_into_tensor_374 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1193, 64, '0'); convert_element_type_1193 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_374); all_gather_into_tensor_374 = None + convert_element_type_1788 = torch.ops.prims.convert_element_type.default(wait_tensor_458, torch.float32); wait_tensor_458 = None + mul_1441 = torch.ops.aten.mul.Tensor(convert_element_type_1786, convert_element_type_1788); convert_element_type_1788 = None + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(add_1436, torch.float32); add_1436 = None + mul_1044 = torch.ops.aten.mul.Tensor(convert_element_type_1194, rsqrt_68); convert_element_type_1194 = None + mul_1443 = torch.ops.aten.mul.Tensor(mul_1044, mul_1441) + sum_140 = torch.ops.aten.sum.dim_IntList(mul_1443, [2], True); mul_1443 = None + div_159 = torch.ops.aten.div.Tensor(mul_1044, 2048) + mul_1444 = torch.ops.aten.mul.Tensor(div_159, sum_140); div_159 = sum_140 = None + sub_652 = torch.ops.aten.sub.Tensor(mul_1441, mul_1444); mul_1441 = mul_1444 = None + mul_1445 = torch.ops.aten.mul.Tensor(sub_652, rsqrt_68); sub_652 = rsqrt_68 = None + mul_1446 = torch.ops.aten.mul.Tensor(convert_element_type_1786, mul_1044); convert_element_type_1786 = mul_1044 = None + sum_141 = torch.ops.aten.sum.dim_IntList(mul_1446, [0, 1]); mul_1446 = None + convert_element_type_1789 = torch.ops.prims.convert_element_type.default(mul_1445, torch.bfloat16); mul_1445 = None + add_1847 = torch.ops.aten.add.Tensor(add_1834, convert_element_type_1789); add_1834 = convert_element_type_1789 = None + convert_element_type_default_69 = torch.ops.prims.convert_element_type.default(sum_141, torch.float32); sum_141 = None + reduce_scatter_tensor_65 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_69, 'avg', 64, '0'); convert_element_type_default_69 = None + wait_tensor_634 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_65); reduce_scatter_tensor_65 = None + view_1862 = torch.ops.aten.view.default(add_1847, [8192, 2048]) + permute_636 = torch.ops.aten.permute.default(view_1862, [1, 0]) + permute_332 = torch.ops.aten.permute.default(getitem_309, [0, 2, 1, 3]) + view_1460 = torch.ops.aten.view.default(permute_332, [2, 4096, -1]); permute_332 = None + view_1462 = torch.ops.aten.view.default(view_1460, [8192, 2048]); view_1460 = None + mm_290 = torch.ops.aten.mm.default(permute_636, view_1462); permute_636 = view_1462 = None + convert_element_type_1190 = torch.ops.prims.convert_element_type.default(primals_364, torch.bfloat16); primals_364 = None + all_gather_into_tensor_373 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1190, 64, '0'); convert_element_type_1190 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_373); all_gather_into_tensor_373 = None + permute_333 = torch.ops.aten.permute.default(wait_tensor_457, [1, 0]); wait_tensor_457 = None + permute_638 = torch.ops.aten.permute.default(permute_333, [1, 0]); permute_333 = None + mm_291 = torch.ops.aten.mm.default(view_1862, permute_638); view_1862 = permute_638 = None + view_1863 = torch.ops.aten.view.default(mm_291, [2, 4096, 2048]); mm_291 = None + convert_element_type_1796 = torch.ops.prims.convert_element_type.default(mm_290, torch.float32); mm_290 = None + reduce_scatter_tensor_66 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1796, 'avg', 64, '0'); convert_element_type_1796 = None + wait_tensor_635 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_66); reduce_scatter_tensor_66 = None + view_1864 = torch.ops.aten.view.default(view_1863, [2, 4096, 16, 128]); view_1863 = None + permute_640 = torch.ops.aten.permute.default(view_1864, [0, 2, 1, 3]); view_1864 = None + fw_graph4 = self.fw_graph4 + joint_graph4 = self.joint_graph4 + mask_graph4 = self.mask_graph4 + flex_attention_backward_4 = torch.ops.higher_order.flex_attention_backward(permute_329, permute_330, permute_331, getitem_309, getitem_310, permute_640, None, fw_graph4, joint_graph4, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph4), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_329 = permute_330 = permute_331 = getitem_309 = getitem_310 = permute_640 = fw_graph4 = joint_graph4 = mask_graph4 = None + getitem_389 = flex_attention_backward_4[0] + getitem_390 = flex_attention_backward_4[1] + getitem_391 = flex_attention_backward_4[2]; flex_attention_backward_4 = None + permute_641 = torch.ops.aten.permute.default(getitem_391, [0, 2, 1, 3]); getitem_391 = None + permute_642 = torch.ops.aten.permute.default(getitem_390, [0, 2, 1, 3]); getitem_390 = None + permute_643 = torch.ops.aten.permute.default(getitem_389, [0, 2, 1, 3]); getitem_389 = None + slice_133 = torch.ops.aten.slice.Tensor(permute_642, 3, 0, 128) + slice_134 = torch.ops.aten.slice.Tensor(permute_642, 3, 128, 192); permute_642 = None + sum_142 = torch.ops.aten.sum.dim_IntList(slice_134, [2], True); slice_134 = None + cat_92 = torch.ops.aten.cat.default([slice_133, permute_641], 3); slice_133 = permute_641 = None + view_1865 = torch.ops.aten.view.default(cat_92, [2, 4096, 4096]); cat_92 = None + view_1866 = torch.ops.aten.view.default(view_1865, [8192, 4096]); view_1865 = None + permute_644 = torch.ops.aten.permute.default(view_1866, [1, 0]) + mm_292 = torch.ops.aten.mm.default(permute_644, view_1457); permute_644 = view_1457 = None + convert_element_type_1187 = torch.ops.prims.convert_element_type.default(primals_363, torch.bfloat16); primals_363 = None + all_gather_into_tensor_372 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1187, 64, '0'); convert_element_type_1187 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_372); all_gather_into_tensor_372 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_456, [1, 0]); wait_tensor_456 = None + permute_646 = torch.ops.aten.permute.default(permute_328, [1, 0]); permute_328 = None + mm_293 = torch.ops.aten.mm.default(view_1866, permute_646); view_1866 = permute_646 = None + view_1867 = torch.ops.aten.view.default(mm_293, [2, 4096, 512]); mm_293 = None + convert_element_type_1801 = torch.ops.prims.convert_element_type.default(mm_292, torch.float32); mm_292 = None + reduce_scatter_tensor_67 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1801, 'avg', 64, '0'); convert_element_type_1801 = None + wait_tensor_636 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_67); reduce_scatter_tensor_67 = None + convert_element_type_1802 = torch.ops.prims.convert_element_type.default(view_1867, torch.float32); view_1867 = None + convert_element_type_1184 = torch.ops.prims.convert_element_type.default(primals_362, torch.bfloat16); primals_362 = None + all_gather_into_tensor_371 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1184, 64, '0'); convert_element_type_1184 = None + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_371); all_gather_into_tensor_371 = None + convert_element_type_1804 = torch.ops.prims.convert_element_type.default(wait_tensor_455, torch.float32); wait_tensor_455 = None + mul_1447 = torch.ops.aten.mul.Tensor(convert_element_type_1802, convert_element_type_1804); convert_element_type_1804 = None + convert_element_type_1185 = torch.ops.prims.convert_element_type.default(getitem_305, torch.float32); getitem_305 = None + mul_1042 = torch.ops.aten.mul.Tensor(convert_element_type_1185, rsqrt_67); convert_element_type_1185 = None + mul_1449 = torch.ops.aten.mul.Tensor(mul_1042, mul_1447) + sum_143 = torch.ops.aten.sum.dim_IntList(mul_1449, [2], True); mul_1449 = None + div_160 = torch.ops.aten.div.Tensor(mul_1042, 512) + mul_1450 = torch.ops.aten.mul.Tensor(div_160, sum_143); div_160 = sum_143 = None + sub_653 = torch.ops.aten.sub.Tensor(mul_1447, mul_1450); mul_1447 = mul_1450 = None + mul_1451 = torch.ops.aten.mul.Tensor(sub_653, rsqrt_67); sub_653 = rsqrt_67 = None + mul_1452 = torch.ops.aten.mul.Tensor(convert_element_type_1802, mul_1042); convert_element_type_1802 = mul_1042 = None + sum_144 = torch.ops.aten.sum.dim_IntList(mul_1452, [0, 1]); mul_1452 = None + convert_element_type_1805 = torch.ops.prims.convert_element_type.default(mul_1451, torch.bfloat16); mul_1451 = None + convert_element_type_default_68 = torch.ops.prims.convert_element_type.default(sum_144, torch.float32); sum_144 = None + reduce_scatter_tensor_68 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_68, 'avg', 64, '0'); convert_element_type_default_68 = None + wait_tensor_637 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_68); reduce_scatter_tensor_68 = None + convert_element_type_1808 = torch.ops.prims.convert_element_type.default(sum_142, torch.float32); sum_142 = None + view_1868 = torch.ops.aten.view.default(convert_element_type_1808, [2, 4096, 1, 32, 2]); convert_element_type_1808 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1868); view_1868 = None + mul_1453 = torch.ops.aten.mul.Tensor(view_as_complex_62, clone_9); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_1453); mul_1453 = None + view_1869 = torch.ops.aten.view.default(view_as_real_62, [2, 4096, 1, 64]); view_as_real_62 = None + convert_element_type_1809 = torch.ops.prims.convert_element_type.default(view_1869, torch.bfloat16); view_1869 = None + squeeze_30 = torch.ops.aten.squeeze.dim(convert_element_type_1809, 2); convert_element_type_1809 = None + cat_93 = torch.ops.aten.cat.default([convert_element_type_1805, squeeze_30], 2); convert_element_type_1805 = squeeze_30 = None + view_1870 = torch.ops.aten.view.default(cat_93, [8192, 576]); cat_93 = None + permute_648 = torch.ops.aten.permute.default(view_1870, [1, 0]) + mm_294 = torch.ops.aten.mm.default(permute_648, view_1443); permute_648 = None + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(primals_361, torch.bfloat16); primals_361 = None + all_gather_into_tensor_370 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1179, 64, '0'); convert_element_type_1179 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_370); all_gather_into_tensor_370 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_454, [1, 0]); wait_tensor_454 = None + permute_650 = torch.ops.aten.permute.default(permute_327, [1, 0]); permute_327 = None + mm_295 = torch.ops.aten.mm.default(view_1870, permute_650); view_1870 = permute_650 = None + view_1871 = torch.ops.aten.view.default(mm_295, [2, 4096, 2048]); mm_295 = None + convert_element_type_1814 = torch.ops.prims.convert_element_type.default(mm_294, torch.float32); mm_294 = None + reduce_scatter_tensor_69 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1814, 'avg', 64, '0'); convert_element_type_1814 = None + wait_tensor_638 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_69); reduce_scatter_tensor_69 = None + slice_135 = torch.ops.aten.slice.Tensor(permute_643, 3, 0, 128) + slice_136 = torch.ops.aten.slice.Tensor(permute_643, 3, 128, 192); permute_643 = None + convert_element_type_1815 = torch.ops.prims.convert_element_type.default(slice_136, torch.float32); slice_136 = None + view_1872 = torch.ops.aten.view.default(convert_element_type_1815, [2, 4096, 16, 32, 2]); convert_element_type_1815 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1872); view_1872 = None + mul_1454 = torch.ops.aten.mul.Tensor(view_as_complex_63, clone_9); view_as_complex_63 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_1454); mul_1454 = None + view_1873 = torch.ops.aten.view.default(view_as_real_63, [2, 4096, 16, 64]); view_as_real_63 = None + convert_element_type_1816 = torch.ops.prims.convert_element_type.default(view_1873, torch.bfloat16); view_1873 = None + cat_94 = torch.ops.aten.cat.default([slice_135, convert_element_type_1816], 3); slice_135 = convert_element_type_1816 = None + view_1874 = torch.ops.aten.view.default(cat_94, [2, 4096, 3072]); cat_94 = None + view_1875 = torch.ops.aten.view.default(view_1874, [8192, 3072]); view_1874 = None + permute_652 = torch.ops.aten.permute.default(view_1875, [1, 0]) + mm_296 = torch.ops.aten.mm.default(permute_652, view_1443); permute_652 = view_1443 = None + convert_element_type_1174 = torch.ops.prims.convert_element_type.default(primals_360, torch.bfloat16); primals_360 = None + all_gather_into_tensor_369 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1174, 64, '0'); convert_element_type_1174 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_369); all_gather_into_tensor_369 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_453, [1, 0]); wait_tensor_453 = None + permute_654 = torch.ops.aten.permute.default(permute_326, [1, 0]); permute_326 = None + mm_297 = torch.ops.aten.mm.default(view_1875, permute_654); view_1875 = permute_654 = None + view_1876 = torch.ops.aten.view.default(mm_297, [2, 4096, 2048]); mm_297 = None + add_1848 = torch.ops.aten.add.Tensor(view_1871, view_1876); view_1871 = view_1876 = None + convert_element_type_1821 = torch.ops.prims.convert_element_type.default(mm_296, torch.float32); mm_296 = None + reduce_scatter_tensor_70 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1821, 'avg', 64, '0'); convert_element_type_1821 = None + wait_tensor_639 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_70); reduce_scatter_tensor_70 = None + convert_element_type_1822 = torch.ops.prims.convert_element_type.default(add_1848, torch.float32); add_1848 = None + convert_element_type_1171 = torch.ops.prims.convert_element_type.default(primals_359, torch.bfloat16); primals_359 = None + all_gather_into_tensor_368 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1171, 64, '0'); convert_element_type_1171 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_368); all_gather_into_tensor_368 = None + convert_element_type_1824 = torch.ops.prims.convert_element_type.default(wait_tensor_452, torch.float32); wait_tensor_452 = None + mul_1455 = torch.ops.aten.mul.Tensor(convert_element_type_1822, convert_element_type_1824); convert_element_type_1824 = None + convert_element_type_1172 = torch.ops.prims.convert_element_type.default(add_1433, torch.float32); add_1433 = None + mul_1038 = torch.ops.aten.mul.Tensor(convert_element_type_1172, rsqrt_66); convert_element_type_1172 = None + mul_1457 = torch.ops.aten.mul.Tensor(mul_1038, mul_1455) + sum_145 = torch.ops.aten.sum.dim_IntList(mul_1457, [2], True); mul_1457 = None + div_161 = torch.ops.aten.div.Tensor(mul_1038, 2048) + mul_1458 = torch.ops.aten.mul.Tensor(div_161, sum_145); div_161 = sum_145 = None + sub_654 = torch.ops.aten.sub.Tensor(mul_1455, mul_1458); mul_1455 = mul_1458 = None + mul_1459 = torch.ops.aten.mul.Tensor(sub_654, rsqrt_66); sub_654 = rsqrt_66 = None + mul_1460 = torch.ops.aten.mul.Tensor(convert_element_type_1822, mul_1038); convert_element_type_1822 = mul_1038 = None + sum_146 = torch.ops.aten.sum.dim_IntList(mul_1460, [0, 1]); mul_1460 = None + convert_element_type_1825 = torch.ops.prims.convert_element_type.default(mul_1459, torch.bfloat16); mul_1459 = None + add_1849 = torch.ops.aten.add.Tensor(add_1847, convert_element_type_1825); add_1847 = convert_element_type_1825 = None + convert_element_type_default_67 = torch.ops.prims.convert_element_type.default(sum_146, torch.float32); sum_146 = None + reduce_scatter_tensor_71 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_67, 'avg', 64, '0'); convert_element_type_default_67 = None + wait_tensor_640 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_71); reduce_scatter_tensor_71 = None + view_1877 = torch.ops.aten.view.default(add_1849, [8192, 2048]) + unsqueeze_58 = torch.ops.aten.unsqueeze.default(view_1877, 1) + convert_element_type_1828 = torch.ops.prims.convert_element_type.default(unsqueeze_58, torch.float32); unsqueeze_58 = None + bmm_36 = torch.ops.aten.bmm.default(permute_656, convert_element_type_1828); permute_656 = None + bmm_37 = torch.ops.aten.bmm.default(convert_element_type_1828, permute_657); convert_element_type_1828 = permute_657 = None + convert_element_type_1829 = torch.ops.prims.convert_element_type.default(bmm_36, torch.bfloat16); bmm_36 = None + view_1878 = torch.ops.aten.view.default(bmm_37, [8192, 6]); bmm_37 = None + view_1879 = torch.ops.aten.view.default(convert_element_type_1829, [49152, 2048]); convert_element_type_1829 = None + index_62 = torch.ops.aten.index.Tensor(view_1879, [getitem_301]); view_1879 = getitem_301 = None + permute_658 = torch.ops.aten.permute.default(view_1877, [1, 0]) + mm_298 = torch.ops.aten.mm.default(permute_658, mul_1035); permute_658 = mul_1035 = None + convert_element_type_1166 = torch.ops.prims.convert_element_type.default(primals_358, torch.bfloat16); primals_358 = None + all_gather_into_tensor_367 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1166, 64, '0'); convert_element_type_1166 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_367); all_gather_into_tensor_367 = None + permute_325 = torch.ops.aten.permute.default(wait_tensor_451, [1, 0]); wait_tensor_451 = None + permute_660 = torch.ops.aten.permute.default(permute_325, [1, 0]); permute_325 = None + mm_299 = torch.ops.aten.mm.default(view_1877, permute_660); view_1877 = permute_660 = None + convert_element_type_1834 = torch.ops.prims.convert_element_type.default(mm_298, torch.float32); mm_298 = None + reduce_scatter_tensor_72 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1834, 'avg', 64, '0'); convert_element_type_1834 = None + wait_tensor_641 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_72); reduce_scatter_tensor_72 = None + convert_element_type_1161 = torch.ops.prims.convert_element_type.default(mm_172, torch.float32); mm_172 = None + neg_42 = torch.ops.aten.neg.default(convert_element_type_1161) + exp_63 = torch.ops.aten.exp.default(neg_42); neg_42 = None + add_1428 = torch.ops.aten.add.Tensor(exp_63, 1); exp_63 = None + div_105 = torch.ops.aten.div.Tensor(convert_element_type_1161, add_1428) + convert_element_type_1162 = torch.ops.prims.convert_element_type.default(div_105, torch.bfloat16); div_105 = None + mul_1461 = torch.ops.aten.mul.Tensor(mm_299, convert_element_type_1162); convert_element_type_1162 = None + mul_1462 = torch.ops.aten.mul.Tensor(mm_299, mm_173); mm_299 = mm_173 = None + permute_662 = torch.ops.aten.permute.default(mul_1461, [1, 0]) + mm_300 = torch.ops.aten.mm.default(permute_662, view_1398); permute_662 = None + convert_element_type_1163 = torch.ops.prims.convert_element_type.default(primals_357, torch.bfloat16); primals_357 = None + all_gather_into_tensor_366 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1163, 64, '0'); convert_element_type_1163 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_366); all_gather_into_tensor_366 = None + permute_324 = torch.ops.aten.permute.default(wait_tensor_450, [1, 0]); wait_tensor_450 = None + permute_664 = torch.ops.aten.permute.default(permute_324, [1, 0]); permute_324 = None + mm_301 = torch.ops.aten.mm.default(mul_1461, permute_664); mul_1461 = permute_664 = None + convert_element_type_1839 = torch.ops.prims.convert_element_type.default(mm_300, torch.float32); mm_300 = None + reduce_scatter_tensor_73 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1839, 'avg', 64, '0'); convert_element_type_1839 = None + wait_tensor_642 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_73); reduce_scatter_tensor_73 = None + convert_element_type_1840 = torch.ops.prims.convert_element_type.default(mul_1462, torch.float32); mul_1462 = None + reciprocal_10 = torch.ops.aten.reciprocal.default(add_1428); add_1428 = None + mul_1463 = torch.ops.aten.mul.Tensor(reciprocal_10, 1); reciprocal_10 = None + mul_1464 = torch.ops.aten.mul.Tensor(convert_element_type_1840, mul_1463); convert_element_type_1840 = None + sub_655 = torch.ops.aten.sub.Tensor(1, mul_1463); mul_1463 = None + mul_1465 = torch.ops.aten.mul.Tensor(convert_element_type_1161, sub_655); convert_element_type_1161 = sub_655 = None + add_1851 = torch.ops.aten.add.Tensor(mul_1465, 1); mul_1465 = None + mul_1466 = torch.ops.aten.mul.Tensor(mul_1464, add_1851); mul_1464 = add_1851 = None + convert_element_type_1842 = torch.ops.prims.convert_element_type.default(mul_1466, torch.bfloat16); mul_1466 = None + permute_666 = torch.ops.aten.permute.default(convert_element_type_1842, [1, 0]) + mm_302 = torch.ops.aten.mm.default(permute_666, view_1398); permute_666 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(primals_356, torch.bfloat16); primals_356 = None + all_gather_into_tensor_365 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1158, 64, '0'); convert_element_type_1158 = None + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_365); all_gather_into_tensor_365 = None + permute_323 = torch.ops.aten.permute.default(wait_tensor_449, [1, 0]); wait_tensor_449 = None + permute_668 = torch.ops.aten.permute.default(permute_323, [1, 0]); permute_323 = None + mm_303 = torch.ops.aten.mm.default(convert_element_type_1842, permute_668); convert_element_type_1842 = permute_668 = None + add_1852 = torch.ops.aten.add.Tensor(mm_301, mm_303); mm_301 = mm_303 = None + convert_element_type_1847 = torch.ops.prims.convert_element_type.default(mm_302, torch.float32); mm_302 = None + reduce_scatter_tensor_74 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1847, 'avg', 64, '0'); convert_element_type_1847 = None + wait_tensor_643 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_74); reduce_scatter_tensor_74 = None + all_to_all_single_88 = torch.ops._c10d_functional.all_to_all_single.default(index_62, [_local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335], [_local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327], '521'); index_62 = None + wait_tensor_644 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_88); all_to_all_single_88 = None + full_368 = torch.ops.aten.full.default([sym_size_int_81, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_81 = None + slice_scatter_5 = torch.ops.aten.slice_scatter.default(full_368, wait_tensor_644, 0, 0, -1); wait_tensor_644 = None + index_63 = torch.ops.aten.index.Tensor(slice_scatter_5, [getitem_302]); slice_scatter_5 = None + permute_670 = torch.ops.aten.permute.default(index_63, [1, 0]) + _grouped_mm_108 = torch.ops.aten._grouped_mm.default(permute_670, mul_1015, cumsum_62); permute_670 = mul_1015 = None + convert_element_type_1152 = torch.ops.prims.convert_element_type.default(primals_354, torch.bfloat16); primals_354 = None + all_gather_into_tensor_361 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1152, 8, '513'); convert_element_type_1152 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_361); all_gather_into_tensor_361 = None + permute_322 = torch.ops.aten.permute.default(wait_tensor_444, [0, 2, 1]); wait_tensor_444 = None + permute_672 = torch.ops.aten.permute.default(permute_322, [0, 2, 1]); permute_322 = None + _grouped_mm_109 = torch.ops.aten._grouped_mm.default(index_63, permute_672, cumsum_62); index_63 = permute_672 = None + convert_element_type_1156 = torch.ops.prims.convert_element_type.default(_grouped_mm_60, torch.float32); _grouped_mm_60 = None + neg_41 = torch.ops.aten.neg.default(convert_element_type_1156) + exp_62 = torch.ops.aten.exp.default(neg_41); neg_41 = None + add_1392 = torch.ops.aten.add.Tensor(exp_62, 1); exp_62 = None + div_104 = torch.ops.aten.div.Tensor(convert_element_type_1156, add_1392) + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(div_104, torch.bfloat16); div_104 = None + mul_1467 = torch.ops.aten.mul.Tensor(_grouped_mm_109, convert_element_type_1157); convert_element_type_1157 = None + mul_1468 = torch.ops.aten.mul.Tensor(_grouped_mm_109, _grouped_mm_61); _grouped_mm_109 = _grouped_mm_61 = None + permute_674 = torch.ops.aten.permute.default(mul_1467, [1, 0]) + _grouped_mm_110 = torch.ops.aten._grouped_mm.default(permute_674, index_41, cumsum_62); permute_674 = None + convert_element_type_1153 = torch.ops.prims.convert_element_type.default(primals_355, torch.bfloat16); primals_355 = None + all_gather_into_tensor_362 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1153, 8, '513'); convert_element_type_1153 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_362); all_gather_into_tensor_362 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_445, [0, 2, 1]); wait_tensor_445 = None + permute_676 = torch.ops.aten.permute.default(permute_321, [0, 2, 1]); permute_321 = None + _grouped_mm_111 = torch.ops.aten._grouped_mm.default(mul_1467, permute_676, cumsum_62); mul_1467 = permute_676 = None + convert_element_type_1848 = torch.ops.prims.convert_element_type.default(mul_1468, torch.float32); mul_1468 = None + reciprocal_11 = torch.ops.aten.reciprocal.default(add_1392); add_1392 = None + mul_1469 = torch.ops.aten.mul.Tensor(reciprocal_11, 1); reciprocal_11 = None + mul_1470 = torch.ops.aten.mul.Tensor(convert_element_type_1848, mul_1469); convert_element_type_1848 = None + sub_656 = torch.ops.aten.sub.Tensor(1, mul_1469); mul_1469 = None + mul_1471 = torch.ops.aten.mul.Tensor(convert_element_type_1156, sub_656); convert_element_type_1156 = sub_656 = None + add_1854 = torch.ops.aten.add.Tensor(mul_1471, 1); mul_1471 = None + mul_1472 = torch.ops.aten.mul.Tensor(mul_1470, add_1854); mul_1470 = add_1854 = None + convert_element_type_1850 = torch.ops.prims.convert_element_type.default(mul_1472, torch.bfloat16); mul_1472 = None + permute_678 = torch.ops.aten.permute.default(convert_element_type_1850, [1, 0]) + _grouped_mm_112 = torch.ops.aten._grouped_mm.default(permute_678, index_41, cumsum_62); permute_678 = index_41 = None + convert_element_type_1150 = torch.ops.prims.convert_element_type.default(primals_353, torch.bfloat16); primals_353 = None + all_gather_into_tensor_359 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1150, 8, '513'); convert_element_type_1150 = None + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_359); all_gather_into_tensor_359 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_442, [0, 2, 1]); wait_tensor_442 = None + permute_680 = torch.ops.aten.permute.default(permute_320, [0, 2, 1]); permute_320 = None + _grouped_mm_113 = torch.ops.aten._grouped_mm.default(convert_element_type_1850, permute_680, cumsum_62); convert_element_type_1850 = permute_680 = cumsum_62 = None + add_1855 = torch.ops.aten.add.Tensor(_grouped_mm_111, _grouped_mm_113); _grouped_mm_111 = _grouped_mm_113 = None + convert_element_type_1851 = torch.ops.prims.convert_element_type.default(_grouped_mm_110, torch.float32); _grouped_mm_110 = None + div_162 = torch.ops.aten.div.Tensor(convert_element_type_1851, 64); convert_element_type_1851 = None + reduce_scatter_tensor_75 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_162, 'sum', 8, '513'); div_162 = None + wait_tensor_645 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_75); reduce_scatter_tensor_75 = None + convert_element_type_1852 = torch.ops.prims.convert_element_type.default(_grouped_mm_108, torch.float32); _grouped_mm_108 = None + div_163 = torch.ops.aten.div.Tensor(convert_element_type_1852, 64); convert_element_type_1852 = None + reduce_scatter_tensor_76 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_163, 'sum', 8, '513'); div_163 = None + wait_tensor_646 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_76); reduce_scatter_tensor_76 = None + convert_element_type_1853 = torch.ops.prims.convert_element_type.default(_grouped_mm_112, torch.float32); _grouped_mm_112 = None + div_164 = torch.ops.aten.div.Tensor(convert_element_type_1853, 64); convert_element_type_1853 = None + reduce_scatter_tensor_77 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_164, 'sum', 8, '513'); div_164 = None + wait_tensor_647 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_77); reduce_scatter_tensor_77 = None + index_put_62 = torch.ops.aten.index_put.default(full_368, [getitem_302], add_1855, True); full_368 = getitem_302 = add_1855 = None + slice_137 = torch.ops.aten.slice.Tensor(index_put_62, 0, 0, add_1856); index_put_62 = add_1856 = None + all_to_all_single_89 = torch.ops._c10d_functional.all_to_all_single.default(slice_137, [_local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327], [_local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335], '521'); slice_137 = _local_scalar_dense_320 = _local_scalar_dense_321 = _local_scalar_dense_322 = _local_scalar_dense_323 = _local_scalar_dense_324 = _local_scalar_dense_325 = _local_scalar_dense_326 = _local_scalar_dense_327 = _local_scalar_dense_328 = _local_scalar_dense_329 = _local_scalar_dense_330 = _local_scalar_dense_331 = _local_scalar_dense_332 = _local_scalar_dense_333 = _local_scalar_dense_334 = _local_scalar_dense_335 = None + wait_tensor_648 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_89); all_to_all_single_89 = None + index_put_63 = torch.ops.aten.index_put.default(full_default_52, [div_102], wait_tensor_648, True); div_102 = wait_tensor_648 = None + add_1860 = torch.ops.aten.add.Tensor(add_1852, index_put_63); add_1852 = index_put_63 = None + mul_1473 = torch.ops.aten.mul.Tensor(view_1878, 1.0); view_1878 = None + scatter_add_5 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_299, mul_1473); getitem_299 = mul_1473 = None + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(mm_171, torch.float32); mm_171 = None + sub_480 = torch.ops.aten.sub.Tensor(convert_element_type_1145, amax_20); convert_element_type_1145 = amax_20 = None + exp_61 = torch.ops.aten.exp.default(sub_480); sub_480 = None + div_101 = torch.ops.aten.div.Tensor(exp_61, sum_81); exp_61 = sum_81 = None + mul_1474 = torch.ops.aten.mul.Tensor(scatter_add_5, div_101); scatter_add_5 = None + sum_147 = torch.ops.aten.sum.dim_IntList(mul_1474, [1], True) + neg_70 = torch.ops.aten.neg.default(div_101); div_101 = None + fma_5 = torch.ops.prims.fma.default(neg_70, sum_147, mul_1474); neg_70 = sum_147 = mul_1474 = None + convert_element_type_1854 = torch.ops.prims.convert_element_type.default(fma_5, torch.bfloat16); fma_5 = None + permute_682 = torch.ops.aten.permute.default(convert_element_type_1854, [1, 0]) + mm_304 = torch.ops.aten.mm.default(permute_682, view_1398); permute_682 = view_1398 = None + convert_element_type_1142 = torch.ops.prims.convert_element_type.default(primals_351, torch.bfloat16); primals_351 = None + all_gather_into_tensor_358 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1142, 64, '0'); convert_element_type_1142 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_358); all_gather_into_tensor_358 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_438, [1, 0]); wait_tensor_438 = None + permute_684 = torch.ops.aten.permute.default(permute_319, [1, 0]); permute_319 = None + mm_305 = torch.ops.aten.mm.default(convert_element_type_1854, permute_684); convert_element_type_1854 = permute_684 = None + add_1861 = torch.ops.aten.add.Tensor(add_1860, mm_305); add_1860 = mm_305 = None + convert_element_type_1859 = torch.ops.prims.convert_element_type.default(mm_304, torch.float32); mm_304 = None + reduce_scatter_tensor_78 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1859, 'avg', 64, '0'); convert_element_type_1859 = None + wait_tensor_649 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_78); reduce_scatter_tensor_78 = None + view_1880 = torch.ops.aten.view.default(add_1861, [2, 4096, 2048]); add_1861 = None + convert_element_type_1860 = torch.ops.prims.convert_element_type.default(view_1880, torch.float32); view_1880 = None + convert_element_type_1139 = torch.ops.prims.convert_element_type.default(primals_349, torch.bfloat16); primals_349 = None + all_gather_into_tensor_357 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1139, 64, '0'); convert_element_type_1139 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_357); all_gather_into_tensor_357 = None + convert_element_type_1862 = torch.ops.prims.convert_element_type.default(wait_tensor_437, torch.float32); wait_tensor_437 = None + mul_1475 = torch.ops.aten.mul.Tensor(convert_element_type_1860, convert_element_type_1862); convert_element_type_1862 = None + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(add_1368, torch.float32); add_1368 = None + mul_995 = torch.ops.aten.mul.Tensor(convert_element_type_1140, rsqrt_65); convert_element_type_1140 = None + mul_1477 = torch.ops.aten.mul.Tensor(mul_995, mul_1475) + sum_148 = torch.ops.aten.sum.dim_IntList(mul_1477, [2], True); mul_1477 = None + div_165 = torch.ops.aten.div.Tensor(mul_995, 2048) + mul_1478 = torch.ops.aten.mul.Tensor(div_165, sum_148); div_165 = sum_148 = None + sub_658 = torch.ops.aten.sub.Tensor(mul_1475, mul_1478); mul_1475 = mul_1478 = None + mul_1479 = torch.ops.aten.mul.Tensor(sub_658, rsqrt_65); sub_658 = rsqrt_65 = None + mul_1480 = torch.ops.aten.mul.Tensor(convert_element_type_1860, mul_995); convert_element_type_1860 = mul_995 = None + sum_149 = torch.ops.aten.sum.dim_IntList(mul_1480, [0, 1]); mul_1480 = None + convert_element_type_1863 = torch.ops.prims.convert_element_type.default(mul_1479, torch.bfloat16); mul_1479 = None + add_1862 = torch.ops.aten.add.Tensor(add_1849, convert_element_type_1863); add_1849 = convert_element_type_1863 = None + convert_element_type_default_66 = torch.ops.prims.convert_element_type.default(sum_149, torch.float32); sum_149 = None + reduce_scatter_tensor_79 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_66, 'avg', 64, '0'); convert_element_type_default_66 = None + wait_tensor_650 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_79); reduce_scatter_tensor_79 = None + view_1881 = torch.ops.aten.view.default(add_1862, [8192, 2048]) + permute_686 = torch.ops.aten.permute.default(view_1881, [1, 0]) + permute_317 = torch.ops.aten.permute.default(getitem_295, [0, 2, 1, 3]) + view_1393 = torch.ops.aten.view.default(permute_317, [2, 4096, -1]); permute_317 = None + view_1395 = torch.ops.aten.view.default(view_1393, [8192, 2048]); view_1393 = None + mm_306 = torch.ops.aten.mm.default(permute_686, view_1395); permute_686 = view_1395 = None + convert_element_type_1136 = torch.ops.prims.convert_element_type.default(primals_348, torch.bfloat16); primals_348 = None + all_gather_into_tensor_356 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1136, 64, '0'); convert_element_type_1136 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_356); all_gather_into_tensor_356 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_436, [1, 0]); wait_tensor_436 = None + permute_688 = torch.ops.aten.permute.default(permute_318, [1, 0]); permute_318 = None + mm_307 = torch.ops.aten.mm.default(view_1881, permute_688); view_1881 = permute_688 = None + view_1882 = torch.ops.aten.view.default(mm_307, [2, 4096, 2048]); mm_307 = None + convert_element_type_1870 = torch.ops.prims.convert_element_type.default(mm_306, torch.float32); mm_306 = None + reduce_scatter_tensor_80 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1870, 'avg', 64, '0'); convert_element_type_1870 = None + wait_tensor_651 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_80); reduce_scatter_tensor_80 = None + view_1883 = torch.ops.aten.view.default(view_1882, [2, 4096, 16, 128]); view_1882 = None + permute_690 = torch.ops.aten.permute.default(view_1883, [0, 2, 1, 3]); view_1883 = None + fw_graph5 = self.fw_graph5 + joint_graph5 = self.joint_graph5 + mask_graph5 = self.mask_graph5 + flex_attention_backward_5 = torch.ops.higher_order.flex_attention_backward(permute_314, permute_315, permute_316, getitem_295, getitem_296, permute_690, None, fw_graph5, joint_graph5, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph5), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_314 = permute_315 = permute_316 = getitem_295 = getitem_296 = permute_690 = fw_graph5 = joint_graph5 = mask_graph5 = None + getitem_393 = flex_attention_backward_5[0] + getitem_394 = flex_attention_backward_5[1] + getitem_395 = flex_attention_backward_5[2]; flex_attention_backward_5 = None + permute_691 = torch.ops.aten.permute.default(getitem_395, [0, 2, 1, 3]); getitem_395 = None + permute_692 = torch.ops.aten.permute.default(getitem_394, [0, 2, 1, 3]); getitem_394 = None + permute_693 = torch.ops.aten.permute.default(getitem_393, [0, 2, 1, 3]); getitem_393 = None + slice_139 = torch.ops.aten.slice.Tensor(permute_692, 3, 0, 128) + slice_140 = torch.ops.aten.slice.Tensor(permute_692, 3, 128, 192); permute_692 = None + sum_150 = torch.ops.aten.sum.dim_IntList(slice_140, [2], True); slice_140 = None + cat_95 = torch.ops.aten.cat.default([slice_139, permute_691], 3); slice_139 = permute_691 = None + view_1884 = torch.ops.aten.view.default(cat_95, [2, 4096, 4096]); cat_95 = None + view_1885 = torch.ops.aten.view.default(view_1884, [8192, 4096]); view_1884 = None + permute_694 = torch.ops.aten.permute.default(view_1885, [1, 0]) + mm_308 = torch.ops.aten.mm.default(permute_694, view_1390); permute_694 = view_1390 = None + convert_element_type_1133 = torch.ops.prims.convert_element_type.default(primals_347, torch.bfloat16); primals_347 = None + all_gather_into_tensor_355 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1133, 64, '0'); convert_element_type_1133 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_355); all_gather_into_tensor_355 = None + permute_313 = torch.ops.aten.permute.default(wait_tensor_435, [1, 0]); wait_tensor_435 = None + permute_696 = torch.ops.aten.permute.default(permute_313, [1, 0]); permute_313 = None + mm_309 = torch.ops.aten.mm.default(view_1885, permute_696); view_1885 = permute_696 = None + view_1886 = torch.ops.aten.view.default(mm_309, [2, 4096, 512]); mm_309 = None + convert_element_type_1875 = torch.ops.prims.convert_element_type.default(mm_308, torch.float32); mm_308 = None + reduce_scatter_tensor_81 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1875, 'avg', 64, '0'); convert_element_type_1875 = None + wait_tensor_652 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_81); reduce_scatter_tensor_81 = None + convert_element_type_1876 = torch.ops.prims.convert_element_type.default(view_1886, torch.float32); view_1886 = None + convert_element_type_1130 = torch.ops.prims.convert_element_type.default(primals_346, torch.bfloat16); primals_346 = None + all_gather_into_tensor_354 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1130, 64, '0'); convert_element_type_1130 = None + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_354); all_gather_into_tensor_354 = None + convert_element_type_1878 = torch.ops.prims.convert_element_type.default(wait_tensor_434, torch.float32); wait_tensor_434 = None + mul_1481 = torch.ops.aten.mul.Tensor(convert_element_type_1876, convert_element_type_1878); convert_element_type_1878 = None + convert_element_type_1131 = torch.ops.prims.convert_element_type.default(getitem_291, torch.float32); getitem_291 = None + mul_993 = torch.ops.aten.mul.Tensor(convert_element_type_1131, rsqrt_64); convert_element_type_1131 = None + mul_1483 = torch.ops.aten.mul.Tensor(mul_993, mul_1481) + sum_151 = torch.ops.aten.sum.dim_IntList(mul_1483, [2], True); mul_1483 = None + div_166 = torch.ops.aten.div.Tensor(mul_993, 512) + mul_1484 = torch.ops.aten.mul.Tensor(div_166, sum_151); div_166 = sum_151 = None + sub_659 = torch.ops.aten.sub.Tensor(mul_1481, mul_1484); mul_1481 = mul_1484 = None + mul_1485 = torch.ops.aten.mul.Tensor(sub_659, rsqrt_64); sub_659 = rsqrt_64 = None + mul_1486 = torch.ops.aten.mul.Tensor(convert_element_type_1876, mul_993); convert_element_type_1876 = mul_993 = None + sum_152 = torch.ops.aten.sum.dim_IntList(mul_1486, [0, 1]); mul_1486 = None + convert_element_type_1879 = torch.ops.prims.convert_element_type.default(mul_1485, torch.bfloat16); mul_1485 = None + convert_element_type_default_65 = torch.ops.prims.convert_element_type.default(sum_152, torch.float32); sum_152 = None + reduce_scatter_tensor_82 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_65, 'avg', 64, '0'); convert_element_type_default_65 = None + wait_tensor_653 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_82); reduce_scatter_tensor_82 = None + convert_element_type_1882 = torch.ops.prims.convert_element_type.default(sum_150, torch.float32); sum_150 = None + view_1887 = torch.ops.aten.view.default(convert_element_type_1882, [2, 4096, 1, 32, 2]); convert_element_type_1882 = None + view_as_complex_64 = torch.ops.aten.view_as_complex.default(view_1887); view_1887 = None + mul_1487 = torch.ops.aten.mul.Tensor(view_as_complex_64, clone_9); view_as_complex_64 = None + view_as_real_64 = torch.ops.aten.view_as_real.default(mul_1487); mul_1487 = None + view_1888 = torch.ops.aten.view.default(view_as_real_64, [2, 4096, 1, 64]); view_as_real_64 = None + convert_element_type_1883 = torch.ops.prims.convert_element_type.default(view_1888, torch.bfloat16); view_1888 = None + squeeze_31 = torch.ops.aten.squeeze.dim(convert_element_type_1883, 2); convert_element_type_1883 = None + cat_96 = torch.ops.aten.cat.default([convert_element_type_1879, squeeze_31], 2); convert_element_type_1879 = squeeze_31 = None + view_1889 = torch.ops.aten.view.default(cat_96, [8192, 576]); cat_96 = None + permute_698 = torch.ops.aten.permute.default(view_1889, [1, 0]) + mm_310 = torch.ops.aten.mm.default(permute_698, view_1376); permute_698 = None + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(primals_345, torch.bfloat16); primals_345 = None + all_gather_into_tensor_353 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1125, 64, '0'); convert_element_type_1125 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_353); all_gather_into_tensor_353 = None + permute_312 = torch.ops.aten.permute.default(wait_tensor_433, [1, 0]); wait_tensor_433 = None + permute_700 = torch.ops.aten.permute.default(permute_312, [1, 0]); permute_312 = None + mm_311 = torch.ops.aten.mm.default(view_1889, permute_700); view_1889 = permute_700 = None + view_1890 = torch.ops.aten.view.default(mm_311, [2, 4096, 2048]); mm_311 = None + convert_element_type_1888 = torch.ops.prims.convert_element_type.default(mm_310, torch.float32); mm_310 = None + reduce_scatter_tensor_83 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1888, 'avg', 64, '0'); convert_element_type_1888 = None + wait_tensor_654 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_83); reduce_scatter_tensor_83 = None + slice_141 = torch.ops.aten.slice.Tensor(permute_693, 3, 0, 128) + slice_142 = torch.ops.aten.slice.Tensor(permute_693, 3, 128, 192); permute_693 = None + convert_element_type_1889 = torch.ops.prims.convert_element_type.default(slice_142, torch.float32); slice_142 = None + view_1891 = torch.ops.aten.view.default(convert_element_type_1889, [2, 4096, 16, 32, 2]); convert_element_type_1889 = None + view_as_complex_65 = torch.ops.aten.view_as_complex.default(view_1891); view_1891 = None + mul_1488 = torch.ops.aten.mul.Tensor(view_as_complex_65, clone_9); view_as_complex_65 = None + view_as_real_65 = torch.ops.aten.view_as_real.default(mul_1488); mul_1488 = None + view_1892 = torch.ops.aten.view.default(view_as_real_65, [2, 4096, 16, 64]); view_as_real_65 = None + convert_element_type_1890 = torch.ops.prims.convert_element_type.default(view_1892, torch.bfloat16); view_1892 = None + cat_97 = torch.ops.aten.cat.default([slice_141, convert_element_type_1890], 3); slice_141 = convert_element_type_1890 = None + view_1893 = torch.ops.aten.view.default(cat_97, [2, 4096, 3072]); cat_97 = None + view_1894 = torch.ops.aten.view.default(view_1893, [8192, 3072]); view_1893 = None + permute_702 = torch.ops.aten.permute.default(view_1894, [1, 0]) + mm_312 = torch.ops.aten.mm.default(permute_702, view_1376); permute_702 = view_1376 = None + convert_element_type_1120 = torch.ops.prims.convert_element_type.default(primals_344, torch.bfloat16); primals_344 = None + all_gather_into_tensor_352 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1120, 64, '0'); convert_element_type_1120 = None + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_352); all_gather_into_tensor_352 = None + permute_311 = torch.ops.aten.permute.default(wait_tensor_432, [1, 0]); wait_tensor_432 = None + permute_704 = torch.ops.aten.permute.default(permute_311, [1, 0]); permute_311 = None + mm_313 = torch.ops.aten.mm.default(view_1894, permute_704); view_1894 = permute_704 = None + view_1895 = torch.ops.aten.view.default(mm_313, [2, 4096, 2048]); mm_313 = None + add_1863 = torch.ops.aten.add.Tensor(view_1890, view_1895); view_1890 = view_1895 = None + convert_element_type_1895 = torch.ops.prims.convert_element_type.default(mm_312, torch.float32); mm_312 = None + reduce_scatter_tensor_84 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1895, 'avg', 64, '0'); convert_element_type_1895 = None + wait_tensor_655 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_84); reduce_scatter_tensor_84 = None + convert_element_type_1896 = torch.ops.prims.convert_element_type.default(add_1863, torch.float32); add_1863 = None + convert_element_type_1117 = torch.ops.prims.convert_element_type.default(primals_343, torch.bfloat16); primals_343 = None + all_gather_into_tensor_351 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1117, 64, '0'); convert_element_type_1117 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_351); all_gather_into_tensor_351 = None + convert_element_type_1898 = torch.ops.prims.convert_element_type.default(wait_tensor_431, torch.float32); wait_tensor_431 = None + mul_1489 = torch.ops.aten.mul.Tensor(convert_element_type_1896, convert_element_type_1898); convert_element_type_1898 = None + convert_element_type_1118 = torch.ops.prims.convert_element_type.default(add_1365, torch.float32); add_1365 = None + mul_989 = torch.ops.aten.mul.Tensor(convert_element_type_1118, rsqrt_63); convert_element_type_1118 = None + mul_1491 = torch.ops.aten.mul.Tensor(mul_989, mul_1489) + sum_153 = torch.ops.aten.sum.dim_IntList(mul_1491, [2], True); mul_1491 = None + div_167 = torch.ops.aten.div.Tensor(mul_989, 2048) + mul_1492 = torch.ops.aten.mul.Tensor(div_167, sum_153); div_167 = sum_153 = None + sub_660 = torch.ops.aten.sub.Tensor(mul_1489, mul_1492); mul_1489 = mul_1492 = None + mul_1493 = torch.ops.aten.mul.Tensor(sub_660, rsqrt_63); sub_660 = rsqrt_63 = None + mul_1494 = torch.ops.aten.mul.Tensor(convert_element_type_1896, mul_989); convert_element_type_1896 = mul_989 = None + sum_154 = torch.ops.aten.sum.dim_IntList(mul_1494, [0, 1]); mul_1494 = None + convert_element_type_1899 = torch.ops.prims.convert_element_type.default(mul_1493, torch.bfloat16); mul_1493 = None + add_1864 = torch.ops.aten.add.Tensor(add_1862, convert_element_type_1899); add_1862 = convert_element_type_1899 = None + convert_element_type_default_64 = torch.ops.prims.convert_element_type.default(sum_154, torch.float32); sum_154 = None + reduce_scatter_tensor_85 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_64, 'avg', 64, '0'); convert_element_type_default_64 = None + wait_tensor_656 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_85); reduce_scatter_tensor_85 = None + view_1896 = torch.ops.aten.view.default(add_1864, [8192, 2048]) + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_1896, 1) + convert_element_type_1902 = torch.ops.prims.convert_element_type.default(unsqueeze_59, torch.float32); unsqueeze_59 = None + bmm_38 = torch.ops.aten.bmm.default(permute_706, convert_element_type_1902); permute_706 = None + bmm_39 = torch.ops.aten.bmm.default(convert_element_type_1902, permute_707); convert_element_type_1902 = permute_707 = None + convert_element_type_1903 = torch.ops.prims.convert_element_type.default(bmm_38, torch.bfloat16); bmm_38 = None + view_1897 = torch.ops.aten.view.default(bmm_39, [8192, 6]); bmm_39 = None + view_1898 = torch.ops.aten.view.default(convert_element_type_1903, [49152, 2048]); convert_element_type_1903 = None + index_64 = torch.ops.aten.index.Tensor(view_1898, [getitem_287]); view_1898 = getitem_287 = None + permute_708 = torch.ops.aten.permute.default(view_1896, [1, 0]) + mm_314 = torch.ops.aten.mm.default(permute_708, mul_986); permute_708 = mul_986 = None + convert_element_type_1112 = torch.ops.prims.convert_element_type.default(primals_342, torch.bfloat16); primals_342 = None + all_gather_into_tensor_350 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1112, 64, '0'); convert_element_type_1112 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_350); all_gather_into_tensor_350 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_430, [1, 0]); wait_tensor_430 = None + permute_710 = torch.ops.aten.permute.default(permute_310, [1, 0]); permute_310 = None + mm_315 = torch.ops.aten.mm.default(view_1896, permute_710); view_1896 = permute_710 = None + convert_element_type_1908 = torch.ops.prims.convert_element_type.default(mm_314, torch.float32); mm_314 = None + reduce_scatter_tensor_86 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1908, 'avg', 64, '0'); convert_element_type_1908 = None + wait_tensor_657 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_86); reduce_scatter_tensor_86 = None + convert_element_type_1107 = torch.ops.prims.convert_element_type.default(mm_164, torch.float32); mm_164 = None + neg_40 = torch.ops.aten.neg.default(convert_element_type_1107) + exp_60 = torch.ops.aten.exp.default(neg_40); neg_40 = None + add_1360 = torch.ops.aten.add.Tensor(exp_60, 1); exp_60 = None + div_100 = torch.ops.aten.div.Tensor(convert_element_type_1107, add_1360) + convert_element_type_1108 = torch.ops.prims.convert_element_type.default(div_100, torch.bfloat16); div_100 = None + mul_1495 = torch.ops.aten.mul.Tensor(mm_315, convert_element_type_1108); convert_element_type_1108 = None + mul_1496 = torch.ops.aten.mul.Tensor(mm_315, mm_165); mm_315 = mm_165 = None + permute_712 = torch.ops.aten.permute.default(mul_1495, [1, 0]) + mm_316 = torch.ops.aten.mm.default(permute_712, view_1331); permute_712 = None + convert_element_type_1109 = torch.ops.prims.convert_element_type.default(primals_341, torch.bfloat16); primals_341 = None + all_gather_into_tensor_349 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1109, 64, '0'); convert_element_type_1109 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_349); all_gather_into_tensor_349 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_429, [1, 0]); wait_tensor_429 = None + permute_714 = torch.ops.aten.permute.default(permute_309, [1, 0]); permute_309 = None + mm_317 = torch.ops.aten.mm.default(mul_1495, permute_714); mul_1495 = permute_714 = None + convert_element_type_1913 = torch.ops.prims.convert_element_type.default(mm_316, torch.float32); mm_316 = None + reduce_scatter_tensor_87 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1913, 'avg', 64, '0'); convert_element_type_1913 = None + wait_tensor_658 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_87); reduce_scatter_tensor_87 = None + convert_element_type_1914 = torch.ops.prims.convert_element_type.default(mul_1496, torch.float32); mul_1496 = None + reciprocal_12 = torch.ops.aten.reciprocal.default(add_1360); add_1360 = None + mul_1497 = torch.ops.aten.mul.Tensor(reciprocal_12, 1); reciprocal_12 = None + mul_1498 = torch.ops.aten.mul.Tensor(convert_element_type_1914, mul_1497); convert_element_type_1914 = None + sub_661 = torch.ops.aten.sub.Tensor(1, mul_1497); mul_1497 = None + mul_1499 = torch.ops.aten.mul.Tensor(convert_element_type_1107, sub_661); convert_element_type_1107 = sub_661 = None + add_1866 = torch.ops.aten.add.Tensor(mul_1499, 1); mul_1499 = None + mul_1500 = torch.ops.aten.mul.Tensor(mul_1498, add_1866); mul_1498 = add_1866 = None + convert_element_type_1916 = torch.ops.prims.convert_element_type.default(mul_1500, torch.bfloat16); mul_1500 = None + permute_716 = torch.ops.aten.permute.default(convert_element_type_1916, [1, 0]) + mm_318 = torch.ops.aten.mm.default(permute_716, view_1331); permute_716 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(primals_340, torch.bfloat16); primals_340 = None + all_gather_into_tensor_348 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1104, 64, '0'); convert_element_type_1104 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_348); all_gather_into_tensor_348 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_428, [1, 0]); wait_tensor_428 = None + permute_718 = torch.ops.aten.permute.default(permute_308, [1, 0]); permute_308 = None + mm_319 = torch.ops.aten.mm.default(convert_element_type_1916, permute_718); convert_element_type_1916 = permute_718 = None + add_1867 = torch.ops.aten.add.Tensor(mm_317, mm_319); mm_317 = mm_319 = None + convert_element_type_1921 = torch.ops.prims.convert_element_type.default(mm_318, torch.float32); mm_318 = None + reduce_scatter_tensor_88 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1921, 'avg', 64, '0'); convert_element_type_1921 = None + wait_tensor_659 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_88); reduce_scatter_tensor_88 = None + all_to_all_single_90 = torch.ops._c10d_functional.all_to_all_single.default(index_64, [_local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319], [_local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311], '521'); index_64 = None + wait_tensor_660 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_90); all_to_all_single_90 = None + full_372 = torch.ops.aten.full.default([sym_size_int_77, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_77 = None + slice_scatter_6 = torch.ops.aten.slice_scatter.default(full_372, wait_tensor_660, 0, 0, -1); wait_tensor_660 = None + index_65 = torch.ops.aten.index.Tensor(slice_scatter_6, [getitem_288]); slice_scatter_6 = None + permute_720 = torch.ops.aten.permute.default(index_65, [1, 0]) + _grouped_mm_114 = torch.ops.aten._grouped_mm.default(permute_720, mul_966, cumsum_59); permute_720 = mul_966 = None + convert_element_type_1098 = torch.ops.prims.convert_element_type.default(primals_338, torch.bfloat16); primals_338 = None + all_gather_into_tensor_344 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1098, 8, '513'); convert_element_type_1098 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_344); all_gather_into_tensor_344 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_423, [0, 2, 1]); wait_tensor_423 = None + permute_722 = torch.ops.aten.permute.default(permute_307, [0, 2, 1]); permute_307 = None + _grouped_mm_115 = torch.ops.aten._grouped_mm.default(index_65, permute_722, cumsum_59); index_65 = permute_722 = None + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(_grouped_mm_57, torch.float32); _grouped_mm_57 = None + neg_39 = torch.ops.aten.neg.default(convert_element_type_1102) + exp_59 = torch.ops.aten.exp.default(neg_39); neg_39 = None + add_1324 = torch.ops.aten.add.Tensor(exp_59, 1); exp_59 = None + div_99 = torch.ops.aten.div.Tensor(convert_element_type_1102, add_1324) + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(div_99, torch.bfloat16); div_99 = None + mul_1501 = torch.ops.aten.mul.Tensor(_grouped_mm_115, convert_element_type_1103); convert_element_type_1103 = None + mul_1502 = torch.ops.aten.mul.Tensor(_grouped_mm_115, _grouped_mm_58); _grouped_mm_115 = _grouped_mm_58 = None + permute_724 = torch.ops.aten.permute.default(mul_1501, [1, 0]) + _grouped_mm_116 = torch.ops.aten._grouped_mm.default(permute_724, index_39, cumsum_59); permute_724 = None + convert_element_type_1099 = torch.ops.prims.convert_element_type.default(primals_339, torch.bfloat16); primals_339 = None + all_gather_into_tensor_345 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1099, 8, '513'); convert_element_type_1099 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_345); all_gather_into_tensor_345 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_424, [0, 2, 1]); wait_tensor_424 = None + permute_726 = torch.ops.aten.permute.default(permute_306, [0, 2, 1]); permute_306 = None + _grouped_mm_117 = torch.ops.aten._grouped_mm.default(mul_1501, permute_726, cumsum_59); mul_1501 = permute_726 = None + convert_element_type_1922 = torch.ops.prims.convert_element_type.default(mul_1502, torch.float32); mul_1502 = None + reciprocal_13 = torch.ops.aten.reciprocal.default(add_1324); add_1324 = None + mul_1503 = torch.ops.aten.mul.Tensor(reciprocal_13, 1); reciprocal_13 = None + mul_1504 = torch.ops.aten.mul.Tensor(convert_element_type_1922, mul_1503); convert_element_type_1922 = None + sub_662 = torch.ops.aten.sub.Tensor(1, mul_1503); mul_1503 = None + mul_1505 = torch.ops.aten.mul.Tensor(convert_element_type_1102, sub_662); convert_element_type_1102 = sub_662 = None + add_1869 = torch.ops.aten.add.Tensor(mul_1505, 1); mul_1505 = None + mul_1506 = torch.ops.aten.mul.Tensor(mul_1504, add_1869); mul_1504 = add_1869 = None + convert_element_type_1924 = torch.ops.prims.convert_element_type.default(mul_1506, torch.bfloat16); mul_1506 = None + permute_728 = torch.ops.aten.permute.default(convert_element_type_1924, [1, 0]) + _grouped_mm_118 = torch.ops.aten._grouped_mm.default(permute_728, index_39, cumsum_59); permute_728 = index_39 = None + convert_element_type_1096 = torch.ops.prims.convert_element_type.default(primals_337, torch.bfloat16); primals_337 = None + all_gather_into_tensor_342 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1096, 8, '513'); convert_element_type_1096 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_342); all_gather_into_tensor_342 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_421, [0, 2, 1]); wait_tensor_421 = None + permute_730 = torch.ops.aten.permute.default(permute_305, [0, 2, 1]); permute_305 = None + _grouped_mm_119 = torch.ops.aten._grouped_mm.default(convert_element_type_1924, permute_730, cumsum_59); convert_element_type_1924 = permute_730 = cumsum_59 = None + add_1870 = torch.ops.aten.add.Tensor(_grouped_mm_117, _grouped_mm_119); _grouped_mm_117 = _grouped_mm_119 = None + convert_element_type_1925 = torch.ops.prims.convert_element_type.default(_grouped_mm_116, torch.float32); _grouped_mm_116 = None + div_168 = torch.ops.aten.div.Tensor(convert_element_type_1925, 64); convert_element_type_1925 = None + reduce_scatter_tensor_89 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_168, 'sum', 8, '513'); div_168 = None + wait_tensor_661 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_89); reduce_scatter_tensor_89 = None + convert_element_type_1926 = torch.ops.prims.convert_element_type.default(_grouped_mm_114, torch.float32); _grouped_mm_114 = None + div_169 = torch.ops.aten.div.Tensor(convert_element_type_1926, 64); convert_element_type_1926 = None + reduce_scatter_tensor_90 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_169, 'sum', 8, '513'); div_169 = None + wait_tensor_662 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_90); reduce_scatter_tensor_90 = None + convert_element_type_1927 = torch.ops.prims.convert_element_type.default(_grouped_mm_118, torch.float32); _grouped_mm_118 = None + div_170 = torch.ops.aten.div.Tensor(convert_element_type_1927, 64); convert_element_type_1927 = None + reduce_scatter_tensor_91 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_170, 'sum', 8, '513'); div_170 = None + wait_tensor_663 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_91); reduce_scatter_tensor_91 = None + index_put_64 = torch.ops.aten.index_put.default(full_372, [getitem_288], add_1870, True); full_372 = getitem_288 = add_1870 = None + slice_143 = torch.ops.aten.slice.Tensor(index_put_64, 0, 0, add_1871); index_put_64 = add_1871 = None + all_to_all_single_91 = torch.ops._c10d_functional.all_to_all_single.default(slice_143, [_local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311], [_local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319], '521'); slice_143 = _local_scalar_dense_304 = _local_scalar_dense_305 = _local_scalar_dense_306 = _local_scalar_dense_307 = _local_scalar_dense_308 = _local_scalar_dense_309 = _local_scalar_dense_310 = _local_scalar_dense_311 = _local_scalar_dense_312 = _local_scalar_dense_313 = _local_scalar_dense_314 = _local_scalar_dense_315 = _local_scalar_dense_316 = _local_scalar_dense_317 = _local_scalar_dense_318 = _local_scalar_dense_319 = None + wait_tensor_664 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_91); all_to_all_single_91 = None + index_put_65 = torch.ops.aten.index_put.default(full_default_52, [div_97], wait_tensor_664, True); div_97 = wait_tensor_664 = None + add_1875 = torch.ops.aten.add.Tensor(add_1867, index_put_65); add_1867 = index_put_65 = None + mul_1507 = torch.ops.aten.mul.Tensor(view_1897, 1.0); view_1897 = None + scatter_add_6 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_285, mul_1507); getitem_285 = mul_1507 = None + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(mm_163, torch.float32); mm_163 = None + sub_456 = torch.ops.aten.sub.Tensor(convert_element_type_1091, amax_19); convert_element_type_1091 = amax_19 = None + exp_58 = torch.ops.aten.exp.default(sub_456); sub_456 = None + div_96 = torch.ops.aten.div.Tensor(exp_58, sum_77); exp_58 = sum_77 = None + mul_1508 = torch.ops.aten.mul.Tensor(scatter_add_6, div_96); scatter_add_6 = None + sum_155 = torch.ops.aten.sum.dim_IntList(mul_1508, [1], True) + neg_73 = torch.ops.aten.neg.default(div_96); div_96 = None + fma_6 = torch.ops.prims.fma.default(neg_73, sum_155, mul_1508); neg_73 = sum_155 = mul_1508 = None + convert_element_type_1928 = torch.ops.prims.convert_element_type.default(fma_6, torch.bfloat16); fma_6 = None + permute_732 = torch.ops.aten.permute.default(convert_element_type_1928, [1, 0]) + mm_320 = torch.ops.aten.mm.default(permute_732, view_1331); permute_732 = view_1331 = None + convert_element_type_1088 = torch.ops.prims.convert_element_type.default(primals_335, torch.bfloat16); primals_335 = None + all_gather_into_tensor_341 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1088, 64, '0'); convert_element_type_1088 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_341); all_gather_into_tensor_341 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_417, [1, 0]); wait_tensor_417 = None + permute_734 = torch.ops.aten.permute.default(permute_304, [1, 0]); permute_304 = None + mm_321 = torch.ops.aten.mm.default(convert_element_type_1928, permute_734); convert_element_type_1928 = permute_734 = None + add_1876 = torch.ops.aten.add.Tensor(add_1875, mm_321); add_1875 = mm_321 = None + convert_element_type_1933 = torch.ops.prims.convert_element_type.default(mm_320, torch.float32); mm_320 = None + reduce_scatter_tensor_92 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1933, 'avg', 64, '0'); convert_element_type_1933 = None + wait_tensor_665 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_92); reduce_scatter_tensor_92 = None + view_1899 = torch.ops.aten.view.default(add_1876, [2, 4096, 2048]); add_1876 = None + convert_element_type_1934 = torch.ops.prims.convert_element_type.default(view_1899, torch.float32); view_1899 = None + convert_element_type_1085 = torch.ops.prims.convert_element_type.default(primals_333, torch.bfloat16); primals_333 = None + all_gather_into_tensor_340 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1085, 64, '0'); convert_element_type_1085 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_340); all_gather_into_tensor_340 = None + convert_element_type_1936 = torch.ops.prims.convert_element_type.default(wait_tensor_416, torch.float32); wait_tensor_416 = None + mul_1509 = torch.ops.aten.mul.Tensor(convert_element_type_1934, convert_element_type_1936); convert_element_type_1936 = None + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(add_1300, torch.float32); add_1300 = None + mul_946 = torch.ops.aten.mul.Tensor(convert_element_type_1086, rsqrt_62); convert_element_type_1086 = None + mul_1511 = torch.ops.aten.mul.Tensor(mul_946, mul_1509) + sum_156 = torch.ops.aten.sum.dim_IntList(mul_1511, [2], True); mul_1511 = None + div_171 = torch.ops.aten.div.Tensor(mul_946, 2048) + mul_1512 = torch.ops.aten.mul.Tensor(div_171, sum_156); div_171 = sum_156 = None + sub_664 = torch.ops.aten.sub.Tensor(mul_1509, mul_1512); mul_1509 = mul_1512 = None + mul_1513 = torch.ops.aten.mul.Tensor(sub_664, rsqrt_62); sub_664 = rsqrt_62 = None + mul_1514 = torch.ops.aten.mul.Tensor(convert_element_type_1934, mul_946); convert_element_type_1934 = mul_946 = None + sum_157 = torch.ops.aten.sum.dim_IntList(mul_1514, [0, 1]); mul_1514 = None + convert_element_type_1937 = torch.ops.prims.convert_element_type.default(mul_1513, torch.bfloat16); mul_1513 = None + add_1877 = torch.ops.aten.add.Tensor(add_1864, convert_element_type_1937); add_1864 = convert_element_type_1937 = None + convert_element_type_default_63 = torch.ops.prims.convert_element_type.default(sum_157, torch.float32); sum_157 = None + reduce_scatter_tensor_93 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_63, 'avg', 64, '0'); convert_element_type_default_63 = None + wait_tensor_666 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_93); reduce_scatter_tensor_93 = None + view_1900 = torch.ops.aten.view.default(add_1877, [8192, 2048]) + permute_736 = torch.ops.aten.permute.default(view_1900, [1, 0]) + permute_302 = torch.ops.aten.permute.default(getitem_281, [0, 2, 1, 3]) + view_1326 = torch.ops.aten.view.default(permute_302, [2, 4096, -1]); permute_302 = None + view_1328 = torch.ops.aten.view.default(view_1326, [8192, 2048]); view_1326 = None + mm_322 = torch.ops.aten.mm.default(permute_736, view_1328); permute_736 = view_1328 = None + convert_element_type_1082 = torch.ops.prims.convert_element_type.default(primals_332, torch.bfloat16); primals_332 = None + all_gather_into_tensor_339 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1082, 64, '0'); convert_element_type_1082 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_339); all_gather_into_tensor_339 = None + permute_303 = torch.ops.aten.permute.default(wait_tensor_415, [1, 0]); wait_tensor_415 = None + permute_738 = torch.ops.aten.permute.default(permute_303, [1, 0]); permute_303 = None + mm_323 = torch.ops.aten.mm.default(view_1900, permute_738); view_1900 = permute_738 = None + view_1901 = torch.ops.aten.view.default(mm_323, [2, 4096, 2048]); mm_323 = None + convert_element_type_1944 = torch.ops.prims.convert_element_type.default(mm_322, torch.float32); mm_322 = None + reduce_scatter_tensor_94 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1944, 'avg', 64, '0'); convert_element_type_1944 = None + wait_tensor_667 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_94); reduce_scatter_tensor_94 = None + view_1902 = torch.ops.aten.view.default(view_1901, [2, 4096, 16, 128]); view_1901 = None + permute_740 = torch.ops.aten.permute.default(view_1902, [0, 2, 1, 3]); view_1902 = None + fw_graph6 = self.fw_graph6 + joint_graph6 = self.joint_graph6 + mask_graph6 = self.mask_graph6 + flex_attention_backward_6 = torch.ops.higher_order.flex_attention_backward(permute_299, permute_300, permute_301, getitem_281, getitem_282, permute_740, None, fw_graph6, joint_graph6, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph6), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_299 = permute_300 = permute_301 = getitem_281 = getitem_282 = permute_740 = fw_graph6 = joint_graph6 = mask_graph6 = None + getitem_397 = flex_attention_backward_6[0] + getitem_398 = flex_attention_backward_6[1] + getitem_399 = flex_attention_backward_6[2]; flex_attention_backward_6 = None + permute_741 = torch.ops.aten.permute.default(getitem_399, [0, 2, 1, 3]); getitem_399 = None + permute_742 = torch.ops.aten.permute.default(getitem_398, [0, 2, 1, 3]); getitem_398 = None + permute_743 = torch.ops.aten.permute.default(getitem_397, [0, 2, 1, 3]); getitem_397 = None + slice_145 = torch.ops.aten.slice.Tensor(permute_742, 3, 0, 128) + slice_146 = torch.ops.aten.slice.Tensor(permute_742, 3, 128, 192); permute_742 = None + sum_158 = torch.ops.aten.sum.dim_IntList(slice_146, [2], True); slice_146 = None + cat_98 = torch.ops.aten.cat.default([slice_145, permute_741], 3); slice_145 = permute_741 = None + view_1903 = torch.ops.aten.view.default(cat_98, [2, 4096, 4096]); cat_98 = None + view_1904 = torch.ops.aten.view.default(view_1903, [8192, 4096]); view_1903 = None + permute_744 = torch.ops.aten.permute.default(view_1904, [1, 0]) + mm_324 = torch.ops.aten.mm.default(permute_744, view_1323); permute_744 = view_1323 = None + convert_element_type_1079 = torch.ops.prims.convert_element_type.default(primals_331, torch.bfloat16); primals_331 = None + all_gather_into_tensor_338 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1079, 64, '0'); convert_element_type_1079 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_338); all_gather_into_tensor_338 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_414, [1, 0]); wait_tensor_414 = None + permute_746 = torch.ops.aten.permute.default(permute_298, [1, 0]); permute_298 = None + mm_325 = torch.ops.aten.mm.default(view_1904, permute_746); view_1904 = permute_746 = None + view_1905 = torch.ops.aten.view.default(mm_325, [2, 4096, 512]); mm_325 = None + convert_element_type_1949 = torch.ops.prims.convert_element_type.default(mm_324, torch.float32); mm_324 = None + reduce_scatter_tensor_95 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1949, 'avg', 64, '0'); convert_element_type_1949 = None + wait_tensor_668 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_95); reduce_scatter_tensor_95 = None + convert_element_type_1950 = torch.ops.prims.convert_element_type.default(view_1905, torch.float32); view_1905 = None + convert_element_type_1076 = torch.ops.prims.convert_element_type.default(primals_330, torch.bfloat16); primals_330 = None + all_gather_into_tensor_337 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1076, 64, '0'); convert_element_type_1076 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_337); all_gather_into_tensor_337 = None + convert_element_type_1952 = torch.ops.prims.convert_element_type.default(wait_tensor_413, torch.float32); wait_tensor_413 = None + mul_1515 = torch.ops.aten.mul.Tensor(convert_element_type_1950, convert_element_type_1952); convert_element_type_1952 = None + convert_element_type_1077 = torch.ops.prims.convert_element_type.default(getitem_277, torch.float32); getitem_277 = None + mul_944 = torch.ops.aten.mul.Tensor(convert_element_type_1077, rsqrt_61); convert_element_type_1077 = None + mul_1517 = torch.ops.aten.mul.Tensor(mul_944, mul_1515) + sum_159 = torch.ops.aten.sum.dim_IntList(mul_1517, [2], True); mul_1517 = None + div_172 = torch.ops.aten.div.Tensor(mul_944, 512) + mul_1518 = torch.ops.aten.mul.Tensor(div_172, sum_159); div_172 = sum_159 = None + sub_665 = torch.ops.aten.sub.Tensor(mul_1515, mul_1518); mul_1515 = mul_1518 = None + mul_1519 = torch.ops.aten.mul.Tensor(sub_665, rsqrt_61); sub_665 = rsqrt_61 = None + mul_1520 = torch.ops.aten.mul.Tensor(convert_element_type_1950, mul_944); convert_element_type_1950 = mul_944 = None + sum_160 = torch.ops.aten.sum.dim_IntList(mul_1520, [0, 1]); mul_1520 = None + convert_element_type_1953 = torch.ops.prims.convert_element_type.default(mul_1519, torch.bfloat16); mul_1519 = None + convert_element_type_default_62 = torch.ops.prims.convert_element_type.default(sum_160, torch.float32); sum_160 = None + reduce_scatter_tensor_96 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_62, 'avg', 64, '0'); convert_element_type_default_62 = None + wait_tensor_669 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_96); reduce_scatter_tensor_96 = None + convert_element_type_1956 = torch.ops.prims.convert_element_type.default(sum_158, torch.float32); sum_158 = None + view_1906 = torch.ops.aten.view.default(convert_element_type_1956, [2, 4096, 1, 32, 2]); convert_element_type_1956 = None + view_as_complex_66 = torch.ops.aten.view_as_complex.default(view_1906); view_1906 = None + mul_1521 = torch.ops.aten.mul.Tensor(view_as_complex_66, clone_9); view_as_complex_66 = None + view_as_real_66 = torch.ops.aten.view_as_real.default(mul_1521); mul_1521 = None + view_1907 = torch.ops.aten.view.default(view_as_real_66, [2, 4096, 1, 64]); view_as_real_66 = None + convert_element_type_1957 = torch.ops.prims.convert_element_type.default(view_1907, torch.bfloat16); view_1907 = None + squeeze_32 = torch.ops.aten.squeeze.dim(convert_element_type_1957, 2); convert_element_type_1957 = None + cat_99 = torch.ops.aten.cat.default([convert_element_type_1953, squeeze_32], 2); convert_element_type_1953 = squeeze_32 = None + view_1908 = torch.ops.aten.view.default(cat_99, [8192, 576]); cat_99 = None + permute_748 = torch.ops.aten.permute.default(view_1908, [1, 0]) + mm_326 = torch.ops.aten.mm.default(permute_748, view_1309); permute_748 = None + convert_element_type_1071 = torch.ops.prims.convert_element_type.default(primals_329, torch.bfloat16); primals_329 = None + all_gather_into_tensor_336 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1071, 64, '0'); convert_element_type_1071 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_336); all_gather_into_tensor_336 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_412, [1, 0]); wait_tensor_412 = None + permute_750 = torch.ops.aten.permute.default(permute_297, [1, 0]); permute_297 = None + mm_327 = torch.ops.aten.mm.default(view_1908, permute_750); view_1908 = permute_750 = None + view_1909 = torch.ops.aten.view.default(mm_327, [2, 4096, 2048]); mm_327 = None + convert_element_type_1962 = torch.ops.prims.convert_element_type.default(mm_326, torch.float32); mm_326 = None + reduce_scatter_tensor_97 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1962, 'avg', 64, '0'); convert_element_type_1962 = None + wait_tensor_670 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_97); reduce_scatter_tensor_97 = None + slice_147 = torch.ops.aten.slice.Tensor(permute_743, 3, 0, 128) + slice_148 = torch.ops.aten.slice.Tensor(permute_743, 3, 128, 192); permute_743 = None + convert_element_type_1963 = torch.ops.prims.convert_element_type.default(slice_148, torch.float32); slice_148 = None + view_1910 = torch.ops.aten.view.default(convert_element_type_1963, [2, 4096, 16, 32, 2]); convert_element_type_1963 = None + view_as_complex_67 = torch.ops.aten.view_as_complex.default(view_1910); view_1910 = None + mul_1522 = torch.ops.aten.mul.Tensor(view_as_complex_67, clone_9); view_as_complex_67 = None + view_as_real_67 = torch.ops.aten.view_as_real.default(mul_1522); mul_1522 = None + view_1911 = torch.ops.aten.view.default(view_as_real_67, [2, 4096, 16, 64]); view_as_real_67 = None + convert_element_type_1964 = torch.ops.prims.convert_element_type.default(view_1911, torch.bfloat16); view_1911 = None + cat_100 = torch.ops.aten.cat.default([slice_147, convert_element_type_1964], 3); slice_147 = convert_element_type_1964 = None + view_1912 = torch.ops.aten.view.default(cat_100, [2, 4096, 3072]); cat_100 = None + view_1913 = torch.ops.aten.view.default(view_1912, [8192, 3072]); view_1912 = None + permute_752 = torch.ops.aten.permute.default(view_1913, [1, 0]) + mm_328 = torch.ops.aten.mm.default(permute_752, view_1309); permute_752 = view_1309 = None + convert_element_type_1066 = torch.ops.prims.convert_element_type.default(primals_328, torch.bfloat16); primals_328 = None + all_gather_into_tensor_335 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1066, 64, '0'); convert_element_type_1066 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_335); all_gather_into_tensor_335 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_411, [1, 0]); wait_tensor_411 = None + permute_754 = torch.ops.aten.permute.default(permute_296, [1, 0]); permute_296 = None + mm_329 = torch.ops.aten.mm.default(view_1913, permute_754); view_1913 = permute_754 = None + view_1914 = torch.ops.aten.view.default(mm_329, [2, 4096, 2048]); mm_329 = None + add_1878 = torch.ops.aten.add.Tensor(view_1909, view_1914); view_1909 = view_1914 = None + convert_element_type_1969 = torch.ops.prims.convert_element_type.default(mm_328, torch.float32); mm_328 = None + reduce_scatter_tensor_98 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1969, 'avg', 64, '0'); convert_element_type_1969 = None + wait_tensor_671 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_98); reduce_scatter_tensor_98 = None + convert_element_type_1970 = torch.ops.prims.convert_element_type.default(add_1878, torch.float32); add_1878 = None + convert_element_type_1063 = torch.ops.prims.convert_element_type.default(primals_327, torch.bfloat16); primals_327 = None + all_gather_into_tensor_334 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1063, 64, '0'); convert_element_type_1063 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_334); all_gather_into_tensor_334 = None + convert_element_type_1972 = torch.ops.prims.convert_element_type.default(wait_tensor_410, torch.float32); wait_tensor_410 = None + mul_1523 = torch.ops.aten.mul.Tensor(convert_element_type_1970, convert_element_type_1972); convert_element_type_1972 = None + convert_element_type_1064 = torch.ops.prims.convert_element_type.default(add_1297, torch.float32); add_1297 = None + mul_940 = torch.ops.aten.mul.Tensor(convert_element_type_1064, rsqrt_60); convert_element_type_1064 = None + mul_1525 = torch.ops.aten.mul.Tensor(mul_940, mul_1523) + sum_161 = torch.ops.aten.sum.dim_IntList(mul_1525, [2], True); mul_1525 = None + div_173 = torch.ops.aten.div.Tensor(mul_940, 2048) + mul_1526 = torch.ops.aten.mul.Tensor(div_173, sum_161); div_173 = sum_161 = None + sub_666 = torch.ops.aten.sub.Tensor(mul_1523, mul_1526); mul_1523 = mul_1526 = None + mul_1527 = torch.ops.aten.mul.Tensor(sub_666, rsqrt_60); sub_666 = rsqrt_60 = None + mul_1528 = torch.ops.aten.mul.Tensor(convert_element_type_1970, mul_940); convert_element_type_1970 = mul_940 = None + sum_162 = torch.ops.aten.sum.dim_IntList(mul_1528, [0, 1]); mul_1528 = None + convert_element_type_1973 = torch.ops.prims.convert_element_type.default(mul_1527, torch.bfloat16); mul_1527 = None + add_1879 = torch.ops.aten.add.Tensor(add_1877, convert_element_type_1973); add_1877 = convert_element_type_1973 = None + convert_element_type_default_61 = torch.ops.prims.convert_element_type.default(sum_162, torch.float32); sum_162 = None + reduce_scatter_tensor_99 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_61, 'avg', 64, '0'); convert_element_type_default_61 = None + wait_tensor_672 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_99); reduce_scatter_tensor_99 = None + view_1915 = torch.ops.aten.view.default(add_1879, [8192, 2048]) + unsqueeze_60 = torch.ops.aten.unsqueeze.default(view_1915, 1) + convert_element_type_1976 = torch.ops.prims.convert_element_type.default(unsqueeze_60, torch.float32); unsqueeze_60 = None + bmm_40 = torch.ops.aten.bmm.default(permute_756, convert_element_type_1976); permute_756 = None + bmm_41 = torch.ops.aten.bmm.default(convert_element_type_1976, permute_757); convert_element_type_1976 = permute_757 = None + convert_element_type_1977 = torch.ops.prims.convert_element_type.default(bmm_40, torch.bfloat16); bmm_40 = None + view_1916 = torch.ops.aten.view.default(bmm_41, [8192, 6]); bmm_41 = None + view_1917 = torch.ops.aten.view.default(convert_element_type_1977, [49152, 2048]); convert_element_type_1977 = None + index_66 = torch.ops.aten.index.Tensor(view_1917, [getitem_273]); view_1917 = getitem_273 = None + permute_758 = torch.ops.aten.permute.default(view_1915, [1, 0]) + mm_330 = torch.ops.aten.mm.default(permute_758, mul_937); permute_758 = mul_937 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(primals_326, torch.bfloat16); primals_326 = None + all_gather_into_tensor_333 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1058, 64, '0'); convert_element_type_1058 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_333); all_gather_into_tensor_333 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_409, [1, 0]); wait_tensor_409 = None + permute_760 = torch.ops.aten.permute.default(permute_295, [1, 0]); permute_295 = None + mm_331 = torch.ops.aten.mm.default(view_1915, permute_760); view_1915 = permute_760 = None + convert_element_type_1982 = torch.ops.prims.convert_element_type.default(mm_330, torch.float32); mm_330 = None + reduce_scatter_tensor_100 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1982, 'avg', 64, '0'); convert_element_type_1982 = None + wait_tensor_673 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_100); reduce_scatter_tensor_100 = None + convert_element_type_1053 = torch.ops.prims.convert_element_type.default(mm_156, torch.float32); mm_156 = None + neg_38 = torch.ops.aten.neg.default(convert_element_type_1053) + exp_57 = torch.ops.aten.exp.default(neg_38); neg_38 = None + add_1292 = torch.ops.aten.add.Tensor(exp_57, 1); exp_57 = None + div_95 = torch.ops.aten.div.Tensor(convert_element_type_1053, add_1292) + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(div_95, torch.bfloat16); div_95 = None + mul_1529 = torch.ops.aten.mul.Tensor(mm_331, convert_element_type_1054); convert_element_type_1054 = None + mul_1530 = torch.ops.aten.mul.Tensor(mm_331, mm_157); mm_331 = mm_157 = None + permute_762 = torch.ops.aten.permute.default(mul_1529, [1, 0]) + mm_332 = torch.ops.aten.mm.default(permute_762, view_1264); permute_762 = None + convert_element_type_1055 = torch.ops.prims.convert_element_type.default(primals_325, torch.bfloat16); primals_325 = None + all_gather_into_tensor_332 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1055, 64, '0'); convert_element_type_1055 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_332); all_gather_into_tensor_332 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_408, [1, 0]); wait_tensor_408 = None + permute_764 = torch.ops.aten.permute.default(permute_294, [1, 0]); permute_294 = None + mm_333 = torch.ops.aten.mm.default(mul_1529, permute_764); mul_1529 = permute_764 = None + convert_element_type_1987 = torch.ops.prims.convert_element_type.default(mm_332, torch.float32); mm_332 = None + reduce_scatter_tensor_101 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1987, 'avg', 64, '0'); convert_element_type_1987 = None + wait_tensor_674 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_101); reduce_scatter_tensor_101 = None + convert_element_type_1988 = torch.ops.prims.convert_element_type.default(mul_1530, torch.float32); mul_1530 = None + reciprocal_14 = torch.ops.aten.reciprocal.default(add_1292); add_1292 = None + mul_1531 = torch.ops.aten.mul.Tensor(reciprocal_14, 1); reciprocal_14 = None + mul_1532 = torch.ops.aten.mul.Tensor(convert_element_type_1988, mul_1531); convert_element_type_1988 = None + sub_667 = torch.ops.aten.sub.Tensor(1, mul_1531); mul_1531 = None + mul_1533 = torch.ops.aten.mul.Tensor(convert_element_type_1053, sub_667); convert_element_type_1053 = sub_667 = None + add_1881 = torch.ops.aten.add.Tensor(mul_1533, 1); mul_1533 = None + mul_1534 = torch.ops.aten.mul.Tensor(mul_1532, add_1881); mul_1532 = add_1881 = None + convert_element_type_1990 = torch.ops.prims.convert_element_type.default(mul_1534, torch.bfloat16); mul_1534 = None + permute_766 = torch.ops.aten.permute.default(convert_element_type_1990, [1, 0]) + mm_334 = torch.ops.aten.mm.default(permute_766, view_1264); permute_766 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(primals_324, torch.bfloat16); primals_324 = None + all_gather_into_tensor_331 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1050, 64, '0'); convert_element_type_1050 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_331); all_gather_into_tensor_331 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_407, [1, 0]); wait_tensor_407 = None + permute_768 = torch.ops.aten.permute.default(permute_293, [1, 0]); permute_293 = None + mm_335 = torch.ops.aten.mm.default(convert_element_type_1990, permute_768); convert_element_type_1990 = permute_768 = None + add_1882 = torch.ops.aten.add.Tensor(mm_333, mm_335); mm_333 = mm_335 = None + convert_element_type_1995 = torch.ops.prims.convert_element_type.default(mm_334, torch.float32); mm_334 = None + reduce_scatter_tensor_102 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1995, 'avg', 64, '0'); convert_element_type_1995 = None + wait_tensor_675 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_102); reduce_scatter_tensor_102 = None + all_to_all_single_92 = torch.ops._c10d_functional.all_to_all_single.default(index_66, [_local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303], [_local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295], '521'); index_66 = None + wait_tensor_676 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_92); all_to_all_single_92 = None + full_376 = torch.ops.aten.full.default([sym_size_int_73, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_73 = None + slice_scatter_7 = torch.ops.aten.slice_scatter.default(full_376, wait_tensor_676, 0, 0, -1); wait_tensor_676 = None + index_67 = torch.ops.aten.index.Tensor(slice_scatter_7, [getitem_274]); slice_scatter_7 = None + permute_770 = torch.ops.aten.permute.default(index_67, [1, 0]) + _grouped_mm_120 = torch.ops.aten._grouped_mm.default(permute_770, mul_917, cumsum_56); permute_770 = mul_917 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(primals_322, torch.bfloat16); primals_322 = None + all_gather_into_tensor_327 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1044, 8, '513'); convert_element_type_1044 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_327); all_gather_into_tensor_327 = None + permute_292 = torch.ops.aten.permute.default(wait_tensor_402, [0, 2, 1]); wait_tensor_402 = None + permute_772 = torch.ops.aten.permute.default(permute_292, [0, 2, 1]); permute_292 = None + _grouped_mm_121 = torch.ops.aten._grouped_mm.default(index_67, permute_772, cumsum_56); index_67 = permute_772 = None + convert_element_type_1048 = torch.ops.prims.convert_element_type.default(_grouped_mm_54, torch.float32); _grouped_mm_54 = None + neg_37 = torch.ops.aten.neg.default(convert_element_type_1048) + exp_56 = torch.ops.aten.exp.default(neg_37); neg_37 = None + add_1256 = torch.ops.aten.add.Tensor(exp_56, 1); exp_56 = None + div_94 = torch.ops.aten.div.Tensor(convert_element_type_1048, add_1256) + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(div_94, torch.bfloat16); div_94 = None + mul_1535 = torch.ops.aten.mul.Tensor(_grouped_mm_121, convert_element_type_1049); convert_element_type_1049 = None + mul_1536 = torch.ops.aten.mul.Tensor(_grouped_mm_121, _grouped_mm_55); _grouped_mm_121 = _grouped_mm_55 = None + permute_774 = torch.ops.aten.permute.default(mul_1535, [1, 0]) + _grouped_mm_122 = torch.ops.aten._grouped_mm.default(permute_774, index_37, cumsum_56); permute_774 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(primals_323, torch.bfloat16); primals_323 = None + all_gather_into_tensor_328 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1045, 8, '513'); convert_element_type_1045 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_328); all_gather_into_tensor_328 = None + permute_291 = torch.ops.aten.permute.default(wait_tensor_403, [0, 2, 1]); wait_tensor_403 = None + permute_776 = torch.ops.aten.permute.default(permute_291, [0, 2, 1]); permute_291 = None + _grouped_mm_123 = torch.ops.aten._grouped_mm.default(mul_1535, permute_776, cumsum_56); mul_1535 = permute_776 = None + convert_element_type_1996 = torch.ops.prims.convert_element_type.default(mul_1536, torch.float32); mul_1536 = None + reciprocal_15 = torch.ops.aten.reciprocal.default(add_1256); add_1256 = None + mul_1537 = torch.ops.aten.mul.Tensor(reciprocal_15, 1); reciprocal_15 = None + mul_1538 = torch.ops.aten.mul.Tensor(convert_element_type_1996, mul_1537); convert_element_type_1996 = None + sub_668 = torch.ops.aten.sub.Tensor(1, mul_1537); mul_1537 = None + mul_1539 = torch.ops.aten.mul.Tensor(convert_element_type_1048, sub_668); convert_element_type_1048 = sub_668 = None + add_1884 = torch.ops.aten.add.Tensor(mul_1539, 1); mul_1539 = None + mul_1540 = torch.ops.aten.mul.Tensor(mul_1538, add_1884); mul_1538 = add_1884 = None + convert_element_type_1998 = torch.ops.prims.convert_element_type.default(mul_1540, torch.bfloat16); mul_1540 = None + permute_778 = torch.ops.aten.permute.default(convert_element_type_1998, [1, 0]) + _grouped_mm_124 = torch.ops.aten._grouped_mm.default(permute_778, index_37, cumsum_56); permute_778 = index_37 = None + convert_element_type_1042 = torch.ops.prims.convert_element_type.default(primals_321, torch.bfloat16); primals_321 = None + all_gather_into_tensor_325 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1042, 8, '513'); convert_element_type_1042 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_325); all_gather_into_tensor_325 = None + permute_290 = torch.ops.aten.permute.default(wait_tensor_400, [0, 2, 1]); wait_tensor_400 = None + permute_780 = torch.ops.aten.permute.default(permute_290, [0, 2, 1]); permute_290 = None + _grouped_mm_125 = torch.ops.aten._grouped_mm.default(convert_element_type_1998, permute_780, cumsum_56); convert_element_type_1998 = permute_780 = cumsum_56 = None + add_1885 = torch.ops.aten.add.Tensor(_grouped_mm_123, _grouped_mm_125); _grouped_mm_123 = _grouped_mm_125 = None + convert_element_type_1999 = torch.ops.prims.convert_element_type.default(_grouped_mm_122, torch.float32); _grouped_mm_122 = None + div_174 = torch.ops.aten.div.Tensor(convert_element_type_1999, 64); convert_element_type_1999 = None + reduce_scatter_tensor_103 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_174, 'sum', 8, '513'); div_174 = None + wait_tensor_677 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_103); reduce_scatter_tensor_103 = None + convert_element_type_2000 = torch.ops.prims.convert_element_type.default(_grouped_mm_120, torch.float32); _grouped_mm_120 = None + div_175 = torch.ops.aten.div.Tensor(convert_element_type_2000, 64); convert_element_type_2000 = None + reduce_scatter_tensor_104 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_175, 'sum', 8, '513'); div_175 = None + wait_tensor_678 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_104); reduce_scatter_tensor_104 = None + convert_element_type_2001 = torch.ops.prims.convert_element_type.default(_grouped_mm_124, torch.float32); _grouped_mm_124 = None + div_176 = torch.ops.aten.div.Tensor(convert_element_type_2001, 64); convert_element_type_2001 = None + reduce_scatter_tensor_105 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_176, 'sum', 8, '513'); div_176 = None + wait_tensor_679 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_105); reduce_scatter_tensor_105 = None + index_put_66 = torch.ops.aten.index_put.default(full_376, [getitem_274], add_1885, True); full_376 = getitem_274 = add_1885 = None + slice_149 = torch.ops.aten.slice.Tensor(index_put_66, 0, 0, add_1886); index_put_66 = add_1886 = None + all_to_all_single_93 = torch.ops._c10d_functional.all_to_all_single.default(slice_149, [_local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295], [_local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303], '521'); slice_149 = _local_scalar_dense_288 = _local_scalar_dense_289 = _local_scalar_dense_290 = _local_scalar_dense_291 = _local_scalar_dense_292 = _local_scalar_dense_293 = _local_scalar_dense_294 = _local_scalar_dense_295 = _local_scalar_dense_296 = _local_scalar_dense_297 = _local_scalar_dense_298 = _local_scalar_dense_299 = _local_scalar_dense_300 = _local_scalar_dense_301 = _local_scalar_dense_302 = _local_scalar_dense_303 = None + wait_tensor_680 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_93); all_to_all_single_93 = None + index_put_67 = torch.ops.aten.index_put.default(full_default_52, [div_92], wait_tensor_680, True); div_92 = wait_tensor_680 = None + add_1890 = torch.ops.aten.add.Tensor(add_1882, index_put_67); add_1882 = index_put_67 = None + mul_1541 = torch.ops.aten.mul.Tensor(view_1916, 1.0); view_1916 = None + scatter_add_7 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_271, mul_1541); getitem_271 = mul_1541 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(mm_155, torch.float32); mm_155 = None + sub_432 = torch.ops.aten.sub.Tensor(convert_element_type_1037, amax_18); convert_element_type_1037 = amax_18 = None + exp_55 = torch.ops.aten.exp.default(sub_432); sub_432 = None + div_91 = torch.ops.aten.div.Tensor(exp_55, sum_73); exp_55 = sum_73 = None + mul_1542 = torch.ops.aten.mul.Tensor(scatter_add_7, div_91); scatter_add_7 = None + sum_163 = torch.ops.aten.sum.dim_IntList(mul_1542, [1], True) + neg_76 = torch.ops.aten.neg.default(div_91); div_91 = None + fma_7 = torch.ops.prims.fma.default(neg_76, sum_163, mul_1542); neg_76 = sum_163 = mul_1542 = None + convert_element_type_2002 = torch.ops.prims.convert_element_type.default(fma_7, torch.bfloat16); fma_7 = None + permute_782 = torch.ops.aten.permute.default(convert_element_type_2002, [1, 0]) + mm_336 = torch.ops.aten.mm.default(permute_782, view_1264); permute_782 = view_1264 = None + convert_element_type_1034 = torch.ops.prims.convert_element_type.default(primals_319, torch.bfloat16); primals_319 = None + all_gather_into_tensor_324 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1034, 64, '0'); convert_element_type_1034 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_324); all_gather_into_tensor_324 = None + permute_289 = torch.ops.aten.permute.default(wait_tensor_396, [1, 0]); wait_tensor_396 = None + permute_784 = torch.ops.aten.permute.default(permute_289, [1, 0]); permute_289 = None + mm_337 = torch.ops.aten.mm.default(convert_element_type_2002, permute_784); convert_element_type_2002 = permute_784 = None + add_1891 = torch.ops.aten.add.Tensor(add_1890, mm_337); add_1890 = mm_337 = None + convert_element_type_2007 = torch.ops.prims.convert_element_type.default(mm_336, torch.float32); mm_336 = None + reduce_scatter_tensor_106 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2007, 'avg', 64, '0'); convert_element_type_2007 = None + wait_tensor_681 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_106); reduce_scatter_tensor_106 = None + view_1918 = torch.ops.aten.view.default(add_1891, [2, 4096, 2048]); add_1891 = None + convert_element_type_2008 = torch.ops.prims.convert_element_type.default(view_1918, torch.float32); view_1918 = None + convert_element_type_1031 = torch.ops.prims.convert_element_type.default(primals_317, torch.bfloat16); primals_317 = None + all_gather_into_tensor_323 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1031, 64, '0'); convert_element_type_1031 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_323); all_gather_into_tensor_323 = None + convert_element_type_2010 = torch.ops.prims.convert_element_type.default(wait_tensor_395, torch.float32); wait_tensor_395 = None + mul_1543 = torch.ops.aten.mul.Tensor(convert_element_type_2008, convert_element_type_2010); convert_element_type_2010 = None + convert_element_type_1032 = torch.ops.prims.convert_element_type.default(add_1232, torch.float32); add_1232 = None + mul_897 = torch.ops.aten.mul.Tensor(convert_element_type_1032, rsqrt_59); convert_element_type_1032 = None + mul_1545 = torch.ops.aten.mul.Tensor(mul_897, mul_1543) + sum_164 = torch.ops.aten.sum.dim_IntList(mul_1545, [2], True); mul_1545 = None + div_177 = torch.ops.aten.div.Tensor(mul_897, 2048) + mul_1546 = torch.ops.aten.mul.Tensor(div_177, sum_164); div_177 = sum_164 = None + sub_670 = torch.ops.aten.sub.Tensor(mul_1543, mul_1546); mul_1543 = mul_1546 = None + mul_1547 = torch.ops.aten.mul.Tensor(sub_670, rsqrt_59); sub_670 = rsqrt_59 = None + mul_1548 = torch.ops.aten.mul.Tensor(convert_element_type_2008, mul_897); convert_element_type_2008 = mul_897 = None + sum_165 = torch.ops.aten.sum.dim_IntList(mul_1548, [0, 1]); mul_1548 = None + convert_element_type_2011 = torch.ops.prims.convert_element_type.default(mul_1547, torch.bfloat16); mul_1547 = None + add_1892 = torch.ops.aten.add.Tensor(add_1879, convert_element_type_2011); add_1879 = convert_element_type_2011 = None + convert_element_type_default_60 = torch.ops.prims.convert_element_type.default(sum_165, torch.float32); sum_165 = None + reduce_scatter_tensor_107 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_60, 'avg', 64, '0'); convert_element_type_default_60 = None + wait_tensor_682 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_107); reduce_scatter_tensor_107 = None + view_1919 = torch.ops.aten.view.default(add_1892, [8192, 2048]) + permute_786 = torch.ops.aten.permute.default(view_1919, [1, 0]) + permute_287 = torch.ops.aten.permute.default(getitem_267, [0, 2, 1, 3]) + view_1259 = torch.ops.aten.view.default(permute_287, [2, 4096, -1]); permute_287 = None + view_1261 = torch.ops.aten.view.default(view_1259, [8192, 2048]); view_1259 = None + mm_338 = torch.ops.aten.mm.default(permute_786, view_1261); permute_786 = view_1261 = None + convert_element_type_1028 = torch.ops.prims.convert_element_type.default(primals_316, torch.bfloat16); primals_316 = None + all_gather_into_tensor_322 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1028, 64, '0'); convert_element_type_1028 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_322); all_gather_into_tensor_322 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_394, [1, 0]); wait_tensor_394 = None + permute_788 = torch.ops.aten.permute.default(permute_288, [1, 0]); permute_288 = None + mm_339 = torch.ops.aten.mm.default(view_1919, permute_788); view_1919 = permute_788 = None + view_1920 = torch.ops.aten.view.default(mm_339, [2, 4096, 2048]); mm_339 = None + convert_element_type_2018 = torch.ops.prims.convert_element_type.default(mm_338, torch.float32); mm_338 = None + reduce_scatter_tensor_108 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2018, 'avg', 64, '0'); convert_element_type_2018 = None + wait_tensor_683 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_108); reduce_scatter_tensor_108 = None + view_1921 = torch.ops.aten.view.default(view_1920, [2, 4096, 16, 128]); view_1920 = None + permute_790 = torch.ops.aten.permute.default(view_1921, [0, 2, 1, 3]); view_1921 = None + fw_graph7 = self.fw_graph7 + joint_graph7 = self.joint_graph7 + mask_graph7 = self.mask_graph7 + flex_attention_backward_7 = torch.ops.higher_order.flex_attention_backward(permute_284, permute_285, permute_286, getitem_267, getitem_268, permute_790, None, fw_graph7, joint_graph7, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph7), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_284 = permute_285 = permute_286 = getitem_267 = getitem_268 = permute_790 = fw_graph7 = joint_graph7 = mask_graph7 = None + getitem_401 = flex_attention_backward_7[0] + getitem_402 = flex_attention_backward_7[1] + getitem_403 = flex_attention_backward_7[2]; flex_attention_backward_7 = None + permute_791 = torch.ops.aten.permute.default(getitem_403, [0, 2, 1, 3]); getitem_403 = None + permute_792 = torch.ops.aten.permute.default(getitem_402, [0, 2, 1, 3]); getitem_402 = None + permute_793 = torch.ops.aten.permute.default(getitem_401, [0, 2, 1, 3]); getitem_401 = None + slice_151 = torch.ops.aten.slice.Tensor(permute_792, 3, 0, 128) + slice_152 = torch.ops.aten.slice.Tensor(permute_792, 3, 128, 192); permute_792 = None + sum_166 = torch.ops.aten.sum.dim_IntList(slice_152, [2], True); slice_152 = None + cat_101 = torch.ops.aten.cat.default([slice_151, permute_791], 3); slice_151 = permute_791 = None + view_1922 = torch.ops.aten.view.default(cat_101, [2, 4096, 4096]); cat_101 = None + view_1923 = torch.ops.aten.view.default(view_1922, [8192, 4096]); view_1922 = None + permute_794 = torch.ops.aten.permute.default(view_1923, [1, 0]) + mm_340 = torch.ops.aten.mm.default(permute_794, view_1256); permute_794 = view_1256 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(primals_315, torch.bfloat16); primals_315 = None + all_gather_into_tensor_321 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1025, 64, '0'); convert_element_type_1025 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_321); all_gather_into_tensor_321 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_393, [1, 0]); wait_tensor_393 = None + permute_796 = torch.ops.aten.permute.default(permute_283, [1, 0]); permute_283 = None + mm_341 = torch.ops.aten.mm.default(view_1923, permute_796); view_1923 = permute_796 = None + view_1924 = torch.ops.aten.view.default(mm_341, [2, 4096, 512]); mm_341 = None + convert_element_type_2023 = torch.ops.prims.convert_element_type.default(mm_340, torch.float32); mm_340 = None + reduce_scatter_tensor_109 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2023, 'avg', 64, '0'); convert_element_type_2023 = None + wait_tensor_684 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_109); reduce_scatter_tensor_109 = None + convert_element_type_2024 = torch.ops.prims.convert_element_type.default(view_1924, torch.float32); view_1924 = None + convert_element_type_1022 = torch.ops.prims.convert_element_type.default(primals_314, torch.bfloat16); primals_314 = None + all_gather_into_tensor_320 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1022, 64, '0'); convert_element_type_1022 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_320); all_gather_into_tensor_320 = None + convert_element_type_2026 = torch.ops.prims.convert_element_type.default(wait_tensor_392, torch.float32); wait_tensor_392 = None + mul_1549 = torch.ops.aten.mul.Tensor(convert_element_type_2024, convert_element_type_2026); convert_element_type_2026 = None + convert_element_type_1023 = torch.ops.prims.convert_element_type.default(getitem_263, torch.float32); getitem_263 = None + mul_895 = torch.ops.aten.mul.Tensor(convert_element_type_1023, rsqrt_58); convert_element_type_1023 = None + mul_1551 = torch.ops.aten.mul.Tensor(mul_895, mul_1549) + sum_167 = torch.ops.aten.sum.dim_IntList(mul_1551, [2], True); mul_1551 = None + div_178 = torch.ops.aten.div.Tensor(mul_895, 512) + mul_1552 = torch.ops.aten.mul.Tensor(div_178, sum_167); div_178 = sum_167 = None + sub_671 = torch.ops.aten.sub.Tensor(mul_1549, mul_1552); mul_1549 = mul_1552 = None + mul_1553 = torch.ops.aten.mul.Tensor(sub_671, rsqrt_58); sub_671 = rsqrt_58 = None + mul_1554 = torch.ops.aten.mul.Tensor(convert_element_type_2024, mul_895); convert_element_type_2024 = mul_895 = None + sum_168 = torch.ops.aten.sum.dim_IntList(mul_1554, [0, 1]); mul_1554 = None + convert_element_type_2027 = torch.ops.prims.convert_element_type.default(mul_1553, torch.bfloat16); mul_1553 = None + convert_element_type_default_59 = torch.ops.prims.convert_element_type.default(sum_168, torch.float32); sum_168 = None + reduce_scatter_tensor_110 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_59, 'avg', 64, '0'); convert_element_type_default_59 = None + wait_tensor_685 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_110); reduce_scatter_tensor_110 = None + convert_element_type_2030 = torch.ops.prims.convert_element_type.default(sum_166, torch.float32); sum_166 = None + view_1925 = torch.ops.aten.view.default(convert_element_type_2030, [2, 4096, 1, 32, 2]); convert_element_type_2030 = None + view_as_complex_68 = torch.ops.aten.view_as_complex.default(view_1925); view_1925 = None + mul_1555 = torch.ops.aten.mul.Tensor(view_as_complex_68, clone_9); view_as_complex_68 = None + view_as_real_68 = torch.ops.aten.view_as_real.default(mul_1555); mul_1555 = None + view_1926 = torch.ops.aten.view.default(view_as_real_68, [2, 4096, 1, 64]); view_as_real_68 = None + convert_element_type_2031 = torch.ops.prims.convert_element_type.default(view_1926, torch.bfloat16); view_1926 = None + squeeze_33 = torch.ops.aten.squeeze.dim(convert_element_type_2031, 2); convert_element_type_2031 = None + cat_102 = torch.ops.aten.cat.default([convert_element_type_2027, squeeze_33], 2); convert_element_type_2027 = squeeze_33 = None + view_1927 = torch.ops.aten.view.default(cat_102, [8192, 576]); cat_102 = None + permute_798 = torch.ops.aten.permute.default(view_1927, [1, 0]) + mm_342 = torch.ops.aten.mm.default(permute_798, view_1242); permute_798 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(primals_313, torch.bfloat16); primals_313 = None + all_gather_into_tensor_319 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1017, 64, '0'); convert_element_type_1017 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_319); all_gather_into_tensor_319 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_391, [1, 0]); wait_tensor_391 = None + permute_800 = torch.ops.aten.permute.default(permute_282, [1, 0]); permute_282 = None + mm_343 = torch.ops.aten.mm.default(view_1927, permute_800); view_1927 = permute_800 = None + view_1928 = torch.ops.aten.view.default(mm_343, [2, 4096, 2048]); mm_343 = None + convert_element_type_2036 = torch.ops.prims.convert_element_type.default(mm_342, torch.float32); mm_342 = None + reduce_scatter_tensor_111 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2036, 'avg', 64, '0'); convert_element_type_2036 = None + wait_tensor_686 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_111); reduce_scatter_tensor_111 = None + slice_153 = torch.ops.aten.slice.Tensor(permute_793, 3, 0, 128) + slice_154 = torch.ops.aten.slice.Tensor(permute_793, 3, 128, 192); permute_793 = None + convert_element_type_2037 = torch.ops.prims.convert_element_type.default(slice_154, torch.float32); slice_154 = None + view_1929 = torch.ops.aten.view.default(convert_element_type_2037, [2, 4096, 16, 32, 2]); convert_element_type_2037 = None + view_as_complex_69 = torch.ops.aten.view_as_complex.default(view_1929); view_1929 = None + mul_1556 = torch.ops.aten.mul.Tensor(view_as_complex_69, clone_9); view_as_complex_69 = None + view_as_real_69 = torch.ops.aten.view_as_real.default(mul_1556); mul_1556 = None + view_1930 = torch.ops.aten.view.default(view_as_real_69, [2, 4096, 16, 64]); view_as_real_69 = None + convert_element_type_2038 = torch.ops.prims.convert_element_type.default(view_1930, torch.bfloat16); view_1930 = None + cat_103 = torch.ops.aten.cat.default([slice_153, convert_element_type_2038], 3); slice_153 = convert_element_type_2038 = None + view_1931 = torch.ops.aten.view.default(cat_103, [2, 4096, 3072]); cat_103 = None + view_1932 = torch.ops.aten.view.default(view_1931, [8192, 3072]); view_1931 = None + permute_802 = torch.ops.aten.permute.default(view_1932, [1, 0]) + mm_344 = torch.ops.aten.mm.default(permute_802, view_1242); permute_802 = view_1242 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(primals_312, torch.bfloat16); primals_312 = None + all_gather_into_tensor_318 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1012, 64, '0'); convert_element_type_1012 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_318); all_gather_into_tensor_318 = None + permute_281 = torch.ops.aten.permute.default(wait_tensor_390, [1, 0]); wait_tensor_390 = None + permute_804 = torch.ops.aten.permute.default(permute_281, [1, 0]); permute_281 = None + mm_345 = torch.ops.aten.mm.default(view_1932, permute_804); view_1932 = permute_804 = None + view_1933 = torch.ops.aten.view.default(mm_345, [2, 4096, 2048]); mm_345 = None + add_1893 = torch.ops.aten.add.Tensor(view_1928, view_1933); view_1928 = view_1933 = None + convert_element_type_2043 = torch.ops.prims.convert_element_type.default(mm_344, torch.float32); mm_344 = None + reduce_scatter_tensor_112 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2043, 'avg', 64, '0'); convert_element_type_2043 = None + wait_tensor_687 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_112); reduce_scatter_tensor_112 = None + convert_element_type_2044 = torch.ops.prims.convert_element_type.default(add_1893, torch.float32); add_1893 = None + convert_element_type_1009 = torch.ops.prims.convert_element_type.default(primals_311, torch.bfloat16); primals_311 = None + all_gather_into_tensor_317 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1009, 64, '0'); convert_element_type_1009 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_317); all_gather_into_tensor_317 = None + convert_element_type_2046 = torch.ops.prims.convert_element_type.default(wait_tensor_389, torch.float32); wait_tensor_389 = None + mul_1557 = torch.ops.aten.mul.Tensor(convert_element_type_2044, convert_element_type_2046); convert_element_type_2046 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(add_1229, torch.float32); add_1229 = None + mul_891 = torch.ops.aten.mul.Tensor(convert_element_type_1010, rsqrt_57); convert_element_type_1010 = None + mul_1559 = torch.ops.aten.mul.Tensor(mul_891, mul_1557) + sum_169 = torch.ops.aten.sum.dim_IntList(mul_1559, [2], True); mul_1559 = None + div_179 = torch.ops.aten.div.Tensor(mul_891, 2048) + mul_1560 = torch.ops.aten.mul.Tensor(div_179, sum_169); div_179 = sum_169 = None + sub_672 = torch.ops.aten.sub.Tensor(mul_1557, mul_1560); mul_1557 = mul_1560 = None + mul_1561 = torch.ops.aten.mul.Tensor(sub_672, rsqrt_57); sub_672 = rsqrt_57 = None + mul_1562 = torch.ops.aten.mul.Tensor(convert_element_type_2044, mul_891); convert_element_type_2044 = mul_891 = None + sum_170 = torch.ops.aten.sum.dim_IntList(mul_1562, [0, 1]); mul_1562 = None + convert_element_type_2047 = torch.ops.prims.convert_element_type.default(mul_1561, torch.bfloat16); mul_1561 = None + add_1894 = torch.ops.aten.add.Tensor(add_1892, convert_element_type_2047); add_1892 = convert_element_type_2047 = None + convert_element_type_default_58 = torch.ops.prims.convert_element_type.default(sum_170, torch.float32); sum_170 = None + reduce_scatter_tensor_113 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_58, 'avg', 64, '0'); convert_element_type_default_58 = None + wait_tensor_688 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_113); reduce_scatter_tensor_113 = None + view_1934 = torch.ops.aten.view.default(add_1894, [8192, 2048]) + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_1934, 1) + convert_element_type_2050 = torch.ops.prims.convert_element_type.default(unsqueeze_61, torch.float32); unsqueeze_61 = None + bmm_42 = torch.ops.aten.bmm.default(permute_806, convert_element_type_2050); permute_806 = None + bmm_43 = torch.ops.aten.bmm.default(convert_element_type_2050, permute_807); convert_element_type_2050 = permute_807 = None + convert_element_type_2051 = torch.ops.prims.convert_element_type.default(bmm_42, torch.bfloat16); bmm_42 = None + view_1935 = torch.ops.aten.view.default(bmm_43, [8192, 6]); bmm_43 = None + view_1936 = torch.ops.aten.view.default(convert_element_type_2051, [49152, 2048]); convert_element_type_2051 = None + index_68 = torch.ops.aten.index.Tensor(view_1936, [getitem_259]); view_1936 = getitem_259 = None + permute_808 = torch.ops.aten.permute.default(view_1934, [1, 0]) + mm_346 = torch.ops.aten.mm.default(permute_808, mul_888); permute_808 = mul_888 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(primals_310, torch.bfloat16); primals_310 = None + all_gather_into_tensor_316 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1004, 64, '0'); convert_element_type_1004 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_316); all_gather_into_tensor_316 = None + permute_280 = torch.ops.aten.permute.default(wait_tensor_388, [1, 0]); wait_tensor_388 = None + permute_810 = torch.ops.aten.permute.default(permute_280, [1, 0]); permute_280 = None + mm_347 = torch.ops.aten.mm.default(view_1934, permute_810); view_1934 = permute_810 = None + convert_element_type_2056 = torch.ops.prims.convert_element_type.default(mm_346, torch.float32); mm_346 = None + reduce_scatter_tensor_114 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2056, 'avg', 64, '0'); convert_element_type_2056 = None + wait_tensor_689 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_114); reduce_scatter_tensor_114 = None + convert_element_type_999 = torch.ops.prims.convert_element_type.default(mm_148, torch.float32); mm_148 = None + neg_36 = torch.ops.aten.neg.default(convert_element_type_999) + exp_54 = torch.ops.aten.exp.default(neg_36); neg_36 = None + add_1224 = torch.ops.aten.add.Tensor(exp_54, 1); exp_54 = None + div_90 = torch.ops.aten.div.Tensor(convert_element_type_999, add_1224) + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(div_90, torch.bfloat16); div_90 = None + mul_1563 = torch.ops.aten.mul.Tensor(mm_347, convert_element_type_1000); convert_element_type_1000 = None + mul_1564 = torch.ops.aten.mul.Tensor(mm_347, mm_149); mm_347 = mm_149 = None + permute_812 = torch.ops.aten.permute.default(mul_1563, [1, 0]) + mm_348 = torch.ops.aten.mm.default(permute_812, view_1197); permute_812 = None + convert_element_type_1001 = torch.ops.prims.convert_element_type.default(primals_309, torch.bfloat16); primals_309 = None + all_gather_into_tensor_315 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1001, 64, '0'); convert_element_type_1001 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_315); all_gather_into_tensor_315 = None + permute_279 = torch.ops.aten.permute.default(wait_tensor_387, [1, 0]); wait_tensor_387 = None + permute_814 = torch.ops.aten.permute.default(permute_279, [1, 0]); permute_279 = None + mm_349 = torch.ops.aten.mm.default(mul_1563, permute_814); mul_1563 = permute_814 = None + convert_element_type_2061 = torch.ops.prims.convert_element_type.default(mm_348, torch.float32); mm_348 = None + reduce_scatter_tensor_115 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2061, 'avg', 64, '0'); convert_element_type_2061 = None + wait_tensor_690 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_115); reduce_scatter_tensor_115 = None + convert_element_type_2062 = torch.ops.prims.convert_element_type.default(mul_1564, torch.float32); mul_1564 = None + reciprocal_16 = torch.ops.aten.reciprocal.default(add_1224); add_1224 = None + mul_1565 = torch.ops.aten.mul.Tensor(reciprocal_16, 1); reciprocal_16 = None + mul_1566 = torch.ops.aten.mul.Tensor(convert_element_type_2062, mul_1565); convert_element_type_2062 = None + sub_673 = torch.ops.aten.sub.Tensor(1, mul_1565); mul_1565 = None + mul_1567 = torch.ops.aten.mul.Tensor(convert_element_type_999, sub_673); convert_element_type_999 = sub_673 = None + add_1896 = torch.ops.aten.add.Tensor(mul_1567, 1); mul_1567 = None + mul_1568 = torch.ops.aten.mul.Tensor(mul_1566, add_1896); mul_1566 = add_1896 = None + convert_element_type_2064 = torch.ops.prims.convert_element_type.default(mul_1568, torch.bfloat16); mul_1568 = None + permute_816 = torch.ops.aten.permute.default(convert_element_type_2064, [1, 0]) + mm_350 = torch.ops.aten.mm.default(permute_816, view_1197); permute_816 = None + convert_element_type_996 = torch.ops.prims.convert_element_type.default(primals_308, torch.bfloat16); primals_308 = None + all_gather_into_tensor_314 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_996, 64, '0'); convert_element_type_996 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_314); all_gather_into_tensor_314 = None + permute_278 = torch.ops.aten.permute.default(wait_tensor_386, [1, 0]); wait_tensor_386 = None + permute_818 = torch.ops.aten.permute.default(permute_278, [1, 0]); permute_278 = None + mm_351 = torch.ops.aten.mm.default(convert_element_type_2064, permute_818); convert_element_type_2064 = permute_818 = None + add_1897 = torch.ops.aten.add.Tensor(mm_349, mm_351); mm_349 = mm_351 = None + convert_element_type_2069 = torch.ops.prims.convert_element_type.default(mm_350, torch.float32); mm_350 = None + reduce_scatter_tensor_116 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2069, 'avg', 64, '0'); convert_element_type_2069 = None + wait_tensor_691 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_116); reduce_scatter_tensor_116 = None + all_to_all_single_94 = torch.ops._c10d_functional.all_to_all_single.default(index_68, [_local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287], [_local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279], '521'); index_68 = None + wait_tensor_692 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_94); all_to_all_single_94 = None + full_380 = torch.ops.aten.full.default([sym_size_int_69, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_69 = None + slice_scatter_8 = torch.ops.aten.slice_scatter.default(full_380, wait_tensor_692, 0, 0, -1); wait_tensor_692 = None + index_69 = torch.ops.aten.index.Tensor(slice_scatter_8, [getitem_260]); slice_scatter_8 = None + permute_820 = torch.ops.aten.permute.default(index_69, [1, 0]) + _grouped_mm_126 = torch.ops.aten._grouped_mm.default(permute_820, mul_868, cumsum_53); permute_820 = mul_868 = None + convert_element_type_990 = torch.ops.prims.convert_element_type.default(primals_306, torch.bfloat16); primals_306 = None + all_gather_into_tensor_310 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_990, 8, '513'); convert_element_type_990 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_310); all_gather_into_tensor_310 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_381, [0, 2, 1]); wait_tensor_381 = None + permute_822 = torch.ops.aten.permute.default(permute_277, [0, 2, 1]); permute_277 = None + _grouped_mm_127 = torch.ops.aten._grouped_mm.default(index_69, permute_822, cumsum_53); index_69 = permute_822 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(_grouped_mm_51, torch.float32); _grouped_mm_51 = None + neg_35 = torch.ops.aten.neg.default(convert_element_type_994) + exp_53 = torch.ops.aten.exp.default(neg_35); neg_35 = None + add_1188 = torch.ops.aten.add.Tensor(exp_53, 1); exp_53 = None + div_89 = torch.ops.aten.div.Tensor(convert_element_type_994, add_1188) + convert_element_type_995 = torch.ops.prims.convert_element_type.default(div_89, torch.bfloat16); div_89 = None + mul_1569 = torch.ops.aten.mul.Tensor(_grouped_mm_127, convert_element_type_995); convert_element_type_995 = None + mul_1570 = torch.ops.aten.mul.Tensor(_grouped_mm_127, _grouped_mm_52); _grouped_mm_127 = _grouped_mm_52 = None + permute_824 = torch.ops.aten.permute.default(mul_1569, [1, 0]) + _grouped_mm_128 = torch.ops.aten._grouped_mm.default(permute_824, index_35, cumsum_53); permute_824 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_307, torch.bfloat16); primals_307 = None + all_gather_into_tensor_311 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 8, '513'); convert_element_type_991 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_311); all_gather_into_tensor_311 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_382, [0, 2, 1]); wait_tensor_382 = None + permute_826 = torch.ops.aten.permute.default(permute_276, [0, 2, 1]); permute_276 = None + _grouped_mm_129 = torch.ops.aten._grouped_mm.default(mul_1569, permute_826, cumsum_53); mul_1569 = permute_826 = None + convert_element_type_2070 = torch.ops.prims.convert_element_type.default(mul_1570, torch.float32); mul_1570 = None + reciprocal_17 = torch.ops.aten.reciprocal.default(add_1188); add_1188 = None + mul_1571 = torch.ops.aten.mul.Tensor(reciprocal_17, 1); reciprocal_17 = None + mul_1572 = torch.ops.aten.mul.Tensor(convert_element_type_2070, mul_1571); convert_element_type_2070 = None + sub_674 = torch.ops.aten.sub.Tensor(1, mul_1571); mul_1571 = None + mul_1573 = torch.ops.aten.mul.Tensor(convert_element_type_994, sub_674); convert_element_type_994 = sub_674 = None + add_1899 = torch.ops.aten.add.Tensor(mul_1573, 1); mul_1573 = None + mul_1574 = torch.ops.aten.mul.Tensor(mul_1572, add_1899); mul_1572 = add_1899 = None + convert_element_type_2072 = torch.ops.prims.convert_element_type.default(mul_1574, torch.bfloat16); mul_1574 = None + permute_828 = torch.ops.aten.permute.default(convert_element_type_2072, [1, 0]) + _grouped_mm_130 = torch.ops.aten._grouped_mm.default(permute_828, index_35, cumsum_53); permute_828 = index_35 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_305, torch.bfloat16); primals_305 = None + all_gather_into_tensor_308 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 8, '513'); convert_element_type_988 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_308); all_gather_into_tensor_308 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_379, [0, 2, 1]); wait_tensor_379 = None + permute_830 = torch.ops.aten.permute.default(permute_275, [0, 2, 1]); permute_275 = None + _grouped_mm_131 = torch.ops.aten._grouped_mm.default(convert_element_type_2072, permute_830, cumsum_53); convert_element_type_2072 = permute_830 = cumsum_53 = None + add_1900 = torch.ops.aten.add.Tensor(_grouped_mm_129, _grouped_mm_131); _grouped_mm_129 = _grouped_mm_131 = None + convert_element_type_2073 = torch.ops.prims.convert_element_type.default(_grouped_mm_128, torch.float32); _grouped_mm_128 = None + div_180 = torch.ops.aten.div.Tensor(convert_element_type_2073, 64); convert_element_type_2073 = None + reduce_scatter_tensor_117 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_180, 'sum', 8, '513'); div_180 = None + wait_tensor_693 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_117); reduce_scatter_tensor_117 = None + convert_element_type_2074 = torch.ops.prims.convert_element_type.default(_grouped_mm_126, torch.float32); _grouped_mm_126 = None + div_181 = torch.ops.aten.div.Tensor(convert_element_type_2074, 64); convert_element_type_2074 = None + reduce_scatter_tensor_118 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_181, 'sum', 8, '513'); div_181 = None + wait_tensor_694 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_118); reduce_scatter_tensor_118 = None + convert_element_type_2075 = torch.ops.prims.convert_element_type.default(_grouped_mm_130, torch.float32); _grouped_mm_130 = None + div_182 = torch.ops.aten.div.Tensor(convert_element_type_2075, 64); convert_element_type_2075 = None + reduce_scatter_tensor_119 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_182, 'sum', 8, '513'); div_182 = None + wait_tensor_695 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_119); reduce_scatter_tensor_119 = None + index_put_68 = torch.ops.aten.index_put.default(full_380, [getitem_260], add_1900, True); full_380 = getitem_260 = add_1900 = None + slice_155 = torch.ops.aten.slice.Tensor(index_put_68, 0, 0, add_1901); index_put_68 = add_1901 = None + all_to_all_single_95 = torch.ops._c10d_functional.all_to_all_single.default(slice_155, [_local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279], [_local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287], '521'); slice_155 = _local_scalar_dense_272 = _local_scalar_dense_273 = _local_scalar_dense_274 = _local_scalar_dense_275 = _local_scalar_dense_276 = _local_scalar_dense_277 = _local_scalar_dense_278 = _local_scalar_dense_279 = _local_scalar_dense_280 = _local_scalar_dense_281 = _local_scalar_dense_282 = _local_scalar_dense_283 = _local_scalar_dense_284 = _local_scalar_dense_285 = _local_scalar_dense_286 = _local_scalar_dense_287 = None + wait_tensor_696 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_95); all_to_all_single_95 = None + index_put_69 = torch.ops.aten.index_put.default(full_default_52, [div_87], wait_tensor_696, True); div_87 = wait_tensor_696 = None + add_1905 = torch.ops.aten.add.Tensor(add_1897, index_put_69); add_1897 = index_put_69 = None + mul_1575 = torch.ops.aten.mul.Tensor(view_1935, 1.0); view_1935 = None + scatter_add_8 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_257, mul_1575); getitem_257 = mul_1575 = None + convert_element_type_983 = torch.ops.prims.convert_element_type.default(mm_147, torch.float32); mm_147 = None + sub_408 = torch.ops.aten.sub.Tensor(convert_element_type_983, amax_17); convert_element_type_983 = amax_17 = None + exp_52 = torch.ops.aten.exp.default(sub_408); sub_408 = None + div_86 = torch.ops.aten.div.Tensor(exp_52, sum_69); exp_52 = sum_69 = None + mul_1576 = torch.ops.aten.mul.Tensor(scatter_add_8, div_86); scatter_add_8 = None + sum_171 = torch.ops.aten.sum.dim_IntList(mul_1576, [1], True) + neg_79 = torch.ops.aten.neg.default(div_86); div_86 = None + fma_8 = torch.ops.prims.fma.default(neg_79, sum_171, mul_1576); neg_79 = sum_171 = mul_1576 = None + convert_element_type_2076 = torch.ops.prims.convert_element_type.default(fma_8, torch.bfloat16); fma_8 = None + permute_832 = torch.ops.aten.permute.default(convert_element_type_2076, [1, 0]) + mm_352 = torch.ops.aten.mm.default(permute_832, view_1197); permute_832 = view_1197 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_303, torch.bfloat16); primals_303 = None + all_gather_into_tensor_307 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 64, '0'); convert_element_type_980 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_307); all_gather_into_tensor_307 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_375, [1, 0]); wait_tensor_375 = None + permute_834 = torch.ops.aten.permute.default(permute_274, [1, 0]); permute_274 = None + mm_353 = torch.ops.aten.mm.default(convert_element_type_2076, permute_834); convert_element_type_2076 = permute_834 = None + add_1906 = torch.ops.aten.add.Tensor(add_1905, mm_353); add_1905 = mm_353 = None + convert_element_type_2081 = torch.ops.prims.convert_element_type.default(mm_352, torch.float32); mm_352 = None + reduce_scatter_tensor_120 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2081, 'avg', 64, '0'); convert_element_type_2081 = None + wait_tensor_697 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_120); reduce_scatter_tensor_120 = None + view_1937 = torch.ops.aten.view.default(add_1906, [2, 4096, 2048]); add_1906 = None + convert_element_type_2082 = torch.ops.prims.convert_element_type.default(view_1937, torch.float32); view_1937 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_301, torch.bfloat16); primals_301 = None + all_gather_into_tensor_306 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 64, '0'); convert_element_type_977 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_306); all_gather_into_tensor_306 = None + convert_element_type_2084 = torch.ops.prims.convert_element_type.default(wait_tensor_374, torch.float32); wait_tensor_374 = None + mul_1577 = torch.ops.aten.mul.Tensor(convert_element_type_2082, convert_element_type_2084); convert_element_type_2084 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_1164, torch.float32); add_1164 = None + mul_848 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_56); convert_element_type_978 = None + mul_1579 = torch.ops.aten.mul.Tensor(mul_848, mul_1577) + sum_172 = torch.ops.aten.sum.dim_IntList(mul_1579, [2], True); mul_1579 = None + div_183 = torch.ops.aten.div.Tensor(mul_848, 2048) + mul_1580 = torch.ops.aten.mul.Tensor(div_183, sum_172); div_183 = sum_172 = None + sub_676 = torch.ops.aten.sub.Tensor(mul_1577, mul_1580); mul_1577 = mul_1580 = None + mul_1581 = torch.ops.aten.mul.Tensor(sub_676, rsqrt_56); sub_676 = rsqrt_56 = None + mul_1582 = torch.ops.aten.mul.Tensor(convert_element_type_2082, mul_848); convert_element_type_2082 = mul_848 = None + sum_173 = torch.ops.aten.sum.dim_IntList(mul_1582, [0, 1]); mul_1582 = None + convert_element_type_2085 = torch.ops.prims.convert_element_type.default(mul_1581, torch.bfloat16); mul_1581 = None + add_1907 = torch.ops.aten.add.Tensor(add_1894, convert_element_type_2085); add_1894 = convert_element_type_2085 = None + convert_element_type_default_57 = torch.ops.prims.convert_element_type.default(sum_173, torch.float32); sum_173 = None + reduce_scatter_tensor_121 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_57, 'avg', 64, '0'); convert_element_type_default_57 = None + wait_tensor_698 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_121); reduce_scatter_tensor_121 = None + view_1938 = torch.ops.aten.view.default(add_1907, [8192, 2048]) + permute_836 = torch.ops.aten.permute.default(view_1938, [1, 0]) + permute_272 = torch.ops.aten.permute.default(getitem_253, [0, 2, 1, 3]) + view_1192 = torch.ops.aten.view.default(permute_272, [2, 4096, -1]); permute_272 = None + view_1194 = torch.ops.aten.view.default(view_1192, [8192, 2048]); view_1192 = None + mm_354 = torch.ops.aten.mm.default(permute_836, view_1194); permute_836 = view_1194 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_300, torch.bfloat16); primals_300 = None + all_gather_into_tensor_305 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 64, '0'); convert_element_type_974 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_305); all_gather_into_tensor_305 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_373, [1, 0]); wait_tensor_373 = None + permute_838 = torch.ops.aten.permute.default(permute_273, [1, 0]); permute_273 = None + mm_355 = torch.ops.aten.mm.default(view_1938, permute_838); view_1938 = permute_838 = None + view_1939 = torch.ops.aten.view.default(mm_355, [2, 4096, 2048]); mm_355 = None + convert_element_type_2092 = torch.ops.prims.convert_element_type.default(mm_354, torch.float32); mm_354 = None + reduce_scatter_tensor_122 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2092, 'avg', 64, '0'); convert_element_type_2092 = None + wait_tensor_699 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_122); reduce_scatter_tensor_122 = None + view_1940 = torch.ops.aten.view.default(view_1939, [2, 4096, 16, 128]); view_1939 = None + permute_840 = torch.ops.aten.permute.default(view_1940, [0, 2, 1, 3]); view_1940 = None + fw_graph8 = self.fw_graph8 + joint_graph8 = self.joint_graph8 + mask_graph8 = self.mask_graph8 + flex_attention_backward_8 = torch.ops.higher_order.flex_attention_backward(permute_269, permute_270, permute_271, getitem_253, getitem_254, permute_840, None, fw_graph8, joint_graph8, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph8), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_269 = permute_270 = permute_271 = getitem_253 = getitem_254 = permute_840 = fw_graph8 = joint_graph8 = mask_graph8 = None + getitem_405 = flex_attention_backward_8[0] + getitem_406 = flex_attention_backward_8[1] + getitem_407 = flex_attention_backward_8[2]; flex_attention_backward_8 = None + permute_841 = torch.ops.aten.permute.default(getitem_407, [0, 2, 1, 3]); getitem_407 = None + permute_842 = torch.ops.aten.permute.default(getitem_406, [0, 2, 1, 3]); getitem_406 = None + permute_843 = torch.ops.aten.permute.default(getitem_405, [0, 2, 1, 3]); getitem_405 = None + slice_157 = torch.ops.aten.slice.Tensor(permute_842, 3, 0, 128) + slice_158 = torch.ops.aten.slice.Tensor(permute_842, 3, 128, 192); permute_842 = None + sum_174 = torch.ops.aten.sum.dim_IntList(slice_158, [2], True); slice_158 = None + cat_104 = torch.ops.aten.cat.default([slice_157, permute_841], 3); slice_157 = permute_841 = None + view_1941 = torch.ops.aten.view.default(cat_104, [2, 4096, 4096]); cat_104 = None + view_1942 = torch.ops.aten.view.default(view_1941, [8192, 4096]); view_1941 = None + permute_844 = torch.ops.aten.permute.default(view_1942, [1, 0]) + mm_356 = torch.ops.aten.mm.default(permute_844, view_1189); permute_844 = view_1189 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(primals_299, torch.bfloat16); primals_299 = None + all_gather_into_tensor_304 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_971, 64, '0'); convert_element_type_971 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_304); all_gather_into_tensor_304 = None + permute_268 = torch.ops.aten.permute.default(wait_tensor_372, [1, 0]); wait_tensor_372 = None + permute_846 = torch.ops.aten.permute.default(permute_268, [1, 0]); permute_268 = None + mm_357 = torch.ops.aten.mm.default(view_1942, permute_846); view_1942 = permute_846 = None + view_1943 = torch.ops.aten.view.default(mm_357, [2, 4096, 512]); mm_357 = None + convert_element_type_2097 = torch.ops.prims.convert_element_type.default(mm_356, torch.float32); mm_356 = None + reduce_scatter_tensor_123 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2097, 'avg', 64, '0'); convert_element_type_2097 = None + wait_tensor_700 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_123); reduce_scatter_tensor_123 = None + convert_element_type_2098 = torch.ops.prims.convert_element_type.default(view_1943, torch.float32); view_1943 = None + convert_element_type_968 = torch.ops.prims.convert_element_type.default(primals_298, torch.bfloat16); primals_298 = None + all_gather_into_tensor_303 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_968, 64, '0'); convert_element_type_968 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_303); all_gather_into_tensor_303 = None + convert_element_type_2100 = torch.ops.prims.convert_element_type.default(wait_tensor_371, torch.float32); wait_tensor_371 = None + mul_1583 = torch.ops.aten.mul.Tensor(convert_element_type_2098, convert_element_type_2100); convert_element_type_2100 = None + convert_element_type_969 = torch.ops.prims.convert_element_type.default(getitem_249, torch.float32); getitem_249 = None + mul_846 = torch.ops.aten.mul.Tensor(convert_element_type_969, rsqrt_55); convert_element_type_969 = None + mul_1585 = torch.ops.aten.mul.Tensor(mul_846, mul_1583) + sum_175 = torch.ops.aten.sum.dim_IntList(mul_1585, [2], True); mul_1585 = None + div_184 = torch.ops.aten.div.Tensor(mul_846, 512) + mul_1586 = torch.ops.aten.mul.Tensor(div_184, sum_175); div_184 = sum_175 = None + sub_677 = torch.ops.aten.sub.Tensor(mul_1583, mul_1586); mul_1583 = mul_1586 = None + mul_1587 = torch.ops.aten.mul.Tensor(sub_677, rsqrt_55); sub_677 = rsqrt_55 = None + mul_1588 = torch.ops.aten.mul.Tensor(convert_element_type_2098, mul_846); convert_element_type_2098 = mul_846 = None + sum_176 = torch.ops.aten.sum.dim_IntList(mul_1588, [0, 1]); mul_1588 = None + convert_element_type_2101 = torch.ops.prims.convert_element_type.default(mul_1587, torch.bfloat16); mul_1587 = None + convert_element_type_default_56 = torch.ops.prims.convert_element_type.default(sum_176, torch.float32); sum_176 = None + reduce_scatter_tensor_124 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_56, 'avg', 64, '0'); convert_element_type_default_56 = None + wait_tensor_701 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_124); reduce_scatter_tensor_124 = None + convert_element_type_2104 = torch.ops.prims.convert_element_type.default(sum_174, torch.float32); sum_174 = None + view_1944 = torch.ops.aten.view.default(convert_element_type_2104, [2, 4096, 1, 32, 2]); convert_element_type_2104 = None + view_as_complex_70 = torch.ops.aten.view_as_complex.default(view_1944); view_1944 = None + mul_1589 = torch.ops.aten.mul.Tensor(view_as_complex_70, clone_9); view_as_complex_70 = None + view_as_real_70 = torch.ops.aten.view_as_real.default(mul_1589); mul_1589 = None + view_1945 = torch.ops.aten.view.default(view_as_real_70, [2, 4096, 1, 64]); view_as_real_70 = None + convert_element_type_2105 = torch.ops.prims.convert_element_type.default(view_1945, torch.bfloat16); view_1945 = None + squeeze_34 = torch.ops.aten.squeeze.dim(convert_element_type_2105, 2); convert_element_type_2105 = None + cat_105 = torch.ops.aten.cat.default([convert_element_type_2101, squeeze_34], 2); convert_element_type_2101 = squeeze_34 = None + view_1946 = torch.ops.aten.view.default(cat_105, [8192, 576]); cat_105 = None + permute_848 = torch.ops.aten.permute.default(view_1946, [1, 0]) + mm_358 = torch.ops.aten.mm.default(permute_848, view_1175); permute_848 = None + convert_element_type_963 = torch.ops.prims.convert_element_type.default(primals_297, torch.bfloat16); primals_297 = None + all_gather_into_tensor_302 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_963, 64, '0'); convert_element_type_963 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_302); all_gather_into_tensor_302 = None + permute_267 = torch.ops.aten.permute.default(wait_tensor_370, [1, 0]); wait_tensor_370 = None + permute_850 = torch.ops.aten.permute.default(permute_267, [1, 0]); permute_267 = None + mm_359 = torch.ops.aten.mm.default(view_1946, permute_850); view_1946 = permute_850 = None + view_1947 = torch.ops.aten.view.default(mm_359, [2, 4096, 2048]); mm_359 = None + convert_element_type_2110 = torch.ops.prims.convert_element_type.default(mm_358, torch.float32); mm_358 = None + reduce_scatter_tensor_125 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2110, 'avg', 64, '0'); convert_element_type_2110 = None + wait_tensor_702 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_125); reduce_scatter_tensor_125 = None + slice_159 = torch.ops.aten.slice.Tensor(permute_843, 3, 0, 128) + slice_160 = torch.ops.aten.slice.Tensor(permute_843, 3, 128, 192); permute_843 = None + convert_element_type_2111 = torch.ops.prims.convert_element_type.default(slice_160, torch.float32); slice_160 = None + view_1948 = torch.ops.aten.view.default(convert_element_type_2111, [2, 4096, 16, 32, 2]); convert_element_type_2111 = None + view_as_complex_71 = torch.ops.aten.view_as_complex.default(view_1948); view_1948 = None + mul_1590 = torch.ops.aten.mul.Tensor(view_as_complex_71, clone_9); view_as_complex_71 = None + view_as_real_71 = torch.ops.aten.view_as_real.default(mul_1590); mul_1590 = None + view_1949 = torch.ops.aten.view.default(view_as_real_71, [2, 4096, 16, 64]); view_as_real_71 = None + convert_element_type_2112 = torch.ops.prims.convert_element_type.default(view_1949, torch.bfloat16); view_1949 = None + cat_106 = torch.ops.aten.cat.default([slice_159, convert_element_type_2112], 3); slice_159 = convert_element_type_2112 = None + view_1950 = torch.ops.aten.view.default(cat_106, [2, 4096, 3072]); cat_106 = None + view_1951 = torch.ops.aten.view.default(view_1950, [8192, 3072]); view_1950 = None + permute_852 = torch.ops.aten.permute.default(view_1951, [1, 0]) + mm_360 = torch.ops.aten.mm.default(permute_852, view_1175); permute_852 = view_1175 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_296, torch.bfloat16); primals_296 = None + all_gather_into_tensor_301 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 64, '0'); convert_element_type_958 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_301); all_gather_into_tensor_301 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_369, [1, 0]); wait_tensor_369 = None + permute_854 = torch.ops.aten.permute.default(permute_266, [1, 0]); permute_266 = None + mm_361 = torch.ops.aten.mm.default(view_1951, permute_854); view_1951 = permute_854 = None + view_1952 = torch.ops.aten.view.default(mm_361, [2, 4096, 2048]); mm_361 = None + add_1908 = torch.ops.aten.add.Tensor(view_1947, view_1952); view_1947 = view_1952 = None + convert_element_type_2117 = torch.ops.prims.convert_element_type.default(mm_360, torch.float32); mm_360 = None + reduce_scatter_tensor_126 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2117, 'avg', 64, '0'); convert_element_type_2117 = None + wait_tensor_703 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_126); reduce_scatter_tensor_126 = None + convert_element_type_2118 = torch.ops.prims.convert_element_type.default(add_1908, torch.float32); add_1908 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_295, torch.bfloat16); primals_295 = None + all_gather_into_tensor_300 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 64, '0'); convert_element_type_955 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_300); all_gather_into_tensor_300 = None + convert_element_type_2120 = torch.ops.prims.convert_element_type.default(wait_tensor_368, torch.float32); wait_tensor_368 = None + mul_1591 = torch.ops.aten.mul.Tensor(convert_element_type_2118, convert_element_type_2120); convert_element_type_2120 = None + convert_element_type_956 = torch.ops.prims.convert_element_type.default(add_1161, torch.float32); add_1161 = None + mul_842 = torch.ops.aten.mul.Tensor(convert_element_type_956, rsqrt_54); convert_element_type_956 = None + mul_1593 = torch.ops.aten.mul.Tensor(mul_842, mul_1591) + sum_177 = torch.ops.aten.sum.dim_IntList(mul_1593, [2], True); mul_1593 = None + div_185 = torch.ops.aten.div.Tensor(mul_842, 2048) + mul_1594 = torch.ops.aten.mul.Tensor(div_185, sum_177); div_185 = sum_177 = None + sub_678 = torch.ops.aten.sub.Tensor(mul_1591, mul_1594); mul_1591 = mul_1594 = None + mul_1595 = torch.ops.aten.mul.Tensor(sub_678, rsqrt_54); sub_678 = rsqrt_54 = None + mul_1596 = torch.ops.aten.mul.Tensor(convert_element_type_2118, mul_842); convert_element_type_2118 = mul_842 = None + sum_178 = torch.ops.aten.sum.dim_IntList(mul_1596, [0, 1]); mul_1596 = None + convert_element_type_2121 = torch.ops.prims.convert_element_type.default(mul_1595, torch.bfloat16); mul_1595 = None + add_1909 = torch.ops.aten.add.Tensor(add_1907, convert_element_type_2121); add_1907 = convert_element_type_2121 = None + convert_element_type_default_55 = torch.ops.prims.convert_element_type.default(sum_178, torch.float32); sum_178 = None + reduce_scatter_tensor_127 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_55, 'avg', 64, '0'); convert_element_type_default_55 = None + wait_tensor_704 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_127); reduce_scatter_tensor_127 = None + view_1953 = torch.ops.aten.view.default(add_1909, [8192, 2048]) + unsqueeze_62 = torch.ops.aten.unsqueeze.default(view_1953, 1) + convert_element_type_2124 = torch.ops.prims.convert_element_type.default(unsqueeze_62, torch.float32); unsqueeze_62 = None + bmm_44 = torch.ops.aten.bmm.default(permute_856, convert_element_type_2124); permute_856 = None + bmm_45 = torch.ops.aten.bmm.default(convert_element_type_2124, permute_857); convert_element_type_2124 = permute_857 = None + convert_element_type_2125 = torch.ops.prims.convert_element_type.default(bmm_44, torch.bfloat16); bmm_44 = None + view_1954 = torch.ops.aten.view.default(bmm_45, [8192, 6]); bmm_45 = None + view_1955 = torch.ops.aten.view.default(convert_element_type_2125, [49152, 2048]); convert_element_type_2125 = None + index_70 = torch.ops.aten.index.Tensor(view_1955, [getitem_245]); view_1955 = getitem_245 = None + permute_858 = torch.ops.aten.permute.default(view_1953, [1, 0]) + mm_362 = torch.ops.aten.mm.default(permute_858, mul_839); permute_858 = mul_839 = None + convert_element_type_950 = torch.ops.prims.convert_element_type.default(primals_294, torch.bfloat16); primals_294 = None + all_gather_into_tensor_299 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_950, 64, '0'); convert_element_type_950 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_299); all_gather_into_tensor_299 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_367, [1, 0]); wait_tensor_367 = None + permute_860 = torch.ops.aten.permute.default(permute_265, [1, 0]); permute_265 = None + mm_363 = torch.ops.aten.mm.default(view_1953, permute_860); view_1953 = permute_860 = None + convert_element_type_2130 = torch.ops.prims.convert_element_type.default(mm_362, torch.float32); mm_362 = None + reduce_scatter_tensor_128 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2130, 'avg', 64, '0'); convert_element_type_2130 = None + wait_tensor_705 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_128); reduce_scatter_tensor_128 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(mm_140, torch.float32); mm_140 = None + neg_34 = torch.ops.aten.neg.default(convert_element_type_945) + exp_51 = torch.ops.aten.exp.default(neg_34); neg_34 = None + add_1156 = torch.ops.aten.add.Tensor(exp_51, 1); exp_51 = None + div_85 = torch.ops.aten.div.Tensor(convert_element_type_945, add_1156) + convert_element_type_946 = torch.ops.prims.convert_element_type.default(div_85, torch.bfloat16); div_85 = None + mul_1597 = torch.ops.aten.mul.Tensor(mm_363, convert_element_type_946); convert_element_type_946 = None + mul_1598 = torch.ops.aten.mul.Tensor(mm_363, mm_141); mm_363 = mm_141 = None + permute_862 = torch.ops.aten.permute.default(mul_1597, [1, 0]) + mm_364 = torch.ops.aten.mm.default(permute_862, view_1130); permute_862 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16); primals_293 = None + all_gather_into_tensor_298 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 64, '0'); convert_element_type_947 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_298); all_gather_into_tensor_298 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_366, [1, 0]); wait_tensor_366 = None + permute_864 = torch.ops.aten.permute.default(permute_264, [1, 0]); permute_264 = None + mm_365 = torch.ops.aten.mm.default(mul_1597, permute_864); mul_1597 = permute_864 = None + convert_element_type_2135 = torch.ops.prims.convert_element_type.default(mm_364, torch.float32); mm_364 = None + reduce_scatter_tensor_129 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2135, 'avg', 64, '0'); convert_element_type_2135 = None + wait_tensor_706 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_129); reduce_scatter_tensor_129 = None + convert_element_type_2136 = torch.ops.prims.convert_element_type.default(mul_1598, torch.float32); mul_1598 = None + reciprocal_18 = torch.ops.aten.reciprocal.default(add_1156); add_1156 = None + mul_1599 = torch.ops.aten.mul.Tensor(reciprocal_18, 1); reciprocal_18 = None + mul_1600 = torch.ops.aten.mul.Tensor(convert_element_type_2136, mul_1599); convert_element_type_2136 = None + sub_679 = torch.ops.aten.sub.Tensor(1, mul_1599); mul_1599 = None + mul_1601 = torch.ops.aten.mul.Tensor(convert_element_type_945, sub_679); convert_element_type_945 = sub_679 = None + add_1911 = torch.ops.aten.add.Tensor(mul_1601, 1); mul_1601 = None + mul_1602 = torch.ops.aten.mul.Tensor(mul_1600, add_1911); mul_1600 = add_1911 = None + convert_element_type_2138 = torch.ops.prims.convert_element_type.default(mul_1602, torch.bfloat16); mul_1602 = None + permute_866 = torch.ops.aten.permute.default(convert_element_type_2138, [1, 0]) + mm_366 = torch.ops.aten.mm.default(permute_866, view_1130); permute_866 = None + convert_element_type_942 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16); primals_292 = None + all_gather_into_tensor_297 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_942, 64, '0'); convert_element_type_942 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_297); all_gather_into_tensor_297 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_365, [1, 0]); wait_tensor_365 = None + permute_868 = torch.ops.aten.permute.default(permute_263, [1, 0]); permute_263 = None + mm_367 = torch.ops.aten.mm.default(convert_element_type_2138, permute_868); convert_element_type_2138 = permute_868 = None + add_1912 = torch.ops.aten.add.Tensor(mm_365, mm_367); mm_365 = mm_367 = None + convert_element_type_2143 = torch.ops.prims.convert_element_type.default(mm_366, torch.float32); mm_366 = None + reduce_scatter_tensor_130 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2143, 'avg', 64, '0'); convert_element_type_2143 = None + wait_tensor_707 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_130); reduce_scatter_tensor_130 = None + all_to_all_single_96 = torch.ops._c10d_functional.all_to_all_single.default(index_70, [_local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271], [_local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263], '521'); index_70 = None + wait_tensor_708 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_96); all_to_all_single_96 = None + full_384 = torch.ops.aten.full.default([sym_size_int_65, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_65 = None + slice_scatter_9 = torch.ops.aten.slice_scatter.default(full_384, wait_tensor_708, 0, 0, -1); wait_tensor_708 = None + index_71 = torch.ops.aten.index.Tensor(slice_scatter_9, [getitem_246]); slice_scatter_9 = None + permute_870 = torch.ops.aten.permute.default(index_71, [1, 0]) + _grouped_mm_132 = torch.ops.aten._grouped_mm.default(permute_870, mul_819, cumsum_50); permute_870 = mul_819 = None + convert_element_type_936 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16); primals_290 = None + all_gather_into_tensor_293 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_936, 8, '513'); convert_element_type_936 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_293); all_gather_into_tensor_293 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_360, [0, 2, 1]); wait_tensor_360 = None + permute_872 = torch.ops.aten.permute.default(permute_262, [0, 2, 1]); permute_262 = None + _grouped_mm_133 = torch.ops.aten._grouped_mm.default(index_71, permute_872, cumsum_50); index_71 = permute_872 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(_grouped_mm_48, torch.float32); _grouped_mm_48 = None + neg_33 = torch.ops.aten.neg.default(convert_element_type_940) + exp_50 = torch.ops.aten.exp.default(neg_33); neg_33 = None + add_1120 = torch.ops.aten.add.Tensor(exp_50, 1); exp_50 = None + div_84 = torch.ops.aten.div.Tensor(convert_element_type_940, add_1120) + convert_element_type_941 = torch.ops.prims.convert_element_type.default(div_84, torch.bfloat16); div_84 = None + mul_1603 = torch.ops.aten.mul.Tensor(_grouped_mm_133, convert_element_type_941); convert_element_type_941 = None + mul_1604 = torch.ops.aten.mul.Tensor(_grouped_mm_133, _grouped_mm_49); _grouped_mm_133 = _grouped_mm_49 = None + permute_874 = torch.ops.aten.permute.default(mul_1603, [1, 0]) + _grouped_mm_134 = torch.ops.aten._grouped_mm.default(permute_874, index_33, cumsum_50); permute_874 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16); primals_291 = None + all_gather_into_tensor_294 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_937, 8, '513'); convert_element_type_937 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_294); all_gather_into_tensor_294 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_361, [0, 2, 1]); wait_tensor_361 = None + permute_876 = torch.ops.aten.permute.default(permute_261, [0, 2, 1]); permute_261 = None + _grouped_mm_135 = torch.ops.aten._grouped_mm.default(mul_1603, permute_876, cumsum_50); mul_1603 = permute_876 = None + convert_element_type_2144 = torch.ops.prims.convert_element_type.default(mul_1604, torch.float32); mul_1604 = None + reciprocal_19 = torch.ops.aten.reciprocal.default(add_1120); add_1120 = None + mul_1605 = torch.ops.aten.mul.Tensor(reciprocal_19, 1); reciprocal_19 = None + mul_1606 = torch.ops.aten.mul.Tensor(convert_element_type_2144, mul_1605); convert_element_type_2144 = None + sub_680 = torch.ops.aten.sub.Tensor(1, mul_1605); mul_1605 = None + mul_1607 = torch.ops.aten.mul.Tensor(convert_element_type_940, sub_680); convert_element_type_940 = sub_680 = None + add_1914 = torch.ops.aten.add.Tensor(mul_1607, 1); mul_1607 = None + mul_1608 = torch.ops.aten.mul.Tensor(mul_1606, add_1914); mul_1606 = add_1914 = None + convert_element_type_2146 = torch.ops.prims.convert_element_type.default(mul_1608, torch.bfloat16); mul_1608 = None + permute_878 = torch.ops.aten.permute.default(convert_element_type_2146, [1, 0]) + _grouped_mm_136 = torch.ops.aten._grouped_mm.default(permute_878, index_33, cumsum_50); permute_878 = index_33 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16); primals_289 = None + all_gather_into_tensor_291 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 8, '513'); convert_element_type_934 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_291); all_gather_into_tensor_291 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_358, [0, 2, 1]); wait_tensor_358 = None + permute_880 = torch.ops.aten.permute.default(permute_260, [0, 2, 1]); permute_260 = None + _grouped_mm_137 = torch.ops.aten._grouped_mm.default(convert_element_type_2146, permute_880, cumsum_50); convert_element_type_2146 = permute_880 = cumsum_50 = None + add_1915 = torch.ops.aten.add.Tensor(_grouped_mm_135, _grouped_mm_137); _grouped_mm_135 = _grouped_mm_137 = None + convert_element_type_2147 = torch.ops.prims.convert_element_type.default(_grouped_mm_134, torch.float32); _grouped_mm_134 = None + div_186 = torch.ops.aten.div.Tensor(convert_element_type_2147, 64); convert_element_type_2147 = None + reduce_scatter_tensor_131 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_186, 'sum', 8, '513'); div_186 = None + wait_tensor_709 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_131); reduce_scatter_tensor_131 = None + convert_element_type_2148 = torch.ops.prims.convert_element_type.default(_grouped_mm_132, torch.float32); _grouped_mm_132 = None + div_187 = torch.ops.aten.div.Tensor(convert_element_type_2148, 64); convert_element_type_2148 = None + reduce_scatter_tensor_132 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_187, 'sum', 8, '513'); div_187 = None + wait_tensor_710 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_132); reduce_scatter_tensor_132 = None + convert_element_type_2149 = torch.ops.prims.convert_element_type.default(_grouped_mm_136, torch.float32); _grouped_mm_136 = None + div_188 = torch.ops.aten.div.Tensor(convert_element_type_2149, 64); convert_element_type_2149 = None + reduce_scatter_tensor_133 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_188, 'sum', 8, '513'); div_188 = None + wait_tensor_711 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_133); reduce_scatter_tensor_133 = None + index_put_70 = torch.ops.aten.index_put.default(full_384, [getitem_246], add_1915, True); full_384 = getitem_246 = add_1915 = None + slice_161 = torch.ops.aten.slice.Tensor(index_put_70, 0, 0, add_1916); index_put_70 = add_1916 = None + all_to_all_single_97 = torch.ops._c10d_functional.all_to_all_single.default(slice_161, [_local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263], [_local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271], '521'); slice_161 = _local_scalar_dense_256 = _local_scalar_dense_257 = _local_scalar_dense_258 = _local_scalar_dense_259 = _local_scalar_dense_260 = _local_scalar_dense_261 = _local_scalar_dense_262 = _local_scalar_dense_263 = _local_scalar_dense_264 = _local_scalar_dense_265 = _local_scalar_dense_266 = _local_scalar_dense_267 = _local_scalar_dense_268 = _local_scalar_dense_269 = _local_scalar_dense_270 = _local_scalar_dense_271 = None + wait_tensor_712 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_97); all_to_all_single_97 = None + index_put_71 = torch.ops.aten.index_put.default(full_default_52, [div_82], wait_tensor_712, True); div_82 = wait_tensor_712 = None + add_1920 = torch.ops.aten.add.Tensor(add_1912, index_put_71); add_1912 = index_put_71 = None + mul_1609 = torch.ops.aten.mul.Tensor(view_1954, 1.0); view_1954 = None + scatter_add_9 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_243, mul_1609); getitem_243 = mul_1609 = None + convert_element_type_929 = torch.ops.prims.convert_element_type.default(mm_139, torch.float32); mm_139 = None + sub_384 = torch.ops.aten.sub.Tensor(convert_element_type_929, amax_16); convert_element_type_929 = amax_16 = None + exp_49 = torch.ops.aten.exp.default(sub_384); sub_384 = None + div_81 = torch.ops.aten.div.Tensor(exp_49, sum_65); exp_49 = sum_65 = None + mul_1610 = torch.ops.aten.mul.Tensor(scatter_add_9, div_81); scatter_add_9 = None + sum_179 = torch.ops.aten.sum.dim_IntList(mul_1610, [1], True) + neg_82 = torch.ops.aten.neg.default(div_81); div_81 = None + fma_9 = torch.ops.prims.fma.default(neg_82, sum_179, mul_1610); neg_82 = sum_179 = mul_1610 = None + convert_element_type_2150 = torch.ops.prims.convert_element_type.default(fma_9, torch.bfloat16); fma_9 = None + permute_882 = torch.ops.aten.permute.default(convert_element_type_2150, [1, 0]) + mm_368 = torch.ops.aten.mm.default(permute_882, view_1130); permute_882 = view_1130 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16); primals_287 = None + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_926, 64, '0'); convert_element_type_926 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_259 = torch.ops.aten.permute.default(wait_tensor_354, [1, 0]); wait_tensor_354 = None + permute_884 = torch.ops.aten.permute.default(permute_259, [1, 0]); permute_259 = None + mm_369 = torch.ops.aten.mm.default(convert_element_type_2150, permute_884); convert_element_type_2150 = permute_884 = None + add_1921 = torch.ops.aten.add.Tensor(add_1920, mm_369); add_1920 = mm_369 = None + convert_element_type_2155 = torch.ops.prims.convert_element_type.default(mm_368, torch.float32); mm_368 = None + reduce_scatter_tensor_134 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2155, 'avg', 64, '0'); convert_element_type_2155 = None + wait_tensor_713 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_134); reduce_scatter_tensor_134 = None + view_1956 = torch.ops.aten.view.default(add_1921, [2, 4096, 2048]); add_1921 = None + convert_element_type_2156 = torch.ops.prims.convert_element_type.default(view_1956, torch.float32); view_1956 = None + convert_element_type_923 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_923, 64, '0'); convert_element_type_923 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_2158 = torch.ops.prims.convert_element_type.default(wait_tensor_353, torch.float32); wait_tensor_353 = None + mul_1611 = torch.ops.aten.mul.Tensor(convert_element_type_2156, convert_element_type_2158); convert_element_type_2158 = None + convert_element_type_924 = torch.ops.prims.convert_element_type.default(add_1096, torch.float32); add_1096 = None + mul_799 = torch.ops.aten.mul.Tensor(convert_element_type_924, rsqrt_53); convert_element_type_924 = None + mul_1613 = torch.ops.aten.mul.Tensor(mul_799, mul_1611) + sum_180 = torch.ops.aten.sum.dim_IntList(mul_1613, [2], True); mul_1613 = None + div_189 = torch.ops.aten.div.Tensor(mul_799, 2048) + mul_1614 = torch.ops.aten.mul.Tensor(div_189, sum_180); div_189 = sum_180 = None + sub_682 = torch.ops.aten.sub.Tensor(mul_1611, mul_1614); mul_1611 = mul_1614 = None + mul_1615 = torch.ops.aten.mul.Tensor(sub_682, rsqrt_53); sub_682 = rsqrt_53 = None + mul_1616 = torch.ops.aten.mul.Tensor(convert_element_type_2156, mul_799); convert_element_type_2156 = mul_799 = None + sum_181 = torch.ops.aten.sum.dim_IntList(mul_1616, [0, 1]); mul_1616 = None + convert_element_type_2159 = torch.ops.prims.convert_element_type.default(mul_1615, torch.bfloat16); mul_1615 = None + add_1922 = torch.ops.aten.add.Tensor(add_1909, convert_element_type_2159); add_1909 = convert_element_type_2159 = None + convert_element_type_default_54 = torch.ops.prims.convert_element_type.default(sum_181, torch.float32); sum_181 = None + reduce_scatter_tensor_135 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_54, 'avg', 64, '0'); convert_element_type_default_54 = None + wait_tensor_714 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_135); reduce_scatter_tensor_135 = None + view_1957 = torch.ops.aten.view.default(add_1922, [8192, 2048]) + permute_886 = torch.ops.aten.permute.default(view_1957, [1, 0]) + permute_257 = torch.ops.aten.permute.default(getitem_239, [0, 2, 1, 3]) + view_1125 = torch.ops.aten.view.default(permute_257, [2, 4096, -1]); permute_257 = None + view_1127 = torch.ops.aten.view.default(view_1125, [8192, 2048]); view_1125 = None + mm_370 = torch.ops.aten.mm.default(permute_886, view_1127); permute_886 = view_1127 = None + convert_element_type_920 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_920, 64, '0'); convert_element_type_920 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_258 = torch.ops.aten.permute.default(wait_tensor_352, [1, 0]); wait_tensor_352 = None + permute_888 = torch.ops.aten.permute.default(permute_258, [1, 0]); permute_258 = None + mm_371 = torch.ops.aten.mm.default(view_1957, permute_888); view_1957 = permute_888 = None + view_1958 = torch.ops.aten.view.default(mm_371, [2, 4096, 2048]); mm_371 = None + convert_element_type_2166 = torch.ops.prims.convert_element_type.default(mm_370, torch.float32); mm_370 = None + reduce_scatter_tensor_136 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2166, 'avg', 64, '0'); convert_element_type_2166 = None + wait_tensor_715 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_136); reduce_scatter_tensor_136 = None + view_1959 = torch.ops.aten.view.default(view_1958, [2, 4096, 16, 128]); view_1958 = None + permute_890 = torch.ops.aten.permute.default(view_1959, [0, 2, 1, 3]); view_1959 = None + fw_graph9 = self.fw_graph9 + joint_graph9 = self.joint_graph9 + mask_graph9 = self.mask_graph9 + flex_attention_backward_9 = torch.ops.higher_order.flex_attention_backward(permute_254, permute_255, permute_256, getitem_239, getitem_240, permute_890, None, fw_graph9, joint_graph9, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph9), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_254 = permute_255 = permute_256 = getitem_239 = getitem_240 = permute_890 = fw_graph9 = joint_graph9 = mask_graph9 = None + getitem_409 = flex_attention_backward_9[0] + getitem_410 = flex_attention_backward_9[1] + getitem_411 = flex_attention_backward_9[2]; flex_attention_backward_9 = None + permute_891 = torch.ops.aten.permute.default(getitem_411, [0, 2, 1, 3]); getitem_411 = None + permute_892 = torch.ops.aten.permute.default(getitem_410, [0, 2, 1, 3]); getitem_410 = None + permute_893 = torch.ops.aten.permute.default(getitem_409, [0, 2, 1, 3]); getitem_409 = None + slice_163 = torch.ops.aten.slice.Tensor(permute_892, 3, 0, 128) + slice_164 = torch.ops.aten.slice.Tensor(permute_892, 3, 128, 192); permute_892 = None + sum_182 = torch.ops.aten.sum.dim_IntList(slice_164, [2], True); slice_164 = None + cat_107 = torch.ops.aten.cat.default([slice_163, permute_891], 3); slice_163 = permute_891 = None + view_1960 = torch.ops.aten.view.default(cat_107, [2, 4096, 4096]); cat_107 = None + view_1961 = torch.ops.aten.view.default(view_1960, [8192, 4096]); view_1960 = None + permute_894 = torch.ops.aten.permute.default(view_1961, [1, 0]) + mm_372 = torch.ops.aten.mm.default(permute_894, view_1122); permute_894 = view_1122 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16); primals_283 = None + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_917, 64, '0'); convert_element_type_917 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_351, [1, 0]); wait_tensor_351 = None + permute_896 = torch.ops.aten.permute.default(permute_253, [1, 0]); permute_253 = None + mm_373 = torch.ops.aten.mm.default(view_1961, permute_896); view_1961 = permute_896 = None + view_1962 = torch.ops.aten.view.default(mm_373, [2, 4096, 512]); mm_373 = None + convert_element_type_2171 = torch.ops.prims.convert_element_type.default(mm_372, torch.float32); mm_372 = None + reduce_scatter_tensor_137 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2171, 'avg', 64, '0'); convert_element_type_2171 = None + wait_tensor_716 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_137); reduce_scatter_tensor_137 = None + convert_element_type_2172 = torch.ops.prims.convert_element_type.default(view_1962, torch.float32); view_1962 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16); primals_282 = None + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 64, '0'); convert_element_type_914 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + convert_element_type_2174 = torch.ops.prims.convert_element_type.default(wait_tensor_350, torch.float32); wait_tensor_350 = None + mul_1617 = torch.ops.aten.mul.Tensor(convert_element_type_2172, convert_element_type_2174); convert_element_type_2174 = None + convert_element_type_915 = torch.ops.prims.convert_element_type.default(getitem_235, torch.float32); getitem_235 = None + mul_797 = torch.ops.aten.mul.Tensor(convert_element_type_915, rsqrt_52); convert_element_type_915 = None + mul_1619 = torch.ops.aten.mul.Tensor(mul_797, mul_1617) + sum_183 = torch.ops.aten.sum.dim_IntList(mul_1619, [2], True); mul_1619 = None + div_190 = torch.ops.aten.div.Tensor(mul_797, 512) + mul_1620 = torch.ops.aten.mul.Tensor(div_190, sum_183); div_190 = sum_183 = None + sub_683 = torch.ops.aten.sub.Tensor(mul_1617, mul_1620); mul_1617 = mul_1620 = None + mul_1621 = torch.ops.aten.mul.Tensor(sub_683, rsqrt_52); sub_683 = rsqrt_52 = None + mul_1622 = torch.ops.aten.mul.Tensor(convert_element_type_2172, mul_797); convert_element_type_2172 = mul_797 = None + sum_184 = torch.ops.aten.sum.dim_IntList(mul_1622, [0, 1]); mul_1622 = None + convert_element_type_2175 = torch.ops.prims.convert_element_type.default(mul_1621, torch.bfloat16); mul_1621 = None + convert_element_type_default_53 = torch.ops.prims.convert_element_type.default(sum_184, torch.float32); sum_184 = None + reduce_scatter_tensor_138 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_53, 'avg', 64, '0'); convert_element_type_default_53 = None + wait_tensor_717 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_138); reduce_scatter_tensor_138 = None + convert_element_type_2178 = torch.ops.prims.convert_element_type.default(sum_182, torch.float32); sum_182 = None + view_1963 = torch.ops.aten.view.default(convert_element_type_2178, [2, 4096, 1, 32, 2]); convert_element_type_2178 = None + view_as_complex_72 = torch.ops.aten.view_as_complex.default(view_1963); view_1963 = None + mul_1623 = torch.ops.aten.mul.Tensor(view_as_complex_72, clone_9); view_as_complex_72 = None + view_as_real_72 = torch.ops.aten.view_as_real.default(mul_1623); mul_1623 = None + view_1964 = torch.ops.aten.view.default(view_as_real_72, [2, 4096, 1, 64]); view_as_real_72 = None + convert_element_type_2179 = torch.ops.prims.convert_element_type.default(view_1964, torch.bfloat16); view_1964 = None + squeeze_35 = torch.ops.aten.squeeze.dim(convert_element_type_2179, 2); convert_element_type_2179 = None + cat_108 = torch.ops.aten.cat.default([convert_element_type_2175, squeeze_35], 2); convert_element_type_2175 = squeeze_35 = None + view_1965 = torch.ops.aten.view.default(cat_108, [8192, 576]); cat_108 = None + permute_898 = torch.ops.aten.permute.default(view_1965, [1, 0]) + mm_374 = torch.ops.aten.mm.default(permute_898, view_1108); permute_898 = None + convert_element_type_909 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16); primals_281 = None + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_909, 64, '0'); convert_element_type_909 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_349, [1, 0]); wait_tensor_349 = None + permute_900 = torch.ops.aten.permute.default(permute_252, [1, 0]); permute_252 = None + mm_375 = torch.ops.aten.mm.default(view_1965, permute_900); view_1965 = permute_900 = None + view_1966 = torch.ops.aten.view.default(mm_375, [2, 4096, 2048]); mm_375 = None + convert_element_type_2184 = torch.ops.prims.convert_element_type.default(mm_374, torch.float32); mm_374 = None + reduce_scatter_tensor_139 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2184, 'avg', 64, '0'); convert_element_type_2184 = None + wait_tensor_718 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_139); reduce_scatter_tensor_139 = None + slice_165 = torch.ops.aten.slice.Tensor(permute_893, 3, 0, 128) + slice_166 = torch.ops.aten.slice.Tensor(permute_893, 3, 128, 192); permute_893 = None + convert_element_type_2185 = torch.ops.prims.convert_element_type.default(slice_166, torch.float32); slice_166 = None + view_1967 = torch.ops.aten.view.default(convert_element_type_2185, [2, 4096, 16, 32, 2]); convert_element_type_2185 = None + view_as_complex_73 = torch.ops.aten.view_as_complex.default(view_1967); view_1967 = None + mul_1624 = torch.ops.aten.mul.Tensor(view_as_complex_73, clone_9); view_as_complex_73 = None + view_as_real_73 = torch.ops.aten.view_as_real.default(mul_1624); mul_1624 = None + view_1968 = torch.ops.aten.view.default(view_as_real_73, [2, 4096, 16, 64]); view_as_real_73 = None + convert_element_type_2186 = torch.ops.prims.convert_element_type.default(view_1968, torch.bfloat16); view_1968 = None + cat_109 = torch.ops.aten.cat.default([slice_165, convert_element_type_2186], 3); slice_165 = convert_element_type_2186 = None + view_1969 = torch.ops.aten.view.default(cat_109, [2, 4096, 3072]); cat_109 = None + view_1970 = torch.ops.aten.view.default(view_1969, [8192, 3072]); view_1969 = None + permute_902 = torch.ops.aten.permute.default(view_1970, [1, 0]) + mm_376 = torch.ops.aten.mm.default(permute_902, view_1108); permute_902 = view_1108 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16); primals_280 = None + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_904, 64, '0'); convert_element_type_904 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_348, [1, 0]); wait_tensor_348 = None + permute_904 = torch.ops.aten.permute.default(permute_251, [1, 0]); permute_251 = None + mm_377 = torch.ops.aten.mm.default(view_1970, permute_904); view_1970 = permute_904 = None + view_1971 = torch.ops.aten.view.default(mm_377, [2, 4096, 2048]); mm_377 = None + add_1923 = torch.ops.aten.add.Tensor(view_1966, view_1971); view_1966 = view_1971 = None + convert_element_type_2191 = torch.ops.prims.convert_element_type.default(mm_376, torch.float32); mm_376 = None + reduce_scatter_tensor_140 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2191, 'avg', 64, '0'); convert_element_type_2191 = None + wait_tensor_719 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_140); reduce_scatter_tensor_140 = None + convert_element_type_2192 = torch.ops.prims.convert_element_type.default(add_1923, torch.float32); add_1923 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16); primals_279 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 64, '0'); convert_element_type_901 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + convert_element_type_2194 = torch.ops.prims.convert_element_type.default(wait_tensor_347, torch.float32); wait_tensor_347 = None + mul_1625 = torch.ops.aten.mul.Tensor(convert_element_type_2192, convert_element_type_2194); convert_element_type_2194 = None + convert_element_type_902 = torch.ops.prims.convert_element_type.default(add_1093, torch.float32); add_1093 = None + mul_793 = torch.ops.aten.mul.Tensor(convert_element_type_902, rsqrt_51); convert_element_type_902 = None + mul_1627 = torch.ops.aten.mul.Tensor(mul_793, mul_1625) + sum_185 = torch.ops.aten.sum.dim_IntList(mul_1627, [2], True); mul_1627 = None + div_191 = torch.ops.aten.div.Tensor(mul_793, 2048) + mul_1628 = torch.ops.aten.mul.Tensor(div_191, sum_185); div_191 = sum_185 = None + sub_684 = torch.ops.aten.sub.Tensor(mul_1625, mul_1628); mul_1625 = mul_1628 = None + mul_1629 = torch.ops.aten.mul.Tensor(sub_684, rsqrt_51); sub_684 = rsqrt_51 = None + mul_1630 = torch.ops.aten.mul.Tensor(convert_element_type_2192, mul_793); convert_element_type_2192 = mul_793 = None + sum_186 = torch.ops.aten.sum.dim_IntList(mul_1630, [0, 1]); mul_1630 = None + convert_element_type_2195 = torch.ops.prims.convert_element_type.default(mul_1629, torch.bfloat16); mul_1629 = None + add_1924 = torch.ops.aten.add.Tensor(add_1922, convert_element_type_2195); add_1922 = convert_element_type_2195 = None + convert_element_type_default_52 = torch.ops.prims.convert_element_type.default(sum_186, torch.float32); sum_186 = None + reduce_scatter_tensor_141 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_52, 'avg', 64, '0'); convert_element_type_default_52 = None + wait_tensor_720 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_141); reduce_scatter_tensor_141 = None + view_1972 = torch.ops.aten.view.default(add_1924, [8192, 2048]) + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_1972, 1) + convert_element_type_2198 = torch.ops.prims.convert_element_type.default(unsqueeze_63, torch.float32); unsqueeze_63 = None + bmm_46 = torch.ops.aten.bmm.default(permute_906, convert_element_type_2198); permute_906 = None + bmm_47 = torch.ops.aten.bmm.default(convert_element_type_2198, permute_907); convert_element_type_2198 = permute_907 = None + convert_element_type_2199 = torch.ops.prims.convert_element_type.default(bmm_46, torch.bfloat16); bmm_46 = None + view_1973 = torch.ops.aten.view.default(bmm_47, [8192, 6]); bmm_47 = None + view_1974 = torch.ops.aten.view.default(convert_element_type_2199, [49152, 2048]); convert_element_type_2199 = None + index_72 = torch.ops.aten.index.Tensor(view_1974, [getitem_231]); view_1974 = getitem_231 = None + permute_908 = torch.ops.aten.permute.default(view_1972, [1, 0]) + mm_378 = torch.ops.aten.mm.default(permute_908, mul_790); permute_908 = mul_790 = None + convert_element_type_896 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16); primals_278 = None + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_896, 64, '0'); convert_element_type_896 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_346, [1, 0]); wait_tensor_346 = None + permute_910 = torch.ops.aten.permute.default(permute_250, [1, 0]); permute_250 = None + mm_379 = torch.ops.aten.mm.default(view_1972, permute_910); view_1972 = permute_910 = None + convert_element_type_2204 = torch.ops.prims.convert_element_type.default(mm_378, torch.float32); mm_378 = None + reduce_scatter_tensor_142 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2204, 'avg', 64, '0'); convert_element_type_2204 = None + wait_tensor_721 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_142); reduce_scatter_tensor_142 = None + convert_element_type_891 = torch.ops.prims.convert_element_type.default(mm_132, torch.float32); mm_132 = None + neg_32 = torch.ops.aten.neg.default(convert_element_type_891) + exp_48 = torch.ops.aten.exp.default(neg_32); neg_32 = None + add_1088 = torch.ops.aten.add.Tensor(exp_48, 1); exp_48 = None + div_80 = torch.ops.aten.div.Tensor(convert_element_type_891, add_1088) + convert_element_type_892 = torch.ops.prims.convert_element_type.default(div_80, torch.bfloat16); div_80 = None + mul_1631 = torch.ops.aten.mul.Tensor(mm_379, convert_element_type_892); convert_element_type_892 = None + mul_1632 = torch.ops.aten.mul.Tensor(mm_379, mm_133); mm_379 = mm_133 = None + permute_912 = torch.ops.aten.permute.default(mul_1631, [1, 0]) + mm_380 = torch.ops.aten.mm.default(permute_912, view_1063); permute_912 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16); primals_277 = None + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_893, 64, '0'); convert_element_type_893 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_345, [1, 0]); wait_tensor_345 = None + permute_914 = torch.ops.aten.permute.default(permute_249, [1, 0]); permute_249 = None + mm_381 = torch.ops.aten.mm.default(mul_1631, permute_914); mul_1631 = permute_914 = None + convert_element_type_2209 = torch.ops.prims.convert_element_type.default(mm_380, torch.float32); mm_380 = None + reduce_scatter_tensor_143 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2209, 'avg', 64, '0'); convert_element_type_2209 = None + wait_tensor_722 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_143); reduce_scatter_tensor_143 = None + convert_element_type_2210 = torch.ops.prims.convert_element_type.default(mul_1632, torch.float32); mul_1632 = None + reciprocal_20 = torch.ops.aten.reciprocal.default(add_1088); add_1088 = None + mul_1633 = torch.ops.aten.mul.Tensor(reciprocal_20, 1); reciprocal_20 = None + mul_1634 = torch.ops.aten.mul.Tensor(convert_element_type_2210, mul_1633); convert_element_type_2210 = None + sub_685 = torch.ops.aten.sub.Tensor(1, mul_1633); mul_1633 = None + mul_1635 = torch.ops.aten.mul.Tensor(convert_element_type_891, sub_685); convert_element_type_891 = sub_685 = None + add_1926 = torch.ops.aten.add.Tensor(mul_1635, 1); mul_1635 = None + mul_1636 = torch.ops.aten.mul.Tensor(mul_1634, add_1926); mul_1634 = add_1926 = None + convert_element_type_2212 = torch.ops.prims.convert_element_type.default(mul_1636, torch.bfloat16); mul_1636 = None + permute_916 = torch.ops.aten.permute.default(convert_element_type_2212, [1, 0]) + mm_382 = torch.ops.aten.mm.default(permute_916, view_1063); permute_916 = None + convert_element_type_888 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16); primals_276 = None + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_888, 64, '0'); convert_element_type_888 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + permute_248 = torch.ops.aten.permute.default(wait_tensor_344, [1, 0]); wait_tensor_344 = None + permute_918 = torch.ops.aten.permute.default(permute_248, [1, 0]); permute_248 = None + mm_383 = torch.ops.aten.mm.default(convert_element_type_2212, permute_918); convert_element_type_2212 = permute_918 = None + add_1927 = torch.ops.aten.add.Tensor(mm_381, mm_383); mm_381 = mm_383 = None + convert_element_type_2217 = torch.ops.prims.convert_element_type.default(mm_382, torch.float32); mm_382 = None + reduce_scatter_tensor_144 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2217, 'avg', 64, '0'); convert_element_type_2217 = None + wait_tensor_723 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_144); reduce_scatter_tensor_144 = None + all_to_all_single_98 = torch.ops._c10d_functional.all_to_all_single.default(index_72, [_local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255], [_local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247], '521'); index_72 = None + wait_tensor_724 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_98); all_to_all_single_98 = None + full_388 = torch.ops.aten.full.default([sym_size_int_61, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_61 = None + slice_scatter_10 = torch.ops.aten.slice_scatter.default(full_388, wait_tensor_724, 0, 0, -1); wait_tensor_724 = None + index_73 = torch.ops.aten.index.Tensor(slice_scatter_10, [getitem_232]); slice_scatter_10 = None + permute_920 = torch.ops.aten.permute.default(index_73, [1, 0]) + _grouped_mm_138 = torch.ops.aten._grouped_mm.default(permute_920, mul_770, cumsum_47); permute_920 = mul_770 = None + convert_element_type_882 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16); primals_274 = None + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_882, 8, '513'); convert_element_type_882 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + permute_247 = torch.ops.aten.permute.default(wait_tensor_339, [0, 2, 1]); wait_tensor_339 = None + permute_922 = torch.ops.aten.permute.default(permute_247, [0, 2, 1]); permute_247 = None + _grouped_mm_139 = torch.ops.aten._grouped_mm.default(index_73, permute_922, cumsum_47); index_73 = permute_922 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(_grouped_mm_45, torch.float32); _grouped_mm_45 = None + neg_31 = torch.ops.aten.neg.default(convert_element_type_886) + exp_47 = torch.ops.aten.exp.default(neg_31); neg_31 = None + add_1052 = torch.ops.aten.add.Tensor(exp_47, 1); exp_47 = None + div_79 = torch.ops.aten.div.Tensor(convert_element_type_886, add_1052) + convert_element_type_887 = torch.ops.prims.convert_element_type.default(div_79, torch.bfloat16); div_79 = None + mul_1637 = torch.ops.aten.mul.Tensor(_grouped_mm_139, convert_element_type_887); convert_element_type_887 = None + mul_1638 = torch.ops.aten.mul.Tensor(_grouped_mm_139, _grouped_mm_46); _grouped_mm_139 = _grouped_mm_46 = None + permute_924 = torch.ops.aten.permute.default(mul_1637, [1, 0]) + _grouped_mm_140 = torch.ops.aten._grouped_mm.default(permute_924, index_31, cumsum_47); permute_924 = None + convert_element_type_883 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16); primals_275 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_883, 8, '513'); convert_element_type_883 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + permute_246 = torch.ops.aten.permute.default(wait_tensor_340, [0, 2, 1]); wait_tensor_340 = None + permute_926 = torch.ops.aten.permute.default(permute_246, [0, 2, 1]); permute_246 = None + _grouped_mm_141 = torch.ops.aten._grouped_mm.default(mul_1637, permute_926, cumsum_47); mul_1637 = permute_926 = None + convert_element_type_2218 = torch.ops.prims.convert_element_type.default(mul_1638, torch.float32); mul_1638 = None + reciprocal_21 = torch.ops.aten.reciprocal.default(add_1052); add_1052 = None + mul_1639 = torch.ops.aten.mul.Tensor(reciprocal_21, 1); reciprocal_21 = None + mul_1640 = torch.ops.aten.mul.Tensor(convert_element_type_2218, mul_1639); convert_element_type_2218 = None + sub_686 = torch.ops.aten.sub.Tensor(1, mul_1639); mul_1639 = None + mul_1641 = torch.ops.aten.mul.Tensor(convert_element_type_886, sub_686); convert_element_type_886 = sub_686 = None + add_1929 = torch.ops.aten.add.Tensor(mul_1641, 1); mul_1641 = None + mul_1642 = torch.ops.aten.mul.Tensor(mul_1640, add_1929); mul_1640 = add_1929 = None + convert_element_type_2220 = torch.ops.prims.convert_element_type.default(mul_1642, torch.bfloat16); mul_1642 = None + permute_928 = torch.ops.aten.permute.default(convert_element_type_2220, [1, 0]) + _grouped_mm_142 = torch.ops.aten._grouped_mm.default(permute_928, index_31, cumsum_47); permute_928 = index_31 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16); primals_273 = None + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_880, 8, '513'); convert_element_type_880 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_245 = torch.ops.aten.permute.default(wait_tensor_337, [0, 2, 1]); wait_tensor_337 = None + permute_930 = torch.ops.aten.permute.default(permute_245, [0, 2, 1]); permute_245 = None + _grouped_mm_143 = torch.ops.aten._grouped_mm.default(convert_element_type_2220, permute_930, cumsum_47); convert_element_type_2220 = permute_930 = cumsum_47 = None + add_1930 = torch.ops.aten.add.Tensor(_grouped_mm_141, _grouped_mm_143); _grouped_mm_141 = _grouped_mm_143 = None + convert_element_type_2221 = torch.ops.prims.convert_element_type.default(_grouped_mm_140, torch.float32); _grouped_mm_140 = None + div_192 = torch.ops.aten.div.Tensor(convert_element_type_2221, 64); convert_element_type_2221 = None + reduce_scatter_tensor_145 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_192, 'sum', 8, '513'); div_192 = None + wait_tensor_725 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_145); reduce_scatter_tensor_145 = None + convert_element_type_2222 = torch.ops.prims.convert_element_type.default(_grouped_mm_138, torch.float32); _grouped_mm_138 = None + div_193 = torch.ops.aten.div.Tensor(convert_element_type_2222, 64); convert_element_type_2222 = None + reduce_scatter_tensor_146 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_193, 'sum', 8, '513'); div_193 = None + wait_tensor_726 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_146); reduce_scatter_tensor_146 = None + convert_element_type_2223 = torch.ops.prims.convert_element_type.default(_grouped_mm_142, torch.float32); _grouped_mm_142 = None + div_194 = torch.ops.aten.div.Tensor(convert_element_type_2223, 64); convert_element_type_2223 = None + reduce_scatter_tensor_147 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_194, 'sum', 8, '513'); div_194 = None + wait_tensor_727 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_147); reduce_scatter_tensor_147 = None + index_put_72 = torch.ops.aten.index_put.default(full_388, [getitem_232], add_1930, True); full_388 = getitem_232 = add_1930 = None + slice_167 = torch.ops.aten.slice.Tensor(index_put_72, 0, 0, add_1931); index_put_72 = add_1931 = None + all_to_all_single_99 = torch.ops._c10d_functional.all_to_all_single.default(slice_167, [_local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247], [_local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255], '521'); slice_167 = _local_scalar_dense_240 = _local_scalar_dense_241 = _local_scalar_dense_242 = _local_scalar_dense_243 = _local_scalar_dense_244 = _local_scalar_dense_245 = _local_scalar_dense_246 = _local_scalar_dense_247 = _local_scalar_dense_248 = _local_scalar_dense_249 = _local_scalar_dense_250 = _local_scalar_dense_251 = _local_scalar_dense_252 = _local_scalar_dense_253 = _local_scalar_dense_254 = _local_scalar_dense_255 = None + wait_tensor_728 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_99); all_to_all_single_99 = None + index_put_73 = torch.ops.aten.index_put.default(full_default_52, [div_77], wait_tensor_728, True); div_77 = wait_tensor_728 = None + add_1935 = torch.ops.aten.add.Tensor(add_1927, index_put_73); add_1927 = index_put_73 = None + mul_1643 = torch.ops.aten.mul.Tensor(view_1973, 1.0); view_1973 = None + scatter_add_10 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_229, mul_1643); getitem_229 = mul_1643 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(mm_131, torch.float32); mm_131 = None + sub_360 = torch.ops.aten.sub.Tensor(convert_element_type_875, amax_15); convert_element_type_875 = amax_15 = None + exp_46 = torch.ops.aten.exp.default(sub_360); sub_360 = None + div_76 = torch.ops.aten.div.Tensor(exp_46, sum_61); exp_46 = sum_61 = None + mul_1644 = torch.ops.aten.mul.Tensor(scatter_add_10, div_76); scatter_add_10 = None + sum_187 = torch.ops.aten.sum.dim_IntList(mul_1644, [1], True) + neg_85 = torch.ops.aten.neg.default(div_76); div_76 = None + fma_10 = torch.ops.prims.fma.default(neg_85, sum_187, mul_1644); neg_85 = sum_187 = mul_1644 = None + convert_element_type_2224 = torch.ops.prims.convert_element_type.default(fma_10, torch.bfloat16); fma_10 = None + permute_932 = torch.ops.aten.permute.default(convert_element_type_2224, [1, 0]) + mm_384 = torch.ops.aten.mm.default(permute_932, view_1063); permute_932 = view_1063 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16); primals_271 = None + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_872, 64, '0'); convert_element_type_872 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_333, [1, 0]); wait_tensor_333 = None + permute_934 = torch.ops.aten.permute.default(permute_244, [1, 0]); permute_244 = None + mm_385 = torch.ops.aten.mm.default(convert_element_type_2224, permute_934); convert_element_type_2224 = permute_934 = None + add_1936 = torch.ops.aten.add.Tensor(add_1935, mm_385); add_1935 = mm_385 = None + convert_element_type_2229 = torch.ops.prims.convert_element_type.default(mm_384, torch.float32); mm_384 = None + reduce_scatter_tensor_148 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2229, 'avg', 64, '0'); convert_element_type_2229 = None + wait_tensor_729 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_148); reduce_scatter_tensor_148 = None + view_1975 = torch.ops.aten.view.default(add_1936, [2, 4096, 2048]); add_1936 = None + convert_element_type_2230 = torch.ops.prims.convert_element_type.default(view_1975, torch.float32); view_1975 = None + convert_element_type_869 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16); primals_269 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_869, 64, '0'); convert_element_type_869 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + convert_element_type_2232 = torch.ops.prims.convert_element_type.default(wait_tensor_332, torch.float32); wait_tensor_332 = None + mul_1645 = torch.ops.aten.mul.Tensor(convert_element_type_2230, convert_element_type_2232); convert_element_type_2232 = None + convert_element_type_870 = torch.ops.prims.convert_element_type.default(add_1028, torch.float32); add_1028 = None + mul_750 = torch.ops.aten.mul.Tensor(convert_element_type_870, rsqrt_50); convert_element_type_870 = None + mul_1647 = torch.ops.aten.mul.Tensor(mul_750, mul_1645) + sum_188 = torch.ops.aten.sum.dim_IntList(mul_1647, [2], True); mul_1647 = None + div_195 = torch.ops.aten.div.Tensor(mul_750, 2048) + mul_1648 = torch.ops.aten.mul.Tensor(div_195, sum_188); div_195 = sum_188 = None + sub_688 = torch.ops.aten.sub.Tensor(mul_1645, mul_1648); mul_1645 = mul_1648 = None + mul_1649 = torch.ops.aten.mul.Tensor(sub_688, rsqrt_50); sub_688 = rsqrt_50 = None + mul_1650 = torch.ops.aten.mul.Tensor(convert_element_type_2230, mul_750); convert_element_type_2230 = mul_750 = None + sum_189 = torch.ops.aten.sum.dim_IntList(mul_1650, [0, 1]); mul_1650 = None + convert_element_type_2233 = torch.ops.prims.convert_element_type.default(mul_1649, torch.bfloat16); mul_1649 = None + add_1937 = torch.ops.aten.add.Tensor(add_1924, convert_element_type_2233); add_1924 = convert_element_type_2233 = None + convert_element_type_default_51 = torch.ops.prims.convert_element_type.default(sum_189, torch.float32); sum_189 = None + reduce_scatter_tensor_149 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_51, 'avg', 64, '0'); convert_element_type_default_51 = None + wait_tensor_730 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_149); reduce_scatter_tensor_149 = None + view_1976 = torch.ops.aten.view.default(add_1937, [8192, 2048]) + permute_936 = torch.ops.aten.permute.default(view_1976, [1, 0]) + permute_242 = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]) + view_1058 = torch.ops.aten.view.default(permute_242, [2, 4096, -1]); permute_242 = None + view_1060 = torch.ops.aten.view.default(view_1058, [8192, 2048]); view_1058 = None + mm_386 = torch.ops.aten.mm.default(permute_936, view_1060); permute_936 = view_1060 = None + convert_element_type_866 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_866, 64, '0'); convert_element_type_866 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_331, [1, 0]); wait_tensor_331 = None + permute_938 = torch.ops.aten.permute.default(permute_243, [1, 0]); permute_243 = None + mm_387 = torch.ops.aten.mm.default(view_1976, permute_938); view_1976 = permute_938 = None + view_1977 = torch.ops.aten.view.default(mm_387, [2, 4096, 2048]); mm_387 = None + convert_element_type_2240 = torch.ops.prims.convert_element_type.default(mm_386, torch.float32); mm_386 = None + reduce_scatter_tensor_150 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2240, 'avg', 64, '0'); convert_element_type_2240 = None + wait_tensor_731 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_150); reduce_scatter_tensor_150 = None + view_1978 = torch.ops.aten.view.default(view_1977, [2, 4096, 16, 128]); view_1977 = None + permute_940 = torch.ops.aten.permute.default(view_1978, [0, 2, 1, 3]); view_1978 = None + fw_graph10 = self.fw_graph10 + joint_graph10 = self.joint_graph10 + mask_graph10 = self.mask_graph10 + flex_attention_backward_10 = torch.ops.higher_order.flex_attention_backward(permute_239, permute_240, permute_241, getitem_225, getitem_226, permute_940, None, fw_graph10, joint_graph10, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph10), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_239 = permute_240 = permute_241 = getitem_225 = getitem_226 = permute_940 = fw_graph10 = joint_graph10 = mask_graph10 = None + getitem_413 = flex_attention_backward_10[0] + getitem_414 = flex_attention_backward_10[1] + getitem_415 = flex_attention_backward_10[2]; flex_attention_backward_10 = None + permute_941 = torch.ops.aten.permute.default(getitem_415, [0, 2, 1, 3]); getitem_415 = None + permute_942 = torch.ops.aten.permute.default(getitem_414, [0, 2, 1, 3]); getitem_414 = None + permute_943 = torch.ops.aten.permute.default(getitem_413, [0, 2, 1, 3]); getitem_413 = None + slice_169 = torch.ops.aten.slice.Tensor(permute_942, 3, 0, 128) + slice_170 = torch.ops.aten.slice.Tensor(permute_942, 3, 128, 192); permute_942 = None + sum_190 = torch.ops.aten.sum.dim_IntList(slice_170, [2], True); slice_170 = None + cat_110 = torch.ops.aten.cat.default([slice_169, permute_941], 3); slice_169 = permute_941 = None + view_1979 = torch.ops.aten.view.default(cat_110, [2, 4096, 4096]); cat_110 = None + view_1980 = torch.ops.aten.view.default(view_1979, [8192, 4096]); view_1979 = None + permute_944 = torch.ops.aten.permute.default(view_1980, [1, 0]) + mm_388 = torch.ops.aten.mm.default(permute_944, view_1055); permute_944 = view_1055 = None + convert_element_type_863 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_863, 64, '0'); convert_element_type_863 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_330, [1, 0]); wait_tensor_330 = None + permute_946 = torch.ops.aten.permute.default(permute_238, [1, 0]); permute_238 = None + mm_389 = torch.ops.aten.mm.default(view_1980, permute_946); view_1980 = permute_946 = None + view_1981 = torch.ops.aten.view.default(mm_389, [2, 4096, 512]); mm_389 = None + convert_element_type_2245 = torch.ops.prims.convert_element_type.default(mm_388, torch.float32); mm_388 = None + reduce_scatter_tensor_151 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2245, 'avg', 64, '0'); convert_element_type_2245 = None + wait_tensor_732 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_151); reduce_scatter_tensor_151 = None + convert_element_type_2246 = torch.ops.prims.convert_element_type.default(view_1981, torch.float32); view_1981 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16); primals_266 = None + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_860, 64, '0'); convert_element_type_860 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + convert_element_type_2248 = torch.ops.prims.convert_element_type.default(wait_tensor_329, torch.float32); wait_tensor_329 = None + mul_1651 = torch.ops.aten.mul.Tensor(convert_element_type_2246, convert_element_type_2248); convert_element_type_2248 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(getitem_221, torch.float32); getitem_221 = None + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_861, rsqrt_49); convert_element_type_861 = None + mul_1653 = torch.ops.aten.mul.Tensor(mul_748, mul_1651) + sum_191 = torch.ops.aten.sum.dim_IntList(mul_1653, [2], True); mul_1653 = None + div_196 = torch.ops.aten.div.Tensor(mul_748, 512) + mul_1654 = torch.ops.aten.mul.Tensor(div_196, sum_191); div_196 = sum_191 = None + sub_689 = torch.ops.aten.sub.Tensor(mul_1651, mul_1654); mul_1651 = mul_1654 = None + mul_1655 = torch.ops.aten.mul.Tensor(sub_689, rsqrt_49); sub_689 = rsqrt_49 = None + mul_1656 = torch.ops.aten.mul.Tensor(convert_element_type_2246, mul_748); convert_element_type_2246 = mul_748 = None + sum_192 = torch.ops.aten.sum.dim_IntList(mul_1656, [0, 1]); mul_1656 = None + convert_element_type_2249 = torch.ops.prims.convert_element_type.default(mul_1655, torch.bfloat16); mul_1655 = None + convert_element_type_default_50 = torch.ops.prims.convert_element_type.default(sum_192, torch.float32); sum_192 = None + reduce_scatter_tensor_152 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_50, 'avg', 64, '0'); convert_element_type_default_50 = None + wait_tensor_733 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_152); reduce_scatter_tensor_152 = None + convert_element_type_2252 = torch.ops.prims.convert_element_type.default(sum_190, torch.float32); sum_190 = None + view_1982 = torch.ops.aten.view.default(convert_element_type_2252, [2, 4096, 1, 32, 2]); convert_element_type_2252 = None + view_as_complex_74 = torch.ops.aten.view_as_complex.default(view_1982); view_1982 = None + mul_1657 = torch.ops.aten.mul.Tensor(view_as_complex_74, clone_9); view_as_complex_74 = None + view_as_real_74 = torch.ops.aten.view_as_real.default(mul_1657); mul_1657 = None + view_1983 = torch.ops.aten.view.default(view_as_real_74, [2, 4096, 1, 64]); view_as_real_74 = None + convert_element_type_2253 = torch.ops.prims.convert_element_type.default(view_1983, torch.bfloat16); view_1983 = None + squeeze_36 = torch.ops.aten.squeeze.dim(convert_element_type_2253, 2); convert_element_type_2253 = None + cat_111 = torch.ops.aten.cat.default([convert_element_type_2249, squeeze_36], 2); convert_element_type_2249 = squeeze_36 = None + view_1984 = torch.ops.aten.view.default(cat_111, [8192, 576]); cat_111 = None + permute_948 = torch.ops.aten.permute.default(view_1984, [1, 0]) + mm_390 = torch.ops.aten.mm.default(permute_948, view_1041); permute_948 = None + convert_element_type_855 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16); primals_265 = None + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_855, 64, '0'); convert_element_type_855 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_237 = torch.ops.aten.permute.default(wait_tensor_328, [1, 0]); wait_tensor_328 = None + permute_950 = torch.ops.aten.permute.default(permute_237, [1, 0]); permute_237 = None + mm_391 = torch.ops.aten.mm.default(view_1984, permute_950); view_1984 = permute_950 = None + view_1985 = torch.ops.aten.view.default(mm_391, [2, 4096, 2048]); mm_391 = None + convert_element_type_2258 = torch.ops.prims.convert_element_type.default(mm_390, torch.float32); mm_390 = None + reduce_scatter_tensor_153 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2258, 'avg', 64, '0'); convert_element_type_2258 = None + wait_tensor_734 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_153); reduce_scatter_tensor_153 = None + slice_171 = torch.ops.aten.slice.Tensor(permute_943, 3, 0, 128) + slice_172 = torch.ops.aten.slice.Tensor(permute_943, 3, 128, 192); permute_943 = None + convert_element_type_2259 = torch.ops.prims.convert_element_type.default(slice_172, torch.float32); slice_172 = None + view_1986 = torch.ops.aten.view.default(convert_element_type_2259, [2, 4096, 16, 32, 2]); convert_element_type_2259 = None + view_as_complex_75 = torch.ops.aten.view_as_complex.default(view_1986); view_1986 = None + mul_1658 = torch.ops.aten.mul.Tensor(view_as_complex_75, clone_9); view_as_complex_75 = None + view_as_real_75 = torch.ops.aten.view_as_real.default(mul_1658); mul_1658 = None + view_1987 = torch.ops.aten.view.default(view_as_real_75, [2, 4096, 16, 64]); view_as_real_75 = None + convert_element_type_2260 = torch.ops.prims.convert_element_type.default(view_1987, torch.bfloat16); view_1987 = None + cat_112 = torch.ops.aten.cat.default([slice_171, convert_element_type_2260], 3); slice_171 = convert_element_type_2260 = None + view_1988 = torch.ops.aten.view.default(cat_112, [2, 4096, 3072]); cat_112 = None + view_1989 = torch.ops.aten.view.default(view_1988, [8192, 3072]); view_1988 = None + permute_952 = torch.ops.aten.permute.default(view_1989, [1, 0]) + mm_392 = torch.ops.aten.mm.default(permute_952, view_1041); permute_952 = view_1041 = None + convert_element_type_850 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16); primals_264 = None + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_850, 64, '0'); convert_element_type_850 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + permute_236 = torch.ops.aten.permute.default(wait_tensor_327, [1, 0]); wait_tensor_327 = None + permute_954 = torch.ops.aten.permute.default(permute_236, [1, 0]); permute_236 = None + mm_393 = torch.ops.aten.mm.default(view_1989, permute_954); view_1989 = permute_954 = None + view_1990 = torch.ops.aten.view.default(mm_393, [2, 4096, 2048]); mm_393 = None + add_1938 = torch.ops.aten.add.Tensor(view_1985, view_1990); view_1985 = view_1990 = None + convert_element_type_2265 = torch.ops.prims.convert_element_type.default(mm_392, torch.float32); mm_392 = None + reduce_scatter_tensor_154 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2265, 'avg', 64, '0'); convert_element_type_2265 = None + wait_tensor_735 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_154); reduce_scatter_tensor_154 = None + convert_element_type_2266 = torch.ops.prims.convert_element_type.default(add_1938, torch.float32); add_1938 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16); primals_263 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_847, 64, '0'); convert_element_type_847 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + convert_element_type_2268 = torch.ops.prims.convert_element_type.default(wait_tensor_326, torch.float32); wait_tensor_326 = None + mul_1659 = torch.ops.aten.mul.Tensor(convert_element_type_2266, convert_element_type_2268); convert_element_type_2268 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(add_1025, torch.float32); add_1025 = None + mul_744 = torch.ops.aten.mul.Tensor(convert_element_type_848, rsqrt_48); convert_element_type_848 = None + mul_1661 = torch.ops.aten.mul.Tensor(mul_744, mul_1659) + sum_193 = torch.ops.aten.sum.dim_IntList(mul_1661, [2], True); mul_1661 = None + div_197 = torch.ops.aten.div.Tensor(mul_744, 2048) + mul_1662 = torch.ops.aten.mul.Tensor(div_197, sum_193); div_197 = sum_193 = None + sub_690 = torch.ops.aten.sub.Tensor(mul_1659, mul_1662); mul_1659 = mul_1662 = None + mul_1663 = torch.ops.aten.mul.Tensor(sub_690, rsqrt_48); sub_690 = rsqrt_48 = None + mul_1664 = torch.ops.aten.mul.Tensor(convert_element_type_2266, mul_744); convert_element_type_2266 = mul_744 = None + sum_194 = torch.ops.aten.sum.dim_IntList(mul_1664, [0, 1]); mul_1664 = None + convert_element_type_2269 = torch.ops.prims.convert_element_type.default(mul_1663, torch.bfloat16); mul_1663 = None + add_1939 = torch.ops.aten.add.Tensor(add_1937, convert_element_type_2269); add_1937 = convert_element_type_2269 = None + convert_element_type_default_49 = torch.ops.prims.convert_element_type.default(sum_194, torch.float32); sum_194 = None + reduce_scatter_tensor_155 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_49, 'avg', 64, '0'); convert_element_type_default_49 = None + wait_tensor_736 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_155); reduce_scatter_tensor_155 = None + view_1991 = torch.ops.aten.view.default(add_1939, [8192, 2048]) + unsqueeze_64 = torch.ops.aten.unsqueeze.default(view_1991, 1) + convert_element_type_2272 = torch.ops.prims.convert_element_type.default(unsqueeze_64, torch.float32); unsqueeze_64 = None + bmm_48 = torch.ops.aten.bmm.default(permute_956, convert_element_type_2272); permute_956 = None + bmm_49 = torch.ops.aten.bmm.default(convert_element_type_2272, permute_957); convert_element_type_2272 = permute_957 = None + convert_element_type_2273 = torch.ops.prims.convert_element_type.default(bmm_48, torch.bfloat16); bmm_48 = None + view_1992 = torch.ops.aten.view.default(bmm_49, [8192, 6]); bmm_49 = None + view_1993 = torch.ops.aten.view.default(convert_element_type_2273, [49152, 2048]); convert_element_type_2273 = None + index_74 = torch.ops.aten.index.Tensor(view_1993, [getitem_217]); view_1993 = getitem_217 = None + permute_958 = torch.ops.aten.permute.default(view_1991, [1, 0]) + mm_394 = torch.ops.aten.mm.default(permute_958, mul_741); permute_958 = mul_741 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16); primals_262 = None + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 64, '0'); convert_element_type_842 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_235 = torch.ops.aten.permute.default(wait_tensor_325, [1, 0]); wait_tensor_325 = None + permute_960 = torch.ops.aten.permute.default(permute_235, [1, 0]); permute_235 = None + mm_395 = torch.ops.aten.mm.default(view_1991, permute_960); view_1991 = permute_960 = None + convert_element_type_2278 = torch.ops.prims.convert_element_type.default(mm_394, torch.float32); mm_394 = None + reduce_scatter_tensor_156 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2278, 'avg', 64, '0'); convert_element_type_2278 = None + wait_tensor_737 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_156); reduce_scatter_tensor_156 = None + convert_element_type_837 = torch.ops.prims.convert_element_type.default(mm_124, torch.float32); mm_124 = None + neg_30 = torch.ops.aten.neg.default(convert_element_type_837) + exp_45 = torch.ops.aten.exp.default(neg_30); neg_30 = None + add_1020 = torch.ops.aten.add.Tensor(exp_45, 1); exp_45 = None + div_75 = torch.ops.aten.div.Tensor(convert_element_type_837, add_1020) + convert_element_type_838 = torch.ops.prims.convert_element_type.default(div_75, torch.bfloat16); div_75 = None + mul_1665 = torch.ops.aten.mul.Tensor(mm_395, convert_element_type_838); convert_element_type_838 = None + mul_1666 = torch.ops.aten.mul.Tensor(mm_395, mm_125); mm_395 = mm_125 = None + permute_962 = torch.ops.aten.permute.default(mul_1665, [1, 0]) + mm_396 = torch.ops.aten.mm.default(permute_962, view_996); permute_962 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16); primals_261 = None + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_839, 64, '0'); convert_element_type_839 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_234 = torch.ops.aten.permute.default(wait_tensor_324, [1, 0]); wait_tensor_324 = None + permute_964 = torch.ops.aten.permute.default(permute_234, [1, 0]); permute_234 = None + mm_397 = torch.ops.aten.mm.default(mul_1665, permute_964); mul_1665 = permute_964 = None + convert_element_type_2283 = torch.ops.prims.convert_element_type.default(mm_396, torch.float32); mm_396 = None + reduce_scatter_tensor_157 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2283, 'avg', 64, '0'); convert_element_type_2283 = None + wait_tensor_738 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_157); reduce_scatter_tensor_157 = None + convert_element_type_2284 = torch.ops.prims.convert_element_type.default(mul_1666, torch.float32); mul_1666 = None + reciprocal_22 = torch.ops.aten.reciprocal.default(add_1020); add_1020 = None + mul_1667 = torch.ops.aten.mul.Tensor(reciprocal_22, 1); reciprocal_22 = None + mul_1668 = torch.ops.aten.mul.Tensor(convert_element_type_2284, mul_1667); convert_element_type_2284 = None + sub_691 = torch.ops.aten.sub.Tensor(1, mul_1667); mul_1667 = None + mul_1669 = torch.ops.aten.mul.Tensor(convert_element_type_837, sub_691); convert_element_type_837 = sub_691 = None + add_1941 = torch.ops.aten.add.Tensor(mul_1669, 1); mul_1669 = None + mul_1670 = torch.ops.aten.mul.Tensor(mul_1668, add_1941); mul_1668 = add_1941 = None + convert_element_type_2286 = torch.ops.prims.convert_element_type.default(mul_1670, torch.bfloat16); mul_1670 = None + permute_966 = torch.ops.aten.permute.default(convert_element_type_2286, [1, 0]) + mm_398 = torch.ops.aten.mm.default(permute_966, view_996); permute_966 = None + convert_element_type_834 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16); primals_260 = None + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_834, 64, '0'); convert_element_type_834 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_323, [1, 0]); wait_tensor_323 = None + permute_968 = torch.ops.aten.permute.default(permute_233, [1, 0]); permute_233 = None + mm_399 = torch.ops.aten.mm.default(convert_element_type_2286, permute_968); convert_element_type_2286 = permute_968 = None + add_1942 = torch.ops.aten.add.Tensor(mm_397, mm_399); mm_397 = mm_399 = None + convert_element_type_2291 = torch.ops.prims.convert_element_type.default(mm_398, torch.float32); mm_398 = None + reduce_scatter_tensor_158 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2291, 'avg', 64, '0'); convert_element_type_2291 = None + wait_tensor_739 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_158); reduce_scatter_tensor_158 = None + all_to_all_single_100 = torch.ops._c10d_functional.all_to_all_single.default(index_74, [_local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239], [_local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231], '521'); index_74 = None + wait_tensor_740 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_100); all_to_all_single_100 = None + full_392 = torch.ops.aten.full.default([sym_size_int_57, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_57 = None + slice_scatter_11 = torch.ops.aten.slice_scatter.default(full_392, wait_tensor_740, 0, 0, -1); wait_tensor_740 = None + index_75 = torch.ops.aten.index.Tensor(slice_scatter_11, [getitem_218]); slice_scatter_11 = None + permute_970 = torch.ops.aten.permute.default(index_75, [1, 0]) + _grouped_mm_144 = torch.ops.aten._grouped_mm.default(permute_970, mul_721, cumsum_44); permute_970 = mul_721 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16); primals_258 = None + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_828, 8, '513'); convert_element_type_828 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_318, [0, 2, 1]); wait_tensor_318 = None + permute_972 = torch.ops.aten.permute.default(permute_232, [0, 2, 1]); permute_232 = None + _grouped_mm_145 = torch.ops.aten._grouped_mm.default(index_75, permute_972, cumsum_44); index_75 = permute_972 = None + convert_element_type_832 = torch.ops.prims.convert_element_type.default(_grouped_mm_42, torch.float32); _grouped_mm_42 = None + neg_29 = torch.ops.aten.neg.default(convert_element_type_832) + exp_44 = torch.ops.aten.exp.default(neg_29); neg_29 = None + add_984 = torch.ops.aten.add.Tensor(exp_44, 1); exp_44 = None + div_74 = torch.ops.aten.div.Tensor(convert_element_type_832, add_984) + convert_element_type_833 = torch.ops.prims.convert_element_type.default(div_74, torch.bfloat16); div_74 = None + mul_1671 = torch.ops.aten.mul.Tensor(_grouped_mm_145, convert_element_type_833); convert_element_type_833 = None + mul_1672 = torch.ops.aten.mul.Tensor(_grouped_mm_145, _grouped_mm_43); _grouped_mm_145 = _grouped_mm_43 = None + permute_974 = torch.ops.aten.permute.default(mul_1671, [1, 0]) + _grouped_mm_146 = torch.ops.aten._grouped_mm.default(permute_974, index_29, cumsum_44); permute_974 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16); primals_259 = None + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 8, '513'); convert_element_type_829 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_319, [0, 2, 1]); wait_tensor_319 = None + permute_976 = torch.ops.aten.permute.default(permute_231, [0, 2, 1]); permute_231 = None + _grouped_mm_147 = torch.ops.aten._grouped_mm.default(mul_1671, permute_976, cumsum_44); mul_1671 = permute_976 = None + convert_element_type_2292 = torch.ops.prims.convert_element_type.default(mul_1672, torch.float32); mul_1672 = None + reciprocal_23 = torch.ops.aten.reciprocal.default(add_984); add_984 = None + mul_1673 = torch.ops.aten.mul.Tensor(reciprocal_23, 1); reciprocal_23 = None + mul_1674 = torch.ops.aten.mul.Tensor(convert_element_type_2292, mul_1673); convert_element_type_2292 = None + sub_692 = torch.ops.aten.sub.Tensor(1, mul_1673); mul_1673 = None + mul_1675 = torch.ops.aten.mul.Tensor(convert_element_type_832, sub_692); convert_element_type_832 = sub_692 = None + add_1944 = torch.ops.aten.add.Tensor(mul_1675, 1); mul_1675 = None + mul_1676 = torch.ops.aten.mul.Tensor(mul_1674, add_1944); mul_1674 = add_1944 = None + convert_element_type_2294 = torch.ops.prims.convert_element_type.default(mul_1676, torch.bfloat16); mul_1676 = None + permute_978 = torch.ops.aten.permute.default(convert_element_type_2294, [1, 0]) + _grouped_mm_148 = torch.ops.aten._grouped_mm.default(permute_978, index_29, cumsum_44); permute_978 = index_29 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16); primals_257 = None + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 8, '513'); convert_element_type_826 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_316, [0, 2, 1]); wait_tensor_316 = None + permute_980 = torch.ops.aten.permute.default(permute_230, [0, 2, 1]); permute_230 = None + _grouped_mm_149 = torch.ops.aten._grouped_mm.default(convert_element_type_2294, permute_980, cumsum_44); convert_element_type_2294 = permute_980 = cumsum_44 = None + add_1945 = torch.ops.aten.add.Tensor(_grouped_mm_147, _grouped_mm_149); _grouped_mm_147 = _grouped_mm_149 = None + convert_element_type_2295 = torch.ops.prims.convert_element_type.default(_grouped_mm_146, torch.float32); _grouped_mm_146 = None + div_198 = torch.ops.aten.div.Tensor(convert_element_type_2295, 64); convert_element_type_2295 = None + reduce_scatter_tensor_159 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_198, 'sum', 8, '513'); div_198 = None + wait_tensor_741 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_159); reduce_scatter_tensor_159 = None + convert_element_type_2296 = torch.ops.prims.convert_element_type.default(_grouped_mm_144, torch.float32); _grouped_mm_144 = None + div_199 = torch.ops.aten.div.Tensor(convert_element_type_2296, 64); convert_element_type_2296 = None + reduce_scatter_tensor_160 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_199, 'sum', 8, '513'); div_199 = None + wait_tensor_742 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_160); reduce_scatter_tensor_160 = None + convert_element_type_2297 = torch.ops.prims.convert_element_type.default(_grouped_mm_148, torch.float32); _grouped_mm_148 = None + div_200 = torch.ops.aten.div.Tensor(convert_element_type_2297, 64); convert_element_type_2297 = None + reduce_scatter_tensor_161 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_200, 'sum', 8, '513'); div_200 = None + wait_tensor_743 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_161); reduce_scatter_tensor_161 = None + index_put_74 = torch.ops.aten.index_put.default(full_392, [getitem_218], add_1945, True); full_392 = getitem_218 = add_1945 = None + slice_173 = torch.ops.aten.slice.Tensor(index_put_74, 0, 0, add_1946); index_put_74 = add_1946 = None + all_to_all_single_101 = torch.ops._c10d_functional.all_to_all_single.default(slice_173, [_local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231], [_local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239], '521'); slice_173 = _local_scalar_dense_224 = _local_scalar_dense_225 = _local_scalar_dense_226 = _local_scalar_dense_227 = _local_scalar_dense_228 = _local_scalar_dense_229 = _local_scalar_dense_230 = _local_scalar_dense_231 = _local_scalar_dense_232 = _local_scalar_dense_233 = _local_scalar_dense_234 = _local_scalar_dense_235 = _local_scalar_dense_236 = _local_scalar_dense_237 = _local_scalar_dense_238 = _local_scalar_dense_239 = None + wait_tensor_744 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_101); all_to_all_single_101 = None + index_put_75 = torch.ops.aten.index_put.default(full_default_52, [div_72], wait_tensor_744, True); div_72 = wait_tensor_744 = None + add_1950 = torch.ops.aten.add.Tensor(add_1942, index_put_75); add_1942 = index_put_75 = None + mul_1677 = torch.ops.aten.mul.Tensor(view_1992, 1.0); view_1992 = None + scatter_add_11 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_215, mul_1677); getitem_215 = mul_1677 = None + convert_element_type_821 = torch.ops.prims.convert_element_type.default(mm_123, torch.float32); mm_123 = None + sub_336 = torch.ops.aten.sub.Tensor(convert_element_type_821, amax_14); convert_element_type_821 = amax_14 = None + exp_43 = torch.ops.aten.exp.default(sub_336); sub_336 = None + div_71 = torch.ops.aten.div.Tensor(exp_43, sum_57); exp_43 = sum_57 = None + mul_1678 = torch.ops.aten.mul.Tensor(scatter_add_11, div_71); scatter_add_11 = None + sum_195 = torch.ops.aten.sum.dim_IntList(mul_1678, [1], True) + neg_88 = torch.ops.aten.neg.default(div_71); div_71 = None + fma_11 = torch.ops.prims.fma.default(neg_88, sum_195, mul_1678); neg_88 = sum_195 = mul_1678 = None + convert_element_type_2298 = torch.ops.prims.convert_element_type.default(fma_11, torch.bfloat16); fma_11 = None + permute_982 = torch.ops.aten.permute.default(convert_element_type_2298, [1, 0]) + mm_400 = torch.ops.aten.mm.default(permute_982, view_996); permute_982 = view_996 = None + convert_element_type_818 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16); primals_255 = None + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_818, 64, '0'); convert_element_type_818 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_312, [1, 0]); wait_tensor_312 = None + permute_984 = torch.ops.aten.permute.default(permute_229, [1, 0]); permute_229 = None + mm_401 = torch.ops.aten.mm.default(convert_element_type_2298, permute_984); convert_element_type_2298 = permute_984 = None + add_1951 = torch.ops.aten.add.Tensor(add_1950, mm_401); add_1950 = mm_401 = None + convert_element_type_2303 = torch.ops.prims.convert_element_type.default(mm_400, torch.float32); mm_400 = None + reduce_scatter_tensor_162 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2303, 'avg', 64, '0'); convert_element_type_2303 = None + wait_tensor_745 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_162); reduce_scatter_tensor_162 = None + view_1994 = torch.ops.aten.view.default(add_1951, [2, 4096, 2048]); add_1951 = None + convert_element_type_2304 = torch.ops.prims.convert_element_type.default(view_1994, torch.float32); view_1994 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16); primals_253 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 64, '0'); convert_element_type_815 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + convert_element_type_2306 = torch.ops.prims.convert_element_type.default(wait_tensor_311, torch.float32); wait_tensor_311 = None + mul_1679 = torch.ops.aten.mul.Tensor(convert_element_type_2304, convert_element_type_2306); convert_element_type_2306 = None + convert_element_type_816 = torch.ops.prims.convert_element_type.default(add_960, torch.float32); add_960 = None + mul_701 = torch.ops.aten.mul.Tensor(convert_element_type_816, rsqrt_47); convert_element_type_816 = None + mul_1681 = torch.ops.aten.mul.Tensor(mul_701, mul_1679) + sum_196 = torch.ops.aten.sum.dim_IntList(mul_1681, [2], True); mul_1681 = None + div_201 = torch.ops.aten.div.Tensor(mul_701, 2048) + mul_1682 = torch.ops.aten.mul.Tensor(div_201, sum_196); div_201 = sum_196 = None + sub_694 = torch.ops.aten.sub.Tensor(mul_1679, mul_1682); mul_1679 = mul_1682 = None + mul_1683 = torch.ops.aten.mul.Tensor(sub_694, rsqrt_47); sub_694 = rsqrt_47 = None + mul_1684 = torch.ops.aten.mul.Tensor(convert_element_type_2304, mul_701); convert_element_type_2304 = mul_701 = None + sum_197 = torch.ops.aten.sum.dim_IntList(mul_1684, [0, 1]); mul_1684 = None + convert_element_type_2307 = torch.ops.prims.convert_element_type.default(mul_1683, torch.bfloat16); mul_1683 = None + add_1952 = torch.ops.aten.add.Tensor(add_1939, convert_element_type_2307); add_1939 = convert_element_type_2307 = None + convert_element_type_default_48 = torch.ops.prims.convert_element_type.default(sum_197, torch.float32); sum_197 = None + reduce_scatter_tensor_163 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_48, 'avg', 64, '0'); convert_element_type_default_48 = None + wait_tensor_746 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_163); reduce_scatter_tensor_163 = None + view_1995 = torch.ops.aten.view.default(add_1952, [8192, 2048]) + permute_986 = torch.ops.aten.permute.default(view_1995, [1, 0]) + permute_227 = torch.ops.aten.permute.default(getitem_211, [0, 2, 1, 3]) + view_991 = torch.ops.aten.view.default(permute_227, [2, 4096, -1]); permute_227 = None + view_993 = torch.ops.aten.view.default(view_991, [8192, 2048]); view_991 = None + mm_402 = torch.ops.aten.mm.default(permute_986, view_993); permute_986 = view_993 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16); primals_252 = None + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 64, '0'); convert_element_type_812 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_310, [1, 0]); wait_tensor_310 = None + permute_988 = torch.ops.aten.permute.default(permute_228, [1, 0]); permute_228 = None + mm_403 = torch.ops.aten.mm.default(view_1995, permute_988); view_1995 = permute_988 = None + view_1996 = torch.ops.aten.view.default(mm_403, [2, 4096, 2048]); mm_403 = None + convert_element_type_2314 = torch.ops.prims.convert_element_type.default(mm_402, torch.float32); mm_402 = None + reduce_scatter_tensor_164 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2314, 'avg', 64, '0'); convert_element_type_2314 = None + wait_tensor_747 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_164); reduce_scatter_tensor_164 = None + view_1997 = torch.ops.aten.view.default(view_1996, [2, 4096, 16, 128]); view_1996 = None + permute_990 = torch.ops.aten.permute.default(view_1997, [0, 2, 1, 3]); view_1997 = None + fw_graph11 = self.fw_graph11 + joint_graph11 = self.joint_graph11 + mask_graph11 = self.mask_graph11 + flex_attention_backward_11 = torch.ops.higher_order.flex_attention_backward(permute_224, permute_225, permute_226, getitem_211, getitem_212, permute_990, None, fw_graph11, joint_graph11, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph11), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_224 = permute_225 = permute_226 = getitem_211 = getitem_212 = permute_990 = fw_graph11 = joint_graph11 = mask_graph11 = None + getitem_417 = flex_attention_backward_11[0] + getitem_418 = flex_attention_backward_11[1] + getitem_419 = flex_attention_backward_11[2]; flex_attention_backward_11 = None + permute_991 = torch.ops.aten.permute.default(getitem_419, [0, 2, 1, 3]); getitem_419 = None + permute_992 = torch.ops.aten.permute.default(getitem_418, [0, 2, 1, 3]); getitem_418 = None + permute_993 = torch.ops.aten.permute.default(getitem_417, [0, 2, 1, 3]); getitem_417 = None + slice_175 = torch.ops.aten.slice.Tensor(permute_992, 3, 0, 128) + slice_176 = torch.ops.aten.slice.Tensor(permute_992, 3, 128, 192); permute_992 = None + sum_198 = torch.ops.aten.sum.dim_IntList(slice_176, [2], True); slice_176 = None + cat_113 = torch.ops.aten.cat.default([slice_175, permute_991], 3); slice_175 = permute_991 = None + view_1998 = torch.ops.aten.view.default(cat_113, [2, 4096, 4096]); cat_113 = None + view_1999 = torch.ops.aten.view.default(view_1998, [8192, 4096]); view_1998 = None + permute_994 = torch.ops.aten.permute.default(view_1999, [1, 0]) + mm_404 = torch.ops.aten.mm.default(permute_994, view_988); permute_994 = view_988 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16); primals_251 = None + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 64, '0'); convert_element_type_809 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + permute_223 = torch.ops.aten.permute.default(wait_tensor_309, [1, 0]); wait_tensor_309 = None + permute_996 = torch.ops.aten.permute.default(permute_223, [1, 0]); permute_223 = None + mm_405 = torch.ops.aten.mm.default(view_1999, permute_996); view_1999 = permute_996 = None + view_2000 = torch.ops.aten.view.default(mm_405, [2, 4096, 512]); mm_405 = None + convert_element_type_2319 = torch.ops.prims.convert_element_type.default(mm_404, torch.float32); mm_404 = None + reduce_scatter_tensor_165 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2319, 'avg', 64, '0'); convert_element_type_2319 = None + wait_tensor_748 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_165); reduce_scatter_tensor_165 = None + convert_element_type_2320 = torch.ops.prims.convert_element_type.default(view_2000, torch.float32); view_2000 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16); primals_250 = None + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_806, 64, '0'); convert_element_type_806 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + convert_element_type_2322 = torch.ops.prims.convert_element_type.default(wait_tensor_308, torch.float32); wait_tensor_308 = None + mul_1685 = torch.ops.aten.mul.Tensor(convert_element_type_2320, convert_element_type_2322); convert_element_type_2322 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(getitem_207, torch.float32); getitem_207 = None + mul_699 = torch.ops.aten.mul.Tensor(convert_element_type_807, rsqrt_46); convert_element_type_807 = None + mul_1687 = torch.ops.aten.mul.Tensor(mul_699, mul_1685) + sum_199 = torch.ops.aten.sum.dim_IntList(mul_1687, [2], True); mul_1687 = None + div_202 = torch.ops.aten.div.Tensor(mul_699, 512) + mul_1688 = torch.ops.aten.mul.Tensor(div_202, sum_199); div_202 = sum_199 = None + sub_695 = torch.ops.aten.sub.Tensor(mul_1685, mul_1688); mul_1685 = mul_1688 = None + mul_1689 = torch.ops.aten.mul.Tensor(sub_695, rsqrt_46); sub_695 = rsqrt_46 = None + mul_1690 = torch.ops.aten.mul.Tensor(convert_element_type_2320, mul_699); convert_element_type_2320 = mul_699 = None + sum_200 = torch.ops.aten.sum.dim_IntList(mul_1690, [0, 1]); mul_1690 = None + convert_element_type_2323 = torch.ops.prims.convert_element_type.default(mul_1689, torch.bfloat16); mul_1689 = None + convert_element_type_default_47 = torch.ops.prims.convert_element_type.default(sum_200, torch.float32); sum_200 = None + reduce_scatter_tensor_166 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_47, 'avg', 64, '0'); convert_element_type_default_47 = None + wait_tensor_749 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_166); reduce_scatter_tensor_166 = None + convert_element_type_2326 = torch.ops.prims.convert_element_type.default(sum_198, torch.float32); sum_198 = None + view_2001 = torch.ops.aten.view.default(convert_element_type_2326, [2, 4096, 1, 32, 2]); convert_element_type_2326 = None + view_as_complex_76 = torch.ops.aten.view_as_complex.default(view_2001); view_2001 = None + mul_1691 = torch.ops.aten.mul.Tensor(view_as_complex_76, clone_9); view_as_complex_76 = None + view_as_real_76 = torch.ops.aten.view_as_real.default(mul_1691); mul_1691 = None + view_2002 = torch.ops.aten.view.default(view_as_real_76, [2, 4096, 1, 64]); view_as_real_76 = None + convert_element_type_2327 = torch.ops.prims.convert_element_type.default(view_2002, torch.bfloat16); view_2002 = None + squeeze_37 = torch.ops.aten.squeeze.dim(convert_element_type_2327, 2); convert_element_type_2327 = None + cat_114 = torch.ops.aten.cat.default([convert_element_type_2323, squeeze_37], 2); convert_element_type_2323 = squeeze_37 = None + view_2003 = torch.ops.aten.view.default(cat_114, [8192, 576]); cat_114 = None + permute_998 = torch.ops.aten.permute.default(view_2003, [1, 0]) + mm_406 = torch.ops.aten.mm.default(permute_998, view_974); permute_998 = None + convert_element_type_801 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16); primals_249 = None + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_801, 64, '0'); convert_element_type_801 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_307, [1, 0]); wait_tensor_307 = None + permute_1000 = torch.ops.aten.permute.default(permute_222, [1, 0]); permute_222 = None + mm_407 = torch.ops.aten.mm.default(view_2003, permute_1000); view_2003 = permute_1000 = None + view_2004 = torch.ops.aten.view.default(mm_407, [2, 4096, 2048]); mm_407 = None + convert_element_type_2332 = torch.ops.prims.convert_element_type.default(mm_406, torch.float32); mm_406 = None + reduce_scatter_tensor_167 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2332, 'avg', 64, '0'); convert_element_type_2332 = None + wait_tensor_750 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_167); reduce_scatter_tensor_167 = None + slice_177 = torch.ops.aten.slice.Tensor(permute_993, 3, 0, 128) + slice_178 = torch.ops.aten.slice.Tensor(permute_993, 3, 128, 192); permute_993 = None + convert_element_type_2333 = torch.ops.prims.convert_element_type.default(slice_178, torch.float32); slice_178 = None + view_2005 = torch.ops.aten.view.default(convert_element_type_2333, [2, 4096, 16, 32, 2]); convert_element_type_2333 = None + view_as_complex_77 = torch.ops.aten.view_as_complex.default(view_2005); view_2005 = None + mul_1692 = torch.ops.aten.mul.Tensor(view_as_complex_77, clone_9); view_as_complex_77 = None + view_as_real_77 = torch.ops.aten.view_as_real.default(mul_1692); mul_1692 = None + view_2006 = torch.ops.aten.view.default(view_as_real_77, [2, 4096, 16, 64]); view_as_real_77 = None + convert_element_type_2334 = torch.ops.prims.convert_element_type.default(view_2006, torch.bfloat16); view_2006 = None + cat_115 = torch.ops.aten.cat.default([slice_177, convert_element_type_2334], 3); slice_177 = convert_element_type_2334 = None + view_2007 = torch.ops.aten.view.default(cat_115, [2, 4096, 3072]); cat_115 = None + view_2008 = torch.ops.aten.view.default(view_2007, [8192, 3072]); view_2007 = None + permute_1002 = torch.ops.aten.permute.default(view_2008, [1, 0]) + mm_408 = torch.ops.aten.mm.default(permute_1002, view_974); permute_1002 = view_974 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16); primals_248 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 64, '0'); convert_element_type_796 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_306, [1, 0]); wait_tensor_306 = None + permute_1004 = torch.ops.aten.permute.default(permute_221, [1, 0]); permute_221 = None + mm_409 = torch.ops.aten.mm.default(view_2008, permute_1004); view_2008 = permute_1004 = None + view_2009 = torch.ops.aten.view.default(mm_409, [2, 4096, 2048]); mm_409 = None + add_1953 = torch.ops.aten.add.Tensor(view_2004, view_2009); view_2004 = view_2009 = None + convert_element_type_2339 = torch.ops.prims.convert_element_type.default(mm_408, torch.float32); mm_408 = None + reduce_scatter_tensor_168 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2339, 'avg', 64, '0'); convert_element_type_2339 = None + wait_tensor_751 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_168); reduce_scatter_tensor_168 = None + convert_element_type_2340 = torch.ops.prims.convert_element_type.default(add_1953, torch.float32); add_1953 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16); primals_247 = None + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 64, '0'); convert_element_type_793 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_2342 = torch.ops.prims.convert_element_type.default(wait_tensor_305, torch.float32); wait_tensor_305 = None + mul_1693 = torch.ops.aten.mul.Tensor(convert_element_type_2340, convert_element_type_2342); convert_element_type_2342 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_957, torch.float32); add_957 = None + mul_695 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_45); convert_element_type_794 = None + mul_1695 = torch.ops.aten.mul.Tensor(mul_695, mul_1693) + sum_201 = torch.ops.aten.sum.dim_IntList(mul_1695, [2], True); mul_1695 = None + div_203 = torch.ops.aten.div.Tensor(mul_695, 2048) + mul_1696 = torch.ops.aten.mul.Tensor(div_203, sum_201); div_203 = sum_201 = None + sub_696 = torch.ops.aten.sub.Tensor(mul_1693, mul_1696); mul_1693 = mul_1696 = None + mul_1697 = torch.ops.aten.mul.Tensor(sub_696, rsqrt_45); sub_696 = rsqrt_45 = None + mul_1698 = torch.ops.aten.mul.Tensor(convert_element_type_2340, mul_695); convert_element_type_2340 = mul_695 = None + sum_202 = torch.ops.aten.sum.dim_IntList(mul_1698, [0, 1]); mul_1698 = None + convert_element_type_2343 = torch.ops.prims.convert_element_type.default(mul_1697, torch.bfloat16); mul_1697 = None + add_1954 = torch.ops.aten.add.Tensor(add_1952, convert_element_type_2343); add_1952 = convert_element_type_2343 = None + convert_element_type_default_46 = torch.ops.prims.convert_element_type.default(sum_202, torch.float32); sum_202 = None + reduce_scatter_tensor_169 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_46, 'avg', 64, '0'); convert_element_type_default_46 = None + wait_tensor_752 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_169); reduce_scatter_tensor_169 = None + view_2010 = torch.ops.aten.view.default(add_1954, [8192, 2048]) + unsqueeze_65 = torch.ops.aten.unsqueeze.default(view_2010, 1) + convert_element_type_2346 = torch.ops.prims.convert_element_type.default(unsqueeze_65, torch.float32); unsqueeze_65 = None + bmm_50 = torch.ops.aten.bmm.default(permute_1006, convert_element_type_2346); permute_1006 = None + bmm_51 = torch.ops.aten.bmm.default(convert_element_type_2346, permute_1007); convert_element_type_2346 = permute_1007 = None + convert_element_type_2347 = torch.ops.prims.convert_element_type.default(bmm_50, torch.bfloat16); bmm_50 = None + view_2011 = torch.ops.aten.view.default(bmm_51, [8192, 6]); bmm_51 = None + view_2012 = torch.ops.aten.view.default(convert_element_type_2347, [49152, 2048]); convert_element_type_2347 = None + index_76 = torch.ops.aten.index.Tensor(view_2012, [getitem_203]); view_2012 = getitem_203 = None + permute_1008 = torch.ops.aten.permute.default(view_2010, [1, 0]) + mm_410 = torch.ops.aten.mm.default(permute_1008, mul_692); permute_1008 = mul_692 = None + convert_element_type_788 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16); primals_246 = None + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_788, 64, '0'); convert_element_type_788 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_304, [1, 0]); wait_tensor_304 = None + permute_1010 = torch.ops.aten.permute.default(permute_220, [1, 0]); permute_220 = None + mm_411 = torch.ops.aten.mm.default(view_2010, permute_1010); view_2010 = permute_1010 = None + convert_element_type_2352 = torch.ops.prims.convert_element_type.default(mm_410, torch.float32); mm_410 = None + reduce_scatter_tensor_170 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2352, 'avg', 64, '0'); convert_element_type_2352 = None + wait_tensor_753 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_170); reduce_scatter_tensor_170 = None + convert_element_type_783 = torch.ops.prims.convert_element_type.default(mm_116, torch.float32); mm_116 = None + neg_28 = torch.ops.aten.neg.default(convert_element_type_783) + exp_42 = torch.ops.aten.exp.default(neg_28); neg_28 = None + add_952 = torch.ops.aten.add.Tensor(exp_42, 1); exp_42 = None + div_70 = torch.ops.aten.div.Tensor(convert_element_type_783, add_952) + convert_element_type_784 = torch.ops.prims.convert_element_type.default(div_70, torch.bfloat16); div_70 = None + mul_1699 = torch.ops.aten.mul.Tensor(mm_411, convert_element_type_784); convert_element_type_784 = None + mul_1700 = torch.ops.aten.mul.Tensor(mm_411, mm_117); mm_411 = mm_117 = None + permute_1012 = torch.ops.aten.permute.default(mul_1699, [1, 0]) + mm_412 = torch.ops.aten.mm.default(permute_1012, view_929); permute_1012 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16); primals_245 = None + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_785, 64, '0'); convert_element_type_785 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_303, [1, 0]); wait_tensor_303 = None + permute_1014 = torch.ops.aten.permute.default(permute_219, [1, 0]); permute_219 = None + mm_413 = torch.ops.aten.mm.default(mul_1699, permute_1014); mul_1699 = permute_1014 = None + convert_element_type_2357 = torch.ops.prims.convert_element_type.default(mm_412, torch.float32); mm_412 = None + reduce_scatter_tensor_171 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2357, 'avg', 64, '0'); convert_element_type_2357 = None + wait_tensor_754 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_171); reduce_scatter_tensor_171 = None + convert_element_type_2358 = torch.ops.prims.convert_element_type.default(mul_1700, torch.float32); mul_1700 = None + reciprocal_24 = torch.ops.aten.reciprocal.default(add_952); add_952 = None + mul_1701 = torch.ops.aten.mul.Tensor(reciprocal_24, 1); reciprocal_24 = None + mul_1702 = torch.ops.aten.mul.Tensor(convert_element_type_2358, mul_1701); convert_element_type_2358 = None + sub_697 = torch.ops.aten.sub.Tensor(1, mul_1701); mul_1701 = None + mul_1703 = torch.ops.aten.mul.Tensor(convert_element_type_783, sub_697); convert_element_type_783 = sub_697 = None + add_1956 = torch.ops.aten.add.Tensor(mul_1703, 1); mul_1703 = None + mul_1704 = torch.ops.aten.mul.Tensor(mul_1702, add_1956); mul_1702 = add_1956 = None + convert_element_type_2360 = torch.ops.prims.convert_element_type.default(mul_1704, torch.bfloat16); mul_1704 = None + permute_1016 = torch.ops.aten.permute.default(convert_element_type_2360, [1, 0]) + mm_414 = torch.ops.aten.mm.default(permute_1016, view_929); permute_1016 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_780, 64, '0'); convert_element_type_780 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_302, [1, 0]); wait_tensor_302 = None + permute_1018 = torch.ops.aten.permute.default(permute_218, [1, 0]); permute_218 = None + mm_415 = torch.ops.aten.mm.default(convert_element_type_2360, permute_1018); convert_element_type_2360 = permute_1018 = None + add_1957 = torch.ops.aten.add.Tensor(mm_413, mm_415); mm_413 = mm_415 = None + convert_element_type_2365 = torch.ops.prims.convert_element_type.default(mm_414, torch.float32); mm_414 = None + reduce_scatter_tensor_172 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2365, 'avg', 64, '0'); convert_element_type_2365 = None + wait_tensor_755 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_172); reduce_scatter_tensor_172 = None + all_to_all_single_102 = torch.ops._c10d_functional.all_to_all_single.default(index_76, [_local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223], [_local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215], '521'); index_76 = None + wait_tensor_756 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_102); all_to_all_single_102 = None + full_396 = torch.ops.aten.full.default([sym_size_int_53, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_53 = None + slice_scatter_12 = torch.ops.aten.slice_scatter.default(full_396, wait_tensor_756, 0, 0, -1); wait_tensor_756 = None + index_77 = torch.ops.aten.index.Tensor(slice_scatter_12, [getitem_204]); slice_scatter_12 = None + permute_1020 = torch.ops.aten.permute.default(index_77, [1, 0]) + _grouped_mm_150 = torch.ops.aten._grouped_mm.default(permute_1020, mul_672, cumsum_41); permute_1020 = mul_672 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16); primals_242 = None + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_774, 8, '513'); convert_element_type_774 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_297, [0, 2, 1]); wait_tensor_297 = None + permute_1022 = torch.ops.aten.permute.default(permute_217, [0, 2, 1]); permute_217 = None + _grouped_mm_151 = torch.ops.aten._grouped_mm.default(index_77, permute_1022, cumsum_41); index_77 = permute_1022 = None + convert_element_type_778 = torch.ops.prims.convert_element_type.default(_grouped_mm_39, torch.float32); _grouped_mm_39 = None + neg_27 = torch.ops.aten.neg.default(convert_element_type_778) + exp_41 = torch.ops.aten.exp.default(neg_27); neg_27 = None + add_916 = torch.ops.aten.add.Tensor(exp_41, 1); exp_41 = None + div_69 = torch.ops.aten.div.Tensor(convert_element_type_778, add_916) + convert_element_type_779 = torch.ops.prims.convert_element_type.default(div_69, torch.bfloat16); div_69 = None + mul_1705 = torch.ops.aten.mul.Tensor(_grouped_mm_151, convert_element_type_779); convert_element_type_779 = None + mul_1706 = torch.ops.aten.mul.Tensor(_grouped_mm_151, _grouped_mm_40); _grouped_mm_151 = _grouped_mm_40 = None + permute_1024 = torch.ops.aten.permute.default(mul_1705, [1, 0]) + _grouped_mm_152 = torch.ops.aten._grouped_mm.default(permute_1024, index_27, cumsum_41); permute_1024 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16); primals_243 = None + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_775, 8, '513'); convert_element_type_775 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_298, [0, 2, 1]); wait_tensor_298 = None + permute_1026 = torch.ops.aten.permute.default(permute_216, [0, 2, 1]); permute_216 = None + _grouped_mm_153 = torch.ops.aten._grouped_mm.default(mul_1705, permute_1026, cumsum_41); mul_1705 = permute_1026 = None + convert_element_type_2366 = torch.ops.prims.convert_element_type.default(mul_1706, torch.float32); mul_1706 = None + reciprocal_25 = torch.ops.aten.reciprocal.default(add_916); add_916 = None + mul_1707 = torch.ops.aten.mul.Tensor(reciprocal_25, 1); reciprocal_25 = None + mul_1708 = torch.ops.aten.mul.Tensor(convert_element_type_2366, mul_1707); convert_element_type_2366 = None + sub_698 = torch.ops.aten.sub.Tensor(1, mul_1707); mul_1707 = None + mul_1709 = torch.ops.aten.mul.Tensor(convert_element_type_778, sub_698); convert_element_type_778 = sub_698 = None + add_1959 = torch.ops.aten.add.Tensor(mul_1709, 1); mul_1709 = None + mul_1710 = torch.ops.aten.mul.Tensor(mul_1708, add_1959); mul_1708 = add_1959 = None + convert_element_type_2368 = torch.ops.prims.convert_element_type.default(mul_1710, torch.bfloat16); mul_1710 = None + permute_1028 = torch.ops.aten.permute.default(convert_element_type_2368, [1, 0]) + _grouped_mm_154 = torch.ops.aten._grouped_mm.default(permute_1028, index_27, cumsum_41); permute_1028 = index_27 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16); primals_241 = None + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_772, 8, '513'); convert_element_type_772 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + permute_215 = torch.ops.aten.permute.default(wait_tensor_295, [0, 2, 1]); wait_tensor_295 = None + permute_1030 = torch.ops.aten.permute.default(permute_215, [0, 2, 1]); permute_215 = None + _grouped_mm_155 = torch.ops.aten._grouped_mm.default(convert_element_type_2368, permute_1030, cumsum_41); convert_element_type_2368 = permute_1030 = cumsum_41 = None + add_1960 = torch.ops.aten.add.Tensor(_grouped_mm_153, _grouped_mm_155); _grouped_mm_153 = _grouped_mm_155 = None + convert_element_type_2369 = torch.ops.prims.convert_element_type.default(_grouped_mm_152, torch.float32); _grouped_mm_152 = None + div_204 = torch.ops.aten.div.Tensor(convert_element_type_2369, 64); convert_element_type_2369 = None + reduce_scatter_tensor_173 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_204, 'sum', 8, '513'); div_204 = None + wait_tensor_757 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_173); reduce_scatter_tensor_173 = None + convert_element_type_2370 = torch.ops.prims.convert_element_type.default(_grouped_mm_150, torch.float32); _grouped_mm_150 = None + div_205 = torch.ops.aten.div.Tensor(convert_element_type_2370, 64); convert_element_type_2370 = None + reduce_scatter_tensor_174 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_205, 'sum', 8, '513'); div_205 = None + wait_tensor_758 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_174); reduce_scatter_tensor_174 = None + convert_element_type_2371 = torch.ops.prims.convert_element_type.default(_grouped_mm_154, torch.float32); _grouped_mm_154 = None + div_206 = torch.ops.aten.div.Tensor(convert_element_type_2371, 64); convert_element_type_2371 = None + reduce_scatter_tensor_175 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_206, 'sum', 8, '513'); div_206 = None + wait_tensor_759 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_175); reduce_scatter_tensor_175 = None + index_put_76 = torch.ops.aten.index_put.default(full_396, [getitem_204], add_1960, True); full_396 = getitem_204 = add_1960 = None + slice_179 = torch.ops.aten.slice.Tensor(index_put_76, 0, 0, add_1961); index_put_76 = add_1961 = None + all_to_all_single_103 = torch.ops._c10d_functional.all_to_all_single.default(slice_179, [_local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215], [_local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223], '521'); slice_179 = _local_scalar_dense_208 = _local_scalar_dense_209 = _local_scalar_dense_210 = _local_scalar_dense_211 = _local_scalar_dense_212 = _local_scalar_dense_213 = _local_scalar_dense_214 = _local_scalar_dense_215 = _local_scalar_dense_216 = _local_scalar_dense_217 = _local_scalar_dense_218 = _local_scalar_dense_219 = _local_scalar_dense_220 = _local_scalar_dense_221 = _local_scalar_dense_222 = _local_scalar_dense_223 = None + wait_tensor_760 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_103); all_to_all_single_103 = None + index_put_77 = torch.ops.aten.index_put.default(full_default_52, [div_67], wait_tensor_760, True); div_67 = wait_tensor_760 = None + add_1965 = torch.ops.aten.add.Tensor(add_1957, index_put_77); add_1957 = index_put_77 = None + mul_1711 = torch.ops.aten.mul.Tensor(view_2011, 1.0); view_2011 = None + scatter_add_12 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_201, mul_1711); getitem_201 = mul_1711 = None + convert_element_type_767 = torch.ops.prims.convert_element_type.default(mm_115, torch.float32); mm_115 = None + sub_312 = torch.ops.aten.sub.Tensor(convert_element_type_767, amax_13); convert_element_type_767 = amax_13 = None + exp_40 = torch.ops.aten.exp.default(sub_312); sub_312 = None + div_66 = torch.ops.aten.div.Tensor(exp_40, sum_53); exp_40 = sum_53 = None + mul_1712 = torch.ops.aten.mul.Tensor(scatter_add_12, div_66); scatter_add_12 = None + sum_203 = torch.ops.aten.sum.dim_IntList(mul_1712, [1], True) + neg_91 = torch.ops.aten.neg.default(div_66); div_66 = None + fma_12 = torch.ops.prims.fma.default(neg_91, sum_203, mul_1712); neg_91 = sum_203 = mul_1712 = None + convert_element_type_2372 = torch.ops.prims.convert_element_type.default(fma_12, torch.bfloat16); fma_12 = None + permute_1032 = torch.ops.aten.permute.default(convert_element_type_2372, [1, 0]) + mm_416 = torch.ops.aten.mm.default(permute_1032, view_929); permute_1032 = view_929 = None + convert_element_type_764 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16); primals_239 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_764, 64, '0'); convert_element_type_764 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + permute_214 = torch.ops.aten.permute.default(wait_tensor_291, [1, 0]); wait_tensor_291 = None + permute_1034 = torch.ops.aten.permute.default(permute_214, [1, 0]); permute_214 = None + mm_417 = torch.ops.aten.mm.default(convert_element_type_2372, permute_1034); convert_element_type_2372 = permute_1034 = None + add_1966 = torch.ops.aten.add.Tensor(add_1965, mm_417); add_1965 = mm_417 = None + convert_element_type_2377 = torch.ops.prims.convert_element_type.default(mm_416, torch.float32); mm_416 = None + reduce_scatter_tensor_176 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2377, 'avg', 64, '0'); convert_element_type_2377 = None + wait_tensor_761 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_176); reduce_scatter_tensor_176 = None + view_2013 = torch.ops.aten.view.default(add_1966, [2, 4096, 2048]); add_1966 = None + convert_element_type_2378 = torch.ops.prims.convert_element_type.default(view_2013, torch.float32); view_2013 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16); primals_237 = None + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_761, 64, '0'); convert_element_type_761 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + convert_element_type_2380 = torch.ops.prims.convert_element_type.default(wait_tensor_290, torch.float32); wait_tensor_290 = None + mul_1713 = torch.ops.aten.mul.Tensor(convert_element_type_2378, convert_element_type_2380); convert_element_type_2380 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(add_892, torch.float32); add_892 = None + mul_652 = torch.ops.aten.mul.Tensor(convert_element_type_762, rsqrt_44); convert_element_type_762 = None + mul_1715 = torch.ops.aten.mul.Tensor(mul_652, mul_1713) + sum_204 = torch.ops.aten.sum.dim_IntList(mul_1715, [2], True); mul_1715 = None + div_207 = torch.ops.aten.div.Tensor(mul_652, 2048) + mul_1716 = torch.ops.aten.mul.Tensor(div_207, sum_204); div_207 = sum_204 = None + sub_700 = torch.ops.aten.sub.Tensor(mul_1713, mul_1716); mul_1713 = mul_1716 = None + mul_1717 = torch.ops.aten.mul.Tensor(sub_700, rsqrt_44); sub_700 = rsqrt_44 = None + mul_1718 = torch.ops.aten.mul.Tensor(convert_element_type_2378, mul_652); convert_element_type_2378 = mul_652 = None + sum_205 = torch.ops.aten.sum.dim_IntList(mul_1718, [0, 1]); mul_1718 = None + convert_element_type_2381 = torch.ops.prims.convert_element_type.default(mul_1717, torch.bfloat16); mul_1717 = None + add_1967 = torch.ops.aten.add.Tensor(add_1954, convert_element_type_2381); add_1954 = convert_element_type_2381 = None + convert_element_type_default_45 = torch.ops.prims.convert_element_type.default(sum_205, torch.float32); sum_205 = None + reduce_scatter_tensor_177 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_45, 'avg', 64, '0'); convert_element_type_default_45 = None + wait_tensor_762 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_177); reduce_scatter_tensor_177 = None + view_2014 = torch.ops.aten.view.default(add_1967, [8192, 2048]) + permute_1036 = torch.ops.aten.permute.default(view_2014, [1, 0]) + permute_212 = torch.ops.aten.permute.default(getitem_197, [0, 2, 1, 3]) + view_924 = torch.ops.aten.view.default(permute_212, [2, 4096, -1]); permute_212 = None + view_926 = torch.ops.aten.view.default(view_924, [8192, 2048]); view_924 = None + mm_418 = torch.ops.aten.mm.default(permute_1036, view_926); permute_1036 = view_926 = None + convert_element_type_758 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16); primals_236 = None + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_758, 64, '0'); convert_element_type_758 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_213 = torch.ops.aten.permute.default(wait_tensor_289, [1, 0]); wait_tensor_289 = None + permute_1038 = torch.ops.aten.permute.default(permute_213, [1, 0]); permute_213 = None + mm_419 = torch.ops.aten.mm.default(view_2014, permute_1038); view_2014 = permute_1038 = None + view_2015 = torch.ops.aten.view.default(mm_419, [2, 4096, 2048]); mm_419 = None + convert_element_type_2388 = torch.ops.prims.convert_element_type.default(mm_418, torch.float32); mm_418 = None + reduce_scatter_tensor_178 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2388, 'avg', 64, '0'); convert_element_type_2388 = None + wait_tensor_763 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_178); reduce_scatter_tensor_178 = None + view_2016 = torch.ops.aten.view.default(view_2015, [2, 4096, 16, 128]); view_2015 = None + permute_1040 = torch.ops.aten.permute.default(view_2016, [0, 2, 1, 3]); view_2016 = None + fw_graph12 = self.fw_graph12 + joint_graph12 = self.joint_graph12 + mask_graph12 = self.mask_graph12 + flex_attention_backward_12 = torch.ops.higher_order.flex_attention_backward(permute_209, permute_210, permute_211, getitem_197, getitem_198, permute_1040, None, fw_graph12, joint_graph12, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph12), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_209 = permute_210 = permute_211 = getitem_197 = getitem_198 = permute_1040 = fw_graph12 = joint_graph12 = mask_graph12 = None + getitem_421 = flex_attention_backward_12[0] + getitem_422 = flex_attention_backward_12[1] + getitem_423 = flex_attention_backward_12[2]; flex_attention_backward_12 = None + permute_1041 = torch.ops.aten.permute.default(getitem_423, [0, 2, 1, 3]); getitem_423 = None + permute_1042 = torch.ops.aten.permute.default(getitem_422, [0, 2, 1, 3]); getitem_422 = None + permute_1043 = torch.ops.aten.permute.default(getitem_421, [0, 2, 1, 3]); getitem_421 = None + slice_181 = torch.ops.aten.slice.Tensor(permute_1042, 3, 0, 128) + slice_182 = torch.ops.aten.slice.Tensor(permute_1042, 3, 128, 192); permute_1042 = None + sum_206 = torch.ops.aten.sum.dim_IntList(slice_182, [2], True); slice_182 = None + cat_116 = torch.ops.aten.cat.default([slice_181, permute_1041], 3); slice_181 = permute_1041 = None + view_2017 = torch.ops.aten.view.default(cat_116, [2, 4096, 4096]); cat_116 = None + view_2018 = torch.ops.aten.view.default(view_2017, [8192, 4096]); view_2017 = None + permute_1044 = torch.ops.aten.permute.default(view_2018, [1, 0]) + mm_420 = torch.ops.aten.mm.default(permute_1044, view_921); permute_1044 = view_921 = None + convert_element_type_755 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16); primals_235 = None + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_755, 64, '0'); convert_element_type_755 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + permute_1046 = torch.ops.aten.permute.default(permute_208, [1, 0]); permute_208 = None + mm_421 = torch.ops.aten.mm.default(view_2018, permute_1046); view_2018 = permute_1046 = None + view_2019 = torch.ops.aten.view.default(mm_421, [2, 4096, 512]); mm_421 = None + convert_element_type_2393 = torch.ops.prims.convert_element_type.default(mm_420, torch.float32); mm_420 = None + reduce_scatter_tensor_179 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2393, 'avg', 64, '0'); convert_element_type_2393 = None + wait_tensor_764 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_179); reduce_scatter_tensor_179 = None + convert_element_type_2394 = torch.ops.prims.convert_element_type.default(view_2019, torch.float32); view_2019 = None + convert_element_type_752 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16); primals_234 = None + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_752, 64, '0'); convert_element_type_752 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_2396 = torch.ops.prims.convert_element_type.default(wait_tensor_287, torch.float32); wait_tensor_287 = None + mul_1719 = torch.ops.aten.mul.Tensor(convert_element_type_2394, convert_element_type_2396); convert_element_type_2396 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(getitem_193, torch.float32); getitem_193 = None + mul_650 = torch.ops.aten.mul.Tensor(convert_element_type_753, rsqrt_43); convert_element_type_753 = None + mul_1721 = torch.ops.aten.mul.Tensor(mul_650, mul_1719) + sum_207 = torch.ops.aten.sum.dim_IntList(mul_1721, [2], True); mul_1721 = None + div_208 = torch.ops.aten.div.Tensor(mul_650, 512) + mul_1722 = torch.ops.aten.mul.Tensor(div_208, sum_207); div_208 = sum_207 = None + sub_701 = torch.ops.aten.sub.Tensor(mul_1719, mul_1722); mul_1719 = mul_1722 = None + mul_1723 = torch.ops.aten.mul.Tensor(sub_701, rsqrt_43); sub_701 = rsqrt_43 = None + mul_1724 = torch.ops.aten.mul.Tensor(convert_element_type_2394, mul_650); convert_element_type_2394 = mul_650 = None + sum_208 = torch.ops.aten.sum.dim_IntList(mul_1724, [0, 1]); mul_1724 = None + convert_element_type_2397 = torch.ops.prims.convert_element_type.default(mul_1723, torch.bfloat16); mul_1723 = None + convert_element_type_default_44 = torch.ops.prims.convert_element_type.default(sum_208, torch.float32); sum_208 = None + reduce_scatter_tensor_180 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_44, 'avg', 64, '0'); convert_element_type_default_44 = None + wait_tensor_765 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_180); reduce_scatter_tensor_180 = None + convert_element_type_2400 = torch.ops.prims.convert_element_type.default(sum_206, torch.float32); sum_206 = None + view_2020 = torch.ops.aten.view.default(convert_element_type_2400, [2, 4096, 1, 32, 2]); convert_element_type_2400 = None + view_as_complex_78 = torch.ops.aten.view_as_complex.default(view_2020); view_2020 = None + mul_1725 = torch.ops.aten.mul.Tensor(view_as_complex_78, clone_9); view_as_complex_78 = None + view_as_real_78 = torch.ops.aten.view_as_real.default(mul_1725); mul_1725 = None + view_2021 = torch.ops.aten.view.default(view_as_real_78, [2, 4096, 1, 64]); view_as_real_78 = None + convert_element_type_2401 = torch.ops.prims.convert_element_type.default(view_2021, torch.bfloat16); view_2021 = None + squeeze_38 = torch.ops.aten.squeeze.dim(convert_element_type_2401, 2); convert_element_type_2401 = None + cat_117 = torch.ops.aten.cat.default([convert_element_type_2397, squeeze_38], 2); convert_element_type_2397 = squeeze_38 = None + view_2022 = torch.ops.aten.view.default(cat_117, [8192, 576]); cat_117 = None + permute_1048 = torch.ops.aten.permute.default(view_2022, [1, 0]) + mm_422 = torch.ops.aten.mm.default(permute_1048, view_907); permute_1048 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16); primals_233 = None + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_747, 64, '0'); convert_element_type_747 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + permute_1050 = torch.ops.aten.permute.default(permute_207, [1, 0]); permute_207 = None + mm_423 = torch.ops.aten.mm.default(view_2022, permute_1050); view_2022 = permute_1050 = None + view_2023 = torch.ops.aten.view.default(mm_423, [2, 4096, 2048]); mm_423 = None + convert_element_type_2406 = torch.ops.prims.convert_element_type.default(mm_422, torch.float32); mm_422 = None + reduce_scatter_tensor_181 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2406, 'avg', 64, '0'); convert_element_type_2406 = None + wait_tensor_766 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_181); reduce_scatter_tensor_181 = None + slice_183 = torch.ops.aten.slice.Tensor(permute_1043, 3, 0, 128) + slice_184 = torch.ops.aten.slice.Tensor(permute_1043, 3, 128, 192); permute_1043 = None + convert_element_type_2407 = torch.ops.prims.convert_element_type.default(slice_184, torch.float32); slice_184 = None + view_2024 = torch.ops.aten.view.default(convert_element_type_2407, [2, 4096, 16, 32, 2]); convert_element_type_2407 = None + view_as_complex_79 = torch.ops.aten.view_as_complex.default(view_2024); view_2024 = None + mul_1726 = torch.ops.aten.mul.Tensor(view_as_complex_79, clone_9); view_as_complex_79 = None + view_as_real_79 = torch.ops.aten.view_as_real.default(mul_1726); mul_1726 = None + view_2025 = torch.ops.aten.view.default(view_as_real_79, [2, 4096, 16, 64]); view_as_real_79 = None + convert_element_type_2408 = torch.ops.prims.convert_element_type.default(view_2025, torch.bfloat16); view_2025 = None + cat_118 = torch.ops.aten.cat.default([slice_183, convert_element_type_2408], 3); slice_183 = convert_element_type_2408 = None + view_2026 = torch.ops.aten.view.default(cat_118, [2, 4096, 3072]); cat_118 = None + view_2027 = torch.ops.aten.view.default(view_2026, [8192, 3072]); view_2026 = None + permute_1052 = torch.ops.aten.permute.default(view_2027, [1, 0]) + mm_424 = torch.ops.aten.mm.default(permute_1052, view_907); permute_1052 = view_907 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16); primals_232 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_742, 64, '0'); convert_element_type_742 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_285, [1, 0]); wait_tensor_285 = None + permute_1054 = torch.ops.aten.permute.default(permute_206, [1, 0]); permute_206 = None + mm_425 = torch.ops.aten.mm.default(view_2027, permute_1054); view_2027 = permute_1054 = None + view_2028 = torch.ops.aten.view.default(mm_425, [2, 4096, 2048]); mm_425 = None + add_1968 = torch.ops.aten.add.Tensor(view_2023, view_2028); view_2023 = view_2028 = None + convert_element_type_2413 = torch.ops.prims.convert_element_type.default(mm_424, torch.float32); mm_424 = None + reduce_scatter_tensor_182 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2413, 'avg', 64, '0'); convert_element_type_2413 = None + wait_tensor_767 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_182); reduce_scatter_tensor_182 = None + convert_element_type_2414 = torch.ops.prims.convert_element_type.default(add_1968, torch.float32); add_1968 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16); primals_231 = None + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_739, 64, '0'); convert_element_type_739 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + convert_element_type_2416 = torch.ops.prims.convert_element_type.default(wait_tensor_284, torch.float32); wait_tensor_284 = None + mul_1727 = torch.ops.aten.mul.Tensor(convert_element_type_2414, convert_element_type_2416); convert_element_type_2416 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(add_889, torch.float32); add_889 = None + mul_646 = torch.ops.aten.mul.Tensor(convert_element_type_740, rsqrt_42); convert_element_type_740 = None + mul_1729 = torch.ops.aten.mul.Tensor(mul_646, mul_1727) + sum_209 = torch.ops.aten.sum.dim_IntList(mul_1729, [2], True); mul_1729 = None + div_209 = torch.ops.aten.div.Tensor(mul_646, 2048) + mul_1730 = torch.ops.aten.mul.Tensor(div_209, sum_209); div_209 = sum_209 = None + sub_702 = torch.ops.aten.sub.Tensor(mul_1727, mul_1730); mul_1727 = mul_1730 = None + mul_1731 = torch.ops.aten.mul.Tensor(sub_702, rsqrt_42); sub_702 = rsqrt_42 = None + mul_1732 = torch.ops.aten.mul.Tensor(convert_element_type_2414, mul_646); convert_element_type_2414 = mul_646 = None + sum_210 = torch.ops.aten.sum.dim_IntList(mul_1732, [0, 1]); mul_1732 = None + convert_element_type_2417 = torch.ops.prims.convert_element_type.default(mul_1731, torch.bfloat16); mul_1731 = None + add_1969 = torch.ops.aten.add.Tensor(add_1967, convert_element_type_2417); add_1967 = convert_element_type_2417 = None + convert_element_type_default_43 = torch.ops.prims.convert_element_type.default(sum_210, torch.float32); sum_210 = None + reduce_scatter_tensor_183 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_43, 'avg', 64, '0'); convert_element_type_default_43 = None + wait_tensor_768 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_183); reduce_scatter_tensor_183 = None + view_2029 = torch.ops.aten.view.default(add_1969, [8192, 2048]) + unsqueeze_66 = torch.ops.aten.unsqueeze.default(view_2029, 1) + convert_element_type_2420 = torch.ops.prims.convert_element_type.default(unsqueeze_66, torch.float32); unsqueeze_66 = None + bmm_52 = torch.ops.aten.bmm.default(permute_1056, convert_element_type_2420); permute_1056 = None + bmm_53 = torch.ops.aten.bmm.default(convert_element_type_2420, permute_1057); convert_element_type_2420 = permute_1057 = None + convert_element_type_2421 = torch.ops.prims.convert_element_type.default(bmm_52, torch.bfloat16); bmm_52 = None + view_2030 = torch.ops.aten.view.default(bmm_53, [8192, 6]); bmm_53 = None + view_2031 = torch.ops.aten.view.default(convert_element_type_2421, [49152, 2048]); convert_element_type_2421 = None + index_78 = torch.ops.aten.index.Tensor(view_2031, [getitem_189]); view_2031 = getitem_189 = None + permute_1058 = torch.ops.aten.permute.default(view_2029, [1, 0]) + mm_426 = torch.ops.aten.mm.default(permute_1058, mul_643); permute_1058 = mul_643 = None + convert_element_type_734 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16); primals_230 = None + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_734, 64, '0'); convert_element_type_734 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + permute_1060 = torch.ops.aten.permute.default(permute_205, [1, 0]); permute_205 = None + mm_427 = torch.ops.aten.mm.default(view_2029, permute_1060); view_2029 = permute_1060 = None + convert_element_type_2426 = torch.ops.prims.convert_element_type.default(mm_426, torch.float32); mm_426 = None + reduce_scatter_tensor_184 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2426, 'avg', 64, '0'); convert_element_type_2426 = None + wait_tensor_769 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_184); reduce_scatter_tensor_184 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mm_108, torch.float32); mm_108 = None + neg_26 = torch.ops.aten.neg.default(convert_element_type_729) + exp_39 = torch.ops.aten.exp.default(neg_26); neg_26 = None + add_884 = torch.ops.aten.add.Tensor(exp_39, 1); exp_39 = None + div_65 = torch.ops.aten.div.Tensor(convert_element_type_729, add_884) + convert_element_type_730 = torch.ops.prims.convert_element_type.default(div_65, torch.bfloat16); div_65 = None + mul_1733 = torch.ops.aten.mul.Tensor(mm_427, convert_element_type_730); convert_element_type_730 = None + mul_1734 = torch.ops.aten.mul.Tensor(mm_427, mm_109); mm_427 = mm_109 = None + permute_1062 = torch.ops.aten.permute.default(mul_1733, [1, 0]) + mm_428 = torch.ops.aten.mm.default(permute_1062, view_862); permute_1062 = None + convert_element_type_731 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16); primals_229 = None + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_731, 64, '0'); convert_element_type_731 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_204 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + permute_1064 = torch.ops.aten.permute.default(permute_204, [1, 0]); permute_204 = None + mm_429 = torch.ops.aten.mm.default(mul_1733, permute_1064); mul_1733 = permute_1064 = None + convert_element_type_2431 = torch.ops.prims.convert_element_type.default(mm_428, torch.float32); mm_428 = None + reduce_scatter_tensor_185 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2431, 'avg', 64, '0'); convert_element_type_2431 = None + wait_tensor_770 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_185); reduce_scatter_tensor_185 = None + convert_element_type_2432 = torch.ops.prims.convert_element_type.default(mul_1734, torch.float32); mul_1734 = None + reciprocal_26 = torch.ops.aten.reciprocal.default(add_884); add_884 = None + mul_1735 = torch.ops.aten.mul.Tensor(reciprocal_26, 1); reciprocal_26 = None + mul_1736 = torch.ops.aten.mul.Tensor(convert_element_type_2432, mul_1735); convert_element_type_2432 = None + sub_703 = torch.ops.aten.sub.Tensor(1, mul_1735); mul_1735 = None + mul_1737 = torch.ops.aten.mul.Tensor(convert_element_type_729, sub_703); convert_element_type_729 = sub_703 = None + add_1971 = torch.ops.aten.add.Tensor(mul_1737, 1); mul_1737 = None + mul_1738 = torch.ops.aten.mul.Tensor(mul_1736, add_1971); mul_1736 = add_1971 = None + convert_element_type_2434 = torch.ops.prims.convert_element_type.default(mul_1738, torch.bfloat16); mul_1738 = None + permute_1066 = torch.ops.aten.permute.default(convert_element_type_2434, [1, 0]) + mm_430 = torch.ops.aten.mm.default(permute_1066, view_862); permute_1066 = None + convert_element_type_726 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16); primals_228 = None + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_726, 64, '0'); convert_element_type_726 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_203 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + permute_1068 = torch.ops.aten.permute.default(permute_203, [1, 0]); permute_203 = None + mm_431 = torch.ops.aten.mm.default(convert_element_type_2434, permute_1068); convert_element_type_2434 = permute_1068 = None + add_1972 = torch.ops.aten.add.Tensor(mm_429, mm_431); mm_429 = mm_431 = None + convert_element_type_2439 = torch.ops.prims.convert_element_type.default(mm_430, torch.float32); mm_430 = None + reduce_scatter_tensor_186 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2439, 'avg', 64, '0'); convert_element_type_2439 = None + wait_tensor_771 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_186); reduce_scatter_tensor_186 = None + all_to_all_single_104 = torch.ops._c10d_functional.all_to_all_single.default(index_78, [_local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207], [_local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199], '521'); index_78 = None + wait_tensor_772 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_104); all_to_all_single_104 = None + full_400 = torch.ops.aten.full.default([sym_size_int_49, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_49 = None + slice_scatter_13 = torch.ops.aten.slice_scatter.default(full_400, wait_tensor_772, 0, 0, -1); wait_tensor_772 = None + index_79 = torch.ops.aten.index.Tensor(slice_scatter_13, [getitem_190]); slice_scatter_13 = None + permute_1070 = torch.ops.aten.permute.default(index_79, [1, 0]) + _grouped_mm_156 = torch.ops.aten._grouped_mm.default(permute_1070, mul_623, cumsum_38); permute_1070 = mul_623 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16); primals_226 = None + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_720, 8, '513'); convert_element_type_720 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_202 = torch.ops.aten.permute.default(wait_tensor_276, [0, 2, 1]); wait_tensor_276 = None + permute_1072 = torch.ops.aten.permute.default(permute_202, [0, 2, 1]); permute_202 = None + _grouped_mm_157 = torch.ops.aten._grouped_mm.default(index_79, permute_1072, cumsum_38); index_79 = permute_1072 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(_grouped_mm_36, torch.float32); _grouped_mm_36 = None + neg_25 = torch.ops.aten.neg.default(convert_element_type_724) + exp_38 = torch.ops.aten.exp.default(neg_25); neg_25 = None + add_848 = torch.ops.aten.add.Tensor(exp_38, 1); exp_38 = None + div_64 = torch.ops.aten.div.Tensor(convert_element_type_724, add_848) + convert_element_type_725 = torch.ops.prims.convert_element_type.default(div_64, torch.bfloat16); div_64 = None + mul_1739 = torch.ops.aten.mul.Tensor(_grouped_mm_157, convert_element_type_725); convert_element_type_725 = None + mul_1740 = torch.ops.aten.mul.Tensor(_grouped_mm_157, _grouped_mm_37); _grouped_mm_157 = _grouped_mm_37 = None + permute_1074 = torch.ops.aten.permute.default(mul_1739, [1, 0]) + _grouped_mm_158 = torch.ops.aten._grouped_mm.default(permute_1074, index_25, cumsum_38); permute_1074 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16); primals_227 = None + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 8, '513'); convert_element_type_721 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + permute_201 = torch.ops.aten.permute.default(wait_tensor_277, [0, 2, 1]); wait_tensor_277 = None + permute_1076 = torch.ops.aten.permute.default(permute_201, [0, 2, 1]); permute_201 = None + _grouped_mm_159 = torch.ops.aten._grouped_mm.default(mul_1739, permute_1076, cumsum_38); mul_1739 = permute_1076 = None + convert_element_type_2440 = torch.ops.prims.convert_element_type.default(mul_1740, torch.float32); mul_1740 = None + reciprocal_27 = torch.ops.aten.reciprocal.default(add_848); add_848 = None + mul_1741 = torch.ops.aten.mul.Tensor(reciprocal_27, 1); reciprocal_27 = None + mul_1742 = torch.ops.aten.mul.Tensor(convert_element_type_2440, mul_1741); convert_element_type_2440 = None + sub_704 = torch.ops.aten.sub.Tensor(1, mul_1741); mul_1741 = None + mul_1743 = torch.ops.aten.mul.Tensor(convert_element_type_724, sub_704); convert_element_type_724 = sub_704 = None + add_1974 = torch.ops.aten.add.Tensor(mul_1743, 1); mul_1743 = None + mul_1744 = torch.ops.aten.mul.Tensor(mul_1742, add_1974); mul_1742 = add_1974 = None + convert_element_type_2442 = torch.ops.prims.convert_element_type.default(mul_1744, torch.bfloat16); mul_1744 = None + permute_1078 = torch.ops.aten.permute.default(convert_element_type_2442, [1, 0]) + _grouped_mm_160 = torch.ops.aten._grouped_mm.default(permute_1078, index_25, cumsum_38); permute_1078 = index_25 = None + convert_element_type_718 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16); primals_225 = None + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_718, 8, '513'); convert_element_type_718 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_274, [0, 2, 1]); wait_tensor_274 = None + permute_1080 = torch.ops.aten.permute.default(permute_200, [0, 2, 1]); permute_200 = None + _grouped_mm_161 = torch.ops.aten._grouped_mm.default(convert_element_type_2442, permute_1080, cumsum_38); convert_element_type_2442 = permute_1080 = cumsum_38 = None + add_1975 = torch.ops.aten.add.Tensor(_grouped_mm_159, _grouped_mm_161); _grouped_mm_159 = _grouped_mm_161 = None + convert_element_type_2443 = torch.ops.prims.convert_element_type.default(_grouped_mm_158, torch.float32); _grouped_mm_158 = None + div_210 = torch.ops.aten.div.Tensor(convert_element_type_2443, 64); convert_element_type_2443 = None + reduce_scatter_tensor_187 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_210, 'sum', 8, '513'); div_210 = None + wait_tensor_773 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_187); reduce_scatter_tensor_187 = None + convert_element_type_2444 = torch.ops.prims.convert_element_type.default(_grouped_mm_156, torch.float32); _grouped_mm_156 = None + div_211 = torch.ops.aten.div.Tensor(convert_element_type_2444, 64); convert_element_type_2444 = None + reduce_scatter_tensor_188 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_211, 'sum', 8, '513'); div_211 = None + wait_tensor_774 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_188); reduce_scatter_tensor_188 = None + convert_element_type_2445 = torch.ops.prims.convert_element_type.default(_grouped_mm_160, torch.float32); _grouped_mm_160 = None + div_212 = torch.ops.aten.div.Tensor(convert_element_type_2445, 64); convert_element_type_2445 = None + reduce_scatter_tensor_189 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_212, 'sum', 8, '513'); div_212 = None + wait_tensor_775 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_189); reduce_scatter_tensor_189 = None + index_put_78 = torch.ops.aten.index_put.default(full_400, [getitem_190], add_1975, True); full_400 = getitem_190 = add_1975 = None + slice_185 = torch.ops.aten.slice.Tensor(index_put_78, 0, 0, add_1976); index_put_78 = add_1976 = None + all_to_all_single_105 = torch.ops._c10d_functional.all_to_all_single.default(slice_185, [_local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199], [_local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207], '521'); slice_185 = _local_scalar_dense_192 = _local_scalar_dense_193 = _local_scalar_dense_194 = _local_scalar_dense_195 = _local_scalar_dense_196 = _local_scalar_dense_197 = _local_scalar_dense_198 = _local_scalar_dense_199 = _local_scalar_dense_200 = _local_scalar_dense_201 = _local_scalar_dense_202 = _local_scalar_dense_203 = _local_scalar_dense_204 = _local_scalar_dense_205 = _local_scalar_dense_206 = _local_scalar_dense_207 = None + wait_tensor_776 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_105); all_to_all_single_105 = None + index_put_79 = torch.ops.aten.index_put.default(full_default_52, [div_62], wait_tensor_776, True); div_62 = wait_tensor_776 = None + add_1980 = torch.ops.aten.add.Tensor(add_1972, index_put_79); add_1972 = index_put_79 = None + mul_1745 = torch.ops.aten.mul.Tensor(view_2030, 1.0); view_2030 = None + scatter_add_13 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_187, mul_1745); getitem_187 = mul_1745 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(mm_107, torch.float32); mm_107 = None + sub_288 = torch.ops.aten.sub.Tensor(convert_element_type_713, amax_12); convert_element_type_713 = amax_12 = None + exp_37 = torch.ops.aten.exp.default(sub_288); sub_288 = None + div_61 = torch.ops.aten.div.Tensor(exp_37, sum_49); exp_37 = sum_49 = None + mul_1746 = torch.ops.aten.mul.Tensor(scatter_add_13, div_61); scatter_add_13 = None + sum_211 = torch.ops.aten.sum.dim_IntList(mul_1746, [1], True) + neg_94 = torch.ops.aten.neg.default(div_61); div_61 = None + fma_13 = torch.ops.prims.fma.default(neg_94, sum_211, mul_1746); neg_94 = sum_211 = mul_1746 = None + convert_element_type_2446 = torch.ops.prims.convert_element_type.default(fma_13, torch.bfloat16); fma_13 = None + permute_1082 = torch.ops.aten.permute.default(convert_element_type_2446, [1, 0]) + mm_432 = torch.ops.aten.mm.default(permute_1082, view_862); permute_1082 = view_862 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16); primals_223 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 64, '0'); convert_element_type_710 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_270, [1, 0]); wait_tensor_270 = None + permute_1084 = torch.ops.aten.permute.default(permute_199, [1, 0]); permute_199 = None + mm_433 = torch.ops.aten.mm.default(convert_element_type_2446, permute_1084); convert_element_type_2446 = permute_1084 = None + add_1981 = torch.ops.aten.add.Tensor(add_1980, mm_433); add_1980 = mm_433 = None + convert_element_type_2451 = torch.ops.prims.convert_element_type.default(mm_432, torch.float32); mm_432 = None + reduce_scatter_tensor_190 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2451, 'avg', 64, '0'); convert_element_type_2451 = None + wait_tensor_777 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_190); reduce_scatter_tensor_190 = None + view_2032 = torch.ops.aten.view.default(add_1981, [2, 4096, 2048]); add_1981 = None + convert_element_type_2452 = torch.ops.prims.convert_element_type.default(view_2032, torch.float32); view_2032 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16); primals_221 = None + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_707, 64, '0'); convert_element_type_707 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + convert_element_type_2454 = torch.ops.prims.convert_element_type.default(wait_tensor_269, torch.float32); wait_tensor_269 = None + mul_1747 = torch.ops.aten.mul.Tensor(convert_element_type_2452, convert_element_type_2454); convert_element_type_2454 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(add_824, torch.float32); add_824 = None + mul_603 = torch.ops.aten.mul.Tensor(convert_element_type_708, rsqrt_41); convert_element_type_708 = None + mul_1749 = torch.ops.aten.mul.Tensor(mul_603, mul_1747) + sum_212 = torch.ops.aten.sum.dim_IntList(mul_1749, [2], True); mul_1749 = None + div_213 = torch.ops.aten.div.Tensor(mul_603, 2048) + mul_1750 = torch.ops.aten.mul.Tensor(div_213, sum_212); div_213 = sum_212 = None + sub_706 = torch.ops.aten.sub.Tensor(mul_1747, mul_1750); mul_1747 = mul_1750 = None + mul_1751 = torch.ops.aten.mul.Tensor(sub_706, rsqrt_41); sub_706 = rsqrt_41 = None + mul_1752 = torch.ops.aten.mul.Tensor(convert_element_type_2452, mul_603); convert_element_type_2452 = mul_603 = None + sum_213 = torch.ops.aten.sum.dim_IntList(mul_1752, [0, 1]); mul_1752 = None + convert_element_type_2455 = torch.ops.prims.convert_element_type.default(mul_1751, torch.bfloat16); mul_1751 = None + add_1982 = torch.ops.aten.add.Tensor(add_1969, convert_element_type_2455); add_1969 = convert_element_type_2455 = None + convert_element_type_default_42 = torch.ops.prims.convert_element_type.default(sum_213, torch.float32); sum_213 = None + reduce_scatter_tensor_191 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_42, 'avg', 64, '0'); convert_element_type_default_42 = None + wait_tensor_778 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_191); reduce_scatter_tensor_191 = None + view_2033 = torch.ops.aten.view.default(add_1982, [8192, 2048]) + permute_1086 = torch.ops.aten.permute.default(view_2033, [1, 0]) + permute_197 = torch.ops.aten.permute.default(getitem_183, [0, 2, 1, 3]) + view_857 = torch.ops.aten.view.default(permute_197, [2, 4096, -1]); permute_197 = None + view_859 = torch.ops.aten.view.default(view_857, [8192, 2048]); view_857 = None + mm_434 = torch.ops.aten.mm.default(permute_1086, view_859); permute_1086 = view_859 = None + convert_element_type_704 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16); primals_220 = None + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_704, 64, '0'); convert_element_type_704 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + permute_1088 = torch.ops.aten.permute.default(permute_198, [1, 0]); permute_198 = None + mm_435 = torch.ops.aten.mm.default(view_2033, permute_1088); view_2033 = permute_1088 = None + view_2034 = torch.ops.aten.view.default(mm_435, [2, 4096, 2048]); mm_435 = None + convert_element_type_2462 = torch.ops.prims.convert_element_type.default(mm_434, torch.float32); mm_434 = None + reduce_scatter_tensor_192 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2462, 'avg', 64, '0'); convert_element_type_2462 = None + wait_tensor_779 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_192); reduce_scatter_tensor_192 = None + view_2035 = torch.ops.aten.view.default(view_2034, [2, 4096, 16, 128]); view_2034 = None + permute_1090 = torch.ops.aten.permute.default(view_2035, [0, 2, 1, 3]); view_2035 = None + fw_graph13 = self.fw_graph13 + joint_graph13 = self.joint_graph13 + mask_graph13 = self.mask_graph13 + flex_attention_backward_13 = torch.ops.higher_order.flex_attention_backward(permute_194, permute_195, permute_196, getitem_183, getitem_184, permute_1090, None, fw_graph13, joint_graph13, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph13), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_194 = permute_195 = permute_196 = getitem_183 = getitem_184 = permute_1090 = fw_graph13 = joint_graph13 = mask_graph13 = None + getitem_425 = flex_attention_backward_13[0] + getitem_426 = flex_attention_backward_13[1] + getitem_427 = flex_attention_backward_13[2]; flex_attention_backward_13 = None + permute_1091 = torch.ops.aten.permute.default(getitem_427, [0, 2, 1, 3]); getitem_427 = None + permute_1092 = torch.ops.aten.permute.default(getitem_426, [0, 2, 1, 3]); getitem_426 = None + permute_1093 = torch.ops.aten.permute.default(getitem_425, [0, 2, 1, 3]); getitem_425 = None + slice_187 = torch.ops.aten.slice.Tensor(permute_1092, 3, 0, 128) + slice_188 = torch.ops.aten.slice.Tensor(permute_1092, 3, 128, 192); permute_1092 = None + sum_214 = torch.ops.aten.sum.dim_IntList(slice_188, [2], True); slice_188 = None + cat_119 = torch.ops.aten.cat.default([slice_187, permute_1091], 3); slice_187 = permute_1091 = None + view_2036 = torch.ops.aten.view.default(cat_119, [2, 4096, 4096]); cat_119 = None + view_2037 = torch.ops.aten.view.default(view_2036, [8192, 4096]); view_2036 = None + permute_1094 = torch.ops.aten.permute.default(view_2037, [1, 0]) + mm_436 = torch.ops.aten.mm.default(permute_1094, view_854); permute_1094 = view_854 = None + convert_element_type_701 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16); primals_219 = None + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_701, 64, '0'); convert_element_type_701 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_193 = torch.ops.aten.permute.default(wait_tensor_267, [1, 0]); wait_tensor_267 = None + permute_1096 = torch.ops.aten.permute.default(permute_193, [1, 0]); permute_193 = None + mm_437 = torch.ops.aten.mm.default(view_2037, permute_1096); view_2037 = permute_1096 = None + view_2038 = torch.ops.aten.view.default(mm_437, [2, 4096, 512]); mm_437 = None + convert_element_type_2467 = torch.ops.prims.convert_element_type.default(mm_436, torch.float32); mm_436 = None + reduce_scatter_tensor_193 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2467, 'avg', 64, '0'); convert_element_type_2467 = None + wait_tensor_780 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_193); reduce_scatter_tensor_193 = None + convert_element_type_2468 = torch.ops.prims.convert_element_type.default(view_2038, torch.float32); view_2038 = None + convert_element_type_698 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16); primals_218 = None + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_698, 64, '0'); convert_element_type_698 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + convert_element_type_2470 = torch.ops.prims.convert_element_type.default(wait_tensor_266, torch.float32); wait_tensor_266 = None + mul_1753 = torch.ops.aten.mul.Tensor(convert_element_type_2468, convert_element_type_2470); convert_element_type_2470 = None + convert_element_type_699 = torch.ops.prims.convert_element_type.default(getitem_179, torch.float32); getitem_179 = None + mul_601 = torch.ops.aten.mul.Tensor(convert_element_type_699, rsqrt_40); convert_element_type_699 = None + mul_1755 = torch.ops.aten.mul.Tensor(mul_601, mul_1753) + sum_215 = torch.ops.aten.sum.dim_IntList(mul_1755, [2], True); mul_1755 = None + div_214 = torch.ops.aten.div.Tensor(mul_601, 512) + mul_1756 = torch.ops.aten.mul.Tensor(div_214, sum_215); div_214 = sum_215 = None + sub_707 = torch.ops.aten.sub.Tensor(mul_1753, mul_1756); mul_1753 = mul_1756 = None + mul_1757 = torch.ops.aten.mul.Tensor(sub_707, rsqrt_40); sub_707 = rsqrt_40 = None + mul_1758 = torch.ops.aten.mul.Tensor(convert_element_type_2468, mul_601); convert_element_type_2468 = mul_601 = None + sum_216 = torch.ops.aten.sum.dim_IntList(mul_1758, [0, 1]); mul_1758 = None + convert_element_type_2471 = torch.ops.prims.convert_element_type.default(mul_1757, torch.bfloat16); mul_1757 = None + convert_element_type_default_41 = torch.ops.prims.convert_element_type.default(sum_216, torch.float32); sum_216 = None + reduce_scatter_tensor_194 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_41, 'avg', 64, '0'); convert_element_type_default_41 = None + wait_tensor_781 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_194); reduce_scatter_tensor_194 = None + convert_element_type_2474 = torch.ops.prims.convert_element_type.default(sum_214, torch.float32); sum_214 = None + view_2039 = torch.ops.aten.view.default(convert_element_type_2474, [2, 4096, 1, 32, 2]); convert_element_type_2474 = None + view_as_complex_80 = torch.ops.aten.view_as_complex.default(view_2039); view_2039 = None + mul_1759 = torch.ops.aten.mul.Tensor(view_as_complex_80, clone_9); view_as_complex_80 = None + view_as_real_80 = torch.ops.aten.view_as_real.default(mul_1759); mul_1759 = None + view_2040 = torch.ops.aten.view.default(view_as_real_80, [2, 4096, 1, 64]); view_as_real_80 = None + convert_element_type_2475 = torch.ops.prims.convert_element_type.default(view_2040, torch.bfloat16); view_2040 = None + squeeze_39 = torch.ops.aten.squeeze.dim(convert_element_type_2475, 2); convert_element_type_2475 = None + cat_120 = torch.ops.aten.cat.default([convert_element_type_2471, squeeze_39], 2); convert_element_type_2471 = squeeze_39 = None + view_2041 = torch.ops.aten.view.default(cat_120, [8192, 576]); cat_120 = None + permute_1098 = torch.ops.aten.permute.default(view_2041, [1, 0]) + mm_438 = torch.ops.aten.mm.default(permute_1098, view_840); permute_1098 = None + convert_element_type_693 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16); primals_217 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_693, 64, '0'); convert_element_type_693 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + permute_192 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + permute_1100 = torch.ops.aten.permute.default(permute_192, [1, 0]); permute_192 = None + mm_439 = torch.ops.aten.mm.default(view_2041, permute_1100); view_2041 = permute_1100 = None + view_2042 = torch.ops.aten.view.default(mm_439, [2, 4096, 2048]); mm_439 = None + convert_element_type_2480 = torch.ops.prims.convert_element_type.default(mm_438, torch.float32); mm_438 = None + reduce_scatter_tensor_195 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2480, 'avg', 64, '0'); convert_element_type_2480 = None + wait_tensor_782 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_195); reduce_scatter_tensor_195 = None + slice_189 = torch.ops.aten.slice.Tensor(permute_1093, 3, 0, 128) + slice_190 = torch.ops.aten.slice.Tensor(permute_1093, 3, 128, 192); permute_1093 = None + convert_element_type_2481 = torch.ops.prims.convert_element_type.default(slice_190, torch.float32); slice_190 = None + view_2043 = torch.ops.aten.view.default(convert_element_type_2481, [2, 4096, 16, 32, 2]); convert_element_type_2481 = None + view_as_complex_81 = torch.ops.aten.view_as_complex.default(view_2043); view_2043 = None + mul_1760 = torch.ops.aten.mul.Tensor(view_as_complex_81, clone_9); view_as_complex_81 = None + view_as_real_81 = torch.ops.aten.view_as_real.default(mul_1760); mul_1760 = None + view_2044 = torch.ops.aten.view.default(view_as_real_81, [2, 4096, 16, 64]); view_as_real_81 = None + convert_element_type_2482 = torch.ops.prims.convert_element_type.default(view_2044, torch.bfloat16); view_2044 = None + cat_121 = torch.ops.aten.cat.default([slice_189, convert_element_type_2482], 3); slice_189 = convert_element_type_2482 = None + view_2045 = torch.ops.aten.view.default(cat_121, [2, 4096, 3072]); cat_121 = None + view_2046 = torch.ops.aten.view.default(view_2045, [8192, 3072]); view_2045 = None + permute_1102 = torch.ops.aten.permute.default(view_2046, [1, 0]) + mm_440 = torch.ops.aten.mm.default(permute_1102, view_840); permute_1102 = view_840 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16); primals_216 = None + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 64, '0'); convert_element_type_688 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_191 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + permute_1104 = torch.ops.aten.permute.default(permute_191, [1, 0]); permute_191 = None + mm_441 = torch.ops.aten.mm.default(view_2046, permute_1104); view_2046 = permute_1104 = None + view_2047 = torch.ops.aten.view.default(mm_441, [2, 4096, 2048]); mm_441 = None + add_1983 = torch.ops.aten.add.Tensor(view_2042, view_2047); view_2042 = view_2047 = None + convert_element_type_2487 = torch.ops.prims.convert_element_type.default(mm_440, torch.float32); mm_440 = None + reduce_scatter_tensor_196 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2487, 'avg', 64, '0'); convert_element_type_2487 = None + wait_tensor_783 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_196); reduce_scatter_tensor_196 = None + convert_element_type_2488 = torch.ops.prims.convert_element_type.default(add_1983, torch.float32); add_1983 = None + convert_element_type_685 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16); primals_215 = None + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_685, 64, '0'); convert_element_type_685 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + convert_element_type_2490 = torch.ops.prims.convert_element_type.default(wait_tensor_263, torch.float32); wait_tensor_263 = None + mul_1761 = torch.ops.aten.mul.Tensor(convert_element_type_2488, convert_element_type_2490); convert_element_type_2490 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(add_821, torch.float32); add_821 = None + mul_597 = torch.ops.aten.mul.Tensor(convert_element_type_686, rsqrt_39); convert_element_type_686 = None + mul_1763 = torch.ops.aten.mul.Tensor(mul_597, mul_1761) + sum_217 = torch.ops.aten.sum.dim_IntList(mul_1763, [2], True); mul_1763 = None + div_215 = torch.ops.aten.div.Tensor(mul_597, 2048) + mul_1764 = torch.ops.aten.mul.Tensor(div_215, sum_217); div_215 = sum_217 = None + sub_708 = torch.ops.aten.sub.Tensor(mul_1761, mul_1764); mul_1761 = mul_1764 = None + mul_1765 = torch.ops.aten.mul.Tensor(sub_708, rsqrt_39); sub_708 = rsqrt_39 = None + mul_1766 = torch.ops.aten.mul.Tensor(convert_element_type_2488, mul_597); convert_element_type_2488 = mul_597 = None + sum_218 = torch.ops.aten.sum.dim_IntList(mul_1766, [0, 1]); mul_1766 = None + convert_element_type_2491 = torch.ops.prims.convert_element_type.default(mul_1765, torch.bfloat16); mul_1765 = None + add_1984 = torch.ops.aten.add.Tensor(add_1982, convert_element_type_2491); add_1982 = convert_element_type_2491 = None + convert_element_type_default_40 = torch.ops.prims.convert_element_type.default(sum_218, torch.float32); sum_218 = None + reduce_scatter_tensor_197 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_40, 'avg', 64, '0'); convert_element_type_default_40 = None + wait_tensor_784 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_197); reduce_scatter_tensor_197 = None + view_2048 = torch.ops.aten.view.default(add_1984, [8192, 2048]) + unsqueeze_67 = torch.ops.aten.unsqueeze.default(view_2048, 1) + convert_element_type_2494 = torch.ops.prims.convert_element_type.default(unsqueeze_67, torch.float32); unsqueeze_67 = None + bmm_54 = torch.ops.aten.bmm.default(permute_1106, convert_element_type_2494); permute_1106 = None + bmm_55 = torch.ops.aten.bmm.default(convert_element_type_2494, permute_1107); convert_element_type_2494 = permute_1107 = None + convert_element_type_2495 = torch.ops.prims.convert_element_type.default(bmm_54, torch.bfloat16); bmm_54 = None + view_2049 = torch.ops.aten.view.default(bmm_55, [8192, 6]); bmm_55 = None + view_2050 = torch.ops.aten.view.default(convert_element_type_2495, [49152, 2048]); convert_element_type_2495 = None + index_80 = torch.ops.aten.index.Tensor(view_2050, [getitem_175]); view_2050 = getitem_175 = None + permute_1108 = torch.ops.aten.permute.default(view_2048, [1, 0]) + mm_442 = torch.ops.aten.mm.default(permute_1108, mul_594); permute_1108 = mul_594 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16); primals_214 = None + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 64, '0'); convert_element_type_680 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_190 = torch.ops.aten.permute.default(wait_tensor_262, [1, 0]); wait_tensor_262 = None + permute_1110 = torch.ops.aten.permute.default(permute_190, [1, 0]); permute_190 = None + mm_443 = torch.ops.aten.mm.default(view_2048, permute_1110); view_2048 = permute_1110 = None + convert_element_type_2500 = torch.ops.prims.convert_element_type.default(mm_442, torch.float32); mm_442 = None + reduce_scatter_tensor_198 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2500, 'avg', 64, '0'); convert_element_type_2500 = None + wait_tensor_785 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_198); reduce_scatter_tensor_198 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(mm_100, torch.float32); mm_100 = None + neg_24 = torch.ops.aten.neg.default(convert_element_type_675) + exp_36 = torch.ops.aten.exp.default(neg_24); neg_24 = None + add_816 = torch.ops.aten.add.Tensor(exp_36, 1); exp_36 = None + div_60 = torch.ops.aten.div.Tensor(convert_element_type_675, add_816) + convert_element_type_676 = torch.ops.prims.convert_element_type.default(div_60, torch.bfloat16); div_60 = None + mul_1767 = torch.ops.aten.mul.Tensor(mm_443, convert_element_type_676); convert_element_type_676 = None + mul_1768 = torch.ops.aten.mul.Tensor(mm_443, mm_101); mm_443 = mm_101 = None + permute_1112 = torch.ops.aten.permute.default(mul_1767, [1, 0]) + mm_444 = torch.ops.aten.mm.default(permute_1112, view_795); permute_1112 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16); primals_213 = None + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 64, '0'); convert_element_type_677 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + permute_1114 = torch.ops.aten.permute.default(permute_189, [1, 0]); permute_189 = None + mm_445 = torch.ops.aten.mm.default(mul_1767, permute_1114); mul_1767 = permute_1114 = None + convert_element_type_2505 = torch.ops.prims.convert_element_type.default(mm_444, torch.float32); mm_444 = None + reduce_scatter_tensor_199 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2505, 'avg', 64, '0'); convert_element_type_2505 = None + wait_tensor_786 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_199); reduce_scatter_tensor_199 = None + convert_element_type_2506 = torch.ops.prims.convert_element_type.default(mul_1768, torch.float32); mul_1768 = None + reciprocal_28 = torch.ops.aten.reciprocal.default(add_816); add_816 = None + mul_1769 = torch.ops.aten.mul.Tensor(reciprocal_28, 1); reciprocal_28 = None + mul_1770 = torch.ops.aten.mul.Tensor(convert_element_type_2506, mul_1769); convert_element_type_2506 = None + sub_709 = torch.ops.aten.sub.Tensor(1, mul_1769); mul_1769 = None + mul_1771 = torch.ops.aten.mul.Tensor(convert_element_type_675, sub_709); convert_element_type_675 = sub_709 = None + add_1986 = torch.ops.aten.add.Tensor(mul_1771, 1); mul_1771 = None + mul_1772 = torch.ops.aten.mul.Tensor(mul_1770, add_1986); mul_1770 = add_1986 = None + convert_element_type_2508 = torch.ops.prims.convert_element_type.default(mul_1772, torch.bfloat16); mul_1772 = None + permute_1116 = torch.ops.aten.permute.default(convert_element_type_2508, [1, 0]) + mm_446 = torch.ops.aten.mm.default(permute_1116, view_795); permute_1116 = None + convert_element_type_672 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16); primals_212 = None + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_672, 64, '0'); convert_element_type_672 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + permute_1118 = torch.ops.aten.permute.default(permute_188, [1, 0]); permute_188 = None + mm_447 = torch.ops.aten.mm.default(convert_element_type_2508, permute_1118); convert_element_type_2508 = permute_1118 = None + add_1987 = torch.ops.aten.add.Tensor(mm_445, mm_447); mm_445 = mm_447 = None + convert_element_type_2513 = torch.ops.prims.convert_element_type.default(mm_446, torch.float32); mm_446 = None + reduce_scatter_tensor_200 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2513, 'avg', 64, '0'); convert_element_type_2513 = None + wait_tensor_787 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_200); reduce_scatter_tensor_200 = None + all_to_all_single_106 = torch.ops._c10d_functional.all_to_all_single.default(index_80, [_local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191], [_local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183], '521'); index_80 = None + wait_tensor_788 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_106); all_to_all_single_106 = None + full_404 = torch.ops.aten.full.default([sym_size_int_45, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_45 = None + slice_scatter_14 = torch.ops.aten.slice_scatter.default(full_404, wait_tensor_788, 0, 0, -1); wait_tensor_788 = None + index_81 = torch.ops.aten.index.Tensor(slice_scatter_14, [getitem_176]); slice_scatter_14 = None + permute_1120 = torch.ops.aten.permute.default(index_81, [1, 0]) + _grouped_mm_162 = torch.ops.aten._grouped_mm.default(permute_1120, mul_574, cumsum_35); permute_1120 = mul_574 = None + convert_element_type_666 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16); primals_210 = None + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_666, 8, '513'); convert_element_type_666 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_255, [0, 2, 1]); wait_tensor_255 = None + permute_1122 = torch.ops.aten.permute.default(permute_187, [0, 2, 1]); permute_187 = None + _grouped_mm_163 = torch.ops.aten._grouped_mm.default(index_81, permute_1122, cumsum_35); index_81 = permute_1122 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(_grouped_mm_33, torch.float32); _grouped_mm_33 = None + neg_23 = torch.ops.aten.neg.default(convert_element_type_670) + exp_35 = torch.ops.aten.exp.default(neg_23); neg_23 = None + add_780 = torch.ops.aten.add.Tensor(exp_35, 1); exp_35 = None + div_59 = torch.ops.aten.div.Tensor(convert_element_type_670, add_780) + convert_element_type_671 = torch.ops.prims.convert_element_type.default(div_59, torch.bfloat16); div_59 = None + mul_1773 = torch.ops.aten.mul.Tensor(_grouped_mm_163, convert_element_type_671); convert_element_type_671 = None + mul_1774 = torch.ops.aten.mul.Tensor(_grouped_mm_163, _grouped_mm_34); _grouped_mm_163 = _grouped_mm_34 = None + permute_1124 = torch.ops.aten.permute.default(mul_1773, [1, 0]) + _grouped_mm_164 = torch.ops.aten._grouped_mm.default(permute_1124, index_23, cumsum_35); permute_1124 = None + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16); primals_211 = None + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 8, '513'); convert_element_type_667 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_256, [0, 2, 1]); wait_tensor_256 = None + permute_1126 = torch.ops.aten.permute.default(permute_186, [0, 2, 1]); permute_186 = None + _grouped_mm_165 = torch.ops.aten._grouped_mm.default(mul_1773, permute_1126, cumsum_35); mul_1773 = permute_1126 = None + convert_element_type_2514 = torch.ops.prims.convert_element_type.default(mul_1774, torch.float32); mul_1774 = None + reciprocal_29 = torch.ops.aten.reciprocal.default(add_780); add_780 = None + mul_1775 = torch.ops.aten.mul.Tensor(reciprocal_29, 1); reciprocal_29 = None + mul_1776 = torch.ops.aten.mul.Tensor(convert_element_type_2514, mul_1775); convert_element_type_2514 = None + sub_710 = torch.ops.aten.sub.Tensor(1, mul_1775); mul_1775 = None + mul_1777 = torch.ops.aten.mul.Tensor(convert_element_type_670, sub_710); convert_element_type_670 = sub_710 = None + add_1989 = torch.ops.aten.add.Tensor(mul_1777, 1); mul_1777 = None + mul_1778 = torch.ops.aten.mul.Tensor(mul_1776, add_1989); mul_1776 = add_1989 = None + convert_element_type_2516 = torch.ops.prims.convert_element_type.default(mul_1778, torch.bfloat16); mul_1778 = None + permute_1128 = torch.ops.aten.permute.default(convert_element_type_2516, [1, 0]) + _grouped_mm_166 = torch.ops.aten._grouped_mm.default(permute_1128, index_23, cumsum_35); permute_1128 = index_23 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16); primals_209 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 8, '513'); convert_element_type_664 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_253, [0, 2, 1]); wait_tensor_253 = None + permute_1130 = torch.ops.aten.permute.default(permute_185, [0, 2, 1]); permute_185 = None + _grouped_mm_167 = torch.ops.aten._grouped_mm.default(convert_element_type_2516, permute_1130, cumsum_35); convert_element_type_2516 = permute_1130 = cumsum_35 = None + add_1990 = torch.ops.aten.add.Tensor(_grouped_mm_165, _grouped_mm_167); _grouped_mm_165 = _grouped_mm_167 = None + convert_element_type_2517 = torch.ops.prims.convert_element_type.default(_grouped_mm_164, torch.float32); _grouped_mm_164 = None + div_216 = torch.ops.aten.div.Tensor(convert_element_type_2517, 64); convert_element_type_2517 = None + reduce_scatter_tensor_201 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_216, 'sum', 8, '513'); div_216 = None + wait_tensor_789 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_201); reduce_scatter_tensor_201 = None + convert_element_type_2518 = torch.ops.prims.convert_element_type.default(_grouped_mm_162, torch.float32); _grouped_mm_162 = None + div_217 = torch.ops.aten.div.Tensor(convert_element_type_2518, 64); convert_element_type_2518 = None + reduce_scatter_tensor_202 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_217, 'sum', 8, '513'); div_217 = None + wait_tensor_790 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_202); reduce_scatter_tensor_202 = None + convert_element_type_2519 = torch.ops.prims.convert_element_type.default(_grouped_mm_166, torch.float32); _grouped_mm_166 = None + div_218 = torch.ops.aten.div.Tensor(convert_element_type_2519, 64); convert_element_type_2519 = None + reduce_scatter_tensor_203 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_218, 'sum', 8, '513'); div_218 = None + wait_tensor_791 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_203); reduce_scatter_tensor_203 = None + index_put_80 = torch.ops.aten.index_put.default(full_404, [getitem_176], add_1990, True); full_404 = getitem_176 = add_1990 = None + slice_191 = torch.ops.aten.slice.Tensor(index_put_80, 0, 0, add_1991); index_put_80 = add_1991 = None + all_to_all_single_107 = torch.ops._c10d_functional.all_to_all_single.default(slice_191, [_local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183], [_local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191], '521'); slice_191 = _local_scalar_dense_176 = _local_scalar_dense_177 = _local_scalar_dense_178 = _local_scalar_dense_179 = _local_scalar_dense_180 = _local_scalar_dense_181 = _local_scalar_dense_182 = _local_scalar_dense_183 = _local_scalar_dense_184 = _local_scalar_dense_185 = _local_scalar_dense_186 = _local_scalar_dense_187 = _local_scalar_dense_188 = _local_scalar_dense_189 = _local_scalar_dense_190 = _local_scalar_dense_191 = None + wait_tensor_792 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_107); all_to_all_single_107 = None + index_put_81 = torch.ops.aten.index_put.default(full_default_52, [div_57], wait_tensor_792, True); div_57 = wait_tensor_792 = None + add_1995 = torch.ops.aten.add.Tensor(add_1987, index_put_81); add_1987 = index_put_81 = None + mul_1779 = torch.ops.aten.mul.Tensor(view_2049, 1.0); view_2049 = None + scatter_add_14 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_173, mul_1779); getitem_173 = mul_1779 = None + convert_element_type_659 = torch.ops.prims.convert_element_type.default(mm_99, torch.float32); mm_99 = None + sub_264 = torch.ops.aten.sub.Tensor(convert_element_type_659, amax_11); convert_element_type_659 = amax_11 = None + exp_34 = torch.ops.aten.exp.default(sub_264); sub_264 = None + div_56 = torch.ops.aten.div.Tensor(exp_34, sum_45); exp_34 = sum_45 = None + mul_1780 = torch.ops.aten.mul.Tensor(scatter_add_14, div_56); scatter_add_14 = None + sum_219 = torch.ops.aten.sum.dim_IntList(mul_1780, [1], True) + neg_97 = torch.ops.aten.neg.default(div_56); div_56 = None + fma_14 = torch.ops.prims.fma.default(neg_97, sum_219, mul_1780); neg_97 = sum_219 = mul_1780 = None + convert_element_type_2520 = torch.ops.prims.convert_element_type.default(fma_14, torch.bfloat16); fma_14 = None + permute_1132 = torch.ops.aten.permute.default(convert_element_type_2520, [1, 0]) + mm_448 = torch.ops.aten.mm.default(permute_1132, view_795); permute_1132 = view_795 = None + convert_element_type_656 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16); primals_207 = None + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_656, 64, '0'); convert_element_type_656 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_249, [1, 0]); wait_tensor_249 = None + permute_1134 = torch.ops.aten.permute.default(permute_184, [1, 0]); permute_184 = None + mm_449 = torch.ops.aten.mm.default(convert_element_type_2520, permute_1134); convert_element_type_2520 = permute_1134 = None + add_1996 = torch.ops.aten.add.Tensor(add_1995, mm_449); add_1995 = mm_449 = None + convert_element_type_2525 = torch.ops.prims.convert_element_type.default(mm_448, torch.float32); mm_448 = None + reduce_scatter_tensor_204 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2525, 'avg', 64, '0'); convert_element_type_2525 = None + wait_tensor_793 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_204); reduce_scatter_tensor_204 = None + view_2051 = torch.ops.aten.view.default(add_1996, [2, 4096, 2048]); add_1996 = None + convert_element_type_2526 = torch.ops.prims.convert_element_type.default(view_2051, torch.float32); view_2051 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16); primals_205 = None + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_653, 64, '0'); convert_element_type_653 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_2528 = torch.ops.prims.convert_element_type.default(wait_tensor_248, torch.float32); wait_tensor_248 = None + mul_1781 = torch.ops.aten.mul.Tensor(convert_element_type_2526, convert_element_type_2528); convert_element_type_2528 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(add_756, torch.float32); add_756 = None + mul_554 = torch.ops.aten.mul.Tensor(convert_element_type_654, rsqrt_38); convert_element_type_654 = None + mul_1783 = torch.ops.aten.mul.Tensor(mul_554, mul_1781) + sum_220 = torch.ops.aten.sum.dim_IntList(mul_1783, [2], True); mul_1783 = None + div_219 = torch.ops.aten.div.Tensor(mul_554, 2048) + mul_1784 = torch.ops.aten.mul.Tensor(div_219, sum_220); div_219 = sum_220 = None + sub_712 = torch.ops.aten.sub.Tensor(mul_1781, mul_1784); mul_1781 = mul_1784 = None + mul_1785 = torch.ops.aten.mul.Tensor(sub_712, rsqrt_38); sub_712 = rsqrt_38 = None + mul_1786 = torch.ops.aten.mul.Tensor(convert_element_type_2526, mul_554); convert_element_type_2526 = mul_554 = None + sum_221 = torch.ops.aten.sum.dim_IntList(mul_1786, [0, 1]); mul_1786 = None + convert_element_type_2529 = torch.ops.prims.convert_element_type.default(mul_1785, torch.bfloat16); mul_1785 = None + add_1997 = torch.ops.aten.add.Tensor(add_1984, convert_element_type_2529); add_1984 = convert_element_type_2529 = None + convert_element_type_default_39 = torch.ops.prims.convert_element_type.default(sum_221, torch.float32); sum_221 = None + reduce_scatter_tensor_205 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_39, 'avg', 64, '0'); convert_element_type_default_39 = None + wait_tensor_794 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_205); reduce_scatter_tensor_205 = None + view_2052 = torch.ops.aten.view.default(add_1997, [8192, 2048]) + permute_1136 = torch.ops.aten.permute.default(view_2052, [1, 0]) + permute_182 = torch.ops.aten.permute.default(getitem_169, [0, 2, 1, 3]) + view_790 = torch.ops.aten.view.default(permute_182, [2, 4096, -1]); permute_182 = None + view_792 = torch.ops.aten.view.default(view_790, [8192, 2048]); view_790 = None + mm_450 = torch.ops.aten.mm.default(permute_1136, view_792); permute_1136 = view_792 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16); primals_204 = None + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 64, '0'); convert_element_type_650 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + permute_1138 = torch.ops.aten.permute.default(permute_183, [1, 0]); permute_183 = None + mm_451 = torch.ops.aten.mm.default(view_2052, permute_1138); view_2052 = permute_1138 = None + view_2053 = torch.ops.aten.view.default(mm_451, [2, 4096, 2048]); mm_451 = None + convert_element_type_2536 = torch.ops.prims.convert_element_type.default(mm_450, torch.float32); mm_450 = None + reduce_scatter_tensor_206 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2536, 'avg', 64, '0'); convert_element_type_2536 = None + wait_tensor_795 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_206); reduce_scatter_tensor_206 = None + view_2054 = torch.ops.aten.view.default(view_2053, [2, 4096, 16, 128]); view_2053 = None + permute_1140 = torch.ops.aten.permute.default(view_2054, [0, 2, 1, 3]); view_2054 = None + fw_graph14 = self.fw_graph14 + joint_graph14 = self.joint_graph14 + mask_graph14 = self.mask_graph14 + flex_attention_backward_14 = torch.ops.higher_order.flex_attention_backward(permute_179, permute_180, permute_181, getitem_169, getitem_170, permute_1140, None, fw_graph14, joint_graph14, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph14), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_179 = permute_180 = permute_181 = getitem_169 = getitem_170 = permute_1140 = fw_graph14 = joint_graph14 = mask_graph14 = None + getitem_429 = flex_attention_backward_14[0] + getitem_430 = flex_attention_backward_14[1] + getitem_431 = flex_attention_backward_14[2]; flex_attention_backward_14 = None + permute_1141 = torch.ops.aten.permute.default(getitem_431, [0, 2, 1, 3]); getitem_431 = None + permute_1142 = torch.ops.aten.permute.default(getitem_430, [0, 2, 1, 3]); getitem_430 = None + permute_1143 = torch.ops.aten.permute.default(getitem_429, [0, 2, 1, 3]); getitem_429 = None + slice_193 = torch.ops.aten.slice.Tensor(permute_1142, 3, 0, 128) + slice_194 = torch.ops.aten.slice.Tensor(permute_1142, 3, 128, 192); permute_1142 = None + sum_222 = torch.ops.aten.sum.dim_IntList(slice_194, [2], True); slice_194 = None + cat_122 = torch.ops.aten.cat.default([slice_193, permute_1141], 3); slice_193 = permute_1141 = None + view_2055 = torch.ops.aten.view.default(cat_122, [2, 4096, 4096]); cat_122 = None + view_2056 = torch.ops.aten.view.default(view_2055, [8192, 4096]); view_2055 = None + permute_1144 = torch.ops.aten.permute.default(view_2056, [1, 0]) + mm_452 = torch.ops.aten.mm.default(permute_1144, view_787); permute_1144 = view_787 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16); primals_203 = None + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 64, '0'); convert_element_type_647 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + permute_1146 = torch.ops.aten.permute.default(permute_178, [1, 0]); permute_178 = None + mm_453 = torch.ops.aten.mm.default(view_2056, permute_1146); view_2056 = permute_1146 = None + view_2057 = torch.ops.aten.view.default(mm_453, [2, 4096, 512]); mm_453 = None + convert_element_type_2541 = torch.ops.prims.convert_element_type.default(mm_452, torch.float32); mm_452 = None + reduce_scatter_tensor_207 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2541, 'avg', 64, '0'); convert_element_type_2541 = None + wait_tensor_796 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_207); reduce_scatter_tensor_207 = None + convert_element_type_2542 = torch.ops.prims.convert_element_type.default(view_2057, torch.float32); view_2057 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16); primals_202 = None + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 64, '0'); convert_element_type_644 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + convert_element_type_2544 = torch.ops.prims.convert_element_type.default(wait_tensor_245, torch.float32); wait_tensor_245 = None + mul_1787 = torch.ops.aten.mul.Tensor(convert_element_type_2542, convert_element_type_2544); convert_element_type_2544 = None + convert_element_type_645 = torch.ops.prims.convert_element_type.default(getitem_165, torch.float32); getitem_165 = None + mul_552 = torch.ops.aten.mul.Tensor(convert_element_type_645, rsqrt_37); convert_element_type_645 = None + mul_1789 = torch.ops.aten.mul.Tensor(mul_552, mul_1787) + sum_223 = torch.ops.aten.sum.dim_IntList(mul_1789, [2], True); mul_1789 = None + div_220 = torch.ops.aten.div.Tensor(mul_552, 512) + mul_1790 = torch.ops.aten.mul.Tensor(div_220, sum_223); div_220 = sum_223 = None + sub_713 = torch.ops.aten.sub.Tensor(mul_1787, mul_1790); mul_1787 = mul_1790 = None + mul_1791 = torch.ops.aten.mul.Tensor(sub_713, rsqrt_37); sub_713 = rsqrt_37 = None + mul_1792 = torch.ops.aten.mul.Tensor(convert_element_type_2542, mul_552); convert_element_type_2542 = mul_552 = None + sum_224 = torch.ops.aten.sum.dim_IntList(mul_1792, [0, 1]); mul_1792 = None + convert_element_type_2545 = torch.ops.prims.convert_element_type.default(mul_1791, torch.bfloat16); mul_1791 = None + convert_element_type_default_38 = torch.ops.prims.convert_element_type.default(sum_224, torch.float32); sum_224 = None + reduce_scatter_tensor_208 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_38, 'avg', 64, '0'); convert_element_type_default_38 = None + wait_tensor_797 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_208); reduce_scatter_tensor_208 = None + convert_element_type_2548 = torch.ops.prims.convert_element_type.default(sum_222, torch.float32); sum_222 = None + view_2058 = torch.ops.aten.view.default(convert_element_type_2548, [2, 4096, 1, 32, 2]); convert_element_type_2548 = None + view_as_complex_82 = torch.ops.aten.view_as_complex.default(view_2058); view_2058 = None + mul_1793 = torch.ops.aten.mul.Tensor(view_as_complex_82, clone_9); view_as_complex_82 = None + view_as_real_82 = torch.ops.aten.view_as_real.default(mul_1793); mul_1793 = None + view_2059 = torch.ops.aten.view.default(view_as_real_82, [2, 4096, 1, 64]); view_as_real_82 = None + convert_element_type_2549 = torch.ops.prims.convert_element_type.default(view_2059, torch.bfloat16); view_2059 = None + squeeze_40 = torch.ops.aten.squeeze.dim(convert_element_type_2549, 2); convert_element_type_2549 = None + cat_123 = torch.ops.aten.cat.default([convert_element_type_2545, squeeze_40], 2); convert_element_type_2545 = squeeze_40 = None + view_2060 = torch.ops.aten.view.default(cat_123, [8192, 576]); cat_123 = None + permute_1148 = torch.ops.aten.permute.default(view_2060, [1, 0]) + mm_454 = torch.ops.aten.mm.default(permute_1148, view_773); permute_1148 = None + convert_element_type_639 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16); primals_201 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_639, 64, '0'); convert_element_type_639 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_244, [1, 0]); wait_tensor_244 = None + permute_1150 = torch.ops.aten.permute.default(permute_177, [1, 0]); permute_177 = None + mm_455 = torch.ops.aten.mm.default(view_2060, permute_1150); view_2060 = permute_1150 = None + view_2061 = torch.ops.aten.view.default(mm_455, [2, 4096, 2048]); mm_455 = None + convert_element_type_2554 = torch.ops.prims.convert_element_type.default(mm_454, torch.float32); mm_454 = None + reduce_scatter_tensor_209 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2554, 'avg', 64, '0'); convert_element_type_2554 = None + wait_tensor_798 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_209); reduce_scatter_tensor_209 = None + slice_195 = torch.ops.aten.slice.Tensor(permute_1143, 3, 0, 128) + slice_196 = torch.ops.aten.slice.Tensor(permute_1143, 3, 128, 192); permute_1143 = None + convert_element_type_2555 = torch.ops.prims.convert_element_type.default(slice_196, torch.float32); slice_196 = None + view_2062 = torch.ops.aten.view.default(convert_element_type_2555, [2, 4096, 16, 32, 2]); convert_element_type_2555 = None + view_as_complex_83 = torch.ops.aten.view_as_complex.default(view_2062); view_2062 = None + mul_1794 = torch.ops.aten.mul.Tensor(view_as_complex_83, clone_9); view_as_complex_83 = None + view_as_real_83 = torch.ops.aten.view_as_real.default(mul_1794); mul_1794 = None + view_2063 = torch.ops.aten.view.default(view_as_real_83, [2, 4096, 16, 64]); view_as_real_83 = None + convert_element_type_2556 = torch.ops.prims.convert_element_type.default(view_2063, torch.bfloat16); view_2063 = None + cat_124 = torch.ops.aten.cat.default([slice_195, convert_element_type_2556], 3); slice_195 = convert_element_type_2556 = None + view_2064 = torch.ops.aten.view.default(cat_124, [2, 4096, 3072]); cat_124 = None + view_2065 = torch.ops.aten.view.default(view_2064, [8192, 3072]); view_2064 = None + permute_1152 = torch.ops.aten.permute.default(view_2065, [1, 0]) + mm_456 = torch.ops.aten.mm.default(permute_1152, view_773); permute_1152 = view_773 = None + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16); primals_200 = None + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 64, '0'); convert_element_type_634 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + permute_1154 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None + mm_457 = torch.ops.aten.mm.default(view_2065, permute_1154); view_2065 = permute_1154 = None + view_2066 = torch.ops.aten.view.default(mm_457, [2, 4096, 2048]); mm_457 = None + add_1998 = torch.ops.aten.add.Tensor(view_2061, view_2066); view_2061 = view_2066 = None + convert_element_type_2561 = torch.ops.prims.convert_element_type.default(mm_456, torch.float32); mm_456 = None + reduce_scatter_tensor_210 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2561, 'avg', 64, '0'); convert_element_type_2561 = None + wait_tensor_799 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_210); reduce_scatter_tensor_210 = None + convert_element_type_2562 = torch.ops.prims.convert_element_type.default(add_1998, torch.float32); add_1998 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16); primals_199 = None + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 64, '0'); convert_element_type_631 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + convert_element_type_2564 = torch.ops.prims.convert_element_type.default(wait_tensor_242, torch.float32); wait_tensor_242 = None + mul_1795 = torch.ops.aten.mul.Tensor(convert_element_type_2562, convert_element_type_2564); convert_element_type_2564 = None + convert_element_type_632 = torch.ops.prims.convert_element_type.default(add_753, torch.float32); add_753 = None + mul_548 = torch.ops.aten.mul.Tensor(convert_element_type_632, rsqrt_36); convert_element_type_632 = None + mul_1797 = torch.ops.aten.mul.Tensor(mul_548, mul_1795) + sum_225 = torch.ops.aten.sum.dim_IntList(mul_1797, [2], True); mul_1797 = None + div_221 = torch.ops.aten.div.Tensor(mul_548, 2048) + mul_1798 = torch.ops.aten.mul.Tensor(div_221, sum_225); div_221 = sum_225 = None + sub_714 = torch.ops.aten.sub.Tensor(mul_1795, mul_1798); mul_1795 = mul_1798 = None + mul_1799 = torch.ops.aten.mul.Tensor(sub_714, rsqrt_36); sub_714 = rsqrt_36 = None + mul_1800 = torch.ops.aten.mul.Tensor(convert_element_type_2562, mul_548); convert_element_type_2562 = mul_548 = None + sum_226 = torch.ops.aten.sum.dim_IntList(mul_1800, [0, 1]); mul_1800 = None + convert_element_type_2565 = torch.ops.prims.convert_element_type.default(mul_1799, torch.bfloat16); mul_1799 = None + add_1999 = torch.ops.aten.add.Tensor(add_1997, convert_element_type_2565); add_1997 = convert_element_type_2565 = None + convert_element_type_default_37 = torch.ops.prims.convert_element_type.default(sum_226, torch.float32); sum_226 = None + reduce_scatter_tensor_211 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_37, 'avg', 64, '0'); convert_element_type_default_37 = None + wait_tensor_800 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_211); reduce_scatter_tensor_211 = None + view_2067 = torch.ops.aten.view.default(add_1999, [8192, 2048]) + unsqueeze_68 = torch.ops.aten.unsqueeze.default(view_2067, 1) + convert_element_type_2568 = torch.ops.prims.convert_element_type.default(unsqueeze_68, torch.float32); unsqueeze_68 = None + bmm_56 = torch.ops.aten.bmm.default(permute_1156, convert_element_type_2568); permute_1156 = None + bmm_57 = torch.ops.aten.bmm.default(convert_element_type_2568, permute_1157); convert_element_type_2568 = permute_1157 = None + convert_element_type_2569 = torch.ops.prims.convert_element_type.default(bmm_56, torch.bfloat16); bmm_56 = None + view_2068 = torch.ops.aten.view.default(bmm_57, [8192, 6]); bmm_57 = None + view_2069 = torch.ops.aten.view.default(convert_element_type_2569, [49152, 2048]); convert_element_type_2569 = None + index_82 = torch.ops.aten.index.Tensor(view_2069, [getitem_161]); view_2069 = getitem_161 = None + permute_1158 = torch.ops.aten.permute.default(view_2067, [1, 0]) + mm_458 = torch.ops.aten.mm.default(permute_1158, mul_545); permute_1158 = mul_545 = None + convert_element_type_626 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16); primals_198 = None + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_626, 64, '0'); convert_element_type_626 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + permute_1160 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None + mm_459 = torch.ops.aten.mm.default(view_2067, permute_1160); view_2067 = permute_1160 = None + convert_element_type_2574 = torch.ops.prims.convert_element_type.default(mm_458, torch.float32); mm_458 = None + reduce_scatter_tensor_212 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2574, 'avg', 64, '0'); convert_element_type_2574 = None + wait_tensor_801 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_212); reduce_scatter_tensor_212 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mm_92, torch.float32); mm_92 = None + neg_22 = torch.ops.aten.neg.default(convert_element_type_621) + exp_33 = torch.ops.aten.exp.default(neg_22); neg_22 = None + add_748 = torch.ops.aten.add.Tensor(exp_33, 1); exp_33 = None + div_55 = torch.ops.aten.div.Tensor(convert_element_type_621, add_748) + convert_element_type_622 = torch.ops.prims.convert_element_type.default(div_55, torch.bfloat16); div_55 = None + mul_1801 = torch.ops.aten.mul.Tensor(mm_459, convert_element_type_622); convert_element_type_622 = None + mul_1802 = torch.ops.aten.mul.Tensor(mm_459, mm_93); mm_459 = mm_93 = None + permute_1162 = torch.ops.aten.permute.default(mul_1801, [1, 0]) + mm_460 = torch.ops.aten.mm.default(permute_1162, view_728); permute_1162 = None + convert_element_type_623 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16); primals_197 = None + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_623, 64, '0'); convert_element_type_623 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_240, [1, 0]); wait_tensor_240 = None + permute_1164 = torch.ops.aten.permute.default(permute_174, [1, 0]); permute_174 = None + mm_461 = torch.ops.aten.mm.default(mul_1801, permute_1164); mul_1801 = permute_1164 = None + convert_element_type_2579 = torch.ops.prims.convert_element_type.default(mm_460, torch.float32); mm_460 = None + reduce_scatter_tensor_213 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2579, 'avg', 64, '0'); convert_element_type_2579 = None + wait_tensor_802 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_213); reduce_scatter_tensor_213 = None + convert_element_type_2580 = torch.ops.prims.convert_element_type.default(mul_1802, torch.float32); mul_1802 = None + reciprocal_30 = torch.ops.aten.reciprocal.default(add_748); add_748 = None + mul_1803 = torch.ops.aten.mul.Tensor(reciprocal_30, 1); reciprocal_30 = None + mul_1804 = torch.ops.aten.mul.Tensor(convert_element_type_2580, mul_1803); convert_element_type_2580 = None + sub_715 = torch.ops.aten.sub.Tensor(1, mul_1803); mul_1803 = None + mul_1805 = torch.ops.aten.mul.Tensor(convert_element_type_621, sub_715); convert_element_type_621 = sub_715 = None + add_2001 = torch.ops.aten.add.Tensor(mul_1805, 1); mul_1805 = None + mul_1806 = torch.ops.aten.mul.Tensor(mul_1804, add_2001); mul_1804 = add_2001 = None + convert_element_type_2582 = torch.ops.prims.convert_element_type.default(mul_1806, torch.bfloat16); mul_1806 = None + permute_1166 = torch.ops.aten.permute.default(convert_element_type_2582, [1, 0]) + mm_462 = torch.ops.aten.mm.default(permute_1166, view_728); permute_1166 = None + convert_element_type_618 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_618, 64, '0'); convert_element_type_618 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + permute_1168 = torch.ops.aten.permute.default(permute_173, [1, 0]); permute_173 = None + mm_463 = torch.ops.aten.mm.default(convert_element_type_2582, permute_1168); convert_element_type_2582 = permute_1168 = None + add_2002 = torch.ops.aten.add.Tensor(mm_461, mm_463); mm_461 = mm_463 = None + convert_element_type_2587 = torch.ops.prims.convert_element_type.default(mm_462, torch.float32); mm_462 = None + reduce_scatter_tensor_214 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2587, 'avg', 64, '0'); convert_element_type_2587 = None + wait_tensor_803 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_214); reduce_scatter_tensor_214 = None + all_to_all_single_108 = torch.ops._c10d_functional.all_to_all_single.default(index_82, [_local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175], [_local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167], '521'); index_82 = None + wait_tensor_804 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_108); all_to_all_single_108 = None + full_408 = torch.ops.aten.full.default([sym_size_int_41, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_41 = None + slice_scatter_15 = torch.ops.aten.slice_scatter.default(full_408, wait_tensor_804, 0, 0, -1); wait_tensor_804 = None + index_83 = torch.ops.aten.index.Tensor(slice_scatter_15, [getitem_162]); slice_scatter_15 = None + permute_1170 = torch.ops.aten.permute.default(index_83, [1, 0]) + _grouped_mm_168 = torch.ops.aten._grouped_mm.default(permute_1170, mul_525, cumsum_32); permute_1170 = mul_525 = None + convert_element_type_612 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16); primals_194 = None + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_612, 8, '513'); convert_element_type_612 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_234, [0, 2, 1]); wait_tensor_234 = None + permute_1172 = torch.ops.aten.permute.default(permute_172, [0, 2, 1]); permute_172 = None + _grouped_mm_169 = torch.ops.aten._grouped_mm.default(index_83, permute_1172, cumsum_32); index_83 = permute_1172 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(_grouped_mm_30, torch.float32); _grouped_mm_30 = None + neg_21 = torch.ops.aten.neg.default(convert_element_type_616) + exp_32 = torch.ops.aten.exp.default(neg_21); neg_21 = None + add_712 = torch.ops.aten.add.Tensor(exp_32, 1); exp_32 = None + div_54 = torch.ops.aten.div.Tensor(convert_element_type_616, add_712) + convert_element_type_617 = torch.ops.prims.convert_element_type.default(div_54, torch.bfloat16); div_54 = None + mul_1807 = torch.ops.aten.mul.Tensor(_grouped_mm_169, convert_element_type_617); convert_element_type_617 = None + mul_1808 = torch.ops.aten.mul.Tensor(_grouped_mm_169, _grouped_mm_31); _grouped_mm_169 = _grouped_mm_31 = None + permute_1174 = torch.ops.aten.permute.default(mul_1807, [1, 0]) + _grouped_mm_170 = torch.ops.aten._grouped_mm.default(permute_1174, index_21, cumsum_32); permute_1174 = None + convert_element_type_613 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_613, 8, '513'); convert_element_type_613 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_171 = torch.ops.aten.permute.default(wait_tensor_235, [0, 2, 1]); wait_tensor_235 = None + permute_1176 = torch.ops.aten.permute.default(permute_171, [0, 2, 1]); permute_171 = None + _grouped_mm_171 = torch.ops.aten._grouped_mm.default(mul_1807, permute_1176, cumsum_32); mul_1807 = permute_1176 = None + convert_element_type_2588 = torch.ops.prims.convert_element_type.default(mul_1808, torch.float32); mul_1808 = None + reciprocal_31 = torch.ops.aten.reciprocal.default(add_712); add_712 = None + mul_1809 = torch.ops.aten.mul.Tensor(reciprocal_31, 1); reciprocal_31 = None + mul_1810 = torch.ops.aten.mul.Tensor(convert_element_type_2588, mul_1809); convert_element_type_2588 = None + sub_716 = torch.ops.aten.sub.Tensor(1, mul_1809); mul_1809 = None + mul_1811 = torch.ops.aten.mul.Tensor(convert_element_type_616, sub_716); convert_element_type_616 = sub_716 = None + add_2004 = torch.ops.aten.add.Tensor(mul_1811, 1); mul_1811 = None + mul_1812 = torch.ops.aten.mul.Tensor(mul_1810, add_2004); mul_1810 = add_2004 = None + convert_element_type_2590 = torch.ops.prims.convert_element_type.default(mul_1812, torch.bfloat16); mul_1812 = None + permute_1178 = torch.ops.aten.permute.default(convert_element_type_2590, [1, 0]) + _grouped_mm_172 = torch.ops.aten._grouped_mm.default(permute_1178, index_21, cumsum_32); permute_1178 = index_21 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16); primals_193 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_610, 8, '513'); convert_element_type_610 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + permute_170 = torch.ops.aten.permute.default(wait_tensor_232, [0, 2, 1]); wait_tensor_232 = None + permute_1180 = torch.ops.aten.permute.default(permute_170, [0, 2, 1]); permute_170 = None + _grouped_mm_173 = torch.ops.aten._grouped_mm.default(convert_element_type_2590, permute_1180, cumsum_32); convert_element_type_2590 = permute_1180 = cumsum_32 = None + add_2005 = torch.ops.aten.add.Tensor(_grouped_mm_171, _grouped_mm_173); _grouped_mm_171 = _grouped_mm_173 = None + convert_element_type_2591 = torch.ops.prims.convert_element_type.default(_grouped_mm_170, torch.float32); _grouped_mm_170 = None + div_222 = torch.ops.aten.div.Tensor(convert_element_type_2591, 64); convert_element_type_2591 = None + reduce_scatter_tensor_215 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_222, 'sum', 8, '513'); div_222 = None + wait_tensor_805 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_215); reduce_scatter_tensor_215 = None + convert_element_type_2592 = torch.ops.prims.convert_element_type.default(_grouped_mm_168, torch.float32); _grouped_mm_168 = None + div_223 = torch.ops.aten.div.Tensor(convert_element_type_2592, 64); convert_element_type_2592 = None + reduce_scatter_tensor_216 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_223, 'sum', 8, '513'); div_223 = None + wait_tensor_806 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_216); reduce_scatter_tensor_216 = None + convert_element_type_2593 = torch.ops.prims.convert_element_type.default(_grouped_mm_172, torch.float32); _grouped_mm_172 = None + div_224 = torch.ops.aten.div.Tensor(convert_element_type_2593, 64); convert_element_type_2593 = None + reduce_scatter_tensor_217 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_224, 'sum', 8, '513'); div_224 = None + wait_tensor_807 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_217); reduce_scatter_tensor_217 = None + index_put_82 = torch.ops.aten.index_put.default(full_408, [getitem_162], add_2005, True); full_408 = getitem_162 = add_2005 = None + slice_197 = torch.ops.aten.slice.Tensor(index_put_82, 0, 0, add_2006); index_put_82 = add_2006 = None + all_to_all_single_109 = torch.ops._c10d_functional.all_to_all_single.default(slice_197, [_local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167], [_local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175], '521'); slice_197 = _local_scalar_dense_160 = _local_scalar_dense_161 = _local_scalar_dense_162 = _local_scalar_dense_163 = _local_scalar_dense_164 = _local_scalar_dense_165 = _local_scalar_dense_166 = _local_scalar_dense_167 = _local_scalar_dense_168 = _local_scalar_dense_169 = _local_scalar_dense_170 = _local_scalar_dense_171 = _local_scalar_dense_172 = _local_scalar_dense_173 = _local_scalar_dense_174 = _local_scalar_dense_175 = None + wait_tensor_808 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_109); all_to_all_single_109 = None + index_put_83 = torch.ops.aten.index_put.default(full_default_52, [div_52], wait_tensor_808, True); div_52 = wait_tensor_808 = None + add_2010 = torch.ops.aten.add.Tensor(add_2002, index_put_83); add_2002 = index_put_83 = None + mul_1813 = torch.ops.aten.mul.Tensor(view_2068, 1.0); view_2068 = None + scatter_add_15 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_159, mul_1813); getitem_159 = mul_1813 = None + convert_element_type_605 = torch.ops.prims.convert_element_type.default(mm_91, torch.float32); mm_91 = None + sub_240 = torch.ops.aten.sub.Tensor(convert_element_type_605, amax_10); convert_element_type_605 = amax_10 = None + exp_31 = torch.ops.aten.exp.default(sub_240); sub_240 = None + div_51 = torch.ops.aten.div.Tensor(exp_31, sum_41); exp_31 = sum_41 = None + mul_1814 = torch.ops.aten.mul.Tensor(scatter_add_15, div_51); scatter_add_15 = None + sum_227 = torch.ops.aten.sum.dim_IntList(mul_1814, [1], True) + neg_100 = torch.ops.aten.neg.default(div_51); div_51 = None + fma_15 = torch.ops.prims.fma.default(neg_100, sum_227, mul_1814); neg_100 = sum_227 = mul_1814 = None + convert_element_type_2594 = torch.ops.prims.convert_element_type.default(fma_15, torch.bfloat16); fma_15 = None + permute_1182 = torch.ops.aten.permute.default(convert_element_type_2594, [1, 0]) + mm_464 = torch.ops.aten.mm.default(permute_1182, view_728); permute_1182 = view_728 = None + convert_element_type_602 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16); primals_191 = None + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_602, 64, '0'); convert_element_type_602 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + permute_169 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + permute_1184 = torch.ops.aten.permute.default(permute_169, [1, 0]); permute_169 = None + mm_465 = torch.ops.aten.mm.default(convert_element_type_2594, permute_1184); convert_element_type_2594 = permute_1184 = None + add_2011 = torch.ops.aten.add.Tensor(add_2010, mm_465); add_2010 = mm_465 = None + convert_element_type_2599 = torch.ops.prims.convert_element_type.default(mm_464, torch.float32); mm_464 = None + reduce_scatter_tensor_218 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2599, 'avg', 64, '0'); convert_element_type_2599 = None + wait_tensor_809 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_218); reduce_scatter_tensor_218 = None + view_2070 = torch.ops.aten.view.default(add_2011, [2, 4096, 2048]); add_2011 = None + convert_element_type_2600 = torch.ops.prims.convert_element_type.default(view_2070, torch.float32); view_2070 = None + convert_element_type_599 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16); primals_189 = None + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_599, 64, '0'); convert_element_type_599 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + convert_element_type_2602 = torch.ops.prims.convert_element_type.default(wait_tensor_227, torch.float32); wait_tensor_227 = None + mul_1815 = torch.ops.aten.mul.Tensor(convert_element_type_2600, convert_element_type_2602); convert_element_type_2602 = None + convert_element_type_600 = torch.ops.prims.convert_element_type.default(add_688, torch.float32); add_688 = None + mul_505 = torch.ops.aten.mul.Tensor(convert_element_type_600, rsqrt_35); convert_element_type_600 = None + mul_1817 = torch.ops.aten.mul.Tensor(mul_505, mul_1815) + sum_228 = torch.ops.aten.sum.dim_IntList(mul_1817, [2], True); mul_1817 = None + div_225 = torch.ops.aten.div.Tensor(mul_505, 2048) + mul_1818 = torch.ops.aten.mul.Tensor(div_225, sum_228); div_225 = sum_228 = None + sub_718 = torch.ops.aten.sub.Tensor(mul_1815, mul_1818); mul_1815 = mul_1818 = None + mul_1819 = torch.ops.aten.mul.Tensor(sub_718, rsqrt_35); sub_718 = rsqrt_35 = None + mul_1820 = torch.ops.aten.mul.Tensor(convert_element_type_2600, mul_505); convert_element_type_2600 = mul_505 = None + sum_229 = torch.ops.aten.sum.dim_IntList(mul_1820, [0, 1]); mul_1820 = None + convert_element_type_2603 = torch.ops.prims.convert_element_type.default(mul_1819, torch.bfloat16); mul_1819 = None + add_2012 = torch.ops.aten.add.Tensor(add_1999, convert_element_type_2603); add_1999 = convert_element_type_2603 = None + convert_element_type_default_36 = torch.ops.prims.convert_element_type.default(sum_229, torch.float32); sum_229 = None + reduce_scatter_tensor_219 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_36, 'avg', 64, '0'); convert_element_type_default_36 = None + wait_tensor_810 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_219); reduce_scatter_tensor_219 = None + view_2071 = torch.ops.aten.view.default(add_2012, [8192, 2048]) + permute_1186 = torch.ops.aten.permute.default(view_2071, [1, 0]) + permute_167 = torch.ops.aten.permute.default(getitem_155, [0, 2, 1, 3]) + view_723 = torch.ops.aten.view.default(permute_167, [2, 4096, -1]); permute_167 = None + view_725 = torch.ops.aten.view.default(view_723, [8192, 2048]); view_723 = None + mm_466 = torch.ops.aten.mm.default(permute_1186, view_725); permute_1186 = view_725 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16); primals_188 = None + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_596, 64, '0'); convert_element_type_596 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + permute_168 = torch.ops.aten.permute.default(wait_tensor_226, [1, 0]); wait_tensor_226 = None + permute_1188 = torch.ops.aten.permute.default(permute_168, [1, 0]); permute_168 = None + mm_467 = torch.ops.aten.mm.default(view_2071, permute_1188); view_2071 = permute_1188 = None + view_2072 = torch.ops.aten.view.default(mm_467, [2, 4096, 2048]); mm_467 = None + convert_element_type_2610 = torch.ops.prims.convert_element_type.default(mm_466, torch.float32); mm_466 = None + reduce_scatter_tensor_220 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2610, 'avg', 64, '0'); convert_element_type_2610 = None + wait_tensor_811 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_220); reduce_scatter_tensor_220 = None + view_2073 = torch.ops.aten.view.default(view_2072, [2, 4096, 16, 128]); view_2072 = None + permute_1190 = torch.ops.aten.permute.default(view_2073, [0, 2, 1, 3]); view_2073 = None + fw_graph15 = self.fw_graph15 + joint_graph15 = self.joint_graph15 + mask_graph15 = self.mask_graph15 + flex_attention_backward_15 = torch.ops.higher_order.flex_attention_backward(permute_164, permute_165, permute_166, getitem_155, getitem_156, permute_1190, None, fw_graph15, joint_graph15, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph15), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_164 = permute_165 = permute_166 = getitem_155 = getitem_156 = permute_1190 = fw_graph15 = joint_graph15 = mask_graph15 = None + getitem_433 = flex_attention_backward_15[0] + getitem_434 = flex_attention_backward_15[1] + getitem_435 = flex_attention_backward_15[2]; flex_attention_backward_15 = None + permute_1191 = torch.ops.aten.permute.default(getitem_435, [0, 2, 1, 3]); getitem_435 = None + permute_1192 = torch.ops.aten.permute.default(getitem_434, [0, 2, 1, 3]); getitem_434 = None + permute_1193 = torch.ops.aten.permute.default(getitem_433, [0, 2, 1, 3]); getitem_433 = None + slice_199 = torch.ops.aten.slice.Tensor(permute_1192, 3, 0, 128) + slice_200 = torch.ops.aten.slice.Tensor(permute_1192, 3, 128, 192); permute_1192 = None + sum_230 = torch.ops.aten.sum.dim_IntList(slice_200, [2], True); slice_200 = None + cat_125 = torch.ops.aten.cat.default([slice_199, permute_1191], 3); slice_199 = permute_1191 = None + view_2074 = torch.ops.aten.view.default(cat_125, [2, 4096, 4096]); cat_125 = None + view_2075 = torch.ops.aten.view.default(view_2074, [8192, 4096]); view_2074 = None + permute_1194 = torch.ops.aten.permute.default(view_2075, [1, 0]) + mm_468 = torch.ops.aten.mm.default(permute_1194, view_720); permute_1194 = view_720 = None + convert_element_type_593 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16); primals_187 = None + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_593, 64, '0'); convert_element_type_593 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + permute_1196 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None + mm_469 = torch.ops.aten.mm.default(view_2075, permute_1196); view_2075 = permute_1196 = None + view_2076 = torch.ops.aten.view.default(mm_469, [2, 4096, 512]); mm_469 = None + convert_element_type_2615 = torch.ops.prims.convert_element_type.default(mm_468, torch.float32); mm_468 = None + reduce_scatter_tensor_221 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2615, 'avg', 64, '0'); convert_element_type_2615 = None + wait_tensor_812 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_221); reduce_scatter_tensor_221 = None + convert_element_type_2616 = torch.ops.prims.convert_element_type.default(view_2076, torch.float32); view_2076 = None + convert_element_type_590 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16); primals_186 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_590, 64, '0'); convert_element_type_590 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + convert_element_type_2618 = torch.ops.prims.convert_element_type.default(wait_tensor_224, torch.float32); wait_tensor_224 = None + mul_1821 = torch.ops.aten.mul.Tensor(convert_element_type_2616, convert_element_type_2618); convert_element_type_2618 = None + convert_element_type_591 = torch.ops.prims.convert_element_type.default(getitem_151, torch.float32); getitem_151 = None + mul_503 = torch.ops.aten.mul.Tensor(convert_element_type_591, rsqrt_34); convert_element_type_591 = None + mul_1823 = torch.ops.aten.mul.Tensor(mul_503, mul_1821) + sum_231 = torch.ops.aten.sum.dim_IntList(mul_1823, [2], True); mul_1823 = None + div_226 = torch.ops.aten.div.Tensor(mul_503, 512) + mul_1824 = torch.ops.aten.mul.Tensor(div_226, sum_231); div_226 = sum_231 = None + sub_719 = torch.ops.aten.sub.Tensor(mul_1821, mul_1824); mul_1821 = mul_1824 = None + mul_1825 = torch.ops.aten.mul.Tensor(sub_719, rsqrt_34); sub_719 = rsqrt_34 = None + mul_1826 = torch.ops.aten.mul.Tensor(convert_element_type_2616, mul_503); convert_element_type_2616 = mul_503 = None + sum_232 = torch.ops.aten.sum.dim_IntList(mul_1826, [0, 1]); mul_1826 = None + convert_element_type_2619 = torch.ops.prims.convert_element_type.default(mul_1825, torch.bfloat16); mul_1825 = None + convert_element_type_default_35 = torch.ops.prims.convert_element_type.default(sum_232, torch.float32); sum_232 = None + reduce_scatter_tensor_222 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_35, 'avg', 64, '0'); convert_element_type_default_35 = None + wait_tensor_813 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_222); reduce_scatter_tensor_222 = None + convert_element_type_2622 = torch.ops.prims.convert_element_type.default(sum_230, torch.float32); sum_230 = None + view_2077 = torch.ops.aten.view.default(convert_element_type_2622, [2, 4096, 1, 32, 2]); convert_element_type_2622 = None + view_as_complex_84 = torch.ops.aten.view_as_complex.default(view_2077); view_2077 = None + mul_1827 = torch.ops.aten.mul.Tensor(view_as_complex_84, clone_9); view_as_complex_84 = None + view_as_real_84 = torch.ops.aten.view_as_real.default(mul_1827); mul_1827 = None + view_2078 = torch.ops.aten.view.default(view_as_real_84, [2, 4096, 1, 64]); view_as_real_84 = None + convert_element_type_2623 = torch.ops.prims.convert_element_type.default(view_2078, torch.bfloat16); view_2078 = None + squeeze_41 = torch.ops.aten.squeeze.dim(convert_element_type_2623, 2); convert_element_type_2623 = None + cat_126 = torch.ops.aten.cat.default([convert_element_type_2619, squeeze_41], 2); convert_element_type_2619 = squeeze_41 = None + view_2079 = torch.ops.aten.view.default(cat_126, [8192, 576]); cat_126 = None + permute_1198 = torch.ops.aten.permute.default(view_2079, [1, 0]) + mm_470 = torch.ops.aten.mm.default(permute_1198, view_706); permute_1198 = None + convert_element_type_585 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16); primals_185 = None + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_585, 64, '0'); convert_element_type_585 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_223, [1, 0]); wait_tensor_223 = None + permute_1200 = torch.ops.aten.permute.default(permute_162, [1, 0]); permute_162 = None + mm_471 = torch.ops.aten.mm.default(view_2079, permute_1200); view_2079 = permute_1200 = None + view_2080 = torch.ops.aten.view.default(mm_471, [2, 4096, 2048]); mm_471 = None + convert_element_type_2628 = torch.ops.prims.convert_element_type.default(mm_470, torch.float32); mm_470 = None + reduce_scatter_tensor_223 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2628, 'avg', 64, '0'); convert_element_type_2628 = None + wait_tensor_814 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_223); reduce_scatter_tensor_223 = None + slice_201 = torch.ops.aten.slice.Tensor(permute_1193, 3, 0, 128) + slice_202 = torch.ops.aten.slice.Tensor(permute_1193, 3, 128, 192); permute_1193 = None + convert_element_type_2629 = torch.ops.prims.convert_element_type.default(slice_202, torch.float32); slice_202 = None + view_2081 = torch.ops.aten.view.default(convert_element_type_2629, [2, 4096, 16, 32, 2]); convert_element_type_2629 = None + view_as_complex_85 = torch.ops.aten.view_as_complex.default(view_2081); view_2081 = None + mul_1828 = torch.ops.aten.mul.Tensor(view_as_complex_85, clone_9); view_as_complex_85 = None + view_as_real_85 = torch.ops.aten.view_as_real.default(mul_1828); mul_1828 = None + view_2082 = torch.ops.aten.view.default(view_as_real_85, [2, 4096, 16, 64]); view_as_real_85 = None + convert_element_type_2630 = torch.ops.prims.convert_element_type.default(view_2082, torch.bfloat16); view_2082 = None + cat_127 = torch.ops.aten.cat.default([slice_201, convert_element_type_2630], 3); slice_201 = convert_element_type_2630 = None + view_2083 = torch.ops.aten.view.default(cat_127, [2, 4096, 3072]); cat_127 = None + view_2084 = torch.ops.aten.view.default(view_2083, [8192, 3072]); view_2083 = None + permute_1202 = torch.ops.aten.permute.default(view_2084, [1, 0]) + mm_472 = torch.ops.aten.mm.default(permute_1202, view_706); permute_1202 = view_706 = None + convert_element_type_580 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16); primals_184 = None + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_580, 64, '0'); convert_element_type_580 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_222, [1, 0]); wait_tensor_222 = None + permute_1204 = torch.ops.aten.permute.default(permute_161, [1, 0]); permute_161 = None + mm_473 = torch.ops.aten.mm.default(view_2084, permute_1204); view_2084 = permute_1204 = None + view_2085 = torch.ops.aten.view.default(mm_473, [2, 4096, 2048]); mm_473 = None + add_2013 = torch.ops.aten.add.Tensor(view_2080, view_2085); view_2080 = view_2085 = None + convert_element_type_2635 = torch.ops.prims.convert_element_type.default(mm_472, torch.float32); mm_472 = None + reduce_scatter_tensor_224 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2635, 'avg', 64, '0'); convert_element_type_2635 = None + wait_tensor_815 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_224); reduce_scatter_tensor_224 = None + convert_element_type_2636 = torch.ops.prims.convert_element_type.default(add_2013, torch.float32); add_2013 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16); primals_183 = None + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_577, 64, '0'); convert_element_type_577 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_2638 = torch.ops.prims.convert_element_type.default(wait_tensor_221, torch.float32); wait_tensor_221 = None + mul_1829 = torch.ops.aten.mul.Tensor(convert_element_type_2636, convert_element_type_2638); convert_element_type_2638 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(add_685, torch.float32); add_685 = None + mul_499 = torch.ops.aten.mul.Tensor(convert_element_type_578, rsqrt_33); convert_element_type_578 = None + mul_1831 = torch.ops.aten.mul.Tensor(mul_499, mul_1829) + sum_233 = torch.ops.aten.sum.dim_IntList(mul_1831, [2], True); mul_1831 = None + div_227 = torch.ops.aten.div.Tensor(mul_499, 2048) + mul_1832 = torch.ops.aten.mul.Tensor(div_227, sum_233); div_227 = sum_233 = None + sub_720 = torch.ops.aten.sub.Tensor(mul_1829, mul_1832); mul_1829 = mul_1832 = None + mul_1833 = torch.ops.aten.mul.Tensor(sub_720, rsqrt_33); sub_720 = rsqrt_33 = None + mul_1834 = torch.ops.aten.mul.Tensor(convert_element_type_2636, mul_499); convert_element_type_2636 = mul_499 = None + sum_234 = torch.ops.aten.sum.dim_IntList(mul_1834, [0, 1]); mul_1834 = None + convert_element_type_2639 = torch.ops.prims.convert_element_type.default(mul_1833, torch.bfloat16); mul_1833 = None + add_2014 = torch.ops.aten.add.Tensor(add_2012, convert_element_type_2639); add_2012 = convert_element_type_2639 = None + convert_element_type_default_34 = torch.ops.prims.convert_element_type.default(sum_234, torch.float32); sum_234 = None + reduce_scatter_tensor_225 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_34, 'avg', 64, '0'); convert_element_type_default_34 = None + wait_tensor_816 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_225); reduce_scatter_tensor_225 = None + view_2086 = torch.ops.aten.view.default(add_2014, [8192, 2048]) + unsqueeze_69 = torch.ops.aten.unsqueeze.default(view_2086, 1) + convert_element_type_2642 = torch.ops.prims.convert_element_type.default(unsqueeze_69, torch.float32); unsqueeze_69 = None + bmm_58 = torch.ops.aten.bmm.default(permute_1206, convert_element_type_2642); permute_1206 = None + bmm_59 = torch.ops.aten.bmm.default(convert_element_type_2642, permute_1207); convert_element_type_2642 = permute_1207 = None + convert_element_type_2643 = torch.ops.prims.convert_element_type.default(bmm_58, torch.bfloat16); bmm_58 = None + view_2087 = torch.ops.aten.view.default(bmm_59, [8192, 6]); bmm_59 = None + view_2088 = torch.ops.aten.view.default(convert_element_type_2643, [49152, 2048]); convert_element_type_2643 = None + index_84 = torch.ops.aten.index.Tensor(view_2088, [getitem_147]); view_2088 = getitem_147 = None + permute_1208 = torch.ops.aten.permute.default(view_2086, [1, 0]) + mm_474 = torch.ops.aten.mm.default(permute_1208, mul_496); permute_1208 = mul_496 = None + convert_element_type_572 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16); primals_182 = None + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_572, 64, '0'); convert_element_type_572 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_160 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + permute_1210 = torch.ops.aten.permute.default(permute_160, [1, 0]); permute_160 = None + mm_475 = torch.ops.aten.mm.default(view_2086, permute_1210); view_2086 = permute_1210 = None + convert_element_type_2648 = torch.ops.prims.convert_element_type.default(mm_474, torch.float32); mm_474 = None + reduce_scatter_tensor_226 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2648, 'avg', 64, '0'); convert_element_type_2648 = None + wait_tensor_817 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_226); reduce_scatter_tensor_226 = None + convert_element_type_567 = torch.ops.prims.convert_element_type.default(mm_84, torch.float32); mm_84 = None + neg_20 = torch.ops.aten.neg.default(convert_element_type_567) + exp_30 = torch.ops.aten.exp.default(neg_20); neg_20 = None + add_680 = torch.ops.aten.add.Tensor(exp_30, 1); exp_30 = None + div_50 = torch.ops.aten.div.Tensor(convert_element_type_567, add_680) + convert_element_type_568 = torch.ops.prims.convert_element_type.default(div_50, torch.bfloat16); div_50 = None + mul_1835 = torch.ops.aten.mul.Tensor(mm_475, convert_element_type_568); convert_element_type_568 = None + mul_1836 = torch.ops.aten.mul.Tensor(mm_475, mm_85); mm_475 = mm_85 = None + permute_1212 = torch.ops.aten.permute.default(mul_1835, [1, 0]) + mm_476 = torch.ops.aten.mm.default(permute_1212, view_661); permute_1212 = None + convert_element_type_569 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16); primals_181 = None + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_569, 64, '0'); convert_element_type_569 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_159 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + permute_1214 = torch.ops.aten.permute.default(permute_159, [1, 0]); permute_159 = None + mm_477 = torch.ops.aten.mm.default(mul_1835, permute_1214); mul_1835 = permute_1214 = None + convert_element_type_2653 = torch.ops.prims.convert_element_type.default(mm_476, torch.float32); mm_476 = None + reduce_scatter_tensor_227 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2653, 'avg', 64, '0'); convert_element_type_2653 = None + wait_tensor_818 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_227); reduce_scatter_tensor_227 = None + convert_element_type_2654 = torch.ops.prims.convert_element_type.default(mul_1836, torch.float32); mul_1836 = None + reciprocal_32 = torch.ops.aten.reciprocal.default(add_680); add_680 = None + mul_1837 = torch.ops.aten.mul.Tensor(reciprocal_32, 1); reciprocal_32 = None + mul_1838 = torch.ops.aten.mul.Tensor(convert_element_type_2654, mul_1837); convert_element_type_2654 = None + sub_721 = torch.ops.aten.sub.Tensor(1, mul_1837); mul_1837 = None + mul_1839 = torch.ops.aten.mul.Tensor(convert_element_type_567, sub_721); convert_element_type_567 = sub_721 = None + add_2016 = torch.ops.aten.add.Tensor(mul_1839, 1); mul_1839 = None + mul_1840 = torch.ops.aten.mul.Tensor(mul_1838, add_2016); mul_1838 = add_2016 = None + convert_element_type_2656 = torch.ops.prims.convert_element_type.default(mul_1840, torch.bfloat16); mul_1840 = None + permute_1216 = torch.ops.aten.permute.default(convert_element_type_2656, [1, 0]) + mm_478 = torch.ops.aten.mm.default(permute_1216, view_661); permute_1216 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_564, 64, '0'); convert_element_type_564 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_158 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + permute_1218 = torch.ops.aten.permute.default(permute_158, [1, 0]); permute_158 = None + mm_479 = torch.ops.aten.mm.default(convert_element_type_2656, permute_1218); convert_element_type_2656 = permute_1218 = None + add_2017 = torch.ops.aten.add.Tensor(mm_477, mm_479); mm_477 = mm_479 = None + convert_element_type_2661 = torch.ops.prims.convert_element_type.default(mm_478, torch.float32); mm_478 = None + reduce_scatter_tensor_228 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2661, 'avg', 64, '0'); convert_element_type_2661 = None + wait_tensor_819 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_228); reduce_scatter_tensor_228 = None + all_to_all_single_110 = torch.ops._c10d_functional.all_to_all_single.default(index_84, [_local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159], [_local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151], '521'); index_84 = None + wait_tensor_820 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_110); all_to_all_single_110 = None + full_412 = torch.ops.aten.full.default([sym_size_int_37, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_37 = None + slice_scatter_16 = torch.ops.aten.slice_scatter.default(full_412, wait_tensor_820, 0, 0, -1); wait_tensor_820 = None + index_85 = torch.ops.aten.index.Tensor(slice_scatter_16, [getitem_148]); slice_scatter_16 = None + permute_1220 = torch.ops.aten.permute.default(index_85, [1, 0]) + _grouped_mm_174 = torch.ops.aten._grouped_mm.default(permute_1220, mul_476, cumsum_29); permute_1220 = mul_476 = None + convert_element_type_558 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16); primals_178 = None + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_558, 8, '513'); convert_element_type_558 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_157 = torch.ops.aten.permute.default(wait_tensor_213, [0, 2, 1]); wait_tensor_213 = None + permute_1222 = torch.ops.aten.permute.default(permute_157, [0, 2, 1]); permute_157 = None + _grouped_mm_175 = torch.ops.aten._grouped_mm.default(index_85, permute_1222, cumsum_29); index_85 = permute_1222 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(_grouped_mm_27, torch.float32); _grouped_mm_27 = None + neg_19 = torch.ops.aten.neg.default(convert_element_type_562) + exp_29 = torch.ops.aten.exp.default(neg_19); neg_19 = None + add_644 = torch.ops.aten.add.Tensor(exp_29, 1); exp_29 = None + div_49 = torch.ops.aten.div.Tensor(convert_element_type_562, add_644) + convert_element_type_563 = torch.ops.prims.convert_element_type.default(div_49, torch.bfloat16); div_49 = None + mul_1841 = torch.ops.aten.mul.Tensor(_grouped_mm_175, convert_element_type_563); convert_element_type_563 = None + mul_1842 = torch.ops.aten.mul.Tensor(_grouped_mm_175, _grouped_mm_28); _grouped_mm_175 = _grouped_mm_28 = None + permute_1224 = torch.ops.aten.permute.default(mul_1841, [1, 0]) + _grouped_mm_176 = torch.ops.aten._grouped_mm.default(permute_1224, index_19, cumsum_29); permute_1224 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 8, '513'); convert_element_type_559 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_214, [0, 2, 1]); wait_tensor_214 = None + permute_1226 = torch.ops.aten.permute.default(permute_156, [0, 2, 1]); permute_156 = None + _grouped_mm_177 = torch.ops.aten._grouped_mm.default(mul_1841, permute_1226, cumsum_29); mul_1841 = permute_1226 = None + convert_element_type_2662 = torch.ops.prims.convert_element_type.default(mul_1842, torch.float32); mul_1842 = None + reciprocal_33 = torch.ops.aten.reciprocal.default(add_644); add_644 = None + mul_1843 = torch.ops.aten.mul.Tensor(reciprocal_33, 1); reciprocal_33 = None + mul_1844 = torch.ops.aten.mul.Tensor(convert_element_type_2662, mul_1843); convert_element_type_2662 = None + sub_722 = torch.ops.aten.sub.Tensor(1, mul_1843); mul_1843 = None + mul_1845 = torch.ops.aten.mul.Tensor(convert_element_type_562, sub_722); convert_element_type_562 = sub_722 = None + add_2019 = torch.ops.aten.add.Tensor(mul_1845, 1); mul_1845 = None + mul_1846 = torch.ops.aten.mul.Tensor(mul_1844, add_2019); mul_1844 = add_2019 = None + convert_element_type_2664 = torch.ops.prims.convert_element_type.default(mul_1846, torch.bfloat16); mul_1846 = None + permute_1228 = torch.ops.aten.permute.default(convert_element_type_2664, [1, 0]) + _grouped_mm_178 = torch.ops.aten._grouped_mm.default(permute_1228, index_19, cumsum_29); permute_1228 = index_19 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16); primals_177 = None + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 8, '513'); convert_element_type_556 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_211, [0, 2, 1]); wait_tensor_211 = None + permute_1230 = torch.ops.aten.permute.default(permute_155, [0, 2, 1]); permute_155 = None + _grouped_mm_179 = torch.ops.aten._grouped_mm.default(convert_element_type_2664, permute_1230, cumsum_29); convert_element_type_2664 = permute_1230 = cumsum_29 = None + add_2020 = torch.ops.aten.add.Tensor(_grouped_mm_177, _grouped_mm_179); _grouped_mm_177 = _grouped_mm_179 = None + convert_element_type_2665 = torch.ops.prims.convert_element_type.default(_grouped_mm_176, torch.float32); _grouped_mm_176 = None + div_228 = torch.ops.aten.div.Tensor(convert_element_type_2665, 64); convert_element_type_2665 = None + reduce_scatter_tensor_229 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_228, 'sum', 8, '513'); div_228 = None + wait_tensor_821 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_229); reduce_scatter_tensor_229 = None + convert_element_type_2666 = torch.ops.prims.convert_element_type.default(_grouped_mm_174, torch.float32); _grouped_mm_174 = None + div_229 = torch.ops.aten.div.Tensor(convert_element_type_2666, 64); convert_element_type_2666 = None + reduce_scatter_tensor_230 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_229, 'sum', 8, '513'); div_229 = None + wait_tensor_822 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_230); reduce_scatter_tensor_230 = None + convert_element_type_2667 = torch.ops.prims.convert_element_type.default(_grouped_mm_178, torch.float32); _grouped_mm_178 = None + div_230 = torch.ops.aten.div.Tensor(convert_element_type_2667, 64); convert_element_type_2667 = None + reduce_scatter_tensor_231 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_230, 'sum', 8, '513'); div_230 = None + wait_tensor_823 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_231); reduce_scatter_tensor_231 = None + index_put_84 = torch.ops.aten.index_put.default(full_412, [getitem_148], add_2020, True); full_412 = getitem_148 = add_2020 = None + slice_203 = torch.ops.aten.slice.Tensor(index_put_84, 0, 0, add_2021); index_put_84 = add_2021 = None + all_to_all_single_111 = torch.ops._c10d_functional.all_to_all_single.default(slice_203, [_local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151], [_local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159], '521'); slice_203 = _local_scalar_dense_144 = _local_scalar_dense_145 = _local_scalar_dense_146 = _local_scalar_dense_147 = _local_scalar_dense_148 = _local_scalar_dense_149 = _local_scalar_dense_150 = _local_scalar_dense_151 = _local_scalar_dense_152 = _local_scalar_dense_153 = _local_scalar_dense_154 = _local_scalar_dense_155 = _local_scalar_dense_156 = _local_scalar_dense_157 = _local_scalar_dense_158 = _local_scalar_dense_159 = None + wait_tensor_824 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_111); all_to_all_single_111 = None + index_put_85 = torch.ops.aten.index_put.default(full_default_52, [div_47], wait_tensor_824, True); div_47 = wait_tensor_824 = None + add_2025 = torch.ops.aten.add.Tensor(add_2017, index_put_85); add_2017 = index_put_85 = None + mul_1847 = torch.ops.aten.mul.Tensor(view_2087, 1.0); view_2087 = None + scatter_add_16 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_145, mul_1847); getitem_145 = mul_1847 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(mm_83, torch.float32); mm_83 = None + sub_216 = torch.ops.aten.sub.Tensor(convert_element_type_551, amax_9); convert_element_type_551 = amax_9 = None + exp_28 = torch.ops.aten.exp.default(sub_216); sub_216 = None + div_46 = torch.ops.aten.div.Tensor(exp_28, sum_37); exp_28 = sum_37 = None + mul_1848 = torch.ops.aten.mul.Tensor(scatter_add_16, div_46); scatter_add_16 = None + sum_235 = torch.ops.aten.sum.dim_IntList(mul_1848, [1], True) + neg_103 = torch.ops.aten.neg.default(div_46); div_46 = None + fma_16 = torch.ops.prims.fma.default(neg_103, sum_235, mul_1848); neg_103 = sum_235 = mul_1848 = None + convert_element_type_2668 = torch.ops.prims.convert_element_type.default(fma_16, torch.bfloat16); fma_16 = None + permute_1232 = torch.ops.aten.permute.default(convert_element_type_2668, [1, 0]) + mm_480 = torch.ops.aten.mm.default(permute_1232, view_661); permute_1232 = view_661 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16); primals_175 = None + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 64, '0'); convert_element_type_548 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + permute_1234 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None + mm_481 = torch.ops.aten.mm.default(convert_element_type_2668, permute_1234); convert_element_type_2668 = permute_1234 = None + add_2026 = torch.ops.aten.add.Tensor(add_2025, mm_481); add_2025 = mm_481 = None + convert_element_type_2673 = torch.ops.prims.convert_element_type.default(mm_480, torch.float32); mm_480 = None + reduce_scatter_tensor_232 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2673, 'avg', 64, '0'); convert_element_type_2673 = None + wait_tensor_825 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_232); reduce_scatter_tensor_232 = None + view_2089 = torch.ops.aten.view.default(add_2026, [2, 4096, 2048]); add_2026 = None + convert_element_type_2674 = torch.ops.prims.convert_element_type.default(view_2089, torch.float32); view_2089 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16); primals_173 = None + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 64, '0'); convert_element_type_545 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + convert_element_type_2676 = torch.ops.prims.convert_element_type.default(wait_tensor_206, torch.float32); wait_tensor_206 = None + mul_1849 = torch.ops.aten.mul.Tensor(convert_element_type_2674, convert_element_type_2676); convert_element_type_2676 = None + convert_element_type_546 = torch.ops.prims.convert_element_type.default(add_620, torch.float32); add_620 = None + mul_456 = torch.ops.aten.mul.Tensor(convert_element_type_546, rsqrt_32); convert_element_type_546 = None + mul_1851 = torch.ops.aten.mul.Tensor(mul_456, mul_1849) + sum_236 = torch.ops.aten.sum.dim_IntList(mul_1851, [2], True); mul_1851 = None + div_231 = torch.ops.aten.div.Tensor(mul_456, 2048) + mul_1852 = torch.ops.aten.mul.Tensor(div_231, sum_236); div_231 = sum_236 = None + sub_724 = torch.ops.aten.sub.Tensor(mul_1849, mul_1852); mul_1849 = mul_1852 = None + mul_1853 = torch.ops.aten.mul.Tensor(sub_724, rsqrt_32); sub_724 = rsqrt_32 = None + mul_1854 = torch.ops.aten.mul.Tensor(convert_element_type_2674, mul_456); convert_element_type_2674 = mul_456 = None + sum_237 = torch.ops.aten.sum.dim_IntList(mul_1854, [0, 1]); mul_1854 = None + convert_element_type_2677 = torch.ops.prims.convert_element_type.default(mul_1853, torch.bfloat16); mul_1853 = None + add_2027 = torch.ops.aten.add.Tensor(add_2014, convert_element_type_2677); add_2014 = convert_element_type_2677 = None + convert_element_type_default_33 = torch.ops.prims.convert_element_type.default(sum_237, torch.float32); sum_237 = None + reduce_scatter_tensor_233 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_33, 'avg', 64, '0'); convert_element_type_default_33 = None + wait_tensor_826 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_233); reduce_scatter_tensor_233 = None + view_2090 = torch.ops.aten.view.default(add_2027, [8192, 2048]) + permute_1236 = torch.ops.aten.permute.default(view_2090, [1, 0]) + permute_152 = torch.ops.aten.permute.default(getitem_141, [0, 2, 1, 3]) + view_656 = torch.ops.aten.view.default(permute_152, [2, 4096, -1]); permute_152 = None + view_658 = torch.ops.aten.view.default(view_656, [8192, 2048]); view_656 = None + mm_482 = torch.ops.aten.mm.default(permute_1236, view_658); permute_1236 = view_658 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16); primals_172 = None + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_542, 64, '0'); convert_element_type_542 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + permute_1238 = torch.ops.aten.permute.default(permute_153, [1, 0]); permute_153 = None + mm_483 = torch.ops.aten.mm.default(view_2090, permute_1238); view_2090 = permute_1238 = None + view_2091 = torch.ops.aten.view.default(mm_483, [2, 4096, 2048]); mm_483 = None + convert_element_type_2684 = torch.ops.prims.convert_element_type.default(mm_482, torch.float32); mm_482 = None + reduce_scatter_tensor_234 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2684, 'avg', 64, '0'); convert_element_type_2684 = None + wait_tensor_827 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_234); reduce_scatter_tensor_234 = None + view_2092 = torch.ops.aten.view.default(view_2091, [2, 4096, 16, 128]); view_2091 = None + permute_1240 = torch.ops.aten.permute.default(view_2092, [0, 2, 1, 3]); view_2092 = None + fw_graph16 = self.fw_graph16 + joint_graph16 = self.joint_graph16 + mask_graph16 = self.mask_graph16 + flex_attention_backward_16 = torch.ops.higher_order.flex_attention_backward(permute_149, permute_150, permute_151, getitem_141, getitem_142, permute_1240, None, fw_graph16, joint_graph16, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph16), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_149 = permute_150 = permute_151 = getitem_141 = getitem_142 = permute_1240 = fw_graph16 = joint_graph16 = mask_graph16 = None + getitem_437 = flex_attention_backward_16[0] + getitem_438 = flex_attention_backward_16[1] + getitem_439 = flex_attention_backward_16[2]; flex_attention_backward_16 = None + permute_1241 = torch.ops.aten.permute.default(getitem_439, [0, 2, 1, 3]); getitem_439 = None + permute_1242 = torch.ops.aten.permute.default(getitem_438, [0, 2, 1, 3]); getitem_438 = None + permute_1243 = torch.ops.aten.permute.default(getitem_437, [0, 2, 1, 3]); getitem_437 = None + slice_205 = torch.ops.aten.slice.Tensor(permute_1242, 3, 0, 128) + slice_206 = torch.ops.aten.slice.Tensor(permute_1242, 3, 128, 192); permute_1242 = None + sum_238 = torch.ops.aten.sum.dim_IntList(slice_206, [2], True); slice_206 = None + cat_128 = torch.ops.aten.cat.default([slice_205, permute_1241], 3); slice_205 = permute_1241 = None + view_2093 = torch.ops.aten.view.default(cat_128, [2, 4096, 4096]); cat_128 = None + view_2094 = torch.ops.aten.view.default(view_2093, [8192, 4096]); view_2093 = None + permute_1244 = torch.ops.aten.permute.default(view_2094, [1, 0]) + mm_484 = torch.ops.aten.mm.default(permute_1244, view_653); permute_1244 = view_653 = None + convert_element_type_539 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16); primals_171 = None + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_539, 64, '0'); convert_element_type_539 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_148 = torch.ops.aten.permute.default(wait_tensor_204, [1, 0]); wait_tensor_204 = None + permute_1246 = torch.ops.aten.permute.default(permute_148, [1, 0]); permute_148 = None + mm_485 = torch.ops.aten.mm.default(view_2094, permute_1246); view_2094 = permute_1246 = None + view_2095 = torch.ops.aten.view.default(mm_485, [2, 4096, 512]); mm_485 = None + convert_element_type_2689 = torch.ops.prims.convert_element_type.default(mm_484, torch.float32); mm_484 = None + reduce_scatter_tensor_235 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2689, 'avg', 64, '0'); convert_element_type_2689 = None + wait_tensor_828 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_235); reduce_scatter_tensor_235 = None + convert_element_type_2690 = torch.ops.prims.convert_element_type.default(view_2095, torch.float32); view_2095 = None + convert_element_type_536 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16); primals_170 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_536, 64, '0'); convert_element_type_536 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + convert_element_type_2692 = torch.ops.prims.convert_element_type.default(wait_tensor_203, torch.float32); wait_tensor_203 = None + mul_1855 = torch.ops.aten.mul.Tensor(convert_element_type_2690, convert_element_type_2692); convert_element_type_2692 = None + convert_element_type_537 = torch.ops.prims.convert_element_type.default(getitem_137, torch.float32); getitem_137 = None + mul_454 = torch.ops.aten.mul.Tensor(convert_element_type_537, rsqrt_31); convert_element_type_537 = None + mul_1857 = torch.ops.aten.mul.Tensor(mul_454, mul_1855) + sum_239 = torch.ops.aten.sum.dim_IntList(mul_1857, [2], True); mul_1857 = None + div_232 = torch.ops.aten.div.Tensor(mul_454, 512) + mul_1858 = torch.ops.aten.mul.Tensor(div_232, sum_239); div_232 = sum_239 = None + sub_725 = torch.ops.aten.sub.Tensor(mul_1855, mul_1858); mul_1855 = mul_1858 = None + mul_1859 = torch.ops.aten.mul.Tensor(sub_725, rsqrt_31); sub_725 = rsqrt_31 = None + mul_1860 = torch.ops.aten.mul.Tensor(convert_element_type_2690, mul_454); convert_element_type_2690 = mul_454 = None + sum_240 = torch.ops.aten.sum.dim_IntList(mul_1860, [0, 1]); mul_1860 = None + convert_element_type_2693 = torch.ops.prims.convert_element_type.default(mul_1859, torch.bfloat16); mul_1859 = None + convert_element_type_default_32 = torch.ops.prims.convert_element_type.default(sum_240, torch.float32); sum_240 = None + reduce_scatter_tensor_236 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_32, 'avg', 64, '0'); convert_element_type_default_32 = None + wait_tensor_829 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_236); reduce_scatter_tensor_236 = None + convert_element_type_2696 = torch.ops.prims.convert_element_type.default(sum_238, torch.float32); sum_238 = None + view_2096 = torch.ops.aten.view.default(convert_element_type_2696, [2, 4096, 1, 32, 2]); convert_element_type_2696 = None + view_as_complex_86 = torch.ops.aten.view_as_complex.default(view_2096); view_2096 = None + mul_1861 = torch.ops.aten.mul.Tensor(view_as_complex_86, clone_9); view_as_complex_86 = None + view_as_real_86 = torch.ops.aten.view_as_real.default(mul_1861); mul_1861 = None + view_2097 = torch.ops.aten.view.default(view_as_real_86, [2, 4096, 1, 64]); view_as_real_86 = None + convert_element_type_2697 = torch.ops.prims.convert_element_type.default(view_2097, torch.bfloat16); view_2097 = None + squeeze_42 = torch.ops.aten.squeeze.dim(convert_element_type_2697, 2); convert_element_type_2697 = None + cat_129 = torch.ops.aten.cat.default([convert_element_type_2693, squeeze_42], 2); convert_element_type_2693 = squeeze_42 = None + view_2098 = torch.ops.aten.view.default(cat_129, [8192, 576]); cat_129 = None + permute_1248 = torch.ops.aten.permute.default(view_2098, [1, 0]) + mm_486 = torch.ops.aten.mm.default(permute_1248, view_639); permute_1248 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16); primals_169 = None + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_531, 64, '0'); convert_element_type_531 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + permute_147 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + permute_1250 = torch.ops.aten.permute.default(permute_147, [1, 0]); permute_147 = None + mm_487 = torch.ops.aten.mm.default(view_2098, permute_1250); view_2098 = permute_1250 = None + view_2099 = torch.ops.aten.view.default(mm_487, [2, 4096, 2048]); mm_487 = None + convert_element_type_2702 = torch.ops.prims.convert_element_type.default(mm_486, torch.float32); mm_486 = None + reduce_scatter_tensor_237 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2702, 'avg', 64, '0'); convert_element_type_2702 = None + wait_tensor_830 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_237); reduce_scatter_tensor_237 = None + slice_207 = torch.ops.aten.slice.Tensor(permute_1243, 3, 0, 128) + slice_208 = torch.ops.aten.slice.Tensor(permute_1243, 3, 128, 192); permute_1243 = None + convert_element_type_2703 = torch.ops.prims.convert_element_type.default(slice_208, torch.float32); slice_208 = None + view_2100 = torch.ops.aten.view.default(convert_element_type_2703, [2, 4096, 16, 32, 2]); convert_element_type_2703 = None + view_as_complex_87 = torch.ops.aten.view_as_complex.default(view_2100); view_2100 = None + mul_1862 = torch.ops.aten.mul.Tensor(view_as_complex_87, clone_9); view_as_complex_87 = None + view_as_real_87 = torch.ops.aten.view_as_real.default(mul_1862); mul_1862 = None + view_2101 = torch.ops.aten.view.default(view_as_real_87, [2, 4096, 16, 64]); view_as_real_87 = None + convert_element_type_2704 = torch.ops.prims.convert_element_type.default(view_2101, torch.bfloat16); view_2101 = None + cat_130 = torch.ops.aten.cat.default([slice_207, convert_element_type_2704], 3); slice_207 = convert_element_type_2704 = None + view_2102 = torch.ops.aten.view.default(cat_130, [2, 4096, 3072]); cat_130 = None + view_2103 = torch.ops.aten.view.default(view_2102, [8192, 3072]); view_2102 = None + permute_1252 = torch.ops.aten.permute.default(view_2103, [1, 0]) + mm_488 = torch.ops.aten.mm.default(permute_1252, view_639); permute_1252 = view_639 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16); primals_168 = None + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 64, '0'); convert_element_type_526 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_146 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + permute_1254 = torch.ops.aten.permute.default(permute_146, [1, 0]); permute_146 = None + mm_489 = torch.ops.aten.mm.default(view_2103, permute_1254); view_2103 = permute_1254 = None + view_2104 = torch.ops.aten.view.default(mm_489, [2, 4096, 2048]); mm_489 = None + add_2028 = torch.ops.aten.add.Tensor(view_2099, view_2104); view_2099 = view_2104 = None + convert_element_type_2709 = torch.ops.prims.convert_element_type.default(mm_488, torch.float32); mm_488 = None + reduce_scatter_tensor_238 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2709, 'avg', 64, '0'); convert_element_type_2709 = None + wait_tensor_831 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_238); reduce_scatter_tensor_238 = None + convert_element_type_2710 = torch.ops.prims.convert_element_type.default(add_2028, torch.float32); add_2028 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16); primals_167 = None + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 64, '0'); convert_element_type_523 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + convert_element_type_2712 = torch.ops.prims.convert_element_type.default(wait_tensor_200, torch.float32); wait_tensor_200 = None + mul_1863 = torch.ops.aten.mul.Tensor(convert_element_type_2710, convert_element_type_2712); convert_element_type_2712 = None + convert_element_type_524 = torch.ops.prims.convert_element_type.default(add_617, torch.float32); add_617 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_524, rsqrt_30); convert_element_type_524 = None + mul_1865 = torch.ops.aten.mul.Tensor(mul_450, mul_1863) + sum_241 = torch.ops.aten.sum.dim_IntList(mul_1865, [2], True); mul_1865 = None + div_233 = torch.ops.aten.div.Tensor(mul_450, 2048) + mul_1866 = torch.ops.aten.mul.Tensor(div_233, sum_241); div_233 = sum_241 = None + sub_726 = torch.ops.aten.sub.Tensor(mul_1863, mul_1866); mul_1863 = mul_1866 = None + mul_1867 = torch.ops.aten.mul.Tensor(sub_726, rsqrt_30); sub_726 = rsqrt_30 = None + mul_1868 = torch.ops.aten.mul.Tensor(convert_element_type_2710, mul_450); convert_element_type_2710 = mul_450 = None + sum_242 = torch.ops.aten.sum.dim_IntList(mul_1868, [0, 1]); mul_1868 = None + convert_element_type_2713 = torch.ops.prims.convert_element_type.default(mul_1867, torch.bfloat16); mul_1867 = None + add_2029 = torch.ops.aten.add.Tensor(add_2027, convert_element_type_2713); add_2027 = convert_element_type_2713 = None + convert_element_type_default_31 = torch.ops.prims.convert_element_type.default(sum_242, torch.float32); sum_242 = None + reduce_scatter_tensor_239 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_31, 'avg', 64, '0'); convert_element_type_default_31 = None + wait_tensor_832 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_239); reduce_scatter_tensor_239 = None + view_2105 = torch.ops.aten.view.default(add_2029, [8192, 2048]) + unsqueeze_70 = torch.ops.aten.unsqueeze.default(view_2105, 1) + convert_element_type_2716 = torch.ops.prims.convert_element_type.default(unsqueeze_70, torch.float32); unsqueeze_70 = None + bmm_60 = torch.ops.aten.bmm.default(permute_1256, convert_element_type_2716); permute_1256 = None + bmm_61 = torch.ops.aten.bmm.default(convert_element_type_2716, permute_1257); convert_element_type_2716 = permute_1257 = None + convert_element_type_2717 = torch.ops.prims.convert_element_type.default(bmm_60, torch.bfloat16); bmm_60 = None + view_2106 = torch.ops.aten.view.default(bmm_61, [8192, 6]); bmm_61 = None + view_2107 = torch.ops.aten.view.default(convert_element_type_2717, [49152, 2048]); convert_element_type_2717 = None + index_86 = torch.ops.aten.index.Tensor(view_2107, [getitem_133]); view_2107 = getitem_133 = None + permute_1258 = torch.ops.aten.permute.default(view_2105, [1, 0]) + mm_490 = torch.ops.aten.mm.default(permute_1258, mul_447); permute_1258 = mul_447 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16); primals_166 = None + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 64, '0'); convert_element_type_518 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + permute_1260 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None + mm_491 = torch.ops.aten.mm.default(view_2105, permute_1260); view_2105 = permute_1260 = None + convert_element_type_2722 = torch.ops.prims.convert_element_type.default(mm_490, torch.float32); mm_490 = None + reduce_scatter_tensor_240 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2722, 'avg', 64, '0'); convert_element_type_2722 = None + wait_tensor_833 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_240); reduce_scatter_tensor_240 = None + convert_element_type_513 = torch.ops.prims.convert_element_type.default(mm_76, torch.float32); mm_76 = None + neg_18 = torch.ops.aten.neg.default(convert_element_type_513) + exp_27 = torch.ops.aten.exp.default(neg_18); neg_18 = None + add_612 = torch.ops.aten.add.Tensor(exp_27, 1); exp_27 = None + div_45 = torch.ops.aten.div.Tensor(convert_element_type_513, add_612) + convert_element_type_514 = torch.ops.prims.convert_element_type.default(div_45, torch.bfloat16); div_45 = None + mul_1869 = torch.ops.aten.mul.Tensor(mm_491, convert_element_type_514); convert_element_type_514 = None + mul_1870 = torch.ops.aten.mul.Tensor(mm_491, mm_77); mm_491 = mm_77 = None + permute_1262 = torch.ops.aten.permute.default(mul_1869, [1, 0]) + mm_492 = torch.ops.aten.mm.default(permute_1262, view_594); permute_1262 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16); primals_165 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 64, '0'); convert_element_type_515 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + permute_1264 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None + mm_493 = torch.ops.aten.mm.default(mul_1869, permute_1264); mul_1869 = permute_1264 = None + convert_element_type_2727 = torch.ops.prims.convert_element_type.default(mm_492, torch.float32); mm_492 = None + reduce_scatter_tensor_241 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2727, 'avg', 64, '0'); convert_element_type_2727 = None + wait_tensor_834 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_241); reduce_scatter_tensor_241 = None + convert_element_type_2728 = torch.ops.prims.convert_element_type.default(mul_1870, torch.float32); mul_1870 = None + reciprocal_34 = torch.ops.aten.reciprocal.default(add_612); add_612 = None + mul_1871 = torch.ops.aten.mul.Tensor(reciprocal_34, 1); reciprocal_34 = None + mul_1872 = torch.ops.aten.mul.Tensor(convert_element_type_2728, mul_1871); convert_element_type_2728 = None + sub_727 = torch.ops.aten.sub.Tensor(1, mul_1871); mul_1871 = None + mul_1873 = torch.ops.aten.mul.Tensor(convert_element_type_513, sub_727); convert_element_type_513 = sub_727 = None + add_2031 = torch.ops.aten.add.Tensor(mul_1873, 1); mul_1873 = None + mul_1874 = torch.ops.aten.mul.Tensor(mul_1872, add_2031); mul_1872 = add_2031 = None + convert_element_type_2730 = torch.ops.prims.convert_element_type.default(mul_1874, torch.bfloat16); mul_1874 = None + permute_1266 = torch.ops.aten.permute.default(convert_element_type_2730, [1, 0]) + mm_494 = torch.ops.aten.mm.default(permute_1266, view_594); permute_1266 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16); primals_164 = None + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_510, 64, '0'); convert_element_type_510 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + permute_1268 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None + mm_495 = torch.ops.aten.mm.default(convert_element_type_2730, permute_1268); convert_element_type_2730 = permute_1268 = None + add_2032 = torch.ops.aten.add.Tensor(mm_493, mm_495); mm_493 = mm_495 = None + convert_element_type_2735 = torch.ops.prims.convert_element_type.default(mm_494, torch.float32); mm_494 = None + reduce_scatter_tensor_242 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2735, 'avg', 64, '0'); convert_element_type_2735 = None + wait_tensor_835 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_242); reduce_scatter_tensor_242 = None + all_to_all_single_112 = torch.ops._c10d_functional.all_to_all_single.default(index_86, [_local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143], [_local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135], '521'); index_86 = None + wait_tensor_836 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_112); all_to_all_single_112 = None + full_416 = torch.ops.aten.full.default([sym_size_int_33, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_33 = None + slice_scatter_17 = torch.ops.aten.slice_scatter.default(full_416, wait_tensor_836, 0, 0, -1); wait_tensor_836 = None + index_87 = torch.ops.aten.index.Tensor(slice_scatter_17, [getitem_134]); slice_scatter_17 = None + permute_1270 = torch.ops.aten.permute.default(index_87, [1, 0]) + _grouped_mm_180 = torch.ops.aten._grouped_mm.default(permute_1270, mul_427, cumsum_26); permute_1270 = mul_427 = None + convert_element_type_504 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16); primals_162 = None + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_504, 8, '513'); convert_element_type_504 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_192, [0, 2, 1]); wait_tensor_192 = None + permute_1272 = torch.ops.aten.permute.default(permute_142, [0, 2, 1]); permute_142 = None + _grouped_mm_181 = torch.ops.aten._grouped_mm.default(index_87, permute_1272, cumsum_26); index_87 = permute_1272 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(_grouped_mm_24, torch.float32); _grouped_mm_24 = None + neg_17 = torch.ops.aten.neg.default(convert_element_type_508) + exp_26 = torch.ops.aten.exp.default(neg_17); neg_17 = None + add_576 = torch.ops.aten.add.Tensor(exp_26, 1); exp_26 = None + div_44 = torch.ops.aten.div.Tensor(convert_element_type_508, add_576) + convert_element_type_509 = torch.ops.prims.convert_element_type.default(div_44, torch.bfloat16); div_44 = None + mul_1875 = torch.ops.aten.mul.Tensor(_grouped_mm_181, convert_element_type_509); convert_element_type_509 = None + mul_1876 = torch.ops.aten.mul.Tensor(_grouped_mm_181, _grouped_mm_25); _grouped_mm_181 = _grouped_mm_25 = None + permute_1274 = torch.ops.aten.permute.default(mul_1875, [1, 0]) + _grouped_mm_182 = torch.ops.aten._grouped_mm.default(permute_1274, index_17, cumsum_26); permute_1274 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16); primals_163 = None + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 8, '513'); convert_element_type_505 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_193, [0, 2, 1]); wait_tensor_193 = None + permute_1276 = torch.ops.aten.permute.default(permute_141, [0, 2, 1]); permute_141 = None + _grouped_mm_183 = torch.ops.aten._grouped_mm.default(mul_1875, permute_1276, cumsum_26); mul_1875 = permute_1276 = None + convert_element_type_2736 = torch.ops.prims.convert_element_type.default(mul_1876, torch.float32); mul_1876 = None + reciprocal_35 = torch.ops.aten.reciprocal.default(add_576); add_576 = None + mul_1877 = torch.ops.aten.mul.Tensor(reciprocal_35, 1); reciprocal_35 = None + mul_1878 = torch.ops.aten.mul.Tensor(convert_element_type_2736, mul_1877); convert_element_type_2736 = None + sub_728 = torch.ops.aten.sub.Tensor(1, mul_1877); mul_1877 = None + mul_1879 = torch.ops.aten.mul.Tensor(convert_element_type_508, sub_728); convert_element_type_508 = sub_728 = None + add_2034 = torch.ops.aten.add.Tensor(mul_1879, 1); mul_1879 = None + mul_1880 = torch.ops.aten.mul.Tensor(mul_1878, add_2034); mul_1878 = add_2034 = None + convert_element_type_2738 = torch.ops.prims.convert_element_type.default(mul_1880, torch.bfloat16); mul_1880 = None + permute_1278 = torch.ops.aten.permute.default(convert_element_type_2738, [1, 0]) + _grouped_mm_184 = torch.ops.aten._grouped_mm.default(permute_1278, index_17, cumsum_26); permute_1278 = index_17 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16); primals_161 = None + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 8, '513'); convert_element_type_502 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_190, [0, 2, 1]); wait_tensor_190 = None + permute_1280 = torch.ops.aten.permute.default(permute_140, [0, 2, 1]); permute_140 = None + _grouped_mm_185 = torch.ops.aten._grouped_mm.default(convert_element_type_2738, permute_1280, cumsum_26); convert_element_type_2738 = permute_1280 = cumsum_26 = None + add_2035 = torch.ops.aten.add.Tensor(_grouped_mm_183, _grouped_mm_185); _grouped_mm_183 = _grouped_mm_185 = None + convert_element_type_2739 = torch.ops.prims.convert_element_type.default(_grouped_mm_182, torch.float32); _grouped_mm_182 = None + div_234 = torch.ops.aten.div.Tensor(convert_element_type_2739, 64); convert_element_type_2739 = None + reduce_scatter_tensor_243 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_234, 'sum', 8, '513'); div_234 = None + wait_tensor_837 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_243); reduce_scatter_tensor_243 = None + convert_element_type_2740 = torch.ops.prims.convert_element_type.default(_grouped_mm_180, torch.float32); _grouped_mm_180 = None + div_235 = torch.ops.aten.div.Tensor(convert_element_type_2740, 64); convert_element_type_2740 = None + reduce_scatter_tensor_244 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_235, 'sum', 8, '513'); div_235 = None + wait_tensor_838 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_244); reduce_scatter_tensor_244 = None + convert_element_type_2741 = torch.ops.prims.convert_element_type.default(_grouped_mm_184, torch.float32); _grouped_mm_184 = None + div_236 = torch.ops.aten.div.Tensor(convert_element_type_2741, 64); convert_element_type_2741 = None + reduce_scatter_tensor_245 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_236, 'sum', 8, '513'); div_236 = None + wait_tensor_839 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_245); reduce_scatter_tensor_245 = None + index_put_86 = torch.ops.aten.index_put.default(full_416, [getitem_134], add_2035, True); full_416 = getitem_134 = add_2035 = None + slice_209 = torch.ops.aten.slice.Tensor(index_put_86, 0, 0, add_2036); index_put_86 = add_2036 = None + all_to_all_single_113 = torch.ops._c10d_functional.all_to_all_single.default(slice_209, [_local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135], [_local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143], '521'); slice_209 = _local_scalar_dense_128 = _local_scalar_dense_129 = _local_scalar_dense_130 = _local_scalar_dense_131 = _local_scalar_dense_132 = _local_scalar_dense_133 = _local_scalar_dense_134 = _local_scalar_dense_135 = _local_scalar_dense_136 = _local_scalar_dense_137 = _local_scalar_dense_138 = _local_scalar_dense_139 = _local_scalar_dense_140 = _local_scalar_dense_141 = _local_scalar_dense_142 = _local_scalar_dense_143 = None + wait_tensor_840 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_113); all_to_all_single_113 = None + index_put_87 = torch.ops.aten.index_put.default(full_default_52, [div_42], wait_tensor_840, True); div_42 = wait_tensor_840 = None + add_2040 = torch.ops.aten.add.Tensor(add_2032, index_put_87); add_2032 = index_put_87 = None + mul_1881 = torch.ops.aten.mul.Tensor(view_2106, 1.0); view_2106 = None + scatter_add_17 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_131, mul_1881); getitem_131 = mul_1881 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(mm_75, torch.float32); mm_75 = None + sub_192 = torch.ops.aten.sub.Tensor(convert_element_type_497, amax_8); convert_element_type_497 = amax_8 = None + exp_25 = torch.ops.aten.exp.default(sub_192); sub_192 = None + div_41 = torch.ops.aten.div.Tensor(exp_25, sum_33); exp_25 = sum_33 = None + mul_1882 = torch.ops.aten.mul.Tensor(scatter_add_17, div_41); scatter_add_17 = None + sum_243 = torch.ops.aten.sum.dim_IntList(mul_1882, [1], True) + neg_106 = torch.ops.aten.neg.default(div_41); div_41 = None + fma_17 = torch.ops.prims.fma.default(neg_106, sum_243, mul_1882); neg_106 = sum_243 = mul_1882 = None + convert_element_type_2742 = torch.ops.prims.convert_element_type.default(fma_17, torch.bfloat16); fma_17 = None + permute_1282 = torch.ops.aten.permute.default(convert_element_type_2742, [1, 0]) + mm_496 = torch.ops.aten.mm.default(permute_1282, view_594); permute_1282 = view_594 = None + convert_element_type_494 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16); primals_159 = None + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_494, 64, '0'); convert_element_type_494 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_186, [1, 0]); wait_tensor_186 = None + permute_1284 = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None + mm_497 = torch.ops.aten.mm.default(convert_element_type_2742, permute_1284); convert_element_type_2742 = permute_1284 = None + add_2041 = torch.ops.aten.add.Tensor(add_2040, mm_497); add_2040 = mm_497 = None + convert_element_type_2747 = torch.ops.prims.convert_element_type.default(mm_496, torch.float32); mm_496 = None + reduce_scatter_tensor_246 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2747, 'avg', 64, '0'); convert_element_type_2747 = None + wait_tensor_841 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_246); reduce_scatter_tensor_246 = None + view_2108 = torch.ops.aten.view.default(add_2041, [2, 4096, 2048]); add_2041 = None + convert_element_type_2748 = torch.ops.prims.convert_element_type.default(view_2108, torch.float32); view_2108 = None + convert_element_type_491 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16); primals_157 = None + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_491, 64, '0'); convert_element_type_491 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + convert_element_type_2750 = torch.ops.prims.convert_element_type.default(wait_tensor_185, torch.float32); wait_tensor_185 = None + mul_1883 = torch.ops.aten.mul.Tensor(convert_element_type_2748, convert_element_type_2750); convert_element_type_2750 = None + convert_element_type_492 = torch.ops.prims.convert_element_type.default(add_552, torch.float32); add_552 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_492, rsqrt_29); convert_element_type_492 = None + mul_1885 = torch.ops.aten.mul.Tensor(mul_407, mul_1883) + sum_244 = torch.ops.aten.sum.dim_IntList(mul_1885, [2], True); mul_1885 = None + div_237 = torch.ops.aten.div.Tensor(mul_407, 2048) + mul_1886 = torch.ops.aten.mul.Tensor(div_237, sum_244); div_237 = sum_244 = None + sub_730 = torch.ops.aten.sub.Tensor(mul_1883, mul_1886); mul_1883 = mul_1886 = None + mul_1887 = torch.ops.aten.mul.Tensor(sub_730, rsqrt_29); sub_730 = rsqrt_29 = None + mul_1888 = torch.ops.aten.mul.Tensor(convert_element_type_2748, mul_407); convert_element_type_2748 = mul_407 = None + sum_245 = torch.ops.aten.sum.dim_IntList(mul_1888, [0, 1]); mul_1888 = None + convert_element_type_2751 = torch.ops.prims.convert_element_type.default(mul_1887, torch.bfloat16); mul_1887 = None + add_2042 = torch.ops.aten.add.Tensor(add_2029, convert_element_type_2751); add_2029 = convert_element_type_2751 = None + convert_element_type_default_30 = torch.ops.prims.convert_element_type.default(sum_245, torch.float32); sum_245 = None + reduce_scatter_tensor_247 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_30, 'avg', 64, '0'); convert_element_type_default_30 = None + wait_tensor_842 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_247); reduce_scatter_tensor_247 = None + view_2109 = torch.ops.aten.view.default(add_2042, [8192, 2048]) + permute_1286 = torch.ops.aten.permute.default(view_2109, [1, 0]) + permute_137 = torch.ops.aten.permute.default(getitem_127, [0, 2, 1, 3]) + view_589 = torch.ops.aten.view.default(permute_137, [2, 4096, -1]); permute_137 = None + view_591 = torch.ops.aten.view.default(view_589, [8192, 2048]); view_589 = None + mm_498 = torch.ops.aten.mm.default(permute_1286, view_591); permute_1286 = view_591 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16); primals_156 = None + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_488, 64, '0'); convert_element_type_488 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_138 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + permute_1288 = torch.ops.aten.permute.default(permute_138, [1, 0]); permute_138 = None + mm_499 = torch.ops.aten.mm.default(view_2109, permute_1288); view_2109 = permute_1288 = None + view_2110 = torch.ops.aten.view.default(mm_499, [2, 4096, 2048]); mm_499 = None + convert_element_type_2758 = torch.ops.prims.convert_element_type.default(mm_498, torch.float32); mm_498 = None + reduce_scatter_tensor_248 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2758, 'avg', 64, '0'); convert_element_type_2758 = None + wait_tensor_843 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_248); reduce_scatter_tensor_248 = None + view_2111 = torch.ops.aten.view.default(view_2110, [2, 4096, 16, 128]); view_2110 = None + permute_1290 = torch.ops.aten.permute.default(view_2111, [0, 2, 1, 3]); view_2111 = None + fw_graph17 = self.fw_graph17 + joint_graph17 = self.joint_graph17 + mask_graph17 = self.mask_graph17 + flex_attention_backward_17 = torch.ops.higher_order.flex_attention_backward(permute_134, permute_135, permute_136, getitem_127, getitem_128, permute_1290, None, fw_graph17, joint_graph17, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph17), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_134 = permute_135 = permute_136 = getitem_127 = getitem_128 = permute_1290 = fw_graph17 = joint_graph17 = mask_graph17 = None + getitem_441 = flex_attention_backward_17[0] + getitem_442 = flex_attention_backward_17[1] + getitem_443 = flex_attention_backward_17[2]; flex_attention_backward_17 = None + permute_1291 = torch.ops.aten.permute.default(getitem_443, [0, 2, 1, 3]); getitem_443 = None + permute_1292 = torch.ops.aten.permute.default(getitem_442, [0, 2, 1, 3]); getitem_442 = None + permute_1293 = torch.ops.aten.permute.default(getitem_441, [0, 2, 1, 3]); getitem_441 = None + slice_211 = torch.ops.aten.slice.Tensor(permute_1292, 3, 0, 128) + slice_212 = torch.ops.aten.slice.Tensor(permute_1292, 3, 128, 192); permute_1292 = None + sum_246 = torch.ops.aten.sum.dim_IntList(slice_212, [2], True); slice_212 = None + cat_131 = torch.ops.aten.cat.default([slice_211, permute_1291], 3); slice_211 = permute_1291 = None + view_2112 = torch.ops.aten.view.default(cat_131, [2, 4096, 4096]); cat_131 = None + view_2113 = torch.ops.aten.view.default(view_2112, [8192, 4096]); view_2112 = None + permute_1294 = torch.ops.aten.permute.default(view_2113, [1, 0]) + mm_500 = torch.ops.aten.mm.default(permute_1294, view_586); permute_1294 = view_586 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16); primals_155 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 64, '0'); convert_element_type_485 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + permute_1296 = torch.ops.aten.permute.default(permute_133, [1, 0]); permute_133 = None + mm_501 = torch.ops.aten.mm.default(view_2113, permute_1296); view_2113 = permute_1296 = None + view_2114 = torch.ops.aten.view.default(mm_501, [2, 4096, 512]); mm_501 = None + convert_element_type_2763 = torch.ops.prims.convert_element_type.default(mm_500, torch.float32); mm_500 = None + reduce_scatter_tensor_249 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2763, 'avg', 64, '0'); convert_element_type_2763 = None + wait_tensor_844 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_249); reduce_scatter_tensor_249 = None + convert_element_type_2764 = torch.ops.prims.convert_element_type.default(view_2114, torch.float32); view_2114 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16); primals_154 = None + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 64, '0'); convert_element_type_482 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_2766 = torch.ops.prims.convert_element_type.default(wait_tensor_182, torch.float32); wait_tensor_182 = None + mul_1889 = torch.ops.aten.mul.Tensor(convert_element_type_2764, convert_element_type_2766); convert_element_type_2766 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(getitem_123, torch.float32); getitem_123 = None + mul_405 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_28); convert_element_type_483 = None + mul_1891 = torch.ops.aten.mul.Tensor(mul_405, mul_1889) + sum_247 = torch.ops.aten.sum.dim_IntList(mul_1891, [2], True); mul_1891 = None + div_238 = torch.ops.aten.div.Tensor(mul_405, 512) + mul_1892 = torch.ops.aten.mul.Tensor(div_238, sum_247); div_238 = sum_247 = None + sub_731 = torch.ops.aten.sub.Tensor(mul_1889, mul_1892); mul_1889 = mul_1892 = None + mul_1893 = torch.ops.aten.mul.Tensor(sub_731, rsqrt_28); sub_731 = rsqrt_28 = None + mul_1894 = torch.ops.aten.mul.Tensor(convert_element_type_2764, mul_405); convert_element_type_2764 = mul_405 = None + sum_248 = torch.ops.aten.sum.dim_IntList(mul_1894, [0, 1]); mul_1894 = None + convert_element_type_2767 = torch.ops.prims.convert_element_type.default(mul_1893, torch.bfloat16); mul_1893 = None + convert_element_type_default_29 = torch.ops.prims.convert_element_type.default(sum_248, torch.float32); sum_248 = None + reduce_scatter_tensor_250 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_29, 'avg', 64, '0'); convert_element_type_default_29 = None + wait_tensor_845 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_250); reduce_scatter_tensor_250 = None + convert_element_type_2770 = torch.ops.prims.convert_element_type.default(sum_246, torch.float32); sum_246 = None + view_2115 = torch.ops.aten.view.default(convert_element_type_2770, [2, 4096, 1, 32, 2]); convert_element_type_2770 = None + view_as_complex_88 = torch.ops.aten.view_as_complex.default(view_2115); view_2115 = None + mul_1895 = torch.ops.aten.mul.Tensor(view_as_complex_88, clone_9); view_as_complex_88 = None + view_as_real_88 = torch.ops.aten.view_as_real.default(mul_1895); mul_1895 = None + view_2116 = torch.ops.aten.view.default(view_as_real_88, [2, 4096, 1, 64]); view_as_real_88 = None + convert_element_type_2771 = torch.ops.prims.convert_element_type.default(view_2116, torch.bfloat16); view_2116 = None + squeeze_43 = torch.ops.aten.squeeze.dim(convert_element_type_2771, 2); convert_element_type_2771 = None + cat_132 = torch.ops.aten.cat.default([convert_element_type_2767, squeeze_43], 2); convert_element_type_2767 = squeeze_43 = None + view_2117 = torch.ops.aten.view.default(cat_132, [8192, 576]); cat_132 = None + permute_1298 = torch.ops.aten.permute.default(view_2117, [1, 0]) + mm_502 = torch.ops.aten.mm.default(permute_1298, view_572); permute_1298 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16); primals_153 = None + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_477, 64, '0'); convert_element_type_477 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_181, [1, 0]); wait_tensor_181 = None + permute_1300 = torch.ops.aten.permute.default(permute_132, [1, 0]); permute_132 = None + mm_503 = torch.ops.aten.mm.default(view_2117, permute_1300); view_2117 = permute_1300 = None + view_2118 = torch.ops.aten.view.default(mm_503, [2, 4096, 2048]); mm_503 = None + convert_element_type_2776 = torch.ops.prims.convert_element_type.default(mm_502, torch.float32); mm_502 = None + reduce_scatter_tensor_251 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2776, 'avg', 64, '0'); convert_element_type_2776 = None + wait_tensor_846 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_251); reduce_scatter_tensor_251 = None + slice_213 = torch.ops.aten.slice.Tensor(permute_1293, 3, 0, 128) + slice_214 = torch.ops.aten.slice.Tensor(permute_1293, 3, 128, 192); permute_1293 = None + convert_element_type_2777 = torch.ops.prims.convert_element_type.default(slice_214, torch.float32); slice_214 = None + view_2119 = torch.ops.aten.view.default(convert_element_type_2777, [2, 4096, 16, 32, 2]); convert_element_type_2777 = None + view_as_complex_89 = torch.ops.aten.view_as_complex.default(view_2119); view_2119 = None + mul_1896 = torch.ops.aten.mul.Tensor(view_as_complex_89, clone_9); view_as_complex_89 = None + view_as_real_89 = torch.ops.aten.view_as_real.default(mul_1896); mul_1896 = None + view_2120 = torch.ops.aten.view.default(view_as_real_89, [2, 4096, 16, 64]); view_as_real_89 = None + convert_element_type_2778 = torch.ops.prims.convert_element_type.default(view_2120, torch.bfloat16); view_2120 = None + cat_133 = torch.ops.aten.cat.default([slice_213, convert_element_type_2778], 3); slice_213 = convert_element_type_2778 = None + view_2121 = torch.ops.aten.view.default(cat_133, [2, 4096, 3072]); cat_133 = None + view_2122 = torch.ops.aten.view.default(view_2121, [8192, 3072]); view_2121 = None + permute_1302 = torch.ops.aten.permute.default(view_2122, [1, 0]) + mm_504 = torch.ops.aten.mm.default(permute_1302, view_572); permute_1302 = view_572 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16); primals_152 = None + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 64, '0'); convert_element_type_472 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + permute_1304 = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None + mm_505 = torch.ops.aten.mm.default(view_2122, permute_1304); view_2122 = permute_1304 = None + view_2123 = torch.ops.aten.view.default(mm_505, [2, 4096, 2048]); mm_505 = None + add_2043 = torch.ops.aten.add.Tensor(view_2118, view_2123); view_2118 = view_2123 = None + convert_element_type_2783 = torch.ops.prims.convert_element_type.default(mm_504, torch.float32); mm_504 = None + reduce_scatter_tensor_252 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2783, 'avg', 64, '0'); convert_element_type_2783 = None + wait_tensor_847 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_252); reduce_scatter_tensor_252 = None + convert_element_type_2784 = torch.ops.prims.convert_element_type.default(add_2043, torch.float32); add_2043 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16); primals_151 = None + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 64, '0'); convert_element_type_469 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + convert_element_type_2786 = torch.ops.prims.convert_element_type.default(wait_tensor_179, torch.float32); wait_tensor_179 = None + mul_1897 = torch.ops.aten.mul.Tensor(convert_element_type_2784, convert_element_type_2786); convert_element_type_2786 = None + convert_element_type_470 = torch.ops.prims.convert_element_type.default(add_549, torch.float32); add_549 = None + mul_401 = torch.ops.aten.mul.Tensor(convert_element_type_470, rsqrt_27); convert_element_type_470 = None + mul_1899 = torch.ops.aten.mul.Tensor(mul_401, mul_1897) + sum_249 = torch.ops.aten.sum.dim_IntList(mul_1899, [2], True); mul_1899 = None + div_239 = torch.ops.aten.div.Tensor(mul_401, 2048) + mul_1900 = torch.ops.aten.mul.Tensor(div_239, sum_249); div_239 = sum_249 = None + sub_732 = torch.ops.aten.sub.Tensor(mul_1897, mul_1900); mul_1897 = mul_1900 = None + mul_1901 = torch.ops.aten.mul.Tensor(sub_732, rsqrt_27); sub_732 = rsqrt_27 = None + mul_1902 = torch.ops.aten.mul.Tensor(convert_element_type_2784, mul_401); convert_element_type_2784 = mul_401 = None + sum_250 = torch.ops.aten.sum.dim_IntList(mul_1902, [0, 1]); mul_1902 = None + convert_element_type_2787 = torch.ops.prims.convert_element_type.default(mul_1901, torch.bfloat16); mul_1901 = None + add_2044 = torch.ops.aten.add.Tensor(add_2042, convert_element_type_2787); add_2042 = convert_element_type_2787 = None + convert_element_type_default_28 = torch.ops.prims.convert_element_type.default(sum_250, torch.float32); sum_250 = None + reduce_scatter_tensor_253 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_28, 'avg', 64, '0'); convert_element_type_default_28 = None + wait_tensor_848 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_253); reduce_scatter_tensor_253 = None + view_2124 = torch.ops.aten.view.default(add_2044, [8192, 2048]) + unsqueeze_71 = torch.ops.aten.unsqueeze.default(view_2124, 1) + convert_element_type_2790 = torch.ops.prims.convert_element_type.default(unsqueeze_71, torch.float32); unsqueeze_71 = None + bmm_62 = torch.ops.aten.bmm.default(permute_1306, convert_element_type_2790); permute_1306 = None + bmm_63 = torch.ops.aten.bmm.default(convert_element_type_2790, permute_1307); convert_element_type_2790 = permute_1307 = None + convert_element_type_2791 = torch.ops.prims.convert_element_type.default(bmm_62, torch.bfloat16); bmm_62 = None + view_2125 = torch.ops.aten.view.default(bmm_63, [8192, 6]); bmm_63 = None + view_2126 = torch.ops.aten.view.default(convert_element_type_2791, [49152, 2048]); convert_element_type_2791 = None + index_88 = torch.ops.aten.index.Tensor(view_2126, [getitem_119]); view_2126 = getitem_119 = None + permute_1308 = torch.ops.aten.permute.default(view_2124, [1, 0]) + mm_506 = torch.ops.aten.mm.default(permute_1308, mul_398); permute_1308 = mul_398 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16); primals_150 = None + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_464, 64, '0'); convert_element_type_464 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + permute_1310 = torch.ops.aten.permute.default(permute_130, [1, 0]); permute_130 = None + mm_507 = torch.ops.aten.mm.default(view_2124, permute_1310); view_2124 = permute_1310 = None + convert_element_type_2796 = torch.ops.prims.convert_element_type.default(mm_506, torch.float32); mm_506 = None + reduce_scatter_tensor_254 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2796, 'avg', 64, '0'); convert_element_type_2796 = None + wait_tensor_849 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_254); reduce_scatter_tensor_254 = None + convert_element_type_459 = torch.ops.prims.convert_element_type.default(mm_68, torch.float32); mm_68 = None + neg_16 = torch.ops.aten.neg.default(convert_element_type_459) + exp_24 = torch.ops.aten.exp.default(neg_16); neg_16 = None + add_544 = torch.ops.aten.add.Tensor(exp_24, 1); exp_24 = None + div_40 = torch.ops.aten.div.Tensor(convert_element_type_459, add_544) + convert_element_type_460 = torch.ops.prims.convert_element_type.default(div_40, torch.bfloat16); div_40 = None + mul_1903 = torch.ops.aten.mul.Tensor(mm_507, convert_element_type_460); convert_element_type_460 = None + mul_1904 = torch.ops.aten.mul.Tensor(mm_507, mm_69); mm_507 = mm_69 = None + permute_1312 = torch.ops.aten.permute.default(mul_1903, [1, 0]) + mm_508 = torch.ops.aten.mm.default(permute_1312, view_527); permute_1312 = None + convert_element_type_461 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16); primals_149 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_461, 64, '0'); convert_element_type_461 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_177, [1, 0]); wait_tensor_177 = None + permute_1314 = torch.ops.aten.permute.default(permute_129, [1, 0]); permute_129 = None + mm_509 = torch.ops.aten.mm.default(mul_1903, permute_1314); mul_1903 = permute_1314 = None + convert_element_type_2801 = torch.ops.prims.convert_element_type.default(mm_508, torch.float32); mm_508 = None + reduce_scatter_tensor_255 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2801, 'avg', 64, '0'); convert_element_type_2801 = None + wait_tensor_850 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_255); reduce_scatter_tensor_255 = None + convert_element_type_2802 = torch.ops.prims.convert_element_type.default(mul_1904, torch.float32); mul_1904 = None + reciprocal_36 = torch.ops.aten.reciprocal.default(add_544); add_544 = None + mul_1905 = torch.ops.aten.mul.Tensor(reciprocal_36, 1); reciprocal_36 = None + mul_1906 = torch.ops.aten.mul.Tensor(convert_element_type_2802, mul_1905); convert_element_type_2802 = None + sub_733 = torch.ops.aten.sub.Tensor(1, mul_1905); mul_1905 = None + mul_1907 = torch.ops.aten.mul.Tensor(convert_element_type_459, sub_733); convert_element_type_459 = sub_733 = None + add_2046 = torch.ops.aten.add.Tensor(mul_1907, 1); mul_1907 = None + mul_1908 = torch.ops.aten.mul.Tensor(mul_1906, add_2046); mul_1906 = add_2046 = None + convert_element_type_2804 = torch.ops.prims.convert_element_type.default(mul_1908, torch.bfloat16); mul_1908 = None + permute_1316 = torch.ops.aten.permute.default(convert_element_type_2804, [1, 0]) + mm_510 = torch.ops.aten.mm.default(permute_1316, view_527); permute_1316 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16); primals_148 = None + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_456, 64, '0'); convert_element_type_456 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + permute_1318 = torch.ops.aten.permute.default(permute_128, [1, 0]); permute_128 = None + mm_511 = torch.ops.aten.mm.default(convert_element_type_2804, permute_1318); convert_element_type_2804 = permute_1318 = None + add_2047 = torch.ops.aten.add.Tensor(mm_509, mm_511); mm_509 = mm_511 = None + convert_element_type_2809 = torch.ops.prims.convert_element_type.default(mm_510, torch.float32); mm_510 = None + reduce_scatter_tensor_256 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2809, 'avg', 64, '0'); convert_element_type_2809 = None + wait_tensor_851 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_256); reduce_scatter_tensor_256 = None + all_to_all_single_114 = torch.ops._c10d_functional.all_to_all_single.default(index_88, [_local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127], [_local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119], '521'); index_88 = None + wait_tensor_852 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_114); all_to_all_single_114 = None + full_420 = torch.ops.aten.full.default([sym_size_int_29, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_29 = None + slice_scatter_18 = torch.ops.aten.slice_scatter.default(full_420, wait_tensor_852, 0, 0, -1); wait_tensor_852 = None + index_89 = torch.ops.aten.index.Tensor(slice_scatter_18, [getitem_120]); slice_scatter_18 = None + permute_1320 = torch.ops.aten.permute.default(index_89, [1, 0]) + _grouped_mm_186 = torch.ops.aten._grouped_mm.default(permute_1320, mul_378, cumsum_23); permute_1320 = mul_378 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16); primals_146 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_450, 8, '513'); convert_element_type_450 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + permute_127 = torch.ops.aten.permute.default(wait_tensor_171, [0, 2, 1]); wait_tensor_171 = None + permute_1322 = torch.ops.aten.permute.default(permute_127, [0, 2, 1]); permute_127 = None + _grouped_mm_187 = torch.ops.aten._grouped_mm.default(index_89, permute_1322, cumsum_23); index_89 = permute_1322 = None + convert_element_type_454 = torch.ops.prims.convert_element_type.default(_grouped_mm_21, torch.float32); _grouped_mm_21 = None + neg_15 = torch.ops.aten.neg.default(convert_element_type_454) + exp_23 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_508 = torch.ops.aten.add.Tensor(exp_23, 1); exp_23 = None + div_39 = torch.ops.aten.div.Tensor(convert_element_type_454, add_508) + convert_element_type_455 = torch.ops.prims.convert_element_type.default(div_39, torch.bfloat16); div_39 = None + mul_1909 = torch.ops.aten.mul.Tensor(_grouped_mm_187, convert_element_type_455); convert_element_type_455 = None + mul_1910 = torch.ops.aten.mul.Tensor(_grouped_mm_187, _grouped_mm_22); _grouped_mm_187 = _grouped_mm_22 = None + permute_1324 = torch.ops.aten.permute.default(mul_1909, [1, 0]) + _grouped_mm_188 = torch.ops.aten._grouped_mm.default(permute_1324, index_15, cumsum_23); permute_1324 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16); primals_147 = None + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 8, '513'); convert_element_type_451 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + permute_126 = torch.ops.aten.permute.default(wait_tensor_172, [0, 2, 1]); wait_tensor_172 = None + permute_1326 = torch.ops.aten.permute.default(permute_126, [0, 2, 1]); permute_126 = None + _grouped_mm_189 = torch.ops.aten._grouped_mm.default(mul_1909, permute_1326, cumsum_23); mul_1909 = permute_1326 = None + convert_element_type_2810 = torch.ops.prims.convert_element_type.default(mul_1910, torch.float32); mul_1910 = None + reciprocal_37 = torch.ops.aten.reciprocal.default(add_508); add_508 = None + mul_1911 = torch.ops.aten.mul.Tensor(reciprocal_37, 1); reciprocal_37 = None + mul_1912 = torch.ops.aten.mul.Tensor(convert_element_type_2810, mul_1911); convert_element_type_2810 = None + sub_734 = torch.ops.aten.sub.Tensor(1, mul_1911); mul_1911 = None + mul_1913 = torch.ops.aten.mul.Tensor(convert_element_type_454, sub_734); convert_element_type_454 = sub_734 = None + add_2049 = torch.ops.aten.add.Tensor(mul_1913, 1); mul_1913 = None + mul_1914 = torch.ops.aten.mul.Tensor(mul_1912, add_2049); mul_1912 = add_2049 = None + convert_element_type_2812 = torch.ops.prims.convert_element_type.default(mul_1914, torch.bfloat16); mul_1914 = None + permute_1328 = torch.ops.aten.permute.default(convert_element_type_2812, [1, 0]) + _grouped_mm_190 = torch.ops.aten._grouped_mm.default(permute_1328, index_15, cumsum_23); permute_1328 = index_15 = None + convert_element_type_448 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16); primals_145 = None + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_448, 8, '513'); convert_element_type_448 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_125 = torch.ops.aten.permute.default(wait_tensor_169, [0, 2, 1]); wait_tensor_169 = None + permute_1330 = torch.ops.aten.permute.default(permute_125, [0, 2, 1]); permute_125 = None + _grouped_mm_191 = torch.ops.aten._grouped_mm.default(convert_element_type_2812, permute_1330, cumsum_23); convert_element_type_2812 = permute_1330 = cumsum_23 = None + add_2050 = torch.ops.aten.add.Tensor(_grouped_mm_189, _grouped_mm_191); _grouped_mm_189 = _grouped_mm_191 = None + convert_element_type_2813 = torch.ops.prims.convert_element_type.default(_grouped_mm_188, torch.float32); _grouped_mm_188 = None + div_240 = torch.ops.aten.div.Tensor(convert_element_type_2813, 64); convert_element_type_2813 = None + reduce_scatter_tensor_257 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_240, 'sum', 8, '513'); div_240 = None + wait_tensor_853 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_257); reduce_scatter_tensor_257 = None + convert_element_type_2814 = torch.ops.prims.convert_element_type.default(_grouped_mm_186, torch.float32); _grouped_mm_186 = None + div_241 = torch.ops.aten.div.Tensor(convert_element_type_2814, 64); convert_element_type_2814 = None + reduce_scatter_tensor_258 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_241, 'sum', 8, '513'); div_241 = None + wait_tensor_854 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_258); reduce_scatter_tensor_258 = None + convert_element_type_2815 = torch.ops.prims.convert_element_type.default(_grouped_mm_190, torch.float32); _grouped_mm_190 = None + div_242 = torch.ops.aten.div.Tensor(convert_element_type_2815, 64); convert_element_type_2815 = None + reduce_scatter_tensor_259 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_242, 'sum', 8, '513'); div_242 = None + wait_tensor_855 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_259); reduce_scatter_tensor_259 = None + index_put_88 = torch.ops.aten.index_put.default(full_420, [getitem_120], add_2050, True); full_420 = getitem_120 = add_2050 = None + slice_215 = torch.ops.aten.slice.Tensor(index_put_88, 0, 0, add_2051); index_put_88 = add_2051 = None + all_to_all_single_115 = torch.ops._c10d_functional.all_to_all_single.default(slice_215, [_local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119], [_local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127], '521'); slice_215 = _local_scalar_dense_112 = _local_scalar_dense_113 = _local_scalar_dense_114 = _local_scalar_dense_115 = _local_scalar_dense_116 = _local_scalar_dense_117 = _local_scalar_dense_118 = _local_scalar_dense_119 = _local_scalar_dense_120 = _local_scalar_dense_121 = _local_scalar_dense_122 = _local_scalar_dense_123 = _local_scalar_dense_124 = _local_scalar_dense_125 = _local_scalar_dense_126 = _local_scalar_dense_127 = None + wait_tensor_856 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_115); all_to_all_single_115 = None + index_put_89 = torch.ops.aten.index_put.default(full_default_52, [div_37], wait_tensor_856, True); div_37 = wait_tensor_856 = None + add_2055 = torch.ops.aten.add.Tensor(add_2047, index_put_89); add_2047 = index_put_89 = None + mul_1915 = torch.ops.aten.mul.Tensor(view_2125, 1.0); view_2125 = None + scatter_add_18 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_117, mul_1915); getitem_117 = mul_1915 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(mm_67, torch.float32); mm_67 = None + sub_168 = torch.ops.aten.sub.Tensor(convert_element_type_443, amax_7); convert_element_type_443 = amax_7 = None + exp_22 = torch.ops.aten.exp.default(sub_168); sub_168 = None + div_36 = torch.ops.aten.div.Tensor(exp_22, sum_29); exp_22 = sum_29 = None + mul_1916 = torch.ops.aten.mul.Tensor(scatter_add_18, div_36); scatter_add_18 = None + sum_251 = torch.ops.aten.sum.dim_IntList(mul_1916, [1], True) + neg_109 = torch.ops.aten.neg.default(div_36); div_36 = None + fma_18 = torch.ops.prims.fma.default(neg_109, sum_251, mul_1916); neg_109 = sum_251 = mul_1916 = None + convert_element_type_2816 = torch.ops.prims.convert_element_type.default(fma_18, torch.bfloat16); fma_18 = None + permute_1332 = torch.ops.aten.permute.default(convert_element_type_2816, [1, 0]) + mm_512 = torch.ops.aten.mm.default(permute_1332, view_527); permute_1332 = view_527 = None + convert_element_type_440 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_440, 64, '0'); convert_element_type_440 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_124 = torch.ops.aten.permute.default(wait_tensor_165, [1, 0]); wait_tensor_165 = None + permute_1334 = torch.ops.aten.permute.default(permute_124, [1, 0]); permute_124 = None + mm_513 = torch.ops.aten.mm.default(convert_element_type_2816, permute_1334); convert_element_type_2816 = permute_1334 = None + add_2056 = torch.ops.aten.add.Tensor(add_2055, mm_513); add_2055 = mm_513 = None + convert_element_type_2821 = torch.ops.prims.convert_element_type.default(mm_512, torch.float32); mm_512 = None + reduce_scatter_tensor_260 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2821, 'avg', 64, '0'); convert_element_type_2821 = None + wait_tensor_857 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_260); reduce_scatter_tensor_260 = None + view_2127 = torch.ops.aten.view.default(add_2056, [2, 4096, 2048]); add_2056 = None + convert_element_type_2822 = torch.ops.prims.convert_element_type.default(view_2127, torch.float32); view_2127 = None + convert_element_type_437 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_437, 64, '0'); convert_element_type_437 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_2824 = torch.ops.prims.convert_element_type.default(wait_tensor_164, torch.float32); wait_tensor_164 = None + mul_1917 = torch.ops.aten.mul.Tensor(convert_element_type_2822, convert_element_type_2824); convert_element_type_2824 = None + convert_element_type_438 = torch.ops.prims.convert_element_type.default(add_484, torch.float32); add_484 = None + mul_358 = torch.ops.aten.mul.Tensor(convert_element_type_438, rsqrt_26); convert_element_type_438 = None + mul_1919 = torch.ops.aten.mul.Tensor(mul_358, mul_1917) + sum_252 = torch.ops.aten.sum.dim_IntList(mul_1919, [2], True); mul_1919 = None + div_243 = torch.ops.aten.div.Tensor(mul_358, 2048) + mul_1920 = torch.ops.aten.mul.Tensor(div_243, sum_252); div_243 = sum_252 = None + sub_736 = torch.ops.aten.sub.Tensor(mul_1917, mul_1920); mul_1917 = mul_1920 = None + mul_1921 = torch.ops.aten.mul.Tensor(sub_736, rsqrt_26); sub_736 = rsqrt_26 = None + mul_1922 = torch.ops.aten.mul.Tensor(convert_element_type_2822, mul_358); convert_element_type_2822 = mul_358 = None + sum_253 = torch.ops.aten.sum.dim_IntList(mul_1922, [0, 1]); mul_1922 = None + convert_element_type_2825 = torch.ops.prims.convert_element_type.default(mul_1921, torch.bfloat16); mul_1921 = None + add_2057 = torch.ops.aten.add.Tensor(add_2044, convert_element_type_2825); add_2044 = convert_element_type_2825 = None + convert_element_type_default_27 = torch.ops.prims.convert_element_type.default(sum_253, torch.float32); sum_253 = None + reduce_scatter_tensor_261 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_27, 'avg', 64, '0'); convert_element_type_default_27 = None + wait_tensor_858 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_261); reduce_scatter_tensor_261 = None + view_2128 = torch.ops.aten.view.default(add_2057, [8192, 2048]) + permute_1336 = torch.ops.aten.permute.default(view_2128, [1, 0]) + permute_122 = torch.ops.aten.permute.default(getitem_113, [0, 2, 1, 3]) + view_522 = torch.ops.aten.view.default(permute_122, [2, 4096, -1]); permute_122 = None + view_524 = torch.ops.aten.view.default(view_522, [8192, 2048]); view_522 = None + mm_514 = torch.ops.aten.mm.default(permute_1336, view_524); permute_1336 = view_524 = None + convert_element_type_434 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_434, 64, '0'); convert_element_type_434 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + permute_1338 = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None + mm_515 = torch.ops.aten.mm.default(view_2128, permute_1338); view_2128 = permute_1338 = None + view_2129 = torch.ops.aten.view.default(mm_515, [2, 4096, 2048]); mm_515 = None + convert_element_type_2832 = torch.ops.prims.convert_element_type.default(mm_514, torch.float32); mm_514 = None + reduce_scatter_tensor_262 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2832, 'avg', 64, '0'); convert_element_type_2832 = None + wait_tensor_859 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_262); reduce_scatter_tensor_262 = None + view_2130 = torch.ops.aten.view.default(view_2129, [2, 4096, 16, 128]); view_2129 = None + permute_1340 = torch.ops.aten.permute.default(view_2130, [0, 2, 1, 3]); view_2130 = None + fw_graph18 = self.fw_graph18 + joint_graph18 = self.joint_graph18 + mask_graph18 = self.mask_graph18 + flex_attention_backward_18 = torch.ops.higher_order.flex_attention_backward(permute_119, permute_120, permute_121, getitem_113, getitem_114, permute_1340, None, fw_graph18, joint_graph18, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph18), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_119 = permute_120 = permute_121 = getitem_113 = getitem_114 = permute_1340 = fw_graph18 = joint_graph18 = mask_graph18 = None + getitem_445 = flex_attention_backward_18[0] + getitem_446 = flex_attention_backward_18[1] + getitem_447 = flex_attention_backward_18[2]; flex_attention_backward_18 = None + permute_1341 = torch.ops.aten.permute.default(getitem_447, [0, 2, 1, 3]); getitem_447 = None + permute_1342 = torch.ops.aten.permute.default(getitem_446, [0, 2, 1, 3]); getitem_446 = None + permute_1343 = torch.ops.aten.permute.default(getitem_445, [0, 2, 1, 3]); getitem_445 = None + slice_217 = torch.ops.aten.slice.Tensor(permute_1342, 3, 0, 128) + slice_218 = torch.ops.aten.slice.Tensor(permute_1342, 3, 128, 192); permute_1342 = None + sum_254 = torch.ops.aten.sum.dim_IntList(slice_218, [2], True); slice_218 = None + cat_134 = torch.ops.aten.cat.default([slice_217, permute_1341], 3); slice_217 = permute_1341 = None + view_2131 = torch.ops.aten.view.default(cat_134, [2, 4096, 4096]); cat_134 = None + view_2132 = torch.ops.aten.view.default(view_2131, [8192, 4096]); view_2131 = None + permute_1344 = torch.ops.aten.permute.default(view_2132, [1, 0]) + mm_516 = torch.ops.aten.mm.default(permute_1344, view_519); permute_1344 = view_519 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_431, 64, '0'); convert_element_type_431 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + permute_1346 = torch.ops.aten.permute.default(permute_118, [1, 0]); permute_118 = None + mm_517 = torch.ops.aten.mm.default(view_2132, permute_1346); view_2132 = permute_1346 = None + view_2133 = torch.ops.aten.view.default(mm_517, [2, 4096, 512]); mm_517 = None + convert_element_type_2837 = torch.ops.prims.convert_element_type.default(mm_516, torch.float32); mm_516 = None + reduce_scatter_tensor_263 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2837, 'avg', 64, '0'); convert_element_type_2837 = None + wait_tensor_860 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_263); reduce_scatter_tensor_263 = None + convert_element_type_2838 = torch.ops.prims.convert_element_type.default(view_2133, torch.float32); view_2133 = None + convert_element_type_428 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_428, 64, '0'); convert_element_type_428 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_2840 = torch.ops.prims.convert_element_type.default(wait_tensor_161, torch.float32); wait_tensor_161 = None + mul_1923 = torch.ops.aten.mul.Tensor(convert_element_type_2838, convert_element_type_2840); convert_element_type_2840 = None + convert_element_type_429 = torch.ops.prims.convert_element_type.default(getitem_109, torch.float32); getitem_109 = None + mul_356 = torch.ops.aten.mul.Tensor(convert_element_type_429, rsqrt_25); convert_element_type_429 = None + mul_1925 = torch.ops.aten.mul.Tensor(mul_356, mul_1923) + sum_255 = torch.ops.aten.sum.dim_IntList(mul_1925, [2], True); mul_1925 = None + div_244 = torch.ops.aten.div.Tensor(mul_356, 512) + mul_1926 = torch.ops.aten.mul.Tensor(div_244, sum_255); div_244 = sum_255 = None + sub_737 = torch.ops.aten.sub.Tensor(mul_1923, mul_1926); mul_1923 = mul_1926 = None + mul_1927 = torch.ops.aten.mul.Tensor(sub_737, rsqrt_25); sub_737 = rsqrt_25 = None + mul_1928 = torch.ops.aten.mul.Tensor(convert_element_type_2838, mul_356); convert_element_type_2838 = mul_356 = None + sum_256 = torch.ops.aten.sum.dim_IntList(mul_1928, [0, 1]); mul_1928 = None + convert_element_type_2841 = torch.ops.prims.convert_element_type.default(mul_1927, torch.bfloat16); mul_1927 = None + convert_element_type_default_26 = torch.ops.prims.convert_element_type.default(sum_256, torch.float32); sum_256 = None + reduce_scatter_tensor_264 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_26, 'avg', 64, '0'); convert_element_type_default_26 = None + wait_tensor_861 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_264); reduce_scatter_tensor_264 = None + convert_element_type_2844 = torch.ops.prims.convert_element_type.default(sum_254, torch.float32); sum_254 = None + view_2134 = torch.ops.aten.view.default(convert_element_type_2844, [2, 4096, 1, 32, 2]); convert_element_type_2844 = None + view_as_complex_90 = torch.ops.aten.view_as_complex.default(view_2134); view_2134 = None + mul_1929 = torch.ops.aten.mul.Tensor(view_as_complex_90, clone_9); view_as_complex_90 = None + view_as_real_90 = torch.ops.aten.view_as_real.default(mul_1929); mul_1929 = None + view_2135 = torch.ops.aten.view.default(view_as_real_90, [2, 4096, 1, 64]); view_as_real_90 = None + convert_element_type_2845 = torch.ops.prims.convert_element_type.default(view_2135, torch.bfloat16); view_2135 = None + squeeze_44 = torch.ops.aten.squeeze.dim(convert_element_type_2845, 2); convert_element_type_2845 = None + cat_135 = torch.ops.aten.cat.default([convert_element_type_2841, squeeze_44], 2); convert_element_type_2841 = squeeze_44 = None + view_2136 = torch.ops.aten.view.default(cat_135, [8192, 576]); cat_135 = None + permute_1348 = torch.ops.aten.permute.default(view_2136, [1, 0]) + mm_518 = torch.ops.aten.mm.default(permute_1348, view_505); permute_1348 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_423, 64, '0'); convert_element_type_423 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + permute_1350 = torch.ops.aten.permute.default(permute_117, [1, 0]); permute_117 = None + mm_519 = torch.ops.aten.mm.default(view_2136, permute_1350); view_2136 = permute_1350 = None + view_2137 = torch.ops.aten.view.default(mm_519, [2, 4096, 2048]); mm_519 = None + convert_element_type_2850 = torch.ops.prims.convert_element_type.default(mm_518, torch.float32); mm_518 = None + reduce_scatter_tensor_265 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2850, 'avg', 64, '0'); convert_element_type_2850 = None + wait_tensor_862 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_265); reduce_scatter_tensor_265 = None + slice_219 = torch.ops.aten.slice.Tensor(permute_1343, 3, 0, 128) + slice_220 = torch.ops.aten.slice.Tensor(permute_1343, 3, 128, 192); permute_1343 = None + convert_element_type_2851 = torch.ops.prims.convert_element_type.default(slice_220, torch.float32); slice_220 = None + view_2138 = torch.ops.aten.view.default(convert_element_type_2851, [2, 4096, 16, 32, 2]); convert_element_type_2851 = None + view_as_complex_91 = torch.ops.aten.view_as_complex.default(view_2138); view_2138 = None + mul_1930 = torch.ops.aten.mul.Tensor(view_as_complex_91, clone_9); view_as_complex_91 = None + view_as_real_91 = torch.ops.aten.view_as_real.default(mul_1930); mul_1930 = None + view_2139 = torch.ops.aten.view.default(view_as_real_91, [2, 4096, 16, 64]); view_as_real_91 = None + convert_element_type_2852 = torch.ops.prims.convert_element_type.default(view_2139, torch.bfloat16); view_2139 = None + cat_136 = torch.ops.aten.cat.default([slice_219, convert_element_type_2852], 3); slice_219 = convert_element_type_2852 = None + view_2140 = torch.ops.aten.view.default(cat_136, [2, 4096, 3072]); cat_136 = None + view_2141 = torch.ops.aten.view.default(view_2140, [8192, 3072]); view_2140 = None + permute_1352 = torch.ops.aten.permute.default(view_2141, [1, 0]) + mm_520 = torch.ops.aten.mm.default(permute_1352, view_505); permute_1352 = view_505 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 64, '0'); convert_element_type_418 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_116 = torch.ops.aten.permute.default(wait_tensor_159, [1, 0]); wait_tensor_159 = None + permute_1354 = torch.ops.aten.permute.default(permute_116, [1, 0]); permute_116 = None + mm_521 = torch.ops.aten.mm.default(view_2141, permute_1354); view_2141 = permute_1354 = None + view_2142 = torch.ops.aten.view.default(mm_521, [2, 4096, 2048]); mm_521 = None + add_2058 = torch.ops.aten.add.Tensor(view_2137, view_2142); view_2137 = view_2142 = None + convert_element_type_2857 = torch.ops.prims.convert_element_type.default(mm_520, torch.float32); mm_520 = None + reduce_scatter_tensor_266 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2857, 'avg', 64, '0'); convert_element_type_2857 = None + wait_tensor_863 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_266); reduce_scatter_tensor_266 = None + convert_element_type_2858 = torch.ops.prims.convert_element_type.default(add_2058, torch.float32); add_2058 = None + convert_element_type_415 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16); primals_135 = None + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_415, 64, '0'); convert_element_type_415 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + convert_element_type_2860 = torch.ops.prims.convert_element_type.default(wait_tensor_158, torch.float32); wait_tensor_158 = None + mul_1931 = torch.ops.aten.mul.Tensor(convert_element_type_2858, convert_element_type_2860); convert_element_type_2860 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(add_481, torch.float32); add_481 = None + mul_352 = torch.ops.aten.mul.Tensor(convert_element_type_416, rsqrt_24); convert_element_type_416 = None + mul_1933 = torch.ops.aten.mul.Tensor(mul_352, mul_1931) + sum_257 = torch.ops.aten.sum.dim_IntList(mul_1933, [2], True); mul_1933 = None + div_245 = torch.ops.aten.div.Tensor(mul_352, 2048) + mul_1934 = torch.ops.aten.mul.Tensor(div_245, sum_257); div_245 = sum_257 = None + sub_738 = torch.ops.aten.sub.Tensor(mul_1931, mul_1934); mul_1931 = mul_1934 = None + mul_1935 = torch.ops.aten.mul.Tensor(sub_738, rsqrt_24); sub_738 = rsqrt_24 = None + mul_1936 = torch.ops.aten.mul.Tensor(convert_element_type_2858, mul_352); convert_element_type_2858 = mul_352 = None + sum_258 = torch.ops.aten.sum.dim_IntList(mul_1936, [0, 1]); mul_1936 = None + convert_element_type_2861 = torch.ops.prims.convert_element_type.default(mul_1935, torch.bfloat16); mul_1935 = None + add_2059 = torch.ops.aten.add.Tensor(add_2057, convert_element_type_2861); add_2057 = convert_element_type_2861 = None + convert_element_type_default_25 = torch.ops.prims.convert_element_type.default(sum_258, torch.float32); sum_258 = None + reduce_scatter_tensor_267 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_25, 'avg', 64, '0'); convert_element_type_default_25 = None + wait_tensor_864 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_267); reduce_scatter_tensor_267 = None + view_2143 = torch.ops.aten.view.default(add_2059, [8192, 2048]) + unsqueeze_72 = torch.ops.aten.unsqueeze.default(view_2143, 1) + convert_element_type_2864 = torch.ops.prims.convert_element_type.default(unsqueeze_72, torch.float32); unsqueeze_72 = None + bmm_64 = torch.ops.aten.bmm.default(permute_1356, convert_element_type_2864); permute_1356 = None + bmm_65 = torch.ops.aten.bmm.default(convert_element_type_2864, permute_1357); convert_element_type_2864 = permute_1357 = None + convert_element_type_2865 = torch.ops.prims.convert_element_type.default(bmm_64, torch.bfloat16); bmm_64 = None + view_2144 = torch.ops.aten.view.default(bmm_65, [8192, 6]); bmm_65 = None + view_2145 = torch.ops.aten.view.default(convert_element_type_2865, [49152, 2048]); convert_element_type_2865 = None + index_90 = torch.ops.aten.index.Tensor(view_2145, [getitem_105]); view_2145 = getitem_105 = None + permute_1358 = torch.ops.aten.permute.default(view_2143, [1, 0]) + mm_522 = torch.ops.aten.mm.default(permute_1358, mul_349); permute_1358 = mul_349 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16); primals_134 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_410, 64, '0'); convert_element_type_410 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_115 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + permute_1360 = torch.ops.aten.permute.default(permute_115, [1, 0]); permute_115 = None + mm_523 = torch.ops.aten.mm.default(view_2143, permute_1360); view_2143 = permute_1360 = None + convert_element_type_2870 = torch.ops.prims.convert_element_type.default(mm_522, torch.float32); mm_522 = None + reduce_scatter_tensor_268 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2870, 'avg', 64, '0'); convert_element_type_2870 = None + wait_tensor_865 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_268); reduce_scatter_tensor_268 = None + convert_element_type_405 = torch.ops.prims.convert_element_type.default(mm_60, torch.float32); mm_60 = None + neg_14 = torch.ops.aten.neg.default(convert_element_type_405) + exp_21 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_476 = torch.ops.aten.add.Tensor(exp_21, 1); exp_21 = None + div_35 = torch.ops.aten.div.Tensor(convert_element_type_405, add_476) + convert_element_type_406 = torch.ops.prims.convert_element_type.default(div_35, torch.bfloat16); div_35 = None + mul_1937 = torch.ops.aten.mul.Tensor(mm_523, convert_element_type_406); convert_element_type_406 = None + mul_1938 = torch.ops.aten.mul.Tensor(mm_523, mm_61); mm_523 = mm_61 = None + permute_1362 = torch.ops.aten.permute.default(mul_1937, [1, 0]) + mm_524 = torch.ops.aten.mm.default(permute_1362, view_460); permute_1362 = None + convert_element_type_407 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16); primals_133 = None + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_407, 64, '0'); convert_element_type_407 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_114 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + permute_1364 = torch.ops.aten.permute.default(permute_114, [1, 0]); permute_114 = None + mm_525 = torch.ops.aten.mm.default(mul_1937, permute_1364); mul_1937 = permute_1364 = None + convert_element_type_2875 = torch.ops.prims.convert_element_type.default(mm_524, torch.float32); mm_524 = None + reduce_scatter_tensor_269 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2875, 'avg', 64, '0'); convert_element_type_2875 = None + wait_tensor_866 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_269); reduce_scatter_tensor_269 = None + convert_element_type_2876 = torch.ops.prims.convert_element_type.default(mul_1938, torch.float32); mul_1938 = None + reciprocal_38 = torch.ops.aten.reciprocal.default(add_476); add_476 = None + mul_1939 = torch.ops.aten.mul.Tensor(reciprocal_38, 1); reciprocal_38 = None + mul_1940 = torch.ops.aten.mul.Tensor(convert_element_type_2876, mul_1939); convert_element_type_2876 = None + sub_739 = torch.ops.aten.sub.Tensor(1, mul_1939); mul_1939 = None + mul_1941 = torch.ops.aten.mul.Tensor(convert_element_type_405, sub_739); convert_element_type_405 = sub_739 = None + add_2061 = torch.ops.aten.add.Tensor(mul_1941, 1); mul_1941 = None + mul_1942 = torch.ops.aten.mul.Tensor(mul_1940, add_2061); mul_1940 = add_2061 = None + convert_element_type_2878 = torch.ops.prims.convert_element_type.default(mul_1942, torch.bfloat16); mul_1942 = None + permute_1366 = torch.ops.aten.permute.default(convert_element_type_2878, [1, 0]) + mm_526 = torch.ops.aten.mm.default(permute_1366, view_460); permute_1366 = None + convert_element_type_402 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16); primals_132 = None + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_402, 64, '0'); convert_element_type_402 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_113 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + permute_1368 = torch.ops.aten.permute.default(permute_113, [1, 0]); permute_113 = None + mm_527 = torch.ops.aten.mm.default(convert_element_type_2878, permute_1368); convert_element_type_2878 = permute_1368 = None + add_2062 = torch.ops.aten.add.Tensor(mm_525, mm_527); mm_525 = mm_527 = None + convert_element_type_2883 = torch.ops.prims.convert_element_type.default(mm_526, torch.float32); mm_526 = None + reduce_scatter_tensor_270 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2883, 'avg', 64, '0'); convert_element_type_2883 = None + wait_tensor_867 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_270); reduce_scatter_tensor_270 = None + all_to_all_single_116 = torch.ops._c10d_functional.all_to_all_single.default(index_90, [_local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111], [_local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103], '521'); index_90 = None + wait_tensor_868 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_116); all_to_all_single_116 = None + full_424 = torch.ops.aten.full.default([sym_size_int_25, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_25 = None + slice_scatter_19 = torch.ops.aten.slice_scatter.default(full_424, wait_tensor_868, 0, 0, -1); wait_tensor_868 = None + index_91 = torch.ops.aten.index.Tensor(slice_scatter_19, [getitem_106]); slice_scatter_19 = None + permute_1370 = torch.ops.aten.permute.default(index_91, [1, 0]) + _grouped_mm_192 = torch.ops.aten._grouped_mm.default(permute_1370, mul_329, cumsum_20); permute_1370 = mul_329 = None + convert_element_type_396 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16); primals_130 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_396, 8, '513'); convert_element_type_396 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_150, [0, 2, 1]); wait_tensor_150 = None + permute_1372 = torch.ops.aten.permute.default(permute_112, [0, 2, 1]); permute_112 = None + _grouped_mm_193 = torch.ops.aten._grouped_mm.default(index_91, permute_1372, cumsum_20); index_91 = permute_1372 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(_grouped_mm_18, torch.float32); _grouped_mm_18 = None + neg_13 = torch.ops.aten.neg.default(convert_element_type_400) + exp_20 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_440 = torch.ops.aten.add.Tensor(exp_20, 1); exp_20 = None + div_34 = torch.ops.aten.div.Tensor(convert_element_type_400, add_440) + convert_element_type_401 = torch.ops.prims.convert_element_type.default(div_34, torch.bfloat16); div_34 = None + mul_1943 = torch.ops.aten.mul.Tensor(_grouped_mm_193, convert_element_type_401); convert_element_type_401 = None + mul_1944 = torch.ops.aten.mul.Tensor(_grouped_mm_193, _grouped_mm_19); _grouped_mm_193 = _grouped_mm_19 = None + permute_1374 = torch.ops.aten.permute.default(mul_1943, [1, 0]) + _grouped_mm_194 = torch.ops.aten._grouped_mm.default(permute_1374, index_13, cumsum_20); permute_1374 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16); primals_131 = None + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 8, '513'); convert_element_type_397 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_151, [0, 2, 1]); wait_tensor_151 = None + permute_1376 = torch.ops.aten.permute.default(permute_111, [0, 2, 1]); permute_111 = None + _grouped_mm_195 = torch.ops.aten._grouped_mm.default(mul_1943, permute_1376, cumsum_20); mul_1943 = permute_1376 = None + convert_element_type_2884 = torch.ops.prims.convert_element_type.default(mul_1944, torch.float32); mul_1944 = None + reciprocal_39 = torch.ops.aten.reciprocal.default(add_440); add_440 = None + mul_1945 = torch.ops.aten.mul.Tensor(reciprocal_39, 1); reciprocal_39 = None + mul_1946 = torch.ops.aten.mul.Tensor(convert_element_type_2884, mul_1945); convert_element_type_2884 = None + sub_740 = torch.ops.aten.sub.Tensor(1, mul_1945); mul_1945 = None + mul_1947 = torch.ops.aten.mul.Tensor(convert_element_type_400, sub_740); convert_element_type_400 = sub_740 = None + add_2064 = torch.ops.aten.add.Tensor(mul_1947, 1); mul_1947 = None + mul_1948 = torch.ops.aten.mul.Tensor(mul_1946, add_2064); mul_1946 = add_2064 = None + convert_element_type_2886 = torch.ops.prims.convert_element_type.default(mul_1948, torch.bfloat16); mul_1948 = None + permute_1378 = torch.ops.aten.permute.default(convert_element_type_2886, [1, 0]) + _grouped_mm_196 = torch.ops.aten._grouped_mm.default(permute_1378, index_13, cumsum_20); permute_1378 = index_13 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16); primals_129 = None + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 8, '513'); convert_element_type_394 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_148, [0, 2, 1]); wait_tensor_148 = None + permute_1380 = torch.ops.aten.permute.default(permute_110, [0, 2, 1]); permute_110 = None + _grouped_mm_197 = torch.ops.aten._grouped_mm.default(convert_element_type_2886, permute_1380, cumsum_20); convert_element_type_2886 = permute_1380 = cumsum_20 = None + add_2065 = torch.ops.aten.add.Tensor(_grouped_mm_195, _grouped_mm_197); _grouped_mm_195 = _grouped_mm_197 = None + convert_element_type_2887 = torch.ops.prims.convert_element_type.default(_grouped_mm_194, torch.float32); _grouped_mm_194 = None + div_246 = torch.ops.aten.div.Tensor(convert_element_type_2887, 64); convert_element_type_2887 = None + reduce_scatter_tensor_271 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_246, 'sum', 8, '513'); div_246 = None + wait_tensor_869 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_271); reduce_scatter_tensor_271 = None + convert_element_type_2888 = torch.ops.prims.convert_element_type.default(_grouped_mm_192, torch.float32); _grouped_mm_192 = None + div_247 = torch.ops.aten.div.Tensor(convert_element_type_2888, 64); convert_element_type_2888 = None + reduce_scatter_tensor_272 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_247, 'sum', 8, '513'); div_247 = None + wait_tensor_870 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_272); reduce_scatter_tensor_272 = None + convert_element_type_2889 = torch.ops.prims.convert_element_type.default(_grouped_mm_196, torch.float32); _grouped_mm_196 = None + div_248 = torch.ops.aten.div.Tensor(convert_element_type_2889, 64); convert_element_type_2889 = None + reduce_scatter_tensor_273 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_248, 'sum', 8, '513'); div_248 = None + wait_tensor_871 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_273); reduce_scatter_tensor_273 = None + index_put_90 = torch.ops.aten.index_put.default(full_424, [getitem_106], add_2065, True); full_424 = getitem_106 = add_2065 = None + slice_221 = torch.ops.aten.slice.Tensor(index_put_90, 0, 0, add_2066); index_put_90 = add_2066 = None + all_to_all_single_117 = torch.ops._c10d_functional.all_to_all_single.default(slice_221, [_local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103], [_local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111], '521'); slice_221 = _local_scalar_dense_96 = _local_scalar_dense_97 = _local_scalar_dense_98 = _local_scalar_dense_99 = _local_scalar_dense_100 = _local_scalar_dense_101 = _local_scalar_dense_102 = _local_scalar_dense_103 = _local_scalar_dense_104 = _local_scalar_dense_105 = _local_scalar_dense_106 = _local_scalar_dense_107 = _local_scalar_dense_108 = _local_scalar_dense_109 = _local_scalar_dense_110 = _local_scalar_dense_111 = None + wait_tensor_872 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_117); all_to_all_single_117 = None + index_put_91 = torch.ops.aten.index_put.default(full_default_52, [div_32], wait_tensor_872, True); div_32 = wait_tensor_872 = None + add_2070 = torch.ops.aten.add.Tensor(add_2062, index_put_91); add_2062 = index_put_91 = None + mul_1949 = torch.ops.aten.mul.Tensor(view_2144, 1.0); view_2144 = None + scatter_add_19 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_103, mul_1949); getitem_103 = mul_1949 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(mm_59, torch.float32); mm_59 = None + sub_144 = torch.ops.aten.sub.Tensor(convert_element_type_389, amax_6); convert_element_type_389 = amax_6 = None + exp_19 = torch.ops.aten.exp.default(sub_144); sub_144 = None + div_31 = torch.ops.aten.div.Tensor(exp_19, sum_25); exp_19 = sum_25 = None + mul_1950 = torch.ops.aten.mul.Tensor(scatter_add_19, div_31); scatter_add_19 = None + sum_259 = torch.ops.aten.sum.dim_IntList(mul_1950, [1], True) + neg_112 = torch.ops.aten.neg.default(div_31); div_31 = None + fma_19 = torch.ops.prims.fma.default(neg_112, sum_259, mul_1950); neg_112 = sum_259 = mul_1950 = None + convert_element_type_2890 = torch.ops.prims.convert_element_type.default(fma_19, torch.bfloat16); fma_19 = None + permute_1382 = torch.ops.aten.permute.default(convert_element_type_2890, [1, 0]) + mm_528 = torch.ops.aten.mm.default(permute_1382, view_460); permute_1382 = view_460 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16); primals_127 = None + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 64, '0'); convert_element_type_386 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_144, [1, 0]); wait_tensor_144 = None + permute_1384 = torch.ops.aten.permute.default(permute_109, [1, 0]); permute_109 = None + mm_529 = torch.ops.aten.mm.default(convert_element_type_2890, permute_1384); convert_element_type_2890 = permute_1384 = None + add_2071 = torch.ops.aten.add.Tensor(add_2070, mm_529); add_2070 = mm_529 = None + convert_element_type_2895 = torch.ops.prims.convert_element_type.default(mm_528, torch.float32); mm_528 = None + reduce_scatter_tensor_274 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2895, 'avg', 64, '0'); convert_element_type_2895 = None + wait_tensor_873 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_274); reduce_scatter_tensor_274 = None + view_2146 = torch.ops.aten.view.default(add_2071, [2, 4096, 2048]); add_2071 = None + convert_element_type_2896 = torch.ops.prims.convert_element_type.default(view_2146, torch.float32); view_2146 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 64, '0'); convert_element_type_383 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + convert_element_type_2898 = torch.ops.prims.convert_element_type.default(wait_tensor_143, torch.float32); wait_tensor_143 = None + mul_1951 = torch.ops.aten.mul.Tensor(convert_element_type_2896, convert_element_type_2898); convert_element_type_2898 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_416, torch.float32); add_416 = None + mul_309 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_1953 = torch.ops.aten.mul.Tensor(mul_309, mul_1951) + sum_260 = torch.ops.aten.sum.dim_IntList(mul_1953, [2], True); mul_1953 = None + div_249 = torch.ops.aten.div.Tensor(mul_309, 2048) + mul_1954 = torch.ops.aten.mul.Tensor(div_249, sum_260); div_249 = sum_260 = None + sub_742 = torch.ops.aten.sub.Tensor(mul_1951, mul_1954); mul_1951 = mul_1954 = None + mul_1955 = torch.ops.aten.mul.Tensor(sub_742, rsqrt_23); sub_742 = rsqrt_23 = None + mul_1956 = torch.ops.aten.mul.Tensor(convert_element_type_2896, mul_309); convert_element_type_2896 = mul_309 = None + sum_261 = torch.ops.aten.sum.dim_IntList(mul_1956, [0, 1]); mul_1956 = None + convert_element_type_2899 = torch.ops.prims.convert_element_type.default(mul_1955, torch.bfloat16); mul_1955 = None + add_2072 = torch.ops.aten.add.Tensor(add_2059, convert_element_type_2899); add_2059 = convert_element_type_2899 = None + convert_element_type_default_24 = torch.ops.prims.convert_element_type.default(sum_261, torch.float32); sum_261 = None + reduce_scatter_tensor_275 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_24, 'avg', 64, '0'); convert_element_type_default_24 = None + wait_tensor_874 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_275); reduce_scatter_tensor_275 = None + view_2147 = torch.ops.aten.view.default(add_2072, [8192, 2048]) + permute_1386 = torch.ops.aten.permute.default(view_2147, [1, 0]) + permute_107 = torch.ops.aten.permute.default(getitem_99, [0, 2, 1, 3]) + view_455 = torch.ops.aten.view.default(permute_107, [2, 4096, -1]); permute_107 = None + view_457 = torch.ops.aten.view.default(view_455, [8192, 2048]); view_455 = None + mm_530 = torch.ops.aten.mm.default(permute_1386, view_457); permute_1386 = view_457 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 64, '0'); convert_element_type_380 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + permute_1388 = torch.ops.aten.permute.default(permute_108, [1, 0]); permute_108 = None + mm_531 = torch.ops.aten.mm.default(view_2147, permute_1388); view_2147 = permute_1388 = None + view_2148 = torch.ops.aten.view.default(mm_531, [2, 4096, 2048]); mm_531 = None + convert_element_type_2906 = torch.ops.prims.convert_element_type.default(mm_530, torch.float32); mm_530 = None + reduce_scatter_tensor_276 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2906, 'avg', 64, '0'); convert_element_type_2906 = None + wait_tensor_875 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_276); reduce_scatter_tensor_276 = None + view_2149 = torch.ops.aten.view.default(view_2148, [2, 4096, 16, 128]); view_2148 = None + permute_1390 = torch.ops.aten.permute.default(view_2149, [0, 2, 1, 3]); view_2149 = None + fw_graph19 = self.fw_graph19 + joint_graph19 = self.joint_graph19 + mask_graph19 = self.mask_graph19 + flex_attention_backward_19 = torch.ops.higher_order.flex_attention_backward(permute_104, permute_105, permute_106, getitem_99, getitem_100, permute_1390, None, fw_graph19, joint_graph19, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph19), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_104 = permute_105 = permute_106 = getitem_99 = getitem_100 = permute_1390 = fw_graph19 = joint_graph19 = mask_graph19 = None + getitem_449 = flex_attention_backward_19[0] + getitem_450 = flex_attention_backward_19[1] + getitem_451 = flex_attention_backward_19[2]; flex_attention_backward_19 = None + permute_1391 = torch.ops.aten.permute.default(getitem_451, [0, 2, 1, 3]); getitem_451 = None + permute_1392 = torch.ops.aten.permute.default(getitem_450, [0, 2, 1, 3]); getitem_450 = None + permute_1393 = torch.ops.aten.permute.default(getitem_449, [0, 2, 1, 3]); getitem_449 = None + slice_223 = torch.ops.aten.slice.Tensor(permute_1392, 3, 0, 128) + slice_224 = torch.ops.aten.slice.Tensor(permute_1392, 3, 128, 192); permute_1392 = None + sum_262 = torch.ops.aten.sum.dim_IntList(slice_224, [2], True); slice_224 = None + cat_137 = torch.ops.aten.cat.default([slice_223, permute_1391], 3); slice_223 = permute_1391 = None + view_2150 = torch.ops.aten.view.default(cat_137, [2, 4096, 4096]); cat_137 = None + view_2151 = torch.ops.aten.view.default(view_2150, [8192, 4096]); view_2150 = None + permute_1394 = torch.ops.aten.permute.default(view_2151, [1, 0]) + mm_532 = torch.ops.aten.mm.default(permute_1394, view_452); permute_1394 = view_452 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_377, 64, '0'); convert_element_type_377 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_103 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + permute_1396 = torch.ops.aten.permute.default(permute_103, [1, 0]); permute_103 = None + mm_533 = torch.ops.aten.mm.default(view_2151, permute_1396); view_2151 = permute_1396 = None + view_2152 = torch.ops.aten.view.default(mm_533, [2, 4096, 512]); mm_533 = None + convert_element_type_2911 = torch.ops.prims.convert_element_type.default(mm_532, torch.float32); mm_532 = None + reduce_scatter_tensor_277 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2911, 'avg', 64, '0'); convert_element_type_2911 = None + wait_tensor_876 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_277); reduce_scatter_tensor_277 = None + convert_element_type_2912 = torch.ops.prims.convert_element_type.default(view_2152, torch.float32); view_2152 = None + convert_element_type_374 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_374, 64, '0'); convert_element_type_374 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + convert_element_type_2914 = torch.ops.prims.convert_element_type.default(wait_tensor_140, torch.float32); wait_tensor_140 = None + mul_1957 = torch.ops.aten.mul.Tensor(convert_element_type_2912, convert_element_type_2914); convert_element_type_2914 = None + convert_element_type_375 = torch.ops.prims.convert_element_type.default(getitem_95, torch.float32); getitem_95 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_375, rsqrt_22); convert_element_type_375 = None + mul_1959 = torch.ops.aten.mul.Tensor(mul_307, mul_1957) + sum_263 = torch.ops.aten.sum.dim_IntList(mul_1959, [2], True); mul_1959 = None + div_250 = torch.ops.aten.div.Tensor(mul_307, 512) + mul_1960 = torch.ops.aten.mul.Tensor(div_250, sum_263); div_250 = sum_263 = None + sub_743 = torch.ops.aten.sub.Tensor(mul_1957, mul_1960); mul_1957 = mul_1960 = None + mul_1961 = torch.ops.aten.mul.Tensor(sub_743, rsqrt_22); sub_743 = rsqrt_22 = None + mul_1962 = torch.ops.aten.mul.Tensor(convert_element_type_2912, mul_307); convert_element_type_2912 = mul_307 = None + sum_264 = torch.ops.aten.sum.dim_IntList(mul_1962, [0, 1]); mul_1962 = None + convert_element_type_2915 = torch.ops.prims.convert_element_type.default(mul_1961, torch.bfloat16); mul_1961 = None + convert_element_type_default_23 = torch.ops.prims.convert_element_type.default(sum_264, torch.float32); sum_264 = None + reduce_scatter_tensor_278 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_23, 'avg', 64, '0'); convert_element_type_default_23 = None + wait_tensor_877 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_278); reduce_scatter_tensor_278 = None + convert_element_type_2918 = torch.ops.prims.convert_element_type.default(sum_262, torch.float32); sum_262 = None + view_2153 = torch.ops.aten.view.default(convert_element_type_2918, [2, 4096, 1, 32, 2]); convert_element_type_2918 = None + view_as_complex_92 = torch.ops.aten.view_as_complex.default(view_2153); view_2153 = None + mul_1963 = torch.ops.aten.mul.Tensor(view_as_complex_92, clone_9); view_as_complex_92 = None + view_as_real_92 = torch.ops.aten.view_as_real.default(mul_1963); mul_1963 = None + view_2154 = torch.ops.aten.view.default(view_as_real_92, [2, 4096, 1, 64]); view_as_real_92 = None + convert_element_type_2919 = torch.ops.prims.convert_element_type.default(view_2154, torch.bfloat16); view_2154 = None + squeeze_45 = torch.ops.aten.squeeze.dim(convert_element_type_2919, 2); convert_element_type_2919 = None + cat_138 = torch.ops.aten.cat.default([convert_element_type_2915, squeeze_45], 2); convert_element_type_2915 = squeeze_45 = None + view_2155 = torch.ops.aten.view.default(cat_138, [8192, 576]); cat_138 = None + permute_1398 = torch.ops.aten.permute.default(view_2155, [1, 0]) + mm_534 = torch.ops.aten.mm.default(permute_1398, view_438); permute_1398 = None + convert_element_type_369 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_369, 64, '0'); convert_element_type_369 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_102 = torch.ops.aten.permute.default(wait_tensor_139, [1, 0]); wait_tensor_139 = None + permute_1400 = torch.ops.aten.permute.default(permute_102, [1, 0]); permute_102 = None + mm_535 = torch.ops.aten.mm.default(view_2155, permute_1400); view_2155 = permute_1400 = None + view_2156 = torch.ops.aten.view.default(mm_535, [2, 4096, 2048]); mm_535 = None + convert_element_type_2924 = torch.ops.prims.convert_element_type.default(mm_534, torch.float32); mm_534 = None + reduce_scatter_tensor_279 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2924, 'avg', 64, '0'); convert_element_type_2924 = None + wait_tensor_878 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_279); reduce_scatter_tensor_279 = None + slice_225 = torch.ops.aten.slice.Tensor(permute_1393, 3, 0, 128) + slice_226 = torch.ops.aten.slice.Tensor(permute_1393, 3, 128, 192); permute_1393 = None + convert_element_type_2925 = torch.ops.prims.convert_element_type.default(slice_226, torch.float32); slice_226 = None + view_2157 = torch.ops.aten.view.default(convert_element_type_2925, [2, 4096, 16, 32, 2]); convert_element_type_2925 = None + view_as_complex_93 = torch.ops.aten.view_as_complex.default(view_2157); view_2157 = None + mul_1964 = torch.ops.aten.mul.Tensor(view_as_complex_93, clone_9); view_as_complex_93 = None + view_as_real_93 = torch.ops.aten.view_as_real.default(mul_1964); mul_1964 = None + view_2158 = torch.ops.aten.view.default(view_as_real_93, [2, 4096, 16, 64]); view_as_real_93 = None + convert_element_type_2926 = torch.ops.prims.convert_element_type.default(view_2158, torch.bfloat16); view_2158 = None + cat_139 = torch.ops.aten.cat.default([slice_225, convert_element_type_2926], 3); slice_225 = convert_element_type_2926 = None + view_2159 = torch.ops.aten.view.default(cat_139, [2, 4096, 3072]); cat_139 = None + view_2160 = torch.ops.aten.view.default(view_2159, [8192, 3072]); view_2159 = None + permute_1402 = torch.ops.aten.permute.default(view_2160, [1, 0]) + mm_536 = torch.ops.aten.mm.default(permute_1402, view_438); permute_1402 = view_438 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 64, '0'); convert_element_type_364 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + permute_1404 = torch.ops.aten.permute.default(permute_101, [1, 0]); permute_101 = None + mm_537 = torch.ops.aten.mm.default(view_2160, permute_1404); view_2160 = permute_1404 = None + view_2161 = torch.ops.aten.view.default(mm_537, [2, 4096, 2048]); mm_537 = None + add_2073 = torch.ops.aten.add.Tensor(view_2156, view_2161); view_2156 = view_2161 = None + convert_element_type_2931 = torch.ops.prims.convert_element_type.default(mm_536, torch.float32); mm_536 = None + reduce_scatter_tensor_280 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2931, 'avg', 64, '0'); convert_element_type_2931 = None + wait_tensor_879 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_280); reduce_scatter_tensor_280 = None + convert_element_type_2932 = torch.ops.prims.convert_element_type.default(add_2073, torch.float32); add_2073 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 64, '0'); convert_element_type_361 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + convert_element_type_2934 = torch.ops.prims.convert_element_type.default(wait_tensor_137, torch.float32); wait_tensor_137 = None + mul_1965 = torch.ops.aten.mul.Tensor(convert_element_type_2932, convert_element_type_2934); convert_element_type_2934 = None + convert_element_type_362 = torch.ops.prims.convert_element_type.default(add_413, torch.float32); add_413 = None + mul_303 = torch.ops.aten.mul.Tensor(convert_element_type_362, rsqrt_21); convert_element_type_362 = None + mul_1967 = torch.ops.aten.mul.Tensor(mul_303, mul_1965) + sum_265 = torch.ops.aten.sum.dim_IntList(mul_1967, [2], True); mul_1967 = None + div_251 = torch.ops.aten.div.Tensor(mul_303, 2048) + mul_1968 = torch.ops.aten.mul.Tensor(div_251, sum_265); div_251 = sum_265 = None + sub_744 = torch.ops.aten.sub.Tensor(mul_1965, mul_1968); mul_1965 = mul_1968 = None + mul_1969 = torch.ops.aten.mul.Tensor(sub_744, rsqrt_21); sub_744 = rsqrt_21 = None + mul_1970 = torch.ops.aten.mul.Tensor(convert_element_type_2932, mul_303); convert_element_type_2932 = mul_303 = None + sum_266 = torch.ops.aten.sum.dim_IntList(mul_1970, [0, 1]); mul_1970 = None + convert_element_type_2935 = torch.ops.prims.convert_element_type.default(mul_1969, torch.bfloat16); mul_1969 = None + add_2074 = torch.ops.aten.add.Tensor(add_2072, convert_element_type_2935); add_2072 = convert_element_type_2935 = None + convert_element_type_default_22 = torch.ops.prims.convert_element_type.default(sum_266, torch.float32); sum_266 = None + reduce_scatter_tensor_281 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_22, 'avg', 64, '0'); convert_element_type_default_22 = None + wait_tensor_880 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_281); reduce_scatter_tensor_281 = None + view_2162 = torch.ops.aten.view.default(add_2074, [8192, 2048]) + unsqueeze_73 = torch.ops.aten.unsqueeze.default(view_2162, 1) + convert_element_type_2938 = torch.ops.prims.convert_element_type.default(unsqueeze_73, torch.float32); unsqueeze_73 = None + bmm_66 = torch.ops.aten.bmm.default(permute_1406, convert_element_type_2938); permute_1406 = None + bmm_67 = torch.ops.aten.bmm.default(convert_element_type_2938, permute_1407); convert_element_type_2938 = permute_1407 = None + convert_element_type_2939 = torch.ops.prims.convert_element_type.default(bmm_66, torch.bfloat16); bmm_66 = None + view_2163 = torch.ops.aten.view.default(bmm_67, [8192, 6]); bmm_67 = None + view_2164 = torch.ops.aten.view.default(convert_element_type_2939, [49152, 2048]); convert_element_type_2939 = None + index_92 = torch.ops.aten.index.Tensor(view_2164, [getitem_91]); view_2164 = getitem_91 = None + permute_1408 = torch.ops.aten.permute.default(view_2162, [1, 0]) + mm_538 = torch.ops.aten.mm.default(permute_1408, mul_300); permute_1408 = mul_300 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_356, 64, '0'); convert_element_type_356 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + permute_1410 = torch.ops.aten.permute.default(permute_100, [1, 0]); permute_100 = None + mm_539 = torch.ops.aten.mm.default(view_2162, permute_1410); view_2162 = permute_1410 = None + convert_element_type_2944 = torch.ops.prims.convert_element_type.default(mm_538, torch.float32); mm_538 = None + reduce_scatter_tensor_282 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2944, 'avg', 64, '0'); convert_element_type_2944 = None + wait_tensor_881 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_282); reduce_scatter_tensor_282 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(mm_52, torch.float32); mm_52 = None + neg_12 = torch.ops.aten.neg.default(convert_element_type_351) + exp_18 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_408 = torch.ops.aten.add.Tensor(exp_18, 1); exp_18 = None + div_30 = torch.ops.aten.div.Tensor(convert_element_type_351, add_408) + convert_element_type_352 = torch.ops.prims.convert_element_type.default(div_30, torch.bfloat16); div_30 = None + mul_1971 = torch.ops.aten.mul.Tensor(mm_539, convert_element_type_352); convert_element_type_352 = None + mul_1972 = torch.ops.aten.mul.Tensor(mm_539, mm_53); mm_539 = mm_53 = None + permute_1412 = torch.ops.aten.permute.default(mul_1971, [1, 0]) + mm_540 = torch.ops.aten.mm.default(permute_1412, view_393); permute_1412 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16); primals_117 = None + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 64, '0'); convert_element_type_353 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + permute_1414 = torch.ops.aten.permute.default(permute_99, [1, 0]); permute_99 = None + mm_541 = torch.ops.aten.mm.default(mul_1971, permute_1414); mul_1971 = permute_1414 = None + convert_element_type_2949 = torch.ops.prims.convert_element_type.default(mm_540, torch.float32); mm_540 = None + reduce_scatter_tensor_283 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2949, 'avg', 64, '0'); convert_element_type_2949 = None + wait_tensor_882 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_283); reduce_scatter_tensor_283 = None + convert_element_type_2950 = torch.ops.prims.convert_element_type.default(mul_1972, torch.float32); mul_1972 = None + reciprocal_40 = torch.ops.aten.reciprocal.default(add_408); add_408 = None + mul_1973 = torch.ops.aten.mul.Tensor(reciprocal_40, 1); reciprocal_40 = None + mul_1974 = torch.ops.aten.mul.Tensor(convert_element_type_2950, mul_1973); convert_element_type_2950 = None + sub_745 = torch.ops.aten.sub.Tensor(1, mul_1973); mul_1973 = None + mul_1975 = torch.ops.aten.mul.Tensor(convert_element_type_351, sub_745); convert_element_type_351 = sub_745 = None + add_2076 = torch.ops.aten.add.Tensor(mul_1975, 1); mul_1975 = None + mul_1976 = torch.ops.aten.mul.Tensor(mul_1974, add_2076); mul_1974 = add_2076 = None + convert_element_type_2952 = torch.ops.prims.convert_element_type.default(mul_1976, torch.bfloat16); mul_1976 = None + permute_1416 = torch.ops.aten.permute.default(convert_element_type_2952, [1, 0]) + mm_542 = torch.ops.aten.mm.default(permute_1416, view_393); permute_1416 = None + convert_element_type_348 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16); primals_116 = None + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_348, 64, '0'); convert_element_type_348 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + permute_1418 = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None + mm_543 = torch.ops.aten.mm.default(convert_element_type_2952, permute_1418); convert_element_type_2952 = permute_1418 = None + add_2077 = torch.ops.aten.add.Tensor(mm_541, mm_543); mm_541 = mm_543 = None + convert_element_type_2957 = torch.ops.prims.convert_element_type.default(mm_542, torch.float32); mm_542 = None + reduce_scatter_tensor_284 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2957, 'avg', 64, '0'); convert_element_type_2957 = None + wait_tensor_883 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_284); reduce_scatter_tensor_284 = None + all_to_all_single_118 = torch.ops._c10d_functional.all_to_all_single.default(index_92, [_local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95], [_local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87], '521'); index_92 = None + wait_tensor_884 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_118); all_to_all_single_118 = None + full_428 = torch.ops.aten.full.default([sym_size_int_21, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_21 = None + slice_scatter_20 = torch.ops.aten.slice_scatter.default(full_428, wait_tensor_884, 0, 0, -1); wait_tensor_884 = None + index_93 = torch.ops.aten.index.Tensor(slice_scatter_20, [getitem_92]); slice_scatter_20 = None + permute_1420 = torch.ops.aten.permute.default(index_93, [1, 0]) + _grouped_mm_198 = torch.ops.aten._grouped_mm.default(permute_1420, mul_280, cumsum_17); permute_1420 = mul_280 = None + convert_element_type_342 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16); primals_114 = None + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_342, 8, '513'); convert_element_type_342 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_129, [0, 2, 1]); wait_tensor_129 = None + permute_1422 = torch.ops.aten.permute.default(permute_97, [0, 2, 1]); permute_97 = None + _grouped_mm_199 = torch.ops.aten._grouped_mm.default(index_93, permute_1422, cumsum_17); index_93 = permute_1422 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(_grouped_mm_15, torch.float32); _grouped_mm_15 = None + neg_11 = torch.ops.aten.neg.default(convert_element_type_346) + exp_17 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_372 = torch.ops.aten.add.Tensor(exp_17, 1); exp_17 = None + div_29 = torch.ops.aten.div.Tensor(convert_element_type_346, add_372) + convert_element_type_347 = torch.ops.prims.convert_element_type.default(div_29, torch.bfloat16); div_29 = None + mul_1977 = torch.ops.aten.mul.Tensor(_grouped_mm_199, convert_element_type_347); convert_element_type_347 = None + mul_1978 = torch.ops.aten.mul.Tensor(_grouped_mm_199, _grouped_mm_16); _grouped_mm_199 = _grouped_mm_16 = None + permute_1424 = torch.ops.aten.permute.default(mul_1977, [1, 0]) + _grouped_mm_200 = torch.ops.aten._grouped_mm.default(permute_1424, index_11, cumsum_17); permute_1424 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16); primals_115 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_343, 8, '513'); convert_element_type_343 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_130, [0, 2, 1]); wait_tensor_130 = None + permute_1426 = torch.ops.aten.permute.default(permute_96, [0, 2, 1]); permute_96 = None + _grouped_mm_201 = torch.ops.aten._grouped_mm.default(mul_1977, permute_1426, cumsum_17); mul_1977 = permute_1426 = None + convert_element_type_2958 = torch.ops.prims.convert_element_type.default(mul_1978, torch.float32); mul_1978 = None + reciprocal_41 = torch.ops.aten.reciprocal.default(add_372); add_372 = None + mul_1979 = torch.ops.aten.mul.Tensor(reciprocal_41, 1); reciprocal_41 = None + mul_1980 = torch.ops.aten.mul.Tensor(convert_element_type_2958, mul_1979); convert_element_type_2958 = None + sub_746 = torch.ops.aten.sub.Tensor(1, mul_1979); mul_1979 = None + mul_1981 = torch.ops.aten.mul.Tensor(convert_element_type_346, sub_746); convert_element_type_346 = sub_746 = None + add_2079 = torch.ops.aten.add.Tensor(mul_1981, 1); mul_1981 = None + mul_1982 = torch.ops.aten.mul.Tensor(mul_1980, add_2079); mul_1980 = add_2079 = None + convert_element_type_2960 = torch.ops.prims.convert_element_type.default(mul_1982, torch.bfloat16); mul_1982 = None + permute_1428 = torch.ops.aten.permute.default(convert_element_type_2960, [1, 0]) + _grouped_mm_202 = torch.ops.aten._grouped_mm.default(permute_1428, index_11, cumsum_17); permute_1428 = index_11 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16); primals_113 = None + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 8, '513'); convert_element_type_340 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_127, [0, 2, 1]); wait_tensor_127 = None + permute_1430 = torch.ops.aten.permute.default(permute_95, [0, 2, 1]); permute_95 = None + _grouped_mm_203 = torch.ops.aten._grouped_mm.default(convert_element_type_2960, permute_1430, cumsum_17); convert_element_type_2960 = permute_1430 = cumsum_17 = None + add_2080 = torch.ops.aten.add.Tensor(_grouped_mm_201, _grouped_mm_203); _grouped_mm_201 = _grouped_mm_203 = None + convert_element_type_2961 = torch.ops.prims.convert_element_type.default(_grouped_mm_200, torch.float32); _grouped_mm_200 = None + div_252 = torch.ops.aten.div.Tensor(convert_element_type_2961, 64); convert_element_type_2961 = None + reduce_scatter_tensor_285 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_252, 'sum', 8, '513'); div_252 = None + wait_tensor_885 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_285); reduce_scatter_tensor_285 = None + convert_element_type_2962 = torch.ops.prims.convert_element_type.default(_grouped_mm_198, torch.float32); _grouped_mm_198 = None + div_253 = torch.ops.aten.div.Tensor(convert_element_type_2962, 64); convert_element_type_2962 = None + reduce_scatter_tensor_286 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_253, 'sum', 8, '513'); div_253 = None + wait_tensor_886 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_286); reduce_scatter_tensor_286 = None + convert_element_type_2963 = torch.ops.prims.convert_element_type.default(_grouped_mm_202, torch.float32); _grouped_mm_202 = None + div_254 = torch.ops.aten.div.Tensor(convert_element_type_2963, 64); convert_element_type_2963 = None + reduce_scatter_tensor_287 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_254, 'sum', 8, '513'); div_254 = None + wait_tensor_887 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_287); reduce_scatter_tensor_287 = None + index_put_92 = torch.ops.aten.index_put.default(full_428, [getitem_92], add_2080, True); full_428 = getitem_92 = add_2080 = None + slice_227 = torch.ops.aten.slice.Tensor(index_put_92, 0, 0, add_2081); index_put_92 = add_2081 = None + all_to_all_single_119 = torch.ops._c10d_functional.all_to_all_single.default(slice_227, [_local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87], [_local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95], '521'); slice_227 = _local_scalar_dense_80 = _local_scalar_dense_81 = _local_scalar_dense_82 = _local_scalar_dense_83 = _local_scalar_dense_84 = _local_scalar_dense_85 = _local_scalar_dense_86 = _local_scalar_dense_87 = _local_scalar_dense_88 = _local_scalar_dense_89 = _local_scalar_dense_90 = _local_scalar_dense_91 = _local_scalar_dense_92 = _local_scalar_dense_93 = _local_scalar_dense_94 = _local_scalar_dense_95 = None + wait_tensor_888 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_119); all_to_all_single_119 = None + index_put_93 = torch.ops.aten.index_put.default(full_default_52, [div_27], wait_tensor_888, True); div_27 = wait_tensor_888 = None + add_2085 = torch.ops.aten.add.Tensor(add_2077, index_put_93); add_2077 = index_put_93 = None + mul_1983 = torch.ops.aten.mul.Tensor(view_2163, 1.0); view_2163 = None + scatter_add_20 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_89, mul_1983); getitem_89 = mul_1983 = None + convert_element_type_335 = torch.ops.prims.convert_element_type.default(mm_51, torch.float32); mm_51 = None + sub_120 = torch.ops.aten.sub.Tensor(convert_element_type_335, amax_5); convert_element_type_335 = amax_5 = None + exp_16 = torch.ops.aten.exp.default(sub_120); sub_120 = None + div_26 = torch.ops.aten.div.Tensor(exp_16, sum_21); exp_16 = sum_21 = None + mul_1984 = torch.ops.aten.mul.Tensor(scatter_add_20, div_26); scatter_add_20 = None + sum_267 = torch.ops.aten.sum.dim_IntList(mul_1984, [1], True) + neg_115 = torch.ops.aten.neg.default(div_26); div_26 = None + fma_20 = torch.ops.prims.fma.default(neg_115, sum_267, mul_1984); neg_115 = sum_267 = mul_1984 = None + convert_element_type_2964 = torch.ops.prims.convert_element_type.default(fma_20, torch.bfloat16); fma_20 = None + permute_1432 = torch.ops.aten.permute.default(convert_element_type_2964, [1, 0]) + mm_544 = torch.ops.aten.mm.default(permute_1432, view_393); permute_1432 = view_393 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16); primals_111 = None + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_332, 64, '0'); convert_element_type_332 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_94 = torch.ops.aten.permute.default(wait_tensor_123, [1, 0]); wait_tensor_123 = None + permute_1434 = torch.ops.aten.permute.default(permute_94, [1, 0]); permute_94 = None + mm_545 = torch.ops.aten.mm.default(convert_element_type_2964, permute_1434); convert_element_type_2964 = permute_1434 = None + add_2086 = torch.ops.aten.add.Tensor(add_2085, mm_545); add_2085 = mm_545 = None + convert_element_type_2969 = torch.ops.prims.convert_element_type.default(mm_544, torch.float32); mm_544 = None + reduce_scatter_tensor_288 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2969, 'avg', 64, '0'); convert_element_type_2969 = None + wait_tensor_889 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_288); reduce_scatter_tensor_288 = None + view_2165 = torch.ops.aten.view.default(add_2086, [2, 4096, 2048]); add_2086 = None + convert_element_type_2970 = torch.ops.prims.convert_element_type.default(view_2165, torch.float32); view_2165 = None + convert_element_type_329 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16); primals_109 = None + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_329, 64, '0'); convert_element_type_329 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + convert_element_type_2972 = torch.ops.prims.convert_element_type.default(wait_tensor_122, torch.float32); wait_tensor_122 = None + mul_1985 = torch.ops.aten.mul.Tensor(convert_element_type_2970, convert_element_type_2972); convert_element_type_2972 = None + convert_element_type_330 = torch.ops.prims.convert_element_type.default(add_348, torch.float32); add_348 = None + mul_260 = torch.ops.aten.mul.Tensor(convert_element_type_330, rsqrt_20); convert_element_type_330 = None + mul_1987 = torch.ops.aten.mul.Tensor(mul_260, mul_1985) + sum_268 = torch.ops.aten.sum.dim_IntList(mul_1987, [2], True); mul_1987 = None + div_255 = torch.ops.aten.div.Tensor(mul_260, 2048) + mul_1988 = torch.ops.aten.mul.Tensor(div_255, sum_268); div_255 = sum_268 = None + sub_748 = torch.ops.aten.sub.Tensor(mul_1985, mul_1988); mul_1985 = mul_1988 = None + mul_1989 = torch.ops.aten.mul.Tensor(sub_748, rsqrt_20); sub_748 = rsqrt_20 = None + mul_1990 = torch.ops.aten.mul.Tensor(convert_element_type_2970, mul_260); convert_element_type_2970 = mul_260 = None + sum_269 = torch.ops.aten.sum.dim_IntList(mul_1990, [0, 1]); mul_1990 = None + convert_element_type_2973 = torch.ops.prims.convert_element_type.default(mul_1989, torch.bfloat16); mul_1989 = None + add_2087 = torch.ops.aten.add.Tensor(add_2074, convert_element_type_2973); add_2074 = convert_element_type_2973 = None + convert_element_type_default_21 = torch.ops.prims.convert_element_type.default(sum_269, torch.float32); sum_269 = None + reduce_scatter_tensor_289 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_21, 'avg', 64, '0'); convert_element_type_default_21 = None + wait_tensor_890 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_289); reduce_scatter_tensor_289 = None + view_2166 = torch.ops.aten.view.default(add_2087, [8192, 2048]) + permute_1436 = torch.ops.aten.permute.default(view_2166, [1, 0]) + permute_92 = torch.ops.aten.permute.default(getitem_85, [0, 2, 1, 3]) + view_388 = torch.ops.aten.view.default(permute_92, [2, 4096, -1]); permute_92 = None + view_390 = torch.ops.aten.view.default(view_388, [8192, 2048]); view_388 = None + mm_546 = torch.ops.aten.mm.default(permute_1436, view_390); permute_1436 = view_390 = None + convert_element_type_326 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_326, 64, '0'); convert_element_type_326 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_93 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + permute_1438 = torch.ops.aten.permute.default(permute_93, [1, 0]); permute_93 = None + mm_547 = torch.ops.aten.mm.default(view_2166, permute_1438); view_2166 = permute_1438 = None + view_2167 = torch.ops.aten.view.default(mm_547, [2, 4096, 2048]); mm_547 = None + convert_element_type_2980 = torch.ops.prims.convert_element_type.default(mm_546, torch.float32); mm_546 = None + reduce_scatter_tensor_290 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2980, 'avg', 64, '0'); convert_element_type_2980 = None + wait_tensor_891 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_290); reduce_scatter_tensor_290 = None + view_2168 = torch.ops.aten.view.default(view_2167, [2, 4096, 16, 128]); view_2167 = None + permute_1440 = torch.ops.aten.permute.default(view_2168, [0, 2, 1, 3]); view_2168 = None + fw_graph20 = self.fw_graph20 + joint_graph20 = self.joint_graph20 + mask_graph20 = self.mask_graph20 + flex_attention_backward_20 = torch.ops.higher_order.flex_attention_backward(permute_89, permute_90, permute_91, getitem_85, getitem_86, permute_1440, None, fw_graph20, joint_graph20, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph20), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_89 = permute_90 = permute_91 = getitem_85 = getitem_86 = permute_1440 = fw_graph20 = joint_graph20 = mask_graph20 = None + getitem_453 = flex_attention_backward_20[0] + getitem_454 = flex_attention_backward_20[1] + getitem_455 = flex_attention_backward_20[2]; flex_attention_backward_20 = None + permute_1441 = torch.ops.aten.permute.default(getitem_455, [0, 2, 1, 3]); getitem_455 = None + permute_1442 = torch.ops.aten.permute.default(getitem_454, [0, 2, 1, 3]); getitem_454 = None + permute_1443 = torch.ops.aten.permute.default(getitem_453, [0, 2, 1, 3]); getitem_453 = None + slice_229 = torch.ops.aten.slice.Tensor(permute_1442, 3, 0, 128) + slice_230 = torch.ops.aten.slice.Tensor(permute_1442, 3, 128, 192); permute_1442 = None + sum_270 = torch.ops.aten.sum.dim_IntList(slice_230, [2], True); slice_230 = None + cat_140 = torch.ops.aten.cat.default([slice_229, permute_1441], 3); slice_229 = permute_1441 = None + view_2169 = torch.ops.aten.view.default(cat_140, [2, 4096, 4096]); cat_140 = None + view_2170 = torch.ops.aten.view.default(view_2169, [8192, 4096]); view_2169 = None + permute_1444 = torch.ops.aten.permute.default(view_2170, [1, 0]) + mm_548 = torch.ops.aten.mm.default(permute_1444, view_385); permute_1444 = view_385 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_323, 64, '0'); convert_element_type_323 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + permute_1446 = torch.ops.aten.permute.default(permute_88, [1, 0]); permute_88 = None + mm_549 = torch.ops.aten.mm.default(view_2170, permute_1446); view_2170 = permute_1446 = None + view_2171 = torch.ops.aten.view.default(mm_549, [2, 4096, 512]); mm_549 = None + convert_element_type_2985 = torch.ops.prims.convert_element_type.default(mm_548, torch.float32); mm_548 = None + reduce_scatter_tensor_291 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2985, 'avg', 64, '0'); convert_element_type_2985 = None + wait_tensor_892 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_291); reduce_scatter_tensor_291 = None + convert_element_type_2986 = torch.ops.prims.convert_element_type.default(view_2171, torch.float32); view_2171 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 64, '0'); convert_element_type_320 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + convert_element_type_2988 = torch.ops.prims.convert_element_type.default(wait_tensor_119, torch.float32); wait_tensor_119 = None + mul_1991 = torch.ops.aten.mul.Tensor(convert_element_type_2986, convert_element_type_2988); convert_element_type_2988 = None + convert_element_type_321 = torch.ops.prims.convert_element_type.default(getitem_81, torch.float32); getitem_81 = None + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_321, rsqrt_19); convert_element_type_321 = None + mul_1993 = torch.ops.aten.mul.Tensor(mul_258, mul_1991) + sum_271 = torch.ops.aten.sum.dim_IntList(mul_1993, [2], True); mul_1993 = None + div_256 = torch.ops.aten.div.Tensor(mul_258, 512) + mul_1994 = torch.ops.aten.mul.Tensor(div_256, sum_271); div_256 = sum_271 = None + sub_749 = torch.ops.aten.sub.Tensor(mul_1991, mul_1994); mul_1991 = mul_1994 = None + mul_1995 = torch.ops.aten.mul.Tensor(sub_749, rsqrt_19); sub_749 = rsqrt_19 = None + mul_1996 = torch.ops.aten.mul.Tensor(convert_element_type_2986, mul_258); convert_element_type_2986 = mul_258 = None + sum_272 = torch.ops.aten.sum.dim_IntList(mul_1996, [0, 1]); mul_1996 = None + convert_element_type_2989 = torch.ops.prims.convert_element_type.default(mul_1995, torch.bfloat16); mul_1995 = None + convert_element_type_default_20 = torch.ops.prims.convert_element_type.default(sum_272, torch.float32); sum_272 = None + reduce_scatter_tensor_292 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_20, 'avg', 64, '0'); convert_element_type_default_20 = None + wait_tensor_893 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_292); reduce_scatter_tensor_292 = None + convert_element_type_2992 = torch.ops.prims.convert_element_type.default(sum_270, torch.float32); sum_270 = None + view_2172 = torch.ops.aten.view.default(convert_element_type_2992, [2, 4096, 1, 32, 2]); convert_element_type_2992 = None + view_as_complex_94 = torch.ops.aten.view_as_complex.default(view_2172); view_2172 = None + mul_1997 = torch.ops.aten.mul.Tensor(view_as_complex_94, clone_9); view_as_complex_94 = None + view_as_real_94 = torch.ops.aten.view_as_real.default(mul_1997); mul_1997 = None + view_2173 = torch.ops.aten.view.default(view_as_real_94, [2, 4096, 1, 64]); view_as_real_94 = None + convert_element_type_2993 = torch.ops.prims.convert_element_type.default(view_2173, torch.bfloat16); view_2173 = None + squeeze_46 = torch.ops.aten.squeeze.dim(convert_element_type_2993, 2); convert_element_type_2993 = None + cat_141 = torch.ops.aten.cat.default([convert_element_type_2989, squeeze_46], 2); convert_element_type_2989 = squeeze_46 = None + view_2174 = torch.ops.aten.view.default(cat_141, [8192, 576]); cat_141 = None + permute_1448 = torch.ops.aten.permute.default(view_2174, [1, 0]) + mm_550 = torch.ops.aten.mm.default(permute_1448, view_371); permute_1448 = None + convert_element_type_315 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_315, 64, '0'); convert_element_type_315 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_118, [1, 0]); wait_tensor_118 = None + permute_1450 = torch.ops.aten.permute.default(permute_87, [1, 0]); permute_87 = None + mm_551 = torch.ops.aten.mm.default(view_2174, permute_1450); view_2174 = permute_1450 = None + view_2175 = torch.ops.aten.view.default(mm_551, [2, 4096, 2048]); mm_551 = None + convert_element_type_2998 = torch.ops.prims.convert_element_type.default(mm_550, torch.float32); mm_550 = None + reduce_scatter_tensor_293 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2998, 'avg', 64, '0'); convert_element_type_2998 = None + wait_tensor_894 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_293); reduce_scatter_tensor_293 = None + slice_231 = torch.ops.aten.slice.Tensor(permute_1443, 3, 0, 128) + slice_232 = torch.ops.aten.slice.Tensor(permute_1443, 3, 128, 192); permute_1443 = None + convert_element_type_2999 = torch.ops.prims.convert_element_type.default(slice_232, torch.float32); slice_232 = None + view_2176 = torch.ops.aten.view.default(convert_element_type_2999, [2, 4096, 16, 32, 2]); convert_element_type_2999 = None + view_as_complex_95 = torch.ops.aten.view_as_complex.default(view_2176); view_2176 = None + mul_1998 = torch.ops.aten.mul.Tensor(view_as_complex_95, clone_9); view_as_complex_95 = None + view_as_real_95 = torch.ops.aten.view_as_real.default(mul_1998); mul_1998 = None + view_2177 = torch.ops.aten.view.default(view_as_real_95, [2, 4096, 16, 64]); view_as_real_95 = None + convert_element_type_3000 = torch.ops.prims.convert_element_type.default(view_2177, torch.bfloat16); view_2177 = None + cat_142 = torch.ops.aten.cat.default([slice_231, convert_element_type_3000], 3); slice_231 = convert_element_type_3000 = None + view_2178 = torch.ops.aten.view.default(cat_142, [2, 4096, 3072]); cat_142 = None + view_2179 = torch.ops.aten.view.default(view_2178, [8192, 3072]); view_2178 = None + permute_1452 = torch.ops.aten.permute.default(view_2179, [1, 0]) + mm_552 = torch.ops.aten.mm.default(permute_1452, view_371); permute_1452 = view_371 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_310, 64, '0'); convert_element_type_310 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + permute_1454 = torch.ops.aten.permute.default(permute_86, [1, 0]); permute_86 = None + mm_553 = torch.ops.aten.mm.default(view_2179, permute_1454); view_2179 = permute_1454 = None + view_2180 = torch.ops.aten.view.default(mm_553, [2, 4096, 2048]); mm_553 = None + add_2088 = torch.ops.aten.add.Tensor(view_2175, view_2180); view_2175 = view_2180 = None + convert_element_type_3005 = torch.ops.prims.convert_element_type.default(mm_552, torch.float32); mm_552 = None + reduce_scatter_tensor_294 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3005, 'avg', 64, '0'); convert_element_type_3005 = None + wait_tensor_895 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_294); reduce_scatter_tensor_294 = None + convert_element_type_3006 = torch.ops.prims.convert_element_type.default(add_2088, torch.float32); add_2088 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 64, '0'); convert_element_type_307 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_3008 = torch.ops.prims.convert_element_type.default(wait_tensor_116, torch.float32); wait_tensor_116 = None + mul_1999 = torch.ops.aten.mul.Tensor(convert_element_type_3006, convert_element_type_3008); convert_element_type_3008 = None + convert_element_type_308 = torch.ops.prims.convert_element_type.default(add_345, torch.float32); add_345 = None + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_308, rsqrt_18); convert_element_type_308 = None + mul_2001 = torch.ops.aten.mul.Tensor(mul_254, mul_1999) + sum_273 = torch.ops.aten.sum.dim_IntList(mul_2001, [2], True); mul_2001 = None + div_257 = torch.ops.aten.div.Tensor(mul_254, 2048) + mul_2002 = torch.ops.aten.mul.Tensor(div_257, sum_273); div_257 = sum_273 = None + sub_750 = torch.ops.aten.sub.Tensor(mul_1999, mul_2002); mul_1999 = mul_2002 = None + mul_2003 = torch.ops.aten.mul.Tensor(sub_750, rsqrt_18); sub_750 = rsqrt_18 = None + mul_2004 = torch.ops.aten.mul.Tensor(convert_element_type_3006, mul_254); convert_element_type_3006 = mul_254 = None + sum_274 = torch.ops.aten.sum.dim_IntList(mul_2004, [0, 1]); mul_2004 = None + convert_element_type_3009 = torch.ops.prims.convert_element_type.default(mul_2003, torch.bfloat16); mul_2003 = None + add_2089 = torch.ops.aten.add.Tensor(add_2087, convert_element_type_3009); add_2087 = convert_element_type_3009 = None + convert_element_type_default_19 = torch.ops.prims.convert_element_type.default(sum_274, torch.float32); sum_274 = None + reduce_scatter_tensor_295 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_19, 'avg', 64, '0'); convert_element_type_default_19 = None + wait_tensor_896 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_295); reduce_scatter_tensor_295 = None + view_2181 = torch.ops.aten.view.default(add_2089, [8192, 2048]) + unsqueeze_74 = torch.ops.aten.unsqueeze.default(view_2181, 1) + convert_element_type_3012 = torch.ops.prims.convert_element_type.default(unsqueeze_74, torch.float32); unsqueeze_74 = None + bmm_68 = torch.ops.aten.bmm.default(permute_1456, convert_element_type_3012); permute_1456 = None + bmm_69 = torch.ops.aten.bmm.default(convert_element_type_3012, permute_1457); convert_element_type_3012 = permute_1457 = None + convert_element_type_3013 = torch.ops.prims.convert_element_type.default(bmm_68, torch.bfloat16); bmm_68 = None + view_2182 = torch.ops.aten.view.default(bmm_69, [8192, 6]); bmm_69 = None + view_2183 = torch.ops.aten.view.default(convert_element_type_3013, [49152, 2048]); convert_element_type_3013 = None + index_94 = torch.ops.aten.index.Tensor(view_2183, [getitem_77]); view_2183 = getitem_77 = None + permute_1458 = torch.ops.aten.permute.default(view_2181, [1, 0]) + mm_554 = torch.ops.aten.mm.default(permute_1458, mul_251); permute_1458 = mul_251 = None + convert_element_type_302 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_302, 64, '0'); convert_element_type_302 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + permute_1460 = torch.ops.aten.permute.default(permute_85, [1, 0]); permute_85 = None + mm_555 = torch.ops.aten.mm.default(view_2181, permute_1460); view_2181 = permute_1460 = None + convert_element_type_3018 = torch.ops.prims.convert_element_type.default(mm_554, torch.float32); mm_554 = None + reduce_scatter_tensor_296 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3018, 'avg', 64, '0'); convert_element_type_3018 = None + wait_tensor_897 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_296); reduce_scatter_tensor_296 = None + convert_element_type_297 = torch.ops.prims.convert_element_type.default(mm_44, torch.float32); mm_44 = None + neg_10 = torch.ops.aten.neg.default(convert_element_type_297) + exp_15 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_340 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + div_25 = torch.ops.aten.div.Tensor(convert_element_type_297, add_340) + convert_element_type_298 = torch.ops.prims.convert_element_type.default(div_25, torch.bfloat16); div_25 = None + mul_2005 = torch.ops.aten.mul.Tensor(mm_555, convert_element_type_298); convert_element_type_298 = None + mul_2006 = torch.ops.aten.mul.Tensor(mm_555, mm_45); mm_555 = mm_45 = None + permute_1462 = torch.ops.aten.permute.default(mul_2005, [1, 0]) + mm_556 = torch.ops.aten.mm.default(permute_1462, view_326); permute_1462 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_299, 64, '0'); convert_element_type_299 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_114, [1, 0]); wait_tensor_114 = None + permute_1464 = torch.ops.aten.permute.default(permute_84, [1, 0]); permute_84 = None + mm_557 = torch.ops.aten.mm.default(mul_2005, permute_1464); mul_2005 = permute_1464 = None + convert_element_type_3023 = torch.ops.prims.convert_element_type.default(mm_556, torch.float32); mm_556 = None + reduce_scatter_tensor_297 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3023, 'avg', 64, '0'); convert_element_type_3023 = None + wait_tensor_898 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_297); reduce_scatter_tensor_297 = None + convert_element_type_3024 = torch.ops.prims.convert_element_type.default(mul_2006, torch.float32); mul_2006 = None + reciprocal_42 = torch.ops.aten.reciprocal.default(add_340); add_340 = None + mul_2007 = torch.ops.aten.mul.Tensor(reciprocal_42, 1); reciprocal_42 = None + mul_2008 = torch.ops.aten.mul.Tensor(convert_element_type_3024, mul_2007); convert_element_type_3024 = None + sub_751 = torch.ops.aten.sub.Tensor(1, mul_2007); mul_2007 = None + mul_2009 = torch.ops.aten.mul.Tensor(convert_element_type_297, sub_751); convert_element_type_297 = sub_751 = None + add_2091 = torch.ops.aten.add.Tensor(mul_2009, 1); mul_2009 = None + mul_2010 = torch.ops.aten.mul.Tensor(mul_2008, add_2091); mul_2008 = add_2091 = None + convert_element_type_3026 = torch.ops.prims.convert_element_type.default(mul_2010, torch.bfloat16); mul_2010 = None + permute_1466 = torch.ops.aten.permute.default(convert_element_type_3026, [1, 0]) + mm_558 = torch.ops.aten.mm.default(permute_1466, view_326); permute_1466 = None + convert_element_type_294 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_294, 64, '0'); convert_element_type_294 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_83 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + permute_1468 = torch.ops.aten.permute.default(permute_83, [1, 0]); permute_83 = None + mm_559 = torch.ops.aten.mm.default(convert_element_type_3026, permute_1468); convert_element_type_3026 = permute_1468 = None + add_2092 = torch.ops.aten.add.Tensor(mm_557, mm_559); mm_557 = mm_559 = None + convert_element_type_3031 = torch.ops.prims.convert_element_type.default(mm_558, torch.float32); mm_558 = None + reduce_scatter_tensor_298 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3031, 'avg', 64, '0'); convert_element_type_3031 = None + wait_tensor_899 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_298); reduce_scatter_tensor_298 = None + all_to_all_single_120 = torch.ops._c10d_functional.all_to_all_single.default(index_94, [_local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79], [_local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71], '521'); index_94 = None + wait_tensor_900 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_120); all_to_all_single_120 = None + full_432 = torch.ops.aten.full.default([sym_size_int_17, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_17 = None + slice_scatter_21 = torch.ops.aten.slice_scatter.default(full_432, wait_tensor_900, 0, 0, -1); wait_tensor_900 = None + index_95 = torch.ops.aten.index.Tensor(slice_scatter_21, [getitem_78]); slice_scatter_21 = None + permute_1470 = torch.ops.aten.permute.default(index_95, [1, 0]) + _grouped_mm_204 = torch.ops.aten._grouped_mm.default(permute_1470, mul_231, cumsum_14); permute_1470 = mul_231 = None + convert_element_type_288 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16); primals_98 = None + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_288, 8, '513'); convert_element_type_288 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + permute_82 = torch.ops.aten.permute.default(wait_tensor_108, [0, 2, 1]); wait_tensor_108 = None + permute_1472 = torch.ops.aten.permute.default(permute_82, [0, 2, 1]); permute_82 = None + _grouped_mm_205 = torch.ops.aten._grouped_mm.default(index_95, permute_1472, cumsum_14); index_95 = permute_1472 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(_grouped_mm_12, torch.float32); _grouped_mm_12 = None + neg_9 = torch.ops.aten.neg.default(convert_element_type_292) + exp_14 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_304 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + div_24 = torch.ops.aten.div.Tensor(convert_element_type_292, add_304) + convert_element_type_293 = torch.ops.prims.convert_element_type.default(div_24, torch.bfloat16); div_24 = None + mul_2011 = torch.ops.aten.mul.Tensor(_grouped_mm_205, convert_element_type_293); convert_element_type_293 = None + mul_2012 = torch.ops.aten.mul.Tensor(_grouped_mm_205, _grouped_mm_13); _grouped_mm_205 = _grouped_mm_13 = None + permute_1474 = torch.ops.aten.permute.default(mul_2011, [1, 0]) + _grouped_mm_206 = torch.ops.aten._grouped_mm.default(permute_1474, index_9, cumsum_14); permute_1474 = None + convert_element_type_289 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16); primals_99 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_289, 8, '513'); convert_element_type_289 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + permute_81 = torch.ops.aten.permute.default(wait_tensor_109, [0, 2, 1]); wait_tensor_109 = None + permute_1476 = torch.ops.aten.permute.default(permute_81, [0, 2, 1]); permute_81 = None + _grouped_mm_207 = torch.ops.aten._grouped_mm.default(mul_2011, permute_1476, cumsum_14); mul_2011 = permute_1476 = None + convert_element_type_3032 = torch.ops.prims.convert_element_type.default(mul_2012, torch.float32); mul_2012 = None + reciprocal_43 = torch.ops.aten.reciprocal.default(add_304); add_304 = None + mul_2013 = torch.ops.aten.mul.Tensor(reciprocal_43, 1); reciprocal_43 = None + mul_2014 = torch.ops.aten.mul.Tensor(convert_element_type_3032, mul_2013); convert_element_type_3032 = None + sub_752 = torch.ops.aten.sub.Tensor(1, mul_2013); mul_2013 = None + mul_2015 = torch.ops.aten.mul.Tensor(convert_element_type_292, sub_752); convert_element_type_292 = sub_752 = None + add_2094 = torch.ops.aten.add.Tensor(mul_2015, 1); mul_2015 = None + mul_2016 = torch.ops.aten.mul.Tensor(mul_2014, add_2094); mul_2014 = add_2094 = None + convert_element_type_3034 = torch.ops.prims.convert_element_type.default(mul_2016, torch.bfloat16); mul_2016 = None + permute_1478 = torch.ops.aten.permute.default(convert_element_type_3034, [1, 0]) + _grouped_mm_208 = torch.ops.aten._grouped_mm.default(permute_1478, index_9, cumsum_14); permute_1478 = index_9 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16); primals_97 = None + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 8, '513'); convert_element_type_286 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + permute_80 = torch.ops.aten.permute.default(wait_tensor_106, [0, 2, 1]); wait_tensor_106 = None + permute_1480 = torch.ops.aten.permute.default(permute_80, [0, 2, 1]); permute_80 = None + _grouped_mm_209 = torch.ops.aten._grouped_mm.default(convert_element_type_3034, permute_1480, cumsum_14); convert_element_type_3034 = permute_1480 = cumsum_14 = None + add_2095 = torch.ops.aten.add.Tensor(_grouped_mm_207, _grouped_mm_209); _grouped_mm_207 = _grouped_mm_209 = None + convert_element_type_3035 = torch.ops.prims.convert_element_type.default(_grouped_mm_206, torch.float32); _grouped_mm_206 = None + div_258 = torch.ops.aten.div.Tensor(convert_element_type_3035, 64); convert_element_type_3035 = None + reduce_scatter_tensor_299 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_258, 'sum', 8, '513'); div_258 = None + wait_tensor_901 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_299); reduce_scatter_tensor_299 = None + convert_element_type_3036 = torch.ops.prims.convert_element_type.default(_grouped_mm_204, torch.float32); _grouped_mm_204 = None + div_259 = torch.ops.aten.div.Tensor(convert_element_type_3036, 64); convert_element_type_3036 = None + reduce_scatter_tensor_300 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_259, 'sum', 8, '513'); div_259 = None + wait_tensor_902 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_300); reduce_scatter_tensor_300 = None + convert_element_type_3037 = torch.ops.prims.convert_element_type.default(_grouped_mm_208, torch.float32); _grouped_mm_208 = None + div_260 = torch.ops.aten.div.Tensor(convert_element_type_3037, 64); convert_element_type_3037 = None + reduce_scatter_tensor_301 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_260, 'sum', 8, '513'); div_260 = None + wait_tensor_903 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_301); reduce_scatter_tensor_301 = None + index_put_94 = torch.ops.aten.index_put.default(full_432, [getitem_78], add_2095, True); full_432 = getitem_78 = add_2095 = None + slice_233 = torch.ops.aten.slice.Tensor(index_put_94, 0, 0, add_2096); index_put_94 = add_2096 = None + all_to_all_single_121 = torch.ops._c10d_functional.all_to_all_single.default(slice_233, [_local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71], [_local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79], '521'); slice_233 = _local_scalar_dense_64 = _local_scalar_dense_65 = _local_scalar_dense_66 = _local_scalar_dense_67 = _local_scalar_dense_68 = _local_scalar_dense_69 = _local_scalar_dense_70 = _local_scalar_dense_71 = _local_scalar_dense_72 = _local_scalar_dense_73 = _local_scalar_dense_74 = _local_scalar_dense_75 = _local_scalar_dense_76 = _local_scalar_dense_77 = _local_scalar_dense_78 = _local_scalar_dense_79 = None + wait_tensor_904 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_121); all_to_all_single_121 = None + index_put_95 = torch.ops.aten.index_put.default(full_default_52, [div_22], wait_tensor_904, True); div_22 = wait_tensor_904 = None + add_2100 = torch.ops.aten.add.Tensor(add_2092, index_put_95); add_2092 = index_put_95 = None + mul_2017 = torch.ops.aten.mul.Tensor(view_2182, 1.0); view_2182 = None + scatter_add_21 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_75, mul_2017); getitem_75 = mul_2017 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(mm_43, torch.float32); mm_43 = None + sub_96 = torch.ops.aten.sub.Tensor(convert_element_type_281, amax_4); convert_element_type_281 = amax_4 = None + exp_13 = torch.ops.aten.exp.default(sub_96); sub_96 = None + div_21 = torch.ops.aten.div.Tensor(exp_13, sum_17); exp_13 = sum_17 = None + mul_2018 = torch.ops.aten.mul.Tensor(scatter_add_21, div_21); scatter_add_21 = None + sum_275 = torch.ops.aten.sum.dim_IntList(mul_2018, [1], True) + neg_118 = torch.ops.aten.neg.default(div_21); div_21 = None + fma_21 = torch.ops.prims.fma.default(neg_118, sum_275, mul_2018); neg_118 = sum_275 = mul_2018 = None + convert_element_type_3038 = torch.ops.prims.convert_element_type.default(fma_21, torch.bfloat16); fma_21 = None + permute_1482 = torch.ops.aten.permute.default(convert_element_type_3038, [1, 0]) + mm_560 = torch.ops.aten.mm.default(permute_1482, view_326); permute_1482 = view_326 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16); primals_95 = None + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_278, 64, '0'); convert_element_type_278 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + permute_1484 = torch.ops.aten.permute.default(permute_79, [1, 0]); permute_79 = None + mm_561 = torch.ops.aten.mm.default(convert_element_type_3038, permute_1484); convert_element_type_3038 = permute_1484 = None + add_2101 = torch.ops.aten.add.Tensor(add_2100, mm_561); add_2100 = mm_561 = None + convert_element_type_3043 = torch.ops.prims.convert_element_type.default(mm_560, torch.float32); mm_560 = None + reduce_scatter_tensor_302 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3043, 'avg', 64, '0'); convert_element_type_3043 = None + wait_tensor_905 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_302); reduce_scatter_tensor_302 = None + view_2184 = torch.ops.aten.view.default(add_2101, [2, 4096, 2048]); add_2101 = None + convert_element_type_3044 = torch.ops.prims.convert_element_type.default(view_2184, torch.float32); view_2184 = None + convert_element_type_275 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16); primals_93 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_275, 64, '0'); convert_element_type_275 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + convert_element_type_3046 = torch.ops.prims.convert_element_type.default(wait_tensor_101, torch.float32); wait_tensor_101 = None + mul_2019 = torch.ops.aten.mul.Tensor(convert_element_type_3044, convert_element_type_3046); convert_element_type_3046 = None + convert_element_type_276 = torch.ops.prims.convert_element_type.default(add_280, torch.float32); add_280 = None + mul_211 = torch.ops.aten.mul.Tensor(convert_element_type_276, rsqrt_17); convert_element_type_276 = None + mul_2021 = torch.ops.aten.mul.Tensor(mul_211, mul_2019) + sum_276 = torch.ops.aten.sum.dim_IntList(mul_2021, [2], True); mul_2021 = None + div_261 = torch.ops.aten.div.Tensor(mul_211, 2048) + mul_2022 = torch.ops.aten.mul.Tensor(div_261, sum_276); div_261 = sum_276 = None + sub_754 = torch.ops.aten.sub.Tensor(mul_2019, mul_2022); mul_2019 = mul_2022 = None + mul_2023 = torch.ops.aten.mul.Tensor(sub_754, rsqrt_17); sub_754 = rsqrt_17 = None + mul_2024 = torch.ops.aten.mul.Tensor(convert_element_type_3044, mul_211); convert_element_type_3044 = mul_211 = None + sum_277 = torch.ops.aten.sum.dim_IntList(mul_2024, [0, 1]); mul_2024 = None + convert_element_type_3047 = torch.ops.prims.convert_element_type.default(mul_2023, torch.bfloat16); mul_2023 = None + add_2102 = torch.ops.aten.add.Tensor(add_2089, convert_element_type_3047); add_2089 = convert_element_type_3047 = None + convert_element_type_default_18 = torch.ops.prims.convert_element_type.default(sum_277, torch.float32); sum_277 = None + reduce_scatter_tensor_303 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_18, 'avg', 64, '0'); convert_element_type_default_18 = None + wait_tensor_906 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_303); reduce_scatter_tensor_303 = None + view_2185 = torch.ops.aten.view.default(add_2102, [8192, 2048]) + permute_1486 = torch.ops.aten.permute.default(view_2185, [1, 0]) + permute_77 = torch.ops.aten.permute.default(getitem_71, [0, 2, 1, 3]) + view_321 = torch.ops.aten.view.default(permute_77, [2, 4096, -1]); permute_77 = None + view_323 = torch.ops.aten.view.default(view_321, [8192, 2048]); view_321 = None + mm_562 = torch.ops.aten.mm.default(permute_1486, view_323); permute_1486 = view_323 = None + convert_element_type_272 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16); primals_92 = None + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_272, 64, '0'); convert_element_type_272 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_100, [1, 0]); wait_tensor_100 = None + permute_1488 = torch.ops.aten.permute.default(permute_78, [1, 0]); permute_78 = None + mm_563 = torch.ops.aten.mm.default(view_2185, permute_1488); view_2185 = permute_1488 = None + view_2186 = torch.ops.aten.view.default(mm_563, [2, 4096, 2048]); mm_563 = None + convert_element_type_3054 = torch.ops.prims.convert_element_type.default(mm_562, torch.float32); mm_562 = None + reduce_scatter_tensor_304 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3054, 'avg', 64, '0'); convert_element_type_3054 = None + wait_tensor_907 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_304); reduce_scatter_tensor_304 = None + view_2187 = torch.ops.aten.view.default(view_2186, [2, 4096, 16, 128]); view_2186 = None + permute_1490 = torch.ops.aten.permute.default(view_2187, [0, 2, 1, 3]); view_2187 = None + fw_graph21 = self.fw_graph21 + joint_graph21 = self.joint_graph21 + mask_graph21 = self.mask_graph21 + flex_attention_backward_21 = torch.ops.higher_order.flex_attention_backward(permute_74, permute_75, permute_76, getitem_71, getitem_72, permute_1490, None, fw_graph21, joint_graph21, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph21), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_74 = permute_75 = permute_76 = getitem_71 = getitem_72 = permute_1490 = fw_graph21 = joint_graph21 = mask_graph21 = None + getitem_457 = flex_attention_backward_21[0] + getitem_458 = flex_attention_backward_21[1] + getitem_459 = flex_attention_backward_21[2]; flex_attention_backward_21 = None + permute_1491 = torch.ops.aten.permute.default(getitem_459, [0, 2, 1, 3]); getitem_459 = None + permute_1492 = torch.ops.aten.permute.default(getitem_458, [0, 2, 1, 3]); getitem_458 = None + permute_1493 = torch.ops.aten.permute.default(getitem_457, [0, 2, 1, 3]); getitem_457 = None + slice_235 = torch.ops.aten.slice.Tensor(permute_1492, 3, 0, 128) + slice_236 = torch.ops.aten.slice.Tensor(permute_1492, 3, 128, 192); permute_1492 = None + sum_278 = torch.ops.aten.sum.dim_IntList(slice_236, [2], True); slice_236 = None + cat_143 = torch.ops.aten.cat.default([slice_235, permute_1491], 3); slice_235 = permute_1491 = None + view_2188 = torch.ops.aten.view.default(cat_143, [2, 4096, 4096]); cat_143 = None + view_2189 = torch.ops.aten.view.default(view_2188, [8192, 4096]); view_2188 = None + permute_1494 = torch.ops.aten.permute.default(view_2189, [1, 0]) + mm_564 = torch.ops.aten.mm.default(permute_1494, view_318); permute_1494 = view_318 = None + convert_element_type_269 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16); primals_91 = None + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_269, 64, '0'); convert_element_type_269 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + permute_1496 = torch.ops.aten.permute.default(permute_73, [1, 0]); permute_73 = None + mm_565 = torch.ops.aten.mm.default(view_2189, permute_1496); view_2189 = permute_1496 = None + view_2190 = torch.ops.aten.view.default(mm_565, [2, 4096, 512]); mm_565 = None + convert_element_type_3059 = torch.ops.prims.convert_element_type.default(mm_564, torch.float32); mm_564 = None + reduce_scatter_tensor_305 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3059, 'avg', 64, '0'); convert_element_type_3059 = None + wait_tensor_908 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_305); reduce_scatter_tensor_305 = None + convert_element_type_3060 = torch.ops.prims.convert_element_type.default(view_2190, torch.float32); view_2190 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_266, 64, '0'); convert_element_type_266 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_3062 = torch.ops.prims.convert_element_type.default(wait_tensor_98, torch.float32); wait_tensor_98 = None + mul_2025 = torch.ops.aten.mul.Tensor(convert_element_type_3060, convert_element_type_3062); convert_element_type_3062 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(getitem_67, torch.float32); getitem_67 = None + mul_209 = torch.ops.aten.mul.Tensor(convert_element_type_267, rsqrt_16); convert_element_type_267 = None + mul_2027 = torch.ops.aten.mul.Tensor(mul_209, mul_2025) + sum_279 = torch.ops.aten.sum.dim_IntList(mul_2027, [2], True); mul_2027 = None + div_262 = torch.ops.aten.div.Tensor(mul_209, 512) + mul_2028 = torch.ops.aten.mul.Tensor(div_262, sum_279); div_262 = sum_279 = None + sub_755 = torch.ops.aten.sub.Tensor(mul_2025, mul_2028); mul_2025 = mul_2028 = None + mul_2029 = torch.ops.aten.mul.Tensor(sub_755, rsqrt_16); sub_755 = rsqrt_16 = None + mul_2030 = torch.ops.aten.mul.Tensor(convert_element_type_3060, mul_209); convert_element_type_3060 = mul_209 = None + sum_280 = torch.ops.aten.sum.dim_IntList(mul_2030, [0, 1]); mul_2030 = None + convert_element_type_3063 = torch.ops.prims.convert_element_type.default(mul_2029, torch.bfloat16); mul_2029 = None + convert_element_type_default_17 = torch.ops.prims.convert_element_type.default(sum_280, torch.float32); sum_280 = None + reduce_scatter_tensor_306 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_17, 'avg', 64, '0'); convert_element_type_default_17 = None + wait_tensor_909 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_306); reduce_scatter_tensor_306 = None + convert_element_type_3066 = torch.ops.prims.convert_element_type.default(sum_278, torch.float32); sum_278 = None + view_2191 = torch.ops.aten.view.default(convert_element_type_3066, [2, 4096, 1, 32, 2]); convert_element_type_3066 = None + view_as_complex_96 = torch.ops.aten.view_as_complex.default(view_2191); view_2191 = None + mul_2031 = torch.ops.aten.mul.Tensor(view_as_complex_96, clone_9); view_as_complex_96 = None + view_as_real_96 = torch.ops.aten.view_as_real.default(mul_2031); mul_2031 = None + view_2192 = torch.ops.aten.view.default(view_as_real_96, [2, 4096, 1, 64]); view_as_real_96 = None + convert_element_type_3067 = torch.ops.prims.convert_element_type.default(view_2192, torch.bfloat16); view_2192 = None + squeeze_47 = torch.ops.aten.squeeze.dim(convert_element_type_3067, 2); convert_element_type_3067 = None + cat_144 = torch.ops.aten.cat.default([convert_element_type_3063, squeeze_47], 2); convert_element_type_3063 = squeeze_47 = None + view_2193 = torch.ops.aten.view.default(cat_144, [8192, 576]); cat_144 = None + permute_1498 = torch.ops.aten.permute.default(view_2193, [1, 0]) + mm_566 = torch.ops.aten.mm.default(permute_1498, view_304); permute_1498 = None + convert_element_type_261 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_261, 64, '0'); convert_element_type_261 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_72 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + permute_1500 = torch.ops.aten.permute.default(permute_72, [1, 0]); permute_72 = None + mm_567 = torch.ops.aten.mm.default(view_2193, permute_1500); view_2193 = permute_1500 = None + view_2194 = torch.ops.aten.view.default(mm_567, [2, 4096, 2048]); mm_567 = None + convert_element_type_3072 = torch.ops.prims.convert_element_type.default(mm_566, torch.float32); mm_566 = None + reduce_scatter_tensor_307 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3072, 'avg', 64, '0'); convert_element_type_3072 = None + wait_tensor_910 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_307); reduce_scatter_tensor_307 = None + slice_237 = torch.ops.aten.slice.Tensor(permute_1493, 3, 0, 128) + slice_238 = torch.ops.aten.slice.Tensor(permute_1493, 3, 128, 192); permute_1493 = None + convert_element_type_3073 = torch.ops.prims.convert_element_type.default(slice_238, torch.float32); slice_238 = None + view_2195 = torch.ops.aten.view.default(convert_element_type_3073, [2, 4096, 16, 32, 2]); convert_element_type_3073 = None + view_as_complex_97 = torch.ops.aten.view_as_complex.default(view_2195); view_2195 = None + mul_2032 = torch.ops.aten.mul.Tensor(view_as_complex_97, clone_9); view_as_complex_97 = None + view_as_real_97 = torch.ops.aten.view_as_real.default(mul_2032); mul_2032 = None + view_2196 = torch.ops.aten.view.default(view_as_real_97, [2, 4096, 16, 64]); view_as_real_97 = None + convert_element_type_3074 = torch.ops.prims.convert_element_type.default(view_2196, torch.bfloat16); view_2196 = None + cat_145 = torch.ops.aten.cat.default([slice_237, convert_element_type_3074], 3); slice_237 = convert_element_type_3074 = None + view_2197 = torch.ops.aten.view.default(cat_145, [2, 4096, 3072]); cat_145 = None + view_2198 = torch.ops.aten.view.default(view_2197, [8192, 3072]); view_2197 = None + permute_1502 = torch.ops.aten.permute.default(view_2198, [1, 0]) + mm_568 = torch.ops.aten.mm.default(permute_1502, view_304); permute_1502 = view_304 = None + convert_element_type_256 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_256, 64, '0'); convert_element_type_256 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_71 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + permute_1504 = torch.ops.aten.permute.default(permute_71, [1, 0]); permute_71 = None + mm_569 = torch.ops.aten.mm.default(view_2198, permute_1504); view_2198 = permute_1504 = None + view_2199 = torch.ops.aten.view.default(mm_569, [2, 4096, 2048]); mm_569 = None + add_2103 = torch.ops.aten.add.Tensor(view_2194, view_2199); view_2194 = view_2199 = None + convert_element_type_3079 = torch.ops.prims.convert_element_type.default(mm_568, torch.float32); mm_568 = None + reduce_scatter_tensor_308 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3079, 'avg', 64, '0'); convert_element_type_3079 = None + wait_tensor_911 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_308); reduce_scatter_tensor_308 = None + convert_element_type_3080 = torch.ops.prims.convert_element_type.default(add_2103, torch.float32); add_2103 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 64, '0'); convert_element_type_253 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + convert_element_type_3082 = torch.ops.prims.convert_element_type.default(wait_tensor_95, torch.float32); wait_tensor_95 = None + mul_2033 = torch.ops.aten.mul.Tensor(convert_element_type_3080, convert_element_type_3082); convert_element_type_3082 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(add_277, torch.float32); add_277 = None + mul_205 = torch.ops.aten.mul.Tensor(convert_element_type_254, rsqrt_15); convert_element_type_254 = None + mul_2035 = torch.ops.aten.mul.Tensor(mul_205, mul_2033) + sum_281 = torch.ops.aten.sum.dim_IntList(mul_2035, [2], True); mul_2035 = None + div_263 = torch.ops.aten.div.Tensor(mul_205, 2048) + mul_2036 = torch.ops.aten.mul.Tensor(div_263, sum_281); div_263 = sum_281 = None + sub_756 = torch.ops.aten.sub.Tensor(mul_2033, mul_2036); mul_2033 = mul_2036 = None + mul_2037 = torch.ops.aten.mul.Tensor(sub_756, rsqrt_15); sub_756 = rsqrt_15 = None + mul_2038 = torch.ops.aten.mul.Tensor(convert_element_type_3080, mul_205); convert_element_type_3080 = mul_205 = None + sum_282 = torch.ops.aten.sum.dim_IntList(mul_2038, [0, 1]); mul_2038 = None + convert_element_type_3083 = torch.ops.prims.convert_element_type.default(mul_2037, torch.bfloat16); mul_2037 = None + add_2104 = torch.ops.aten.add.Tensor(add_2102, convert_element_type_3083); add_2102 = convert_element_type_3083 = None + convert_element_type_default_16 = torch.ops.prims.convert_element_type.default(sum_282, torch.float32); sum_282 = None + reduce_scatter_tensor_309 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_16, 'avg', 64, '0'); convert_element_type_default_16 = None + wait_tensor_912 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_309); reduce_scatter_tensor_309 = None + view_2200 = torch.ops.aten.view.default(add_2104, [8192, 2048]) + unsqueeze_75 = torch.ops.aten.unsqueeze.default(view_2200, 1) + convert_element_type_3086 = torch.ops.prims.convert_element_type.default(unsqueeze_75, torch.float32); unsqueeze_75 = None + bmm_70 = torch.ops.aten.bmm.default(permute_1506, convert_element_type_3086); permute_1506 = None + bmm_71 = torch.ops.aten.bmm.default(convert_element_type_3086, permute_1507); convert_element_type_3086 = permute_1507 = None + convert_element_type_3087 = torch.ops.prims.convert_element_type.default(bmm_70, torch.bfloat16); bmm_70 = None + view_2201 = torch.ops.aten.view.default(bmm_71, [8192, 6]); bmm_71 = None + view_2202 = torch.ops.aten.view.default(convert_element_type_3087, [49152, 2048]); convert_element_type_3087 = None + index_96 = torch.ops.aten.index.Tensor(view_2202, [getitem_63]); view_2202 = getitem_63 = None + permute_1508 = torch.ops.aten.permute.default(view_2200, [1, 0]) + mm_570 = torch.ops.aten.mm.default(permute_1508, mul_202); permute_1508 = mul_202 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 64, '0'); convert_element_type_248 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + permute_70 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + permute_1510 = torch.ops.aten.permute.default(permute_70, [1, 0]); permute_70 = None + mm_571 = torch.ops.aten.mm.default(view_2200, permute_1510); view_2200 = permute_1510 = None + convert_element_type_3092 = torch.ops.prims.convert_element_type.default(mm_570, torch.float32); mm_570 = None + reduce_scatter_tensor_310 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3092, 'avg', 64, '0'); convert_element_type_3092 = None + wait_tensor_913 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_310); reduce_scatter_tensor_310 = None + convert_element_type_243 = torch.ops.prims.convert_element_type.default(mm_36, torch.float32); mm_36 = None + neg_8 = torch.ops.aten.neg.default(convert_element_type_243) + exp_12 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_272 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + div_20 = torch.ops.aten.div.Tensor(convert_element_type_243, add_272) + convert_element_type_244 = torch.ops.prims.convert_element_type.default(div_20, torch.bfloat16); div_20 = None + mul_2039 = torch.ops.aten.mul.Tensor(mm_571, convert_element_type_244); convert_element_type_244 = None + mul_2040 = torch.ops.aten.mul.Tensor(mm_571, mm_37); mm_571 = mm_37 = None + permute_1512 = torch.ops.aten.permute.default(mul_2039, [1, 0]) + mm_572 = torch.ops.aten.mm.default(permute_1512, view_259); permute_1512 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_245, 64, '0'); convert_element_type_245 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_69 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + permute_1514 = torch.ops.aten.permute.default(permute_69, [1, 0]); permute_69 = None + mm_573 = torch.ops.aten.mm.default(mul_2039, permute_1514); mul_2039 = permute_1514 = None + convert_element_type_3097 = torch.ops.prims.convert_element_type.default(mm_572, torch.float32); mm_572 = None + reduce_scatter_tensor_311 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3097, 'avg', 64, '0'); convert_element_type_3097 = None + wait_tensor_914 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_311); reduce_scatter_tensor_311 = None + convert_element_type_3098 = torch.ops.prims.convert_element_type.default(mul_2040, torch.float32); mul_2040 = None + reciprocal_44 = torch.ops.aten.reciprocal.default(add_272); add_272 = None + mul_2041 = torch.ops.aten.mul.Tensor(reciprocal_44, 1); reciprocal_44 = None + mul_2042 = torch.ops.aten.mul.Tensor(convert_element_type_3098, mul_2041); convert_element_type_3098 = None + sub_757 = torch.ops.aten.sub.Tensor(1, mul_2041); mul_2041 = None + mul_2043 = torch.ops.aten.mul.Tensor(convert_element_type_243, sub_757); convert_element_type_243 = sub_757 = None + add_2106 = torch.ops.aten.add.Tensor(mul_2043, 1); mul_2043 = None + mul_2044 = torch.ops.aten.mul.Tensor(mul_2042, add_2106); mul_2042 = add_2106 = None + convert_element_type_3100 = torch.ops.prims.convert_element_type.default(mul_2044, torch.bfloat16); mul_2044 = None + permute_1516 = torch.ops.aten.permute.default(convert_element_type_3100, [1, 0]) + mm_574 = torch.ops.aten.mm.default(permute_1516, view_259); permute_1516 = None + convert_element_type_240 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_240, 64, '0'); convert_element_type_240 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + permute_1518 = torch.ops.aten.permute.default(permute_68, [1, 0]); permute_68 = None + mm_575 = torch.ops.aten.mm.default(convert_element_type_3100, permute_1518); convert_element_type_3100 = permute_1518 = None + add_2107 = torch.ops.aten.add.Tensor(mm_573, mm_575); mm_573 = mm_575 = None + convert_element_type_3105 = torch.ops.prims.convert_element_type.default(mm_574, torch.float32); mm_574 = None + reduce_scatter_tensor_312 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3105, 'avg', 64, '0'); convert_element_type_3105 = None + wait_tensor_915 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_312); reduce_scatter_tensor_312 = None + all_to_all_single_122 = torch.ops._c10d_functional.all_to_all_single.default(index_96, [_local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63], [_local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55], '521'); index_96 = None + wait_tensor_916 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_122); all_to_all_single_122 = None + full_436 = torch.ops.aten.full.default([sym_size_int_13, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_13 = None + slice_scatter_22 = torch.ops.aten.slice_scatter.default(full_436, wait_tensor_916, 0, 0, -1); wait_tensor_916 = None + index_97 = torch.ops.aten.index.Tensor(slice_scatter_22, [getitem_64]); slice_scatter_22 = None + permute_1520 = torch.ops.aten.permute.default(index_97, [1, 0]) + _grouped_mm_210 = torch.ops.aten._grouped_mm.default(permute_1520, mul_182, cumsum_11); permute_1520 = mul_182 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 8, '513'); convert_element_type_234 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_87, [0, 2, 1]); wait_tensor_87 = None + permute_1522 = torch.ops.aten.permute.default(permute_67, [0, 2, 1]); permute_67 = None + _grouped_mm_211 = torch.ops.aten._grouped_mm.default(index_97, permute_1522, cumsum_11); index_97 = permute_1522 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(_grouped_mm_9, torch.float32); _grouped_mm_9 = None + neg_7 = torch.ops.aten.neg.default(convert_element_type_238) + exp_11 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_236 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + div_19 = torch.ops.aten.div.Tensor(convert_element_type_238, add_236) + convert_element_type_239 = torch.ops.prims.convert_element_type.default(div_19, torch.bfloat16); div_19 = None + mul_2045 = torch.ops.aten.mul.Tensor(_grouped_mm_211, convert_element_type_239); convert_element_type_239 = None + mul_2046 = torch.ops.aten.mul.Tensor(_grouped_mm_211, _grouped_mm_10); _grouped_mm_211 = _grouped_mm_10 = None + permute_1524 = torch.ops.aten.permute.default(mul_2045, [1, 0]) + _grouped_mm_212 = torch.ops.aten._grouped_mm.default(permute_1524, index_7, cumsum_11); permute_1524 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 8, '513'); convert_element_type_235 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_88, [0, 2, 1]); wait_tensor_88 = None + permute_1526 = torch.ops.aten.permute.default(permute_66, [0, 2, 1]); permute_66 = None + _grouped_mm_213 = torch.ops.aten._grouped_mm.default(mul_2045, permute_1526, cumsum_11); mul_2045 = permute_1526 = None + convert_element_type_3106 = torch.ops.prims.convert_element_type.default(mul_2046, torch.float32); mul_2046 = None + reciprocal_45 = torch.ops.aten.reciprocal.default(add_236); add_236 = None + mul_2047 = torch.ops.aten.mul.Tensor(reciprocal_45, 1); reciprocal_45 = None + mul_2048 = torch.ops.aten.mul.Tensor(convert_element_type_3106, mul_2047); convert_element_type_3106 = None + sub_758 = torch.ops.aten.sub.Tensor(1, mul_2047); mul_2047 = None + mul_2049 = torch.ops.aten.mul.Tensor(convert_element_type_238, sub_758); convert_element_type_238 = sub_758 = None + add_2109 = torch.ops.aten.add.Tensor(mul_2049, 1); mul_2049 = None + mul_2050 = torch.ops.aten.mul.Tensor(mul_2048, add_2109); mul_2048 = add_2109 = None + convert_element_type_3108 = torch.ops.prims.convert_element_type.default(mul_2050, torch.bfloat16); mul_2050 = None + permute_1528 = torch.ops.aten.permute.default(convert_element_type_3108, [1, 0]) + _grouped_mm_214 = torch.ops.aten._grouped_mm.default(permute_1528, index_7, cumsum_11); permute_1528 = index_7 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16); primals_81 = None + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 8, '513'); convert_element_type_232 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_85, [0, 2, 1]); wait_tensor_85 = None + permute_1530 = torch.ops.aten.permute.default(permute_65, [0, 2, 1]); permute_65 = None + _grouped_mm_215 = torch.ops.aten._grouped_mm.default(convert_element_type_3108, permute_1530, cumsum_11); convert_element_type_3108 = permute_1530 = cumsum_11 = None + add_2110 = torch.ops.aten.add.Tensor(_grouped_mm_213, _grouped_mm_215); _grouped_mm_213 = _grouped_mm_215 = None + convert_element_type_3109 = torch.ops.prims.convert_element_type.default(_grouped_mm_212, torch.float32); _grouped_mm_212 = None + div_264 = torch.ops.aten.div.Tensor(convert_element_type_3109, 64); convert_element_type_3109 = None + reduce_scatter_tensor_313 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_264, 'sum', 8, '513'); div_264 = None + wait_tensor_917 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_313); reduce_scatter_tensor_313 = None + convert_element_type_3110 = torch.ops.prims.convert_element_type.default(_grouped_mm_210, torch.float32); _grouped_mm_210 = None + div_265 = torch.ops.aten.div.Tensor(convert_element_type_3110, 64); convert_element_type_3110 = None + reduce_scatter_tensor_314 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_265, 'sum', 8, '513'); div_265 = None + wait_tensor_918 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_314); reduce_scatter_tensor_314 = None + convert_element_type_3111 = torch.ops.prims.convert_element_type.default(_grouped_mm_214, torch.float32); _grouped_mm_214 = None + div_266 = torch.ops.aten.div.Tensor(convert_element_type_3111, 64); convert_element_type_3111 = None + reduce_scatter_tensor_315 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_266, 'sum', 8, '513'); div_266 = None + wait_tensor_919 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_315); reduce_scatter_tensor_315 = None + index_put_96 = torch.ops.aten.index_put.default(full_436, [getitem_64], add_2110, True); full_436 = getitem_64 = add_2110 = None + slice_239 = torch.ops.aten.slice.Tensor(index_put_96, 0, 0, add_2111); index_put_96 = add_2111 = None + all_to_all_single_123 = torch.ops._c10d_functional.all_to_all_single.default(slice_239, [_local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55], [_local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63], '521'); slice_239 = _local_scalar_dense_48 = _local_scalar_dense_49 = _local_scalar_dense_50 = _local_scalar_dense_51 = _local_scalar_dense_52 = _local_scalar_dense_53 = _local_scalar_dense_54 = _local_scalar_dense_55 = _local_scalar_dense_56 = _local_scalar_dense_57 = _local_scalar_dense_58 = _local_scalar_dense_59 = _local_scalar_dense_60 = _local_scalar_dense_61 = _local_scalar_dense_62 = _local_scalar_dense_63 = None + wait_tensor_920 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_123); all_to_all_single_123 = None + index_put_97 = torch.ops.aten.index_put.default(full_default_52, [div_17], wait_tensor_920, True); div_17 = wait_tensor_920 = None + add_2115 = torch.ops.aten.add.Tensor(add_2107, index_put_97); add_2107 = index_put_97 = None + mul_2051 = torch.ops.aten.mul.Tensor(view_2201, 1.0); view_2201 = None + scatter_add_22 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_61, mul_2051); getitem_61 = mul_2051 = None + convert_element_type_227 = torch.ops.prims.convert_element_type.default(mm_35, torch.float32); mm_35 = None + sub_72 = torch.ops.aten.sub.Tensor(convert_element_type_227, amax_3); convert_element_type_227 = amax_3 = None + exp_10 = torch.ops.aten.exp.default(sub_72); sub_72 = None + div_16 = torch.ops.aten.div.Tensor(exp_10, sum_13); exp_10 = sum_13 = None + mul_2052 = torch.ops.aten.mul.Tensor(scatter_add_22, div_16); scatter_add_22 = None + sum_283 = torch.ops.aten.sum.dim_IntList(mul_2052, [1], True) + neg_121 = torch.ops.aten.neg.default(div_16); div_16 = None + fma_22 = torch.ops.prims.fma.default(neg_121, sum_283, mul_2052); neg_121 = sum_283 = mul_2052 = None + convert_element_type_3112 = torch.ops.prims.convert_element_type.default(fma_22, torch.bfloat16); fma_22 = None + permute_1532 = torch.ops.aten.permute.default(convert_element_type_3112, [1, 0]) + mm_576 = torch.ops.aten.mm.default(permute_1532, view_259); permute_1532 = view_259 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16); primals_79 = None + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_224, 64, '0'); convert_element_type_224 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_81, [1, 0]); wait_tensor_81 = None + permute_1534 = torch.ops.aten.permute.default(permute_64, [1, 0]); permute_64 = None + mm_577 = torch.ops.aten.mm.default(convert_element_type_3112, permute_1534); convert_element_type_3112 = permute_1534 = None + add_2116 = torch.ops.aten.add.Tensor(add_2115, mm_577); add_2115 = mm_577 = None + convert_element_type_3117 = torch.ops.prims.convert_element_type.default(mm_576, torch.float32); mm_576 = None + reduce_scatter_tensor_316 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3117, 'avg', 64, '0'); convert_element_type_3117 = None + wait_tensor_921 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_316); reduce_scatter_tensor_316 = None + view_2203 = torch.ops.aten.view.default(add_2116, [2, 4096, 2048]); add_2116 = None + convert_element_type_3118 = torch.ops.prims.convert_element_type.default(view_2203, torch.float32); view_2203 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16); primals_77 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 64, '0'); convert_element_type_221 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + convert_element_type_3120 = torch.ops.prims.convert_element_type.default(wait_tensor_80, torch.float32); wait_tensor_80 = None + mul_2053 = torch.ops.aten.mul.Tensor(convert_element_type_3118, convert_element_type_3120); convert_element_type_3120 = None + convert_element_type_222 = torch.ops.prims.convert_element_type.default(add_212, torch.float32); add_212 = None + mul_162 = torch.ops.aten.mul.Tensor(convert_element_type_222, rsqrt_14); convert_element_type_222 = None + mul_2055 = torch.ops.aten.mul.Tensor(mul_162, mul_2053) + sum_284 = torch.ops.aten.sum.dim_IntList(mul_2055, [2], True); mul_2055 = None + div_267 = torch.ops.aten.div.Tensor(mul_162, 2048) + mul_2056 = torch.ops.aten.mul.Tensor(div_267, sum_284); div_267 = sum_284 = None + sub_760 = torch.ops.aten.sub.Tensor(mul_2053, mul_2056); mul_2053 = mul_2056 = None + mul_2057 = torch.ops.aten.mul.Tensor(sub_760, rsqrt_14); sub_760 = rsqrt_14 = None + mul_2058 = torch.ops.aten.mul.Tensor(convert_element_type_3118, mul_162); convert_element_type_3118 = mul_162 = None + sum_285 = torch.ops.aten.sum.dim_IntList(mul_2058, [0, 1]); mul_2058 = None + convert_element_type_3121 = torch.ops.prims.convert_element_type.default(mul_2057, torch.bfloat16); mul_2057 = None + add_2117 = torch.ops.aten.add.Tensor(add_2104, convert_element_type_3121); add_2104 = convert_element_type_3121 = None + convert_element_type_default_15 = torch.ops.prims.convert_element_type.default(sum_285, torch.float32); sum_285 = None + reduce_scatter_tensor_317 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_15, 'avg', 64, '0'); convert_element_type_default_15 = None + wait_tensor_922 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_317); reduce_scatter_tensor_317 = None + view_2204 = torch.ops.aten.view.default(add_2117, [8192, 2048]) + permute_1536 = torch.ops.aten.permute.default(view_2204, [1, 0]) + permute_62 = torch.ops.aten.permute.default(getitem_57, [0, 2, 1, 3]) + view_254 = torch.ops.aten.view.default(permute_62, [2, 4096, -1]); permute_62 = None + view_256 = torch.ops.aten.view.default(view_254, [8192, 2048]); view_254 = None + mm_578 = torch.ops.aten.mm.default(permute_1536, view_256); permute_1536 = view_256 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16); primals_76 = None + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 64, '0'); convert_element_type_218 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + permute_1538 = torch.ops.aten.permute.default(permute_63, [1, 0]); permute_63 = None + mm_579 = torch.ops.aten.mm.default(view_2204, permute_1538); view_2204 = permute_1538 = None + view_2205 = torch.ops.aten.view.default(mm_579, [2, 4096, 2048]); mm_579 = None + convert_element_type_3128 = torch.ops.prims.convert_element_type.default(mm_578, torch.float32); mm_578 = None + reduce_scatter_tensor_318 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3128, 'avg', 64, '0'); convert_element_type_3128 = None + wait_tensor_923 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_318); reduce_scatter_tensor_318 = None + view_2206 = torch.ops.aten.view.default(view_2205, [2, 4096, 16, 128]); view_2205 = None + permute_1540 = torch.ops.aten.permute.default(view_2206, [0, 2, 1, 3]); view_2206 = None + fw_graph22 = self.fw_graph22 + joint_graph22 = self.joint_graph22 + mask_graph22 = self.mask_graph22 + flex_attention_backward_22 = torch.ops.higher_order.flex_attention_backward(permute_59, permute_60, permute_61, getitem_57, getitem_58, permute_1540, None, fw_graph22, joint_graph22, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph22), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_59 = permute_60 = permute_61 = getitem_57 = getitem_58 = permute_1540 = fw_graph22 = joint_graph22 = mask_graph22 = None + getitem_461 = flex_attention_backward_22[0] + getitem_462 = flex_attention_backward_22[1] + getitem_463 = flex_attention_backward_22[2]; flex_attention_backward_22 = None + permute_1541 = torch.ops.aten.permute.default(getitem_463, [0, 2, 1, 3]); getitem_463 = None + permute_1542 = torch.ops.aten.permute.default(getitem_462, [0, 2, 1, 3]); getitem_462 = None + permute_1543 = torch.ops.aten.permute.default(getitem_461, [0, 2, 1, 3]); getitem_461 = None + slice_241 = torch.ops.aten.slice.Tensor(permute_1542, 3, 0, 128) + slice_242 = torch.ops.aten.slice.Tensor(permute_1542, 3, 128, 192); permute_1542 = None + sum_286 = torch.ops.aten.sum.dim_IntList(slice_242, [2], True); slice_242 = None + cat_146 = torch.ops.aten.cat.default([slice_241, permute_1541], 3); slice_241 = permute_1541 = None + view_2207 = torch.ops.aten.view.default(cat_146, [2, 4096, 4096]); cat_146 = None + view_2208 = torch.ops.aten.view.default(view_2207, [8192, 4096]); view_2207 = None + permute_1544 = torch.ops.aten.permute.default(view_2208, [1, 0]) + mm_580 = torch.ops.aten.mm.default(permute_1544, view_251); permute_1544 = view_251 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16); primals_75 = None + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 64, '0'); convert_element_type_215 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_58 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + permute_1546 = torch.ops.aten.permute.default(permute_58, [1, 0]); permute_58 = None + mm_581 = torch.ops.aten.mm.default(view_2208, permute_1546); view_2208 = permute_1546 = None + view_2209 = torch.ops.aten.view.default(mm_581, [2, 4096, 512]); mm_581 = None + convert_element_type_3133 = torch.ops.prims.convert_element_type.default(mm_580, torch.float32); mm_580 = None + reduce_scatter_tensor_319 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3133, 'avg', 64, '0'); convert_element_type_3133 = None + wait_tensor_924 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_319); reduce_scatter_tensor_319 = None + convert_element_type_3134 = torch.ops.prims.convert_element_type.default(view_2209, torch.float32); view_2209 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16); primals_74 = None + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_212, 64, '0'); convert_element_type_212 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + convert_element_type_3136 = torch.ops.prims.convert_element_type.default(wait_tensor_77, torch.float32); wait_tensor_77 = None + mul_2059 = torch.ops.aten.mul.Tensor(convert_element_type_3134, convert_element_type_3136); convert_element_type_3136 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(getitem_53, torch.float32); getitem_53 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_213, rsqrt_13); convert_element_type_213 = None + mul_2061 = torch.ops.aten.mul.Tensor(mul_160, mul_2059) + sum_287 = torch.ops.aten.sum.dim_IntList(mul_2061, [2], True); mul_2061 = None + div_268 = torch.ops.aten.div.Tensor(mul_160, 512) + mul_2062 = torch.ops.aten.mul.Tensor(div_268, sum_287); div_268 = sum_287 = None + sub_761 = torch.ops.aten.sub.Tensor(mul_2059, mul_2062); mul_2059 = mul_2062 = None + mul_2063 = torch.ops.aten.mul.Tensor(sub_761, rsqrt_13); sub_761 = rsqrt_13 = None + mul_2064 = torch.ops.aten.mul.Tensor(convert_element_type_3134, mul_160); convert_element_type_3134 = mul_160 = None + sum_288 = torch.ops.aten.sum.dim_IntList(mul_2064, [0, 1]); mul_2064 = None + convert_element_type_3137 = torch.ops.prims.convert_element_type.default(mul_2063, torch.bfloat16); mul_2063 = None + convert_element_type_default_14 = torch.ops.prims.convert_element_type.default(sum_288, torch.float32); sum_288 = None + reduce_scatter_tensor_320 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_14, 'avg', 64, '0'); convert_element_type_default_14 = None + wait_tensor_925 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_320); reduce_scatter_tensor_320 = None + convert_element_type_3140 = torch.ops.prims.convert_element_type.default(sum_286, torch.float32); sum_286 = None + view_2210 = torch.ops.aten.view.default(convert_element_type_3140, [2, 4096, 1, 32, 2]); convert_element_type_3140 = None + view_as_complex_98 = torch.ops.aten.view_as_complex.default(view_2210); view_2210 = None + mul_2065 = torch.ops.aten.mul.Tensor(view_as_complex_98, clone_9); view_as_complex_98 = None + view_as_real_98 = torch.ops.aten.view_as_real.default(mul_2065); mul_2065 = None + view_2211 = torch.ops.aten.view.default(view_as_real_98, [2, 4096, 1, 64]); view_as_real_98 = None + convert_element_type_3141 = torch.ops.prims.convert_element_type.default(view_2211, torch.bfloat16); view_2211 = None + squeeze_48 = torch.ops.aten.squeeze.dim(convert_element_type_3141, 2); convert_element_type_3141 = None + cat_147 = torch.ops.aten.cat.default([convert_element_type_3137, squeeze_48], 2); convert_element_type_3137 = squeeze_48 = None + view_2212 = torch.ops.aten.view.default(cat_147, [8192, 576]); cat_147 = None + permute_1548 = torch.ops.aten.permute.default(view_2212, [1, 0]) + mm_582 = torch.ops.aten.mm.default(permute_1548, view_237); permute_1548 = None + convert_element_type_207 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16); primals_73 = None + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_207, 64, '0'); convert_element_type_207 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + permute_1550 = torch.ops.aten.permute.default(permute_57, [1, 0]); permute_57 = None + mm_583 = torch.ops.aten.mm.default(view_2212, permute_1550); view_2212 = permute_1550 = None + view_2213 = torch.ops.aten.view.default(mm_583, [2, 4096, 2048]); mm_583 = None + convert_element_type_3146 = torch.ops.prims.convert_element_type.default(mm_582, torch.float32); mm_582 = None + reduce_scatter_tensor_321 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3146, 'avg', 64, '0'); convert_element_type_3146 = None + wait_tensor_926 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_321); reduce_scatter_tensor_321 = None + slice_243 = torch.ops.aten.slice.Tensor(permute_1543, 3, 0, 128) + slice_244 = torch.ops.aten.slice.Tensor(permute_1543, 3, 128, 192); permute_1543 = None + convert_element_type_3147 = torch.ops.prims.convert_element_type.default(slice_244, torch.float32); slice_244 = None + view_2214 = torch.ops.aten.view.default(convert_element_type_3147, [2, 4096, 16, 32, 2]); convert_element_type_3147 = None + view_as_complex_99 = torch.ops.aten.view_as_complex.default(view_2214); view_2214 = None + mul_2066 = torch.ops.aten.mul.Tensor(view_as_complex_99, clone_9); view_as_complex_99 = None + view_as_real_99 = torch.ops.aten.view_as_real.default(mul_2066); mul_2066 = None + view_2215 = torch.ops.aten.view.default(view_as_real_99, [2, 4096, 16, 64]); view_as_real_99 = None + convert_element_type_3148 = torch.ops.prims.convert_element_type.default(view_2215, torch.bfloat16); view_2215 = None + cat_148 = torch.ops.aten.cat.default([slice_243, convert_element_type_3148], 3); slice_243 = convert_element_type_3148 = None + view_2216 = torch.ops.aten.view.default(cat_148, [2, 4096, 3072]); cat_148 = None + view_2217 = torch.ops.aten.view.default(view_2216, [8192, 3072]); view_2216 = None + permute_1552 = torch.ops.aten.permute.default(view_2217, [1, 0]) + mm_584 = torch.ops.aten.mm.default(permute_1552, view_237); permute_1552 = view_237 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 64, '0'); convert_element_type_202 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + permute_1554 = torch.ops.aten.permute.default(permute_56, [1, 0]); permute_56 = None + mm_585 = torch.ops.aten.mm.default(view_2217, permute_1554); view_2217 = permute_1554 = None + view_2218 = torch.ops.aten.view.default(mm_585, [2, 4096, 2048]); mm_585 = None + add_2118 = torch.ops.aten.add.Tensor(view_2213, view_2218); view_2213 = view_2218 = None + convert_element_type_3153 = torch.ops.prims.convert_element_type.default(mm_584, torch.float32); mm_584 = None + reduce_scatter_tensor_322 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3153, 'avg', 64, '0'); convert_element_type_3153 = None + wait_tensor_927 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_322); reduce_scatter_tensor_322 = None + convert_element_type_3154 = torch.ops.prims.convert_element_type.default(add_2118, torch.float32); add_2118 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 64, '0'); convert_element_type_199 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_3156 = torch.ops.prims.convert_element_type.default(wait_tensor_74, torch.float32); wait_tensor_74 = None + mul_2067 = torch.ops.aten.mul.Tensor(convert_element_type_3154, convert_element_type_3156); convert_element_type_3156 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_209, torch.float32); add_209 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_2069 = torch.ops.aten.mul.Tensor(mul_156, mul_2067) + sum_289 = torch.ops.aten.sum.dim_IntList(mul_2069, [2], True); mul_2069 = None + div_269 = torch.ops.aten.div.Tensor(mul_156, 2048) + mul_2070 = torch.ops.aten.mul.Tensor(div_269, sum_289); div_269 = sum_289 = None + sub_762 = torch.ops.aten.sub.Tensor(mul_2067, mul_2070); mul_2067 = mul_2070 = None + mul_2071 = torch.ops.aten.mul.Tensor(sub_762, rsqrt_12); sub_762 = rsqrt_12 = None + mul_2072 = torch.ops.aten.mul.Tensor(convert_element_type_3154, mul_156); convert_element_type_3154 = mul_156 = None + sum_290 = torch.ops.aten.sum.dim_IntList(mul_2072, [0, 1]); mul_2072 = None + convert_element_type_3157 = torch.ops.prims.convert_element_type.default(mul_2071, torch.bfloat16); mul_2071 = None + add_2119 = torch.ops.aten.add.Tensor(add_2117, convert_element_type_3157); add_2117 = convert_element_type_3157 = None + convert_element_type_default_13 = torch.ops.prims.convert_element_type.default(sum_290, torch.float32); sum_290 = None + reduce_scatter_tensor_323 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_13, 'avg', 64, '0'); convert_element_type_default_13 = None + wait_tensor_928 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_323); reduce_scatter_tensor_323 = None + view_2219 = torch.ops.aten.view.default(add_2119, [8192, 2048]) + unsqueeze_76 = torch.ops.aten.unsqueeze.default(view_2219, 1) + convert_element_type_3160 = torch.ops.prims.convert_element_type.default(unsqueeze_76, torch.float32); unsqueeze_76 = None + bmm_72 = torch.ops.aten.bmm.default(permute_1556, convert_element_type_3160); permute_1556 = None + bmm_73 = torch.ops.aten.bmm.default(convert_element_type_3160, permute_1557); convert_element_type_3160 = permute_1557 = None + convert_element_type_3161 = torch.ops.prims.convert_element_type.default(bmm_72, torch.bfloat16); bmm_72 = None + view_2220 = torch.ops.aten.view.default(bmm_73, [8192, 6]); bmm_73 = None + view_2221 = torch.ops.aten.view.default(convert_element_type_3161, [49152, 2048]); convert_element_type_3161 = None + index_98 = torch.ops.aten.index.Tensor(view_2221, [getitem_49]); view_2221 = getitem_49 = None + permute_1558 = torch.ops.aten.permute.default(view_2219, [1, 0]) + mm_586 = torch.ops.aten.mm.default(permute_1558, mul_153); permute_1558 = mul_153 = None + convert_element_type_194 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_194, 64, '0'); convert_element_type_194 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_73, [1, 0]); wait_tensor_73 = None + permute_1560 = torch.ops.aten.permute.default(permute_55, [1, 0]); permute_55 = None + mm_587 = torch.ops.aten.mm.default(view_2219, permute_1560); view_2219 = permute_1560 = None + convert_element_type_3166 = torch.ops.prims.convert_element_type.default(mm_586, torch.float32); mm_586 = None + reduce_scatter_tensor_324 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3166, 'avg', 64, '0'); convert_element_type_3166 = None + wait_tensor_929 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_324); reduce_scatter_tensor_324 = None + convert_element_type_189 = torch.ops.prims.convert_element_type.default(mm_28, torch.float32); mm_28 = None + neg_6 = torch.ops.aten.neg.default(convert_element_type_189) + exp_9 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_204 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + div_15 = torch.ops.aten.div.Tensor(convert_element_type_189, add_204) + convert_element_type_190 = torch.ops.prims.convert_element_type.default(div_15, torch.bfloat16); div_15 = None + mul_2073 = torch.ops.aten.mul.Tensor(mm_587, convert_element_type_190); convert_element_type_190 = None + mul_2074 = torch.ops.aten.mul.Tensor(mm_587, mm_29); mm_587 = mm_29 = None + permute_1562 = torch.ops.aten.permute.default(mul_2073, [1, 0]) + mm_588 = torch.ops.aten.mm.default(permute_1562, view_192); permute_1562 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_191, 64, '0'); convert_element_type_191 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + permute_1564 = torch.ops.aten.permute.default(permute_54, [1, 0]); permute_54 = None + mm_589 = torch.ops.aten.mm.default(mul_2073, permute_1564); mul_2073 = permute_1564 = None + convert_element_type_3171 = torch.ops.prims.convert_element_type.default(mm_588, torch.float32); mm_588 = None + reduce_scatter_tensor_325 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3171, 'avg', 64, '0'); convert_element_type_3171 = None + wait_tensor_930 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_325); reduce_scatter_tensor_325 = None + convert_element_type_3172 = torch.ops.prims.convert_element_type.default(mul_2074, torch.float32); mul_2074 = None + reciprocal_46 = torch.ops.aten.reciprocal.default(add_204); add_204 = None + mul_2075 = torch.ops.aten.mul.Tensor(reciprocal_46, 1); reciprocal_46 = None + mul_2076 = torch.ops.aten.mul.Tensor(convert_element_type_3172, mul_2075); convert_element_type_3172 = None + sub_763 = torch.ops.aten.sub.Tensor(1, mul_2075); mul_2075 = None + mul_2077 = torch.ops.aten.mul.Tensor(convert_element_type_189, sub_763); convert_element_type_189 = sub_763 = None + add_2121 = torch.ops.aten.add.Tensor(mul_2077, 1); mul_2077 = None + mul_2078 = torch.ops.aten.mul.Tensor(mul_2076, add_2121); mul_2076 = add_2121 = None + convert_element_type_3174 = torch.ops.prims.convert_element_type.default(mul_2078, torch.bfloat16); mul_2078 = None + permute_1566 = torch.ops.aten.permute.default(convert_element_type_3174, [1, 0]) + mm_590 = torch.ops.aten.mm.default(permute_1566, view_192); permute_1566 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_186, 64, '0'); convert_element_type_186 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + permute_1568 = torch.ops.aten.permute.default(permute_53, [1, 0]); permute_53 = None + mm_591 = torch.ops.aten.mm.default(convert_element_type_3174, permute_1568); convert_element_type_3174 = permute_1568 = None + add_2122 = torch.ops.aten.add.Tensor(mm_589, mm_591); mm_589 = mm_591 = None + convert_element_type_3179 = torch.ops.prims.convert_element_type.default(mm_590, torch.float32); mm_590 = None + reduce_scatter_tensor_326 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3179, 'avg', 64, '0'); convert_element_type_3179 = None + wait_tensor_931 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_326); reduce_scatter_tensor_326 = None + all_to_all_single_124 = torch.ops._c10d_functional.all_to_all_single.default(index_98, [_local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47], [_local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39], '521'); index_98 = None + wait_tensor_932 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_124); all_to_all_single_124 = None + full_440 = torch.ops.aten.full.default([sym_size_int_9, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_9 = None + slice_scatter_23 = torch.ops.aten.slice_scatter.default(full_440, wait_tensor_932, 0, 0, -1); wait_tensor_932 = None + index_99 = torch.ops.aten.index.Tensor(slice_scatter_23, [getitem_50]); slice_scatter_23 = None + permute_1570 = torch.ops.aten.permute.default(index_99, [1, 0]) + _grouped_mm_216 = torch.ops.aten._grouped_mm.default(permute_1570, mul_133, cumsum_8); permute_1570 = mul_133 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_180, 8, '513'); convert_element_type_180 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_66, [0, 2, 1]); wait_tensor_66 = None + permute_1572 = torch.ops.aten.permute.default(permute_52, [0, 2, 1]); permute_52 = None + _grouped_mm_217 = torch.ops.aten._grouped_mm.default(index_99, permute_1572, cumsum_8); index_99 = permute_1572 = None + convert_element_type_184 = torch.ops.prims.convert_element_type.default(_grouped_mm_6, torch.float32); _grouped_mm_6 = None + neg_5 = torch.ops.aten.neg.default(convert_element_type_184) + exp_8 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_168 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + div_14 = torch.ops.aten.div.Tensor(convert_element_type_184, add_168) + convert_element_type_185 = torch.ops.prims.convert_element_type.default(div_14, torch.bfloat16); div_14 = None + mul_2079 = torch.ops.aten.mul.Tensor(_grouped_mm_217, convert_element_type_185); convert_element_type_185 = None + mul_2080 = torch.ops.aten.mul.Tensor(_grouped_mm_217, _grouped_mm_7); _grouped_mm_217 = _grouped_mm_7 = None + permute_1574 = torch.ops.aten.permute.default(mul_2079, [1, 0]) + _grouped_mm_218 = torch.ops.aten._grouped_mm.default(permute_1574, index_5, cumsum_8); permute_1574 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_181, 8, '513'); convert_element_type_181 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_67, [0, 2, 1]); wait_tensor_67 = None + permute_1576 = torch.ops.aten.permute.default(permute_51, [0, 2, 1]); permute_51 = None + _grouped_mm_219 = torch.ops.aten._grouped_mm.default(mul_2079, permute_1576, cumsum_8); mul_2079 = permute_1576 = None + convert_element_type_3180 = torch.ops.prims.convert_element_type.default(mul_2080, torch.float32); mul_2080 = None + reciprocal_47 = torch.ops.aten.reciprocal.default(add_168); add_168 = None + mul_2081 = torch.ops.aten.mul.Tensor(reciprocal_47, 1); reciprocal_47 = None + mul_2082 = torch.ops.aten.mul.Tensor(convert_element_type_3180, mul_2081); convert_element_type_3180 = None + sub_764 = torch.ops.aten.sub.Tensor(1, mul_2081); mul_2081 = None + mul_2083 = torch.ops.aten.mul.Tensor(convert_element_type_184, sub_764); convert_element_type_184 = sub_764 = None + add_2124 = torch.ops.aten.add.Tensor(mul_2083, 1); mul_2083 = None + mul_2084 = torch.ops.aten.mul.Tensor(mul_2082, add_2124); mul_2082 = add_2124 = None + convert_element_type_3182 = torch.ops.prims.convert_element_type.default(mul_2084, torch.bfloat16); mul_2084 = None + permute_1578 = torch.ops.aten.permute.default(convert_element_type_3182, [1, 0]) + _grouped_mm_220 = torch.ops.aten._grouped_mm.default(permute_1578, index_5, cumsum_8); permute_1578 = index_5 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_178, 8, '513'); convert_element_type_178 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_50 = torch.ops.aten.permute.default(wait_tensor_64, [0, 2, 1]); wait_tensor_64 = None + permute_1580 = torch.ops.aten.permute.default(permute_50, [0, 2, 1]); permute_50 = None + _grouped_mm_221 = torch.ops.aten._grouped_mm.default(convert_element_type_3182, permute_1580, cumsum_8); convert_element_type_3182 = permute_1580 = cumsum_8 = None + add_2125 = torch.ops.aten.add.Tensor(_grouped_mm_219, _grouped_mm_221); _grouped_mm_219 = _grouped_mm_221 = None + convert_element_type_3183 = torch.ops.prims.convert_element_type.default(_grouped_mm_218, torch.float32); _grouped_mm_218 = None + div_270 = torch.ops.aten.div.Tensor(convert_element_type_3183, 64); convert_element_type_3183 = None + reduce_scatter_tensor_327 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_270, 'sum', 8, '513'); div_270 = None + wait_tensor_933 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_327); reduce_scatter_tensor_327 = None + convert_element_type_3184 = torch.ops.prims.convert_element_type.default(_grouped_mm_216, torch.float32); _grouped_mm_216 = None + div_271 = torch.ops.aten.div.Tensor(convert_element_type_3184, 64); convert_element_type_3184 = None + reduce_scatter_tensor_328 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_271, 'sum', 8, '513'); div_271 = None + wait_tensor_934 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_328); reduce_scatter_tensor_328 = None + convert_element_type_3185 = torch.ops.prims.convert_element_type.default(_grouped_mm_220, torch.float32); _grouped_mm_220 = None + div_272 = torch.ops.aten.div.Tensor(convert_element_type_3185, 64); convert_element_type_3185 = None + reduce_scatter_tensor_329 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_272, 'sum', 8, '513'); div_272 = None + wait_tensor_935 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_329); reduce_scatter_tensor_329 = None + index_put_98 = torch.ops.aten.index_put.default(full_440, [getitem_50], add_2125, True); full_440 = getitem_50 = add_2125 = None + slice_245 = torch.ops.aten.slice.Tensor(index_put_98, 0, 0, add_2126); index_put_98 = add_2126 = None + all_to_all_single_125 = torch.ops._c10d_functional.all_to_all_single.default(slice_245, [_local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39], [_local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47], '521'); slice_245 = _local_scalar_dense_32 = _local_scalar_dense_33 = _local_scalar_dense_34 = _local_scalar_dense_35 = _local_scalar_dense_36 = _local_scalar_dense_37 = _local_scalar_dense_38 = _local_scalar_dense_39 = _local_scalar_dense_40 = _local_scalar_dense_41 = _local_scalar_dense_42 = _local_scalar_dense_43 = _local_scalar_dense_44 = _local_scalar_dense_45 = _local_scalar_dense_46 = _local_scalar_dense_47 = None + wait_tensor_936 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_125); all_to_all_single_125 = None + index_put_99 = torch.ops.aten.index_put.default(full_default_52, [div_12], wait_tensor_936, True); div_12 = wait_tensor_936 = None + add_2130 = torch.ops.aten.add.Tensor(add_2122, index_put_99); add_2122 = index_put_99 = None + mul_2085 = torch.ops.aten.mul.Tensor(view_2220, 1.0); view_2220 = None + scatter_add_23 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_47, mul_2085); getitem_47 = mul_2085 = None + convert_element_type_173 = torch.ops.prims.convert_element_type.default(mm_27, torch.float32); mm_27 = None + sub_48 = torch.ops.aten.sub.Tensor(convert_element_type_173, amax_2); convert_element_type_173 = amax_2 = None + exp_7 = torch.ops.aten.exp.default(sub_48); sub_48 = None + div_11 = torch.ops.aten.div.Tensor(exp_7, sum_9); exp_7 = sum_9 = None + mul_2086 = torch.ops.aten.mul.Tensor(scatter_add_23, div_11); scatter_add_23 = None + sum_291 = torch.ops.aten.sum.dim_IntList(mul_2086, [1], True) + neg_124 = torch.ops.aten.neg.default(div_11); div_11 = None + fma_23 = torch.ops.prims.fma.default(neg_124, sum_291, mul_2086); neg_124 = sum_291 = mul_2086 = None + convert_element_type_3186 = torch.ops.prims.convert_element_type.default(fma_23, torch.bfloat16); fma_23 = None + permute_1582 = torch.ops.aten.permute.default(convert_element_type_3186, [1, 0]) + mm_592 = torch.ops.aten.mm.default(permute_1582, view_192); permute_1582 = view_192 = None + convert_element_type_170 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16); primals_63 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_170, 64, '0'); convert_element_type_170 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + permute_49 = torch.ops.aten.permute.default(wait_tensor_60, [1, 0]); wait_tensor_60 = None + permute_1584 = torch.ops.aten.permute.default(permute_49, [1, 0]); permute_49 = None + mm_593 = torch.ops.aten.mm.default(convert_element_type_3186, permute_1584); convert_element_type_3186 = permute_1584 = None + add_2131 = torch.ops.aten.add.Tensor(add_2130, mm_593); add_2130 = mm_593 = None + convert_element_type_3191 = torch.ops.prims.convert_element_type.default(mm_592, torch.float32); mm_592 = None + reduce_scatter_tensor_330 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3191, 'avg', 64, '0'); convert_element_type_3191 = None + wait_tensor_937 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_330); reduce_scatter_tensor_330 = None + view_2222 = torch.ops.aten.view.default(add_2131, [2, 4096, 2048]); add_2131 = None + convert_element_type_3192 = torch.ops.prims.convert_element_type.default(view_2222, torch.float32); view_2222 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16); primals_61 = None + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_167, 64, '0'); convert_element_type_167 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_3194 = torch.ops.prims.convert_element_type.default(wait_tensor_59, torch.float32); wait_tensor_59 = None + mul_2087 = torch.ops.aten.mul.Tensor(convert_element_type_3192, convert_element_type_3194); convert_element_type_3194 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(add_144, torch.float32); add_144 = None + mul_113 = torch.ops.aten.mul.Tensor(convert_element_type_168, rsqrt_11); convert_element_type_168 = None + mul_2089 = torch.ops.aten.mul.Tensor(mul_113, mul_2087) + sum_292 = torch.ops.aten.sum.dim_IntList(mul_2089, [2], True); mul_2089 = None + div_273 = torch.ops.aten.div.Tensor(mul_113, 2048) + mul_2090 = torch.ops.aten.mul.Tensor(div_273, sum_292); div_273 = sum_292 = None + sub_766 = torch.ops.aten.sub.Tensor(mul_2087, mul_2090); mul_2087 = mul_2090 = None + mul_2091 = torch.ops.aten.mul.Tensor(sub_766, rsqrt_11); sub_766 = rsqrt_11 = None + mul_2092 = torch.ops.aten.mul.Tensor(convert_element_type_3192, mul_113); convert_element_type_3192 = mul_113 = None + sum_293 = torch.ops.aten.sum.dim_IntList(mul_2092, [0, 1]); mul_2092 = None + convert_element_type_3195 = torch.ops.prims.convert_element_type.default(mul_2091, torch.bfloat16); mul_2091 = None + add_2132 = torch.ops.aten.add.Tensor(add_2119, convert_element_type_3195); add_2119 = convert_element_type_3195 = None + convert_element_type_default_12 = torch.ops.prims.convert_element_type.default(sum_293, torch.float32); sum_293 = None + reduce_scatter_tensor_331 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_12, 'avg', 64, '0'); convert_element_type_default_12 = None + wait_tensor_938 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_331); reduce_scatter_tensor_331 = None + view_2223 = torch.ops.aten.view.default(add_2132, [8192, 2048]) + permute_1586 = torch.ops.aten.permute.default(view_2223, [1, 0]) + permute_47 = torch.ops.aten.permute.default(getitem_43, [0, 2, 1, 3]) + view_187 = torch.ops.aten.view.default(permute_47, [2, 4096, -1]); permute_47 = None + view_189 = torch.ops.aten.view.default(view_187, [8192, 2048]); view_187 = None + mm_594 = torch.ops.aten.mm.default(permute_1586, view_189); permute_1586 = view_189 = None + convert_element_type_164 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16); primals_60 = None + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_164, 64, '0'); convert_element_type_164 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_48 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + permute_1588 = torch.ops.aten.permute.default(permute_48, [1, 0]); permute_48 = None + mm_595 = torch.ops.aten.mm.default(view_2223, permute_1588); view_2223 = permute_1588 = None + view_2224 = torch.ops.aten.view.default(mm_595, [2, 4096, 2048]); mm_595 = None + convert_element_type_3202 = torch.ops.prims.convert_element_type.default(mm_594, torch.float32); mm_594 = None + reduce_scatter_tensor_332 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3202, 'avg', 64, '0'); convert_element_type_3202 = None + wait_tensor_939 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_332); reduce_scatter_tensor_332 = None + view_2225 = torch.ops.aten.view.default(view_2224, [2, 4096, 16, 128]); view_2224 = None + permute_1590 = torch.ops.aten.permute.default(view_2225, [0, 2, 1, 3]); view_2225 = None + fw_graph23 = self.fw_graph23 + joint_graph23 = self.joint_graph23 + mask_graph23 = self.mask_graph23 + flex_attention_backward_23 = torch.ops.higher_order.flex_attention_backward(permute_44, permute_45, permute_46, getitem_43, getitem_44, permute_1590, None, fw_graph23, joint_graph23, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph23), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_44 = permute_45 = permute_46 = getitem_43 = getitem_44 = permute_1590 = fw_graph23 = joint_graph23 = mask_graph23 = None + getitem_465 = flex_attention_backward_23[0] + getitem_466 = flex_attention_backward_23[1] + getitem_467 = flex_attention_backward_23[2]; flex_attention_backward_23 = None + permute_1591 = torch.ops.aten.permute.default(getitem_467, [0, 2, 1, 3]); getitem_467 = None + permute_1592 = torch.ops.aten.permute.default(getitem_466, [0, 2, 1, 3]); getitem_466 = None + permute_1593 = torch.ops.aten.permute.default(getitem_465, [0, 2, 1, 3]); getitem_465 = None + slice_247 = torch.ops.aten.slice.Tensor(permute_1592, 3, 0, 128) + slice_248 = torch.ops.aten.slice.Tensor(permute_1592, 3, 128, 192); permute_1592 = None + sum_294 = torch.ops.aten.sum.dim_IntList(slice_248, [2], True); slice_248 = None + cat_149 = torch.ops.aten.cat.default([slice_247, permute_1591], 3); slice_247 = permute_1591 = None + view_2226 = torch.ops.aten.view.default(cat_149, [2, 4096, 4096]); cat_149 = None + view_2227 = torch.ops.aten.view.default(view_2226, [8192, 4096]); view_2226 = None + permute_1594 = torch.ops.aten.permute.default(view_2227, [1, 0]) + mm_596 = torch.ops.aten.mm.default(permute_1594, view_184); permute_1594 = view_184 = None + convert_element_type_161 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16); primals_59 = None + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_161, 64, '0'); convert_element_type_161 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + permute_1596 = torch.ops.aten.permute.default(permute_43, [1, 0]); permute_43 = None + mm_597 = torch.ops.aten.mm.default(view_2227, permute_1596); view_2227 = permute_1596 = None + view_2228 = torch.ops.aten.view.default(mm_597, [2, 4096, 512]); mm_597 = None + convert_element_type_3207 = torch.ops.prims.convert_element_type.default(mm_596, torch.float32); mm_596 = None + reduce_scatter_tensor_333 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3207, 'avg', 64, '0'); convert_element_type_3207 = None + wait_tensor_940 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_333); reduce_scatter_tensor_333 = None + convert_element_type_3208 = torch.ops.prims.convert_element_type.default(view_2228, torch.float32); view_2228 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16); primals_58 = None + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_158, 64, '0'); convert_element_type_158 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + convert_element_type_3210 = torch.ops.prims.convert_element_type.default(wait_tensor_56, torch.float32); wait_tensor_56 = None + mul_2093 = torch.ops.aten.mul.Tensor(convert_element_type_3208, convert_element_type_3210); convert_element_type_3210 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(getitem_39, torch.float32); getitem_39 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_159, rsqrt_10); convert_element_type_159 = None + mul_2095 = torch.ops.aten.mul.Tensor(mul_111, mul_2093) + sum_295 = torch.ops.aten.sum.dim_IntList(mul_2095, [2], True); mul_2095 = None + div_274 = torch.ops.aten.div.Tensor(mul_111, 512) + mul_2096 = torch.ops.aten.mul.Tensor(div_274, sum_295); div_274 = sum_295 = None + sub_767 = torch.ops.aten.sub.Tensor(mul_2093, mul_2096); mul_2093 = mul_2096 = None + mul_2097 = torch.ops.aten.mul.Tensor(sub_767, rsqrt_10); sub_767 = rsqrt_10 = None + mul_2098 = torch.ops.aten.mul.Tensor(convert_element_type_3208, mul_111); convert_element_type_3208 = mul_111 = None + sum_296 = torch.ops.aten.sum.dim_IntList(mul_2098, [0, 1]); mul_2098 = None + convert_element_type_3211 = torch.ops.prims.convert_element_type.default(mul_2097, torch.bfloat16); mul_2097 = None + convert_element_type_default_11 = torch.ops.prims.convert_element_type.default(sum_296, torch.float32); sum_296 = None + reduce_scatter_tensor_334 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_11, 'avg', 64, '0'); convert_element_type_default_11 = None + wait_tensor_941 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_334); reduce_scatter_tensor_334 = None + convert_element_type_3214 = torch.ops.prims.convert_element_type.default(sum_294, torch.float32); sum_294 = None + view_2229 = torch.ops.aten.view.default(convert_element_type_3214, [2, 4096, 1, 32, 2]); convert_element_type_3214 = None + view_as_complex_100 = torch.ops.aten.view_as_complex.default(view_2229); view_2229 = None + mul_2099 = torch.ops.aten.mul.Tensor(view_as_complex_100, clone_9); view_as_complex_100 = None + view_as_real_100 = torch.ops.aten.view_as_real.default(mul_2099); mul_2099 = None + view_2230 = torch.ops.aten.view.default(view_as_real_100, [2, 4096, 1, 64]); view_as_real_100 = None + convert_element_type_3215 = torch.ops.prims.convert_element_type.default(view_2230, torch.bfloat16); view_2230 = None + squeeze_49 = torch.ops.aten.squeeze.dim(convert_element_type_3215, 2); convert_element_type_3215 = None + cat_150 = torch.ops.aten.cat.default([convert_element_type_3211, squeeze_49], 2); convert_element_type_3211 = squeeze_49 = None + view_2231 = torch.ops.aten.view.default(cat_150, [8192, 576]); cat_150 = None + permute_1598 = torch.ops.aten.permute.default(view_2231, [1, 0]) + mm_598 = torch.ops.aten.mm.default(permute_1598, view_170); permute_1598 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16); primals_57 = None + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_153, 64, '0'); convert_element_type_153 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_55, [1, 0]); wait_tensor_55 = None + permute_1600 = torch.ops.aten.permute.default(permute_42, [1, 0]); permute_42 = None + mm_599 = torch.ops.aten.mm.default(view_2231, permute_1600); view_2231 = permute_1600 = None + view_2232 = torch.ops.aten.view.default(mm_599, [2, 4096, 2048]); mm_599 = None + convert_element_type_3220 = torch.ops.prims.convert_element_type.default(mm_598, torch.float32); mm_598 = None + reduce_scatter_tensor_335 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3220, 'avg', 64, '0'); convert_element_type_3220 = None + wait_tensor_942 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_335); reduce_scatter_tensor_335 = None + slice_249 = torch.ops.aten.slice.Tensor(permute_1593, 3, 0, 128) + slice_250 = torch.ops.aten.slice.Tensor(permute_1593, 3, 128, 192); permute_1593 = None + convert_element_type_3221 = torch.ops.prims.convert_element_type.default(slice_250, torch.float32); slice_250 = None + view_2233 = torch.ops.aten.view.default(convert_element_type_3221, [2, 4096, 16, 32, 2]); convert_element_type_3221 = None + view_as_complex_101 = torch.ops.aten.view_as_complex.default(view_2233); view_2233 = None + mul_2100 = torch.ops.aten.mul.Tensor(view_as_complex_101, clone_9); view_as_complex_101 = None + view_as_real_101 = torch.ops.aten.view_as_real.default(mul_2100); mul_2100 = None + view_2234 = torch.ops.aten.view.default(view_as_real_101, [2, 4096, 16, 64]); view_as_real_101 = None + convert_element_type_3222 = torch.ops.prims.convert_element_type.default(view_2234, torch.bfloat16); view_2234 = None + cat_151 = torch.ops.aten.cat.default([slice_249, convert_element_type_3222], 3); slice_249 = convert_element_type_3222 = None + view_2235 = torch.ops.aten.view.default(cat_151, [2, 4096, 3072]); cat_151 = None + view_2236 = torch.ops.aten.view.default(view_2235, [8192, 3072]); view_2235 = None + permute_1602 = torch.ops.aten.permute.default(view_2236, [1, 0]) + mm_600 = torch.ops.aten.mm.default(permute_1602, view_170); permute_1602 = view_170 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16); primals_56 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_148, 64, '0'); convert_element_type_148 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + permute_1604 = torch.ops.aten.permute.default(permute_41, [1, 0]); permute_41 = None + mm_601 = torch.ops.aten.mm.default(view_2236, permute_1604); view_2236 = permute_1604 = None + view_2237 = torch.ops.aten.view.default(mm_601, [2, 4096, 2048]); mm_601 = None + add_2133 = torch.ops.aten.add.Tensor(view_2232, view_2237); view_2232 = view_2237 = None + convert_element_type_3227 = torch.ops.prims.convert_element_type.default(mm_600, torch.float32); mm_600 = None + reduce_scatter_tensor_336 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3227, 'avg', 64, '0'); convert_element_type_3227 = None + wait_tensor_943 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_336); reduce_scatter_tensor_336 = None + convert_element_type_3228 = torch.ops.prims.convert_element_type.default(add_2133, torch.float32); add_2133 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16); primals_55 = None + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_145, 64, '0'); convert_element_type_145 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_3230 = torch.ops.prims.convert_element_type.default(wait_tensor_53, torch.float32); wait_tensor_53 = None + mul_2101 = torch.ops.aten.mul.Tensor(convert_element_type_3228, convert_element_type_3230); convert_element_type_3230 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(add_141, torch.float32); add_141 = None + mul_107 = torch.ops.aten.mul.Tensor(convert_element_type_146, rsqrt_9); convert_element_type_146 = None + mul_2103 = torch.ops.aten.mul.Tensor(mul_107, mul_2101) + sum_297 = torch.ops.aten.sum.dim_IntList(mul_2103, [2], True); mul_2103 = None + div_275 = torch.ops.aten.div.Tensor(mul_107, 2048) + mul_2104 = torch.ops.aten.mul.Tensor(div_275, sum_297); div_275 = sum_297 = None + sub_768 = torch.ops.aten.sub.Tensor(mul_2101, mul_2104); mul_2101 = mul_2104 = None + mul_2105 = torch.ops.aten.mul.Tensor(sub_768, rsqrt_9); sub_768 = rsqrt_9 = None + mul_2106 = torch.ops.aten.mul.Tensor(convert_element_type_3228, mul_107); convert_element_type_3228 = mul_107 = None + sum_298 = torch.ops.aten.sum.dim_IntList(mul_2106, [0, 1]); mul_2106 = None + convert_element_type_3231 = torch.ops.prims.convert_element_type.default(mul_2105, torch.bfloat16); mul_2105 = None + add_2134 = torch.ops.aten.add.Tensor(add_2132, convert_element_type_3231); add_2132 = convert_element_type_3231 = None + convert_element_type_default_10 = torch.ops.prims.convert_element_type.default(sum_298, torch.float32); sum_298 = None + reduce_scatter_tensor_337 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_10, 'avg', 64, '0'); convert_element_type_default_10 = None + wait_tensor_944 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_337); reduce_scatter_tensor_337 = None + view_2238 = torch.ops.aten.view.default(add_2134, [8192, 2048]) + unsqueeze_77 = torch.ops.aten.unsqueeze.default(view_2238, 1) + convert_element_type_3234 = torch.ops.prims.convert_element_type.default(unsqueeze_77, torch.float32); unsqueeze_77 = None + bmm_74 = torch.ops.aten.bmm.default(permute_1606, convert_element_type_3234); permute_1606 = None + bmm_75 = torch.ops.aten.bmm.default(convert_element_type_3234, permute_1607); convert_element_type_3234 = permute_1607 = None + convert_element_type_3235 = torch.ops.prims.convert_element_type.default(bmm_74, torch.bfloat16); bmm_74 = None + view_2239 = torch.ops.aten.view.default(bmm_75, [8192, 6]); bmm_75 = None + view_2240 = torch.ops.aten.view.default(convert_element_type_3235, [49152, 2048]); convert_element_type_3235 = None + index_100 = torch.ops.aten.index.Tensor(view_2240, [getitem_35]); view_2240 = getitem_35 = None + permute_1608 = torch.ops.aten.permute.default(view_2238, [1, 0]) + mm_602 = torch.ops.aten.mm.default(permute_1608, mul_104); permute_1608 = mul_104 = None + convert_element_type_140 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_140, 64, '0'); convert_element_type_140 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + permute_1610 = torch.ops.aten.permute.default(permute_40, [1, 0]); permute_40 = None + mm_603 = torch.ops.aten.mm.default(view_2238, permute_1610); view_2238 = permute_1610 = None + convert_element_type_3240 = torch.ops.prims.convert_element_type.default(mm_602, torch.float32); mm_602 = None + reduce_scatter_tensor_338 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3240, 'avg', 64, '0'); convert_element_type_3240 = None + wait_tensor_945 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_338); reduce_scatter_tensor_338 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mm_20, torch.float32); mm_20 = None + neg_4 = torch.ops.aten.neg.default(convert_element_type_135) + exp_6 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_136 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + div_10 = torch.ops.aten.div.Tensor(convert_element_type_135, add_136) + convert_element_type_136 = torch.ops.prims.convert_element_type.default(div_10, torch.bfloat16); div_10 = None + mul_2107 = torch.ops.aten.mul.Tensor(mm_603, convert_element_type_136); convert_element_type_136 = None + mul_2108 = torch.ops.aten.mul.Tensor(mm_603, mm_21); mm_603 = mm_21 = None + permute_1612 = torch.ops.aten.permute.default(mul_2107, [1, 0]) + mm_604 = torch.ops.aten.mm.default(permute_1612, view_125); permute_1612 = None + convert_element_type_137 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_137, 64, '0'); convert_element_type_137 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_39 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + permute_1614 = torch.ops.aten.permute.default(permute_39, [1, 0]); permute_39 = None + mm_605 = torch.ops.aten.mm.default(mul_2107, permute_1614); mul_2107 = permute_1614 = None + convert_element_type_3245 = torch.ops.prims.convert_element_type.default(mm_604, torch.float32); mm_604 = None + reduce_scatter_tensor_339 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3245, 'avg', 64, '0'); convert_element_type_3245 = None + wait_tensor_946 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_339); reduce_scatter_tensor_339 = None + convert_element_type_3246 = torch.ops.prims.convert_element_type.default(mul_2108, torch.float32); mul_2108 = None + reciprocal_48 = torch.ops.aten.reciprocal.default(add_136); add_136 = None + mul_2109 = torch.ops.aten.mul.Tensor(reciprocal_48, 1); reciprocal_48 = None + mul_2110 = torch.ops.aten.mul.Tensor(convert_element_type_3246, mul_2109); convert_element_type_3246 = None + sub_769 = torch.ops.aten.sub.Tensor(1, mul_2109); mul_2109 = None + mul_2111 = torch.ops.aten.mul.Tensor(convert_element_type_135, sub_769); convert_element_type_135 = sub_769 = None + add_2136 = torch.ops.aten.add.Tensor(mul_2111, 1); mul_2111 = None + mul_2112 = torch.ops.aten.mul.Tensor(mul_2110, add_2136); mul_2110 = add_2136 = None + convert_element_type_3248 = torch.ops.prims.convert_element_type.default(mul_2112, torch.bfloat16); mul_2112 = None + permute_1616 = torch.ops.aten.permute.default(convert_element_type_3248, [1, 0]) + mm_606 = torch.ops.aten.mm.default(permute_1616, view_125); permute_1616 = None + convert_element_type_132 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_132, 64, '0'); convert_element_type_132 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_38 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + permute_1618 = torch.ops.aten.permute.default(permute_38, [1, 0]); permute_38 = None + mm_607 = torch.ops.aten.mm.default(convert_element_type_3248, permute_1618); convert_element_type_3248 = permute_1618 = None + add_2137 = torch.ops.aten.add.Tensor(mm_605, mm_607); mm_605 = mm_607 = None + convert_element_type_3253 = torch.ops.prims.convert_element_type.default(mm_606, torch.float32); mm_606 = None + reduce_scatter_tensor_340 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3253, 'avg', 64, '0'); convert_element_type_3253 = None + wait_tensor_947 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_340); reduce_scatter_tensor_340 = None + all_to_all_single_126 = torch.ops._c10d_functional.all_to_all_single.default(index_100, [_local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31], [_local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23], '521'); index_100 = None + wait_tensor_948 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_126); all_to_all_single_126 = None + full_444 = torch.ops.aten.full.default([sym_size_int_5, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_5 = None + slice_scatter_24 = torch.ops.aten.slice_scatter.default(full_444, wait_tensor_948, 0, 0, -1); wait_tensor_948 = None + index_101 = torch.ops.aten.index.Tensor(slice_scatter_24, [getitem_36]); slice_scatter_24 = None + permute_1620 = torch.ops.aten.permute.default(index_101, [1, 0]) + _grouped_mm_222 = torch.ops.aten._grouped_mm.default(permute_1620, mul_84, cumsum_5); permute_1620 = mul_84 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_126, 8, '513'); convert_element_type_126 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_37 = torch.ops.aten.permute.default(wait_tensor_45, [0, 2, 1]); wait_tensor_45 = None + permute_1622 = torch.ops.aten.permute.default(permute_37, [0, 2, 1]); permute_37 = None + _grouped_mm_223 = torch.ops.aten._grouped_mm.default(index_101, permute_1622, cumsum_5); index_101 = permute_1622 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(_grouped_mm_3, torch.float32); _grouped_mm_3 = None + neg_3 = torch.ops.aten.neg.default(convert_element_type_130) + exp_5 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_100 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + div_9 = torch.ops.aten.div.Tensor(convert_element_type_130, add_100) + convert_element_type_131 = torch.ops.prims.convert_element_type.default(div_9, torch.bfloat16); div_9 = None + mul_2113 = torch.ops.aten.mul.Tensor(_grouped_mm_223, convert_element_type_131); convert_element_type_131 = None + mul_2114 = torch.ops.aten.mul.Tensor(_grouped_mm_223, _grouped_mm_4); _grouped_mm_223 = _grouped_mm_4 = None + permute_1624 = torch.ops.aten.permute.default(mul_2113, [1, 0]) + _grouped_mm_224 = torch.ops.aten._grouped_mm.default(permute_1624, index_3, cumsum_5); permute_1624 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 8, '513'); convert_element_type_127 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_36 = torch.ops.aten.permute.default(wait_tensor_46, [0, 2, 1]); wait_tensor_46 = None + permute_1626 = torch.ops.aten.permute.default(permute_36, [0, 2, 1]); permute_36 = None + _grouped_mm_225 = torch.ops.aten._grouped_mm.default(mul_2113, permute_1626, cumsum_5); mul_2113 = permute_1626 = None + convert_element_type_3254 = torch.ops.prims.convert_element_type.default(mul_2114, torch.float32); mul_2114 = None + reciprocal_49 = torch.ops.aten.reciprocal.default(add_100); add_100 = None + mul_2115 = torch.ops.aten.mul.Tensor(reciprocal_49, 1); reciprocal_49 = None + mul_2116 = torch.ops.aten.mul.Tensor(convert_element_type_3254, mul_2115); convert_element_type_3254 = None + sub_770 = torch.ops.aten.sub.Tensor(1, mul_2115); mul_2115 = None + mul_2117 = torch.ops.aten.mul.Tensor(convert_element_type_130, sub_770); convert_element_type_130 = sub_770 = None + add_2139 = torch.ops.aten.add.Tensor(mul_2117, 1); mul_2117 = None + mul_2118 = torch.ops.aten.mul.Tensor(mul_2116, add_2139); mul_2116 = add_2139 = None + convert_element_type_3256 = torch.ops.prims.convert_element_type.default(mul_2118, torch.bfloat16); mul_2118 = None + permute_1628 = torch.ops.aten.permute.default(convert_element_type_3256, [1, 0]) + _grouped_mm_226 = torch.ops.aten._grouped_mm.default(permute_1628, index_3, cumsum_5); permute_1628 = index_3 = None + convert_element_type_124 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_124, 8, '513'); convert_element_type_124 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_43, [0, 2, 1]); wait_tensor_43 = None + permute_1630 = torch.ops.aten.permute.default(permute_35, [0, 2, 1]); permute_35 = None + _grouped_mm_227 = torch.ops.aten._grouped_mm.default(convert_element_type_3256, permute_1630, cumsum_5); convert_element_type_3256 = permute_1630 = cumsum_5 = None + add_2140 = torch.ops.aten.add.Tensor(_grouped_mm_225, _grouped_mm_227); _grouped_mm_225 = _grouped_mm_227 = None + convert_element_type_3257 = torch.ops.prims.convert_element_type.default(_grouped_mm_224, torch.float32); _grouped_mm_224 = None + div_276 = torch.ops.aten.div.Tensor(convert_element_type_3257, 64); convert_element_type_3257 = None + reduce_scatter_tensor_341 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_276, 'sum', 8, '513'); div_276 = None + wait_tensor_949 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_341); reduce_scatter_tensor_341 = None + convert_element_type_3258 = torch.ops.prims.convert_element_type.default(_grouped_mm_222, torch.float32); _grouped_mm_222 = None + div_277 = torch.ops.aten.div.Tensor(convert_element_type_3258, 64); convert_element_type_3258 = None + reduce_scatter_tensor_342 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_277, 'sum', 8, '513'); div_277 = None + wait_tensor_950 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_342); reduce_scatter_tensor_342 = None + convert_element_type_3259 = torch.ops.prims.convert_element_type.default(_grouped_mm_226, torch.float32); _grouped_mm_226 = None + div_278 = torch.ops.aten.div.Tensor(convert_element_type_3259, 64); convert_element_type_3259 = None + reduce_scatter_tensor_343 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_278, 'sum', 8, '513'); div_278 = None + wait_tensor_951 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_343); reduce_scatter_tensor_343 = None + index_put_100 = torch.ops.aten.index_put.default(full_444, [getitem_36], add_2140, True); full_444 = getitem_36 = add_2140 = None + slice_251 = torch.ops.aten.slice.Tensor(index_put_100, 0, 0, add_2141); index_put_100 = add_2141 = None + all_to_all_single_127 = torch.ops._c10d_functional.all_to_all_single.default(slice_251, [_local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23], [_local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31], '521'); slice_251 = _local_scalar_dense_16 = _local_scalar_dense_17 = _local_scalar_dense_18 = _local_scalar_dense_19 = _local_scalar_dense_20 = _local_scalar_dense_21 = _local_scalar_dense_22 = _local_scalar_dense_23 = _local_scalar_dense_24 = _local_scalar_dense_25 = _local_scalar_dense_26 = _local_scalar_dense_27 = _local_scalar_dense_28 = _local_scalar_dense_29 = _local_scalar_dense_30 = _local_scalar_dense_31 = None + wait_tensor_952 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_127); all_to_all_single_127 = None + index_put_101 = torch.ops.aten.index_put.default(full_default_52, [div_7], wait_tensor_952, True); div_7 = wait_tensor_952 = None + add_2145 = torch.ops.aten.add.Tensor(add_2137, index_put_101); add_2137 = index_put_101 = None + mul_2119 = torch.ops.aten.mul.Tensor(view_2239, 1.0); view_2239 = None + scatter_add_24 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_33, mul_2119); getitem_33 = mul_2119 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(mm_19, torch.float32); mm_19 = None + sub_24 = torch.ops.aten.sub.Tensor(convert_element_type_119, amax_1); convert_element_type_119 = amax_1 = None + exp_4 = torch.ops.aten.exp.default(sub_24); sub_24 = None + div_6 = torch.ops.aten.div.Tensor(exp_4, sum_5); exp_4 = sum_5 = None + mul_2120 = torch.ops.aten.mul.Tensor(scatter_add_24, div_6); scatter_add_24 = None + sum_299 = torch.ops.aten.sum.dim_IntList(mul_2120, [1], True) + neg_127 = torch.ops.aten.neg.default(div_6); div_6 = None + fma_24 = torch.ops.prims.fma.default(neg_127, sum_299, mul_2120); neg_127 = sum_299 = mul_2120 = None + convert_element_type_3260 = torch.ops.prims.convert_element_type.default(fma_24, torch.bfloat16); fma_24 = None + permute_1632 = torch.ops.aten.permute.default(convert_element_type_3260, [1, 0]) + mm_608 = torch.ops.aten.mm.default(permute_1632, view_125); permute_1632 = view_125 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 64, '0'); convert_element_type_116 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + permute_1634 = torch.ops.aten.permute.default(permute_34, [1, 0]); permute_34 = None + mm_609 = torch.ops.aten.mm.default(convert_element_type_3260, permute_1634); convert_element_type_3260 = permute_1634 = None + add_2146 = torch.ops.aten.add.Tensor(add_2145, mm_609); add_2145 = mm_609 = None + convert_element_type_3265 = torch.ops.prims.convert_element_type.default(mm_608, torch.float32); mm_608 = None + reduce_scatter_tensor_344 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3265, 'avg', 64, '0'); convert_element_type_3265 = None + wait_tensor_953 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_344); reduce_scatter_tensor_344 = None + view_2241 = torch.ops.aten.view.default(add_2146, [2, 4096, 2048]); add_2146 = None + convert_element_type_3266 = torch.ops.prims.convert_element_type.default(view_2241, torch.float32); view_2241 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16); primals_45 = None + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_113, 64, '0'); convert_element_type_113 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_3268 = torch.ops.prims.convert_element_type.default(wait_tensor_38, torch.float32); wait_tensor_38 = None + mul_2121 = torch.ops.aten.mul.Tensor(convert_element_type_3266, convert_element_type_3268); convert_element_type_3268 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(add_76, torch.float32); add_76 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_114, rsqrt_8); convert_element_type_114 = None + mul_2123 = torch.ops.aten.mul.Tensor(mul_64, mul_2121) + sum_300 = torch.ops.aten.sum.dim_IntList(mul_2123, [2], True); mul_2123 = None + div_279 = torch.ops.aten.div.Tensor(mul_64, 2048) + mul_2124 = torch.ops.aten.mul.Tensor(div_279, sum_300); div_279 = sum_300 = None + sub_772 = torch.ops.aten.sub.Tensor(mul_2121, mul_2124); mul_2121 = mul_2124 = None + mul_2125 = torch.ops.aten.mul.Tensor(sub_772, rsqrt_8); sub_772 = rsqrt_8 = None + mul_2126 = torch.ops.aten.mul.Tensor(convert_element_type_3266, mul_64); convert_element_type_3266 = mul_64 = None + sum_301 = torch.ops.aten.sum.dim_IntList(mul_2126, [0, 1]); mul_2126 = None + convert_element_type_3269 = torch.ops.prims.convert_element_type.default(mul_2125, torch.bfloat16); mul_2125 = None + add_2147 = torch.ops.aten.add.Tensor(add_2134, convert_element_type_3269); add_2134 = convert_element_type_3269 = None + convert_element_type_default_9 = torch.ops.prims.convert_element_type.default(sum_301, torch.float32); sum_301 = None + reduce_scatter_tensor_345 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_9, 'avg', 64, '0'); convert_element_type_default_9 = None + wait_tensor_954 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_345); reduce_scatter_tensor_345 = None + view_2242 = torch.ops.aten.view.default(add_2147, [8192, 2048]) + permute_1636 = torch.ops.aten.permute.default(view_2242, [1, 0]) + permute_32 = torch.ops.aten.permute.default(getitem_29, [0, 2, 1, 3]) + view_120 = torch.ops.aten.view.default(permute_32, [2, 4096, -1]); permute_32 = None + view_122 = torch.ops.aten.view.default(view_120, [8192, 2048]); view_120 = None + mm_610 = torch.ops.aten.mm.default(permute_1636, view_122); permute_1636 = view_122 = None + convert_element_type_110 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16); primals_44 = None + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_110, 64, '0'); convert_element_type_110 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + permute_1638 = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None + mm_611 = torch.ops.aten.mm.default(view_2242, permute_1638); view_2242 = permute_1638 = None + view_2243 = torch.ops.aten.view.default(mm_611, [2, 4096, 2048]); mm_611 = None + convert_element_type_3276 = torch.ops.prims.convert_element_type.default(mm_610, torch.float32); mm_610 = None + reduce_scatter_tensor_346 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3276, 'avg', 64, '0'); convert_element_type_3276 = None + wait_tensor_955 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_346); reduce_scatter_tensor_346 = None + view_2244 = torch.ops.aten.view.default(view_2243, [2, 4096, 16, 128]); view_2243 = None + permute_1640 = torch.ops.aten.permute.default(view_2244, [0, 2, 1, 3]); view_2244 = None + fw_graph24 = self.fw_graph24 + joint_graph24 = self.joint_graph24 + mask_graph24 = self.mask_graph24 + flex_attention_backward_24 = torch.ops.higher_order.flex_attention_backward(permute_29, permute_30, permute_31, getitem_29, getitem_30, permute_1640, None, fw_graph24, joint_graph24, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph24), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_29 = permute_30 = permute_31 = getitem_29 = getitem_30 = permute_1640 = fw_graph24 = joint_graph24 = mask_graph24 = None + getitem_469 = flex_attention_backward_24[0] + getitem_470 = flex_attention_backward_24[1] + getitem_471 = flex_attention_backward_24[2]; flex_attention_backward_24 = None + permute_1641 = torch.ops.aten.permute.default(getitem_471, [0, 2, 1, 3]); getitem_471 = None + permute_1642 = torch.ops.aten.permute.default(getitem_470, [0, 2, 1, 3]); getitem_470 = None + permute_1643 = torch.ops.aten.permute.default(getitem_469, [0, 2, 1, 3]); getitem_469 = None + slice_253 = torch.ops.aten.slice.Tensor(permute_1642, 3, 0, 128) + slice_254 = torch.ops.aten.slice.Tensor(permute_1642, 3, 128, 192); permute_1642 = None + sum_302 = torch.ops.aten.sum.dim_IntList(slice_254, [2], True); slice_254 = None + cat_152 = torch.ops.aten.cat.default([slice_253, permute_1641], 3); slice_253 = permute_1641 = None + view_2245 = torch.ops.aten.view.default(cat_152, [2, 4096, 4096]); cat_152 = None + view_2246 = torch.ops.aten.view.default(view_2245, [8192, 4096]); view_2245 = None + permute_1644 = torch.ops.aten.permute.default(view_2246, [1, 0]) + mm_612 = torch.ops.aten.mm.default(permute_1644, view_117); permute_1644 = view_117 = None + convert_element_type_107 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16); primals_43 = None + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_107, 64, '0'); convert_element_type_107 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_28 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + permute_1646 = torch.ops.aten.permute.default(permute_28, [1, 0]); permute_28 = None + mm_613 = torch.ops.aten.mm.default(view_2246, permute_1646); view_2246 = permute_1646 = None + view_2247 = torch.ops.aten.view.default(mm_613, [2, 4096, 512]); mm_613 = None + convert_element_type_3281 = torch.ops.prims.convert_element_type.default(mm_612, torch.float32); mm_612 = None + reduce_scatter_tensor_347 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3281, 'avg', 64, '0'); convert_element_type_3281 = None + wait_tensor_956 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_347); reduce_scatter_tensor_347 = None + convert_element_type_3282 = torch.ops.prims.convert_element_type.default(view_2247, torch.float32); view_2247 = None + convert_element_type_104 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16); primals_42 = None + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_104, 64, '0'); convert_element_type_104 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + convert_element_type_3284 = torch.ops.prims.convert_element_type.default(wait_tensor_35, torch.float32); wait_tensor_35 = None + mul_2127 = torch.ops.aten.mul.Tensor(convert_element_type_3282, convert_element_type_3284); convert_element_type_3284 = None + convert_element_type_105 = torch.ops.prims.convert_element_type.default(getitem_25, torch.float32); getitem_25 = None + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_105, rsqrt_7); convert_element_type_105 = None + mul_2129 = torch.ops.aten.mul.Tensor(mul_62, mul_2127) + sum_303 = torch.ops.aten.sum.dim_IntList(mul_2129, [2], True); mul_2129 = None + div_280 = torch.ops.aten.div.Tensor(mul_62, 512) + mul_2130 = torch.ops.aten.mul.Tensor(div_280, sum_303); div_280 = sum_303 = None + sub_773 = torch.ops.aten.sub.Tensor(mul_2127, mul_2130); mul_2127 = mul_2130 = None + mul_2131 = torch.ops.aten.mul.Tensor(sub_773, rsqrt_7); sub_773 = rsqrt_7 = None + mul_2132 = torch.ops.aten.mul.Tensor(convert_element_type_3282, mul_62); convert_element_type_3282 = mul_62 = None + sum_304 = torch.ops.aten.sum.dim_IntList(mul_2132, [0, 1]); mul_2132 = None + convert_element_type_3285 = torch.ops.prims.convert_element_type.default(mul_2131, torch.bfloat16); mul_2131 = None + convert_element_type_default_8 = torch.ops.prims.convert_element_type.default(sum_304, torch.float32); sum_304 = None + reduce_scatter_tensor_348 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_8, 'avg', 64, '0'); convert_element_type_default_8 = None + wait_tensor_957 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_348); reduce_scatter_tensor_348 = None + convert_element_type_3288 = torch.ops.prims.convert_element_type.default(sum_302, torch.float32); sum_302 = None + view_2248 = torch.ops.aten.view.default(convert_element_type_3288, [2, 4096, 1, 32, 2]); convert_element_type_3288 = None + view_as_complex_102 = torch.ops.aten.view_as_complex.default(view_2248); view_2248 = None + mul_2133 = torch.ops.aten.mul.Tensor(view_as_complex_102, clone_9); view_as_complex_102 = None + view_as_real_102 = torch.ops.aten.view_as_real.default(mul_2133); mul_2133 = None + view_2249 = torch.ops.aten.view.default(view_as_real_102, [2, 4096, 1, 64]); view_as_real_102 = None + convert_element_type_3289 = torch.ops.prims.convert_element_type.default(view_2249, torch.bfloat16); view_2249 = None + squeeze_50 = torch.ops.aten.squeeze.dim(convert_element_type_3289, 2); convert_element_type_3289 = None + cat_153 = torch.ops.aten.cat.default([convert_element_type_3285, squeeze_50], 2); convert_element_type_3285 = squeeze_50 = None + view_2250 = torch.ops.aten.view.default(cat_153, [8192, 576]); cat_153 = None + permute_1648 = torch.ops.aten.permute.default(view_2250, [1, 0]) + mm_614 = torch.ops.aten.mm.default(permute_1648, view_103); permute_1648 = None + convert_element_type_99 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16); primals_41 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_99, 64, '0'); convert_element_type_99 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + permute_27 = torch.ops.aten.permute.default(wait_tensor_34, [1, 0]); wait_tensor_34 = None + permute_1650 = torch.ops.aten.permute.default(permute_27, [1, 0]); permute_27 = None + mm_615 = torch.ops.aten.mm.default(view_2250, permute_1650); view_2250 = permute_1650 = None + view_2251 = torch.ops.aten.view.default(mm_615, [2, 4096, 2048]); mm_615 = None + convert_element_type_3294 = torch.ops.prims.convert_element_type.default(mm_614, torch.float32); mm_614 = None + reduce_scatter_tensor_349 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3294, 'avg', 64, '0'); convert_element_type_3294 = None + wait_tensor_958 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_349); reduce_scatter_tensor_349 = None + slice_255 = torch.ops.aten.slice.Tensor(permute_1643, 3, 0, 128) + slice_256 = torch.ops.aten.slice.Tensor(permute_1643, 3, 128, 192); permute_1643 = None + convert_element_type_3295 = torch.ops.prims.convert_element_type.default(slice_256, torch.float32); slice_256 = None + view_2252 = torch.ops.aten.view.default(convert_element_type_3295, [2, 4096, 16, 32, 2]); convert_element_type_3295 = None + view_as_complex_103 = torch.ops.aten.view_as_complex.default(view_2252); view_2252 = None + mul_2134 = torch.ops.aten.mul.Tensor(view_as_complex_103, clone_9); view_as_complex_103 = None + view_as_real_103 = torch.ops.aten.view_as_real.default(mul_2134); mul_2134 = None + view_2253 = torch.ops.aten.view.default(view_as_real_103, [2, 4096, 16, 64]); view_as_real_103 = None + convert_element_type_3296 = torch.ops.prims.convert_element_type.default(view_2253, torch.bfloat16); view_2253 = None + cat_154 = torch.ops.aten.cat.default([slice_255, convert_element_type_3296], 3); slice_255 = convert_element_type_3296 = None + view_2254 = torch.ops.aten.view.default(cat_154, [2, 4096, 3072]); cat_154 = None + view_2255 = torch.ops.aten.view.default(view_2254, [8192, 3072]); view_2254 = None + permute_1652 = torch.ops.aten.permute.default(view_2255, [1, 0]) + mm_616 = torch.ops.aten.mm.default(permute_1652, view_103); permute_1652 = view_103 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16); primals_40 = None + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 64, '0'); convert_element_type_94 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_26 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + permute_1654 = torch.ops.aten.permute.default(permute_26, [1, 0]); permute_26 = None + mm_617 = torch.ops.aten.mm.default(view_2255, permute_1654); view_2255 = permute_1654 = None + view_2256 = torch.ops.aten.view.default(mm_617, [2, 4096, 2048]); mm_617 = None + add_2148 = torch.ops.aten.add.Tensor(view_2251, view_2256); view_2251 = view_2256 = None + convert_element_type_3301 = torch.ops.prims.convert_element_type.default(mm_616, torch.float32); mm_616 = None + reduce_scatter_tensor_350 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3301, 'avg', 64, '0'); convert_element_type_3301 = None + wait_tensor_959 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_350); reduce_scatter_tensor_350 = None + convert_element_type_3302 = torch.ops.prims.convert_element_type.default(add_2148, torch.float32); add_2148 = None + convert_element_type_91 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16); primals_39 = None + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_91, 64, '0'); convert_element_type_91 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_3304 = torch.ops.prims.convert_element_type.default(wait_tensor_32, torch.float32); wait_tensor_32 = None + mul_2135 = torch.ops.aten.mul.Tensor(convert_element_type_3302, convert_element_type_3304); convert_element_type_3304 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(add_73, torch.float32); add_73 = None + mul_58 = torch.ops.aten.mul.Tensor(convert_element_type_92, rsqrt_6); convert_element_type_92 = None + mul_2137 = torch.ops.aten.mul.Tensor(mul_58, mul_2135) + sum_305 = torch.ops.aten.sum.dim_IntList(mul_2137, [2], True); mul_2137 = None + div_281 = torch.ops.aten.div.Tensor(mul_58, 2048) + mul_2138 = torch.ops.aten.mul.Tensor(div_281, sum_305); div_281 = sum_305 = None + sub_774 = torch.ops.aten.sub.Tensor(mul_2135, mul_2138); mul_2135 = mul_2138 = None + mul_2139 = torch.ops.aten.mul.Tensor(sub_774, rsqrt_6); sub_774 = rsqrt_6 = None + mul_2140 = torch.ops.aten.mul.Tensor(convert_element_type_3302, mul_58); convert_element_type_3302 = mul_58 = None + sum_306 = torch.ops.aten.sum.dim_IntList(mul_2140, [0, 1]); mul_2140 = None + convert_element_type_3305 = torch.ops.prims.convert_element_type.default(mul_2139, torch.bfloat16); mul_2139 = None + add_2149 = torch.ops.aten.add.Tensor(add_2147, convert_element_type_3305); add_2147 = convert_element_type_3305 = None + convert_element_type_default_7 = torch.ops.prims.convert_element_type.default(sum_306, torch.float32); sum_306 = None + reduce_scatter_tensor_351 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_7, 'avg', 64, '0'); convert_element_type_default_7 = None + wait_tensor_960 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_351); reduce_scatter_tensor_351 = None + view_2257 = torch.ops.aten.view.default(add_2149, [8192, 2048]) + unsqueeze_78 = torch.ops.aten.unsqueeze.default(view_2257, 1) + convert_element_type_3308 = torch.ops.prims.convert_element_type.default(unsqueeze_78, torch.float32); unsqueeze_78 = None + bmm_76 = torch.ops.aten.bmm.default(permute_1656, convert_element_type_3308); permute_1656 = None + bmm_77 = torch.ops.aten.bmm.default(convert_element_type_3308, permute_1657); convert_element_type_3308 = permute_1657 = None + convert_element_type_3309 = torch.ops.prims.convert_element_type.default(bmm_76, torch.bfloat16); bmm_76 = None + view_2258 = torch.ops.aten.view.default(bmm_77, [8192, 6]); bmm_77 = None + view_2259 = torch.ops.aten.view.default(convert_element_type_3309, [49152, 2048]); convert_element_type_3309 = None + index_102 = torch.ops.aten.index.Tensor(view_2259, [getitem_21]); view_2259 = getitem_21 = None + permute_1658 = torch.ops.aten.permute.default(view_2257, [1, 0]) + mm_618 = torch.ops.aten.mm.default(permute_1658, mul_55); permute_1658 = mul_55 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16); primals_38 = None + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 64, '0'); convert_element_type_86 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_25 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + permute_1660 = torch.ops.aten.permute.default(permute_25, [1, 0]); permute_25 = None + mm_619 = torch.ops.aten.mm.default(view_2257, permute_1660); view_2257 = permute_1660 = None + convert_element_type_3314 = torch.ops.prims.convert_element_type.default(mm_618, torch.float32); mm_618 = None + reduce_scatter_tensor_352 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3314, 'avg', 64, '0'); convert_element_type_3314 = None + wait_tensor_961 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_352); reduce_scatter_tensor_352 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(mm_12, torch.float32); mm_12 = None + neg_2 = torch.ops.aten.neg.default(convert_element_type_81) + exp_3 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_68 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + div_5 = torch.ops.aten.div.Tensor(convert_element_type_81, add_68) + convert_element_type_82 = torch.ops.prims.convert_element_type.default(div_5, torch.bfloat16); div_5 = None + mul_2141 = torch.ops.aten.mul.Tensor(mm_619, convert_element_type_82); convert_element_type_82 = None + mul_2142 = torch.ops.aten.mul.Tensor(mm_619, mm_13); mm_619 = mm_13 = None + permute_1662 = torch.ops.aten.permute.default(mul_2141, [1, 0]) + mm_620 = torch.ops.aten.mm.default(permute_1662, view_58); permute_1662 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16); primals_37 = None + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 64, '0'); convert_element_type_83 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + permute_1664 = torch.ops.aten.permute.default(permute_24, [1, 0]); permute_24 = None + mm_621 = torch.ops.aten.mm.default(mul_2141, permute_1664); mul_2141 = permute_1664 = None + convert_element_type_3319 = torch.ops.prims.convert_element_type.default(mm_620, torch.float32); mm_620 = None + reduce_scatter_tensor_353 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3319, 'avg', 64, '0'); convert_element_type_3319 = None + wait_tensor_962 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_353); reduce_scatter_tensor_353 = None + convert_element_type_3320 = torch.ops.prims.convert_element_type.default(mul_2142, torch.float32); mul_2142 = None + reciprocal_50 = torch.ops.aten.reciprocal.default(add_68); add_68 = None + mul_2143 = torch.ops.aten.mul.Tensor(reciprocal_50, 1); reciprocal_50 = None + mul_2144 = torch.ops.aten.mul.Tensor(convert_element_type_3320, mul_2143); convert_element_type_3320 = None + sub_775 = torch.ops.aten.sub.Tensor(1, mul_2143); mul_2143 = None + mul_2145 = torch.ops.aten.mul.Tensor(convert_element_type_81, sub_775); convert_element_type_81 = sub_775 = None + add_2151 = torch.ops.aten.add.Tensor(mul_2145, 1); mul_2145 = None + mul_2146 = torch.ops.aten.mul.Tensor(mul_2144, add_2151); mul_2144 = add_2151 = None + convert_element_type_3322 = torch.ops.prims.convert_element_type.default(mul_2146, torch.bfloat16); mul_2146 = None + permute_1666 = torch.ops.aten.permute.default(convert_element_type_3322, [1, 0]) + mm_622 = torch.ops.aten.mm.default(permute_1666, view_58); permute_1666 = None + convert_element_type_78 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_78, 64, '0'); convert_element_type_78 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + permute_1668 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None + mm_623 = torch.ops.aten.mm.default(convert_element_type_3322, permute_1668); convert_element_type_3322 = permute_1668 = None + add_2152 = torch.ops.aten.add.Tensor(mm_621, mm_623); mm_621 = mm_623 = None + convert_element_type_3327 = torch.ops.prims.convert_element_type.default(mm_622, torch.float32); mm_622 = None + reduce_scatter_tensor_354 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3327, 'avg', 64, '0'); convert_element_type_3327 = None + wait_tensor_963 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_354); reduce_scatter_tensor_354 = None + all_to_all_single_128 = torch.ops._c10d_functional.all_to_all_single.default(index_102, [_local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15], [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7], '521'); index_102 = None + wait_tensor_964 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_128); all_to_all_single_128 = None + full_448 = torch.ops.aten.full.default([sym_size_int_1, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False); sym_size_int_1 = None + slice_scatter_25 = torch.ops.aten.slice_scatter.default(full_448, wait_tensor_964, 0, 0, -1); wait_tensor_964 = None + index_103 = torch.ops.aten.index.Tensor(slice_scatter_25, [getitem_22]); slice_scatter_25 = None + permute_1670 = torch.ops.aten.permute.default(index_103, [1, 0]) + _grouped_mm_228 = torch.ops.aten._grouped_mm.default(permute_1670, mul_35, cumsum_2); permute_1670 = mul_35 = None + convert_element_type_72 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_72, 8, '513'); convert_element_type_72 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_24, [0, 2, 1]); wait_tensor_24 = None + permute_1672 = torch.ops.aten.permute.default(permute_22, [0, 2, 1]); permute_22 = None + _grouped_mm_229 = torch.ops.aten._grouped_mm.default(index_103, permute_1672, cumsum_2); index_103 = permute_1672 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(_grouped_mm, torch.float32); _grouped_mm = None + neg_1 = torch.ops.aten.neg.default(convert_element_type_76) + exp_2 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_32 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + div_4 = torch.ops.aten.div.Tensor(convert_element_type_76, add_32) + convert_element_type_77 = torch.ops.prims.convert_element_type.default(div_4, torch.bfloat16); div_4 = None + mul_2147 = torch.ops.aten.mul.Tensor(_grouped_mm_229, convert_element_type_77); convert_element_type_77 = None + mul_2148 = torch.ops.aten.mul.Tensor(_grouped_mm_229, _grouped_mm_1); _grouped_mm_229 = _grouped_mm_1 = None + permute_1674 = torch.ops.aten.permute.default(mul_2147, [1, 0]) + _grouped_mm_230 = torch.ops.aten._grouped_mm.default(permute_1674, index_1, cumsum_2); permute_1674 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 8, '513'); convert_element_type_73 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_25, [0, 2, 1]); wait_tensor_25 = None + permute_1676 = torch.ops.aten.permute.default(permute_21, [0, 2, 1]); permute_21 = None + _grouped_mm_231 = torch.ops.aten._grouped_mm.default(mul_2147, permute_1676, cumsum_2); mul_2147 = permute_1676 = None + convert_element_type_3328 = torch.ops.prims.convert_element_type.default(mul_2148, torch.float32); mul_2148 = None + reciprocal_51 = torch.ops.aten.reciprocal.default(add_32); add_32 = None + mul_2149 = torch.ops.aten.mul.Tensor(reciprocal_51, 1); reciprocal_51 = None + mul_2150 = torch.ops.aten.mul.Tensor(convert_element_type_3328, mul_2149); convert_element_type_3328 = None + sub_776 = torch.ops.aten.sub.Tensor(1, mul_2149); mul_2149 = None + mul_2151 = torch.ops.aten.mul.Tensor(convert_element_type_76, sub_776); convert_element_type_76 = sub_776 = None + add_2154 = torch.ops.aten.add.Tensor(mul_2151, 1); mul_2151 = None + mul_2152 = torch.ops.aten.mul.Tensor(mul_2150, add_2154); mul_2150 = add_2154 = None + convert_element_type_3330 = torch.ops.prims.convert_element_type.default(mul_2152, torch.bfloat16); mul_2152 = None + permute_1678 = torch.ops.aten.permute.default(convert_element_type_3330, [1, 0]) + _grouped_mm_232 = torch.ops.aten._grouped_mm.default(permute_1678, index_1, cumsum_2); permute_1678 = index_1 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 8, '513'); convert_element_type_70 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_22, [0, 2, 1]); wait_tensor_22 = None + permute_1680 = torch.ops.aten.permute.default(permute_20, [0, 2, 1]); permute_20 = None + _grouped_mm_233 = torch.ops.aten._grouped_mm.default(convert_element_type_3330, permute_1680, cumsum_2); convert_element_type_3330 = permute_1680 = cumsum_2 = None + add_2155 = torch.ops.aten.add.Tensor(_grouped_mm_231, _grouped_mm_233); _grouped_mm_231 = _grouped_mm_233 = None + convert_element_type_3331 = torch.ops.prims.convert_element_type.default(_grouped_mm_230, torch.float32); _grouped_mm_230 = None + div_282 = torch.ops.aten.div.Tensor(convert_element_type_3331, 64); convert_element_type_3331 = None + reduce_scatter_tensor_355 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_282, 'sum', 8, '513'); div_282 = None + wait_tensor_965 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_355); reduce_scatter_tensor_355 = None + convert_element_type_3332 = torch.ops.prims.convert_element_type.default(_grouped_mm_228, torch.float32); _grouped_mm_228 = None + div_283 = torch.ops.aten.div.Tensor(convert_element_type_3332, 64); convert_element_type_3332 = None + reduce_scatter_tensor_356 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_283, 'sum', 8, '513'); div_283 = None + wait_tensor_966 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_356); reduce_scatter_tensor_356 = None + convert_element_type_3333 = torch.ops.prims.convert_element_type.default(_grouped_mm_232, torch.float32); _grouped_mm_232 = None + div_284 = torch.ops.aten.div.Tensor(convert_element_type_3333, 64); convert_element_type_3333 = None + reduce_scatter_tensor_357 = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_284, 'sum', 8, '513'); div_284 = None + wait_tensor_967 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_357); reduce_scatter_tensor_357 = None + index_put_102 = torch.ops.aten.index_put.default(full_448, [getitem_22], add_2155, True); full_448 = getitem_22 = add_2155 = None + slice_257 = torch.ops.aten.slice.Tensor(index_put_102, 0, 0, add_2156); index_put_102 = add_2156 = None + all_to_all_single_129 = torch.ops._c10d_functional.all_to_all_single.default(slice_257, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7], [_local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15], '521'); slice_257 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = _local_scalar_dense_3 = _local_scalar_dense_4 = _local_scalar_dense_5 = _local_scalar_dense_6 = _local_scalar_dense_7 = _local_scalar_dense_8 = _local_scalar_dense_9 = _local_scalar_dense_10 = _local_scalar_dense_11 = _local_scalar_dense_12 = _local_scalar_dense_13 = _local_scalar_dense_14 = _local_scalar_dense_15 = None + wait_tensor_968 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_129); all_to_all_single_129 = None + index_put_103 = torch.ops.aten.index_put.default(full_default_52, [div_2], wait_tensor_968, True); full_default_52 = div_2 = wait_tensor_968 = None + add_2160 = torch.ops.aten.add.Tensor(add_2152, index_put_103); add_2152 = index_put_103 = None + mul_2153 = torch.ops.aten.mul.Tensor(view_2258, 1.0); view_2258 = None + scatter_add_25 = torch.ops.aten.scatter_add.default(full_default_53, 1, getitem_19, mul_2153); full_default_53 = getitem_19 = mul_2153 = None + convert_element_type_65 = torch.ops.prims.convert_element_type.default(mm_11, torch.float32); mm_11 = None + sub = torch.ops.aten.sub.Tensor(convert_element_type_65, amax); convert_element_type_65 = amax = None + exp_1 = torch.ops.aten.exp.default(sub); sub = None + div_1 = torch.ops.aten.div.Tensor(exp_1, sum_1); exp_1 = sum_1 = None + mul_2154 = torch.ops.aten.mul.Tensor(scatter_add_25, div_1); scatter_add_25 = None + sum_307 = torch.ops.aten.sum.dim_IntList(mul_2154, [1], True) + neg_130 = torch.ops.aten.neg.default(div_1); div_1 = None + fma_25 = torch.ops.prims.fma.default(neg_130, sum_307, mul_2154); neg_130 = sum_307 = mul_2154 = None + convert_element_type_3334 = torch.ops.prims.convert_element_type.default(fma_25, torch.bfloat16); fma_25 = None + permute_1682 = torch.ops.aten.permute.default(convert_element_type_3334, [1, 0]) + mm_624 = torch.ops.aten.mm.default(permute_1682, view_58); permute_1682 = view_58 = None + convert_element_type_62 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_62, 64, '0'); convert_element_type_62 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + permute_1684 = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None + mm_625 = torch.ops.aten.mm.default(convert_element_type_3334, permute_1684); convert_element_type_3334 = permute_1684 = None + add_2161 = torch.ops.aten.add.Tensor(add_2160, mm_625); add_2160 = mm_625 = None + convert_element_type_3339 = torch.ops.prims.convert_element_type.default(mm_624, torch.float32); mm_624 = None + reduce_scatter_tensor_358 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3339, 'avg', 64, '0'); convert_element_type_3339 = None + wait_tensor_969 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_358); reduce_scatter_tensor_358 = None + view_2260 = torch.ops.aten.view.default(add_2161, [2, 4096, 2048]); add_2161 = None + convert_element_type_3340 = torch.ops.prims.convert_element_type.default(view_2260, torch.float32); view_2260 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_59, 64, '0'); convert_element_type_59 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + convert_element_type_3342 = torch.ops.prims.convert_element_type.default(wait_tensor_17, torch.float32); wait_tensor_17 = None + mul_2155 = torch.ops.aten.mul.Tensor(convert_element_type_3340, convert_element_type_3342); convert_element_type_3342 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(add_8, torch.float32); add_8 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, rsqrt_5); convert_element_type_60 = None + mul_2157 = torch.ops.aten.mul.Tensor(mul_15, mul_2155) + sum_308 = torch.ops.aten.sum.dim_IntList(mul_2157, [2], True); mul_2157 = None + div_285 = torch.ops.aten.div.Tensor(mul_15, 2048) + mul_2158 = torch.ops.aten.mul.Tensor(div_285, sum_308); div_285 = sum_308 = None + sub_778 = torch.ops.aten.sub.Tensor(mul_2155, mul_2158); mul_2155 = mul_2158 = None + mul_2159 = torch.ops.aten.mul.Tensor(sub_778, rsqrt_5); sub_778 = rsqrt_5 = None + mul_2160 = torch.ops.aten.mul.Tensor(convert_element_type_3340, mul_15); convert_element_type_3340 = mul_15 = None + sum_309 = torch.ops.aten.sum.dim_IntList(mul_2160, [0, 1]); mul_2160 = None + convert_element_type_3343 = torch.ops.prims.convert_element_type.default(mul_2159, torch.bfloat16); mul_2159 = None + add_2162 = torch.ops.aten.add.Tensor(add_2149, convert_element_type_3343); add_2149 = convert_element_type_3343 = None + convert_element_type_default_6 = torch.ops.prims.convert_element_type.default(sum_309, torch.float32); sum_309 = None + reduce_scatter_tensor_359 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_6, 'avg', 64, '0'); convert_element_type_default_6 = None + wait_tensor_970 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_359); reduce_scatter_tensor_359 = None + view_2261 = torch.ops.aten.view.default(add_2162, [8192, 2048]) + permute_1686 = torch.ops.aten.permute.default(view_2261, [1, 0]) + permute_17 = torch.ops.aten.permute.default(getitem_15, [0, 2, 1, 3]) + view_53 = torch.ops.aten.view.default(permute_17, [2, 4096, -1]); permute_17 = None + view_55 = torch.ops.aten.view.default(view_53, [8192, 2048]); view_53 = None + mm_626 = torch.ops.aten.mm.default(permute_1686, view_55); permute_1686 = view_55 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 64, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + permute_1688 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None + mm_627 = torch.ops.aten.mm.default(view_2261, permute_1688); view_2261 = permute_1688 = None + view_2262 = torch.ops.aten.view.default(mm_627, [2, 4096, 2048]); mm_627 = None + convert_element_type_3350 = torch.ops.prims.convert_element_type.default(mm_626, torch.float32); mm_626 = None + reduce_scatter_tensor_360 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3350, 'avg', 64, '0'); convert_element_type_3350 = None + wait_tensor_971 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_360); reduce_scatter_tensor_360 = None + view_2263 = torch.ops.aten.view.default(view_2262, [2, 4096, 16, 128]); view_2262 = None + permute_1690 = torch.ops.aten.permute.default(view_2263, [0, 2, 1, 3]); view_2263 = None + fw_graph25 = self.fw_graph25 + joint_graph25 = self.joint_graph25 + mask_graph25 = self.mask_graph25 + flex_attention_backward_25 = torch.ops.higher_order.flex_attention_backward(permute_14, permute_15, permute_16, getitem_15, getitem_16, permute_1690, None, fw_graph25, joint_graph25, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph25), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_14 = permute_15 = permute_16 = getitem_15 = getitem_16 = permute_1690 = fw_graph25 = joint_graph25 = mask_graph25 = None + getitem_473 = flex_attention_backward_25[0] + getitem_474 = flex_attention_backward_25[1] + getitem_475 = flex_attention_backward_25[2]; flex_attention_backward_25 = None + permute_1691 = torch.ops.aten.permute.default(getitem_475, [0, 2, 1, 3]); getitem_475 = None + permute_1692 = torch.ops.aten.permute.default(getitem_474, [0, 2, 1, 3]); getitem_474 = None + permute_1693 = torch.ops.aten.permute.default(getitem_473, [0, 2, 1, 3]); getitem_473 = None + slice_259 = torch.ops.aten.slice.Tensor(permute_1692, 3, 0, 128) + slice_260 = torch.ops.aten.slice.Tensor(permute_1692, 3, 128, 192); permute_1692 = None + sum_310 = torch.ops.aten.sum.dim_IntList(slice_260, [2], True); slice_260 = None + cat_155 = torch.ops.aten.cat.default([slice_259, permute_1691], 3); slice_259 = permute_1691 = None + view_2264 = torch.ops.aten.view.default(cat_155, [2, 4096, 4096]); cat_155 = None + view_2265 = torch.ops.aten.view.default(view_2264, [8192, 4096]); view_2264 = None + permute_1694 = torch.ops.aten.permute.default(view_2265, [1, 0]) + mm_628 = torch.ops.aten.mm.default(permute_1694, view_50); permute_1694 = view_50 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16); primals_27 = None + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 64, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_15, [1, 0]); wait_tensor_15 = None + permute_1696 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None + mm_629 = torch.ops.aten.mm.default(view_2265, permute_1696); view_2265 = permute_1696 = None + view_2266 = torch.ops.aten.view.default(mm_629, [2, 4096, 512]); mm_629 = None + convert_element_type_3355 = torch.ops.prims.convert_element_type.default(mm_628, torch.float32); mm_628 = None + reduce_scatter_tensor_361 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3355, 'avg', 64, '0'); convert_element_type_3355 = None + wait_tensor_972 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_361); reduce_scatter_tensor_361 = None + convert_element_type_3356 = torch.ops.prims.convert_element_type.default(view_2266, torch.float32); view_2266 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16); primals_26 = None + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 64, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + convert_element_type_3358 = torch.ops.prims.convert_element_type.default(wait_tensor_14, torch.float32); wait_tensor_14 = None + mul_2161 = torch.ops.aten.mul.Tensor(convert_element_type_3356, convert_element_type_3358); convert_element_type_3358 = None + convert_element_type_51 = torch.ops.prims.convert_element_type.default(getitem_11, torch.float32); getitem_11 = None + mul_13 = torch.ops.aten.mul.Tensor(convert_element_type_51, rsqrt_4); convert_element_type_51 = None + mul_2163 = torch.ops.aten.mul.Tensor(mul_13, mul_2161) + sum_311 = torch.ops.aten.sum.dim_IntList(mul_2163, [2], True); mul_2163 = None + div_286 = torch.ops.aten.div.Tensor(mul_13, 512) + mul_2164 = torch.ops.aten.mul.Tensor(div_286, sum_311); div_286 = sum_311 = None + sub_779 = torch.ops.aten.sub.Tensor(mul_2161, mul_2164); mul_2161 = mul_2164 = None + mul_2165 = torch.ops.aten.mul.Tensor(sub_779, rsqrt_4); sub_779 = rsqrt_4 = None + mul_2166 = torch.ops.aten.mul.Tensor(convert_element_type_3356, mul_13); convert_element_type_3356 = mul_13 = None + sum_312 = torch.ops.aten.sum.dim_IntList(mul_2166, [0, 1]); mul_2166 = None + convert_element_type_3359 = torch.ops.prims.convert_element_type.default(mul_2165, torch.bfloat16); mul_2165 = None + convert_element_type_default_5 = torch.ops.prims.convert_element_type.default(sum_312, torch.float32); sum_312 = None + reduce_scatter_tensor_362 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_5, 'avg', 64, '0'); convert_element_type_default_5 = None + wait_tensor_973 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_362); reduce_scatter_tensor_362 = None + convert_element_type_3362 = torch.ops.prims.convert_element_type.default(sum_310, torch.float32); sum_310 = None + view_2267 = torch.ops.aten.view.default(convert_element_type_3362, [2, 4096, 1, 32, 2]); convert_element_type_3362 = None + view_as_complex_104 = torch.ops.aten.view_as_complex.default(view_2267); view_2267 = None + mul_2167 = torch.ops.aten.mul.Tensor(view_as_complex_104, clone_9); view_as_complex_104 = None + view_as_real_104 = torch.ops.aten.view_as_real.default(mul_2167); mul_2167 = None + view_2268 = torch.ops.aten.view.default(view_as_real_104, [2, 4096, 1, 64]); view_as_real_104 = None + convert_element_type_3363 = torch.ops.prims.convert_element_type.default(view_2268, torch.bfloat16); view_2268 = None + squeeze_51 = torch.ops.aten.squeeze.dim(convert_element_type_3363, 2); convert_element_type_3363 = None + cat_156 = torch.ops.aten.cat.default([convert_element_type_3359, squeeze_51], 2); convert_element_type_3359 = squeeze_51 = None + view_2269 = torch.ops.aten.view.default(cat_156, [8192, 576]); cat_156 = None + permute_1698 = torch.ops.aten.permute.default(view_2269, [1, 0]) + mm_630 = torch.ops.aten.mm.default(permute_1698, view_36); permute_1698 = None + convert_element_type_45 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16); primals_25 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_45, 64, '0'); convert_element_type_45 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + permute_1700 = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None + mm_631 = torch.ops.aten.mm.default(view_2269, permute_1700); view_2269 = permute_1700 = None + view_2270 = torch.ops.aten.view.default(mm_631, [2, 4096, 2048]); mm_631 = None + convert_element_type_3368 = torch.ops.prims.convert_element_type.default(mm_630, torch.float32); mm_630 = None + reduce_scatter_tensor_363 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3368, 'avg', 64, '0'); convert_element_type_3368 = None + wait_tensor_974 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_363); reduce_scatter_tensor_363 = None + slice_261 = torch.ops.aten.slice.Tensor(permute_1693, 3, 0, 128) + slice_262 = torch.ops.aten.slice.Tensor(permute_1693, 3, 128, 192); permute_1693 = None + convert_element_type_3369 = torch.ops.prims.convert_element_type.default(slice_262, torch.float32); slice_262 = None + view_2271 = torch.ops.aten.view.default(convert_element_type_3369, [2, 4096, 16, 32, 2]); convert_element_type_3369 = None + view_as_complex_105 = torch.ops.aten.view_as_complex.default(view_2271); view_2271 = None + mul_2168 = torch.ops.aten.mul.Tensor(view_as_complex_105, clone_9); view_as_complex_105 = None + view_as_real_105 = torch.ops.aten.view_as_real.default(mul_2168); mul_2168 = None + view_2272 = torch.ops.aten.view.default(view_as_real_105, [2, 4096, 16, 64]); view_as_real_105 = None + convert_element_type_3370 = torch.ops.prims.convert_element_type.default(view_2272, torch.bfloat16); view_2272 = None + cat_157 = torch.ops.aten.cat.default([slice_261, convert_element_type_3370], 3); slice_261 = convert_element_type_3370 = None + view_2273 = torch.ops.aten.view.default(cat_157, [2, 4096, 3072]); cat_157 = None + view_2274 = torch.ops.aten.view.default(view_2273, [8192, 3072]); view_2273 = None + permute_1702 = torch.ops.aten.permute.default(view_2274, [1, 0]) + mm_632 = torch.ops.aten.mm.default(permute_1702, view_36); permute_1702 = view_36 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16); primals_24 = None + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 64, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + permute_1704 = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None + mm_633 = torch.ops.aten.mm.default(view_2274, permute_1704); view_2274 = permute_1704 = None + view_2275 = torch.ops.aten.view.default(mm_633, [2, 4096, 2048]); mm_633 = None + add_2163 = torch.ops.aten.add.Tensor(view_2270, view_2275); view_2270 = view_2275 = None + convert_element_type_3375 = torch.ops.prims.convert_element_type.default(mm_632, torch.float32); mm_632 = None + reduce_scatter_tensor_364 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3375, 'avg', 64, '0'); convert_element_type_3375 = None + wait_tensor_975 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_364); reduce_scatter_tensor_364 = None + convert_element_type_3376 = torch.ops.prims.convert_element_type.default(add_2163, torch.float32); add_2163 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16); primals_23 = None + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 64, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + convert_element_type_3378 = torch.ops.prims.convert_element_type.default(wait_tensor_11, torch.float32); wait_tensor_11 = None + mul_2169 = torch.ops.aten.mul.Tensor(convert_element_type_3376, convert_element_type_3378); convert_element_type_3378 = None + convert_element_type_38 = torch.ops.prims.convert_element_type.default(add_5, torch.float32); add_5 = None + mul_9 = torch.ops.aten.mul.Tensor(convert_element_type_38, rsqrt_3); convert_element_type_38 = None + mul_2171 = torch.ops.aten.mul.Tensor(mul_9, mul_2169) + sum_313 = torch.ops.aten.sum.dim_IntList(mul_2171, [2], True); mul_2171 = None + div_287 = torch.ops.aten.div.Tensor(mul_9, 2048) + mul_2172 = torch.ops.aten.mul.Tensor(div_287, sum_313); div_287 = sum_313 = None + sub_780 = torch.ops.aten.sub.Tensor(mul_2169, mul_2172); mul_2169 = mul_2172 = None + mul_2173 = torch.ops.aten.mul.Tensor(sub_780, rsqrt_3); sub_780 = rsqrt_3 = None + mul_2174 = torch.ops.aten.mul.Tensor(convert_element_type_3376, mul_9); convert_element_type_3376 = mul_9 = None + sum_314 = torch.ops.aten.sum.dim_IntList(mul_2174, [0, 1]); mul_2174 = None + convert_element_type_3379 = torch.ops.prims.convert_element_type.default(mul_2173, torch.bfloat16); mul_2173 = None + add_2164 = torch.ops.aten.add.Tensor(add_2162, convert_element_type_3379); add_2162 = convert_element_type_3379 = None + convert_element_type_default_4 = torch.ops.prims.convert_element_type.default(sum_314, torch.float32); sum_314 = None + reduce_scatter_tensor_365 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_4, 'avg', 64, '0'); convert_element_type_default_4 = None + wait_tensor_976 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_365); reduce_scatter_tensor_365 = None + view_2276 = torch.ops.aten.view.default(add_2164, [8192, 2048]) + permute_1706 = torch.ops.aten.permute.default(view_2276, [1, 0]) + mm_634 = torch.ops.aten.mm.default(permute_1706, view_32); permute_1706 = view_32 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16); primals_22 = None + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 64, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_10, [1, 0]); wait_tensor_10 = None + permute_1708 = torch.ops.aten.permute.default(permute_10, [1, 0]); permute_10 = None + mm_635 = torch.ops.aten.mm.default(view_2276, permute_1708); view_2276 = permute_1708 = None + view_2277 = torch.ops.aten.view.default(mm_635, [2, 4096, 10944]); mm_635 = None + convert_element_type_3386 = torch.ops.prims.convert_element_type.default(mm_634, torch.float32); mm_634 = None + reduce_scatter_tensor_366 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3386, 'avg', 64, '0'); convert_element_type_3386 = None + wait_tensor_977 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_366); reduce_scatter_tensor_366 = None + view_27 = torch.ops.aten.view.default(mm_4, [2, 4096, 10944]); mm_4 = None + convert_element_type_29 = torch.ops.prims.convert_element_type.default(view_27, torch.float32); view_27 = None + neg = torch.ops.aten.neg.default(convert_element_type_29) + exp = torch.ops.aten.exp.default(neg); neg = None + add_4 = torch.ops.aten.add.Tensor(exp, 1); exp = None + div = torch.ops.aten.div.Tensor(convert_element_type_29, add_4) + convert_element_type_30 = torch.ops.prims.convert_element_type.default(div, torch.bfloat16); div = None + mul_2175 = torch.ops.aten.mul.Tensor(view_2277, convert_element_type_30); convert_element_type_30 = None + view_30 = torch.ops.aten.view.default(mm_5, [2, 4096, 10944]); mm_5 = None + mul_2176 = torch.ops.aten.mul.Tensor(view_2277, view_30); view_2277 = view_30 = None + view_2278 = torch.ops.aten.view.default(mul_2175, [8192, 10944]); mul_2175 = None + permute_1710 = torch.ops.aten.permute.default(view_2278, [1, 0]) + mm_636 = torch.ops.aten.mm.default(permute_1710, view_26); permute_1710 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16); primals_21 = None + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 64, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_9, [1, 0]); wait_tensor_9 = None + permute_1712 = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None + mm_637 = torch.ops.aten.mm.default(view_2278, permute_1712); view_2278 = permute_1712 = None + view_2279 = torch.ops.aten.view.default(mm_637, [2, 4096, 2048]); mm_637 = None + convert_element_type_3391 = torch.ops.prims.convert_element_type.default(mm_636, torch.float32); mm_636 = None + reduce_scatter_tensor_367 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3391, 'avg', 64, '0'); convert_element_type_3391 = None + wait_tensor_978 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_367); reduce_scatter_tensor_367 = None + convert_element_type_3392 = torch.ops.prims.convert_element_type.default(mul_2176, torch.float32); mul_2176 = None + reciprocal_52 = torch.ops.aten.reciprocal.default(add_4); add_4 = None + mul_2177 = torch.ops.aten.mul.Tensor(reciprocal_52, 1); reciprocal_52 = None + mul_2178 = torch.ops.aten.mul.Tensor(convert_element_type_3392, mul_2177); convert_element_type_3392 = None + sub_781 = torch.ops.aten.sub.Tensor(1, mul_2177); mul_2177 = None + mul_2179 = torch.ops.aten.mul.Tensor(convert_element_type_29, sub_781); convert_element_type_29 = sub_781 = None + add_2166 = torch.ops.aten.add.Tensor(mul_2179, 1); mul_2179 = None + mul_2180 = torch.ops.aten.mul.Tensor(mul_2178, add_2166); mul_2178 = add_2166 = None + convert_element_type_3394 = torch.ops.prims.convert_element_type.default(mul_2180, torch.bfloat16); mul_2180 = None + view_2280 = torch.ops.aten.view.default(convert_element_type_3394, [8192, 10944]); convert_element_type_3394 = None + permute_1714 = torch.ops.aten.permute.default(view_2280, [1, 0]) + mm_638 = torch.ops.aten.mm.default(permute_1714, view_26); permute_1714 = view_26 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16); primals_20 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_26, 64, '0'); convert_element_type_26 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_8, [1, 0]); wait_tensor_8 = None + permute_1716 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None + mm_639 = torch.ops.aten.mm.default(view_2280, permute_1716); view_2280 = permute_1716 = None + view_2281 = torch.ops.aten.view.default(mm_639, [2, 4096, 2048]); mm_639 = None + add_2167 = torch.ops.aten.add.Tensor(view_2279, view_2281); view_2279 = view_2281 = None + convert_element_type_3399 = torch.ops.prims.convert_element_type.default(mm_638, torch.float32); mm_638 = None + reduce_scatter_tensor_368 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3399, 'avg', 64, '0'); convert_element_type_3399 = None + wait_tensor_979 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_368); reduce_scatter_tensor_368 = None + convert_element_type_3400 = torch.ops.prims.convert_element_type.default(add_2167, torch.float32); add_2167 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16); primals_19 = None + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 64, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_3402 = torch.ops.prims.convert_element_type.default(wait_tensor_7, torch.float32); wait_tensor_7 = None + mul_2181 = torch.ops.aten.mul.Tensor(convert_element_type_3400, convert_element_type_3402); convert_element_type_3402 = None + view_23 = torch.ops.aten.view.default(mm_3, [2, 4096, 2048]); mm_3 = None + add_2 = torch.ops.aten.add.Tensor(embedding, view_23); view_23 = None + convert_element_type_24 = torch.ops.prims.convert_element_type.default(add_2, torch.float32); add_2 = None + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_24, rsqrt_2); convert_element_type_24 = None + mul_2183 = torch.ops.aten.mul.Tensor(mul_6, mul_2181) + sum_315 = torch.ops.aten.sum.dim_IntList(mul_2183, [2], True); mul_2183 = None + div_288 = torch.ops.aten.div.Tensor(mul_6, 2048) + mul_2184 = torch.ops.aten.mul.Tensor(div_288, sum_315); div_288 = sum_315 = None + sub_782 = torch.ops.aten.sub.Tensor(mul_2181, mul_2184); mul_2181 = mul_2184 = None + mul_2185 = torch.ops.aten.mul.Tensor(sub_782, rsqrt_2); sub_782 = rsqrt_2 = None + mul_2186 = torch.ops.aten.mul.Tensor(convert_element_type_3400, mul_6); convert_element_type_3400 = mul_6 = None + sum_316 = torch.ops.aten.sum.dim_IntList(mul_2186, [0, 1]); mul_2186 = None + convert_element_type_3403 = torch.ops.prims.convert_element_type.default(mul_2185, torch.bfloat16); mul_2185 = None + add_2168 = torch.ops.aten.add.Tensor(add_2164, convert_element_type_3403); add_2164 = convert_element_type_3403 = None + convert_element_type_default_3 = torch.ops.prims.convert_element_type.default(sum_316, torch.float32); sum_316 = None + reduce_scatter_tensor_369 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_3, 'avg', 64, '0'); convert_element_type_default_3 = None + wait_tensor_980 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_369); reduce_scatter_tensor_369 = None + view_2282 = torch.ops.aten.view.default(add_2168, [8192, 2048]) + permute_1718 = torch.ops.aten.permute.default(view_2282, [1, 0]) + permute_6 = torch.ops.aten.permute.default(getitem_6, [0, 2, 1, 3]) + view_20 = torch.ops.aten.view.default(permute_6, [2, 4096, -1]); permute_6 = None + view_22 = torch.ops.aten.view.default(view_20, [8192, 2048]); view_20 = None + mm_640 = torch.ops.aten.mm.default(permute_1718, view_22); permute_1718 = view_22 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16); primals_18 = None + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 64, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + permute_1720 = torch.ops.aten.permute.default(permute_7, [1, 0]); permute_7 = None + mm_641 = torch.ops.aten.mm.default(view_2282, permute_1720); view_2282 = permute_1720 = None + view_2283 = torch.ops.aten.view.default(mm_641, [2, 4096, 2048]); mm_641 = None + convert_element_type_3410 = torch.ops.prims.convert_element_type.default(mm_640, torch.float32); mm_640 = None + reduce_scatter_tensor_370 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3410, 'avg', 64, '0'); convert_element_type_3410 = None + wait_tensor_981 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_370); reduce_scatter_tensor_370 = None + view_2284 = torch.ops.aten.view.default(view_2283, [2, 4096, 16, 128]); view_2283 = None + permute_1722 = torch.ops.aten.permute.default(view_2284, [0, 2, 1, 3]); view_2284 = None + fw_graph26 = self.fw_graph26 + joint_graph26 = self.joint_graph26 + mask_graph26 = self.mask_graph26 + flex_attention_backward_26 = torch.ops.higher_order.flex_attention_backward(permute_3, permute_4, permute_5, getitem_6, getitem_7, permute_1722, None, fw_graph26, joint_graph26, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, mask_graph26), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); permute_3 = permute_4 = permute_5 = getitem_6 = getitem_7 = permute_1722 = fw_graph26 = joint_graph26 = primals_10 = primals_9 = primals_12 = primals_13 = primals_14 = primals_15 = primals_16 = primals_17 = mask_graph26 = primals_11 = None + getitem_477 = flex_attention_backward_26[0] + getitem_478 = flex_attention_backward_26[1] + getitem_479 = flex_attention_backward_26[2]; flex_attention_backward_26 = None + permute_1723 = torch.ops.aten.permute.default(getitem_479, [0, 2, 1, 3]); getitem_479 = None + permute_1724 = torch.ops.aten.permute.default(getitem_478, [0, 2, 1, 3]); getitem_478 = None + permute_1725 = torch.ops.aten.permute.default(getitem_477, [0, 2, 1, 3]); getitem_477 = None + slice_263 = torch.ops.aten.slice.Tensor(permute_1724, 3, 0, 128) + slice_264 = torch.ops.aten.slice.Tensor(permute_1724, 3, 128, 192); permute_1724 = None + sum_317 = torch.ops.aten.sum.dim_IntList(slice_264, [2], True); slice_264 = None + cat_158 = torch.ops.aten.cat.default([slice_263, permute_1723], 3); slice_263 = permute_1723 = None + view_2285 = torch.ops.aten.view.default(cat_158, [2, 4096, 4096]); cat_158 = None + view_2286 = torch.ops.aten.view.default(view_2285, [8192, 4096]); view_2285 = None + permute_1726 = torch.ops.aten.permute.default(view_2286, [1, 0]) + mm_642 = torch.ops.aten.mm.default(permute_1726, view_17); permute_1726 = view_17 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16); primals_8 = None + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 64, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + permute_1728 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + mm_643 = torch.ops.aten.mm.default(view_2286, permute_1728); view_2286 = permute_1728 = None + view_2287 = torch.ops.aten.view.default(mm_643, [2, 4096, 512]); mm_643 = None + convert_element_type_3415 = torch.ops.prims.convert_element_type.default(mm_642, torch.float32); mm_642 = None + reduce_scatter_tensor_371 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3415, 'avg', 64, '0'); convert_element_type_3415 = None + wait_tensor_982 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_371); reduce_scatter_tensor_371 = None + convert_element_type_3416 = torch.ops.prims.convert_element_type.default(view_2287, torch.float32); view_2287 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16); primals_7 = None + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_14, 64, '0'); convert_element_type_14 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + convert_element_type_3418 = torch.ops.prims.convert_element_type.default(wait_tensor_4, torch.float32); wait_tensor_4 = None + mul_2187 = torch.ops.aten.mul.Tensor(convert_element_type_3416, convert_element_type_3418); convert_element_type_3418 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(getitem_2, torch.float32); getitem_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_15, rsqrt_1); convert_element_type_15 = None + mul_2189 = torch.ops.aten.mul.Tensor(mul_4, mul_2187) + sum_318 = torch.ops.aten.sum.dim_IntList(mul_2189, [2], True); mul_2189 = None + div_289 = torch.ops.aten.div.Tensor(mul_4, 512) + mul_2190 = torch.ops.aten.mul.Tensor(div_289, sum_318); div_289 = sum_318 = None + sub_783 = torch.ops.aten.sub.Tensor(mul_2187, mul_2190); mul_2187 = mul_2190 = None + mul_2191 = torch.ops.aten.mul.Tensor(sub_783, rsqrt_1); sub_783 = rsqrt_1 = None + mul_2192 = torch.ops.aten.mul.Tensor(convert_element_type_3416, mul_4); convert_element_type_3416 = mul_4 = None + sum_319 = torch.ops.aten.sum.dim_IntList(mul_2192, [0, 1]); mul_2192 = None + convert_element_type_3419 = torch.ops.prims.convert_element_type.default(mul_2191, torch.bfloat16); mul_2191 = None + convert_element_type_default_2 = torch.ops.prims.convert_element_type.default(sum_319, torch.float32); sum_319 = None + reduce_scatter_tensor_372 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_2, 'avg', 64, '0'); convert_element_type_default_2 = None + wait_tensor_983 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_372); reduce_scatter_tensor_372 = None + convert_element_type_3422 = torch.ops.prims.convert_element_type.default(sum_317, torch.float32); sum_317 = None + view_2288 = torch.ops.aten.view.default(convert_element_type_3422, [2, 4096, 1, 32, 2]); convert_element_type_3422 = None + view_as_complex_106 = torch.ops.aten.view_as_complex.default(view_2288); view_2288 = None + mul_2193 = torch.ops.aten.mul.Tensor(view_as_complex_106, clone_9); view_as_complex_106 = None + view_as_real_106 = torch.ops.aten.view_as_real.default(mul_2193); mul_2193 = None + view_2289 = torch.ops.aten.view.default(view_as_real_106, [2, 4096, 1, 64]); view_as_real_106 = None + convert_element_type_3423 = torch.ops.prims.convert_element_type.default(view_2289, torch.bfloat16); view_2289 = None + squeeze_52 = torch.ops.aten.squeeze.dim(convert_element_type_3423, 2); convert_element_type_3423 = None + cat_159 = torch.ops.aten.cat.default([convert_element_type_3419, squeeze_52], 2); convert_element_type_3419 = squeeze_52 = None + view_2290 = torch.ops.aten.view.default(cat_159, [8192, 576]); cat_159 = None + permute_1730 = torch.ops.aten.permute.default(view_2290, [1, 0]) + mm_644 = torch.ops.aten.mm.default(permute_1730, view_3); permute_1730 = None + convert_element_type_9 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16); primals_6 = None + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_9, 64, '0'); convert_element_type_9 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_3, [1, 0]); wait_tensor_3 = None + permute_1732 = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None + mm_645 = torch.ops.aten.mm.default(view_2290, permute_1732); view_2290 = permute_1732 = None + view_2291 = torch.ops.aten.view.default(mm_645, [2, 4096, 2048]); mm_645 = None + convert_element_type_3428 = torch.ops.prims.convert_element_type.default(mm_644, torch.float32); mm_644 = None + reduce_scatter_tensor_373 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3428, 'avg', 64, '0'); convert_element_type_3428 = None + wait_tensor_984 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_373); reduce_scatter_tensor_373 = None + slice_265 = torch.ops.aten.slice.Tensor(permute_1725, 3, 0, 128) + slice_266 = torch.ops.aten.slice.Tensor(permute_1725, 3, 128, 192); permute_1725 = None + convert_element_type_3429 = torch.ops.prims.convert_element_type.default(slice_266, torch.float32); slice_266 = None + view_2292 = torch.ops.aten.view.default(convert_element_type_3429, [2, 4096, 16, 32, 2]); convert_element_type_3429 = None + view_as_complex_107 = torch.ops.aten.view_as_complex.default(view_2292); view_2292 = None + mul_2194 = torch.ops.aten.mul.Tensor(view_as_complex_107, clone_9); view_as_complex_107 = clone_9 = None + view_as_real_107 = torch.ops.aten.view_as_real.default(mul_2194); mul_2194 = None + view_2293 = torch.ops.aten.view.default(view_as_real_107, [2, 4096, 16, 64]); view_as_real_107 = None + convert_element_type_3430 = torch.ops.prims.convert_element_type.default(view_2293, torch.bfloat16); view_2293 = None + cat_160 = torch.ops.aten.cat.default([slice_265, convert_element_type_3430], 3); slice_265 = convert_element_type_3430 = None + view_2294 = torch.ops.aten.view.default(cat_160, [2, 4096, 3072]); cat_160 = None + view_2295 = torch.ops.aten.view.default(view_2294, [8192, 3072]); view_2294 = None + permute_1734 = torch.ops.aten.permute.default(view_2295, [1, 0]) + mm_646 = torch.ops.aten.mm.default(permute_1734, view_3); permute_1734 = view_3 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16); primals_5 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 64, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + permute_1736 = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + mm_647 = torch.ops.aten.mm.default(view_2295, permute_1736); view_2295 = permute_1736 = None + view_2296 = torch.ops.aten.view.default(mm_647, [2, 4096, 2048]); mm_647 = None + add_2169 = torch.ops.aten.add.Tensor(view_2291, view_2296); view_2291 = view_2296 = None + convert_element_type_3435 = torch.ops.prims.convert_element_type.default(mm_646, torch.float32); mm_646 = None + reduce_scatter_tensor_374 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_3435, 'avg', 64, '0'); convert_element_type_3435 = None + wait_tensor_985 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_374); reduce_scatter_tensor_374 = None + convert_element_type_3436 = torch.ops.prims.convert_element_type.default(add_2169, torch.float32); add_2169 = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 64, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_3438 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32); wait_tensor_1 = None + mul_2195 = torch.ops.aten.mul.Tensor(convert_element_type_3436, convert_element_type_3438); convert_element_type_3438 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32); embedding = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_2197 = torch.ops.aten.mul.Tensor(mul, mul_2195) + sum_320 = torch.ops.aten.sum.dim_IntList(mul_2197, [2], True); mul_2197 = None + div_290 = torch.ops.aten.div.Tensor(mul, 2048) + mul_2198 = torch.ops.aten.mul.Tensor(div_290, sum_320); div_290 = sum_320 = None + sub_784 = torch.ops.aten.sub.Tensor(mul_2195, mul_2198); mul_2195 = mul_2198 = None + mul_2199 = torch.ops.aten.mul.Tensor(sub_784, rsqrt); sub_784 = rsqrt = None + mul_2200 = torch.ops.aten.mul.Tensor(convert_element_type_3436, mul); convert_element_type_3436 = mul = None + sum_321 = torch.ops.aten.sum.dim_IntList(mul_2200, [0, 1]); mul_2200 = None + convert_element_type_3439 = torch.ops.prims.convert_element_type.default(mul_2199, torch.bfloat16); mul_2199 = None + add_2170 = torch.ops.aten.add.Tensor(add_2168, convert_element_type_3439); add_2168 = convert_element_type_3439 = None + convert_element_type_default_1 = torch.ops.prims.convert_element_type.default(sum_321, torch.float32); sum_321 = None + reduce_scatter_tensor_375 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_1, 'avg', 64, '0'); convert_element_type_default_1 = None + wait_tensor_986 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_375); reduce_scatter_tensor_375 = None + convert_element_type_3442 = torch.ops.prims.convert_element_type.default(add_2170, torch.float32); add_2170 = None + eq_572 = torch.ops.aten.eq.Scalar(primals_2, -1) + unsqueeze_79 = torch.ops.aten.unsqueeze.default(eq_572, -1); eq_572 = None + full_default_104 = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + where = torch.ops.aten.where.self(unsqueeze_79, full_default_104, convert_element_type_3442); unsqueeze_79 = full_default_104 = convert_element_type_3442 = None + full_default_105 = torch.ops.aten.full.default([102400, 2048], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_104 = torch.ops.aten.index_put.default(full_default_105, [primals_2], where, True); full_default_105 = primals_2 = where = None + convert_element_type_default = torch.ops.prims.convert_element_type.default(index_put_104, torch.float32); index_put_104 = None + reduce_scatter_tensor_376 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default, 'avg', 64, '0'); convert_element_type_default = None + wait_tensor_987 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_376); reduce_scatter_tensor_376 = None + return (wait_tensor_987, None, None, wait_tensor_986, wait_tensor_985, wait_tensor_984, wait_tensor_983, wait_tensor_982, None, None, None, None, None, None, None, None, None, wait_tensor_981, wait_tensor_980, wait_tensor_979, wait_tensor_978, wait_tensor_977, wait_tensor_976, wait_tensor_975, wait_tensor_974, wait_tensor_973, wait_tensor_972, wait_tensor_971, wait_tensor_970, None, wait_tensor_969, None, wait_tensor_967, wait_tensor_966, wait_tensor_965, wait_tensor_963, wait_tensor_962, wait_tensor_961, wait_tensor_960, wait_tensor_959, wait_tensor_958, wait_tensor_957, wait_tensor_956, wait_tensor_955, wait_tensor_954, None, wait_tensor_953, None, wait_tensor_951, wait_tensor_950, wait_tensor_949, wait_tensor_947, wait_tensor_946, wait_tensor_945, wait_tensor_944, wait_tensor_943, wait_tensor_942, wait_tensor_941, wait_tensor_940, wait_tensor_939, wait_tensor_938, None, wait_tensor_937, None, wait_tensor_935, wait_tensor_934, wait_tensor_933, wait_tensor_931, wait_tensor_930, wait_tensor_929, wait_tensor_928, wait_tensor_927, wait_tensor_926, wait_tensor_925, wait_tensor_924, wait_tensor_923, wait_tensor_922, None, wait_tensor_921, None, wait_tensor_919, wait_tensor_918, wait_tensor_917, wait_tensor_915, wait_tensor_914, wait_tensor_913, wait_tensor_912, wait_tensor_911, wait_tensor_910, wait_tensor_909, wait_tensor_908, wait_tensor_907, wait_tensor_906, None, wait_tensor_905, None, wait_tensor_903, wait_tensor_902, wait_tensor_901, wait_tensor_899, wait_tensor_898, wait_tensor_897, wait_tensor_896, wait_tensor_895, wait_tensor_894, wait_tensor_893, wait_tensor_892, wait_tensor_891, wait_tensor_890, None, wait_tensor_889, None, wait_tensor_887, wait_tensor_886, wait_tensor_885, wait_tensor_883, wait_tensor_882, wait_tensor_881, wait_tensor_880, wait_tensor_879, wait_tensor_878, wait_tensor_877, wait_tensor_876, wait_tensor_875, wait_tensor_874, None, wait_tensor_873, None, wait_tensor_871, wait_tensor_870, wait_tensor_869, wait_tensor_867, wait_tensor_866, wait_tensor_865, wait_tensor_864, wait_tensor_863, wait_tensor_862, wait_tensor_861, wait_tensor_860, wait_tensor_859, wait_tensor_858, None, wait_tensor_857, None, wait_tensor_855, wait_tensor_854, wait_tensor_853, wait_tensor_851, wait_tensor_850, wait_tensor_849, wait_tensor_848, wait_tensor_847, wait_tensor_846, wait_tensor_845, wait_tensor_844, wait_tensor_843, wait_tensor_842, None, wait_tensor_841, None, wait_tensor_839, wait_tensor_838, wait_tensor_837, wait_tensor_835, wait_tensor_834, wait_tensor_833, wait_tensor_832, wait_tensor_831, wait_tensor_830, wait_tensor_829, wait_tensor_828, wait_tensor_827, wait_tensor_826, None, wait_tensor_825, None, wait_tensor_823, wait_tensor_822, wait_tensor_821, wait_tensor_819, wait_tensor_818, wait_tensor_817, wait_tensor_816, wait_tensor_815, wait_tensor_814, wait_tensor_813, wait_tensor_812, wait_tensor_811, wait_tensor_810, None, wait_tensor_809, None, wait_tensor_807, wait_tensor_806, wait_tensor_805, wait_tensor_803, wait_tensor_802, wait_tensor_801, wait_tensor_800, wait_tensor_799, wait_tensor_798, wait_tensor_797, wait_tensor_796, wait_tensor_795, wait_tensor_794, None, wait_tensor_793, None, wait_tensor_791, wait_tensor_790, wait_tensor_789, wait_tensor_787, wait_tensor_786, wait_tensor_785, wait_tensor_784, wait_tensor_783, wait_tensor_782, wait_tensor_781, wait_tensor_780, wait_tensor_779, wait_tensor_778, None, wait_tensor_777, None, wait_tensor_775, wait_tensor_774, wait_tensor_773, wait_tensor_771, wait_tensor_770, wait_tensor_769, wait_tensor_768, wait_tensor_767, wait_tensor_766, wait_tensor_765, wait_tensor_764, wait_tensor_763, wait_tensor_762, None, wait_tensor_761, None, wait_tensor_759, wait_tensor_758, wait_tensor_757, wait_tensor_755, wait_tensor_754, wait_tensor_753, wait_tensor_752, wait_tensor_751, wait_tensor_750, wait_tensor_749, wait_tensor_748, wait_tensor_747, wait_tensor_746, None, wait_tensor_745, None, wait_tensor_743, wait_tensor_742, wait_tensor_741, wait_tensor_739, wait_tensor_738, wait_tensor_737, wait_tensor_736, wait_tensor_735, wait_tensor_734, wait_tensor_733, wait_tensor_732, wait_tensor_731, wait_tensor_730, None, wait_tensor_729, None, wait_tensor_727, wait_tensor_726, wait_tensor_725, wait_tensor_723, wait_tensor_722, wait_tensor_721, wait_tensor_720, wait_tensor_719, wait_tensor_718, wait_tensor_717, wait_tensor_716, wait_tensor_715, wait_tensor_714, None, wait_tensor_713, None, wait_tensor_711, wait_tensor_710, wait_tensor_709, wait_tensor_707, wait_tensor_706, wait_tensor_705, wait_tensor_704, wait_tensor_703, wait_tensor_702, wait_tensor_701, wait_tensor_700, wait_tensor_699, wait_tensor_698, None, wait_tensor_697, None, wait_tensor_695, wait_tensor_694, wait_tensor_693, wait_tensor_691, wait_tensor_690, wait_tensor_689, wait_tensor_688, wait_tensor_687, wait_tensor_686, wait_tensor_685, wait_tensor_684, wait_tensor_683, wait_tensor_682, None, wait_tensor_681, None, wait_tensor_679, wait_tensor_678, wait_tensor_677, wait_tensor_675, wait_tensor_674, wait_tensor_673, wait_tensor_672, wait_tensor_671, wait_tensor_670, wait_tensor_669, wait_tensor_668, wait_tensor_667, wait_tensor_666, None, wait_tensor_665, None, wait_tensor_663, wait_tensor_662, wait_tensor_661, wait_tensor_659, wait_tensor_658, wait_tensor_657, wait_tensor_656, wait_tensor_655, wait_tensor_654, wait_tensor_653, wait_tensor_652, wait_tensor_651, wait_tensor_650, None, wait_tensor_649, None, wait_tensor_647, wait_tensor_646, wait_tensor_645, wait_tensor_643, wait_tensor_642, wait_tensor_641, wait_tensor_640, wait_tensor_639, wait_tensor_638, wait_tensor_637, wait_tensor_636, wait_tensor_635, wait_tensor_634, None, wait_tensor_633, None, wait_tensor_631, wait_tensor_630, wait_tensor_629, wait_tensor_627, wait_tensor_626, wait_tensor_625, wait_tensor_624, wait_tensor_623, wait_tensor_622, wait_tensor_621, wait_tensor_620, wait_tensor_619, wait_tensor_618, None, wait_tensor_617, None, wait_tensor_615, wait_tensor_614, wait_tensor_613, wait_tensor_611, wait_tensor_610, wait_tensor_609, wait_tensor_608, wait_tensor_607, wait_tensor_606, wait_tensor_605, wait_tensor_604, wait_tensor_603, wait_tensor_602, None, wait_tensor_601, None, wait_tensor_599, wait_tensor_598, wait_tensor_597, wait_tensor_595, wait_tensor_594, wait_tensor_593, wait_tensor_592, wait_tensor_591, wait_tensor_590, wait_tensor_589, wait_tensor_588, wait_tensor_587, wait_tensor_586, None, wait_tensor_585, None, wait_tensor_583, wait_tensor_582, wait_tensor_581, wait_tensor_579, wait_tensor_578, wait_tensor_577, wait_tensor_576, wait_tensor_575, wait_tensor_574, wait_tensor_573, wait_tensor_572, wait_tensor_571, wait_tensor_570, None, wait_tensor_569, None, wait_tensor_567, wait_tensor_566, wait_tensor_565, wait_tensor_563, wait_tensor_562, wait_tensor_561, wait_tensor_560, wait_tensor_559) + +def load_args(reader): + # MoE expert token counts (approximate uniform distribution) + u8 = u9 = u10 = u11 = u12 = u13 = u14 = u15 = u24 = u25 = u26 = u27 = u28 = u29 = u30 = u31 = u40 = u41 = u42 = u43 = u44 = u45 = u46 = u47 = u56 = u57 = u58 = u59 = u60 = u61 = u62 = u63 = u72 = u73 = u74 = u75 = u76 = u77 = u78 = u79 = u88 = u89 = u90 = u91 = u92 = u93 = u94 = u95 = u104 = u105 = u106 = u107 = u108 = u109 = u110 = u111 = u120 = u121 = u122 = u123 = u124 = u125 = u126 = u127 = u136 = u137 = u138 = u139 = u140 = u141 = u142 = u143 = u152 = u153 = u154 = u155 = u156 = u157 = u158 = u159 = u168 = u169 = u170 = u171 = u172 = u173 = u174 = u175 = u184 = u185 = u186 = u187 = u188 = u189 = u190 = u191 = u200 = u201 = u202 = u203 = u204 = u205 = u206 = u207 = u216 = u217 = u218 = u219 = u220 = u221 = u222 = u223 = u232 = u233 = u234 = u235 = u236 = u237 = u238 = u239 = u248 = u249 = u250 = u251 = u252 = u253 = u254 = u255 = u264 = u265 = u266 = u267 = u268 = u269 = u270 = u271 = u280 = u281 = u282 = u283 = u284 = u285 = u286 = u287 = u296 = u297 = u298 = u299 = u300 = u301 = u302 = u303 = u312 = u313 = u314 = u315 = u316 = u317 = u318 = u319 = u328 = u329 = u330 = u331 = u332 = u333 = u334 = u335 = u344 = u345 = u346 = u347 = u348 = u349 = u350 = u351 = u360 = u361 = u362 = u363 = u364 = u365 = u366 = u367 = u376 = u377 = u378 = u379 = u380 = u381 = u382 = u383 = u392 = u393 = u394 = u395 = u396 = u397 = u398 = u399 = u408 = u409 = u410 = u411 = u412 = u413 = u414 = u415 = 512 + reader.symint(512) # _local_scalar_dense + reader.symint(512) # _local_scalar_dense_1 + reader.symint(512) # _local_scalar_dense_2 + reader.symint(512) # _local_scalar_dense_3 + reader.symint(512) # _local_scalar_dense_4 + reader.symint(512) # _local_scalar_dense_5 + reader.symint(512) # _local_scalar_dense_6 + reader.symint(512) # _local_scalar_dense_7 + reader.symint(512) # _local_scalar_dense_8 + reader.symint(512) # _local_scalar_dense_9 + reader.symint(512) # _local_scalar_dense_10 + reader.symint(512) # _local_scalar_dense_11 + reader.symint(512) # _local_scalar_dense_12 + reader.symint(512) # _local_scalar_dense_13 + reader.symint(512) # _local_scalar_dense_14 + reader.symint(512) # _local_scalar_dense_15 + reader.symint(512) # _local_scalar_dense_16 + reader.symint(512) # _local_scalar_dense_17 + reader.symint(512) # _local_scalar_dense_18 + reader.symint(512) # _local_scalar_dense_19 + reader.symint(512) # _local_scalar_dense_20 + reader.symint(512) # _local_scalar_dense_21 + reader.symint(512) # _local_scalar_dense_22 + reader.symint(512) # _local_scalar_dense_23 + reader.symint(512) # _local_scalar_dense_24 + reader.symint(512) # _local_scalar_dense_25 + reader.symint(512) # _local_scalar_dense_26 + reader.symint(512) # _local_scalar_dense_27 + reader.symint(512) # _local_scalar_dense_28 + reader.symint(512) # _local_scalar_dense_29 + reader.symint(512) # _local_scalar_dense_30 + reader.symint(512) # _local_scalar_dense_31 + reader.symint(512) # _local_scalar_dense_32 + reader.symint(512) # _local_scalar_dense_33 + reader.symint(512) # _local_scalar_dense_34 + reader.symint(512) # _local_scalar_dense_35 + reader.symint(512) # _local_scalar_dense_36 + reader.symint(512) # _local_scalar_dense_37 + reader.symint(512) # _local_scalar_dense_38 + reader.symint(512) # _local_scalar_dense_39 + reader.symint(512) # _local_scalar_dense_40 + reader.symint(512) # _local_scalar_dense_41 + reader.symint(512) # _local_scalar_dense_42 + reader.symint(512) # _local_scalar_dense_43 + reader.symint(512) # _local_scalar_dense_44 + reader.symint(512) # _local_scalar_dense_45 + reader.symint(512) # _local_scalar_dense_46 + reader.symint(512) # _local_scalar_dense_47 + reader.symint(512) # _local_scalar_dense_48 + reader.symint(512) # _local_scalar_dense_49 + reader.symint(512) # _local_scalar_dense_50 + reader.symint(512) # _local_scalar_dense_51 + reader.symint(512) # _local_scalar_dense_52 + reader.symint(512) # _local_scalar_dense_53 + reader.symint(512) # _local_scalar_dense_54 + reader.symint(512) # _local_scalar_dense_55 + reader.symint(512) # _local_scalar_dense_56 + reader.symint(512) # _local_scalar_dense_57 + reader.symint(512) # _local_scalar_dense_58 + reader.symint(512) # _local_scalar_dense_59 + reader.symint(512) # _local_scalar_dense_60 + reader.symint(512) # _local_scalar_dense_61 + reader.symint(512) # _local_scalar_dense_62 + reader.symint(512) # _local_scalar_dense_63 + reader.symint(512) # _local_scalar_dense_64 + reader.symint(512) # _local_scalar_dense_65 + reader.symint(512) # _local_scalar_dense_66 + reader.symint(512) # _local_scalar_dense_67 + reader.symint(512) # _local_scalar_dense_68 + reader.symint(512) # _local_scalar_dense_69 + reader.symint(512) # _local_scalar_dense_70 + reader.symint(512) # _local_scalar_dense_71 + reader.symint(512) # _local_scalar_dense_72 + reader.symint(512) # _local_scalar_dense_73 + reader.symint(512) # _local_scalar_dense_74 + reader.symint(512) # _local_scalar_dense_75 + reader.symint(512) # _local_scalar_dense_76 + reader.symint(512) # _local_scalar_dense_77 + reader.symint(512) # _local_scalar_dense_78 + reader.symint(512) # _local_scalar_dense_79 + reader.symint(512) # _local_scalar_dense_80 + reader.symint(512) # _local_scalar_dense_81 + reader.symint(512) # _local_scalar_dense_82 + reader.symint(512) # _local_scalar_dense_83 + reader.symint(512) # _local_scalar_dense_84 + reader.symint(512) # _local_scalar_dense_85 + reader.symint(512) # _local_scalar_dense_86 + reader.symint(512) # _local_scalar_dense_87 + reader.symint(512) # _local_scalar_dense_88 + reader.symint(512) # _local_scalar_dense_89 + reader.symint(512) # _local_scalar_dense_90 + reader.symint(512) # _local_scalar_dense_91 + reader.symint(512) # _local_scalar_dense_92 + reader.symint(512) # _local_scalar_dense_93 + reader.symint(512) # _local_scalar_dense_94 + reader.symint(512) # _local_scalar_dense_95 + reader.symint(512) # _local_scalar_dense_96 + reader.symint(512) # _local_scalar_dense_97 + reader.symint(512) # _local_scalar_dense_98 + reader.symint(512) # _local_scalar_dense_99 + reader.symint(512) # _local_scalar_dense_100 + reader.symint(512) # _local_scalar_dense_101 + reader.symint(512) # _local_scalar_dense_102 + reader.symint(512) # _local_scalar_dense_103 + reader.symint(512) # _local_scalar_dense_104 + reader.symint(512) # _local_scalar_dense_105 + reader.symint(512) # _local_scalar_dense_106 + reader.symint(512) # _local_scalar_dense_107 + reader.symint(512) # _local_scalar_dense_108 + reader.symint(512) # _local_scalar_dense_109 + reader.symint(512) # _local_scalar_dense_110 + reader.symint(512) # _local_scalar_dense_111 + reader.symint(512) # _local_scalar_dense_112 + reader.symint(512) # _local_scalar_dense_113 + reader.symint(512) # _local_scalar_dense_114 + reader.symint(512) # _local_scalar_dense_115 + reader.symint(512) # _local_scalar_dense_116 + reader.symint(512) # _local_scalar_dense_117 + reader.symint(512) # _local_scalar_dense_118 + reader.symint(512) # _local_scalar_dense_119 + reader.symint(512) # _local_scalar_dense_120 + reader.symint(512) # _local_scalar_dense_121 + reader.symint(512) # _local_scalar_dense_122 + reader.symint(512) # _local_scalar_dense_123 + reader.symint(512) # _local_scalar_dense_124 + reader.symint(512) # _local_scalar_dense_125 + reader.symint(512) # _local_scalar_dense_126 + reader.symint(512) # _local_scalar_dense_127 + reader.symint(512) # _local_scalar_dense_128 + reader.symint(512) # _local_scalar_dense_129 + reader.symint(512) # _local_scalar_dense_130 + reader.symint(512) # _local_scalar_dense_131 + reader.symint(512) # _local_scalar_dense_132 + reader.symint(512) # _local_scalar_dense_133 + reader.symint(512) # _local_scalar_dense_134 + reader.symint(512) # _local_scalar_dense_135 + reader.symint(512) # _local_scalar_dense_136 + reader.symint(512) # _local_scalar_dense_137 + reader.symint(512) # _local_scalar_dense_138 + reader.symint(512) # _local_scalar_dense_139 + reader.symint(512) # _local_scalar_dense_140 + reader.symint(512) # _local_scalar_dense_141 + reader.symint(512) # _local_scalar_dense_142 + reader.symint(512) # _local_scalar_dense_143 + reader.symint(512) # _local_scalar_dense_144 + reader.symint(512) # _local_scalar_dense_145 + reader.symint(512) # _local_scalar_dense_146 + reader.symint(512) # _local_scalar_dense_147 + reader.symint(512) # _local_scalar_dense_148 + reader.symint(512) # _local_scalar_dense_149 + reader.symint(512) # _local_scalar_dense_150 + reader.symint(512) # _local_scalar_dense_151 + reader.symint(512) # _local_scalar_dense_152 + reader.symint(512) # _local_scalar_dense_153 + reader.symint(512) # _local_scalar_dense_154 + reader.symint(512) # _local_scalar_dense_155 + reader.symint(512) # _local_scalar_dense_156 + reader.symint(512) # _local_scalar_dense_157 + reader.symint(512) # _local_scalar_dense_158 + reader.symint(512) # _local_scalar_dense_159 + reader.symint(512) # _local_scalar_dense_160 + reader.symint(512) # _local_scalar_dense_161 + reader.symint(512) # _local_scalar_dense_162 + reader.symint(512) # _local_scalar_dense_163 + reader.symint(512) # _local_scalar_dense_164 + reader.symint(512) # _local_scalar_dense_165 + reader.symint(512) # _local_scalar_dense_166 + reader.symint(512) # _local_scalar_dense_167 + reader.symint(512) # _local_scalar_dense_168 + reader.symint(512) # _local_scalar_dense_169 + reader.symint(512) # _local_scalar_dense_170 + reader.symint(512) # _local_scalar_dense_171 + reader.symint(512) # _local_scalar_dense_172 + reader.symint(512) # _local_scalar_dense_173 + reader.symint(512) # _local_scalar_dense_174 + reader.symint(512) # _local_scalar_dense_175 + reader.symint(512) # _local_scalar_dense_176 + reader.symint(512) # _local_scalar_dense_177 + reader.symint(512) # _local_scalar_dense_178 + reader.symint(512) # _local_scalar_dense_179 + reader.symint(512) # _local_scalar_dense_180 + reader.symint(512) # _local_scalar_dense_181 + reader.symint(512) # _local_scalar_dense_182 + reader.symint(512) # _local_scalar_dense_183 + reader.symint(512) # _local_scalar_dense_184 + reader.symint(512) # _local_scalar_dense_185 + reader.symint(512) # _local_scalar_dense_186 + reader.symint(512) # _local_scalar_dense_187 + reader.symint(512) # _local_scalar_dense_188 + reader.symint(512) # _local_scalar_dense_189 + reader.symint(512) # _local_scalar_dense_190 + reader.symint(512) # _local_scalar_dense_191 + reader.symint(512) # _local_scalar_dense_192 + reader.symint(512) # _local_scalar_dense_193 + reader.symint(512) # _local_scalar_dense_194 + reader.symint(512) # _local_scalar_dense_195 + reader.symint(512) # _local_scalar_dense_196 + reader.symint(512) # _local_scalar_dense_197 + reader.symint(512) # _local_scalar_dense_198 + reader.symint(512) # _local_scalar_dense_199 + reader.symint(512) # _local_scalar_dense_200 + reader.symint(512) # _local_scalar_dense_201 + reader.symint(512) # _local_scalar_dense_202 + reader.symint(512) # _local_scalar_dense_203 + reader.symint(512) # _local_scalar_dense_204 + reader.symint(512) # _local_scalar_dense_205 + reader.symint(512) # _local_scalar_dense_206 + reader.symint(512) # _local_scalar_dense_207 + reader.symint(512) # _local_scalar_dense_208 + reader.symint(512) # _local_scalar_dense_209 + reader.symint(512) # _local_scalar_dense_210 + reader.symint(512) # _local_scalar_dense_211 + reader.symint(512) # _local_scalar_dense_212 + reader.symint(512) # _local_scalar_dense_213 + reader.symint(512) # _local_scalar_dense_214 + reader.symint(512) # _local_scalar_dense_215 + reader.symint(512) # _local_scalar_dense_216 + reader.symint(512) # _local_scalar_dense_217 + reader.symint(512) # _local_scalar_dense_218 + reader.symint(512) # _local_scalar_dense_219 + reader.symint(512) # _local_scalar_dense_220 + reader.symint(512) # _local_scalar_dense_221 + reader.symint(512) # _local_scalar_dense_222 + reader.symint(512) # _local_scalar_dense_223 + reader.symint(512) # _local_scalar_dense_224 + reader.symint(512) # _local_scalar_dense_225 + reader.symint(512) # _local_scalar_dense_226 + reader.symint(512) # _local_scalar_dense_227 + reader.symint(512) # _local_scalar_dense_228 + reader.symint(512) # _local_scalar_dense_229 + reader.symint(512) # _local_scalar_dense_230 + reader.symint(512) # _local_scalar_dense_231 + reader.symint(512) # _local_scalar_dense_232 + reader.symint(512) # _local_scalar_dense_233 + reader.symint(512) # _local_scalar_dense_234 + reader.symint(512) # _local_scalar_dense_235 + reader.symint(512) # _local_scalar_dense_236 + reader.symint(512) # _local_scalar_dense_237 + reader.symint(512) # _local_scalar_dense_238 + reader.symint(512) # _local_scalar_dense_239 + reader.symint(512) # _local_scalar_dense_240 + reader.symint(512) # _local_scalar_dense_241 + reader.symint(512) # _local_scalar_dense_242 + reader.symint(512) # _local_scalar_dense_243 + reader.symint(512) # _local_scalar_dense_244 + reader.symint(512) # _local_scalar_dense_245 + reader.symint(512) # _local_scalar_dense_246 + reader.symint(512) # _local_scalar_dense_247 + reader.symint(512) # _local_scalar_dense_248 + reader.symint(512) # _local_scalar_dense_249 + reader.symint(512) # _local_scalar_dense_250 + reader.symint(512) # _local_scalar_dense_251 + reader.symint(512) # _local_scalar_dense_252 + reader.symint(512) # _local_scalar_dense_253 + reader.symint(512) # _local_scalar_dense_254 + reader.symint(512) # _local_scalar_dense_255 + reader.symint(512) # _local_scalar_dense_256 + reader.symint(512) # _local_scalar_dense_257 + reader.symint(512) # _local_scalar_dense_258 + reader.symint(512) # _local_scalar_dense_259 + reader.symint(512) # _local_scalar_dense_260 + reader.symint(512) # _local_scalar_dense_261 + reader.symint(512) # _local_scalar_dense_262 + reader.symint(512) # _local_scalar_dense_263 + reader.symint(512) # _local_scalar_dense_264 + reader.symint(512) # _local_scalar_dense_265 + reader.symint(512) # _local_scalar_dense_266 + reader.symint(512) # _local_scalar_dense_267 + reader.symint(512) # _local_scalar_dense_268 + reader.symint(512) # _local_scalar_dense_269 + reader.symint(512) # _local_scalar_dense_270 + reader.symint(512) # _local_scalar_dense_271 + reader.symint(512) # _local_scalar_dense_272 + reader.symint(512) # _local_scalar_dense_273 + reader.symint(512) # _local_scalar_dense_274 + reader.symint(512) # _local_scalar_dense_275 + reader.symint(512) # _local_scalar_dense_276 + reader.symint(512) # _local_scalar_dense_277 + reader.symint(512) # _local_scalar_dense_278 + reader.symint(512) # _local_scalar_dense_279 + reader.symint(512) # _local_scalar_dense_280 + reader.symint(512) # _local_scalar_dense_281 + reader.symint(512) # _local_scalar_dense_282 + reader.symint(512) # _local_scalar_dense_283 + reader.symint(512) # _local_scalar_dense_284 + reader.symint(512) # _local_scalar_dense_285 + reader.symint(512) # _local_scalar_dense_286 + reader.symint(512) # _local_scalar_dense_287 + reader.symint(512) # _local_scalar_dense_288 + reader.symint(512) # _local_scalar_dense_289 + reader.symint(512) # _local_scalar_dense_290 + reader.symint(512) # _local_scalar_dense_291 + reader.symint(512) # _local_scalar_dense_292 + reader.symint(512) # _local_scalar_dense_293 + reader.symint(512) # _local_scalar_dense_294 + reader.symint(512) # _local_scalar_dense_295 + reader.symint(512) # _local_scalar_dense_296 + reader.symint(512) # _local_scalar_dense_297 + reader.symint(512) # _local_scalar_dense_298 + reader.symint(512) # _local_scalar_dense_299 + reader.symint(512) # _local_scalar_dense_300 + reader.symint(512) # _local_scalar_dense_301 + reader.symint(512) # _local_scalar_dense_302 + reader.symint(512) # _local_scalar_dense_303 + reader.symint(512) # _local_scalar_dense_304 + reader.symint(512) # _local_scalar_dense_305 + reader.symint(512) # _local_scalar_dense_306 + reader.symint(512) # _local_scalar_dense_307 + reader.symint(512) # _local_scalar_dense_308 + reader.symint(512) # _local_scalar_dense_309 + reader.symint(512) # _local_scalar_dense_310 + reader.symint(512) # _local_scalar_dense_311 + reader.symint(512) # _local_scalar_dense_312 + reader.symint(512) # _local_scalar_dense_313 + reader.symint(512) # _local_scalar_dense_314 + reader.symint(512) # _local_scalar_dense_315 + reader.symint(512) # _local_scalar_dense_316 + reader.symint(512) # _local_scalar_dense_317 + reader.symint(512) # _local_scalar_dense_318 + reader.symint(512) # _local_scalar_dense_319 + reader.symint(512) # _local_scalar_dense_320 + reader.symint(512) # _local_scalar_dense_321 + reader.symint(512) # _local_scalar_dense_322 + reader.symint(512) # _local_scalar_dense_323 + reader.symint(512) # _local_scalar_dense_324 + reader.symint(512) # _local_scalar_dense_325 + reader.symint(512) # _local_scalar_dense_326 + reader.symint(512) # _local_scalar_dense_327 + reader.symint(512) # _local_scalar_dense_328 + reader.symint(512) # _local_scalar_dense_329 + reader.symint(512) # _local_scalar_dense_330 + reader.symint(512) # _local_scalar_dense_331 + reader.symint(512) # _local_scalar_dense_332 + reader.symint(512) # _local_scalar_dense_333 + reader.symint(512) # _local_scalar_dense_334 + reader.symint(512) # _local_scalar_dense_335 + reader.symint(512) # _local_scalar_dense_336 + reader.symint(512) # _local_scalar_dense_337 + reader.symint(512) # _local_scalar_dense_338 + reader.symint(512) # _local_scalar_dense_339 + reader.symint(512) # _local_scalar_dense_340 + reader.symint(512) # _local_scalar_dense_341 + reader.symint(512) # _local_scalar_dense_342 + reader.symint(512) # _local_scalar_dense_343 + reader.symint(512) # _local_scalar_dense_344 + reader.symint(512) # _local_scalar_dense_345 + reader.symint(512) # _local_scalar_dense_346 + reader.symint(512) # _local_scalar_dense_347 + reader.symint(512) # _local_scalar_dense_348 + reader.symint(512) # _local_scalar_dense_349 + reader.symint(512) # _local_scalar_dense_350 + reader.symint(512) # _local_scalar_dense_351 + reader.symint(512) # _local_scalar_dense_352 + reader.symint(512) # _local_scalar_dense_353 + reader.symint(512) # _local_scalar_dense_354 + reader.symint(512) # _local_scalar_dense_355 + reader.symint(512) # _local_scalar_dense_356 + reader.symint(512) # _local_scalar_dense_357 + reader.symint(512) # _local_scalar_dense_358 + reader.symint(512) # _local_scalar_dense_359 + reader.symint(512) # _local_scalar_dense_360 + reader.symint(512) # _local_scalar_dense_361 + reader.symint(512) # _local_scalar_dense_362 + reader.symint(512) # _local_scalar_dense_363 + reader.symint(512) # _local_scalar_dense_364 + reader.symint(512) # _local_scalar_dense_365 + reader.symint(512) # _local_scalar_dense_366 + reader.symint(512) # _local_scalar_dense_367 + reader.symint(512) # _local_scalar_dense_368 + reader.symint(512) # _local_scalar_dense_369 + reader.symint(512) # _local_scalar_dense_370 + reader.symint(512) # _local_scalar_dense_371 + reader.symint(512) # _local_scalar_dense_372 + reader.symint(512) # _local_scalar_dense_373 + reader.symint(512) # _local_scalar_dense_374 + reader.symint(512) # _local_scalar_dense_375 + reader.symint(512) # _local_scalar_dense_376 + reader.symint(512) # _local_scalar_dense_377 + reader.symint(512) # _local_scalar_dense_378 + reader.symint(512) # _local_scalar_dense_379 + reader.symint(512) # _local_scalar_dense_380 + reader.symint(512) # _local_scalar_dense_381 + reader.symint(512) # _local_scalar_dense_382 + reader.symint(512) # _local_scalar_dense_383 + reader.symint(512) # _local_scalar_dense_384 + reader.symint(512) # _local_scalar_dense_385 + reader.symint(512) # _local_scalar_dense_386 + reader.symint(512) # _local_scalar_dense_387 + reader.symint(512) # _local_scalar_dense_388 + reader.symint(512) # _local_scalar_dense_389 + reader.symint(512) # _local_scalar_dense_390 + reader.symint(512) # _local_scalar_dense_391 + reader.symint(512) # _local_scalar_dense_392 + reader.symint(512) # _local_scalar_dense_393 + reader.symint(512) # _local_scalar_dense_394 + reader.symint(512) # _local_scalar_dense_395 + reader.symint(512) # _local_scalar_dense_396 + reader.symint(512) # _local_scalar_dense_397 + reader.symint(512) # _local_scalar_dense_398 + reader.symint(512) # _local_scalar_dense_399 + reader.symint(512) # _local_scalar_dense_400 + reader.symint(512) # _local_scalar_dense_401 + reader.symint(512) # _local_scalar_dense_402 + reader.symint(512) # _local_scalar_dense_403 + reader.symint(512) # _local_scalar_dense_404 + reader.symint(512) # _local_scalar_dense_405 + reader.symint(512) # _local_scalar_dense_406 + reader.symint(512) # _local_scalar_dense_407 + reader.symint(512) # _local_scalar_dense_408 + reader.symint(512) # _local_scalar_dense_409 + reader.symint(512) # _local_scalar_dense_410 + reader.symint(512) # _local_scalar_dense_411 + reader.symint(512) # _local_scalar_dense_412 + reader.symint(512) # _local_scalar_dense_413 + reader.symint(512) # _local_scalar_dense_414 + reader.symint(512) # _local_scalar_dense_415 + reader.symint(512) # sym_size_int_1 + reader.symint(512) # sym_size_int_5 + reader.symint(512) # sym_size_int_9 + reader.symint(512) # sym_size_int_13 + reader.symint(512) # sym_size_int_17 + reader.symint(512) # sym_size_int_21 + reader.symint(512) # sym_size_int_25 + reader.symint(512) # sym_size_int_29 + reader.symint(512) # sym_size_int_33 + reader.symint(512) # sym_size_int_37 + reader.symint(512) # sym_size_int_41 + reader.symint(512) # sym_size_int_45 + reader.symint(512) # sym_size_int_49 + reader.symint(512) # sym_size_int_53 + reader.symint(512) # sym_size_int_57 + reader.symint(512) # sym_size_int_61 + reader.symint(512) # sym_size_int_65 + reader.symint(512) # sym_size_int_69 + reader.symint(512) # sym_size_int_73 + reader.symint(512) # sym_size_int_77 + reader.symint(512) # sym_size_int_81 + reader.symint(512) # sym_size_int_85 + reader.symint(512) # sym_size_int_89 + reader.symint(512) # sym_size_int_93 + reader.symint(512) # sym_size_int_97 + reader.symint(512) # sym_size_int_101 + reader.symint(512) # add_1781 + reader.symint(512) # add_1796 + reader.symint(512) # add_1811 + reader.symint(512) # add_1826 + reader.symint(512) # add_1841 + reader.symint(512) # add_1856 + reader.symint(512) # add_1871 + reader.symint(512) # add_1886 + reader.symint(512) # add_1901 + reader.symint(512) # add_1916 + reader.symint(512) # add_1931 + reader.symint(512) # add_1946 + reader.symint(512) # add_1961 + reader.symint(512) # add_1976 + reader.symint(512) # add_1991 + reader.symint(512) # add_2006 + reader.symint(512) # add_2021 + reader.symint(512) # add_2036 + reader.symint(512) # add_2051 + reader.symint(512) # add_2066 + reader.symint(512) # add_2081 + reader.symint(512) # add_2096 + reader.symint(512) # add_2111 + reader.symint(512) # add_2126 + reader.symint(512) # add_2141 + reader.symint(512) # add_2156 + buf0 = reader.storage(None, 13107200, device=device(type='cuda', index=0)) + reader.tensor(buf0, (1600, 2048), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 65536, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 4096), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (4096, 32), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf3, (32,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf4, (48, 2048), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf5, (9, 2048), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf6, (8,), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf7, (64, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf8, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_9 + buf9 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf9, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_10 + buf10 = reader.storage(None, 32768, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf10, (2, 4096), dtype=torch.int32, is_leaf=True) # primals_11 + buf11 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf11, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_12 + buf12 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf12, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_13 + buf13 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf13, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_14 + buf14 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf14, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_15 + buf15 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf15, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_16 + buf16 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf16, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_17 + buf17 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf17, (32, 2048), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf18, (32,), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 1400832, device=device(type='cuda', index=0)) + reader.tensor(buf19, (171, 2048), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 1400832, device=device(type='cuda', index=0)) + reader.tensor(buf20, (171, 2048), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 1400832, device=device(type='cuda', index=0)) + reader.tensor(buf21, (32, 10944), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf22, (32,), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf23, (48, 2048), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf24, (9, 2048), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf25, (8,), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf26, (64, 512), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf27, (32, 2048), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf28, (32,), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf29, (1, 2048), is_leaf=True) # primals_31 + buf30 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf30, (1, 1408, 2048), is_leaf=True) # primals_33 + buf31 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf31, (1, 2048, 1408), is_leaf=True) # primals_34 + buf32 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf32, (1, 1408, 2048), is_leaf=True) # primals_35 + buf33 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf33, (44, 2048), is_leaf=True) # primals_36 + buf34 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf34, (44, 2048), is_leaf=True) # primals_37 + buf35 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf35, (32, 2816), is_leaf=True) # primals_38 + buf36 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf36, (32,), is_leaf=True) # primals_39 + buf37 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf37, (48, 2048), is_leaf=True) # primals_40 + buf38 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf38, (9, 2048), is_leaf=True) # primals_41 + buf39 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf39, (8,), is_leaf=True) # primals_42 + buf40 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf40, (64, 512), is_leaf=True) # primals_43 + buf41 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf41, (32, 2048), is_leaf=True) # primals_44 + buf42 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf42, (32,), is_leaf=True) # primals_45 + buf43 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf43, (1, 2048), is_leaf=True) # primals_47 + buf44 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf44, (1, 1408, 2048), is_leaf=True) # primals_49 + buf45 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf45, (1, 2048, 1408), is_leaf=True) # primals_50 + buf46 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf46, (1, 1408, 2048), is_leaf=True) # primals_51 + buf47 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf47, (44, 2048), is_leaf=True) # primals_52 + buf48 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf48, (44, 2048), is_leaf=True) # primals_53 + buf49 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf49, (32, 2816), is_leaf=True) # primals_54 + buf50 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf50, (32,), is_leaf=True) # primals_55 + buf51 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf51, (48, 2048), is_leaf=True) # primals_56 + buf52 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf52, (9, 2048), is_leaf=True) # primals_57 + buf53 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf53, (8,), is_leaf=True) # primals_58 + buf54 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf54, (64, 512), is_leaf=True) # primals_59 + buf55 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf55, (32, 2048), is_leaf=True) # primals_60 + buf56 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf56, (32,), is_leaf=True) # primals_61 + buf57 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf57, (1, 2048), is_leaf=True) # primals_63 + buf58 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf58, (1, 1408, 2048), is_leaf=True) # primals_65 + buf59 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf59, (1, 2048, 1408), is_leaf=True) # primals_66 + buf60 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf60, (1, 1408, 2048), is_leaf=True) # primals_67 + buf61 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf61, (44, 2048), is_leaf=True) # primals_68 + buf62 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf62, (44, 2048), is_leaf=True) # primals_69 + buf63 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf63, (32, 2816), is_leaf=True) # primals_70 + buf64 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf64, (32,), is_leaf=True) # primals_71 + buf65 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf65, (48, 2048), is_leaf=True) # primals_72 + buf66 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf66, (9, 2048), is_leaf=True) # primals_73 + buf67 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf67, (8,), is_leaf=True) # primals_74 + buf68 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf68, (64, 512), is_leaf=True) # primals_75 + buf69 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf69, (32, 2048), is_leaf=True) # primals_76 + buf70 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf70, (32,), is_leaf=True) # primals_77 + buf71 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf71, (1, 2048), is_leaf=True) # primals_79 + buf72 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf72, (1, 1408, 2048), is_leaf=True) # primals_81 + buf73 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf73, (1, 2048, 1408), is_leaf=True) # primals_82 + buf74 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf74, (1, 1408, 2048), is_leaf=True) # primals_83 + buf75 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf75, (44, 2048), is_leaf=True) # primals_84 + buf76 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf76, (44, 2048), is_leaf=True) # primals_85 + buf77 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf77, (32, 2816), is_leaf=True) # primals_86 + buf78 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf78, (32,), is_leaf=True) # primals_87 + buf79 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf79, (48, 2048), is_leaf=True) # primals_88 + buf80 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf80, (9, 2048), is_leaf=True) # primals_89 + buf81 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf81, (8,), is_leaf=True) # primals_90 + buf82 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf82, (64, 512), is_leaf=True) # primals_91 + buf83 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf83, (32, 2048), is_leaf=True) # primals_92 + buf84 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf84, (32,), is_leaf=True) # primals_93 + buf85 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf85, (1, 2048), is_leaf=True) # primals_95 + buf86 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf86, (1, 1408, 2048), is_leaf=True) # primals_97 + buf87 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf87, (1, 2048, 1408), is_leaf=True) # primals_98 + buf88 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf88, (1, 1408, 2048), is_leaf=True) # primals_99 + buf89 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf89, (44, 2048), is_leaf=True) # primals_100 + buf90 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf90, (44, 2048), is_leaf=True) # primals_101 + buf91 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf91, (32, 2816), is_leaf=True) # primals_102 + buf92 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf92, (32,), is_leaf=True) # primals_103 + buf93 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf93, (48, 2048), is_leaf=True) # primals_104 + buf94 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf94, (9, 2048), is_leaf=True) # primals_105 + buf95 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf95, (8,), is_leaf=True) # primals_106 + buf96 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf96, (64, 512), is_leaf=True) # primals_107 + buf97 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf97, (32, 2048), is_leaf=True) # primals_108 + buf98 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf98, (32,), is_leaf=True) # primals_109 + buf99 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf99, (1, 2048), is_leaf=True) # primals_111 + buf100 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf100, (1, 1408, 2048), is_leaf=True) # primals_113 + buf101 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf101, (1, 2048, 1408), is_leaf=True) # primals_114 + buf102 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf102, (1, 1408, 2048), is_leaf=True) # primals_115 + buf103 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf103, (44, 2048), is_leaf=True) # primals_116 + buf104 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf104, (44, 2048), is_leaf=True) # primals_117 + buf105 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf105, (32, 2816), is_leaf=True) # primals_118 + buf106 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf106, (32,), is_leaf=True) # primals_119 + buf107 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf107, (48, 2048), is_leaf=True) # primals_120 + buf108 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf108, (9, 2048), is_leaf=True) # primals_121 + buf109 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf109, (8,), is_leaf=True) # primals_122 + buf110 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf110, (64, 512), is_leaf=True) # primals_123 + buf111 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf111, (32, 2048), is_leaf=True) # primals_124 + buf112 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf112, (32,), is_leaf=True) # primals_125 + buf113 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf113, (1, 2048), is_leaf=True) # primals_127 + buf114 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf114, (1, 1408, 2048), is_leaf=True) # primals_129 + buf115 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf115, (1, 2048, 1408), is_leaf=True) # primals_130 + buf116 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf116, (1, 1408, 2048), is_leaf=True) # primals_131 + buf117 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf117, (44, 2048), is_leaf=True) # primals_132 + buf118 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf118, (44, 2048), is_leaf=True) # primals_133 + buf119 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf119, (32, 2816), is_leaf=True) # primals_134 + buf120 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf120, (32,), is_leaf=True) # primals_135 + buf121 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf121, (48, 2048), is_leaf=True) # primals_136 + buf122 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf122, (9, 2048), is_leaf=True) # primals_137 + buf123 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf123, (8,), is_leaf=True) # primals_138 + buf124 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf124, (64, 512), is_leaf=True) # primals_139 + buf125 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf125, (32, 2048), is_leaf=True) # primals_140 + buf126 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf126, (32,), is_leaf=True) # primals_141 + buf127 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf127, (1, 2048), is_leaf=True) # primals_143 + buf128 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf128, (1, 1408, 2048), is_leaf=True) # primals_145 + buf129 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf129, (1, 2048, 1408), is_leaf=True) # primals_146 + buf130 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf130, (1, 1408, 2048), is_leaf=True) # primals_147 + buf131 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf131, (44, 2048), is_leaf=True) # primals_148 + buf132 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf132, (44, 2048), is_leaf=True) # primals_149 + buf133 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf133, (32, 2816), is_leaf=True) # primals_150 + buf134 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf134, (32,), is_leaf=True) # primals_151 + buf135 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf135, (48, 2048), is_leaf=True) # primals_152 + buf136 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf136, (9, 2048), is_leaf=True) # primals_153 + buf137 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf137, (8,), is_leaf=True) # primals_154 + buf138 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf138, (64, 512), is_leaf=True) # primals_155 + buf139 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf139, (32, 2048), is_leaf=True) # primals_156 + buf140 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf140, (32,), is_leaf=True) # primals_157 + buf141 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf141, (1, 2048), is_leaf=True) # primals_159 + buf142 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf142, (1, 1408, 2048), is_leaf=True) # primals_161 + buf143 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf143, (1, 2048, 1408), is_leaf=True) # primals_162 + buf144 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf144, (1, 1408, 2048), is_leaf=True) # primals_163 + buf145 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf145, (44, 2048), is_leaf=True) # primals_164 + buf146 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf146, (44, 2048), is_leaf=True) # primals_165 + buf147 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf147, (32, 2816), is_leaf=True) # primals_166 + buf148 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf148, (32,), is_leaf=True) # primals_167 + buf149 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf149, (48, 2048), is_leaf=True) # primals_168 + buf150 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf150, (9, 2048), is_leaf=True) # primals_169 + buf151 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf151, (8,), is_leaf=True) # primals_170 + buf152 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf152, (64, 512), is_leaf=True) # primals_171 + buf153 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf153, (32, 2048), is_leaf=True) # primals_172 + buf154 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf154, (32,), is_leaf=True) # primals_173 + buf155 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf155, (1, 2048), is_leaf=True) # primals_175 + buf156 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf156, (1, 1408, 2048), is_leaf=True) # primals_177 + buf157 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf157, (1, 2048, 1408), is_leaf=True) # primals_178 + buf158 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf158, (1, 1408, 2048), is_leaf=True) # primals_179 + buf159 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf159, (44, 2048), is_leaf=True) # primals_180 + buf160 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf160, (44, 2048), is_leaf=True) # primals_181 + buf161 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf161, (32, 2816), is_leaf=True) # primals_182 + buf162 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf162, (32,), is_leaf=True) # primals_183 + buf163 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf163, (48, 2048), is_leaf=True) # primals_184 + buf164 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf164, (9, 2048), is_leaf=True) # primals_185 + buf165 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf165, (8,), is_leaf=True) # primals_186 + buf166 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf166, (64, 512), is_leaf=True) # primals_187 + buf167 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf167, (32, 2048), is_leaf=True) # primals_188 + buf168 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf168, (32,), is_leaf=True) # primals_189 + buf169 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf169, (1, 2048), is_leaf=True) # primals_191 + buf170 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf170, (1, 1408, 2048), is_leaf=True) # primals_193 + buf171 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf171, (1, 2048, 1408), is_leaf=True) # primals_194 + buf172 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf172, (1, 1408, 2048), is_leaf=True) # primals_195 + buf173 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf173, (44, 2048), is_leaf=True) # primals_196 + buf174 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf174, (44, 2048), is_leaf=True) # primals_197 + buf175 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf175, (32, 2816), is_leaf=True) # primals_198 + buf176 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf176, (32,), is_leaf=True) # primals_199 + buf177 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf177, (48, 2048), is_leaf=True) # primals_200 + buf178 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf178, (9, 2048), is_leaf=True) # primals_201 + buf179 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf179, (8,), is_leaf=True) # primals_202 + buf180 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf180, (64, 512), is_leaf=True) # primals_203 + buf181 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf181, (32, 2048), is_leaf=True) # primals_204 + buf182 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf182, (32,), is_leaf=True) # primals_205 + buf183 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf183, (1, 2048), is_leaf=True) # primals_207 + buf184 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf184, (1, 1408, 2048), is_leaf=True) # primals_209 + buf185 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf185, (1, 2048, 1408), is_leaf=True) # primals_210 + buf186 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf186, (1, 1408, 2048), is_leaf=True) # primals_211 + buf187 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf187, (44, 2048), is_leaf=True) # primals_212 + buf188 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf188, (44, 2048), is_leaf=True) # primals_213 + buf189 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf189, (32, 2816), is_leaf=True) # primals_214 + buf190 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf190, (32,), is_leaf=True) # primals_215 + buf191 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf191, (48, 2048), is_leaf=True) # primals_216 + buf192 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf192, (9, 2048), is_leaf=True) # primals_217 + buf193 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf193, (8,), is_leaf=True) # primals_218 + buf194 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf194, (64, 512), is_leaf=True) # primals_219 + buf195 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf195, (32, 2048), is_leaf=True) # primals_220 + buf196 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf196, (32,), is_leaf=True) # primals_221 + buf197 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf197, (1, 2048), is_leaf=True) # primals_223 + buf198 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf198, (1, 1408, 2048), is_leaf=True) # primals_225 + buf199 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf199, (1, 2048, 1408), is_leaf=True) # primals_226 + buf200 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf200, (1, 1408, 2048), is_leaf=True) # primals_227 + buf201 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf201, (44, 2048), is_leaf=True) # primals_228 + buf202 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf202, (44, 2048), is_leaf=True) # primals_229 + buf203 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf203, (32, 2816), is_leaf=True) # primals_230 + buf204 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf204, (32,), is_leaf=True) # primals_231 + buf205 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf205, (48, 2048), is_leaf=True) # primals_232 + buf206 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf206, (9, 2048), is_leaf=True) # primals_233 + buf207 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf207, (8,), is_leaf=True) # primals_234 + buf208 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf208, (64, 512), is_leaf=True) # primals_235 + buf209 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf209, (32, 2048), is_leaf=True) # primals_236 + buf210 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf210, (32,), is_leaf=True) # primals_237 + buf211 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf211, (1, 2048), is_leaf=True) # primals_239 + buf212 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf212, (1, 1408, 2048), is_leaf=True) # primals_241 + buf213 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf213, (1, 2048, 1408), is_leaf=True) # primals_242 + buf214 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf214, (1, 1408, 2048), is_leaf=True) # primals_243 + buf215 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf215, (44, 2048), is_leaf=True) # primals_244 + buf216 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf216, (44, 2048), is_leaf=True) # primals_245 + buf217 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf217, (32, 2816), is_leaf=True) # primals_246 + buf218 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf218, (32,), is_leaf=True) # primals_247 + buf219 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf219, (48, 2048), is_leaf=True) # primals_248 + buf220 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf220, (9, 2048), is_leaf=True) # primals_249 + buf221 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf221, (8,), is_leaf=True) # primals_250 + buf222 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf222, (64, 512), is_leaf=True) # primals_251 + buf223 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf223, (32, 2048), is_leaf=True) # primals_252 + buf224 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf224, (32,), is_leaf=True) # primals_253 + buf225 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf225, (1, 2048), is_leaf=True) # primals_255 + buf226 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf226, (1, 1408, 2048), is_leaf=True) # primals_257 + buf227 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf227, (1, 2048, 1408), is_leaf=True) # primals_258 + buf228 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf228, (1, 1408, 2048), is_leaf=True) # primals_259 + buf229 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf229, (44, 2048), is_leaf=True) # primals_260 + buf230 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf230, (44, 2048), is_leaf=True) # primals_261 + buf231 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf231, (32, 2816), is_leaf=True) # primals_262 + buf232 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf232, (32,), is_leaf=True) # primals_263 + buf233 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf233, (48, 2048), is_leaf=True) # primals_264 + buf234 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf234, (9, 2048), is_leaf=True) # primals_265 + buf235 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf235, (8,), is_leaf=True) # primals_266 + buf236 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf236, (64, 512), is_leaf=True) # primals_267 + buf237 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf237, (32, 2048), is_leaf=True) # primals_268 + buf238 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf238, (32,), is_leaf=True) # primals_269 + buf239 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf239, (1, 2048), is_leaf=True) # primals_271 + buf240 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf240, (1, 1408, 2048), is_leaf=True) # primals_273 + buf241 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf241, (1, 2048, 1408), is_leaf=True) # primals_274 + buf242 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf242, (1, 1408, 2048), is_leaf=True) # primals_275 + buf243 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf243, (44, 2048), is_leaf=True) # primals_276 + buf244 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf244, (44, 2048), is_leaf=True) # primals_277 + buf245 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf245, (32, 2816), is_leaf=True) # primals_278 + buf246 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf246, (32,), is_leaf=True) # primals_279 + buf247 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf247, (48, 2048), is_leaf=True) # primals_280 + buf248 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf248, (9, 2048), is_leaf=True) # primals_281 + buf249 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf249, (8,), is_leaf=True) # primals_282 + buf250 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf250, (64, 512), is_leaf=True) # primals_283 + buf251 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf251, (32, 2048), is_leaf=True) # primals_284 + buf252 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf252, (32,), is_leaf=True) # primals_285 + buf253 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf253, (1, 2048), is_leaf=True) # primals_287 + buf254 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf254, (1, 1408, 2048), is_leaf=True) # primals_289 + buf255 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf255, (1, 2048, 1408), is_leaf=True) # primals_290 + buf256 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf256, (1, 1408, 2048), is_leaf=True) # primals_291 + buf257 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf257, (44, 2048), is_leaf=True) # primals_292 + buf258 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf258, (44, 2048), is_leaf=True) # primals_293 + buf259 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf259, (32, 2816), is_leaf=True) # primals_294 + buf260 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf260, (32,), is_leaf=True) # primals_295 + buf261 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf261, (48, 2048), is_leaf=True) # primals_296 + buf262 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf262, (9, 2048), is_leaf=True) # primals_297 + buf263 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf263, (8,), is_leaf=True) # primals_298 + buf264 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf264, (64, 512), is_leaf=True) # primals_299 + buf265 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf265, (32, 2048), is_leaf=True) # primals_300 + buf266 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf266, (32,), is_leaf=True) # primals_301 + buf267 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf267, (1, 2048), is_leaf=True) # primals_303 + buf268 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf268, (1, 1408, 2048), is_leaf=True) # primals_305 + buf269 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf269, (1, 2048, 1408), is_leaf=True) # primals_306 + buf270 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf270, (1, 1408, 2048), is_leaf=True) # primals_307 + buf271 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf271, (44, 2048), is_leaf=True) # primals_308 + buf272 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf272, (44, 2048), is_leaf=True) # primals_309 + buf273 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf273, (32, 2816), is_leaf=True) # primals_310 + buf274 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf274, (32,), is_leaf=True) # primals_311 + buf275 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf275, (48, 2048), is_leaf=True) # primals_312 + buf276 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf276, (9, 2048), is_leaf=True) # primals_313 + buf277 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf277, (8,), is_leaf=True) # primals_314 + buf278 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf278, (64, 512), is_leaf=True) # primals_315 + buf279 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf279, (32, 2048), is_leaf=True) # primals_316 + buf280 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf280, (32,), is_leaf=True) # primals_317 + buf281 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf281, (1, 2048), is_leaf=True) # primals_319 + buf282 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf282, (1, 1408, 2048), is_leaf=True) # primals_321 + buf283 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf283, (1, 2048, 1408), is_leaf=True) # primals_322 + buf284 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf284, (1, 1408, 2048), is_leaf=True) # primals_323 + buf285 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf285, (44, 2048), is_leaf=True) # primals_324 + buf286 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf286, (44, 2048), is_leaf=True) # primals_325 + buf287 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf287, (32, 2816), is_leaf=True) # primals_326 + buf288 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf288, (32,), is_leaf=True) # primals_327 + buf289 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf289, (48, 2048), is_leaf=True) # primals_328 + buf290 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf290, (9, 2048), is_leaf=True) # primals_329 + buf291 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf291, (8,), is_leaf=True) # primals_330 + buf292 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf292, (64, 512), is_leaf=True) # primals_331 + buf293 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf293, (32, 2048), is_leaf=True) # primals_332 + buf294 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf294, (32,), is_leaf=True) # primals_333 + buf295 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf295, (1, 2048), is_leaf=True) # primals_335 + buf296 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf296, (1, 1408, 2048), is_leaf=True) # primals_337 + buf297 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf297, (1, 2048, 1408), is_leaf=True) # primals_338 + buf298 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf298, (1, 1408, 2048), is_leaf=True) # primals_339 + buf299 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf299, (44, 2048), is_leaf=True) # primals_340 + buf300 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf300, (44, 2048), is_leaf=True) # primals_341 + buf301 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf301, (32, 2816), is_leaf=True) # primals_342 + buf302 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf302, (32,), is_leaf=True) # primals_343 + buf303 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf303, (48, 2048), is_leaf=True) # primals_344 + buf304 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf304, (9, 2048), is_leaf=True) # primals_345 + buf305 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf305, (8,), is_leaf=True) # primals_346 + buf306 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf306, (64, 512), is_leaf=True) # primals_347 + buf307 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf307, (32, 2048), is_leaf=True) # primals_348 + buf308 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf308, (32,), is_leaf=True) # primals_349 + buf309 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf309, (1, 2048), is_leaf=True) # primals_351 + buf310 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf310, (1, 1408, 2048), is_leaf=True) # primals_353 + buf311 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf311, (1, 2048, 1408), is_leaf=True) # primals_354 + buf312 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf312, (1, 1408, 2048), is_leaf=True) # primals_355 + buf313 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf313, (44, 2048), is_leaf=True) # primals_356 + buf314 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf314, (44, 2048), is_leaf=True) # primals_357 + buf315 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf315, (32, 2816), is_leaf=True) # primals_358 + buf316 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf316, (32,), is_leaf=True) # primals_359 + buf317 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf317, (48, 2048), is_leaf=True) # primals_360 + buf318 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf318, (9, 2048), is_leaf=True) # primals_361 + buf319 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf319, (8,), is_leaf=True) # primals_362 + buf320 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf320, (64, 512), is_leaf=True) # primals_363 + buf321 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf321, (32, 2048), is_leaf=True) # primals_364 + buf322 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf322, (32,), is_leaf=True) # primals_365 + buf323 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf323, (1, 2048), is_leaf=True) # primals_367 + buf324 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf324, (1, 1408, 2048), is_leaf=True) # primals_369 + buf325 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf325, (1, 2048, 1408), is_leaf=True) # primals_370 + buf326 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf326, (1, 1408, 2048), is_leaf=True) # primals_371 + buf327 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf327, (44, 2048), is_leaf=True) # primals_372 + buf328 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf328, (44, 2048), is_leaf=True) # primals_373 + buf329 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf329, (32, 2816), is_leaf=True) # primals_374 + buf330 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf330, (32,), is_leaf=True) # primals_375 + buf331 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf331, (48, 2048), is_leaf=True) # primals_376 + buf332 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf332, (9, 2048), is_leaf=True) # primals_377 + buf333 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf333, (8,), is_leaf=True) # primals_378 + buf334 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf334, (64, 512), is_leaf=True) # primals_379 + buf335 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf335, (32, 2048), is_leaf=True) # primals_380 + buf336 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf336, (32,), is_leaf=True) # primals_381 + buf337 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf337, (1, 2048), is_leaf=True) # primals_383 + buf338 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf338, (1, 1408, 2048), is_leaf=True) # primals_385 + buf339 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf339, (1, 2048, 1408), is_leaf=True) # primals_386 + buf340 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf340, (1, 1408, 2048), is_leaf=True) # primals_387 + buf341 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf341, (44, 2048), is_leaf=True) # primals_388 + buf342 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf342, (44, 2048), is_leaf=True) # primals_389 + buf343 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf343, (32, 2816), is_leaf=True) # primals_390 + buf344 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf344, (32,), is_leaf=True) # primals_391 + buf345 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf345, (48, 2048), is_leaf=True) # primals_392 + buf346 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf346, (9, 2048), is_leaf=True) # primals_393 + buf347 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf347, (8,), is_leaf=True) # primals_394 + buf348 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf348, (64, 512), is_leaf=True) # primals_395 + buf349 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf349, (32, 2048), is_leaf=True) # primals_396 + buf350 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf350, (32,), is_leaf=True) # primals_397 + buf351 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf351, (1, 2048), is_leaf=True) # primals_399 + buf352 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf352, (1, 1408, 2048), is_leaf=True) # primals_401 + buf353 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf353, (1, 2048, 1408), is_leaf=True) # primals_402 + buf354 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf354, (1, 1408, 2048), is_leaf=True) # primals_403 + buf355 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf355, (44, 2048), is_leaf=True) # primals_404 + buf356 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf356, (44, 2048), is_leaf=True) # primals_405 + buf357 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf357, (32, 2816), is_leaf=True) # primals_406 + buf358 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf358, (32,), is_leaf=True) # primals_407 + buf359 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf359, (48, 2048), is_leaf=True) # primals_408 + buf360 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf360, (9, 2048), is_leaf=True) # primals_409 + buf361 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf361, (8,), is_leaf=True) # primals_410 + buf362 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf362, (64, 512), is_leaf=True) # primals_411 + buf363 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf363, (32, 2048), is_leaf=True) # primals_412 + buf364 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf364, (32,), is_leaf=True) # primals_413 + buf365 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf365, (1, 2048), is_leaf=True) # primals_415 + buf366 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf366, (1, 1408, 2048), is_leaf=True) # primals_417 + buf367 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf367, (1, 2048, 1408), is_leaf=True) # primals_418 + buf368 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf368, (1, 1408, 2048), is_leaf=True) # primals_419 + buf369 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf369, (44, 2048), is_leaf=True) # primals_420 + buf370 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf370, (44, 2048), is_leaf=True) # primals_421 + buf371 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf371, (32, 2816), is_leaf=True) # primals_422 + buf372 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf372, (32,), is_leaf=True) # primals_423 + buf373 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf373, (48, 2048), is_leaf=True) # primals_424 + buf374 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf374, (9, 2048), is_leaf=True) # primals_425 + buf375 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf375, (8,), is_leaf=True) # primals_426 + buf376 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf376, (64, 512), is_leaf=True) # primals_427 + buf377 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf377, (32, 2048), is_leaf=True) # primals_428 + buf378 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf378, (32,), is_leaf=True) # primals_429 + buf379 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf379, (1, 2048), is_leaf=True) # primals_431 + buf380 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf380, (1, 1408, 2048), is_leaf=True) # primals_433 + buf381 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf381, (1, 2048, 1408), is_leaf=True) # primals_434 + buf382 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf382, (1, 1408, 2048), is_leaf=True) # primals_435 + buf383 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf383, (44, 2048), is_leaf=True) # primals_436 + buf384 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf384, (44, 2048), is_leaf=True) # primals_437 + buf385 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf385, (32, 2816), is_leaf=True) # primals_438 + buf386 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf386, (32,), is_leaf=True) # primals_439 + buf387 = reader.storage(None, 13107200, device=device(type='cuda', index=0)) + reader.tensor(buf387, (1600, 2048), is_leaf=True) # primals_440 + buf388 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf388, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # embedding + buf389 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf389, (2, 4096, 1), is_leaf=True) # rsqrt + buf390 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf390, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_3 + buf391 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf391, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_2 + buf392 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf392, (2, 4096, 1), is_leaf=True) # rsqrt_1 + buf393 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf393, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_17 + buf394 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf394, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_3 + buf395 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf395, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_4 + buf396 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf396, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_5 + buf397 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf397, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_6 + buf398 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf398, (2, 16, 4096), is_leaf=True) # getitem_7 + buf399 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf399, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # mm_3 + buf400 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf400, (2, 4096, 1), is_leaf=True) # rsqrt_2 + buf401 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf401, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_26 + buf402 = reader.storage(None, 179306496, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf402, (8192, 10944), dtype=torch.bfloat16, is_leaf=True) # mm_4 + buf403 = reader.storage(None, 179306496, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf403, (8192, 10944), dtype=torch.bfloat16, is_leaf=True) # mm_5 + buf404 = reader.storage(None, 179306496, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf404, (8192, 10944), dtype=torch.bfloat16, is_leaf=True) # view_32 + buf405 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf405, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_5 + buf406 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf406, (2, 4096, 1), is_leaf=True) # rsqrt_3 + buf407 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf407, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_36 + buf408 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf408, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_11 + buf409 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf409, (2, 4096, 1), is_leaf=True) # rsqrt_4 + buf410 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf410, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_50 + buf411 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf411, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_14 + buf412 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf412, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_15 + buf413 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf413, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_16 + buf414 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf414, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_15 + buf415 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf415, (2, 16, 4096), is_leaf=True) # getitem_16 + buf416 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf416, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_8 + buf417 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf417, (2, 4096, 1), is_leaf=True) # rsqrt_5 + buf418 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf418, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_58 + buf419 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf419, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_11 + buf420 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf420, (8192, 1), is_leaf=True) # amax + buf421 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf421, (8192, 1), is_leaf=True) # sum_1 + buf422 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf422, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_19 + buf423 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf423, (49152,), dtype=torch.int64, is_leaf=True) # getitem_21 + buf424 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf424, (49152,), dtype=torch.int64, is_leaf=True) # div_2 + buf425 = reader.storage(None, 32*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf425, (8*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_22 + buf426 = reader.storage(None, 32768*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf426, (8*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_1 + buf427 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf427, (8,), dtype=torch.int32, is_leaf=True) # cumsum_2 + buf428 = reader.storage(None, 22528*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf428, (8*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm + buf429 = reader.storage(None, 22528*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf429, (8*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_1 + buf430 = reader.storage(None, 22528*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf430, (8*(((u10 + u11 + u12 + u13 + u14 + u15 + u8 + u9 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_35 + buf431 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf431, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_12 + buf432 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf432, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_13 + buf433 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf433, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_55 + buf434 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf434, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_73 + buf435 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf435, (2, 4096, 1), is_leaf=True) # rsqrt_6 + buf436 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf436, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_103 + buf437 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf437, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_25 + buf438 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf438, (2, 4096, 1), is_leaf=True) # rsqrt_7 + buf439 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf439, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_117 + buf440 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf440, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_29 + buf441 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf441, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_30 + buf442 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf442, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_31 + buf443 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf443, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_29 + buf444 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf444, (2, 16, 4096), is_leaf=True) # getitem_30 + buf445 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf445, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_76 + buf446 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf446, (2, 4096, 1), is_leaf=True) # rsqrt_8 + buf447 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf447, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_125 + buf448 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf448, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_19 + buf449 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf449, (8192, 1), is_leaf=True) # amax_1 + buf450 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf450, (8192, 1), is_leaf=True) # sum_5 + buf451 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf451, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_33 + buf452 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf452, (49152,), dtype=torch.int64, is_leaf=True) # getitem_35 + buf453 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf453, (49152,), dtype=torch.int64, is_leaf=True) # div_7 + buf454 = reader.storage(None, 32*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf454, (8*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_36 + buf455 = reader.storage(None, 32768*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf455, (8*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_3 + buf456 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf456, (8,), dtype=torch.int32, is_leaf=True) # cumsum_5 + buf457 = reader.storage(None, 22528*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf457, (8*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_3 + buf458 = reader.storage(None, 22528*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf458, (8*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_4 + buf459 = reader.storage(None, 22528*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf459, (8*(((u24 + u25 + u26 + u27 + u28 + u29 + u30 + u31 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_84 + buf460 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf460, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_20 + buf461 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf461, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_21 + buf462 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf462, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_104 + buf463 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf463, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_141 + buf464 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf464, (2, 4096, 1), is_leaf=True) # rsqrt_9 + buf465 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf465, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_170 + buf466 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf466, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_39 + buf467 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf467, (2, 4096, 1), is_leaf=True) # rsqrt_10 + buf468 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf468, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_184 + buf469 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf469, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_44 + buf470 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf470, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_45 + buf471 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf471, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_46 + buf472 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf472, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_43 + buf473 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf473, (2, 16, 4096), is_leaf=True) # getitem_44 + buf474 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf474, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_144 + buf475 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf475, (2, 4096, 1), is_leaf=True) # rsqrt_11 + buf476 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf476, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_192 + buf477 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf477, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_27 + buf478 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf478, (8192, 1), is_leaf=True) # amax_2 + buf479 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf479, (8192, 1), is_leaf=True) # sum_9 + buf480 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf480, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_47 + buf481 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf481, (49152,), dtype=torch.int64, is_leaf=True) # getitem_49 + buf482 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf482, (49152,), dtype=torch.int64, is_leaf=True) # div_12 + buf483 = reader.storage(None, 32*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf483, (8*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_50 + buf484 = reader.storage(None, 32768*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf484, (8*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_5 + buf485 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf485, (8,), dtype=torch.int32, is_leaf=True) # cumsum_8 + buf486 = reader.storage(None, 22528*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf486, (8*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_6 + buf487 = reader.storage(None, 22528*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf487, (8*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_7 + buf488 = reader.storage(None, 22528*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf488, (8*(((u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_133 + buf489 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf489, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_28 + buf490 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf490, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_29 + buf491 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf491, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_153 + buf492 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf492, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_209 + buf493 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf493, (2, 4096, 1), is_leaf=True) # rsqrt_12 + buf494 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf494, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_237 + buf495 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf495, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_53 + buf496 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf496, (2, 4096, 1), is_leaf=True) # rsqrt_13 + buf497 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf497, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_251 + buf498 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf498, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_59 + buf499 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf499, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_60 + buf500 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf500, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_61 + buf501 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf501, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_57 + buf502 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf502, (2, 16, 4096), is_leaf=True) # getitem_58 + buf503 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf503, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_212 + buf504 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf504, (2, 4096, 1), is_leaf=True) # rsqrt_14 + buf505 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf505, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_259 + buf506 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf506, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_35 + buf507 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf507, (8192, 1), is_leaf=True) # amax_3 + buf508 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf508, (8192, 1), is_leaf=True) # sum_13 + buf509 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf509, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_61 + buf510 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf510, (49152,), dtype=torch.int64, is_leaf=True) # getitem_63 + buf511 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf511, (49152,), dtype=torch.int64, is_leaf=True) # div_17 + buf512 = reader.storage(None, 32*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf512, (8*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_64 + buf513 = reader.storage(None, 32768*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf513, (8*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_7 + buf514 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf514, (8,), dtype=torch.int32, is_leaf=True) # cumsum_11 + buf515 = reader.storage(None, 22528*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf515, (8*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_9 + buf516 = reader.storage(None, 22528*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf516, (8*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_10 + buf517 = reader.storage(None, 22528*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf517, (8*(((u56 + u57 + u58 + u59 + u60 + u61 + u62 + u63 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_182 + buf518 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf518, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_36 + buf519 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf519, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_37 + buf520 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf520, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_202 + buf521 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf521, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_277 + buf522 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf522, (2, 4096, 1), is_leaf=True) # rsqrt_15 + buf523 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf523, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_304 + buf524 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf524, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_67 + buf525 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf525, (2, 4096, 1), is_leaf=True) # rsqrt_16 + buf526 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf526, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_318 + buf527 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf527, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_74 + buf528 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf528, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_75 + buf529 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf529, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_76 + buf530 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf530, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_71 + buf531 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf531, (2, 16, 4096), is_leaf=True) # getitem_72 + buf532 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf532, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_280 + buf533 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf533, (2, 4096, 1), is_leaf=True) # rsqrt_17 + buf534 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf534, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_326 + buf535 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf535, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_43 + buf536 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf536, (8192, 1), is_leaf=True) # amax_4 + buf537 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf537, (8192, 1), is_leaf=True) # sum_17 + buf538 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf538, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_75 + buf539 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf539, (49152,), dtype=torch.int64, is_leaf=True) # getitem_77 + buf540 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf540, (49152,), dtype=torch.int64, is_leaf=True) # div_22 + buf541 = reader.storage(None, 32*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf541, (8*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_78 + buf542 = reader.storage(None, 32768*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf542, (8*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_9 + buf543 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf543, (8,), dtype=torch.int32, is_leaf=True) # cumsum_14 + buf544 = reader.storage(None, 22528*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf544, (8*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_12 + buf545 = reader.storage(None, 22528*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf545, (8*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_13 + buf546 = reader.storage(None, 22528*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf546, (8*(((u72 + u73 + u74 + u75 + u76 + u77 + u78 + u79 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_231 + buf547 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf547, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_44 + buf548 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf548, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_45 + buf549 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf549, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_251 + buf550 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf550, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_345 + buf551 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf551, (2, 4096, 1), is_leaf=True) # rsqrt_18 + buf552 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf552, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_371 + buf553 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf553, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_81 + buf554 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf554, (2, 4096, 1), is_leaf=True) # rsqrt_19 + buf555 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf555, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_385 + buf556 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf556, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_89 + buf557 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf557, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_90 + buf558 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf558, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_91 + buf559 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf559, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_85 + buf560 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf560, (2, 16, 4096), is_leaf=True) # getitem_86 + buf561 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf561, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_348 + buf562 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf562, (2, 4096, 1), is_leaf=True) # rsqrt_20 + buf563 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf563, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_393 + buf564 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf564, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_51 + buf565 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf565, (8192, 1), is_leaf=True) # amax_5 + buf566 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf566, (8192, 1), is_leaf=True) # sum_21 + buf567 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf567, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_89 + buf568 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf568, (49152,), dtype=torch.int64, is_leaf=True) # getitem_91 + buf569 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf569, (49152,), dtype=torch.int64, is_leaf=True) # div_27 + buf570 = reader.storage(None, 32*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf570, (8*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_92 + buf571 = reader.storage(None, 32768*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf571, (8*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_11 + buf572 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf572, (8,), dtype=torch.int32, is_leaf=True) # cumsum_17 + buf573 = reader.storage(None, 22528*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf573, (8*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_15 + buf574 = reader.storage(None, 22528*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf574, (8*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_16 + buf575 = reader.storage(None, 22528*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf575, (8*(((u88 + u89 + u90 + u91 + u92 + u93 + u94 + u95 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_280 + buf576 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf576, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_52 + buf577 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf577, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_53 + buf578 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf578, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_300 + buf579 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf579, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_413 + buf580 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf580, (2, 4096, 1), is_leaf=True) # rsqrt_21 + buf581 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf581, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_438 + buf582 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf582, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_95 + buf583 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf583, (2, 4096, 1), is_leaf=True) # rsqrt_22 + buf584 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf584, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_452 + buf585 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf585, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_104 + buf586 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf586, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_105 + buf587 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf587, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_106 + buf588 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf588, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_99 + buf589 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf589, (2, 16, 4096), is_leaf=True) # getitem_100 + buf590 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf590, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_416 + buf591 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf591, (2, 4096, 1), is_leaf=True) # rsqrt_23 + buf592 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf592, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_460 + buf593 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf593, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_59 + buf594 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf594, (8192, 1), is_leaf=True) # amax_6 + buf595 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf595, (8192, 1), is_leaf=True) # sum_25 + buf596 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf596, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_103 + buf597 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf597, (49152,), dtype=torch.int64, is_leaf=True) # getitem_105 + buf598 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf598, (49152,), dtype=torch.int64, is_leaf=True) # div_32 + buf599 = reader.storage(None, 32*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf599, (8*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_106 + buf600 = reader.storage(None, 32768*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf600, (8*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_13 + buf601 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf601, (8,), dtype=torch.int32, is_leaf=True) # cumsum_20 + buf602 = reader.storage(None, 22528*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf602, (8*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_18 + buf603 = reader.storage(None, 22528*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf603, (8*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_19 + buf604 = reader.storage(None, 22528*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf604, (8*(((u104 + u105 + u106 + u107 + u108 + u109 + u110 + u111 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_329 + buf605 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf605, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_60 + buf606 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf606, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_61 + buf607 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf607, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_349 + buf608 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf608, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_481 + buf609 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf609, (2, 4096, 1), is_leaf=True) # rsqrt_24 + buf610 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf610, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_505 + buf611 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf611, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_109 + buf612 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf612, (2, 4096, 1), is_leaf=True) # rsqrt_25 + buf613 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf613, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_519 + buf614 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf614, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_119 + buf615 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf615, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_120 + buf616 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf616, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_121 + buf617 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf617, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_113 + buf618 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf618, (2, 16, 4096), is_leaf=True) # getitem_114 + buf619 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf619, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_484 + buf620 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf620, (2, 4096, 1), is_leaf=True) # rsqrt_26 + buf621 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf621, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_527 + buf622 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf622, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_67 + buf623 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf623, (8192, 1), is_leaf=True) # amax_7 + buf624 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf624, (8192, 1), is_leaf=True) # sum_29 + buf625 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf625, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_117 + buf626 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf626, (49152,), dtype=torch.int64, is_leaf=True) # getitem_119 + buf627 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf627, (49152,), dtype=torch.int64, is_leaf=True) # div_37 + buf628 = reader.storage(None, 32*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf628, (8*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_120 + buf629 = reader.storage(None, 32768*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf629, (8*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_15 + buf630 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf630, (8,), dtype=torch.int32, is_leaf=True) # cumsum_23 + buf631 = reader.storage(None, 22528*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf631, (8*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_21 + buf632 = reader.storage(None, 22528*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf632, (8*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_22 + buf633 = reader.storage(None, 22528*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf633, (8*(((u120 + u121 + u122 + u123 + u124 + u125 + u126 + u127 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_378 + buf634 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf634, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_68 + buf635 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf635, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_69 + buf636 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf636, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_398 + buf637 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf637, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_549 + buf638 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf638, (2, 4096, 1), is_leaf=True) # rsqrt_27 + buf639 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf639, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_572 + buf640 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf640, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_123 + buf641 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf641, (2, 4096, 1), is_leaf=True) # rsqrt_28 + buf642 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf642, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_586 + buf643 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf643, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_134 + buf644 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf644, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_135 + buf645 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf645, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_136 + buf646 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf646, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_127 + buf647 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf647, (2, 16, 4096), is_leaf=True) # getitem_128 + buf648 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf648, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_552 + buf649 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf649, (2, 4096, 1), is_leaf=True) # rsqrt_29 + buf650 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf650, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_594 + buf651 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf651, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_75 + buf652 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf652, (8192, 1), is_leaf=True) # amax_8 + buf653 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf653, (8192, 1), is_leaf=True) # sum_33 + buf654 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf654, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_131 + buf655 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf655, (49152,), dtype=torch.int64, is_leaf=True) # getitem_133 + buf656 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf656, (49152,), dtype=torch.int64, is_leaf=True) # div_42 + buf657 = reader.storage(None, 32*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf657, (8*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_134 + buf658 = reader.storage(None, 32768*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf658, (8*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_17 + buf659 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf659, (8,), dtype=torch.int32, is_leaf=True) # cumsum_26 + buf660 = reader.storage(None, 22528*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf660, (8*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_24 + buf661 = reader.storage(None, 22528*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf661, (8*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_25 + buf662 = reader.storage(None, 22528*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf662, (8*(((u136 + u137 + u138 + u139 + u140 + u141 + u142 + u143 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_427 + buf663 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf663, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_76 + buf664 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf664, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_77 + buf665 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf665, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_447 + buf666 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf666, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_617 + buf667 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf667, (2, 4096, 1), is_leaf=True) # rsqrt_30 + buf668 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf668, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_639 + buf669 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf669, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_137 + buf670 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf670, (2, 4096, 1), is_leaf=True) # rsqrt_31 + buf671 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf671, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_653 + buf672 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf672, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_149 + buf673 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf673, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_150 + buf674 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf674, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_151 + buf675 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf675, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_141 + buf676 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf676, (2, 16, 4096), is_leaf=True) # getitem_142 + buf677 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf677, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_620 + buf678 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf678, (2, 4096, 1), is_leaf=True) # rsqrt_32 + buf679 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf679, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_661 + buf680 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf680, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_83 + buf681 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf681, (8192, 1), is_leaf=True) # amax_9 + buf682 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf682, (8192, 1), is_leaf=True) # sum_37 + buf683 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf683, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_145 + buf684 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf684, (49152,), dtype=torch.int64, is_leaf=True) # getitem_147 + buf685 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf685, (49152,), dtype=torch.int64, is_leaf=True) # div_47 + buf686 = reader.storage(None, 32*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf686, (8*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_148 + buf687 = reader.storage(None, 32768*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf687, (8*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_19 + buf688 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf688, (8,), dtype=torch.int32, is_leaf=True) # cumsum_29 + buf689 = reader.storage(None, 22528*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf689, (8*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_27 + buf690 = reader.storage(None, 22528*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf690, (8*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_28 + buf691 = reader.storage(None, 22528*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf691, (8*(((u152 + u153 + u154 + u155 + u156 + u157 + u158 + u159 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_476 + buf692 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf692, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_84 + buf693 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf693, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_85 + buf694 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf694, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_496 + buf695 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf695, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_685 + buf696 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf696, (2, 4096, 1), is_leaf=True) # rsqrt_33 + buf697 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf697, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_706 + buf698 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf698, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_151 + buf699 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf699, (2, 4096, 1), is_leaf=True) # rsqrt_34 + buf700 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf700, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_720 + buf701 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf701, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_164 + buf702 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf702, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_165 + buf703 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf703, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_166 + buf704 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf704, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_155 + buf705 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf705, (2, 16, 4096), is_leaf=True) # getitem_156 + buf706 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf706, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_688 + buf707 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf707, (2, 4096, 1), is_leaf=True) # rsqrt_35 + buf708 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf708, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_728 + buf709 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf709, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_91 + buf710 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf710, (8192, 1), is_leaf=True) # amax_10 + buf711 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf711, (8192, 1), is_leaf=True) # sum_41 + buf712 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf712, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_159 + buf713 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf713, (49152,), dtype=torch.int64, is_leaf=True) # getitem_161 + buf714 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf714, (49152,), dtype=torch.int64, is_leaf=True) # div_52 + buf715 = reader.storage(None, 32*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf715, (8*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_162 + buf716 = reader.storage(None, 32768*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf716, (8*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_21 + buf717 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf717, (8,), dtype=torch.int32, is_leaf=True) # cumsum_32 + buf718 = reader.storage(None, 22528*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf718, (8*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_30 + buf719 = reader.storage(None, 22528*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf719, (8*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_31 + buf720 = reader.storage(None, 22528*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf720, (8*(((u168 + u169 + u170 + u171 + u172 + u173 + u174 + u175 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_525 + buf721 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf721, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_92 + buf722 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf722, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_93 + buf723 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf723, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_545 + buf724 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf724, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_753 + buf725 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf725, (2, 4096, 1), is_leaf=True) # rsqrt_36 + buf726 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf726, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_773 + buf727 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf727, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_165 + buf728 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf728, (2, 4096, 1), is_leaf=True) # rsqrt_37 + buf729 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf729, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_787 + buf730 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf730, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_179 + buf731 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf731, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_180 + buf732 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf732, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_181 + buf733 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf733, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_169 + buf734 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf734, (2, 16, 4096), is_leaf=True) # getitem_170 + buf735 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf735, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_756 + buf736 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf736, (2, 4096, 1), is_leaf=True) # rsqrt_38 + buf737 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf737, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_795 + buf738 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf738, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_99 + buf739 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf739, (8192, 1), is_leaf=True) # amax_11 + buf740 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf740, (8192, 1), is_leaf=True) # sum_45 + buf741 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf741, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_173 + buf742 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf742, (49152,), dtype=torch.int64, is_leaf=True) # getitem_175 + buf743 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf743, (49152,), dtype=torch.int64, is_leaf=True) # div_57 + buf744 = reader.storage(None, 32*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf744, (8*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_176 + buf745 = reader.storage(None, 32768*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf745, (8*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_23 + buf746 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf746, (8,), dtype=torch.int32, is_leaf=True) # cumsum_35 + buf747 = reader.storage(None, 22528*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf747, (8*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_33 + buf748 = reader.storage(None, 22528*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf748, (8*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_34 + buf749 = reader.storage(None, 22528*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf749, (8*(((u184 + u185 + u186 + u187 + u188 + u189 + u190 + u191 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_574 + buf750 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf750, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_100 + buf751 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf751, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_101 + buf752 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf752, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_594 + buf753 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf753, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_821 + buf754 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf754, (2, 4096, 1), is_leaf=True) # rsqrt_39 + buf755 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf755, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_840 + buf756 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf756, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_179 + buf757 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf757, (2, 4096, 1), is_leaf=True) # rsqrt_40 + buf758 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf758, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_854 + buf759 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf759, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_194 + buf760 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf760, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_195 + buf761 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf761, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_196 + buf762 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf762, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_183 + buf763 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf763, (2, 16, 4096), is_leaf=True) # getitem_184 + buf764 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf764, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_824 + buf765 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf765, (2, 4096, 1), is_leaf=True) # rsqrt_41 + buf766 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf766, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_862 + buf767 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf767, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_107 + buf768 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf768, (8192, 1), is_leaf=True) # amax_12 + buf769 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf769, (8192, 1), is_leaf=True) # sum_49 + buf770 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf770, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_187 + buf771 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf771, (49152,), dtype=torch.int64, is_leaf=True) # getitem_189 + buf772 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf772, (49152,), dtype=torch.int64, is_leaf=True) # div_62 + buf773 = reader.storage(None, 32*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf773, (8*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_190 + buf774 = reader.storage(None, 32768*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf774, (8*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_25 + buf775 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf775, (8,), dtype=torch.int32, is_leaf=True) # cumsum_38 + buf776 = reader.storage(None, 22528*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf776, (8*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_36 + buf777 = reader.storage(None, 22528*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf777, (8*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_37 + buf778 = reader.storage(None, 22528*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf778, (8*(((u200 + u201 + u202 + u203 + u204 + u205 + u206 + u207 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_623 + buf779 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf779, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_108 + buf780 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf780, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_109 + buf781 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf781, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_643 + buf782 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf782, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_889 + buf783 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf783, (2, 4096, 1), is_leaf=True) # rsqrt_42 + buf784 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf784, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_907 + buf785 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf785, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_193 + buf786 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf786, (2, 4096, 1), is_leaf=True) # rsqrt_43 + buf787 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf787, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_921 + buf788 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf788, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_209 + buf789 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf789, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_210 + buf790 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf790, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_211 + buf791 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf791, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_197 + buf792 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf792, (2, 16, 4096), is_leaf=True) # getitem_198 + buf793 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf793, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_892 + buf794 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf794, (2, 4096, 1), is_leaf=True) # rsqrt_44 + buf795 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf795, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_929 + buf796 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf796, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_115 + buf797 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf797, (8192, 1), is_leaf=True) # amax_13 + buf798 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf798, (8192, 1), is_leaf=True) # sum_53 + buf799 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf799, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_201 + buf800 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf800, (49152,), dtype=torch.int64, is_leaf=True) # getitem_203 + buf801 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf801, (49152,), dtype=torch.int64, is_leaf=True) # div_67 + buf802 = reader.storage(None, 32*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf802, (8*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_204 + buf803 = reader.storage(None, 32768*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf803, (8*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_27 + buf804 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf804, (8,), dtype=torch.int32, is_leaf=True) # cumsum_41 + buf805 = reader.storage(None, 22528*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf805, (8*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_39 + buf806 = reader.storage(None, 22528*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf806, (8*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_40 + buf807 = reader.storage(None, 22528*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf807, (8*(((u216 + u217 + u218 + u219 + u220 + u221 + u222 + u223 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_672 + buf808 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf808, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_116 + buf809 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf809, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_117 + buf810 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf810, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_692 + buf811 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf811, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_957 + buf812 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf812, (2, 4096, 1), is_leaf=True) # rsqrt_45 + buf813 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf813, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_974 + buf814 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf814, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_207 + buf815 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf815, (2, 4096, 1), is_leaf=True) # rsqrt_46 + buf816 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf816, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_988 + buf817 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf817, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_224 + buf818 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf818, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_225 + buf819 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf819, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_226 + buf820 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf820, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_211 + buf821 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf821, (2, 16, 4096), is_leaf=True) # getitem_212 + buf822 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf822, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_960 + buf823 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf823, (2, 4096, 1), is_leaf=True) # rsqrt_47 + buf824 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf824, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_996 + buf825 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf825, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_123 + buf826 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf826, (8192, 1), is_leaf=True) # amax_14 + buf827 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf827, (8192, 1), is_leaf=True) # sum_57 + buf828 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf828, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_215 + buf829 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf829, (49152,), dtype=torch.int64, is_leaf=True) # getitem_217 + buf830 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf830, (49152,), dtype=torch.int64, is_leaf=True) # div_72 + buf831 = reader.storage(None, 32*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf831, (8*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_218 + buf832 = reader.storage(None, 32768*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf832, (8*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_29 + buf833 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf833, (8,), dtype=torch.int32, is_leaf=True) # cumsum_44 + buf834 = reader.storage(None, 22528*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf834, (8*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_42 + buf835 = reader.storage(None, 22528*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf835, (8*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_43 + buf836 = reader.storage(None, 22528*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf836, (8*(((u232 + u233 + u234 + u235 + u236 + u237 + u238 + u239 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_721 + buf837 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf837, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_124 + buf838 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf838, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_125 + buf839 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf839, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_741 + buf840 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf840, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1025 + buf841 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf841, (2, 4096, 1), is_leaf=True) # rsqrt_48 + buf842 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf842, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1041 + buf843 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf843, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_221 + buf844 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf844, (2, 4096, 1), is_leaf=True) # rsqrt_49 + buf845 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf845, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1055 + buf846 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf846, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_239 + buf847 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf847, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_240 + buf848 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf848, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_241 + buf849 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf849, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_225 + buf850 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf850, (2, 16, 4096), is_leaf=True) # getitem_226 + buf851 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf851, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1028 + buf852 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf852, (2, 4096, 1), is_leaf=True) # rsqrt_50 + buf853 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf853, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1063 + buf854 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf854, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_131 + buf855 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf855, (8192, 1), is_leaf=True) # amax_15 + buf856 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf856, (8192, 1), is_leaf=True) # sum_61 + buf857 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf857, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_229 + buf858 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf858, (49152,), dtype=torch.int64, is_leaf=True) # getitem_231 + buf859 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf859, (49152,), dtype=torch.int64, is_leaf=True) # div_77 + buf860 = reader.storage(None, 32*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf860, (8*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_232 + buf861 = reader.storage(None, 32768*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf861, (8*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_31 + buf862 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf862, (8,), dtype=torch.int32, is_leaf=True) # cumsum_47 + buf863 = reader.storage(None, 22528*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf863, (8*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_45 + buf864 = reader.storage(None, 22528*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf864, (8*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_46 + buf865 = reader.storage(None, 22528*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf865, (8*(((u248 + u249 + u250 + u251 + u252 + u253 + u254 + u255 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_770 + buf866 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf866, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_132 + buf867 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf867, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_133 + buf868 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf868, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_790 + buf869 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf869, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1093 + buf870 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf870, (2, 4096, 1), is_leaf=True) # rsqrt_51 + buf871 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf871, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1108 + buf872 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf872, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_235 + buf873 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf873, (2, 4096, 1), is_leaf=True) # rsqrt_52 + buf874 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf874, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1122 + buf875 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf875, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_254 + buf876 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf876, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_255 + buf877 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf877, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_256 + buf878 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf878, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_239 + buf879 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf879, (2, 16, 4096), is_leaf=True) # getitem_240 + buf880 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf880, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1096 + buf881 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf881, (2, 4096, 1), is_leaf=True) # rsqrt_53 + buf882 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf882, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1130 + buf883 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf883, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_139 + buf884 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf884, (8192, 1), is_leaf=True) # amax_16 + buf885 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf885, (8192, 1), is_leaf=True) # sum_65 + buf886 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf886, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_243 + buf887 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf887, (49152,), dtype=torch.int64, is_leaf=True) # getitem_245 + buf888 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf888, (49152,), dtype=torch.int64, is_leaf=True) # div_82 + buf889 = reader.storage(None, 32*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf889, (8*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_246 + buf890 = reader.storage(None, 32768*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf890, (8*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_33 + buf891 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf891, (8,), dtype=torch.int32, is_leaf=True) # cumsum_50 + buf892 = reader.storage(None, 22528*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf892, (8*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_48 + buf893 = reader.storage(None, 22528*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf893, (8*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_49 + buf894 = reader.storage(None, 22528*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf894, (8*(((u264 + u265 + u266 + u267 + u268 + u269 + u270 + u271 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_819 + buf895 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf895, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_140 + buf896 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf896, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_141 + buf897 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf897, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_839 + buf898 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf898, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1161 + buf899 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf899, (2, 4096, 1), is_leaf=True) # rsqrt_54 + buf900 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf900, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1175 + buf901 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf901, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_249 + buf902 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf902, (2, 4096, 1), is_leaf=True) # rsqrt_55 + buf903 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf903, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1189 + buf904 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf904, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_269 + buf905 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf905, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_270 + buf906 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf906, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_271 + buf907 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf907, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_253 + buf908 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf908, (2, 16, 4096), is_leaf=True) # getitem_254 + buf909 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf909, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1164 + buf910 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf910, (2, 4096, 1), is_leaf=True) # rsqrt_56 + buf911 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf911, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1197 + buf912 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf912, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_147 + buf913 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf913, (8192, 1), is_leaf=True) # amax_17 + buf914 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf914, (8192, 1), is_leaf=True) # sum_69 + buf915 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf915, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_257 + buf916 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf916, (49152,), dtype=torch.int64, is_leaf=True) # getitem_259 + buf917 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf917, (49152,), dtype=torch.int64, is_leaf=True) # div_87 + buf918 = reader.storage(None, 32*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf918, (8*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_260 + buf919 = reader.storage(None, 32768*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf919, (8*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_35 + buf920 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf920, (8,), dtype=torch.int32, is_leaf=True) # cumsum_53 + buf921 = reader.storage(None, 22528*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf921, (8*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_51 + buf922 = reader.storage(None, 22528*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf922, (8*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_52 + buf923 = reader.storage(None, 22528*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf923, (8*(((u280 + u281 + u282 + u283 + u284 + u285 + u286 + u287 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_868 + buf924 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf924, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_148 + buf925 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf925, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_149 + buf926 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf926, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_888 + buf927 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf927, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1229 + buf928 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf928, (2, 4096, 1), is_leaf=True) # rsqrt_57 + buf929 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf929, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1242 + buf930 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf930, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_263 + buf931 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf931, (2, 4096, 1), is_leaf=True) # rsqrt_58 + buf932 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf932, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1256 + buf933 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf933, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_284 + buf934 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf934, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_285 + buf935 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf935, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_286 + buf936 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf936, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_267 + buf937 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf937, (2, 16, 4096), is_leaf=True) # getitem_268 + buf938 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf938, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1232 + buf939 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf939, (2, 4096, 1), is_leaf=True) # rsqrt_59 + buf940 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf940, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1264 + buf941 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf941, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_155 + buf942 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf942, (8192, 1), is_leaf=True) # amax_18 + buf943 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf943, (8192, 1), is_leaf=True) # sum_73 + buf944 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf944, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_271 + buf945 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf945, (49152,), dtype=torch.int64, is_leaf=True) # getitem_273 + buf946 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf946, (49152,), dtype=torch.int64, is_leaf=True) # div_92 + buf947 = reader.storage(None, 32*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf947, (8*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_274 + buf948 = reader.storage(None, 32768*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf948, (8*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_37 + buf949 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf949, (8,), dtype=torch.int32, is_leaf=True) # cumsum_56 + buf950 = reader.storage(None, 22528*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf950, (8*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_54 + buf951 = reader.storage(None, 22528*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf951, (8*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_55 + buf952 = reader.storage(None, 22528*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf952, (8*(((u296 + u297 + u298 + u299 + u300 + u301 + u302 + u303 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_917 + buf953 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf953, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_156 + buf954 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf954, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_157 + buf955 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf955, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_937 + buf956 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf956, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1297 + buf957 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf957, (2, 4096, 1), is_leaf=True) # rsqrt_60 + buf958 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf958, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1309 + buf959 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf959, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_277 + buf960 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf960, (2, 4096, 1), is_leaf=True) # rsqrt_61 + buf961 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf961, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1323 + buf962 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf962, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_299 + buf963 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf963, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_300 + buf964 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf964, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_301 + buf965 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf965, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_281 + buf966 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf966, (2, 16, 4096), is_leaf=True) # getitem_282 + buf967 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf967, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1300 + buf968 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf968, (2, 4096, 1), is_leaf=True) # rsqrt_62 + buf969 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf969, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1331 + buf970 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf970, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_163 + buf971 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf971, (8192, 1), is_leaf=True) # amax_19 + buf972 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf972, (8192, 1), is_leaf=True) # sum_77 + buf973 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf973, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_285 + buf974 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf974, (49152,), dtype=torch.int64, is_leaf=True) # getitem_287 + buf975 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf975, (49152,), dtype=torch.int64, is_leaf=True) # div_97 + buf976 = reader.storage(None, 32*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf976, (8*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_288 + buf977 = reader.storage(None, 32768*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf977, (8*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_39 + buf978 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf978, (8,), dtype=torch.int32, is_leaf=True) # cumsum_59 + buf979 = reader.storage(None, 22528*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf979, (8*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_57 + buf980 = reader.storage(None, 22528*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf980, (8*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_58 + buf981 = reader.storage(None, 22528*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf981, (8*(((u312 + u313 + u314 + u315 + u316 + u317 + u318 + u319 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_966 + buf982 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf982, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_164 + buf983 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf983, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_165 + buf984 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf984, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_986 + buf985 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf985, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1365 + buf986 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf986, (2, 4096, 1), is_leaf=True) # rsqrt_63 + buf987 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf987, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1376 + buf988 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf988, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_291 + buf989 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf989, (2, 4096, 1), is_leaf=True) # rsqrt_64 + buf990 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf990, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1390 + buf991 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf991, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_314 + buf992 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf992, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_315 + buf993 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf993, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_316 + buf994 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf994, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_295 + buf995 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf995, (2, 16, 4096), is_leaf=True) # getitem_296 + buf996 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf996, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1368 + buf997 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf997, (2, 4096, 1), is_leaf=True) # rsqrt_65 + buf998 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf998, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1398 + buf999 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf999, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_171 + buf1000 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1000, (8192, 1), is_leaf=True) # amax_20 + buf1001 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1001, (8192, 1), is_leaf=True) # sum_81 + buf1002 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1002, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_299 + buf1003 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1003, (49152,), dtype=torch.int64, is_leaf=True) # getitem_301 + buf1004 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1004, (49152,), dtype=torch.int64, is_leaf=True) # div_102 + buf1005 = reader.storage(None, 32*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1005, (8*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_302 + buf1006 = reader.storage(None, 32768*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1006, (8*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_41 + buf1007 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1007, (8,), dtype=torch.int32, is_leaf=True) # cumsum_62 + buf1008 = reader.storage(None, 22528*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1008, (8*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_60 + buf1009 = reader.storage(None, 22528*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1009, (8*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_61 + buf1010 = reader.storage(None, 22528*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1010, (8*(((u328 + u329 + u330 + u331 + u332 + u333 + u334 + u335 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1015 + buf1011 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1011, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_172 + buf1012 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1012, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_173 + buf1013 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1013, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1035 + buf1014 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1014, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1433 + buf1015 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1015, (2, 4096, 1), is_leaf=True) # rsqrt_66 + buf1016 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1016, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1443 + buf1017 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1017, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_305 + buf1018 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1018, (2, 4096, 1), is_leaf=True) # rsqrt_67 + buf1019 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1019, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1457 + buf1020 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1020, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_329 + buf1021 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1021, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_330 + buf1022 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1022, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_331 + buf1023 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1023, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_309 + buf1024 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf1024, (2, 16, 4096), is_leaf=True) # getitem_310 + buf1025 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1025, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1436 + buf1026 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1026, (2, 4096, 1), is_leaf=True) # rsqrt_68 + buf1027 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1027, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1465 + buf1028 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1028, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_179 + buf1029 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1029, (8192, 1), is_leaf=True) # amax_21 + buf1030 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1030, (8192, 1), is_leaf=True) # sum_85 + buf1031 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1031, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_313 + buf1032 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1032, (49152,), dtype=torch.int64, is_leaf=True) # getitem_315 + buf1033 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1033, (49152,), dtype=torch.int64, is_leaf=True) # div_107 + buf1034 = reader.storage(None, 32*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1034, (8*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_316 + buf1035 = reader.storage(None, 32768*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1035, (8*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_43 + buf1036 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1036, (8,), dtype=torch.int32, is_leaf=True) # cumsum_65 + buf1037 = reader.storage(None, 22528*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1037, (8*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_63 + buf1038 = reader.storage(None, 22528*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1038, (8*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_64 + buf1039 = reader.storage(None, 22528*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1039, (8*(((u344 + u345 + u346 + u347 + u348 + u349 + u350 + u351 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1064 + buf1040 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1040, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_180 + buf1041 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1041, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_181 + buf1042 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1042, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1084 + buf1043 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1043, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1501 + buf1044 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1044, (2, 4096, 1), is_leaf=True) # rsqrt_69 + buf1045 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1045, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1510 + buf1046 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1046, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_319 + buf1047 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1047, (2, 4096, 1), is_leaf=True) # rsqrt_70 + buf1048 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1048, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1524 + buf1049 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1049, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_344 + buf1050 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1050, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_345 + buf1051 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1051, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_346 + buf1052 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1052, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_323 + buf1053 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf1053, (2, 16, 4096), is_leaf=True) # getitem_324 + buf1054 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1054, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1504 + buf1055 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1055, (2, 4096, 1), is_leaf=True) # rsqrt_71 + buf1056 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1056, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1532 + buf1057 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1057, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_187 + buf1058 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1058, (8192, 1), is_leaf=True) # amax_22 + buf1059 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1059, (8192, 1), is_leaf=True) # sum_89 + buf1060 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1060, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_327 + buf1061 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1061, (49152,), dtype=torch.int64, is_leaf=True) # getitem_329 + buf1062 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1062, (49152,), dtype=torch.int64, is_leaf=True) # div_112 + buf1063 = reader.storage(None, 32*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1063, (8*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_330 + buf1064 = reader.storage(None, 32768*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1064, (8*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_45 + buf1065 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1065, (8,), dtype=torch.int32, is_leaf=True) # cumsum_68 + buf1066 = reader.storage(None, 22528*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1066, (8*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_66 + buf1067 = reader.storage(None, 22528*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1067, (8*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_67 + buf1068 = reader.storage(None, 22528*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1068, (8*(((u360 + u361 + u362 + u363 + u364 + u365 + u366 + u367 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1113 + buf1069 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1069, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_188 + buf1070 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1070, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_189 + buf1071 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1071, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1133 + buf1072 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1072, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1569 + buf1073 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1073, (2, 4096, 1), is_leaf=True) # rsqrt_72 + buf1074 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1074, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1577 + buf1075 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1075, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_333 + buf1076 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1076, (2, 4096, 1), is_leaf=True) # rsqrt_73 + buf1077 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1077, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1591 + buf1078 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1078, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_359 + buf1079 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1079, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_360 + buf1080 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1080, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_361 + buf1081 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1081, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_337 + buf1082 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf1082, (2, 16, 4096), is_leaf=True) # getitem_338 + buf1083 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1083, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1572 + buf1084 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1084, (2, 4096, 1), is_leaf=True) # rsqrt_74 + buf1085 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1085, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1599 + buf1086 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1086, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_195 + buf1087 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1087, (8192, 1), is_leaf=True) # amax_23 + buf1088 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1088, (8192, 1), is_leaf=True) # sum_93 + buf1089 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1089, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_341 + buf1090 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1090, (49152,), dtype=torch.int64, is_leaf=True) # getitem_343 + buf1091 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1091, (49152,), dtype=torch.int64, is_leaf=True) # div_117 + buf1092 = reader.storage(None, 32*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1092, (8*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_344 + buf1093 = reader.storage(None, 32768*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1093, (8*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_47 + buf1094 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1094, (8,), dtype=torch.int32, is_leaf=True) # cumsum_71 + buf1095 = reader.storage(None, 22528*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1095, (8*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_69 + buf1096 = reader.storage(None, 22528*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1096, (8*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_70 + buf1097 = reader.storage(None, 22528*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1097, (8*(((u376 + u377 + u378 + u379 + u380 + u381 + u382 + u383 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1162 + buf1098 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1098, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_196 + buf1099 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1099, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_197 + buf1100 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1100, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1182 + buf1101 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1101, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1637 + buf1102 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1102, (2, 4096, 1), is_leaf=True) # rsqrt_75 + buf1103 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1103, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1644 + buf1104 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1104, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_347 + buf1105 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1105, (2, 4096, 1), is_leaf=True) # rsqrt_76 + buf1106 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1106, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1658 + buf1107 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1107, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_374 + buf1108 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1108, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_375 + buf1109 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1109, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_376 + buf1110 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1110, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_351 + buf1111 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf1111, (2, 16, 4096), is_leaf=True) # getitem_352 + buf1112 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1112, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1640 + buf1113 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1113, (2, 4096, 1), is_leaf=True) # rsqrt_77 + buf1114 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1114, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1666 + buf1115 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1115, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_203 + buf1116 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1116, (8192, 1), is_leaf=True) # amax_24 + buf1117 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1117, (8192, 1), is_leaf=True) # sum_97 + buf1118 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1118, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_355 + buf1119 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1119, (49152,), dtype=torch.int64, is_leaf=True) # getitem_357 + buf1120 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1120, (49152,), dtype=torch.int64, is_leaf=True) # div_122 + buf1121 = reader.storage(None, 32*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1121, (8*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_358 + buf1122 = reader.storage(None, 32768*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1122, (8*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_49 + buf1123 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1123, (8,), dtype=torch.int32, is_leaf=True) # cumsum_74 + buf1124 = reader.storage(None, 22528*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1124, (8*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_72 + buf1125 = reader.storage(None, 22528*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1125, (8*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_73 + buf1126 = reader.storage(None, 22528*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1126, (8*(((u392 + u393 + u394 + u395 + u396 + u397 + u398 + u399 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1211 + buf1127 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1127, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_204 + buf1128 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1128, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_205 + buf1129 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1129, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1231 + buf1130 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1130, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1705 + buf1131 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1131, (2, 4096, 1), is_leaf=True) # rsqrt_78 + buf1132 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1132, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1711 + buf1133 = reader.storage(None, 9437184, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1133, (2, 4096, 512), (2359296, 576, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_361 + buf1134 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1134, (2, 4096, 1), is_leaf=True) # rsqrt_79 + buf1135 = reader.storage(None, 8388608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1135, (8192, 512), dtype=torch.bfloat16, is_leaf=True) # view_1725 + buf1136 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1136, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_389 + buf1137 = reader.storage(None, 50331648, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1137, (2, 16, 4096, 192), (12582912, 192, 3072, 1), dtype=torch.bfloat16, is_leaf=True) # permute_390 + buf1138 = reader.storage(None, 67108864, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1138, (2, 16, 4096, 128), (16777216, 256, 4096, 1), dtype=torch.bfloat16, storage_offset=128, is_leaf=True) # permute_391 + buf1139 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1139, (2, 16, 4096, 128), (8388608, 128, 2048, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_365 + buf1140 = reader.storage(None, 524288, device=device(type='cuda', index=0)) + reader.tensor(buf1140, (2, 16, 4096), is_leaf=True) # getitem_366 + buf1141 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1141, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1708 + buf1142 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1142, (2, 4096, 1), is_leaf=True) # rsqrt_80 + buf1143 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1143, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1733 + buf1144 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1144, (8192, 64), dtype=torch.bfloat16, is_leaf=True) # mm_211 + buf1145 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1145, (8192, 1), is_leaf=True) # amax_25 + buf1146 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1146, (8192, 1), is_leaf=True) # sum_101 + buf1147 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1147, (8192, 6), dtype=torch.int64, is_leaf=True) # getitem_369 + buf1148 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1148, (49152,), dtype=torch.int64, is_leaf=True) # getitem_371 + buf1149 = reader.storage(None, 393216, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1149, (49152,), dtype=torch.int64, is_leaf=True) # div_127 + buf1150 = reader.storage(None, 32*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1150, (8*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)),), dtype=torch.int32, is_leaf=True) # getitem_372 + buf1151 = reader.storage(None, 32768*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1151, (8*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), 2048), dtype=torch.bfloat16, is_leaf=True) # index_51 + buf1152 = reader.storage(None, 32, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf1152, (8,), dtype=torch.int32, is_leaf=True) # cumsum_77 + buf1153 = reader.storage(None, 22528*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1153, (8*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_75 + buf1154 = reader.storage(None, 22528*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1154, (8*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # _grouped_mm_76 + buf1155 = reader.storage(None, 22528*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1155, (8*(((u408 + u409 + u410 + u411 + u412 + u413 + u414 + u415 + 71)//8)), 1408), dtype=torch.bfloat16, is_leaf=True) # mul_1260 + buf1156 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1156, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_212 + buf1157 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1157, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mm_213 + buf1158 = reader.storage(None, 46137344, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1158, (8192, 2816), dtype=torch.bfloat16, is_leaf=True) # mul_1280 + buf1159 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1159, (2, 4096, 2048), dtype=torch.bfloat16, is_leaf=True) # add_1773 + buf1160 = reader.storage(None, 32768, device=device(type='cuda', index=0)) + reader.tensor(buf1160, (2, 4096, 1), is_leaf=True) # rsqrt_81 + buf1161 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1161, (8192, 2048), dtype=torch.bfloat16, is_leaf=True) # view_1778 + buf1162 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1162, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_406 + buf1163 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1163, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_407 + buf1164 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1164, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_456 + buf1165 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1165, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_457 + buf1166 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1166, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_506 + buf1167 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1167, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_507 + buf1168 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1168, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_556 + buf1169 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1169, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_557 + buf1170 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1170, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_606 + buf1171 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1171, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_607 + buf1172 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1172, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_656 + buf1173 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1173, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_657 + buf1174 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1174, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_706 + buf1175 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1175, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_707 + buf1176 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1176, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_756 + buf1177 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1177, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_757 + buf1178 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1178, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_806 + buf1179 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1179, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_807 + buf1180 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1180, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_856 + buf1181 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1181, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_857 + buf1182 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1182, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_906 + buf1183 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1183, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_907 + buf1184 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1184, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_956 + buf1185 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1185, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_957 + buf1186 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1186, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1006 + buf1187 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1187, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1007 + buf1188 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1188, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1056 + buf1189 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1189, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1057 + buf1190 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1190, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1106 + buf1191 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1191, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1107 + buf1192 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1192, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1156 + buf1193 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1193, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1157 + buf1194 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1194, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1206 + buf1195 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1195, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1207 + buf1196 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1196, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1256 + buf1197 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1197, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1257 + buf1198 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1198, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1306 + buf1199 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1199, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1307 + buf1200 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1200, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1356 + buf1201 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1201, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1357 + buf1202 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1202, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1406 + buf1203 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1203, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1407 + buf1204 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1204, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1456 + buf1205 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1205, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1457 + buf1206 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1206, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1506 + buf1207 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1207, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1507 + buf1208 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1208, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1556 + buf1209 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1209, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1557 + buf1210 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1210, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1606 + buf1211 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1211, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1607 + buf1212 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf1212, (8192, 6, 1), (6, 1, 6), is_leaf=True) # permute_1656 + buf1213 = reader.storage(None, 402653184, device=device(type='cuda', index=0)) + reader.tensor(buf1213, (8192, 2048, 6), (12288, 1, 2048), is_leaf=True) # permute_1657 + buf1214 = reader.storage(None, 1677721600, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf1214, (2, 4096, 102400), dtype=torch.bfloat16, is_leaf=True) # tangents_1 +load_args._version = 0 +mod = Repro() +if __name__ == '__main__': + from torch._dynamo.repro.after_aot import run_repro + from torch._dynamo.repro.after_aot import setup_fake_process_groups + setup_fake_process_groups({'0': {'size': 64, 'rank': 0}, '521': {'size': 8, 'rank': 0}, '513': {'size': 8, 'rank': 0}}) + with torch.no_grad(): + run_repro(mod, load_args, accuracy=False, command='run', save_dir=None, tracing_mode='symbolic', check_str=None) + # To run it separately, do + # mod, args = run_repro(mod, load_args, accuracy=False, command='get_args', save_dir=None, tracing_mode='symbolic', check_str=None) + # mod(*args) + dist.destroy_process_group() + +# Helper functions for overlap simulator +def get_pg_config(): + """DSv3 64 GPUs: FSDP=64, TP=1, EP=8.""" + return {'0': {'size': 64, 'rank': 0}, '513': {'size': 8, 'rank': 0}, '521': {'size': 8, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls8_8.table" + +def get_colls_group_mapping(): + # FSDP "0" → internode (table group "0"), EP "513","521" → intranode (table group "1") + return {'0': '0', '513': '1', '521': '1'} diff --git a/autoparallel/tools/overlap_simulator/repro_dsv3_fw_128.py b/autoparallel/tools/overlap_simulator/repro_dsv3_fw_128.py new file mode 100644 index 00000000..7120138b --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_dsv3_fw_128.py @@ -0,0 +1,10290 @@ +# fmt: off +# flake8: noqa +# isort: skip_file + +import os +os.environ['PYTORCH_KERNEL_CACHE_PATH'] = '/mnt/mffuse/.cache/torch/kernels' +os.environ['TORCH_DISABLE_ADDR2LINE'] = '1' +os.environ['TORCH_TRACE'] = '/mnt/mffuse/outputs/sfsdp-dsv3-16b--tp1-bs2-inductor-128-ivankobzarev-hkpmbjc3/torch_trace/' +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' +os.environ['TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE'] = '[${role_name}${rank}|${local_rank}]:' +os.environ['TORCHELASTIC_MAX_RESTARTS'] = '0' +os.environ['TORCHX_INTERNAL_SESSION_ID'] = '03a200cc-023c-47d4-8372-8d223aedc5c2' +os.environ['TORCHX_RUN_PYTHONPATH'] = '' +os.environ['TORCHELASTIC_ERROR_FILE'] = '/tmp/torchelastic_i226b1gg/sfsdp-dsv3-16b--tp1-bs2-inductor-128-ivankobzarev-hkpmbjc3_ukylfuu9/attempt_0/0/error.json' +os.environ['TORCH_ADDR2LINE_BINARY'] = '/packages/folly.symbolizer/folly-addr2line' +os.environ['TORCHX_JOB_ID'] = 'mast_conda://torchx/sfsdp-dsv3-16b--tp1-bs2-inductor-128-ivankobzarev-hkpmbjc3' +os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '3' +os.environ['TORCHELASTIC_SIGNALS_TO_HANDLE'] = 'SIGTERM,SIGINT,SIGHUP,SIGQUIT' +os.environ['TORCHELASTIC_RUN_ID'] = 'sfsdp-dsv3-16b--tp1-bs2-inductor-128-ivankobzarev-hkpmbjc3' +os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1' +os.environ['TORCHELASTIC_RESTART_COUNT'] = '0' +os.environ['TORCHELASTIC_USE_AGENT_STORE'] = 'False' +os.environ['PYTORCH_DDP_USE_SIDE_STREAM'] = '0' +os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/tmp/torchinductor_root' +os.environ['TORCH_FR_BUFFER_SIZE'] = '20000' +os.environ['TORCH_NCCL_DUMP_ON_TIMEOUT'] = '1' +os.environ['TORCH_FR_DUMP_TEMP_FILE'] = '/mnt/mffuse_nccl_trace/nccl_trace/sfsdp-dsv3-16b--tp1-bs2-inductor-128-ivankobzarev-hkpmbjc3/v_0/attempt_0/nccl_trace_rank_' +os.environ['TRITON_CACHE_DIR'] = '/tmp/torchinductor_root/triton/0' + +import torch +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +import torch._inductor.inductor_prims +import torch.distributed as dist +from torch.testing._internal.distributed.fake_pg import FakeStore +import triton +import triton.language as tl + +import torch._dynamo.config +import torch._inductor.config +import torch._functorch.config +import torch.fx.experimental._config +torch._dynamo.config.capture_scalar_outputs = True +torch._inductor.config.allow_buffer_reuse = False +torch._inductor.config.reorder_for_compute_comm_overlap = False +torch._inductor.config.reorder_for_peak_memory = False +torch._inductor.config.max_autotune = False +torch._inductor.config.coordinate_descent_tuning = False +torch._inductor.config.deterministic = False +torch._inductor.config.aten_distributed_optimizations.collective_bucketing = True +torch._inductor.config.aten_distributed_optimizations.insert_overlap_deps = True +torch._inductor.config.wrap_inductor_compiled_regions = False +torch._inductor.config.triton.cudagraphs = False +torch._inductor.config.triton.store_cubin = False +torch._inductor.config.test_configs.runtime_triton_dtype_assert = False +torch._functorch.config.functionalize_rng_ops = False +torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = True +torch._functorch.config.unlift_effect_tokens = True +torch._functorch.config.selective_decompose = False + + + +isolate_fails_code_str = None + + + + + +if "__compile_source__" in globals(): + import inspect as __after_aot_inspect + import linecache as __after_aot_linecache + __after_aot_filename = __after_aot_inspect.currentframe().f_code.co_filename + __after_aot_linecache.cache[__after_aot_filename] = ( + len(__compile_source__), + None, + __compile_source__.splitlines(True), + __after_aot_filename, + ) +# torch version: 2.11.0a0+git5ac4d4b +# torch cuda version: 12.4 +# torch git version: 5ac4d4bf3f85e15fdd6676f46b090568ea91e47e + + +# CUDA Info: +# nvcc not found +# GPU Hardware Info: +# NVIDIA H100 80GB HBM3 : 8 + +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.reset_table() + +@triton.jit +def _fill_indices_kernel_0( + tokens_per_expert_group_ptr, + start_index_values_ptr, + write_offsets_ptr, + output_ptr, + experts_per_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # Number of threads per block +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # map programs (blocks) to the experts and loop (grid stride) if needed + for expert_id in range(pid, experts_per_rank, num_programs): + # read this experts write offset + write_offset = tl.load(write_offsets_ptr + expert_id) + + for r in range(num_ranks): + # index into tokens_per_expert_group array + i = r * experts_per_rank + expert_id + + # load start index and number of tokens for this expert-rank pair + start_index = tl.load(start_index_values_ptr + i) + length = tl.load(tokens_per_expert_group_ptr + i) + + # each thread in block processes tokens in parallel + offsets = tl.arange(0, BLOCK_SIZE) + + # tokens are processed in chunks of BLOCK_SIZE + for chunk_start in range(0, length, BLOCK_SIZE): + chunk_offsets = chunk_start + offsets + + # mask valid indices + mask = chunk_offsets < length + + values = start_index + chunk_offsets + + # destination + dest_indices = write_offset + chunk_offsets + + # store + tl.store(output_ptr + dest_indices, values, mask=mask) + + # update write offset for next rank + write_offset += length + +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(_fill_indices_kernel_0) +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.constant_args={0: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 1: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 2: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 3: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 4: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 5: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 6: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 7: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 8: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 9: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 10: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 11: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 12: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 13: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 14: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 15: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 16: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 17: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 18: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 19: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 20: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 21: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 22: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 23: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 24: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 25: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}} + +from torch.nn import * +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sdpa_score0 = lambda score, b, h, m, n, *args: score + self.sdpa_mask0 = lambda b, h, m, n, *args: True + self.sdpa_score1 = lambda score, b, h, m, n, *args: score + self.sdpa_mask1 = lambda b, h, m, n, *args: True + self.sdpa_score2 = lambda score, b, h, m, n, *args: score + self.sdpa_mask2 = lambda b, h, m, n, *args: True + self.sdpa_score3 = lambda score, b, h, m, n, *args: score + self.sdpa_mask3 = lambda b, h, m, n, *args: True + self.sdpa_score4 = lambda score, b, h, m, n, *args: score + self.sdpa_mask4 = lambda b, h, m, n, *args: True + self.sdpa_score5 = lambda score, b, h, m, n, *args: score + self.sdpa_mask5 = lambda b, h, m, n, *args: True + self.sdpa_score6 = lambda score, b, h, m, n, *args: score + self.sdpa_mask6 = lambda b, h, m, n, *args: True + self.sdpa_score7 = lambda score, b, h, m, n, *args: score + self.sdpa_mask7 = lambda b, h, m, n, *args: True + self.sdpa_score8 = lambda score, b, h, m, n, *args: score + self.sdpa_mask8 = lambda b, h, m, n, *args: True + self.sdpa_score9 = lambda score, b, h, m, n, *args: score + self.sdpa_mask9 = lambda b, h, m, n, *args: True + self.sdpa_score10 = lambda score, b, h, m, n, *args: score + self.sdpa_mask10 = lambda b, h, m, n, *args: True + self.sdpa_score11 = lambda score, b, h, m, n, *args: score + self.sdpa_mask11 = lambda b, h, m, n, *args: True + self.sdpa_score12 = lambda score, b, h, m, n, *args: score + self.sdpa_mask12 = lambda b, h, m, n, *args: True + self.sdpa_score13 = lambda score, b, h, m, n, *args: score + self.sdpa_mask13 = lambda b, h, m, n, *args: True + self.sdpa_score14 = lambda score, b, h, m, n, *args: score + self.sdpa_mask14 = lambda b, h, m, n, *args: True + self.sdpa_score15 = lambda score, b, h, m, n, *args: score + self.sdpa_mask15 = lambda b, h, m, n, *args: True + self.sdpa_score16 = lambda score, b, h, m, n, *args: score + self.sdpa_mask16 = lambda b, h, m, n, *args: True + self.sdpa_score17 = lambda score, b, h, m, n, *args: score + self.sdpa_mask17 = lambda b, h, m, n, *args: True + self.sdpa_score18 = lambda score, b, h, m, n, *args: score + self.sdpa_mask18 = lambda b, h, m, n, *args: True + self.sdpa_score19 = lambda score, b, h, m, n, *args: score + self.sdpa_mask19 = lambda b, h, m, n, *args: True + self.sdpa_score20 = lambda score, b, h, m, n, *args: score + self.sdpa_mask20 = lambda b, h, m, n, *args: True + self.sdpa_score21 = lambda score, b, h, m, n, *args: score + self.sdpa_mask21 = lambda b, h, m, n, *args: True + self.sdpa_score22 = lambda score, b, h, m, n, *args: score + self.sdpa_mask22 = lambda b, h, m, n, *args: True + self.sdpa_score23 = lambda score, b, h, m, n, *args: score + self.sdpa_mask23 = lambda b, h, m, n, *args: True + self.sdpa_score24 = lambda score, b, h, m, n, *args: score + self.sdpa_mask24 = lambda b, h, m, n, *args: True + self.sdpa_score25 = lambda score, b, h, m, n, *args: score + self.sdpa_mask25 = lambda b, h, m, n, *args: True + self.sdpa_score26 = lambda score, b, h, m, n, *args: score + self.sdpa_mask26 = lambda b, h, m, n, *args: True + + + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, primals_294, primals_295, primals_296, primals_297, primals_298, primals_299, primals_300, primals_301, primals_302, primals_303, primals_304, primals_305, primals_306, primals_307, primals_308, primals_309, primals_310, primals_311, primals_312, primals_313, primals_314, primals_315, primals_316, primals_317, primals_318, primals_319, primals_320, primals_321, primals_322, primals_323, primals_324, primals_325, primals_326, primals_327, primals_328, primals_329, primals_330, primals_331, primals_332, primals_333, primals_334, primals_335, primals_336, primals_337, primals_338, primals_339, primals_340, primals_341, primals_342, primals_343, primals_344, primals_345, primals_346, primals_347, primals_348, primals_349, primals_350, primals_351, primals_352, primals_353, primals_354, primals_355, primals_356, primals_357, primals_358, primals_359, primals_360, primals_361, primals_362, primals_363, primals_364, primals_365, primals_366, primals_367, primals_368, primals_369, primals_370, primals_371, primals_372, primals_373, primals_374, primals_375, primals_376, primals_377, primals_378, primals_379, primals_380, primals_381, primals_382, primals_383, primals_384, primals_385, primals_386, primals_387, primals_388, primals_389, primals_390, primals_391, primals_392, primals_393, primals_394, primals_395, primals_396, primals_397, primals_398, primals_399, primals_400, primals_401, primals_402, primals_403, primals_404, primals_405, primals_406, primals_407, primals_408, primals_409, primals_410, primals_411, primals_412, primals_413, primals_414, primals_415, primals_416, primals_417, primals_418, primals_419, primals_420, primals_421, primals_422, primals_423, primals_424, primals_425, primals_426, primals_427, primals_428, primals_429, primals_430, primals_431, primals_432, primals_433, primals_434, primals_435, primals_436, primals_437, primals_438, primals_439, primals_440): + convert_element_type = torch.ops.prims.convert_element_type.default(primals_1, torch.bfloat16) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type, 128, '0'); convert_element_type = None + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None + embedding = torch.ops.aten.embedding.default(wait_tensor, primals_2); wait_tensor = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16) + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 128, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32) + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_1); mul = wait_tensor_1 = None + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16) + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 128, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + view_3 = torch.ops.aten.view.default(convert_element_type_3, [8192, 2048]); convert_element_type_3 = None + mm = torch.ops.aten.mm.default(view_3, permute); permute = None + view_4 = torch.ops.aten.view.default(mm, [2, 4096, 3072]); mm = None + view_5 = torch.ops.aten.view.default(view_4, [2, 4096, -1, 192]); view_4 = None + split_with_sizes = torch.ops.aten.split_with_sizes.default(view_5, [128, 64], -1); view_5 = None + getitem = split_with_sizes[0] + getitem_1 = split_with_sizes[1]; split_with_sizes = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default(getitem_1, torch.float32); getitem_1 = None + view_6 = torch.ops.aten.view.default(convert_element_type_7, [2, 4096, 16, -1, 2]); convert_element_type_7 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_6); view_6 = None + view_7 = torch.ops.aten.view.default(primals_3, [1, 4096, 1, 32]) + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_7); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_8 = torch.ops.aten.view.default(view_as_real, [2, 4096, 16, 64]); view_as_real = None + convert_element_type_8 = torch.ops.prims.convert_element_type.default(view_8, torch.bfloat16); view_8 = None + cat = torch.ops.aten.cat.default([getitem, convert_element_type_8], -1); getitem = convert_element_type_8 = None + convert_element_type_9 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16) + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_9, 128, '0'); convert_element_type_9 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + slice_2 = torch.ops.aten.slice.Tensor(wait_tensor_3, 0, 0, 576); wait_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(slice_2, [1, 0]); slice_2 = None + mm_1 = torch.ops.aten.mm.default(view_3, permute_1); permute_1 = None + view_11 = torch.ops.aten.view.default(mm_1, [2, 4096, 576]); mm_1 = None + split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(view_11, [512, 64], -1); view_11 = None + getitem_2 = split_with_sizes_1[0] + getitem_3 = split_with_sizes_1[1]; split_with_sizes_1 = None + unsqueeze = torch.ops.aten.unsqueeze.default(getitem_3, 2); getitem_3 = None + convert_element_type_12 = torch.ops.prims.convert_element_type.default(unsqueeze, torch.float32); unsqueeze = None + view_12 = torch.ops.aten.view.default(convert_element_type_12, [2, 4096, 1, -1, 2]); convert_element_type_12 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_12); view_12 = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_7); view_as_complex_1 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_14 = torch.ops.aten.view.default(view_as_real_1, [2, 4096, 1, 64]); view_as_real_1 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_14, torch.bfloat16); view_14 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16) + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_14, 128, '0'); convert_element_type_14 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(getitem_2, torch.float32) + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_15, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_1 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_1); add_1 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_15, rsqrt_1); convert_element_type_15 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_4); mul_4 = wait_tensor_4 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16) + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 128, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + view_17 = torch.ops.aten.view.default(convert_element_type_16, [8192, 512]); convert_element_type_16 = None + mm_2 = torch.ops.aten.mm.default(view_17, permute_2); permute_2 = None + view_18 = torch.ops.aten.view.default(mm_2, [2, 4096, 4096]); mm_2 = None + view_19 = torch.ops.aten.view.default(view_18, [2, 4096, -1, 256]); view_18 = None + split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(view_19, [128, 128], -1); view_19 = None + getitem_4 = split_with_sizes_2[0] + getitem_5 = split_with_sizes_2[1]; split_with_sizes_2 = None + expand = torch.ops.aten.expand.default(convert_element_type_13, [-1, -1, 16, -1]); convert_element_type_13 = None + cat_1 = torch.ops.aten.cat.default([getitem_4, expand], -1); getitem_4 = expand = None + permute_3 = torch.ops.aten.permute.default(cat, [0, 2, 1, 3]); cat = None + permute_4 = torch.ops.aten.permute.default(cat_1, [0, 2, 1, 3]); cat_1 = None + permute_5 = torch.ops.aten.permute.default(getitem_5, [0, 2, 1, 3]); getitem_5 = None + sdpa_score0 = self.sdpa_score0 + sdpa_mask0 = self.sdpa_mask0 + flex_attention = torch.ops.higher_order.flex_attention(permute_3, permute_4, permute_5, sdpa_score0, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask0), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score0 = sdpa_mask0 = None + getitem_6 = flex_attention[0] + getitem_7 = flex_attention[1]; flex_attention = None + permute_6 = torch.ops.aten.permute.default(getitem_6, [0, 2, 1, 3]) + view_20 = torch.ops.aten.view.default(permute_6, [2, 4096, -1]); permute_6 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16) + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 128, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + view_22 = torch.ops.aten.view.default(view_20, [8192, 2048]); view_20 = None + mm_3 = torch.ops.aten.mm.default(view_22, permute_7); view_22 = permute_7 = None + view_23 = torch.ops.aten.view.default(mm_3, [2, 4096, 2048]) + add_2 = torch.ops.aten.add.Tensor(embedding, view_23); view_23 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16) + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 128, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_24 = torch.ops.prims.convert_element_type.default(add_2, torch.float32) + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_24, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_3 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_3); add_3 = None + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_24, rsqrt_2); convert_element_type_24 = None + mul_7 = torch.ops.aten.mul.Tensor(mul_6, wait_tensor_7); mul_6 = wait_tensor_7 = None + convert_element_type_25 = torch.ops.prims.convert_element_type.default(mul_7, torch.bfloat16); mul_7 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16) + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_26, 128, '0'); convert_element_type_26 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + slice_4 = torch.ops.aten.slice.Tensor(wait_tensor_8, 0, 0, 10944); wait_tensor_8 = None + permute_8 = torch.ops.aten.permute.default(slice_4, [1, 0]); slice_4 = None + view_26 = torch.ops.aten.view.default(convert_element_type_25, [8192, 2048]); convert_element_type_25 = None + mm_4 = torch.ops.aten.mm.default(view_26, permute_8); permute_8 = None + view_27 = torch.ops.aten.view.default(mm_4, [2, 4096, 10944]) + convert_element_type_29 = torch.ops.prims.convert_element_type.default(view_27, torch.float32); view_27 = None + neg = torch.ops.aten.neg.default(convert_element_type_29) + exp = torch.ops.aten.exp.default(neg); neg = None + add_4 = torch.ops.aten.add.Tensor(exp, 1); exp = None + div = torch.ops.aten.div.Tensor(convert_element_type_29, add_4); convert_element_type_29 = add_4 = None + convert_element_type_30 = torch.ops.prims.convert_element_type.default(div, torch.bfloat16); div = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16) + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 128, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + slice_5 = torch.ops.aten.slice.Tensor(wait_tensor_9, 0, 0, 10944); wait_tensor_9 = None + permute_9 = torch.ops.aten.permute.default(slice_5, [1, 0]); slice_5 = None + mm_5 = torch.ops.aten.mm.default(view_26, permute_9); permute_9 = None + view_30 = torch.ops.aten.view.default(mm_5, [2, 4096, 10944]) + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_30, view_30); convert_element_type_30 = view_30 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16) + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 128, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_10, [1, 0]); wait_tensor_10 = None + view_32 = torch.ops.aten.view.default(mul_8, [8192, 10944]); mul_8 = None + mm_6 = torch.ops.aten.mm.default(view_32, permute_10); permute_10 = None + view_33 = torch.ops.aten.view.default(mm_6, [2, 4096, 2048]); mm_6 = None + add_5 = torch.ops.aten.add.Tensor(add_2, view_33); add_2 = view_33 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16) + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 128, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + convert_element_type_38 = torch.ops.prims.convert_element_type.default(add_5, torch.float32) + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_38, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_9 = torch.ops.aten.mul.Tensor(convert_element_type_38, rsqrt_3); convert_element_type_38 = None + mul_10 = torch.ops.aten.mul.Tensor(mul_9, wait_tensor_11); mul_9 = wait_tensor_11 = None + convert_element_type_39 = torch.ops.prims.convert_element_type.default(mul_10, torch.bfloat16); mul_10 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16) + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 128, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + view_36 = torch.ops.aten.view.default(convert_element_type_39, [8192, 2048]); convert_element_type_39 = None + mm_7 = torch.ops.aten.mm.default(view_36, permute_11); permute_11 = None + view_37 = torch.ops.aten.view.default(mm_7, [2, 4096, 3072]); mm_7 = None + view_38 = torch.ops.aten.view.default(view_37, [2, 4096, -1, 192]); view_37 = None + split_with_sizes_3 = torch.ops.aten.split_with_sizes.default(view_38, [128, 64], -1); view_38 = None + getitem_9 = split_with_sizes_3[0] + getitem_10 = split_with_sizes_3[1]; split_with_sizes_3 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(getitem_10, torch.float32); getitem_10 = None + view_39 = torch.ops.aten.view.default(convert_element_type_43, [2, 4096, 16, -1, 2]); convert_element_type_43 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_39); view_39 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_7); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_41 = torch.ops.aten.view.default(view_as_real_2, [2, 4096, 16, 64]); view_as_real_2 = None + convert_element_type_44 = torch.ops.prims.convert_element_type.default(view_41, torch.bfloat16); view_41 = None + cat_2 = torch.ops.aten.cat.default([getitem_9, convert_element_type_44], -1); getitem_9 = convert_element_type_44 = None + convert_element_type_45 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16) + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_45, 128, '0'); convert_element_type_45 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + slice_7 = torch.ops.aten.slice.Tensor(wait_tensor_13, 0, 0, 576); wait_tensor_13 = None + permute_12 = torch.ops.aten.permute.default(slice_7, [1, 0]); slice_7 = None + mm_8 = torch.ops.aten.mm.default(view_36, permute_12); permute_12 = None + view_44 = torch.ops.aten.view.default(mm_8, [2, 4096, 576]); mm_8 = None + split_with_sizes_4 = torch.ops.aten.split_with_sizes.default(view_44, [512, 64], -1); view_44 = None + getitem_11 = split_with_sizes_4[0] + getitem_12 = split_with_sizes_4[1]; split_with_sizes_4 = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(getitem_12, 2); getitem_12 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(unsqueeze_1, torch.float32); unsqueeze_1 = None + view_45 = torch.ops.aten.view.default(convert_element_type_48, [2, 4096, 1, -1, 2]); convert_element_type_48 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_45); view_45 = None + mul_12 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_7); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_12); mul_12 = None + view_47 = torch.ops.aten.view.default(view_as_real_3, [2, 4096, 1, 64]); view_as_real_3 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_47, torch.bfloat16); view_47 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16) + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 128, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + convert_element_type_51 = torch.ops.prims.convert_element_type.default(getitem_11, torch.float32) + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_51, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_7 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_7); add_7 = None + mul_13 = torch.ops.aten.mul.Tensor(convert_element_type_51, rsqrt_4); convert_element_type_51 = None + mul_14 = torch.ops.aten.mul.Tensor(mul_13, wait_tensor_14); mul_13 = wait_tensor_14 = None + convert_element_type_52 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16) + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 128, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_15, [1, 0]); wait_tensor_15 = None + view_50 = torch.ops.aten.view.default(convert_element_type_52, [8192, 512]); convert_element_type_52 = None + mm_9 = torch.ops.aten.mm.default(view_50, permute_13); permute_13 = None + view_51 = torch.ops.aten.view.default(mm_9, [2, 4096, 4096]); mm_9 = None + view_52 = torch.ops.aten.view.default(view_51, [2, 4096, -1, 256]); view_51 = None + split_with_sizes_5 = torch.ops.aten.split_with_sizes.default(view_52, [128, 128], -1); view_52 = None + getitem_13 = split_with_sizes_5[0] + getitem_14 = split_with_sizes_5[1]; split_with_sizes_5 = None + expand_1 = torch.ops.aten.expand.default(convert_element_type_49, [-1, -1, 16, -1]); convert_element_type_49 = None + cat_3 = torch.ops.aten.cat.default([getitem_13, expand_1], -1); getitem_13 = expand_1 = None + permute_14 = torch.ops.aten.permute.default(cat_2, [0, 2, 1, 3]); cat_2 = None + permute_15 = torch.ops.aten.permute.default(cat_3, [0, 2, 1, 3]); cat_3 = None + permute_16 = torch.ops.aten.permute.default(getitem_14, [0, 2, 1, 3]); getitem_14 = None + sdpa_score1 = self.sdpa_score1 + sdpa_mask1 = self.sdpa_mask1 + flex_attention_1 = torch.ops.higher_order.flex_attention(permute_14, permute_15, permute_16, sdpa_score1, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask1), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score1 = sdpa_mask1 = None + getitem_15 = flex_attention_1[0] + getitem_16 = flex_attention_1[1]; flex_attention_1 = None + permute_17 = torch.ops.aten.permute.default(getitem_15, [0, 2, 1, 3]) + view_53 = torch.ops.aten.view.default(permute_17, [2, 4096, -1]); permute_17 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16) + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 128, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + view_55 = torch.ops.aten.view.default(view_53, [8192, 2048]); view_53 = None + mm_10 = torch.ops.aten.mm.default(view_55, permute_18); view_55 = permute_18 = None + view_56 = torch.ops.aten.view.default(mm_10, [2, 4096, 2048]); mm_10 = None + add_8 = torch.ops.aten.add.Tensor(add_5, view_56); view_56 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16) + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_59, 128, '0'); convert_element_type_59 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(add_8, torch.float32) + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_60, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_9 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_9); add_9 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, rsqrt_5); convert_element_type_60 = None + mul_16 = torch.ops.aten.mul.Tensor(mul_15, wait_tensor_17); mul_15 = wait_tensor_17 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(mul_16, torch.bfloat16); mul_16 = None + view_58 = torch.ops.aten.view.default(convert_element_type_61, [-1, 2048]); convert_element_type_61 = None + convert_element_type_62 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16) + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_62, 128, '0'); convert_element_type_62 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + slice_9 = torch.ops.aten.slice.Tensor(wait_tensor_18, 0, 0, 64); wait_tensor_18 = None + permute_19 = torch.ops.aten.permute.default(slice_9, [1, 0]); slice_9 = None + mm_11 = torch.ops.aten.mm.default(view_58, permute_19); permute_19 = None + convert_element_type_65 = torch.ops.prims.convert_element_type.default(mm_11, torch.float32) + amax = torch.ops.aten.amax.default(convert_element_type_65, [1], True) + sub = torch.ops.aten.sub.Tensor(convert_element_type_65, amax); convert_element_type_65 = None + exp_1 = torch.ops.aten.exp.default(sub); sub = None + sum_1 = torch.ops.aten.sum.dim_IntList(exp_1, [1], True) + div_1 = torch.ops.aten.div.Tensor(exp_1, sum_1); exp_1 = None + add_10 = torch.ops.aten.add.Tensor(div_1, primals_30); primals_30 = None + topk = torch.ops.aten.topk.default(add_10, 6, -1, True, False); add_10 = None + getitem_19 = topk[1]; topk = None + gather = torch.ops.aten.gather.default(div_1, 1, getitem_19); div_1 = None + mul_17 = torch.ops.aten.mul.Tensor(gather, 1.0); gather = None + view_60 = torch.ops.aten.view.default(getitem_19, [-1]) + histc = torch.ops.aten.histc.default(view_60, 64, 0, 64) + add_11 = torch.ops.aten.add.Tensor(primals_32, histc) + sort = torch.ops.aten.sort.stable(view_60, stable = True); view_60 = None + getitem_21 = sort[1]; sort = None + div_2 = torch.ops.aten.div.Tensor_mode(getitem_21, 6, rounding_mode = 'floor') + index = torch.ops.aten.index.Tensor(view_58, [div_2]) + all_to_all_single = torch.ops._c10d_functional.all_to_all_single.default(histc, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single); all_to_all_single = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_19); wait_tensor_19 = None + view_64 = torch.ops.aten.view.default(histc, [8, -1]); histc = None + sum_2 = torch.ops.aten.sum.dim_IntList(view_64, [1]); view_64 = None + device_put = torch.ops.prims.device_put.default(sum_2, device(type='cpu'), True); sum_2 = None + view_65 = torch.ops.aten.view.default(wait_tensor_20, [8, -1]) + sum_3 = torch.ops.aten.sum.dim_IntList(view_65, [1]) + device_put_1 = torch.ops.prims.device_put.default(sum_3, device(type='cpu')); sum_3 = None + select = torch.ops.aten.select.int(device_put, 0, 0) + _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None + ge = _local_scalar_dense >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None + select_1 = torch.ops.aten.select.int(device_put, 0, 1) + _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None + ge_1 = _local_scalar_dense_1 >= 0 + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None + select_2 = torch.ops.aten.select.int(device_put, 0, 2) + _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None + ge_2 = _local_scalar_dense_2 >= 0 + _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None + select_3 = torch.ops.aten.select.int(device_put, 0, 3) + _local_scalar_dense_3 = torch.ops.aten._local_scalar_dense.default(select_3); select_3 = None + ge_3 = _local_scalar_dense_3 >= 0 + _assert_scalar_3 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_3'"); ge_3 = _assert_scalar_3 = None + select_4 = torch.ops.aten.select.int(device_put, 0, 4) + _local_scalar_dense_4 = torch.ops.aten._local_scalar_dense.default(select_4); select_4 = None + ge_4 = _local_scalar_dense_4 >= 0 + _assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_4, "Runtime assertion failed for expression u4 >= 0 on node 'ge_4'"); ge_4 = _assert_scalar_4 = None + select_5 = torch.ops.aten.select.int(device_put, 0, 5) + _local_scalar_dense_5 = torch.ops.aten._local_scalar_dense.default(select_5); select_5 = None + ge_5 = _local_scalar_dense_5 >= 0 + _assert_scalar_5 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u5 >= 0 on node 'ge_5'"); ge_5 = _assert_scalar_5 = None + select_6 = torch.ops.aten.select.int(device_put, 0, 6) + _local_scalar_dense_6 = torch.ops.aten._local_scalar_dense.default(select_6); select_6 = None + ge_6 = _local_scalar_dense_6 >= 0 + _assert_scalar_6 = torch.ops.aten._assert_scalar.default(ge_6, "Runtime assertion failed for expression u6 >= 0 on node 'ge_6'"); ge_6 = _assert_scalar_6 = None + select_7 = torch.ops.aten.select.int(device_put, 0, 7); device_put = None + _local_scalar_dense_7 = torch.ops.aten._local_scalar_dense.default(select_7); select_7 = None + ge_7 = _local_scalar_dense_7 >= 0 + _assert_scalar_7 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u7 >= 0 on node 'ge_7'"); ge_7 = _assert_scalar_7 = None + select_8 = torch.ops.aten.select.int(device_put_1, 0, 0) + _local_scalar_dense_8 = torch.ops.aten._local_scalar_dense.default(select_8); select_8 = None + ge_8 = _local_scalar_dense_8 >= 0 + _assert_scalar_8 = torch.ops.aten._assert_scalar.default(ge_8, "Runtime assertion failed for expression u8 >= 0 on node 'ge_8'"); ge_8 = _assert_scalar_8 = None + select_9 = torch.ops.aten.select.int(device_put_1, 0, 1) + _local_scalar_dense_9 = torch.ops.aten._local_scalar_dense.default(select_9); select_9 = None + ge_9 = _local_scalar_dense_9 >= 0 + _assert_scalar_9 = torch.ops.aten._assert_scalar.default(ge_9, "Runtime assertion failed for expression u9 >= 0 on node 'ge_9'"); ge_9 = _assert_scalar_9 = None + select_10 = torch.ops.aten.select.int(device_put_1, 0, 2) + _local_scalar_dense_10 = torch.ops.aten._local_scalar_dense.default(select_10); select_10 = None + ge_10 = _local_scalar_dense_10 >= 0 + _assert_scalar_10 = torch.ops.aten._assert_scalar.default(ge_10, "Runtime assertion failed for expression u10 >= 0 on node 'ge_10'"); ge_10 = _assert_scalar_10 = None + select_11 = torch.ops.aten.select.int(device_put_1, 0, 3) + _local_scalar_dense_11 = torch.ops.aten._local_scalar_dense.default(select_11); select_11 = None + ge_11 = _local_scalar_dense_11 >= 0 + _assert_scalar_11 = torch.ops.aten._assert_scalar.default(ge_11, "Runtime assertion failed for expression u11 >= 0 on node 'ge_11'"); ge_11 = _assert_scalar_11 = None + select_12 = torch.ops.aten.select.int(device_put_1, 0, 4) + _local_scalar_dense_12 = torch.ops.aten._local_scalar_dense.default(select_12); select_12 = None + ge_12 = _local_scalar_dense_12 >= 0 + _assert_scalar_12 = torch.ops.aten._assert_scalar.default(ge_12, "Runtime assertion failed for expression u12 >= 0 on node 'ge_12'"); ge_12 = _assert_scalar_12 = None + select_13 = torch.ops.aten.select.int(device_put_1, 0, 5) + _local_scalar_dense_13 = torch.ops.aten._local_scalar_dense.default(select_13); select_13 = None + ge_13 = _local_scalar_dense_13 >= 0 + _assert_scalar_13 = torch.ops.aten._assert_scalar.default(ge_13, "Runtime assertion failed for expression u13 >= 0 on node 'ge_13'"); ge_13 = _assert_scalar_13 = None + select_14 = torch.ops.aten.select.int(device_put_1, 0, 6) + _local_scalar_dense_14 = torch.ops.aten._local_scalar_dense.default(select_14); select_14 = None + ge_14 = _local_scalar_dense_14 >= 0 + _assert_scalar_14 = torch.ops.aten._assert_scalar.default(ge_14, "Runtime assertion failed for expression u14 >= 0 on node 'ge_14'"); ge_14 = _assert_scalar_14 = None + select_15 = torch.ops.aten.select.int(device_put_1, 0, 7); device_put_1 = None + _local_scalar_dense_15 = torch.ops.aten._local_scalar_dense.default(select_15); select_15 = None + ge_15 = _local_scalar_dense_15 >= 0 + _assert_scalar_15 = torch.ops.aten._assert_scalar.default(ge_15, "Runtime assertion failed for expression u15 >= 0 on node 'ge_15'"); ge_15 = _assert_scalar_15 = None + all_to_all_single_1 = torch.ops._c10d_functional.all_to_all_single.default(index, [_local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15], [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7], '1033'); index = None + sym_size_int = torch.ops.aten.sym_size.int(all_to_all_single_1, 0) + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_1); all_to_all_single_1 = None + sym_sum = torch.sym_sum((_local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15, _local_scalar_dense_8, _local_scalar_dense_9)) + add_18 = sym_sum + 64; sym_sum = None + add_19 = add_18 + 8; add_18 = None + sub_3 = add_19 - 1; add_19 = None + floordiv = sub_3 // 8; sub_3 = None + mul_22 = floordiv * 8; floordiv = None + cumsum = torch.ops.aten.cumsum.default(wait_tensor_20, 0) + sub_4 = torch.ops.aten.sub.Tensor(cumsum, wait_tensor_20); cumsum = None + sum_4 = torch.ops.aten.sum.dim_IntList(view_65, [0]); view_65 = None + clamp_min = torch.ops.aten.clamp_min.default(sum_4, 8); sum_4 = None + add_20 = torch.ops.aten.add.Tensor(clamp_min, 8); clamp_min = None + sub_5 = torch.ops.aten.sub.Tensor(add_20, 1); add_20 = None + div_3 = torch.ops.aten.div.Tensor_mode(sub_5, 8, rounding_mode = 'floor'); sub_5 = None + mul_23 = torch.ops.aten.mul.Tensor(div_3, 8); div_3 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(mul_23, torch.int32); mul_23 = None + cumsum_1 = torch.ops.aten.cumsum.default(convert_element_type_68, 0) + sub_6 = torch.ops.aten.sub.Tensor(cumsum_1, convert_element_type_68); cumsum_1 = None + full_20 = torch.ops.aten.full.default([mul_22], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_22 = None + triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_20, 'start_index_values_ptr': sub_4, 'write_offsets_ptr': sub_6, 'output_ptr': full_20}, tensors_to_clone = ['output_ptr']); wait_tensor_20 = sub_4 = sub_6 = full_20 = None + getitem_22 = triton_kernel_wrapper_functional_proxy['output_ptr']; triton_kernel_wrapper_functional_proxy = None + full_default = torch.ops.aten.full.default([1, 2048], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + cat_4 = torch.ops.aten.cat.default([wait_tensor_21, full_default]); wait_tensor_21 = None + sym_size_int_1 = torch.ops.aten.sym_size.int(cat_4, 0) + sym_sum_1 = torch.sym_sum((1, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15, _local_scalar_dense_8, _local_scalar_dense_9)) + index_1 = torch.ops.aten.index.Tensor(cat_4, [getitem_22]); cat_4 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16) + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 16, '1025'); convert_element_type_70 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + split_1 = torch.ops.aten.split.Tensor(wait_tensor_22, 8); wait_tensor_22 = None + getitem_39 = split_1[0] + getitem_40 = split_1[1] + getitem_41 = split_1[2] + getitem_42 = split_1[3] + getitem_43 = split_1[4] + getitem_44 = split_1[5] + getitem_45 = split_1[6] + getitem_46 = split_1[7] + getitem_47 = split_1[8] + getitem_48 = split_1[9] + getitem_49 = split_1[10] + getitem_50 = split_1[11] + getitem_51 = split_1[12] + getitem_52 = split_1[13] + getitem_53 = split_1[14] + getitem_54 = split_1[15]; split_1 = None + cat_6 = torch.ops.aten.cat.default([getitem_39, getitem_40, getitem_41, getitem_42, getitem_43, getitem_44, getitem_45, getitem_46, getitem_47, getitem_48, getitem_49, getitem_50, getitem_51, getitem_52, getitem_53, getitem_54], 1); getitem_39 = getitem_40 = getitem_41 = getitem_42 = getitem_43 = getitem_44 = getitem_45 = getitem_46 = getitem_47 = getitem_48 = getitem_49 = getitem_50 = getitem_51 = getitem_52 = getitem_53 = getitem_54 = None + convert_element_type_72 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16) + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_72, 16, '1025'); convert_element_type_72 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + split_2 = torch.ops.aten.split.Tensor(wait_tensor_24, 8); wait_tensor_24 = None + getitem_55 = split_2[0] + getitem_56 = split_2[1] + getitem_57 = split_2[2] + getitem_58 = split_2[3] + getitem_59 = split_2[4] + getitem_60 = split_2[5] + getitem_61 = split_2[6] + getitem_62 = split_2[7] + getitem_63 = split_2[8] + getitem_64 = split_2[9] + getitem_65 = split_2[10] + getitem_66 = split_2[11] + getitem_67 = split_2[12] + getitem_68 = split_2[13] + getitem_69 = split_2[14] + getitem_70 = split_2[15]; split_2 = None + cat_7 = torch.ops.aten.cat.default([getitem_55, getitem_56, getitem_57, getitem_58, getitem_59, getitem_60, getitem_61, getitem_62, getitem_63, getitem_64, getitem_65, getitem_66, getitem_67, getitem_68, getitem_69, getitem_70], 1); getitem_55 = getitem_56 = getitem_57 = getitem_58 = getitem_59 = getitem_60 = getitem_61 = getitem_62 = getitem_63 = getitem_64 = getitem_65 = getitem_66 = getitem_67 = getitem_68 = getitem_69 = getitem_70 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16) + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 16, '1025'); convert_element_type_73 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + split_3 = torch.ops.aten.split.Tensor(wait_tensor_25, 8); wait_tensor_25 = None + getitem_71 = split_3[0] + getitem_72 = split_3[1] + getitem_73 = split_3[2] + getitem_74 = split_3[3] + getitem_75 = split_3[4] + getitem_76 = split_3[5] + getitem_77 = split_3[6] + getitem_78 = split_3[7] + getitem_79 = split_3[8] + getitem_80 = split_3[9] + getitem_81 = split_3[10] + getitem_82 = split_3[11] + getitem_83 = split_3[12] + getitem_84 = split_3[13] + getitem_85 = split_3[14] + getitem_86 = split_3[15]; split_3 = None + cat_8 = torch.ops.aten.cat.default([getitem_71, getitem_72, getitem_73, getitem_74, getitem_75, getitem_76, getitem_77, getitem_78, getitem_79, getitem_80, getitem_81, getitem_82, getitem_83, getitem_84, getitem_85, getitem_86], 1); getitem_71 = getitem_72 = getitem_73 = getitem_74 = getitem_75 = getitem_76 = getitem_77 = getitem_78 = getitem_79 = getitem_80 = getitem_81 = getitem_82 = getitem_83 = getitem_84 = getitem_85 = getitem_86 = None + cumsum_2 = torch.ops.aten.cumsum.default(convert_element_type_68, 0, dtype = torch.int32); convert_element_type_68 = None + permute_20 = torch.ops.aten.permute.default(cat_6, [0, 2, 1]); cat_6 = None + _grouped_mm = torch.ops.aten._grouped_mm.default(index_1, permute_20, cumsum_2) + convert_element_type_76 = torch.ops.prims.convert_element_type.default(_grouped_mm, torch.float32) + neg_1 = torch.ops.aten.neg.default(convert_element_type_76) + exp_2 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_32 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + div_4 = torch.ops.aten.div.Tensor(convert_element_type_76, add_32); convert_element_type_76 = add_32 = None + convert_element_type_77 = torch.ops.prims.convert_element_type.default(div_4, torch.bfloat16); div_4 = None + permute_21 = torch.ops.aten.permute.default(cat_8, [0, 2, 1]); cat_8 = None + _grouped_mm_1 = torch.ops.aten._grouped_mm.default(index_1, permute_21, cumsum_2) + mul_35 = torch.ops.aten.mul.Tensor(convert_element_type_77, _grouped_mm_1); convert_element_type_77 = None + permute_22 = torch.ops.aten.permute.default(cat_7, [0, 2, 1]); cat_7 = None + _grouped_mm_2 = torch.ops.aten._grouped_mm.default(mul_35, permute_22, cumsum_2) + empty = torch.ops.aten.empty.memory_format([sym_size_int_1, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put = torch.ops.aten.index_put.default(empty, [getitem_22], _grouped_mm_2); empty = _grouped_mm_2 = None + slice_11 = torch.ops.aten.slice.Tensor(index_put, 0, 0, -1); index_put = None + all_to_all_single_2 = torch.ops._c10d_functional.all_to_all_single.default(slice_11, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7], [_local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15], '1033'); slice_11 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_2); all_to_all_single_2 = None + convert_element_type_78 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16) + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_78, 128, '0'); convert_element_type_78 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + mm_12 = torch.ops.aten.mm.default(view_58, permute_23); permute_23 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(mm_12, torch.float32) + neg_2 = torch.ops.aten.neg.default(convert_element_type_81) + exp_3 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_68 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + div_5 = torch.ops.aten.div.Tensor(convert_element_type_81, add_68); convert_element_type_81 = add_68 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(div_5, torch.bfloat16); div_5 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16) + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 128, '0'); convert_element_type_83 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + mm_13 = torch.ops.aten.mm.default(view_58, permute_24); permute_24 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_82, mm_13); convert_element_type_82 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16) + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 128, '0'); convert_element_type_86 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_25 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_14 = torch.ops.aten.mm.default(mul_55, permute_25); permute_25 = None + full_default_1 = torch.ops.aten.full.default([49152, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_1 = torch.ops.aten.index_put.default(full_default_1, [getitem_21], wait_tensor_28); wait_tensor_28 = None + view_98 = torch.ops.aten.view.default(mul_17, [-1, 1, 6]); mul_17 = None + view_99 = torch.ops.aten.view.default(index_put_1, [-1, 6, 2048]); index_put_1 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(view_99, torch.float32); view_99 = None + bmm = torch.ops.aten.bmm.default(view_98, convert_element_type_89) + convert_element_type_90 = torch.ops.prims.convert_element_type.default(bmm, torch.bfloat16); bmm = None + squeeze = torch.ops.aten.squeeze.dim(convert_element_type_90, 1); convert_element_type_90 = None + add_72 = torch.ops.aten.add.Tensor(mm_14, squeeze); mm_14 = squeeze = None + view_100 = torch.ops.aten.view.default(add_72, [2, 4096, 2048]); add_72 = None + add_73 = torch.ops.aten.add.Tensor(add_8, view_100); view_100 = None + convert_element_type_91 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16) + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_91, 128, '0'); convert_element_type_91 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(add_73, torch.float32) + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_92, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_74 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_58 = torch.ops.aten.mul.Tensor(convert_element_type_92, rsqrt_6); convert_element_type_92 = None + mul_59 = torch.ops.aten.mul.Tensor(mul_58, wait_tensor_32); mul_58 = wait_tensor_32 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_59, torch.bfloat16); mul_59 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16) + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 128, '0'); convert_element_type_94 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_26 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + view_103 = torch.ops.aten.view.default(convert_element_type_93, [8192, 2048]); convert_element_type_93 = None + mm_15 = torch.ops.aten.mm.default(view_103, permute_26); permute_26 = None + view_104 = torch.ops.aten.view.default(mm_15, [2, 4096, 3072]); mm_15 = None + view_105 = torch.ops.aten.view.default(view_104, [2, 4096, -1, 192]); view_104 = None + split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_105, [128, 64], -1); view_105 = None + getitem_119 = split_with_sizes_6[0] + getitem_120 = split_with_sizes_6[1]; split_with_sizes_6 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(getitem_120, torch.float32); getitem_120 = None + view_106 = torch.ops.aten.view.default(convert_element_type_97, [2, 4096, 16, -1, 2]); convert_element_type_97 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_106); view_106 = None + mul_60 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_7); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_60); mul_60 = None + view_108 = torch.ops.aten.view.default(view_as_real_4, [2, 4096, 16, 64]); view_as_real_4 = None + convert_element_type_98 = torch.ops.prims.convert_element_type.default(view_108, torch.bfloat16); view_108 = None + cat_11 = torch.ops.aten.cat.default([getitem_119, convert_element_type_98], -1); getitem_119 = convert_element_type_98 = None + convert_element_type_99 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16) + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_99, 128, '0'); convert_element_type_99 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + slice_13 = torch.ops.aten.slice.Tensor(wait_tensor_34, 0, 0, 576); wait_tensor_34 = None + permute_27 = torch.ops.aten.permute.default(slice_13, [1, 0]); slice_13 = None + mm_16 = torch.ops.aten.mm.default(view_103, permute_27); permute_27 = None + view_111 = torch.ops.aten.view.default(mm_16, [2, 4096, 576]); mm_16 = None + split_with_sizes_7 = torch.ops.aten.split_with_sizes.default(view_111, [512, 64], -1); view_111 = None + getitem_121 = split_with_sizes_7[0] + getitem_122 = split_with_sizes_7[1]; split_with_sizes_7 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(getitem_122, 2); getitem_122 = None + convert_element_type_102 = torch.ops.prims.convert_element_type.default(unsqueeze_3, torch.float32); unsqueeze_3 = None + view_112 = torch.ops.aten.view.default(convert_element_type_102, [2, 4096, 1, -1, 2]); convert_element_type_102 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_112); view_112 = None + mul_61 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_7); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_61); mul_61 = None + view_114 = torch.ops.aten.view.default(view_as_real_5, [2, 4096, 1, 64]); view_as_real_5 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(view_114, torch.bfloat16); view_114 = None + convert_element_type_104 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16) + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_104, 128, '0'); convert_element_type_104 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + convert_element_type_105 = torch.ops.prims.convert_element_type.default(getitem_121, torch.float32) + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_105, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_75 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_75); add_75 = None + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_105, rsqrt_7); convert_element_type_105 = None + mul_63 = torch.ops.aten.mul.Tensor(mul_62, wait_tensor_35); mul_62 = wait_tensor_35 = None + convert_element_type_106 = torch.ops.prims.convert_element_type.default(mul_63, torch.bfloat16); mul_63 = None + convert_element_type_107 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16) + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_107, 128, '0'); convert_element_type_107 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_28 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + view_117 = torch.ops.aten.view.default(convert_element_type_106, [8192, 512]); convert_element_type_106 = None + mm_17 = torch.ops.aten.mm.default(view_117, permute_28); permute_28 = None + view_118 = torch.ops.aten.view.default(mm_17, [2, 4096, 4096]); mm_17 = None + view_119 = torch.ops.aten.view.default(view_118, [2, 4096, -1, 256]); view_118 = None + split_with_sizes_8 = torch.ops.aten.split_with_sizes.default(view_119, [128, 128], -1); view_119 = None + getitem_123 = split_with_sizes_8[0] + getitem_124 = split_with_sizes_8[1]; split_with_sizes_8 = None + expand_2 = torch.ops.aten.expand.default(convert_element_type_103, [-1, -1, 16, -1]); convert_element_type_103 = None + cat_12 = torch.ops.aten.cat.default([getitem_123, expand_2], -1); getitem_123 = expand_2 = None + permute_29 = torch.ops.aten.permute.default(cat_11, [0, 2, 1, 3]); cat_11 = None + permute_30 = torch.ops.aten.permute.default(cat_12, [0, 2, 1, 3]); cat_12 = None + permute_31 = torch.ops.aten.permute.default(getitem_124, [0, 2, 1, 3]); getitem_124 = None + sdpa_score2 = self.sdpa_score2 + sdpa_mask2 = self.sdpa_mask2 + flex_attention_2 = torch.ops.higher_order.flex_attention(permute_29, permute_30, permute_31, sdpa_score2, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask2), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score2 = sdpa_mask2 = None + getitem_125 = flex_attention_2[0] + getitem_126 = flex_attention_2[1]; flex_attention_2 = None + permute_32 = torch.ops.aten.permute.default(getitem_125, [0, 2, 1, 3]) + view_120 = torch.ops.aten.view.default(permute_32, [2, 4096, -1]); permute_32 = None + convert_element_type_110 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16) + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_110, 128, '0'); convert_element_type_110 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + view_122 = torch.ops.aten.view.default(view_120, [8192, 2048]); view_120 = None + mm_18 = torch.ops.aten.mm.default(view_122, permute_33); view_122 = permute_33 = None + view_123 = torch.ops.aten.view.default(mm_18, [2, 4096, 2048]); mm_18 = None + add_76 = torch.ops.aten.add.Tensor(add_73, view_123); view_123 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16) + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_113, 128, '0'); convert_element_type_113 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(add_76, torch.float32) + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_114, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_77 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_77); add_77 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_114, rsqrt_8); convert_element_type_114 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_38); mul_64 = wait_tensor_38 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + view_125 = torch.ops.aten.view.default(convert_element_type_115, [-1, 2048]); convert_element_type_115 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16) + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 128, '0'); convert_element_type_116 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + slice_15 = torch.ops.aten.slice.Tensor(wait_tensor_39, 0, 0, 64); wait_tensor_39 = None + permute_34 = torch.ops.aten.permute.default(slice_15, [1, 0]); slice_15 = None + mm_19 = torch.ops.aten.mm.default(view_125, permute_34); permute_34 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(mm_19, torch.float32) + amax_1 = torch.ops.aten.amax.default(convert_element_type_119, [1], True) + sub_24 = torch.ops.aten.sub.Tensor(convert_element_type_119, amax_1); convert_element_type_119 = None + exp_4 = torch.ops.aten.exp.default(sub_24); sub_24 = None + sum_5 = torch.ops.aten.sum.dim_IntList(exp_4, [1], True) + div_6 = torch.ops.aten.div.Tensor(exp_4, sum_5); exp_4 = None + add_78 = torch.ops.aten.add.Tensor(div_6, primals_46); primals_46 = None + topk_1 = torch.ops.aten.topk.default(add_78, 6, -1, True, False); add_78 = None + getitem_129 = topk_1[1]; topk_1 = None + gather_1 = torch.ops.aten.gather.default(div_6, 1, getitem_129); div_6 = None + mul_66 = torch.ops.aten.mul.Tensor(gather_1, 1.0); gather_1 = None + view_127 = torch.ops.aten.view.default(getitem_129, [-1]) + histc_2 = torch.ops.aten.histc.default(view_127, 64, 0, 64) + add_79 = torch.ops.aten.add.Tensor(primals_48, histc_2) + sort_1 = torch.ops.aten.sort.stable(view_127, stable = True); view_127 = None + getitem_131 = sort_1[1]; sort_1 = None + div_7 = torch.ops.aten.div.Tensor_mode(getitem_131, 6, rounding_mode = 'floor') + index_2 = torch.ops.aten.index.Tensor(view_125, [div_7]) + all_to_all_single_3 = torch.ops._c10d_functional.all_to_all_single.default(histc_2, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_3); all_to_all_single_3 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_40); wait_tensor_40 = None + view_131 = torch.ops.aten.view.default(histc_2, [8, -1]); histc_2 = None + sum_6 = torch.ops.aten.sum.dim_IntList(view_131, [1]); view_131 = None + device_put_2 = torch.ops.prims.device_put.default(sum_6, device(type='cpu'), True); sum_6 = None + view_132 = torch.ops.aten.view.default(wait_tensor_41, [8, -1]) + sum_7 = torch.ops.aten.sum.dim_IntList(view_132, [1]) + device_put_3 = torch.ops.prims.device_put.default(sum_7, device(type='cpu')); sum_7 = None + select_16 = torch.ops.aten.select.int(device_put_2, 0, 0) + _local_scalar_dense_16 = torch.ops.aten._local_scalar_dense.default(select_16); select_16 = None + ge_20 = _local_scalar_dense_16 >= 0 + _assert_scalar_16 = torch.ops.aten._assert_scalar.default(ge_20, "Runtime assertion failed for expression u16 >= 0 on node 'ge_16'"); ge_20 = _assert_scalar_16 = None + select_17 = torch.ops.aten.select.int(device_put_2, 0, 1) + _local_scalar_dense_17 = torch.ops.aten._local_scalar_dense.default(select_17); select_17 = None + ge_21 = _local_scalar_dense_17 >= 0 + _assert_scalar_17 = torch.ops.aten._assert_scalar.default(ge_21, "Runtime assertion failed for expression u17 >= 0 on node 'ge_17'"); ge_21 = _assert_scalar_17 = None + select_18 = torch.ops.aten.select.int(device_put_2, 0, 2) + _local_scalar_dense_18 = torch.ops.aten._local_scalar_dense.default(select_18); select_18 = None + ge_22 = _local_scalar_dense_18 >= 0 + _assert_scalar_18 = torch.ops.aten._assert_scalar.default(ge_22, "Runtime assertion failed for expression u18 >= 0 on node 'ge_18'"); ge_22 = _assert_scalar_18 = None + select_19 = torch.ops.aten.select.int(device_put_2, 0, 3) + _local_scalar_dense_19 = torch.ops.aten._local_scalar_dense.default(select_19); select_19 = None + ge_23 = _local_scalar_dense_19 >= 0 + _assert_scalar_19 = torch.ops.aten._assert_scalar.default(ge_23, "Runtime assertion failed for expression u19 >= 0 on node 'ge_19'"); ge_23 = _assert_scalar_19 = None + select_20 = torch.ops.aten.select.int(device_put_2, 0, 4) + _local_scalar_dense_20 = torch.ops.aten._local_scalar_dense.default(select_20); select_20 = None + ge_24 = _local_scalar_dense_20 >= 0 + _assert_scalar_20 = torch.ops.aten._assert_scalar.default(ge_24, "Runtime assertion failed for expression u20 >= 0 on node 'ge_20'"); ge_24 = _assert_scalar_20 = None + select_21 = torch.ops.aten.select.int(device_put_2, 0, 5) + _local_scalar_dense_21 = torch.ops.aten._local_scalar_dense.default(select_21); select_21 = None + ge_25 = _local_scalar_dense_21 >= 0 + _assert_scalar_21 = torch.ops.aten._assert_scalar.default(ge_25, "Runtime assertion failed for expression u21 >= 0 on node 'ge_21'"); ge_25 = _assert_scalar_21 = None + select_22 = torch.ops.aten.select.int(device_put_2, 0, 6) + _local_scalar_dense_22 = torch.ops.aten._local_scalar_dense.default(select_22); select_22 = None + ge_26 = _local_scalar_dense_22 >= 0 + _assert_scalar_22 = torch.ops.aten._assert_scalar.default(ge_26, "Runtime assertion failed for expression u22 >= 0 on node 'ge_22'"); ge_26 = _assert_scalar_22 = None + select_23 = torch.ops.aten.select.int(device_put_2, 0, 7); device_put_2 = None + _local_scalar_dense_23 = torch.ops.aten._local_scalar_dense.default(select_23); select_23 = None + ge_27 = _local_scalar_dense_23 >= 0 + _assert_scalar_23 = torch.ops.aten._assert_scalar.default(ge_27, "Runtime assertion failed for expression u23 >= 0 on node 'ge_23'"); ge_27 = _assert_scalar_23 = None + select_24 = torch.ops.aten.select.int(device_put_3, 0, 0) + _local_scalar_dense_24 = torch.ops.aten._local_scalar_dense.default(select_24); select_24 = None + ge_28 = _local_scalar_dense_24 >= 0 + _assert_scalar_24 = torch.ops.aten._assert_scalar.default(ge_28, "Runtime assertion failed for expression u24 >= 0 on node 'ge_24'"); ge_28 = _assert_scalar_24 = None + select_25 = torch.ops.aten.select.int(device_put_3, 0, 1) + _local_scalar_dense_25 = torch.ops.aten._local_scalar_dense.default(select_25); select_25 = None + ge_29 = _local_scalar_dense_25 >= 0 + _assert_scalar_25 = torch.ops.aten._assert_scalar.default(ge_29, "Runtime assertion failed for expression u25 >= 0 on node 'ge_25'"); ge_29 = _assert_scalar_25 = None + select_26 = torch.ops.aten.select.int(device_put_3, 0, 2) + _local_scalar_dense_26 = torch.ops.aten._local_scalar_dense.default(select_26); select_26 = None + ge_30 = _local_scalar_dense_26 >= 0 + _assert_scalar_26 = torch.ops.aten._assert_scalar.default(ge_30, "Runtime assertion failed for expression u26 >= 0 on node 'ge_26'"); ge_30 = _assert_scalar_26 = None + select_27 = torch.ops.aten.select.int(device_put_3, 0, 3) + _local_scalar_dense_27 = torch.ops.aten._local_scalar_dense.default(select_27); select_27 = None + ge_31 = _local_scalar_dense_27 >= 0 + _assert_scalar_27 = torch.ops.aten._assert_scalar.default(ge_31, "Runtime assertion failed for expression u27 >= 0 on node 'ge_27'"); ge_31 = _assert_scalar_27 = None + select_28 = torch.ops.aten.select.int(device_put_3, 0, 4) + _local_scalar_dense_28 = torch.ops.aten._local_scalar_dense.default(select_28); select_28 = None + ge_32 = _local_scalar_dense_28 >= 0 + _assert_scalar_28 = torch.ops.aten._assert_scalar.default(ge_32, "Runtime assertion failed for expression u28 >= 0 on node 'ge_28'"); ge_32 = _assert_scalar_28 = None + select_29 = torch.ops.aten.select.int(device_put_3, 0, 5) + _local_scalar_dense_29 = torch.ops.aten._local_scalar_dense.default(select_29); select_29 = None + ge_33 = _local_scalar_dense_29 >= 0 + _assert_scalar_29 = torch.ops.aten._assert_scalar.default(ge_33, "Runtime assertion failed for expression u29 >= 0 on node 'ge_29'"); ge_33 = _assert_scalar_29 = None + select_30 = torch.ops.aten.select.int(device_put_3, 0, 6) + _local_scalar_dense_30 = torch.ops.aten._local_scalar_dense.default(select_30); select_30 = None + ge_34 = _local_scalar_dense_30 >= 0 + _assert_scalar_30 = torch.ops.aten._assert_scalar.default(ge_34, "Runtime assertion failed for expression u30 >= 0 on node 'ge_30'"); ge_34 = _assert_scalar_30 = None + select_31 = torch.ops.aten.select.int(device_put_3, 0, 7); device_put_3 = None + _local_scalar_dense_31 = torch.ops.aten._local_scalar_dense.default(select_31); select_31 = None + ge_35 = _local_scalar_dense_31 >= 0 + _assert_scalar_31 = torch.ops.aten._assert_scalar.default(ge_35, "Runtime assertion failed for expression u31 >= 0 on node 'ge_31'"); ge_35 = _assert_scalar_31 = None + all_to_all_single_4 = torch.ops._c10d_functional.all_to_all_single.default(index_2, [_local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31], [_local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23], '1033'); index_2 = None + sym_size_int_4 = torch.ops.aten.sym_size.int(all_to_all_single_4, 0) + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_4); all_to_all_single_4 = None + sym_sum_2 = torch.sym_sum((_local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31)) + add_86 = sym_sum_2 + 64; sym_sum_2 = None + add_87 = add_86 + 8; add_86 = None + sub_27 = add_87 - 1; add_87 = None + floordiv_1 = sub_27 // 8; sub_27 = None + mul_71 = floordiv_1 * 8; floordiv_1 = None + cumsum_3 = torch.ops.aten.cumsum.default(wait_tensor_41, 0) + sub_28 = torch.ops.aten.sub.Tensor(cumsum_3, wait_tensor_41); cumsum_3 = None + sum_8 = torch.ops.aten.sum.dim_IntList(view_132, [0]); view_132 = None + clamp_min_1 = torch.ops.aten.clamp_min.default(sum_8, 8); sum_8 = None + add_88 = torch.ops.aten.add.Tensor(clamp_min_1, 8); clamp_min_1 = None + sub_29 = torch.ops.aten.sub.Tensor(add_88, 1); add_88 = None + div_8 = torch.ops.aten.div.Tensor_mode(sub_29, 8, rounding_mode = 'floor'); sub_29 = None + mul_72 = torch.ops.aten.mul.Tensor(div_8, 8); div_8 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(mul_72, torch.int32); mul_72 = None + cumsum_4 = torch.ops.aten.cumsum.default(convert_element_type_122, 0) + sub_30 = torch.ops.aten.sub.Tensor(cumsum_4, convert_element_type_122); cumsum_4 = None + full_33 = torch.ops.aten.full.default([mul_71], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_71 = None + triton_kernel_wrapper_functional_proxy_1 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 1, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_41, 'start_index_values_ptr': sub_28, 'write_offsets_ptr': sub_30, 'output_ptr': full_33}, tensors_to_clone = ['output_ptr']); wait_tensor_41 = sub_28 = sub_30 = full_33 = None + getitem_132 = triton_kernel_wrapper_functional_proxy_1['output_ptr']; triton_kernel_wrapper_functional_proxy_1 = None + cat_13 = torch.ops.aten.cat.default([wait_tensor_42, full_default]); wait_tensor_42 = None + sym_size_int_5 = torch.ops.aten.sym_size.int(cat_13, 0) + sym_sum_3 = torch.sym_sum((1, _local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31)) + index_3 = torch.ops.aten.index.Tensor(cat_13, [getitem_132]); cat_13 = None + convert_element_type_124 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16) + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_124, 16, '1025'); convert_element_type_124 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + split_7 = torch.ops.aten.split.Tensor(wait_tensor_43, 8); wait_tensor_43 = None + getitem_149 = split_7[0] + getitem_150 = split_7[1] + getitem_151 = split_7[2] + getitem_152 = split_7[3] + getitem_153 = split_7[4] + getitem_154 = split_7[5] + getitem_155 = split_7[6] + getitem_156 = split_7[7] + getitem_157 = split_7[8] + getitem_158 = split_7[9] + getitem_159 = split_7[10] + getitem_160 = split_7[11] + getitem_161 = split_7[12] + getitem_162 = split_7[13] + getitem_163 = split_7[14] + getitem_164 = split_7[15]; split_7 = None + cat_15 = torch.ops.aten.cat.default([getitem_149, getitem_150, getitem_151, getitem_152, getitem_153, getitem_154, getitem_155, getitem_156, getitem_157, getitem_158, getitem_159, getitem_160, getitem_161, getitem_162, getitem_163, getitem_164], 1); getitem_149 = getitem_150 = getitem_151 = getitem_152 = getitem_153 = getitem_154 = getitem_155 = getitem_156 = getitem_157 = getitem_158 = getitem_159 = getitem_160 = getitem_161 = getitem_162 = getitem_163 = getitem_164 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16) + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_126, 16, '1025'); convert_element_type_126 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + split_8 = torch.ops.aten.split.Tensor(wait_tensor_45, 8); wait_tensor_45 = None + getitem_165 = split_8[0] + getitem_166 = split_8[1] + getitem_167 = split_8[2] + getitem_168 = split_8[3] + getitem_169 = split_8[4] + getitem_170 = split_8[5] + getitem_171 = split_8[6] + getitem_172 = split_8[7] + getitem_173 = split_8[8] + getitem_174 = split_8[9] + getitem_175 = split_8[10] + getitem_176 = split_8[11] + getitem_177 = split_8[12] + getitem_178 = split_8[13] + getitem_179 = split_8[14] + getitem_180 = split_8[15]; split_8 = None + cat_16 = torch.ops.aten.cat.default([getitem_165, getitem_166, getitem_167, getitem_168, getitem_169, getitem_170, getitem_171, getitem_172, getitem_173, getitem_174, getitem_175, getitem_176, getitem_177, getitem_178, getitem_179, getitem_180], 1); getitem_165 = getitem_166 = getitem_167 = getitem_168 = getitem_169 = getitem_170 = getitem_171 = getitem_172 = getitem_173 = getitem_174 = getitem_175 = getitem_176 = getitem_177 = getitem_178 = getitem_179 = getitem_180 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16) + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 16, '1025'); convert_element_type_127 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + split_9 = torch.ops.aten.split.Tensor(wait_tensor_46, 8); wait_tensor_46 = None + getitem_181 = split_9[0] + getitem_182 = split_9[1] + getitem_183 = split_9[2] + getitem_184 = split_9[3] + getitem_185 = split_9[4] + getitem_186 = split_9[5] + getitem_187 = split_9[6] + getitem_188 = split_9[7] + getitem_189 = split_9[8] + getitem_190 = split_9[9] + getitem_191 = split_9[10] + getitem_192 = split_9[11] + getitem_193 = split_9[12] + getitem_194 = split_9[13] + getitem_195 = split_9[14] + getitem_196 = split_9[15]; split_9 = None + cat_17 = torch.ops.aten.cat.default([getitem_181, getitem_182, getitem_183, getitem_184, getitem_185, getitem_186, getitem_187, getitem_188, getitem_189, getitem_190, getitem_191, getitem_192, getitem_193, getitem_194, getitem_195, getitem_196], 1); getitem_181 = getitem_182 = getitem_183 = getitem_184 = getitem_185 = getitem_186 = getitem_187 = getitem_188 = getitem_189 = getitem_190 = getitem_191 = getitem_192 = getitem_193 = getitem_194 = getitem_195 = getitem_196 = None + cumsum_5 = torch.ops.aten.cumsum.default(convert_element_type_122, 0, dtype = torch.int32); convert_element_type_122 = None + permute_35 = torch.ops.aten.permute.default(cat_15, [0, 2, 1]); cat_15 = None + _grouped_mm_3 = torch.ops.aten._grouped_mm.default(index_3, permute_35, cumsum_5) + convert_element_type_130 = torch.ops.prims.convert_element_type.default(_grouped_mm_3, torch.float32) + neg_3 = torch.ops.aten.neg.default(convert_element_type_130) + exp_5 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_100 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + div_9 = torch.ops.aten.div.Tensor(convert_element_type_130, add_100); convert_element_type_130 = add_100 = None + convert_element_type_131 = torch.ops.prims.convert_element_type.default(div_9, torch.bfloat16); div_9 = None + permute_36 = torch.ops.aten.permute.default(cat_17, [0, 2, 1]); cat_17 = None + _grouped_mm_4 = torch.ops.aten._grouped_mm.default(index_3, permute_36, cumsum_5) + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_131, _grouped_mm_4); convert_element_type_131 = None + permute_37 = torch.ops.aten.permute.default(cat_16, [0, 2, 1]); cat_16 = None + _grouped_mm_5 = torch.ops.aten._grouped_mm.default(mul_84, permute_37, cumsum_5) + empty_1 = torch.ops.aten.empty.memory_format([sym_size_int_5, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_2 = torch.ops.aten.index_put.default(empty_1, [getitem_132], _grouped_mm_5); empty_1 = _grouped_mm_5 = None + slice_17 = torch.ops.aten.slice.Tensor(index_put_2, 0, 0, -1); index_put_2 = None + all_to_all_single_5 = torch.ops._c10d_functional.all_to_all_single.default(slice_17, [_local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23], [_local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31], '1033'); slice_17 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_5); all_to_all_single_5 = None + convert_element_type_132 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16) + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_132, 128, '0'); convert_element_type_132 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_38 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + mm_20 = torch.ops.aten.mm.default(view_125, permute_38); permute_38 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mm_20, torch.float32) + neg_4 = torch.ops.aten.neg.default(convert_element_type_135) + exp_6 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_136 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + div_10 = torch.ops.aten.div.Tensor(convert_element_type_135, add_136); convert_element_type_135 = add_136 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(div_10, torch.bfloat16); div_10 = None + convert_element_type_137 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16) + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_137, 128, '0'); convert_element_type_137 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_39 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + mm_21 = torch.ops.aten.mm.default(view_125, permute_39); permute_39 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_136, mm_21); convert_element_type_136 = None + convert_element_type_140 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16) + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_140, 128, '0'); convert_element_type_140 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + mm_22 = torch.ops.aten.mm.default(mul_104, permute_40); permute_40 = None + index_put_3 = torch.ops.aten.index_put.default(full_default_1, [getitem_131], wait_tensor_49); wait_tensor_49 = None + view_165 = torch.ops.aten.view.default(mul_66, [-1, 1, 6]); mul_66 = None + view_166 = torch.ops.aten.view.default(index_put_3, [-1, 6, 2048]); index_put_3 = None + convert_element_type_143 = torch.ops.prims.convert_element_type.default(view_166, torch.float32); view_166 = None + bmm_1 = torch.ops.aten.bmm.default(view_165, convert_element_type_143) + convert_element_type_144 = torch.ops.prims.convert_element_type.default(bmm_1, torch.bfloat16); bmm_1 = None + squeeze_1 = torch.ops.aten.squeeze.dim(convert_element_type_144, 1); convert_element_type_144 = None + add_140 = torch.ops.aten.add.Tensor(mm_22, squeeze_1); mm_22 = squeeze_1 = None + view_167 = torch.ops.aten.view.default(add_140, [2, 4096, 2048]); add_140 = None + add_141 = torch.ops.aten.add.Tensor(add_76, view_167); view_167 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16) + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_145, 128, '0'); convert_element_type_145 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(add_141, torch.float32) + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_146, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_142 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_142); add_142 = None + mul_107 = torch.ops.aten.mul.Tensor(convert_element_type_146, rsqrt_9); convert_element_type_146 = None + mul_108 = torch.ops.aten.mul.Tensor(mul_107, wait_tensor_53); mul_107 = wait_tensor_53 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(mul_108, torch.bfloat16); mul_108 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16) + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_148, 128, '0'); convert_element_type_148 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + view_170 = torch.ops.aten.view.default(convert_element_type_147, [8192, 2048]); convert_element_type_147 = None + mm_23 = torch.ops.aten.mm.default(view_170, permute_41); permute_41 = None + view_171 = torch.ops.aten.view.default(mm_23, [2, 4096, 3072]); mm_23 = None + view_172 = torch.ops.aten.view.default(view_171, [2, 4096, -1, 192]); view_171 = None + split_with_sizes_9 = torch.ops.aten.split_with_sizes.default(view_172, [128, 64], -1); view_172 = None + getitem_229 = split_with_sizes_9[0] + getitem_230 = split_with_sizes_9[1]; split_with_sizes_9 = None + convert_element_type_151 = torch.ops.prims.convert_element_type.default(getitem_230, torch.float32); getitem_230 = None + view_173 = torch.ops.aten.view.default(convert_element_type_151, [2, 4096, 16, -1, 2]); convert_element_type_151 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_173); view_173 = None + mul_109 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_7); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_109); mul_109 = None + view_175 = torch.ops.aten.view.default(view_as_real_6, [2, 4096, 16, 64]); view_as_real_6 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(view_175, torch.bfloat16); view_175 = None + cat_20 = torch.ops.aten.cat.default([getitem_229, convert_element_type_152], -1); getitem_229 = convert_element_type_152 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16) + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_153, 128, '0'); convert_element_type_153 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + slice_19 = torch.ops.aten.slice.Tensor(wait_tensor_55, 0, 0, 576); wait_tensor_55 = None + permute_42 = torch.ops.aten.permute.default(slice_19, [1, 0]); slice_19 = None + mm_24 = torch.ops.aten.mm.default(view_170, permute_42); permute_42 = None + view_178 = torch.ops.aten.view.default(mm_24, [2, 4096, 576]); mm_24 = None + split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_178, [512, 64], -1); view_178 = None + getitem_231 = split_with_sizes_10[0] + getitem_232 = split_with_sizes_10[1]; split_with_sizes_10 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(getitem_232, 2); getitem_232 = None + convert_element_type_156 = torch.ops.prims.convert_element_type.default(unsqueeze_5, torch.float32); unsqueeze_5 = None + view_179 = torch.ops.aten.view.default(convert_element_type_156, [2, 4096, 1, -1, 2]); convert_element_type_156 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_179); view_179 = None + mul_110 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_7); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_110); mul_110 = None + view_181 = torch.ops.aten.view.default(view_as_real_7, [2, 4096, 1, 64]); view_as_real_7 = None + convert_element_type_157 = torch.ops.prims.convert_element_type.default(view_181, torch.bfloat16); view_181 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16) + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_158, 128, '0'); convert_element_type_158 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(getitem_231, torch.float32) + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_159, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_143 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_143); add_143 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_159, rsqrt_10); convert_element_type_159 = None + mul_112 = torch.ops.aten.mul.Tensor(mul_111, wait_tensor_56); mul_111 = wait_tensor_56 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(mul_112, torch.bfloat16); mul_112 = None + convert_element_type_161 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16) + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_161, 128, '0'); convert_element_type_161 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + view_184 = torch.ops.aten.view.default(convert_element_type_160, [8192, 512]); convert_element_type_160 = None + mm_25 = torch.ops.aten.mm.default(view_184, permute_43); permute_43 = None + view_185 = torch.ops.aten.view.default(mm_25, [2, 4096, 4096]); mm_25 = None + view_186 = torch.ops.aten.view.default(view_185, [2, 4096, -1, 256]); view_185 = None + split_with_sizes_11 = torch.ops.aten.split_with_sizes.default(view_186, [128, 128], -1); view_186 = None + getitem_233 = split_with_sizes_11[0] + getitem_234 = split_with_sizes_11[1]; split_with_sizes_11 = None + expand_3 = torch.ops.aten.expand.default(convert_element_type_157, [-1, -1, 16, -1]); convert_element_type_157 = None + cat_21 = torch.ops.aten.cat.default([getitem_233, expand_3], -1); getitem_233 = expand_3 = None + permute_44 = torch.ops.aten.permute.default(cat_20, [0, 2, 1, 3]); cat_20 = None + permute_45 = torch.ops.aten.permute.default(cat_21, [0, 2, 1, 3]); cat_21 = None + permute_46 = torch.ops.aten.permute.default(getitem_234, [0, 2, 1, 3]); getitem_234 = None + sdpa_score3 = self.sdpa_score3 + sdpa_mask3 = self.sdpa_mask3 + flex_attention_3 = torch.ops.higher_order.flex_attention(permute_44, permute_45, permute_46, sdpa_score3, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask3), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score3 = sdpa_mask3 = None + getitem_235 = flex_attention_3[0] + getitem_236 = flex_attention_3[1]; flex_attention_3 = None + permute_47 = torch.ops.aten.permute.default(getitem_235, [0, 2, 1, 3]) + view_187 = torch.ops.aten.view.default(permute_47, [2, 4096, -1]); permute_47 = None + convert_element_type_164 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16) + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_164, 128, '0'); convert_element_type_164 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_48 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + view_189 = torch.ops.aten.view.default(view_187, [8192, 2048]); view_187 = None + mm_26 = torch.ops.aten.mm.default(view_189, permute_48); view_189 = permute_48 = None + view_190 = torch.ops.aten.view.default(mm_26, [2, 4096, 2048]); mm_26 = None + add_144 = torch.ops.aten.add.Tensor(add_141, view_190); view_190 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16) + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_167, 128, '0'); convert_element_type_167 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(add_144, torch.float32) + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_168, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_145 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_145); add_145 = None + mul_113 = torch.ops.aten.mul.Tensor(convert_element_type_168, rsqrt_11); convert_element_type_168 = None + mul_114 = torch.ops.aten.mul.Tensor(mul_113, wait_tensor_59); mul_113 = wait_tensor_59 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(mul_114, torch.bfloat16); mul_114 = None + view_192 = torch.ops.aten.view.default(convert_element_type_169, [-1, 2048]); convert_element_type_169 = None + convert_element_type_170 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16) + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_170, 128, '0'); convert_element_type_170 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + slice_21 = torch.ops.aten.slice.Tensor(wait_tensor_60, 0, 0, 64); wait_tensor_60 = None + permute_49 = torch.ops.aten.permute.default(slice_21, [1, 0]); slice_21 = None + mm_27 = torch.ops.aten.mm.default(view_192, permute_49); permute_49 = None + convert_element_type_173 = torch.ops.prims.convert_element_type.default(mm_27, torch.float32) + amax_2 = torch.ops.aten.amax.default(convert_element_type_173, [1], True) + sub_48 = torch.ops.aten.sub.Tensor(convert_element_type_173, amax_2); convert_element_type_173 = None + exp_7 = torch.ops.aten.exp.default(sub_48); sub_48 = None + sum_9 = torch.ops.aten.sum.dim_IntList(exp_7, [1], True) + div_11 = torch.ops.aten.div.Tensor(exp_7, sum_9); exp_7 = None + add_146 = torch.ops.aten.add.Tensor(div_11, primals_62); primals_62 = None + topk_2 = torch.ops.aten.topk.default(add_146, 6, -1, True, False); add_146 = None + getitem_239 = topk_2[1]; topk_2 = None + gather_2 = torch.ops.aten.gather.default(div_11, 1, getitem_239); div_11 = None + mul_115 = torch.ops.aten.mul.Tensor(gather_2, 1.0); gather_2 = None + view_194 = torch.ops.aten.view.default(getitem_239, [-1]) + histc_4 = torch.ops.aten.histc.default(view_194, 64, 0, 64) + add_147 = torch.ops.aten.add.Tensor(primals_64, histc_4) + sort_2 = torch.ops.aten.sort.stable(view_194, stable = True); view_194 = None + getitem_241 = sort_2[1]; sort_2 = None + div_12 = torch.ops.aten.div.Tensor_mode(getitem_241, 6, rounding_mode = 'floor') + index_4 = torch.ops.aten.index.Tensor(view_192, [div_12]) + all_to_all_single_6 = torch.ops._c10d_functional.all_to_all_single.default(histc_4, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_6); all_to_all_single_6 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_61); wait_tensor_61 = None + view_198 = torch.ops.aten.view.default(histc_4, [8, -1]); histc_4 = None + sum_10 = torch.ops.aten.sum.dim_IntList(view_198, [1]); view_198 = None + device_put_4 = torch.ops.prims.device_put.default(sum_10, device(type='cpu'), True); sum_10 = None + view_199 = torch.ops.aten.view.default(wait_tensor_62, [8, -1]) + sum_11 = torch.ops.aten.sum.dim_IntList(view_199, [1]) + device_put_5 = torch.ops.prims.device_put.default(sum_11, device(type='cpu')); sum_11 = None + select_32 = torch.ops.aten.select.int(device_put_4, 0, 0) + _local_scalar_dense_32 = torch.ops.aten._local_scalar_dense.default(select_32); select_32 = None + ge_40 = _local_scalar_dense_32 >= 0 + _assert_scalar_32 = torch.ops.aten._assert_scalar.default(ge_40, "Runtime assertion failed for expression u32 >= 0 on node 'ge_32'"); ge_40 = _assert_scalar_32 = None + select_33 = torch.ops.aten.select.int(device_put_4, 0, 1) + _local_scalar_dense_33 = torch.ops.aten._local_scalar_dense.default(select_33); select_33 = None + ge_41 = _local_scalar_dense_33 >= 0 + _assert_scalar_33 = torch.ops.aten._assert_scalar.default(ge_41, "Runtime assertion failed for expression u33 >= 0 on node 'ge_33'"); ge_41 = _assert_scalar_33 = None + select_34 = torch.ops.aten.select.int(device_put_4, 0, 2) + _local_scalar_dense_34 = torch.ops.aten._local_scalar_dense.default(select_34); select_34 = None + ge_42 = _local_scalar_dense_34 >= 0 + _assert_scalar_34 = torch.ops.aten._assert_scalar.default(ge_42, "Runtime assertion failed for expression u34 >= 0 on node 'ge_34'"); ge_42 = _assert_scalar_34 = None + select_35 = torch.ops.aten.select.int(device_put_4, 0, 3) + _local_scalar_dense_35 = torch.ops.aten._local_scalar_dense.default(select_35); select_35 = None + ge_43 = _local_scalar_dense_35 >= 0 + _assert_scalar_35 = torch.ops.aten._assert_scalar.default(ge_43, "Runtime assertion failed for expression u35 >= 0 on node 'ge_35'"); ge_43 = _assert_scalar_35 = None + select_36 = torch.ops.aten.select.int(device_put_4, 0, 4) + _local_scalar_dense_36 = torch.ops.aten._local_scalar_dense.default(select_36); select_36 = None + ge_44 = _local_scalar_dense_36 >= 0 + _assert_scalar_36 = torch.ops.aten._assert_scalar.default(ge_44, "Runtime assertion failed for expression u36 >= 0 on node 'ge_36'"); ge_44 = _assert_scalar_36 = None + select_37 = torch.ops.aten.select.int(device_put_4, 0, 5) + _local_scalar_dense_37 = torch.ops.aten._local_scalar_dense.default(select_37); select_37 = None + ge_45 = _local_scalar_dense_37 >= 0 + _assert_scalar_37 = torch.ops.aten._assert_scalar.default(ge_45, "Runtime assertion failed for expression u37 >= 0 on node 'ge_37'"); ge_45 = _assert_scalar_37 = None + select_38 = torch.ops.aten.select.int(device_put_4, 0, 6) + _local_scalar_dense_38 = torch.ops.aten._local_scalar_dense.default(select_38); select_38 = None + ge_46 = _local_scalar_dense_38 >= 0 + _assert_scalar_38 = torch.ops.aten._assert_scalar.default(ge_46, "Runtime assertion failed for expression u38 >= 0 on node 'ge_38'"); ge_46 = _assert_scalar_38 = None + select_39 = torch.ops.aten.select.int(device_put_4, 0, 7); device_put_4 = None + _local_scalar_dense_39 = torch.ops.aten._local_scalar_dense.default(select_39); select_39 = None + ge_47 = _local_scalar_dense_39 >= 0 + _assert_scalar_39 = torch.ops.aten._assert_scalar.default(ge_47, "Runtime assertion failed for expression u39 >= 0 on node 'ge_39'"); ge_47 = _assert_scalar_39 = None + select_40 = torch.ops.aten.select.int(device_put_5, 0, 0) + _local_scalar_dense_40 = torch.ops.aten._local_scalar_dense.default(select_40); select_40 = None + ge_48 = _local_scalar_dense_40 >= 0 + _assert_scalar_40 = torch.ops.aten._assert_scalar.default(ge_48, "Runtime assertion failed for expression u40 >= 0 on node 'ge_40'"); ge_48 = _assert_scalar_40 = None + select_41 = torch.ops.aten.select.int(device_put_5, 0, 1) + _local_scalar_dense_41 = torch.ops.aten._local_scalar_dense.default(select_41); select_41 = None + ge_49 = _local_scalar_dense_41 >= 0 + _assert_scalar_41 = torch.ops.aten._assert_scalar.default(ge_49, "Runtime assertion failed for expression u41 >= 0 on node 'ge_41'"); ge_49 = _assert_scalar_41 = None + select_42 = torch.ops.aten.select.int(device_put_5, 0, 2) + _local_scalar_dense_42 = torch.ops.aten._local_scalar_dense.default(select_42); select_42 = None + ge_50 = _local_scalar_dense_42 >= 0 + _assert_scalar_42 = torch.ops.aten._assert_scalar.default(ge_50, "Runtime assertion failed for expression u42 >= 0 on node 'ge_42'"); ge_50 = _assert_scalar_42 = None + select_43 = torch.ops.aten.select.int(device_put_5, 0, 3) + _local_scalar_dense_43 = torch.ops.aten._local_scalar_dense.default(select_43); select_43 = None + ge_51 = _local_scalar_dense_43 >= 0 + _assert_scalar_43 = torch.ops.aten._assert_scalar.default(ge_51, "Runtime assertion failed for expression u43 >= 0 on node 'ge_43'"); ge_51 = _assert_scalar_43 = None + select_44 = torch.ops.aten.select.int(device_put_5, 0, 4) + _local_scalar_dense_44 = torch.ops.aten._local_scalar_dense.default(select_44); select_44 = None + ge_52 = _local_scalar_dense_44 >= 0 + _assert_scalar_44 = torch.ops.aten._assert_scalar.default(ge_52, "Runtime assertion failed for expression u44 >= 0 on node 'ge_44'"); ge_52 = _assert_scalar_44 = None + select_45 = torch.ops.aten.select.int(device_put_5, 0, 5) + _local_scalar_dense_45 = torch.ops.aten._local_scalar_dense.default(select_45); select_45 = None + ge_53 = _local_scalar_dense_45 >= 0 + _assert_scalar_45 = torch.ops.aten._assert_scalar.default(ge_53, "Runtime assertion failed for expression u45 >= 0 on node 'ge_45'"); ge_53 = _assert_scalar_45 = None + select_46 = torch.ops.aten.select.int(device_put_5, 0, 6) + _local_scalar_dense_46 = torch.ops.aten._local_scalar_dense.default(select_46); select_46 = None + ge_54 = _local_scalar_dense_46 >= 0 + _assert_scalar_46 = torch.ops.aten._assert_scalar.default(ge_54, "Runtime assertion failed for expression u46 >= 0 on node 'ge_46'"); ge_54 = _assert_scalar_46 = None + select_47 = torch.ops.aten.select.int(device_put_5, 0, 7); device_put_5 = None + _local_scalar_dense_47 = torch.ops.aten._local_scalar_dense.default(select_47); select_47 = None + ge_55 = _local_scalar_dense_47 >= 0 + _assert_scalar_47 = torch.ops.aten._assert_scalar.default(ge_55, "Runtime assertion failed for expression u47 >= 0 on node 'ge_47'"); ge_55 = _assert_scalar_47 = None + all_to_all_single_7 = torch.ops._c10d_functional.all_to_all_single.default(index_4, [_local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47], [_local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39], '1033'); index_4 = None + sym_size_int_8 = torch.ops.aten.sym_size.int(all_to_all_single_7, 0) + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_7); all_to_all_single_7 = None + sym_sum_4 = torch.sym_sum((_local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47)) + add_154 = sym_sum_4 + 64; sym_sum_4 = None + add_155 = add_154 + 8; add_154 = None + sub_51 = add_155 - 1; add_155 = None + floordiv_2 = sub_51 // 8; sub_51 = None + mul_120 = floordiv_2 * 8; floordiv_2 = None + cumsum_6 = torch.ops.aten.cumsum.default(wait_tensor_62, 0) + sub_52 = torch.ops.aten.sub.Tensor(cumsum_6, wait_tensor_62); cumsum_6 = None + sum_12 = torch.ops.aten.sum.dim_IntList(view_199, [0]); view_199 = None + clamp_min_2 = torch.ops.aten.clamp_min.default(sum_12, 8); sum_12 = None + add_156 = torch.ops.aten.add.Tensor(clamp_min_2, 8); clamp_min_2 = None + sub_53 = torch.ops.aten.sub.Tensor(add_156, 1); add_156 = None + div_13 = torch.ops.aten.div.Tensor_mode(sub_53, 8, rounding_mode = 'floor'); sub_53 = None + mul_121 = torch.ops.aten.mul.Tensor(div_13, 8); div_13 = None + convert_element_type_176 = torch.ops.prims.convert_element_type.default(mul_121, torch.int32); mul_121 = None + cumsum_7 = torch.ops.aten.cumsum.default(convert_element_type_176, 0) + sub_54 = torch.ops.aten.sub.Tensor(cumsum_7, convert_element_type_176); cumsum_7 = None + full_46 = torch.ops.aten.full.default([mul_120], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_120 = None + triton_kernel_wrapper_functional_proxy_2 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 2, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_62, 'start_index_values_ptr': sub_52, 'write_offsets_ptr': sub_54, 'output_ptr': full_46}, tensors_to_clone = ['output_ptr']); wait_tensor_62 = sub_52 = sub_54 = full_46 = None + getitem_242 = triton_kernel_wrapper_functional_proxy_2['output_ptr']; triton_kernel_wrapper_functional_proxy_2 = None + cat_22 = torch.ops.aten.cat.default([wait_tensor_63, full_default]); wait_tensor_63 = None + sym_size_int_9 = torch.ops.aten.sym_size.int(cat_22, 0) + sym_sum_5 = torch.sym_sum((1, _local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47)) + index_5 = torch.ops.aten.index.Tensor(cat_22, [getitem_242]); cat_22 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16) + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_178, 16, '1025'); convert_element_type_178 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + split_13 = torch.ops.aten.split.Tensor(wait_tensor_64, 8); wait_tensor_64 = None + getitem_259 = split_13[0] + getitem_260 = split_13[1] + getitem_261 = split_13[2] + getitem_262 = split_13[3] + getitem_263 = split_13[4] + getitem_264 = split_13[5] + getitem_265 = split_13[6] + getitem_266 = split_13[7] + getitem_267 = split_13[8] + getitem_268 = split_13[9] + getitem_269 = split_13[10] + getitem_270 = split_13[11] + getitem_271 = split_13[12] + getitem_272 = split_13[13] + getitem_273 = split_13[14] + getitem_274 = split_13[15]; split_13 = None + cat_24 = torch.ops.aten.cat.default([getitem_259, getitem_260, getitem_261, getitem_262, getitem_263, getitem_264, getitem_265, getitem_266, getitem_267, getitem_268, getitem_269, getitem_270, getitem_271, getitem_272, getitem_273, getitem_274], 1); getitem_259 = getitem_260 = getitem_261 = getitem_262 = getitem_263 = getitem_264 = getitem_265 = getitem_266 = getitem_267 = getitem_268 = getitem_269 = getitem_270 = getitem_271 = getitem_272 = getitem_273 = getitem_274 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16) + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_180, 16, '1025'); convert_element_type_180 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + split_14 = torch.ops.aten.split.Tensor(wait_tensor_66, 8); wait_tensor_66 = None + getitem_275 = split_14[0] + getitem_276 = split_14[1] + getitem_277 = split_14[2] + getitem_278 = split_14[3] + getitem_279 = split_14[4] + getitem_280 = split_14[5] + getitem_281 = split_14[6] + getitem_282 = split_14[7] + getitem_283 = split_14[8] + getitem_284 = split_14[9] + getitem_285 = split_14[10] + getitem_286 = split_14[11] + getitem_287 = split_14[12] + getitem_288 = split_14[13] + getitem_289 = split_14[14] + getitem_290 = split_14[15]; split_14 = None + cat_25 = torch.ops.aten.cat.default([getitem_275, getitem_276, getitem_277, getitem_278, getitem_279, getitem_280, getitem_281, getitem_282, getitem_283, getitem_284, getitem_285, getitem_286, getitem_287, getitem_288, getitem_289, getitem_290], 1); getitem_275 = getitem_276 = getitem_277 = getitem_278 = getitem_279 = getitem_280 = getitem_281 = getitem_282 = getitem_283 = getitem_284 = getitem_285 = getitem_286 = getitem_287 = getitem_288 = getitem_289 = getitem_290 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16) + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_181, 16, '1025'); convert_element_type_181 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + split_15 = torch.ops.aten.split.Tensor(wait_tensor_67, 8); wait_tensor_67 = None + getitem_291 = split_15[0] + getitem_292 = split_15[1] + getitem_293 = split_15[2] + getitem_294 = split_15[3] + getitem_295 = split_15[4] + getitem_296 = split_15[5] + getitem_297 = split_15[6] + getitem_298 = split_15[7] + getitem_299 = split_15[8] + getitem_300 = split_15[9] + getitem_301 = split_15[10] + getitem_302 = split_15[11] + getitem_303 = split_15[12] + getitem_304 = split_15[13] + getitem_305 = split_15[14] + getitem_306 = split_15[15]; split_15 = None + cat_26 = torch.ops.aten.cat.default([getitem_291, getitem_292, getitem_293, getitem_294, getitem_295, getitem_296, getitem_297, getitem_298, getitem_299, getitem_300, getitem_301, getitem_302, getitem_303, getitem_304, getitem_305, getitem_306], 1); getitem_291 = getitem_292 = getitem_293 = getitem_294 = getitem_295 = getitem_296 = getitem_297 = getitem_298 = getitem_299 = getitem_300 = getitem_301 = getitem_302 = getitem_303 = getitem_304 = getitem_305 = getitem_306 = None + cumsum_8 = torch.ops.aten.cumsum.default(convert_element_type_176, 0, dtype = torch.int32); convert_element_type_176 = None + permute_50 = torch.ops.aten.permute.default(cat_24, [0, 2, 1]); cat_24 = None + _grouped_mm_6 = torch.ops.aten._grouped_mm.default(index_5, permute_50, cumsum_8) + convert_element_type_184 = torch.ops.prims.convert_element_type.default(_grouped_mm_6, torch.float32) + neg_5 = torch.ops.aten.neg.default(convert_element_type_184) + exp_8 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_168 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + div_14 = torch.ops.aten.div.Tensor(convert_element_type_184, add_168); convert_element_type_184 = add_168 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(div_14, torch.bfloat16); div_14 = None + permute_51 = torch.ops.aten.permute.default(cat_26, [0, 2, 1]); cat_26 = None + _grouped_mm_7 = torch.ops.aten._grouped_mm.default(index_5, permute_51, cumsum_8) + mul_133 = torch.ops.aten.mul.Tensor(convert_element_type_185, _grouped_mm_7); convert_element_type_185 = None + permute_52 = torch.ops.aten.permute.default(cat_25, [0, 2, 1]); cat_25 = None + _grouped_mm_8 = torch.ops.aten._grouped_mm.default(mul_133, permute_52, cumsum_8) + empty_2 = torch.ops.aten.empty.memory_format([sym_size_int_9, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_4 = torch.ops.aten.index_put.default(empty_2, [getitem_242], _grouped_mm_8); empty_2 = _grouped_mm_8 = None + slice_23 = torch.ops.aten.slice.Tensor(index_put_4, 0, 0, -1); index_put_4 = None + all_to_all_single_8 = torch.ops._c10d_functional.all_to_all_single.default(slice_23, [_local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39], [_local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47], '1033'); slice_23 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_8); all_to_all_single_8 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16) + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_186, 128, '0'); convert_element_type_186 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_28 = torch.ops.aten.mm.default(view_192, permute_53); permute_53 = None + convert_element_type_189 = torch.ops.prims.convert_element_type.default(mm_28, torch.float32) + neg_6 = torch.ops.aten.neg.default(convert_element_type_189) + exp_9 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_204 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + div_15 = torch.ops.aten.div.Tensor(convert_element_type_189, add_204); convert_element_type_189 = add_204 = None + convert_element_type_190 = torch.ops.prims.convert_element_type.default(div_15, torch.bfloat16); div_15 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16) + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_191, 128, '0'); convert_element_type_191 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + mm_29 = torch.ops.aten.mm.default(view_192, permute_54); permute_54 = None + mul_153 = torch.ops.aten.mul.Tensor(convert_element_type_190, mm_29); convert_element_type_190 = None + convert_element_type_194 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16) + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_194, 128, '0'); convert_element_type_194 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_73, [1, 0]); wait_tensor_73 = None + mm_30 = torch.ops.aten.mm.default(mul_153, permute_55); permute_55 = None + index_put_5 = torch.ops.aten.index_put.default(full_default_1, [getitem_241], wait_tensor_70); wait_tensor_70 = None + view_232 = torch.ops.aten.view.default(mul_115, [-1, 1, 6]); mul_115 = None + view_233 = torch.ops.aten.view.default(index_put_5, [-1, 6, 2048]); index_put_5 = None + convert_element_type_197 = torch.ops.prims.convert_element_type.default(view_233, torch.float32); view_233 = None + bmm_2 = torch.ops.aten.bmm.default(view_232, convert_element_type_197) + convert_element_type_198 = torch.ops.prims.convert_element_type.default(bmm_2, torch.bfloat16); bmm_2 = None + squeeze_2 = torch.ops.aten.squeeze.dim(convert_element_type_198, 1); convert_element_type_198 = None + add_208 = torch.ops.aten.add.Tensor(mm_30, squeeze_2); mm_30 = squeeze_2 = None + view_234 = torch.ops.aten.view.default(add_208, [2, 4096, 2048]); add_208 = None + add_209 = torch.ops.aten.add.Tensor(add_144, view_234); view_234 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16) + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 128, '0'); convert_element_type_199 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_209, torch.float32) + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_210 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_210); add_210 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_74); mul_156 = wait_tensor_74 = None + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16) + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 128, '0'); convert_element_type_202 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + view_237 = torch.ops.aten.view.default(convert_element_type_201, [8192, 2048]); convert_element_type_201 = None + mm_31 = torch.ops.aten.mm.default(view_237, permute_56); permute_56 = None + view_238 = torch.ops.aten.view.default(mm_31, [2, 4096, 3072]); mm_31 = None + view_239 = torch.ops.aten.view.default(view_238, [2, 4096, -1, 192]); view_238 = None + split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(view_239, [128, 64], -1); view_239 = None + getitem_339 = split_with_sizes_12[0] + getitem_340 = split_with_sizes_12[1]; split_with_sizes_12 = None + convert_element_type_205 = torch.ops.prims.convert_element_type.default(getitem_340, torch.float32); getitem_340 = None + view_240 = torch.ops.aten.view.default(convert_element_type_205, [2, 4096, 16, -1, 2]); convert_element_type_205 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_240); view_240 = None + mul_158 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_7); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_158); mul_158 = None + view_242 = torch.ops.aten.view.default(view_as_real_8, [2, 4096, 16, 64]); view_as_real_8 = None + convert_element_type_206 = torch.ops.prims.convert_element_type.default(view_242, torch.bfloat16); view_242 = None + cat_29 = torch.ops.aten.cat.default([getitem_339, convert_element_type_206], -1); getitem_339 = convert_element_type_206 = None + convert_element_type_207 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16) + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_207, 128, '0'); convert_element_type_207 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + slice_25 = torch.ops.aten.slice.Tensor(wait_tensor_76, 0, 0, 576); wait_tensor_76 = None + permute_57 = torch.ops.aten.permute.default(slice_25, [1, 0]); slice_25 = None + mm_32 = torch.ops.aten.mm.default(view_237, permute_57); permute_57 = None + view_245 = torch.ops.aten.view.default(mm_32, [2, 4096, 576]); mm_32 = None + split_with_sizes_13 = torch.ops.aten.split_with_sizes.default(view_245, [512, 64], -1); view_245 = None + getitem_341 = split_with_sizes_13[0] + getitem_342 = split_with_sizes_13[1]; split_with_sizes_13 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(getitem_342, 2); getitem_342 = None + convert_element_type_210 = torch.ops.prims.convert_element_type.default(unsqueeze_7, torch.float32); unsqueeze_7 = None + view_246 = torch.ops.aten.view.default(convert_element_type_210, [2, 4096, 1, -1, 2]); convert_element_type_210 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_246); view_246 = None + mul_159 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_7); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_159); mul_159 = None + view_248 = torch.ops.aten.view.default(view_as_real_9, [2, 4096, 1, 64]); view_as_real_9 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_248, torch.bfloat16); view_248 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16) + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_212, 128, '0'); convert_element_type_212 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(getitem_341, torch.float32) + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_213, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_211 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_211); add_211 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_213, rsqrt_13); convert_element_type_213 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_77); mul_160 = wait_tensor_77 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16) + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 128, '0'); convert_element_type_215 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_58 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + view_251 = torch.ops.aten.view.default(convert_element_type_214, [8192, 512]); convert_element_type_214 = None + mm_33 = torch.ops.aten.mm.default(view_251, permute_58); permute_58 = None + view_252 = torch.ops.aten.view.default(mm_33, [2, 4096, 4096]); mm_33 = None + view_253 = torch.ops.aten.view.default(view_252, [2, 4096, -1, 256]); view_252 = None + split_with_sizes_14 = torch.ops.aten.split_with_sizes.default(view_253, [128, 128], -1); view_253 = None + getitem_343 = split_with_sizes_14[0] + getitem_344 = split_with_sizes_14[1]; split_with_sizes_14 = None + expand_4 = torch.ops.aten.expand.default(convert_element_type_211, [-1, -1, 16, -1]); convert_element_type_211 = None + cat_30 = torch.ops.aten.cat.default([getitem_343, expand_4], -1); getitem_343 = expand_4 = None + permute_59 = torch.ops.aten.permute.default(cat_29, [0, 2, 1, 3]); cat_29 = None + permute_60 = torch.ops.aten.permute.default(cat_30, [0, 2, 1, 3]); cat_30 = None + permute_61 = torch.ops.aten.permute.default(getitem_344, [0, 2, 1, 3]); getitem_344 = None + sdpa_score4 = self.sdpa_score4 + sdpa_mask4 = self.sdpa_mask4 + flex_attention_4 = torch.ops.higher_order.flex_attention(permute_59, permute_60, permute_61, sdpa_score4, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask4), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score4 = sdpa_mask4 = None + getitem_345 = flex_attention_4[0] + getitem_346 = flex_attention_4[1]; flex_attention_4 = None + permute_62 = torch.ops.aten.permute.default(getitem_345, [0, 2, 1, 3]) + view_254 = torch.ops.aten.view.default(permute_62, [2, 4096, -1]); permute_62 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16) + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 128, '0'); convert_element_type_218 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + view_256 = torch.ops.aten.view.default(view_254, [8192, 2048]); view_254 = None + mm_34 = torch.ops.aten.mm.default(view_256, permute_63); view_256 = permute_63 = None + view_257 = torch.ops.aten.view.default(mm_34, [2, 4096, 2048]); mm_34 = None + add_212 = torch.ops.aten.add.Tensor(add_209, view_257); view_257 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16) + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 128, '0'); convert_element_type_221 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + convert_element_type_222 = torch.ops.prims.convert_element_type.default(add_212, torch.float32) + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_222, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_213 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_213); add_213 = None + mul_162 = torch.ops.aten.mul.Tensor(convert_element_type_222, rsqrt_14); convert_element_type_222 = None + mul_163 = torch.ops.aten.mul.Tensor(mul_162, wait_tensor_80); mul_162 = wait_tensor_80 = None + convert_element_type_223 = torch.ops.prims.convert_element_type.default(mul_163, torch.bfloat16); mul_163 = None + view_259 = torch.ops.aten.view.default(convert_element_type_223, [-1, 2048]); convert_element_type_223 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16) + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_224, 128, '0'); convert_element_type_224 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + slice_27 = torch.ops.aten.slice.Tensor(wait_tensor_81, 0, 0, 64); wait_tensor_81 = None + permute_64 = torch.ops.aten.permute.default(slice_27, [1, 0]); slice_27 = None + mm_35 = torch.ops.aten.mm.default(view_259, permute_64); permute_64 = None + convert_element_type_227 = torch.ops.prims.convert_element_type.default(mm_35, torch.float32) + amax_3 = torch.ops.aten.amax.default(convert_element_type_227, [1], True) + sub_72 = torch.ops.aten.sub.Tensor(convert_element_type_227, amax_3); convert_element_type_227 = None + exp_10 = torch.ops.aten.exp.default(sub_72); sub_72 = None + sum_13 = torch.ops.aten.sum.dim_IntList(exp_10, [1], True) + div_16 = torch.ops.aten.div.Tensor(exp_10, sum_13); exp_10 = None + add_214 = torch.ops.aten.add.Tensor(div_16, primals_78); primals_78 = None + topk_3 = torch.ops.aten.topk.default(add_214, 6, -1, True, False); add_214 = None + getitem_349 = topk_3[1]; topk_3 = None + gather_3 = torch.ops.aten.gather.default(div_16, 1, getitem_349); div_16 = None + mul_164 = torch.ops.aten.mul.Tensor(gather_3, 1.0); gather_3 = None + view_261 = torch.ops.aten.view.default(getitem_349, [-1]) + histc_6 = torch.ops.aten.histc.default(view_261, 64, 0, 64) + add_215 = torch.ops.aten.add.Tensor(primals_80, histc_6) + sort_3 = torch.ops.aten.sort.stable(view_261, stable = True); view_261 = None + getitem_351 = sort_3[1]; sort_3 = None + div_17 = torch.ops.aten.div.Tensor_mode(getitem_351, 6, rounding_mode = 'floor') + index_6 = torch.ops.aten.index.Tensor(view_259, [div_17]) + all_to_all_single_9 = torch.ops._c10d_functional.all_to_all_single.default(histc_6, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_9); all_to_all_single_9 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_82); wait_tensor_82 = None + view_265 = torch.ops.aten.view.default(histc_6, [8, -1]); histc_6 = None + sum_14 = torch.ops.aten.sum.dim_IntList(view_265, [1]); view_265 = None + device_put_6 = torch.ops.prims.device_put.default(sum_14, device(type='cpu'), True); sum_14 = None + view_266 = torch.ops.aten.view.default(wait_tensor_83, [8, -1]) + sum_15 = torch.ops.aten.sum.dim_IntList(view_266, [1]) + device_put_7 = torch.ops.prims.device_put.default(sum_15, device(type='cpu')); sum_15 = None + select_48 = torch.ops.aten.select.int(device_put_6, 0, 0) + _local_scalar_dense_48 = torch.ops.aten._local_scalar_dense.default(select_48); select_48 = None + ge_60 = _local_scalar_dense_48 >= 0 + _assert_scalar_48 = torch.ops.aten._assert_scalar.default(ge_60, "Runtime assertion failed for expression u48 >= 0 on node 'ge_48'"); ge_60 = _assert_scalar_48 = None + select_49 = torch.ops.aten.select.int(device_put_6, 0, 1) + _local_scalar_dense_49 = torch.ops.aten._local_scalar_dense.default(select_49); select_49 = None + ge_61 = _local_scalar_dense_49 >= 0 + _assert_scalar_49 = torch.ops.aten._assert_scalar.default(ge_61, "Runtime assertion failed for expression u49 >= 0 on node 'ge_49'"); ge_61 = _assert_scalar_49 = None + select_50 = torch.ops.aten.select.int(device_put_6, 0, 2) + _local_scalar_dense_50 = torch.ops.aten._local_scalar_dense.default(select_50); select_50 = None + ge_62 = _local_scalar_dense_50 >= 0 + _assert_scalar_50 = torch.ops.aten._assert_scalar.default(ge_62, "Runtime assertion failed for expression u50 >= 0 on node 'ge_50'"); ge_62 = _assert_scalar_50 = None + select_51 = torch.ops.aten.select.int(device_put_6, 0, 3) + _local_scalar_dense_51 = torch.ops.aten._local_scalar_dense.default(select_51); select_51 = None + ge_63 = _local_scalar_dense_51 >= 0 + _assert_scalar_51 = torch.ops.aten._assert_scalar.default(ge_63, "Runtime assertion failed for expression u51 >= 0 on node 'ge_51'"); ge_63 = _assert_scalar_51 = None + select_52 = torch.ops.aten.select.int(device_put_6, 0, 4) + _local_scalar_dense_52 = torch.ops.aten._local_scalar_dense.default(select_52); select_52 = None + ge_64 = _local_scalar_dense_52 >= 0 + _assert_scalar_52 = torch.ops.aten._assert_scalar.default(ge_64, "Runtime assertion failed for expression u52 >= 0 on node 'ge_52'"); ge_64 = _assert_scalar_52 = None + select_53 = torch.ops.aten.select.int(device_put_6, 0, 5) + _local_scalar_dense_53 = torch.ops.aten._local_scalar_dense.default(select_53); select_53 = None + ge_65 = _local_scalar_dense_53 >= 0 + _assert_scalar_53 = torch.ops.aten._assert_scalar.default(ge_65, "Runtime assertion failed for expression u53 >= 0 on node 'ge_53'"); ge_65 = _assert_scalar_53 = None + select_54 = torch.ops.aten.select.int(device_put_6, 0, 6) + _local_scalar_dense_54 = torch.ops.aten._local_scalar_dense.default(select_54); select_54 = None + ge_66 = _local_scalar_dense_54 >= 0 + _assert_scalar_54 = torch.ops.aten._assert_scalar.default(ge_66, "Runtime assertion failed for expression u54 >= 0 on node 'ge_54'"); ge_66 = _assert_scalar_54 = None + select_55 = torch.ops.aten.select.int(device_put_6, 0, 7); device_put_6 = None + _local_scalar_dense_55 = torch.ops.aten._local_scalar_dense.default(select_55); select_55 = None + ge_67 = _local_scalar_dense_55 >= 0 + _assert_scalar_55 = torch.ops.aten._assert_scalar.default(ge_67, "Runtime assertion failed for expression u55 >= 0 on node 'ge_55'"); ge_67 = _assert_scalar_55 = None + select_56 = torch.ops.aten.select.int(device_put_7, 0, 0) + _local_scalar_dense_56 = torch.ops.aten._local_scalar_dense.default(select_56); select_56 = None + ge_68 = _local_scalar_dense_56 >= 0 + _assert_scalar_56 = torch.ops.aten._assert_scalar.default(ge_68, "Runtime assertion failed for expression u56 >= 0 on node 'ge_56'"); ge_68 = _assert_scalar_56 = None + select_57 = torch.ops.aten.select.int(device_put_7, 0, 1) + _local_scalar_dense_57 = torch.ops.aten._local_scalar_dense.default(select_57); select_57 = None + ge_69 = _local_scalar_dense_57 >= 0 + _assert_scalar_57 = torch.ops.aten._assert_scalar.default(ge_69, "Runtime assertion failed for expression u57 >= 0 on node 'ge_57'"); ge_69 = _assert_scalar_57 = None + select_58 = torch.ops.aten.select.int(device_put_7, 0, 2) + _local_scalar_dense_58 = torch.ops.aten._local_scalar_dense.default(select_58); select_58 = None + ge_70 = _local_scalar_dense_58 >= 0 + _assert_scalar_58 = torch.ops.aten._assert_scalar.default(ge_70, "Runtime assertion failed for expression u58 >= 0 on node 'ge_58'"); ge_70 = _assert_scalar_58 = None + select_59 = torch.ops.aten.select.int(device_put_7, 0, 3) + _local_scalar_dense_59 = torch.ops.aten._local_scalar_dense.default(select_59); select_59 = None + ge_71 = _local_scalar_dense_59 >= 0 + _assert_scalar_59 = torch.ops.aten._assert_scalar.default(ge_71, "Runtime assertion failed for expression u59 >= 0 on node 'ge_59'"); ge_71 = _assert_scalar_59 = None + select_60 = torch.ops.aten.select.int(device_put_7, 0, 4) + _local_scalar_dense_60 = torch.ops.aten._local_scalar_dense.default(select_60); select_60 = None + ge_72 = _local_scalar_dense_60 >= 0 + _assert_scalar_60 = torch.ops.aten._assert_scalar.default(ge_72, "Runtime assertion failed for expression u60 >= 0 on node 'ge_60'"); ge_72 = _assert_scalar_60 = None + select_61 = torch.ops.aten.select.int(device_put_7, 0, 5) + _local_scalar_dense_61 = torch.ops.aten._local_scalar_dense.default(select_61); select_61 = None + ge_73 = _local_scalar_dense_61 >= 0 + _assert_scalar_61 = torch.ops.aten._assert_scalar.default(ge_73, "Runtime assertion failed for expression u61 >= 0 on node 'ge_61'"); ge_73 = _assert_scalar_61 = None + select_62 = torch.ops.aten.select.int(device_put_7, 0, 6) + _local_scalar_dense_62 = torch.ops.aten._local_scalar_dense.default(select_62); select_62 = None + ge_74 = _local_scalar_dense_62 >= 0 + _assert_scalar_62 = torch.ops.aten._assert_scalar.default(ge_74, "Runtime assertion failed for expression u62 >= 0 on node 'ge_62'"); ge_74 = _assert_scalar_62 = None + select_63 = torch.ops.aten.select.int(device_put_7, 0, 7); device_put_7 = None + _local_scalar_dense_63 = torch.ops.aten._local_scalar_dense.default(select_63); select_63 = None + ge_75 = _local_scalar_dense_63 >= 0 + _assert_scalar_63 = torch.ops.aten._assert_scalar.default(ge_75, "Runtime assertion failed for expression u63 >= 0 on node 'ge_63'"); ge_75 = _assert_scalar_63 = None + all_to_all_single_10 = torch.ops._c10d_functional.all_to_all_single.default(index_6, [_local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63], [_local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55], '1033'); index_6 = None + sym_size_int_12 = torch.ops.aten.sym_size.int(all_to_all_single_10, 0) + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_10); all_to_all_single_10 = None + sym_sum_6 = torch.sym_sum((_local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63)) + add_222 = sym_sum_6 + 64; sym_sum_6 = None + add_223 = add_222 + 8; add_222 = None + sub_75 = add_223 - 1; add_223 = None + floordiv_3 = sub_75 // 8; sub_75 = None + mul_169 = floordiv_3 * 8; floordiv_3 = None + cumsum_9 = torch.ops.aten.cumsum.default(wait_tensor_83, 0) + sub_76 = torch.ops.aten.sub.Tensor(cumsum_9, wait_tensor_83); cumsum_9 = None + sum_16 = torch.ops.aten.sum.dim_IntList(view_266, [0]); view_266 = None + clamp_min_3 = torch.ops.aten.clamp_min.default(sum_16, 8); sum_16 = None + add_224 = torch.ops.aten.add.Tensor(clamp_min_3, 8); clamp_min_3 = None + sub_77 = torch.ops.aten.sub.Tensor(add_224, 1); add_224 = None + div_18 = torch.ops.aten.div.Tensor_mode(sub_77, 8, rounding_mode = 'floor'); sub_77 = None + mul_170 = torch.ops.aten.mul.Tensor(div_18, 8); div_18 = None + convert_element_type_230 = torch.ops.prims.convert_element_type.default(mul_170, torch.int32); mul_170 = None + cumsum_10 = torch.ops.aten.cumsum.default(convert_element_type_230, 0) + sub_78 = torch.ops.aten.sub.Tensor(cumsum_10, convert_element_type_230); cumsum_10 = None + full_59 = torch.ops.aten.full.default([mul_169], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_169 = None + triton_kernel_wrapper_functional_proxy_3 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_83, 'start_index_values_ptr': sub_76, 'write_offsets_ptr': sub_78, 'output_ptr': full_59}, tensors_to_clone = ['output_ptr']); wait_tensor_83 = sub_76 = sub_78 = full_59 = None + getitem_352 = triton_kernel_wrapper_functional_proxy_3['output_ptr']; triton_kernel_wrapper_functional_proxy_3 = None + cat_31 = torch.ops.aten.cat.default([wait_tensor_84, full_default]); wait_tensor_84 = None + sym_size_int_13 = torch.ops.aten.sym_size.int(cat_31, 0) + sym_sum_7 = torch.sym_sum((1, _local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63)) + index_7 = torch.ops.aten.index.Tensor(cat_31, [getitem_352]); cat_31 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16) + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 16, '1025'); convert_element_type_232 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + split_19 = torch.ops.aten.split.Tensor(wait_tensor_85, 8); wait_tensor_85 = None + getitem_369 = split_19[0] + getitem_370 = split_19[1] + getitem_371 = split_19[2] + getitem_372 = split_19[3] + getitem_373 = split_19[4] + getitem_374 = split_19[5] + getitem_375 = split_19[6] + getitem_376 = split_19[7] + getitem_377 = split_19[8] + getitem_378 = split_19[9] + getitem_379 = split_19[10] + getitem_380 = split_19[11] + getitem_381 = split_19[12] + getitem_382 = split_19[13] + getitem_383 = split_19[14] + getitem_384 = split_19[15]; split_19 = None + cat_33 = torch.ops.aten.cat.default([getitem_369, getitem_370, getitem_371, getitem_372, getitem_373, getitem_374, getitem_375, getitem_376, getitem_377, getitem_378, getitem_379, getitem_380, getitem_381, getitem_382, getitem_383, getitem_384], 1); getitem_369 = getitem_370 = getitem_371 = getitem_372 = getitem_373 = getitem_374 = getitem_375 = getitem_376 = getitem_377 = getitem_378 = getitem_379 = getitem_380 = getitem_381 = getitem_382 = getitem_383 = getitem_384 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16) + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 16, '1025'); convert_element_type_234 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + split_20 = torch.ops.aten.split.Tensor(wait_tensor_87, 8); wait_tensor_87 = None + getitem_385 = split_20[0] + getitem_386 = split_20[1] + getitem_387 = split_20[2] + getitem_388 = split_20[3] + getitem_389 = split_20[4] + getitem_390 = split_20[5] + getitem_391 = split_20[6] + getitem_392 = split_20[7] + getitem_393 = split_20[8] + getitem_394 = split_20[9] + getitem_395 = split_20[10] + getitem_396 = split_20[11] + getitem_397 = split_20[12] + getitem_398 = split_20[13] + getitem_399 = split_20[14] + getitem_400 = split_20[15]; split_20 = None + cat_34 = torch.ops.aten.cat.default([getitem_385, getitem_386, getitem_387, getitem_388, getitem_389, getitem_390, getitem_391, getitem_392, getitem_393, getitem_394, getitem_395, getitem_396, getitem_397, getitem_398, getitem_399, getitem_400], 1); getitem_385 = getitem_386 = getitem_387 = getitem_388 = getitem_389 = getitem_390 = getitem_391 = getitem_392 = getitem_393 = getitem_394 = getitem_395 = getitem_396 = getitem_397 = getitem_398 = getitem_399 = getitem_400 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16) + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 16, '1025'); convert_element_type_235 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + split_21 = torch.ops.aten.split.Tensor(wait_tensor_88, 8); wait_tensor_88 = None + getitem_401 = split_21[0] + getitem_402 = split_21[1] + getitem_403 = split_21[2] + getitem_404 = split_21[3] + getitem_405 = split_21[4] + getitem_406 = split_21[5] + getitem_407 = split_21[6] + getitem_408 = split_21[7] + getitem_409 = split_21[8] + getitem_410 = split_21[9] + getitem_411 = split_21[10] + getitem_412 = split_21[11] + getitem_413 = split_21[12] + getitem_414 = split_21[13] + getitem_415 = split_21[14] + getitem_416 = split_21[15]; split_21 = None + cat_35 = torch.ops.aten.cat.default([getitem_401, getitem_402, getitem_403, getitem_404, getitem_405, getitem_406, getitem_407, getitem_408, getitem_409, getitem_410, getitem_411, getitem_412, getitem_413, getitem_414, getitem_415, getitem_416], 1); getitem_401 = getitem_402 = getitem_403 = getitem_404 = getitem_405 = getitem_406 = getitem_407 = getitem_408 = getitem_409 = getitem_410 = getitem_411 = getitem_412 = getitem_413 = getitem_414 = getitem_415 = getitem_416 = None + cumsum_11 = torch.ops.aten.cumsum.default(convert_element_type_230, 0, dtype = torch.int32); convert_element_type_230 = None + permute_65 = torch.ops.aten.permute.default(cat_33, [0, 2, 1]); cat_33 = None + _grouped_mm_9 = torch.ops.aten._grouped_mm.default(index_7, permute_65, cumsum_11) + convert_element_type_238 = torch.ops.prims.convert_element_type.default(_grouped_mm_9, torch.float32) + neg_7 = torch.ops.aten.neg.default(convert_element_type_238) + exp_11 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_236 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + div_19 = torch.ops.aten.div.Tensor(convert_element_type_238, add_236); convert_element_type_238 = add_236 = None + convert_element_type_239 = torch.ops.prims.convert_element_type.default(div_19, torch.bfloat16); div_19 = None + permute_66 = torch.ops.aten.permute.default(cat_35, [0, 2, 1]); cat_35 = None + _grouped_mm_10 = torch.ops.aten._grouped_mm.default(index_7, permute_66, cumsum_11) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_239, _grouped_mm_10); convert_element_type_239 = None + permute_67 = torch.ops.aten.permute.default(cat_34, [0, 2, 1]); cat_34 = None + _grouped_mm_11 = torch.ops.aten._grouped_mm.default(mul_182, permute_67, cumsum_11) + empty_3 = torch.ops.aten.empty.memory_format([sym_size_int_13, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_6 = torch.ops.aten.index_put.default(empty_3, [getitem_352], _grouped_mm_11); empty_3 = _grouped_mm_11 = None + slice_29 = torch.ops.aten.slice.Tensor(index_put_6, 0, 0, -1); index_put_6 = None + all_to_all_single_11 = torch.ops._c10d_functional.all_to_all_single.default(slice_29, [_local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55], [_local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63], '1033'); slice_29 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_11); all_to_all_single_11 = None + convert_element_type_240 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16) + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_240, 128, '0'); convert_element_type_240 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + mm_36 = torch.ops.aten.mm.default(view_259, permute_68); permute_68 = None + convert_element_type_243 = torch.ops.prims.convert_element_type.default(mm_36, torch.float32) + neg_8 = torch.ops.aten.neg.default(convert_element_type_243) + exp_12 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_272 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + div_20 = torch.ops.aten.div.Tensor(convert_element_type_243, add_272); convert_element_type_243 = add_272 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(div_20, torch.bfloat16); div_20 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16) + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_245, 128, '0'); convert_element_type_245 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_69 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + mm_37 = torch.ops.aten.mm.default(view_259, permute_69); permute_69 = None + mul_202 = torch.ops.aten.mul.Tensor(convert_element_type_244, mm_37); convert_element_type_244 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16) + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 128, '0'); convert_element_type_248 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + permute_70 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + mm_38 = torch.ops.aten.mm.default(mul_202, permute_70); permute_70 = None + index_put_7 = torch.ops.aten.index_put.default(full_default_1, [getitem_351], wait_tensor_91); wait_tensor_91 = None + view_299 = torch.ops.aten.view.default(mul_164, [-1, 1, 6]); mul_164 = None + view_300 = torch.ops.aten.view.default(index_put_7, [-1, 6, 2048]); index_put_7 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None + bmm_3 = torch.ops.aten.bmm.default(view_299, convert_element_type_251) + convert_element_type_252 = torch.ops.prims.convert_element_type.default(bmm_3, torch.bfloat16); bmm_3 = None + squeeze_3 = torch.ops.aten.squeeze.dim(convert_element_type_252, 1); convert_element_type_252 = None + add_276 = torch.ops.aten.add.Tensor(mm_38, squeeze_3); mm_38 = squeeze_3 = None + view_301 = torch.ops.aten.view.default(add_276, [2, 4096, 2048]); add_276 = None + add_277 = torch.ops.aten.add.Tensor(add_212, view_301); view_301 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16) + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 128, '0'); convert_element_type_253 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(add_277, torch.float32) + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_254, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_278 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_278); add_278 = None + mul_205 = torch.ops.aten.mul.Tensor(convert_element_type_254, rsqrt_15); convert_element_type_254 = None + mul_206 = torch.ops.aten.mul.Tensor(mul_205, wait_tensor_95); mul_205 = wait_tensor_95 = None + convert_element_type_255 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_256 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16) + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_256, 128, '0'); convert_element_type_256 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_71 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + view_304 = torch.ops.aten.view.default(convert_element_type_255, [8192, 2048]); convert_element_type_255 = None + mm_39 = torch.ops.aten.mm.default(view_304, permute_71); permute_71 = None + view_305 = torch.ops.aten.view.default(mm_39, [2, 4096, 3072]); mm_39 = None + view_306 = torch.ops.aten.view.default(view_305, [2, 4096, -1, 192]); view_305 = None + split_with_sizes_15 = torch.ops.aten.split_with_sizes.default(view_306, [128, 64], -1); view_306 = None + getitem_449 = split_with_sizes_15[0] + getitem_450 = split_with_sizes_15[1]; split_with_sizes_15 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(getitem_450, torch.float32); getitem_450 = None + view_307 = torch.ops.aten.view.default(convert_element_type_259, [2, 4096, 16, -1, 2]); convert_element_type_259 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_307); view_307 = None + mul_207 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_7); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_207); mul_207 = None + view_309 = torch.ops.aten.view.default(view_as_real_10, [2, 4096, 16, 64]); view_as_real_10 = None + convert_element_type_260 = torch.ops.prims.convert_element_type.default(view_309, torch.bfloat16); view_309 = None + cat_38 = torch.ops.aten.cat.default([getitem_449, convert_element_type_260], -1); getitem_449 = convert_element_type_260 = None + convert_element_type_261 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16) + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_261, 128, '0'); convert_element_type_261 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + slice_31 = torch.ops.aten.slice.Tensor(wait_tensor_97, 0, 0, 576); wait_tensor_97 = None + permute_72 = torch.ops.aten.permute.default(slice_31, [1, 0]); slice_31 = None + mm_40 = torch.ops.aten.mm.default(view_304, permute_72); permute_72 = None + view_312 = torch.ops.aten.view.default(mm_40, [2, 4096, 576]); mm_40 = None + split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_312, [512, 64], -1); view_312 = None + getitem_451 = split_with_sizes_16[0] + getitem_452 = split_with_sizes_16[1]; split_with_sizes_16 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(getitem_452, 2); getitem_452 = None + convert_element_type_264 = torch.ops.prims.convert_element_type.default(unsqueeze_9, torch.float32); unsqueeze_9 = None + view_313 = torch.ops.aten.view.default(convert_element_type_264, [2, 4096, 1, -1, 2]); convert_element_type_264 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_313); view_313 = None + mul_208 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_7); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_208); mul_208 = None + view_315 = torch.ops.aten.view.default(view_as_real_11, [2, 4096, 1, 64]); view_as_real_11 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(view_315, torch.bfloat16); view_315 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16) + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_266, 128, '0'); convert_element_type_266 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(getitem_451, torch.float32) + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_267, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_279 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_279); add_279 = None + mul_209 = torch.ops.aten.mul.Tensor(convert_element_type_267, rsqrt_16); convert_element_type_267 = None + mul_210 = torch.ops.aten.mul.Tensor(mul_209, wait_tensor_98); mul_209 = wait_tensor_98 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(mul_210, torch.bfloat16); mul_210 = None + convert_element_type_269 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16) + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_269, 128, '0'); convert_element_type_269 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + view_318 = torch.ops.aten.view.default(convert_element_type_268, [8192, 512]); convert_element_type_268 = None + mm_41 = torch.ops.aten.mm.default(view_318, permute_73); permute_73 = None + view_319 = torch.ops.aten.view.default(mm_41, [2, 4096, 4096]); mm_41 = None + view_320 = torch.ops.aten.view.default(view_319, [2, 4096, -1, 256]); view_319 = None + split_with_sizes_17 = torch.ops.aten.split_with_sizes.default(view_320, [128, 128], -1); view_320 = None + getitem_453 = split_with_sizes_17[0] + getitem_454 = split_with_sizes_17[1]; split_with_sizes_17 = None + expand_5 = torch.ops.aten.expand.default(convert_element_type_265, [-1, -1, 16, -1]); convert_element_type_265 = None + cat_39 = torch.ops.aten.cat.default([getitem_453, expand_5], -1); getitem_453 = expand_5 = None + permute_74 = torch.ops.aten.permute.default(cat_38, [0, 2, 1, 3]); cat_38 = None + permute_75 = torch.ops.aten.permute.default(cat_39, [0, 2, 1, 3]); cat_39 = None + permute_76 = torch.ops.aten.permute.default(getitem_454, [0, 2, 1, 3]); getitem_454 = None + sdpa_score5 = self.sdpa_score5 + sdpa_mask5 = self.sdpa_mask5 + flex_attention_5 = torch.ops.higher_order.flex_attention(permute_74, permute_75, permute_76, sdpa_score5, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask5), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score5 = sdpa_mask5 = None + getitem_455 = flex_attention_5[0] + getitem_456 = flex_attention_5[1]; flex_attention_5 = None + permute_77 = torch.ops.aten.permute.default(getitem_455, [0, 2, 1, 3]) + view_321 = torch.ops.aten.view.default(permute_77, [2, 4096, -1]); permute_77 = None + convert_element_type_272 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16) + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_272, 128, '0'); convert_element_type_272 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_100, [1, 0]); wait_tensor_100 = None + view_323 = torch.ops.aten.view.default(view_321, [8192, 2048]); view_321 = None + mm_42 = torch.ops.aten.mm.default(view_323, permute_78); view_323 = permute_78 = None + view_324 = torch.ops.aten.view.default(mm_42, [2, 4096, 2048]); mm_42 = None + add_280 = torch.ops.aten.add.Tensor(add_277, view_324); view_324 = None + convert_element_type_275 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16) + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_275, 128, '0'); convert_element_type_275 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + convert_element_type_276 = torch.ops.prims.convert_element_type.default(add_280, torch.float32) + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_276, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_281 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_281); add_281 = None + mul_211 = torch.ops.aten.mul.Tensor(convert_element_type_276, rsqrt_17); convert_element_type_276 = None + mul_212 = torch.ops.aten.mul.Tensor(mul_211, wait_tensor_101); mul_211 = wait_tensor_101 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(mul_212, torch.bfloat16); mul_212 = None + view_326 = torch.ops.aten.view.default(convert_element_type_277, [-1, 2048]); convert_element_type_277 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16) + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_278, 128, '0'); convert_element_type_278 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + slice_33 = torch.ops.aten.slice.Tensor(wait_tensor_102, 0, 0, 64); wait_tensor_102 = None + permute_79 = torch.ops.aten.permute.default(slice_33, [1, 0]); slice_33 = None + mm_43 = torch.ops.aten.mm.default(view_326, permute_79); permute_79 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(mm_43, torch.float32) + amax_4 = torch.ops.aten.amax.default(convert_element_type_281, [1], True) + sub_96 = torch.ops.aten.sub.Tensor(convert_element_type_281, amax_4); convert_element_type_281 = None + exp_13 = torch.ops.aten.exp.default(sub_96); sub_96 = None + sum_17 = torch.ops.aten.sum.dim_IntList(exp_13, [1], True) + div_21 = torch.ops.aten.div.Tensor(exp_13, sum_17); exp_13 = None + add_282 = torch.ops.aten.add.Tensor(div_21, primals_94); primals_94 = None + topk_4 = torch.ops.aten.topk.default(add_282, 6, -1, True, False); add_282 = None + getitem_459 = topk_4[1]; topk_4 = None + gather_4 = torch.ops.aten.gather.default(div_21, 1, getitem_459); div_21 = None + mul_213 = torch.ops.aten.mul.Tensor(gather_4, 1.0); gather_4 = None + view_328 = torch.ops.aten.view.default(getitem_459, [-1]) + histc_8 = torch.ops.aten.histc.default(view_328, 64, 0, 64) + add_283 = torch.ops.aten.add.Tensor(primals_96, histc_8) + sort_4 = torch.ops.aten.sort.stable(view_328, stable = True); view_328 = None + getitem_461 = sort_4[1]; sort_4 = None + div_22 = torch.ops.aten.div.Tensor_mode(getitem_461, 6, rounding_mode = 'floor') + index_8 = torch.ops.aten.index.Tensor(view_326, [div_22]) + all_to_all_single_12 = torch.ops._c10d_functional.all_to_all_single.default(histc_8, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_12); all_to_all_single_12 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_103); wait_tensor_103 = None + view_332 = torch.ops.aten.view.default(histc_8, [8, -1]); histc_8 = None + sum_18 = torch.ops.aten.sum.dim_IntList(view_332, [1]); view_332 = None + device_put_8 = torch.ops.prims.device_put.default(sum_18, device(type='cpu'), True); sum_18 = None + view_333 = torch.ops.aten.view.default(wait_tensor_104, [8, -1]) + sum_19 = torch.ops.aten.sum.dim_IntList(view_333, [1]) + device_put_9 = torch.ops.prims.device_put.default(sum_19, device(type='cpu')); sum_19 = None + select_64 = torch.ops.aten.select.int(device_put_8, 0, 0) + _local_scalar_dense_64 = torch.ops.aten._local_scalar_dense.default(select_64); select_64 = None + ge_80 = _local_scalar_dense_64 >= 0 + _assert_scalar_64 = torch.ops.aten._assert_scalar.default(ge_80, "Runtime assertion failed for expression u64 >= 0 on node 'ge_64'"); ge_80 = _assert_scalar_64 = None + select_65 = torch.ops.aten.select.int(device_put_8, 0, 1) + _local_scalar_dense_65 = torch.ops.aten._local_scalar_dense.default(select_65); select_65 = None + ge_81 = _local_scalar_dense_65 >= 0 + _assert_scalar_65 = torch.ops.aten._assert_scalar.default(ge_81, "Runtime assertion failed for expression u65 >= 0 on node 'ge_65'"); ge_81 = _assert_scalar_65 = None + select_66 = torch.ops.aten.select.int(device_put_8, 0, 2) + _local_scalar_dense_66 = torch.ops.aten._local_scalar_dense.default(select_66); select_66 = None + ge_82 = _local_scalar_dense_66 >= 0 + _assert_scalar_66 = torch.ops.aten._assert_scalar.default(ge_82, "Runtime assertion failed for expression u66 >= 0 on node 'ge_66'"); ge_82 = _assert_scalar_66 = None + select_67 = torch.ops.aten.select.int(device_put_8, 0, 3) + _local_scalar_dense_67 = torch.ops.aten._local_scalar_dense.default(select_67); select_67 = None + ge_83 = _local_scalar_dense_67 >= 0 + _assert_scalar_67 = torch.ops.aten._assert_scalar.default(ge_83, "Runtime assertion failed for expression u67 >= 0 on node 'ge_67'"); ge_83 = _assert_scalar_67 = None + select_68 = torch.ops.aten.select.int(device_put_8, 0, 4) + _local_scalar_dense_68 = torch.ops.aten._local_scalar_dense.default(select_68); select_68 = None + ge_84 = _local_scalar_dense_68 >= 0 + _assert_scalar_68 = torch.ops.aten._assert_scalar.default(ge_84, "Runtime assertion failed for expression u68 >= 0 on node 'ge_68'"); ge_84 = _assert_scalar_68 = None + select_69 = torch.ops.aten.select.int(device_put_8, 0, 5) + _local_scalar_dense_69 = torch.ops.aten._local_scalar_dense.default(select_69); select_69 = None + ge_85 = _local_scalar_dense_69 >= 0 + _assert_scalar_69 = torch.ops.aten._assert_scalar.default(ge_85, "Runtime assertion failed for expression u69 >= 0 on node 'ge_69'"); ge_85 = _assert_scalar_69 = None + select_70 = torch.ops.aten.select.int(device_put_8, 0, 6) + _local_scalar_dense_70 = torch.ops.aten._local_scalar_dense.default(select_70); select_70 = None + ge_86 = _local_scalar_dense_70 >= 0 + _assert_scalar_70 = torch.ops.aten._assert_scalar.default(ge_86, "Runtime assertion failed for expression u70 >= 0 on node 'ge_70'"); ge_86 = _assert_scalar_70 = None + select_71 = torch.ops.aten.select.int(device_put_8, 0, 7); device_put_8 = None + _local_scalar_dense_71 = torch.ops.aten._local_scalar_dense.default(select_71); select_71 = None + ge_87 = _local_scalar_dense_71 >= 0 + _assert_scalar_71 = torch.ops.aten._assert_scalar.default(ge_87, "Runtime assertion failed for expression u71 >= 0 on node 'ge_71'"); ge_87 = _assert_scalar_71 = None + select_72 = torch.ops.aten.select.int(device_put_9, 0, 0) + _local_scalar_dense_72 = torch.ops.aten._local_scalar_dense.default(select_72); select_72 = None + ge_88 = _local_scalar_dense_72 >= 0 + _assert_scalar_72 = torch.ops.aten._assert_scalar.default(ge_88, "Runtime assertion failed for expression u72 >= 0 on node 'ge_72'"); ge_88 = _assert_scalar_72 = None + select_73 = torch.ops.aten.select.int(device_put_9, 0, 1) + _local_scalar_dense_73 = torch.ops.aten._local_scalar_dense.default(select_73); select_73 = None + ge_89 = _local_scalar_dense_73 >= 0 + _assert_scalar_73 = torch.ops.aten._assert_scalar.default(ge_89, "Runtime assertion failed for expression u73 >= 0 on node 'ge_73'"); ge_89 = _assert_scalar_73 = None + select_74 = torch.ops.aten.select.int(device_put_9, 0, 2) + _local_scalar_dense_74 = torch.ops.aten._local_scalar_dense.default(select_74); select_74 = None + ge_90 = _local_scalar_dense_74 >= 0 + _assert_scalar_74 = torch.ops.aten._assert_scalar.default(ge_90, "Runtime assertion failed for expression u74 >= 0 on node 'ge_74'"); ge_90 = _assert_scalar_74 = None + select_75 = torch.ops.aten.select.int(device_put_9, 0, 3) + _local_scalar_dense_75 = torch.ops.aten._local_scalar_dense.default(select_75); select_75 = None + ge_91 = _local_scalar_dense_75 >= 0 + _assert_scalar_75 = torch.ops.aten._assert_scalar.default(ge_91, "Runtime assertion failed for expression u75 >= 0 on node 'ge_75'"); ge_91 = _assert_scalar_75 = None + select_76 = torch.ops.aten.select.int(device_put_9, 0, 4) + _local_scalar_dense_76 = torch.ops.aten._local_scalar_dense.default(select_76); select_76 = None + ge_92 = _local_scalar_dense_76 >= 0 + _assert_scalar_76 = torch.ops.aten._assert_scalar.default(ge_92, "Runtime assertion failed for expression u76 >= 0 on node 'ge_76'"); ge_92 = _assert_scalar_76 = None + select_77 = torch.ops.aten.select.int(device_put_9, 0, 5) + _local_scalar_dense_77 = torch.ops.aten._local_scalar_dense.default(select_77); select_77 = None + ge_93 = _local_scalar_dense_77 >= 0 + _assert_scalar_77 = torch.ops.aten._assert_scalar.default(ge_93, "Runtime assertion failed for expression u77 >= 0 on node 'ge_77'"); ge_93 = _assert_scalar_77 = None + select_78 = torch.ops.aten.select.int(device_put_9, 0, 6) + _local_scalar_dense_78 = torch.ops.aten._local_scalar_dense.default(select_78); select_78 = None + ge_94 = _local_scalar_dense_78 >= 0 + _assert_scalar_78 = torch.ops.aten._assert_scalar.default(ge_94, "Runtime assertion failed for expression u78 >= 0 on node 'ge_78'"); ge_94 = _assert_scalar_78 = None + select_79 = torch.ops.aten.select.int(device_put_9, 0, 7); device_put_9 = None + _local_scalar_dense_79 = torch.ops.aten._local_scalar_dense.default(select_79); select_79 = None + ge_95 = _local_scalar_dense_79 >= 0 + _assert_scalar_79 = torch.ops.aten._assert_scalar.default(ge_95, "Runtime assertion failed for expression u79 >= 0 on node 'ge_79'"); ge_95 = _assert_scalar_79 = None + all_to_all_single_13 = torch.ops._c10d_functional.all_to_all_single.default(index_8, [_local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79], [_local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71], '1033'); index_8 = None + sym_size_int_16 = torch.ops.aten.sym_size.int(all_to_all_single_13, 0) + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_13); all_to_all_single_13 = None + sym_sum_8 = torch.sym_sum((_local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79)) + add_290 = sym_sum_8 + 64; sym_sum_8 = None + add_291 = add_290 + 8; add_290 = None + sub_99 = add_291 - 1; add_291 = None + floordiv_4 = sub_99 // 8; sub_99 = None + mul_218 = floordiv_4 * 8; floordiv_4 = None + cumsum_12 = torch.ops.aten.cumsum.default(wait_tensor_104, 0) + sub_100 = torch.ops.aten.sub.Tensor(cumsum_12, wait_tensor_104); cumsum_12 = None + sum_20 = torch.ops.aten.sum.dim_IntList(view_333, [0]); view_333 = None + clamp_min_4 = torch.ops.aten.clamp_min.default(sum_20, 8); sum_20 = None + add_292 = torch.ops.aten.add.Tensor(clamp_min_4, 8); clamp_min_4 = None + sub_101 = torch.ops.aten.sub.Tensor(add_292, 1); add_292 = None + div_23 = torch.ops.aten.div.Tensor_mode(sub_101, 8, rounding_mode = 'floor'); sub_101 = None + mul_219 = torch.ops.aten.mul.Tensor(div_23, 8); div_23 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(mul_219, torch.int32); mul_219 = None + cumsum_13 = torch.ops.aten.cumsum.default(convert_element_type_284, 0) + sub_102 = torch.ops.aten.sub.Tensor(cumsum_13, convert_element_type_284); cumsum_13 = None + full_72 = torch.ops.aten.full.default([mul_218], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_218 = None + triton_kernel_wrapper_functional_proxy_4 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 4, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_104, 'start_index_values_ptr': sub_100, 'write_offsets_ptr': sub_102, 'output_ptr': full_72}, tensors_to_clone = ['output_ptr']); wait_tensor_104 = sub_100 = sub_102 = full_72 = None + getitem_462 = triton_kernel_wrapper_functional_proxy_4['output_ptr']; triton_kernel_wrapper_functional_proxy_4 = None + cat_40 = torch.ops.aten.cat.default([wait_tensor_105, full_default]); wait_tensor_105 = None + sym_size_int_17 = torch.ops.aten.sym_size.int(cat_40, 0) + sym_sum_9 = torch.sym_sum((1, _local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79)) + index_9 = torch.ops.aten.index.Tensor(cat_40, [getitem_462]); cat_40 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16) + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 16, '1025'); convert_element_type_286 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + split_25 = torch.ops.aten.split.Tensor(wait_tensor_106, 8); wait_tensor_106 = None + getitem_479 = split_25[0] + getitem_480 = split_25[1] + getitem_481 = split_25[2] + getitem_482 = split_25[3] + getitem_483 = split_25[4] + getitem_484 = split_25[5] + getitem_485 = split_25[6] + getitem_486 = split_25[7] + getitem_487 = split_25[8] + getitem_488 = split_25[9] + getitem_489 = split_25[10] + getitem_490 = split_25[11] + getitem_491 = split_25[12] + getitem_492 = split_25[13] + getitem_493 = split_25[14] + getitem_494 = split_25[15]; split_25 = None + cat_42 = torch.ops.aten.cat.default([getitem_479, getitem_480, getitem_481, getitem_482, getitem_483, getitem_484, getitem_485, getitem_486, getitem_487, getitem_488, getitem_489, getitem_490, getitem_491, getitem_492, getitem_493, getitem_494], 1); getitem_479 = getitem_480 = getitem_481 = getitem_482 = getitem_483 = getitem_484 = getitem_485 = getitem_486 = getitem_487 = getitem_488 = getitem_489 = getitem_490 = getitem_491 = getitem_492 = getitem_493 = getitem_494 = None + convert_element_type_288 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16) + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_288, 16, '1025'); convert_element_type_288 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + split_26 = torch.ops.aten.split.Tensor(wait_tensor_108, 8); wait_tensor_108 = None + getitem_495 = split_26[0] + getitem_496 = split_26[1] + getitem_497 = split_26[2] + getitem_498 = split_26[3] + getitem_499 = split_26[4] + getitem_500 = split_26[5] + getitem_501 = split_26[6] + getitem_502 = split_26[7] + getitem_503 = split_26[8] + getitem_504 = split_26[9] + getitem_505 = split_26[10] + getitem_506 = split_26[11] + getitem_507 = split_26[12] + getitem_508 = split_26[13] + getitem_509 = split_26[14] + getitem_510 = split_26[15]; split_26 = None + cat_43 = torch.ops.aten.cat.default([getitem_495, getitem_496, getitem_497, getitem_498, getitem_499, getitem_500, getitem_501, getitem_502, getitem_503, getitem_504, getitem_505, getitem_506, getitem_507, getitem_508, getitem_509, getitem_510], 1); getitem_495 = getitem_496 = getitem_497 = getitem_498 = getitem_499 = getitem_500 = getitem_501 = getitem_502 = getitem_503 = getitem_504 = getitem_505 = getitem_506 = getitem_507 = getitem_508 = getitem_509 = getitem_510 = None + convert_element_type_289 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16) + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_289, 16, '1025'); convert_element_type_289 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + split_27 = torch.ops.aten.split.Tensor(wait_tensor_109, 8); wait_tensor_109 = None + getitem_511 = split_27[0] + getitem_512 = split_27[1] + getitem_513 = split_27[2] + getitem_514 = split_27[3] + getitem_515 = split_27[4] + getitem_516 = split_27[5] + getitem_517 = split_27[6] + getitem_518 = split_27[7] + getitem_519 = split_27[8] + getitem_520 = split_27[9] + getitem_521 = split_27[10] + getitem_522 = split_27[11] + getitem_523 = split_27[12] + getitem_524 = split_27[13] + getitem_525 = split_27[14] + getitem_526 = split_27[15]; split_27 = None + cat_44 = torch.ops.aten.cat.default([getitem_511, getitem_512, getitem_513, getitem_514, getitem_515, getitem_516, getitem_517, getitem_518, getitem_519, getitem_520, getitem_521, getitem_522, getitem_523, getitem_524, getitem_525, getitem_526], 1); getitem_511 = getitem_512 = getitem_513 = getitem_514 = getitem_515 = getitem_516 = getitem_517 = getitem_518 = getitem_519 = getitem_520 = getitem_521 = getitem_522 = getitem_523 = getitem_524 = getitem_525 = getitem_526 = None + cumsum_14 = torch.ops.aten.cumsum.default(convert_element_type_284, 0, dtype = torch.int32); convert_element_type_284 = None + permute_80 = torch.ops.aten.permute.default(cat_42, [0, 2, 1]); cat_42 = None + _grouped_mm_12 = torch.ops.aten._grouped_mm.default(index_9, permute_80, cumsum_14) + convert_element_type_292 = torch.ops.prims.convert_element_type.default(_grouped_mm_12, torch.float32) + neg_9 = torch.ops.aten.neg.default(convert_element_type_292) + exp_14 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_304 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + div_24 = torch.ops.aten.div.Tensor(convert_element_type_292, add_304); convert_element_type_292 = add_304 = None + convert_element_type_293 = torch.ops.prims.convert_element_type.default(div_24, torch.bfloat16); div_24 = None + permute_81 = torch.ops.aten.permute.default(cat_44, [0, 2, 1]); cat_44 = None + _grouped_mm_13 = torch.ops.aten._grouped_mm.default(index_9, permute_81, cumsum_14) + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_293, _grouped_mm_13); convert_element_type_293 = None + permute_82 = torch.ops.aten.permute.default(cat_43, [0, 2, 1]); cat_43 = None + _grouped_mm_14 = torch.ops.aten._grouped_mm.default(mul_231, permute_82, cumsum_14) + empty_4 = torch.ops.aten.empty.memory_format([sym_size_int_17, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_8 = torch.ops.aten.index_put.default(empty_4, [getitem_462], _grouped_mm_14); empty_4 = _grouped_mm_14 = None + slice_35 = torch.ops.aten.slice.Tensor(index_put_8, 0, 0, -1); index_put_8 = None + all_to_all_single_14 = torch.ops._c10d_functional.all_to_all_single.default(slice_35, [_local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71], [_local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79], '1033'); slice_35 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_14); all_to_all_single_14 = None + convert_element_type_294 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16) + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_294, 128, '0'); convert_element_type_294 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_83 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + mm_44 = torch.ops.aten.mm.default(view_326, permute_83); permute_83 = None + convert_element_type_297 = torch.ops.prims.convert_element_type.default(mm_44, torch.float32) + neg_10 = torch.ops.aten.neg.default(convert_element_type_297) + exp_15 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_340 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + div_25 = torch.ops.aten.div.Tensor(convert_element_type_297, add_340); convert_element_type_297 = add_340 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(div_25, torch.bfloat16); div_25 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16) + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_299, 128, '0'); convert_element_type_299 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_114, [1, 0]); wait_tensor_114 = None + mm_45 = torch.ops.aten.mm.default(view_326, permute_84); permute_84 = None + mul_251 = torch.ops.aten.mul.Tensor(convert_element_type_298, mm_45); convert_element_type_298 = None + convert_element_type_302 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16) + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_302, 128, '0'); convert_element_type_302 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + mm_46 = torch.ops.aten.mm.default(mul_251, permute_85); permute_85 = None + index_put_9 = torch.ops.aten.index_put.default(full_default_1, [getitem_461], wait_tensor_112); wait_tensor_112 = None + view_366 = torch.ops.aten.view.default(mul_213, [-1, 1, 6]); mul_213 = None + view_367 = torch.ops.aten.view.default(index_put_9, [-1, 6, 2048]); index_put_9 = None + convert_element_type_305 = torch.ops.prims.convert_element_type.default(view_367, torch.float32); view_367 = None + bmm_4 = torch.ops.aten.bmm.default(view_366, convert_element_type_305) + convert_element_type_306 = torch.ops.prims.convert_element_type.default(bmm_4, torch.bfloat16); bmm_4 = None + squeeze_4 = torch.ops.aten.squeeze.dim(convert_element_type_306, 1); convert_element_type_306 = None + add_344 = torch.ops.aten.add.Tensor(mm_46, squeeze_4); mm_46 = squeeze_4 = None + view_368 = torch.ops.aten.view.default(add_344, [2, 4096, 2048]); add_344 = None + add_345 = torch.ops.aten.add.Tensor(add_280, view_368); view_368 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16) + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 128, '0'); convert_element_type_307 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_308 = torch.ops.prims.convert_element_type.default(add_345, torch.float32) + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_308, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_346 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_346); add_346 = None + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_308, rsqrt_18); convert_element_type_308 = None + mul_255 = torch.ops.aten.mul.Tensor(mul_254, wait_tensor_116); mul_254 = wait_tensor_116 = None + convert_element_type_309 = torch.ops.prims.convert_element_type.default(mul_255, torch.bfloat16); mul_255 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16) + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_310, 128, '0'); convert_element_type_310 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + view_371 = torch.ops.aten.view.default(convert_element_type_309, [8192, 2048]); convert_element_type_309 = None + mm_47 = torch.ops.aten.mm.default(view_371, permute_86); permute_86 = None + view_372 = torch.ops.aten.view.default(mm_47, [2, 4096, 3072]); mm_47 = None + view_373 = torch.ops.aten.view.default(view_372, [2, 4096, -1, 192]); view_372 = None + split_with_sizes_18 = torch.ops.aten.split_with_sizes.default(view_373, [128, 64], -1); view_373 = None + getitem_559 = split_with_sizes_18[0] + getitem_560 = split_with_sizes_18[1]; split_with_sizes_18 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(getitem_560, torch.float32); getitem_560 = None + view_374 = torch.ops.aten.view.default(convert_element_type_313, [2, 4096, 16, -1, 2]); convert_element_type_313 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_374); view_374 = None + mul_256 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_7); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_256); mul_256 = None + view_376 = torch.ops.aten.view.default(view_as_real_12, [2, 4096, 16, 64]); view_as_real_12 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(view_376, torch.bfloat16); view_376 = None + cat_47 = torch.ops.aten.cat.default([getitem_559, convert_element_type_314], -1); getitem_559 = convert_element_type_314 = None + convert_element_type_315 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16) + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_315, 128, '0'); convert_element_type_315 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + slice_37 = torch.ops.aten.slice.Tensor(wait_tensor_118, 0, 0, 576); wait_tensor_118 = None + permute_87 = torch.ops.aten.permute.default(slice_37, [1, 0]); slice_37 = None + mm_48 = torch.ops.aten.mm.default(view_371, permute_87); permute_87 = None + view_379 = torch.ops.aten.view.default(mm_48, [2, 4096, 576]); mm_48 = None + split_with_sizes_19 = torch.ops.aten.split_with_sizes.default(view_379, [512, 64], -1); view_379 = None + getitem_561 = split_with_sizes_19[0] + getitem_562 = split_with_sizes_19[1]; split_with_sizes_19 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(getitem_562, 2); getitem_562 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(unsqueeze_11, torch.float32); unsqueeze_11 = None + view_380 = torch.ops.aten.view.default(convert_element_type_318, [2, 4096, 1, -1, 2]); convert_element_type_318 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_380); view_380 = None + mul_257 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_7); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_257); mul_257 = None + view_382 = torch.ops.aten.view.default(view_as_real_13, [2, 4096, 1, 64]); view_as_real_13 = None + convert_element_type_319 = torch.ops.prims.convert_element_type.default(view_382, torch.bfloat16); view_382 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16) + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 128, '0'); convert_element_type_320 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + convert_element_type_321 = torch.ops.prims.convert_element_type.default(getitem_561, torch.float32) + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_321, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_347 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_347); add_347 = None + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_321, rsqrt_19); convert_element_type_321 = None + mul_259 = torch.ops.aten.mul.Tensor(mul_258, wait_tensor_119); mul_258 = wait_tensor_119 = None + convert_element_type_322 = torch.ops.prims.convert_element_type.default(mul_259, torch.bfloat16); mul_259 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16) + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_323, 128, '0'); convert_element_type_323 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + view_385 = torch.ops.aten.view.default(convert_element_type_322, [8192, 512]); convert_element_type_322 = None + mm_49 = torch.ops.aten.mm.default(view_385, permute_88); permute_88 = None + view_386 = torch.ops.aten.view.default(mm_49, [2, 4096, 4096]); mm_49 = None + view_387 = torch.ops.aten.view.default(view_386, [2, 4096, -1, 256]); view_386 = None + split_with_sizes_20 = torch.ops.aten.split_with_sizes.default(view_387, [128, 128], -1); view_387 = None + getitem_563 = split_with_sizes_20[0] + getitem_564 = split_with_sizes_20[1]; split_with_sizes_20 = None + expand_6 = torch.ops.aten.expand.default(convert_element_type_319, [-1, -1, 16, -1]); convert_element_type_319 = None + cat_48 = torch.ops.aten.cat.default([getitem_563, expand_6], -1); getitem_563 = expand_6 = None + permute_89 = torch.ops.aten.permute.default(cat_47, [0, 2, 1, 3]); cat_47 = None + permute_90 = torch.ops.aten.permute.default(cat_48, [0, 2, 1, 3]); cat_48 = None + permute_91 = torch.ops.aten.permute.default(getitem_564, [0, 2, 1, 3]); getitem_564 = None + sdpa_score6 = self.sdpa_score6 + sdpa_mask6 = self.sdpa_mask6 + flex_attention_6 = torch.ops.higher_order.flex_attention(permute_89, permute_90, permute_91, sdpa_score6, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask6), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score6 = sdpa_mask6 = None + getitem_565 = flex_attention_6[0] + getitem_566 = flex_attention_6[1]; flex_attention_6 = None + permute_92 = torch.ops.aten.permute.default(getitem_565, [0, 2, 1, 3]) + view_388 = torch.ops.aten.view.default(permute_92, [2, 4096, -1]); permute_92 = None + convert_element_type_326 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16) + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_326, 128, '0'); convert_element_type_326 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_93 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + view_390 = torch.ops.aten.view.default(view_388, [8192, 2048]); view_388 = None + mm_50 = torch.ops.aten.mm.default(view_390, permute_93); view_390 = permute_93 = None + view_391 = torch.ops.aten.view.default(mm_50, [2, 4096, 2048]); mm_50 = None + add_348 = torch.ops.aten.add.Tensor(add_345, view_391); view_391 = None + convert_element_type_329 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16) + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_329, 128, '0'); convert_element_type_329 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + convert_element_type_330 = torch.ops.prims.convert_element_type.default(add_348, torch.float32) + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_330, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_349 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_349); add_349 = None + mul_260 = torch.ops.aten.mul.Tensor(convert_element_type_330, rsqrt_20); convert_element_type_330 = None + mul_261 = torch.ops.aten.mul.Tensor(mul_260, wait_tensor_122); mul_260 = wait_tensor_122 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(mul_261, torch.bfloat16); mul_261 = None + view_393 = torch.ops.aten.view.default(convert_element_type_331, [-1, 2048]); convert_element_type_331 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16) + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_332, 128, '0'); convert_element_type_332 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + slice_39 = torch.ops.aten.slice.Tensor(wait_tensor_123, 0, 0, 64); wait_tensor_123 = None + permute_94 = torch.ops.aten.permute.default(slice_39, [1, 0]); slice_39 = None + mm_51 = torch.ops.aten.mm.default(view_393, permute_94); permute_94 = None + convert_element_type_335 = torch.ops.prims.convert_element_type.default(mm_51, torch.float32) + amax_5 = torch.ops.aten.amax.default(convert_element_type_335, [1], True) + sub_120 = torch.ops.aten.sub.Tensor(convert_element_type_335, amax_5); convert_element_type_335 = None + exp_16 = torch.ops.aten.exp.default(sub_120); sub_120 = None + sum_21 = torch.ops.aten.sum.dim_IntList(exp_16, [1], True) + div_26 = torch.ops.aten.div.Tensor(exp_16, sum_21); exp_16 = None + add_350 = torch.ops.aten.add.Tensor(div_26, primals_110); primals_110 = None + topk_5 = torch.ops.aten.topk.default(add_350, 6, -1, True, False); add_350 = None + getitem_569 = topk_5[1]; topk_5 = None + gather_5 = torch.ops.aten.gather.default(div_26, 1, getitem_569); div_26 = None + mul_262 = torch.ops.aten.mul.Tensor(gather_5, 1.0); gather_5 = None + view_395 = torch.ops.aten.view.default(getitem_569, [-1]) + histc_10 = torch.ops.aten.histc.default(view_395, 64, 0, 64) + add_351 = torch.ops.aten.add.Tensor(primals_112, histc_10) + sort_5 = torch.ops.aten.sort.stable(view_395, stable = True); view_395 = None + getitem_571 = sort_5[1]; sort_5 = None + div_27 = torch.ops.aten.div.Tensor_mode(getitem_571, 6, rounding_mode = 'floor') + index_10 = torch.ops.aten.index.Tensor(view_393, [div_27]) + all_to_all_single_15 = torch.ops._c10d_functional.all_to_all_single.default(histc_10, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_15); all_to_all_single_15 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_124); wait_tensor_124 = None + view_399 = torch.ops.aten.view.default(histc_10, [8, -1]); histc_10 = None + sum_22 = torch.ops.aten.sum.dim_IntList(view_399, [1]); view_399 = None + device_put_10 = torch.ops.prims.device_put.default(sum_22, device(type='cpu'), True); sum_22 = None + view_400 = torch.ops.aten.view.default(wait_tensor_125, [8, -1]) + sum_23 = torch.ops.aten.sum.dim_IntList(view_400, [1]) + device_put_11 = torch.ops.prims.device_put.default(sum_23, device(type='cpu')); sum_23 = None + select_80 = torch.ops.aten.select.int(device_put_10, 0, 0) + _local_scalar_dense_80 = torch.ops.aten._local_scalar_dense.default(select_80); select_80 = None + ge_100 = _local_scalar_dense_80 >= 0 + _assert_scalar_80 = torch.ops.aten._assert_scalar.default(ge_100, "Runtime assertion failed for expression u80 >= 0 on node 'ge_80'"); ge_100 = _assert_scalar_80 = None + select_81 = torch.ops.aten.select.int(device_put_10, 0, 1) + _local_scalar_dense_81 = torch.ops.aten._local_scalar_dense.default(select_81); select_81 = None + ge_101 = _local_scalar_dense_81 >= 0 + _assert_scalar_81 = torch.ops.aten._assert_scalar.default(ge_101, "Runtime assertion failed for expression u81 >= 0 on node 'ge_81'"); ge_101 = _assert_scalar_81 = None + select_82 = torch.ops.aten.select.int(device_put_10, 0, 2) + _local_scalar_dense_82 = torch.ops.aten._local_scalar_dense.default(select_82); select_82 = None + ge_102 = _local_scalar_dense_82 >= 0 + _assert_scalar_82 = torch.ops.aten._assert_scalar.default(ge_102, "Runtime assertion failed for expression u82 >= 0 on node 'ge_82'"); ge_102 = _assert_scalar_82 = None + select_83 = torch.ops.aten.select.int(device_put_10, 0, 3) + _local_scalar_dense_83 = torch.ops.aten._local_scalar_dense.default(select_83); select_83 = None + ge_103 = _local_scalar_dense_83 >= 0 + _assert_scalar_83 = torch.ops.aten._assert_scalar.default(ge_103, "Runtime assertion failed for expression u83 >= 0 on node 'ge_83'"); ge_103 = _assert_scalar_83 = None + select_84 = torch.ops.aten.select.int(device_put_10, 0, 4) + _local_scalar_dense_84 = torch.ops.aten._local_scalar_dense.default(select_84); select_84 = None + ge_104 = _local_scalar_dense_84 >= 0 + _assert_scalar_84 = torch.ops.aten._assert_scalar.default(ge_104, "Runtime assertion failed for expression u84 >= 0 on node 'ge_84'"); ge_104 = _assert_scalar_84 = None + select_85 = torch.ops.aten.select.int(device_put_10, 0, 5) + _local_scalar_dense_85 = torch.ops.aten._local_scalar_dense.default(select_85); select_85 = None + ge_105 = _local_scalar_dense_85 >= 0 + _assert_scalar_85 = torch.ops.aten._assert_scalar.default(ge_105, "Runtime assertion failed for expression u85 >= 0 on node 'ge_85'"); ge_105 = _assert_scalar_85 = None + select_86 = torch.ops.aten.select.int(device_put_10, 0, 6) + _local_scalar_dense_86 = torch.ops.aten._local_scalar_dense.default(select_86); select_86 = None + ge_106 = _local_scalar_dense_86 >= 0 + _assert_scalar_86 = torch.ops.aten._assert_scalar.default(ge_106, "Runtime assertion failed for expression u86 >= 0 on node 'ge_86'"); ge_106 = _assert_scalar_86 = None + select_87 = torch.ops.aten.select.int(device_put_10, 0, 7); device_put_10 = None + _local_scalar_dense_87 = torch.ops.aten._local_scalar_dense.default(select_87); select_87 = None + ge_107 = _local_scalar_dense_87 >= 0 + _assert_scalar_87 = torch.ops.aten._assert_scalar.default(ge_107, "Runtime assertion failed for expression u87 >= 0 on node 'ge_87'"); ge_107 = _assert_scalar_87 = None + select_88 = torch.ops.aten.select.int(device_put_11, 0, 0) + _local_scalar_dense_88 = torch.ops.aten._local_scalar_dense.default(select_88); select_88 = None + ge_108 = _local_scalar_dense_88 >= 0 + _assert_scalar_88 = torch.ops.aten._assert_scalar.default(ge_108, "Runtime assertion failed for expression u88 >= 0 on node 'ge_88'"); ge_108 = _assert_scalar_88 = None + select_89 = torch.ops.aten.select.int(device_put_11, 0, 1) + _local_scalar_dense_89 = torch.ops.aten._local_scalar_dense.default(select_89); select_89 = None + ge_109 = _local_scalar_dense_89 >= 0 + _assert_scalar_89 = torch.ops.aten._assert_scalar.default(ge_109, "Runtime assertion failed for expression u89 >= 0 on node 'ge_89'"); ge_109 = _assert_scalar_89 = None + select_90 = torch.ops.aten.select.int(device_put_11, 0, 2) + _local_scalar_dense_90 = torch.ops.aten._local_scalar_dense.default(select_90); select_90 = None + ge_110 = _local_scalar_dense_90 >= 0 + _assert_scalar_90 = torch.ops.aten._assert_scalar.default(ge_110, "Runtime assertion failed for expression u90 >= 0 on node 'ge_90'"); ge_110 = _assert_scalar_90 = None + select_91 = torch.ops.aten.select.int(device_put_11, 0, 3) + _local_scalar_dense_91 = torch.ops.aten._local_scalar_dense.default(select_91); select_91 = None + ge_111 = _local_scalar_dense_91 >= 0 + _assert_scalar_91 = torch.ops.aten._assert_scalar.default(ge_111, "Runtime assertion failed for expression u91 >= 0 on node 'ge_91'"); ge_111 = _assert_scalar_91 = None + select_92 = torch.ops.aten.select.int(device_put_11, 0, 4) + _local_scalar_dense_92 = torch.ops.aten._local_scalar_dense.default(select_92); select_92 = None + ge_112 = _local_scalar_dense_92 >= 0 + _assert_scalar_92 = torch.ops.aten._assert_scalar.default(ge_112, "Runtime assertion failed for expression u92 >= 0 on node 'ge_92'"); ge_112 = _assert_scalar_92 = None + select_93 = torch.ops.aten.select.int(device_put_11, 0, 5) + _local_scalar_dense_93 = torch.ops.aten._local_scalar_dense.default(select_93); select_93 = None + ge_113 = _local_scalar_dense_93 >= 0 + _assert_scalar_93 = torch.ops.aten._assert_scalar.default(ge_113, "Runtime assertion failed for expression u93 >= 0 on node 'ge_93'"); ge_113 = _assert_scalar_93 = None + select_94 = torch.ops.aten.select.int(device_put_11, 0, 6) + _local_scalar_dense_94 = torch.ops.aten._local_scalar_dense.default(select_94); select_94 = None + ge_114 = _local_scalar_dense_94 >= 0 + _assert_scalar_94 = torch.ops.aten._assert_scalar.default(ge_114, "Runtime assertion failed for expression u94 >= 0 on node 'ge_94'"); ge_114 = _assert_scalar_94 = None + select_95 = torch.ops.aten.select.int(device_put_11, 0, 7); device_put_11 = None + _local_scalar_dense_95 = torch.ops.aten._local_scalar_dense.default(select_95); select_95 = None + ge_115 = _local_scalar_dense_95 >= 0 + _assert_scalar_95 = torch.ops.aten._assert_scalar.default(ge_115, "Runtime assertion failed for expression u95 >= 0 on node 'ge_95'"); ge_115 = _assert_scalar_95 = None + all_to_all_single_16 = torch.ops._c10d_functional.all_to_all_single.default(index_10, [_local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95], [_local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87], '1033'); index_10 = None + sym_size_int_20 = torch.ops.aten.sym_size.int(all_to_all_single_16, 0) + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_16); all_to_all_single_16 = None + sym_sum_10 = torch.sym_sum((_local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95)) + add_358 = sym_sum_10 + 64; sym_sum_10 = None + add_359 = add_358 + 8; add_358 = None + sub_123 = add_359 - 1; add_359 = None + floordiv_5 = sub_123 // 8; sub_123 = None + mul_267 = floordiv_5 * 8; floordiv_5 = None + cumsum_15 = torch.ops.aten.cumsum.default(wait_tensor_125, 0) + sub_124 = torch.ops.aten.sub.Tensor(cumsum_15, wait_tensor_125); cumsum_15 = None + sum_24 = torch.ops.aten.sum.dim_IntList(view_400, [0]); view_400 = None + clamp_min_5 = torch.ops.aten.clamp_min.default(sum_24, 8); sum_24 = None + add_360 = torch.ops.aten.add.Tensor(clamp_min_5, 8); clamp_min_5 = None + sub_125 = torch.ops.aten.sub.Tensor(add_360, 1); add_360 = None + div_28 = torch.ops.aten.div.Tensor_mode(sub_125, 8, rounding_mode = 'floor'); sub_125 = None + mul_268 = torch.ops.aten.mul.Tensor(div_28, 8); div_28 = None + convert_element_type_338 = torch.ops.prims.convert_element_type.default(mul_268, torch.int32); mul_268 = None + cumsum_16 = torch.ops.aten.cumsum.default(convert_element_type_338, 0) + sub_126 = torch.ops.aten.sub.Tensor(cumsum_16, convert_element_type_338); cumsum_16 = None + full_85 = torch.ops.aten.full.default([mul_267], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_267 = None + triton_kernel_wrapper_functional_proxy_5 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 5, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_125, 'start_index_values_ptr': sub_124, 'write_offsets_ptr': sub_126, 'output_ptr': full_85}, tensors_to_clone = ['output_ptr']); wait_tensor_125 = sub_124 = sub_126 = full_85 = None + getitem_572 = triton_kernel_wrapper_functional_proxy_5['output_ptr']; triton_kernel_wrapper_functional_proxy_5 = None + cat_49 = torch.ops.aten.cat.default([wait_tensor_126, full_default]); wait_tensor_126 = None + sym_size_int_21 = torch.ops.aten.sym_size.int(cat_49, 0) + sym_sum_11 = torch.sym_sum((1, _local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95)) + index_11 = torch.ops.aten.index.Tensor(cat_49, [getitem_572]); cat_49 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16) + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 16, '1025'); convert_element_type_340 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + split_31 = torch.ops.aten.split.Tensor(wait_tensor_127, 8); wait_tensor_127 = None + getitem_589 = split_31[0] + getitem_590 = split_31[1] + getitem_591 = split_31[2] + getitem_592 = split_31[3] + getitem_593 = split_31[4] + getitem_594 = split_31[5] + getitem_595 = split_31[6] + getitem_596 = split_31[7] + getitem_597 = split_31[8] + getitem_598 = split_31[9] + getitem_599 = split_31[10] + getitem_600 = split_31[11] + getitem_601 = split_31[12] + getitem_602 = split_31[13] + getitem_603 = split_31[14] + getitem_604 = split_31[15]; split_31 = None + cat_51 = torch.ops.aten.cat.default([getitem_589, getitem_590, getitem_591, getitem_592, getitem_593, getitem_594, getitem_595, getitem_596, getitem_597, getitem_598, getitem_599, getitem_600, getitem_601, getitem_602, getitem_603, getitem_604], 1); getitem_589 = getitem_590 = getitem_591 = getitem_592 = getitem_593 = getitem_594 = getitem_595 = getitem_596 = getitem_597 = getitem_598 = getitem_599 = getitem_600 = getitem_601 = getitem_602 = getitem_603 = getitem_604 = None + convert_element_type_342 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16) + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_342, 16, '1025'); convert_element_type_342 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + split_32 = torch.ops.aten.split.Tensor(wait_tensor_129, 8); wait_tensor_129 = None + getitem_605 = split_32[0] + getitem_606 = split_32[1] + getitem_607 = split_32[2] + getitem_608 = split_32[3] + getitem_609 = split_32[4] + getitem_610 = split_32[5] + getitem_611 = split_32[6] + getitem_612 = split_32[7] + getitem_613 = split_32[8] + getitem_614 = split_32[9] + getitem_615 = split_32[10] + getitem_616 = split_32[11] + getitem_617 = split_32[12] + getitem_618 = split_32[13] + getitem_619 = split_32[14] + getitem_620 = split_32[15]; split_32 = None + cat_52 = torch.ops.aten.cat.default([getitem_605, getitem_606, getitem_607, getitem_608, getitem_609, getitem_610, getitem_611, getitem_612, getitem_613, getitem_614, getitem_615, getitem_616, getitem_617, getitem_618, getitem_619, getitem_620], 1); getitem_605 = getitem_606 = getitem_607 = getitem_608 = getitem_609 = getitem_610 = getitem_611 = getitem_612 = getitem_613 = getitem_614 = getitem_615 = getitem_616 = getitem_617 = getitem_618 = getitem_619 = getitem_620 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16) + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_343, 16, '1025'); convert_element_type_343 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + split_33 = torch.ops.aten.split.Tensor(wait_tensor_130, 8); wait_tensor_130 = None + getitem_621 = split_33[0] + getitem_622 = split_33[1] + getitem_623 = split_33[2] + getitem_624 = split_33[3] + getitem_625 = split_33[4] + getitem_626 = split_33[5] + getitem_627 = split_33[6] + getitem_628 = split_33[7] + getitem_629 = split_33[8] + getitem_630 = split_33[9] + getitem_631 = split_33[10] + getitem_632 = split_33[11] + getitem_633 = split_33[12] + getitem_634 = split_33[13] + getitem_635 = split_33[14] + getitem_636 = split_33[15]; split_33 = None + cat_53 = torch.ops.aten.cat.default([getitem_621, getitem_622, getitem_623, getitem_624, getitem_625, getitem_626, getitem_627, getitem_628, getitem_629, getitem_630, getitem_631, getitem_632, getitem_633, getitem_634, getitem_635, getitem_636], 1); getitem_621 = getitem_622 = getitem_623 = getitem_624 = getitem_625 = getitem_626 = getitem_627 = getitem_628 = getitem_629 = getitem_630 = getitem_631 = getitem_632 = getitem_633 = getitem_634 = getitem_635 = getitem_636 = None + cumsum_17 = torch.ops.aten.cumsum.default(convert_element_type_338, 0, dtype = torch.int32); convert_element_type_338 = None + permute_95 = torch.ops.aten.permute.default(cat_51, [0, 2, 1]); cat_51 = None + _grouped_mm_15 = torch.ops.aten._grouped_mm.default(index_11, permute_95, cumsum_17) + convert_element_type_346 = torch.ops.prims.convert_element_type.default(_grouped_mm_15, torch.float32) + neg_11 = torch.ops.aten.neg.default(convert_element_type_346) + exp_17 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_372 = torch.ops.aten.add.Tensor(exp_17, 1); exp_17 = None + div_29 = torch.ops.aten.div.Tensor(convert_element_type_346, add_372); convert_element_type_346 = add_372 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(div_29, torch.bfloat16); div_29 = None + permute_96 = torch.ops.aten.permute.default(cat_53, [0, 2, 1]); cat_53 = None + _grouped_mm_16 = torch.ops.aten._grouped_mm.default(index_11, permute_96, cumsum_17) + mul_280 = torch.ops.aten.mul.Tensor(convert_element_type_347, _grouped_mm_16); convert_element_type_347 = None + permute_97 = torch.ops.aten.permute.default(cat_52, [0, 2, 1]); cat_52 = None + _grouped_mm_17 = torch.ops.aten._grouped_mm.default(mul_280, permute_97, cumsum_17) + empty_5 = torch.ops.aten.empty.memory_format([sym_size_int_21, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_10 = torch.ops.aten.index_put.default(empty_5, [getitem_572], _grouped_mm_17); empty_5 = _grouped_mm_17 = None + slice_41 = torch.ops.aten.slice.Tensor(index_put_10, 0, 0, -1); index_put_10 = None + all_to_all_single_17 = torch.ops._c10d_functional.all_to_all_single.default(slice_41, [_local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87], [_local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95], '1033'); slice_41 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_17); all_to_all_single_17 = None + convert_element_type_348 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16) + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_348, 128, '0'); convert_element_type_348 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + mm_52 = torch.ops.aten.mm.default(view_393, permute_98); permute_98 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(mm_52, torch.float32) + neg_12 = torch.ops.aten.neg.default(convert_element_type_351) + exp_18 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_408 = torch.ops.aten.add.Tensor(exp_18, 1); exp_18 = None + div_30 = torch.ops.aten.div.Tensor(convert_element_type_351, add_408); convert_element_type_351 = add_408 = None + convert_element_type_352 = torch.ops.prims.convert_element_type.default(div_30, torch.bfloat16); div_30 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16) + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 128, '0'); convert_element_type_353 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + mm_53 = torch.ops.aten.mm.default(view_393, permute_99); permute_99 = None + mul_300 = torch.ops.aten.mul.Tensor(convert_element_type_352, mm_53); convert_element_type_352 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16) + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_356, 128, '0'); convert_element_type_356 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + mm_54 = torch.ops.aten.mm.default(mul_300, permute_100); permute_100 = None + index_put_11 = torch.ops.aten.index_put.default(full_default_1, [getitem_571], wait_tensor_133); wait_tensor_133 = None + view_433 = torch.ops.aten.view.default(mul_262, [-1, 1, 6]); mul_262 = None + view_434 = torch.ops.aten.view.default(index_put_11, [-1, 6, 2048]); index_put_11 = None + convert_element_type_359 = torch.ops.prims.convert_element_type.default(view_434, torch.float32); view_434 = None + bmm_5 = torch.ops.aten.bmm.default(view_433, convert_element_type_359) + convert_element_type_360 = torch.ops.prims.convert_element_type.default(bmm_5, torch.bfloat16); bmm_5 = None + squeeze_5 = torch.ops.aten.squeeze.dim(convert_element_type_360, 1); convert_element_type_360 = None + add_412 = torch.ops.aten.add.Tensor(mm_54, squeeze_5); mm_54 = squeeze_5 = None + view_435 = torch.ops.aten.view.default(add_412, [2, 4096, 2048]); add_412 = None + add_413 = torch.ops.aten.add.Tensor(add_348, view_435); view_435 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16) + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 128, '0'); convert_element_type_361 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + convert_element_type_362 = torch.ops.prims.convert_element_type.default(add_413, torch.float32) + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_362, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_414 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_414); add_414 = None + mul_303 = torch.ops.aten.mul.Tensor(convert_element_type_362, rsqrt_21); convert_element_type_362 = None + mul_304 = torch.ops.aten.mul.Tensor(mul_303, wait_tensor_137); mul_303 = wait_tensor_137 = None + convert_element_type_363 = torch.ops.prims.convert_element_type.default(mul_304, torch.bfloat16); mul_304 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16) + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 128, '0'); convert_element_type_364 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + view_438 = torch.ops.aten.view.default(convert_element_type_363, [8192, 2048]); convert_element_type_363 = None + mm_55 = torch.ops.aten.mm.default(view_438, permute_101); permute_101 = None + view_439 = torch.ops.aten.view.default(mm_55, [2, 4096, 3072]); mm_55 = None + view_440 = torch.ops.aten.view.default(view_439, [2, 4096, -1, 192]); view_439 = None + split_with_sizes_21 = torch.ops.aten.split_with_sizes.default(view_440, [128, 64], -1); view_440 = None + getitem_669 = split_with_sizes_21[0] + getitem_670 = split_with_sizes_21[1]; split_with_sizes_21 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(getitem_670, torch.float32); getitem_670 = None + view_441 = torch.ops.aten.view.default(convert_element_type_367, [2, 4096, 16, -1, 2]); convert_element_type_367 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_441); view_441 = None + mul_305 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_7); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_305); mul_305 = None + view_443 = torch.ops.aten.view.default(view_as_real_14, [2, 4096, 16, 64]); view_as_real_14 = None + convert_element_type_368 = torch.ops.prims.convert_element_type.default(view_443, torch.bfloat16); view_443 = None + cat_56 = torch.ops.aten.cat.default([getitem_669, convert_element_type_368], -1); getitem_669 = convert_element_type_368 = None + convert_element_type_369 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16) + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_369, 128, '0'); convert_element_type_369 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + slice_43 = torch.ops.aten.slice.Tensor(wait_tensor_139, 0, 0, 576); wait_tensor_139 = None + permute_102 = torch.ops.aten.permute.default(slice_43, [1, 0]); slice_43 = None + mm_56 = torch.ops.aten.mm.default(view_438, permute_102); permute_102 = None + view_446 = torch.ops.aten.view.default(mm_56, [2, 4096, 576]); mm_56 = None + split_with_sizes_22 = torch.ops.aten.split_with_sizes.default(view_446, [512, 64], -1); view_446 = None + getitem_671 = split_with_sizes_22[0] + getitem_672 = split_with_sizes_22[1]; split_with_sizes_22 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(getitem_672, 2); getitem_672 = None + convert_element_type_372 = torch.ops.prims.convert_element_type.default(unsqueeze_13, torch.float32); unsqueeze_13 = None + view_447 = torch.ops.aten.view.default(convert_element_type_372, [2, 4096, 1, -1, 2]); convert_element_type_372 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_447); view_447 = None + mul_306 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_7); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_306); mul_306 = None + view_449 = torch.ops.aten.view.default(view_as_real_15, [2, 4096, 1, 64]); view_as_real_15 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(view_449, torch.bfloat16); view_449 = None + convert_element_type_374 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16) + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_374, 128, '0'); convert_element_type_374 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + convert_element_type_375 = torch.ops.prims.convert_element_type.default(getitem_671, torch.float32) + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_375, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_415 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_415); add_415 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_375, rsqrt_22); convert_element_type_375 = None + mul_308 = torch.ops.aten.mul.Tensor(mul_307, wait_tensor_140); mul_307 = wait_tensor_140 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(mul_308, torch.bfloat16); mul_308 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16) + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_377, 128, '0'); convert_element_type_377 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_103 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + view_452 = torch.ops.aten.view.default(convert_element_type_376, [8192, 512]); convert_element_type_376 = None + mm_57 = torch.ops.aten.mm.default(view_452, permute_103); permute_103 = None + view_453 = torch.ops.aten.view.default(mm_57, [2, 4096, 4096]); mm_57 = None + view_454 = torch.ops.aten.view.default(view_453, [2, 4096, -1, 256]); view_453 = None + split_with_sizes_23 = torch.ops.aten.split_with_sizes.default(view_454, [128, 128], -1); view_454 = None + getitem_673 = split_with_sizes_23[0] + getitem_674 = split_with_sizes_23[1]; split_with_sizes_23 = None + expand_7 = torch.ops.aten.expand.default(convert_element_type_373, [-1, -1, 16, -1]); convert_element_type_373 = None + cat_57 = torch.ops.aten.cat.default([getitem_673, expand_7], -1); getitem_673 = expand_7 = None + permute_104 = torch.ops.aten.permute.default(cat_56, [0, 2, 1, 3]); cat_56 = None + permute_105 = torch.ops.aten.permute.default(cat_57, [0, 2, 1, 3]); cat_57 = None + permute_106 = torch.ops.aten.permute.default(getitem_674, [0, 2, 1, 3]); getitem_674 = None + sdpa_score7 = self.sdpa_score7 + sdpa_mask7 = self.sdpa_mask7 + flex_attention_7 = torch.ops.higher_order.flex_attention(permute_104, permute_105, permute_106, sdpa_score7, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask7), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score7 = sdpa_mask7 = None + getitem_675 = flex_attention_7[0] + getitem_676 = flex_attention_7[1]; flex_attention_7 = None + permute_107 = torch.ops.aten.permute.default(getitem_675, [0, 2, 1, 3]) + view_455 = torch.ops.aten.view.default(permute_107, [2, 4096, -1]); permute_107 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16) + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 128, '0'); convert_element_type_380 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + view_457 = torch.ops.aten.view.default(view_455, [8192, 2048]); view_455 = None + mm_58 = torch.ops.aten.mm.default(view_457, permute_108); view_457 = permute_108 = None + view_458 = torch.ops.aten.view.default(mm_58, [2, 4096, 2048]); mm_58 = None + add_416 = torch.ops.aten.add.Tensor(add_413, view_458); view_458 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16) + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 128, '0'); convert_element_type_383 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_416, torch.float32) + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_417 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_417); add_417 = None + mul_309 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_310 = torch.ops.aten.mul.Tensor(mul_309, wait_tensor_143); mul_309 = wait_tensor_143 = None + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_310, torch.bfloat16); mul_310 = None + view_460 = torch.ops.aten.view.default(convert_element_type_385, [-1, 2048]); convert_element_type_385 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16) + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 128, '0'); convert_element_type_386 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + slice_45 = torch.ops.aten.slice.Tensor(wait_tensor_144, 0, 0, 64); wait_tensor_144 = None + permute_109 = torch.ops.aten.permute.default(slice_45, [1, 0]); slice_45 = None + mm_59 = torch.ops.aten.mm.default(view_460, permute_109); permute_109 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(mm_59, torch.float32) + amax_6 = torch.ops.aten.amax.default(convert_element_type_389, [1], True) + sub_144 = torch.ops.aten.sub.Tensor(convert_element_type_389, amax_6); convert_element_type_389 = None + exp_19 = torch.ops.aten.exp.default(sub_144); sub_144 = None + sum_25 = torch.ops.aten.sum.dim_IntList(exp_19, [1], True) + div_31 = torch.ops.aten.div.Tensor(exp_19, sum_25); exp_19 = None + add_418 = torch.ops.aten.add.Tensor(div_31, primals_126); primals_126 = None + topk_6 = torch.ops.aten.topk.default(add_418, 6, -1, True, False); add_418 = None + getitem_679 = topk_6[1]; topk_6 = None + gather_6 = torch.ops.aten.gather.default(div_31, 1, getitem_679); div_31 = None + mul_311 = torch.ops.aten.mul.Tensor(gather_6, 1.0); gather_6 = None + view_462 = torch.ops.aten.view.default(getitem_679, [-1]) + histc_12 = torch.ops.aten.histc.default(view_462, 64, 0, 64) + add_419 = torch.ops.aten.add.Tensor(primals_128, histc_12) + sort_6 = torch.ops.aten.sort.stable(view_462, stable = True); view_462 = None + getitem_681 = sort_6[1]; sort_6 = None + div_32 = torch.ops.aten.div.Tensor_mode(getitem_681, 6, rounding_mode = 'floor') + index_12 = torch.ops.aten.index.Tensor(view_460, [div_32]) + all_to_all_single_18 = torch.ops._c10d_functional.all_to_all_single.default(histc_12, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_18); all_to_all_single_18 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_145); wait_tensor_145 = None + view_466 = torch.ops.aten.view.default(histc_12, [8, -1]); histc_12 = None + sum_26 = torch.ops.aten.sum.dim_IntList(view_466, [1]); view_466 = None + device_put_12 = torch.ops.prims.device_put.default(sum_26, device(type='cpu'), True); sum_26 = None + view_467 = torch.ops.aten.view.default(wait_tensor_146, [8, -1]) + sum_27 = torch.ops.aten.sum.dim_IntList(view_467, [1]) + device_put_13 = torch.ops.prims.device_put.default(sum_27, device(type='cpu')); sum_27 = None + select_96 = torch.ops.aten.select.int(device_put_12, 0, 0) + _local_scalar_dense_96 = torch.ops.aten._local_scalar_dense.default(select_96); select_96 = None + ge_120 = _local_scalar_dense_96 >= 0 + _assert_scalar_96 = torch.ops.aten._assert_scalar.default(ge_120, "Runtime assertion failed for expression u96 >= 0 on node 'ge_96'"); ge_120 = _assert_scalar_96 = None + select_97 = torch.ops.aten.select.int(device_put_12, 0, 1) + _local_scalar_dense_97 = torch.ops.aten._local_scalar_dense.default(select_97); select_97 = None + ge_121 = _local_scalar_dense_97 >= 0 + _assert_scalar_97 = torch.ops.aten._assert_scalar.default(ge_121, "Runtime assertion failed for expression u97 >= 0 on node 'ge_97'"); ge_121 = _assert_scalar_97 = None + select_98 = torch.ops.aten.select.int(device_put_12, 0, 2) + _local_scalar_dense_98 = torch.ops.aten._local_scalar_dense.default(select_98); select_98 = None + ge_122 = _local_scalar_dense_98 >= 0 + _assert_scalar_98 = torch.ops.aten._assert_scalar.default(ge_122, "Runtime assertion failed for expression u98 >= 0 on node 'ge_98'"); ge_122 = _assert_scalar_98 = None + select_99 = torch.ops.aten.select.int(device_put_12, 0, 3) + _local_scalar_dense_99 = torch.ops.aten._local_scalar_dense.default(select_99); select_99 = None + ge_123 = _local_scalar_dense_99 >= 0 + _assert_scalar_99 = torch.ops.aten._assert_scalar.default(ge_123, "Runtime assertion failed for expression u99 >= 0 on node 'ge_99'"); ge_123 = _assert_scalar_99 = None + select_100 = torch.ops.aten.select.int(device_put_12, 0, 4) + _local_scalar_dense_100 = torch.ops.aten._local_scalar_dense.default(select_100); select_100 = None + ge_124 = _local_scalar_dense_100 >= 0 + _assert_scalar_100 = torch.ops.aten._assert_scalar.default(ge_124, "Runtime assertion failed for expression u100 >= 0 on node 'ge_100'"); ge_124 = _assert_scalar_100 = None + select_101 = torch.ops.aten.select.int(device_put_12, 0, 5) + _local_scalar_dense_101 = torch.ops.aten._local_scalar_dense.default(select_101); select_101 = None + ge_125 = _local_scalar_dense_101 >= 0 + _assert_scalar_101 = torch.ops.aten._assert_scalar.default(ge_125, "Runtime assertion failed for expression u101 >= 0 on node 'ge_101'"); ge_125 = _assert_scalar_101 = None + select_102 = torch.ops.aten.select.int(device_put_12, 0, 6) + _local_scalar_dense_102 = torch.ops.aten._local_scalar_dense.default(select_102); select_102 = None + ge_126 = _local_scalar_dense_102 >= 0 + _assert_scalar_102 = torch.ops.aten._assert_scalar.default(ge_126, "Runtime assertion failed for expression u102 >= 0 on node 'ge_102'"); ge_126 = _assert_scalar_102 = None + select_103 = torch.ops.aten.select.int(device_put_12, 0, 7); device_put_12 = None + _local_scalar_dense_103 = torch.ops.aten._local_scalar_dense.default(select_103); select_103 = None + ge_127 = _local_scalar_dense_103 >= 0 + _assert_scalar_103 = torch.ops.aten._assert_scalar.default(ge_127, "Runtime assertion failed for expression u103 >= 0 on node 'ge_103'"); ge_127 = _assert_scalar_103 = None + select_104 = torch.ops.aten.select.int(device_put_13, 0, 0) + _local_scalar_dense_104 = torch.ops.aten._local_scalar_dense.default(select_104); select_104 = None + ge_128 = _local_scalar_dense_104 >= 0 + _assert_scalar_104 = torch.ops.aten._assert_scalar.default(ge_128, "Runtime assertion failed for expression u104 >= 0 on node 'ge_104'"); ge_128 = _assert_scalar_104 = None + select_105 = torch.ops.aten.select.int(device_put_13, 0, 1) + _local_scalar_dense_105 = torch.ops.aten._local_scalar_dense.default(select_105); select_105 = None + ge_129 = _local_scalar_dense_105 >= 0 + _assert_scalar_105 = torch.ops.aten._assert_scalar.default(ge_129, "Runtime assertion failed for expression u105 >= 0 on node 'ge_105'"); ge_129 = _assert_scalar_105 = None + select_106 = torch.ops.aten.select.int(device_put_13, 0, 2) + _local_scalar_dense_106 = torch.ops.aten._local_scalar_dense.default(select_106); select_106 = None + ge_130 = _local_scalar_dense_106 >= 0 + _assert_scalar_106 = torch.ops.aten._assert_scalar.default(ge_130, "Runtime assertion failed for expression u106 >= 0 on node 'ge_106'"); ge_130 = _assert_scalar_106 = None + select_107 = torch.ops.aten.select.int(device_put_13, 0, 3) + _local_scalar_dense_107 = torch.ops.aten._local_scalar_dense.default(select_107); select_107 = None + ge_131 = _local_scalar_dense_107 >= 0 + _assert_scalar_107 = torch.ops.aten._assert_scalar.default(ge_131, "Runtime assertion failed for expression u107 >= 0 on node 'ge_107'"); ge_131 = _assert_scalar_107 = None + select_108 = torch.ops.aten.select.int(device_put_13, 0, 4) + _local_scalar_dense_108 = torch.ops.aten._local_scalar_dense.default(select_108); select_108 = None + ge_132 = _local_scalar_dense_108 >= 0 + _assert_scalar_108 = torch.ops.aten._assert_scalar.default(ge_132, "Runtime assertion failed for expression u108 >= 0 on node 'ge_108'"); ge_132 = _assert_scalar_108 = None + select_109 = torch.ops.aten.select.int(device_put_13, 0, 5) + _local_scalar_dense_109 = torch.ops.aten._local_scalar_dense.default(select_109); select_109 = None + ge_133 = _local_scalar_dense_109 >= 0 + _assert_scalar_109 = torch.ops.aten._assert_scalar.default(ge_133, "Runtime assertion failed for expression u109 >= 0 on node 'ge_109'"); ge_133 = _assert_scalar_109 = None + select_110 = torch.ops.aten.select.int(device_put_13, 0, 6) + _local_scalar_dense_110 = torch.ops.aten._local_scalar_dense.default(select_110); select_110 = None + ge_134 = _local_scalar_dense_110 >= 0 + _assert_scalar_110 = torch.ops.aten._assert_scalar.default(ge_134, "Runtime assertion failed for expression u110 >= 0 on node 'ge_110'"); ge_134 = _assert_scalar_110 = None + select_111 = torch.ops.aten.select.int(device_put_13, 0, 7); device_put_13 = None + _local_scalar_dense_111 = torch.ops.aten._local_scalar_dense.default(select_111); select_111 = None + ge_135 = _local_scalar_dense_111 >= 0 + _assert_scalar_111 = torch.ops.aten._assert_scalar.default(ge_135, "Runtime assertion failed for expression u111 >= 0 on node 'ge_111'"); ge_135 = _assert_scalar_111 = None + all_to_all_single_19 = torch.ops._c10d_functional.all_to_all_single.default(index_12, [_local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111], [_local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103], '1033'); index_12 = None + sym_size_int_24 = torch.ops.aten.sym_size.int(all_to_all_single_19, 0) + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_19); all_to_all_single_19 = None + sym_sum_12 = torch.sym_sum((_local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111)) + add_426 = sym_sum_12 + 64; sym_sum_12 = None + add_427 = add_426 + 8; add_426 = None + sub_147 = add_427 - 1; add_427 = None + floordiv_6 = sub_147 // 8; sub_147 = None + mul_316 = floordiv_6 * 8; floordiv_6 = None + cumsum_18 = torch.ops.aten.cumsum.default(wait_tensor_146, 0) + sub_148 = torch.ops.aten.sub.Tensor(cumsum_18, wait_tensor_146); cumsum_18 = None + sum_28 = torch.ops.aten.sum.dim_IntList(view_467, [0]); view_467 = None + clamp_min_6 = torch.ops.aten.clamp_min.default(sum_28, 8); sum_28 = None + add_428 = torch.ops.aten.add.Tensor(clamp_min_6, 8); clamp_min_6 = None + sub_149 = torch.ops.aten.sub.Tensor(add_428, 1); add_428 = None + div_33 = torch.ops.aten.div.Tensor_mode(sub_149, 8, rounding_mode = 'floor'); sub_149 = None + mul_317 = torch.ops.aten.mul.Tensor(div_33, 8); div_33 = None + convert_element_type_392 = torch.ops.prims.convert_element_type.default(mul_317, torch.int32); mul_317 = None + cumsum_19 = torch.ops.aten.cumsum.default(convert_element_type_392, 0) + sub_150 = torch.ops.aten.sub.Tensor(cumsum_19, convert_element_type_392); cumsum_19 = None + full_98 = torch.ops.aten.full.default([mul_316], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_316 = None + triton_kernel_wrapper_functional_proxy_6 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 6, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_146, 'start_index_values_ptr': sub_148, 'write_offsets_ptr': sub_150, 'output_ptr': full_98}, tensors_to_clone = ['output_ptr']); wait_tensor_146 = sub_148 = sub_150 = full_98 = None + getitem_682 = triton_kernel_wrapper_functional_proxy_6['output_ptr']; triton_kernel_wrapper_functional_proxy_6 = None + cat_58 = torch.ops.aten.cat.default([wait_tensor_147, full_default]); wait_tensor_147 = None + sym_size_int_25 = torch.ops.aten.sym_size.int(cat_58, 0) + sym_sum_13 = torch.sym_sum((1, _local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111)) + index_13 = torch.ops.aten.index.Tensor(cat_58, [getitem_682]); cat_58 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16) + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 16, '1025'); convert_element_type_394 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + split_37 = torch.ops.aten.split.Tensor(wait_tensor_148, 8); wait_tensor_148 = None + getitem_699 = split_37[0] + getitem_700 = split_37[1] + getitem_701 = split_37[2] + getitem_702 = split_37[3] + getitem_703 = split_37[4] + getitem_704 = split_37[5] + getitem_705 = split_37[6] + getitem_706 = split_37[7] + getitem_707 = split_37[8] + getitem_708 = split_37[9] + getitem_709 = split_37[10] + getitem_710 = split_37[11] + getitem_711 = split_37[12] + getitem_712 = split_37[13] + getitem_713 = split_37[14] + getitem_714 = split_37[15]; split_37 = None + cat_60 = torch.ops.aten.cat.default([getitem_699, getitem_700, getitem_701, getitem_702, getitem_703, getitem_704, getitem_705, getitem_706, getitem_707, getitem_708, getitem_709, getitem_710, getitem_711, getitem_712, getitem_713, getitem_714], 1); getitem_699 = getitem_700 = getitem_701 = getitem_702 = getitem_703 = getitem_704 = getitem_705 = getitem_706 = getitem_707 = getitem_708 = getitem_709 = getitem_710 = getitem_711 = getitem_712 = getitem_713 = getitem_714 = None + convert_element_type_396 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16) + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_396, 16, '1025'); convert_element_type_396 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + split_38 = torch.ops.aten.split.Tensor(wait_tensor_150, 8); wait_tensor_150 = None + getitem_715 = split_38[0] + getitem_716 = split_38[1] + getitem_717 = split_38[2] + getitem_718 = split_38[3] + getitem_719 = split_38[4] + getitem_720 = split_38[5] + getitem_721 = split_38[6] + getitem_722 = split_38[7] + getitem_723 = split_38[8] + getitem_724 = split_38[9] + getitem_725 = split_38[10] + getitem_726 = split_38[11] + getitem_727 = split_38[12] + getitem_728 = split_38[13] + getitem_729 = split_38[14] + getitem_730 = split_38[15]; split_38 = None + cat_61 = torch.ops.aten.cat.default([getitem_715, getitem_716, getitem_717, getitem_718, getitem_719, getitem_720, getitem_721, getitem_722, getitem_723, getitem_724, getitem_725, getitem_726, getitem_727, getitem_728, getitem_729, getitem_730], 1); getitem_715 = getitem_716 = getitem_717 = getitem_718 = getitem_719 = getitem_720 = getitem_721 = getitem_722 = getitem_723 = getitem_724 = getitem_725 = getitem_726 = getitem_727 = getitem_728 = getitem_729 = getitem_730 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16) + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 16, '1025'); convert_element_type_397 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + split_39 = torch.ops.aten.split.Tensor(wait_tensor_151, 8); wait_tensor_151 = None + getitem_731 = split_39[0] + getitem_732 = split_39[1] + getitem_733 = split_39[2] + getitem_734 = split_39[3] + getitem_735 = split_39[4] + getitem_736 = split_39[5] + getitem_737 = split_39[6] + getitem_738 = split_39[7] + getitem_739 = split_39[8] + getitem_740 = split_39[9] + getitem_741 = split_39[10] + getitem_742 = split_39[11] + getitem_743 = split_39[12] + getitem_744 = split_39[13] + getitem_745 = split_39[14] + getitem_746 = split_39[15]; split_39 = None + cat_62 = torch.ops.aten.cat.default([getitem_731, getitem_732, getitem_733, getitem_734, getitem_735, getitem_736, getitem_737, getitem_738, getitem_739, getitem_740, getitem_741, getitem_742, getitem_743, getitem_744, getitem_745, getitem_746], 1); getitem_731 = getitem_732 = getitem_733 = getitem_734 = getitem_735 = getitem_736 = getitem_737 = getitem_738 = getitem_739 = getitem_740 = getitem_741 = getitem_742 = getitem_743 = getitem_744 = getitem_745 = getitem_746 = None + cumsum_20 = torch.ops.aten.cumsum.default(convert_element_type_392, 0, dtype = torch.int32); convert_element_type_392 = None + permute_110 = torch.ops.aten.permute.default(cat_60, [0, 2, 1]); cat_60 = None + _grouped_mm_18 = torch.ops.aten._grouped_mm.default(index_13, permute_110, cumsum_20) + convert_element_type_400 = torch.ops.prims.convert_element_type.default(_grouped_mm_18, torch.float32) + neg_13 = torch.ops.aten.neg.default(convert_element_type_400) + exp_20 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_440 = torch.ops.aten.add.Tensor(exp_20, 1); exp_20 = None + div_34 = torch.ops.aten.div.Tensor(convert_element_type_400, add_440); convert_element_type_400 = add_440 = None + convert_element_type_401 = torch.ops.prims.convert_element_type.default(div_34, torch.bfloat16); div_34 = None + permute_111 = torch.ops.aten.permute.default(cat_62, [0, 2, 1]); cat_62 = None + _grouped_mm_19 = torch.ops.aten._grouped_mm.default(index_13, permute_111, cumsum_20) + mul_329 = torch.ops.aten.mul.Tensor(convert_element_type_401, _grouped_mm_19); convert_element_type_401 = None + permute_112 = torch.ops.aten.permute.default(cat_61, [0, 2, 1]); cat_61 = None + _grouped_mm_20 = torch.ops.aten._grouped_mm.default(mul_329, permute_112, cumsum_20) + empty_6 = torch.ops.aten.empty.memory_format([sym_size_int_25, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_12 = torch.ops.aten.index_put.default(empty_6, [getitem_682], _grouped_mm_20); empty_6 = _grouped_mm_20 = None + slice_47 = torch.ops.aten.slice.Tensor(index_put_12, 0, 0, -1); index_put_12 = None + all_to_all_single_20 = torch.ops._c10d_functional.all_to_all_single.default(slice_47, [_local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103], [_local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111], '1033'); slice_47 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_20); all_to_all_single_20 = None + convert_element_type_402 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16) + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_402, 128, '0'); convert_element_type_402 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_113 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + mm_60 = torch.ops.aten.mm.default(view_460, permute_113); permute_113 = None + convert_element_type_405 = torch.ops.prims.convert_element_type.default(mm_60, torch.float32) + neg_14 = torch.ops.aten.neg.default(convert_element_type_405) + exp_21 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_476 = torch.ops.aten.add.Tensor(exp_21, 1); exp_21 = None + div_35 = torch.ops.aten.div.Tensor(convert_element_type_405, add_476); convert_element_type_405 = add_476 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(div_35, torch.bfloat16); div_35 = None + convert_element_type_407 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16) + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_407, 128, '0'); convert_element_type_407 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_114 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + mm_61 = torch.ops.aten.mm.default(view_460, permute_114); permute_114 = None + mul_349 = torch.ops.aten.mul.Tensor(convert_element_type_406, mm_61); convert_element_type_406 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16) + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_410, 128, '0'); convert_element_type_410 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_115 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + mm_62 = torch.ops.aten.mm.default(mul_349, permute_115); permute_115 = None + index_put_13 = torch.ops.aten.index_put.default(full_default_1, [getitem_681], wait_tensor_154); wait_tensor_154 = None + view_500 = torch.ops.aten.view.default(mul_311, [-1, 1, 6]); mul_311 = None + view_501 = torch.ops.aten.view.default(index_put_13, [-1, 6, 2048]); index_put_13 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(view_501, torch.float32); view_501 = None + bmm_6 = torch.ops.aten.bmm.default(view_500, convert_element_type_413) + convert_element_type_414 = torch.ops.prims.convert_element_type.default(bmm_6, torch.bfloat16); bmm_6 = None + squeeze_6 = torch.ops.aten.squeeze.dim(convert_element_type_414, 1); convert_element_type_414 = None + add_480 = torch.ops.aten.add.Tensor(mm_62, squeeze_6); mm_62 = squeeze_6 = None + view_502 = torch.ops.aten.view.default(add_480, [2, 4096, 2048]); add_480 = None + add_481 = torch.ops.aten.add.Tensor(add_416, view_502); view_502 = None + convert_element_type_415 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16) + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_415, 128, '0'); convert_element_type_415 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(add_481, torch.float32) + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_416, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_482 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_482); add_482 = None + mul_352 = torch.ops.aten.mul.Tensor(convert_element_type_416, rsqrt_24); convert_element_type_416 = None + mul_353 = torch.ops.aten.mul.Tensor(mul_352, wait_tensor_158); mul_352 = wait_tensor_158 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(mul_353, torch.bfloat16); mul_353 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16) + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 128, '0'); convert_element_type_418 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_116 = torch.ops.aten.permute.default(wait_tensor_159, [1, 0]); wait_tensor_159 = None + view_505 = torch.ops.aten.view.default(convert_element_type_417, [8192, 2048]); convert_element_type_417 = None + mm_63 = torch.ops.aten.mm.default(view_505, permute_116); permute_116 = None + view_506 = torch.ops.aten.view.default(mm_63, [2, 4096, 3072]); mm_63 = None + view_507 = torch.ops.aten.view.default(view_506, [2, 4096, -1, 192]); view_506 = None + split_with_sizes_24 = torch.ops.aten.split_with_sizes.default(view_507, [128, 64], -1); view_507 = None + getitem_779 = split_with_sizes_24[0] + getitem_780 = split_with_sizes_24[1]; split_with_sizes_24 = None + convert_element_type_421 = torch.ops.prims.convert_element_type.default(getitem_780, torch.float32); getitem_780 = None + view_508 = torch.ops.aten.view.default(convert_element_type_421, [2, 4096, 16, -1, 2]); convert_element_type_421 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_508); view_508 = None + mul_354 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_7); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_354); mul_354 = None + view_510 = torch.ops.aten.view.default(view_as_real_16, [2, 4096, 16, 64]); view_as_real_16 = None + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_510, torch.bfloat16); view_510 = None + cat_65 = torch.ops.aten.cat.default([getitem_779, convert_element_type_422], -1); getitem_779 = convert_element_type_422 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16) + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_423, 128, '0'); convert_element_type_423 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + slice_49 = torch.ops.aten.slice.Tensor(wait_tensor_160, 0, 0, 576); wait_tensor_160 = None + permute_117 = torch.ops.aten.permute.default(slice_49, [1, 0]); slice_49 = None + mm_64 = torch.ops.aten.mm.default(view_505, permute_117); permute_117 = None + view_513 = torch.ops.aten.view.default(mm_64, [2, 4096, 576]); mm_64 = None + split_with_sizes_25 = torch.ops.aten.split_with_sizes.default(view_513, [512, 64], -1); view_513 = None + getitem_781 = split_with_sizes_25[0] + getitem_782 = split_with_sizes_25[1]; split_with_sizes_25 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(getitem_782, 2); getitem_782 = None + convert_element_type_426 = torch.ops.prims.convert_element_type.default(unsqueeze_15, torch.float32); unsqueeze_15 = None + view_514 = torch.ops.aten.view.default(convert_element_type_426, [2, 4096, 1, -1, 2]); convert_element_type_426 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_514); view_514 = None + mul_355 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_7); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_355); mul_355 = None + view_516 = torch.ops.aten.view.default(view_as_real_17, [2, 4096, 1, 64]); view_as_real_17 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(view_516, torch.bfloat16); view_516 = None + convert_element_type_428 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16) + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_428, 128, '0'); convert_element_type_428 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_429 = torch.ops.prims.convert_element_type.default(getitem_781, torch.float32) + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_429, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_483 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_483); add_483 = None + mul_356 = torch.ops.aten.mul.Tensor(convert_element_type_429, rsqrt_25); convert_element_type_429 = None + mul_357 = torch.ops.aten.mul.Tensor(mul_356, wait_tensor_161); mul_356 = wait_tensor_161 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(mul_357, torch.bfloat16); mul_357 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16) + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_431, 128, '0'); convert_element_type_431 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + view_519 = torch.ops.aten.view.default(convert_element_type_430, [8192, 512]); convert_element_type_430 = None + mm_65 = torch.ops.aten.mm.default(view_519, permute_118); permute_118 = None + view_520 = torch.ops.aten.view.default(mm_65, [2, 4096, 4096]); mm_65 = None + view_521 = torch.ops.aten.view.default(view_520, [2, 4096, -1, 256]); view_520 = None + split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_521, [128, 128], -1); view_521 = None + getitem_783 = split_with_sizes_26[0] + getitem_784 = split_with_sizes_26[1]; split_with_sizes_26 = None + expand_8 = torch.ops.aten.expand.default(convert_element_type_427, [-1, -1, 16, -1]); convert_element_type_427 = None + cat_66 = torch.ops.aten.cat.default([getitem_783, expand_8], -1); getitem_783 = expand_8 = None + permute_119 = torch.ops.aten.permute.default(cat_65, [0, 2, 1, 3]); cat_65 = None + permute_120 = torch.ops.aten.permute.default(cat_66, [0, 2, 1, 3]); cat_66 = None + permute_121 = torch.ops.aten.permute.default(getitem_784, [0, 2, 1, 3]); getitem_784 = None + sdpa_score8 = self.sdpa_score8 + sdpa_mask8 = self.sdpa_mask8 + flex_attention_8 = torch.ops.higher_order.flex_attention(permute_119, permute_120, permute_121, sdpa_score8, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask8), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score8 = sdpa_mask8 = None + getitem_785 = flex_attention_8[0] + getitem_786 = flex_attention_8[1]; flex_attention_8 = None + permute_122 = torch.ops.aten.permute.default(getitem_785, [0, 2, 1, 3]) + view_522 = torch.ops.aten.view.default(permute_122, [2, 4096, -1]); permute_122 = None + convert_element_type_434 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16) + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_434, 128, '0'); convert_element_type_434 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + view_524 = torch.ops.aten.view.default(view_522, [8192, 2048]); view_522 = None + mm_66 = torch.ops.aten.mm.default(view_524, permute_123); view_524 = permute_123 = None + view_525 = torch.ops.aten.view.default(mm_66, [2, 4096, 2048]); mm_66 = None + add_484 = torch.ops.aten.add.Tensor(add_481, view_525); view_525 = None + convert_element_type_437 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16) + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_437, 128, '0'); convert_element_type_437 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_438 = torch.ops.prims.convert_element_type.default(add_484, torch.float32) + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_438, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_485 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_485); add_485 = None + mul_358 = torch.ops.aten.mul.Tensor(convert_element_type_438, rsqrt_26); convert_element_type_438 = None + mul_359 = torch.ops.aten.mul.Tensor(mul_358, wait_tensor_164); mul_358 = wait_tensor_164 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(mul_359, torch.bfloat16); mul_359 = None + view_527 = torch.ops.aten.view.default(convert_element_type_439, [-1, 2048]); convert_element_type_439 = None + convert_element_type_440 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16) + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_440, 128, '0'); convert_element_type_440 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + slice_51 = torch.ops.aten.slice.Tensor(wait_tensor_165, 0, 0, 64); wait_tensor_165 = None + permute_124 = torch.ops.aten.permute.default(slice_51, [1, 0]); slice_51 = None + mm_67 = torch.ops.aten.mm.default(view_527, permute_124); permute_124 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(mm_67, torch.float32) + amax_7 = torch.ops.aten.amax.default(convert_element_type_443, [1], True) + sub_168 = torch.ops.aten.sub.Tensor(convert_element_type_443, amax_7); convert_element_type_443 = None + exp_22 = torch.ops.aten.exp.default(sub_168); sub_168 = None + sum_29 = torch.ops.aten.sum.dim_IntList(exp_22, [1], True) + div_36 = torch.ops.aten.div.Tensor(exp_22, sum_29); exp_22 = None + add_486 = torch.ops.aten.add.Tensor(div_36, primals_142); primals_142 = None + topk_7 = torch.ops.aten.topk.default(add_486, 6, -1, True, False); add_486 = None + getitem_789 = topk_7[1]; topk_7 = None + gather_7 = torch.ops.aten.gather.default(div_36, 1, getitem_789); div_36 = None + mul_360 = torch.ops.aten.mul.Tensor(gather_7, 1.0); gather_7 = None + view_529 = torch.ops.aten.view.default(getitem_789, [-1]) + histc_14 = torch.ops.aten.histc.default(view_529, 64, 0, 64) + add_487 = torch.ops.aten.add.Tensor(primals_144, histc_14) + sort_7 = torch.ops.aten.sort.stable(view_529, stable = True); view_529 = None + getitem_791 = sort_7[1]; sort_7 = None + div_37 = torch.ops.aten.div.Tensor_mode(getitem_791, 6, rounding_mode = 'floor') + index_14 = torch.ops.aten.index.Tensor(view_527, [div_37]) + all_to_all_single_21 = torch.ops._c10d_functional.all_to_all_single.default(histc_14, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_21); all_to_all_single_21 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_166); wait_tensor_166 = None + view_533 = torch.ops.aten.view.default(histc_14, [8, -1]); histc_14 = None + sum_30 = torch.ops.aten.sum.dim_IntList(view_533, [1]); view_533 = None + device_put_14 = torch.ops.prims.device_put.default(sum_30, device(type='cpu'), True); sum_30 = None + view_534 = torch.ops.aten.view.default(wait_tensor_167, [8, -1]) + sum_31 = torch.ops.aten.sum.dim_IntList(view_534, [1]) + device_put_15 = torch.ops.prims.device_put.default(sum_31, device(type='cpu')); sum_31 = None + select_112 = torch.ops.aten.select.int(device_put_14, 0, 0) + _local_scalar_dense_112 = torch.ops.aten._local_scalar_dense.default(select_112); select_112 = None + ge_140 = _local_scalar_dense_112 >= 0 + _assert_scalar_112 = torch.ops.aten._assert_scalar.default(ge_140, "Runtime assertion failed for expression u112 >= 0 on node 'ge_112'"); ge_140 = _assert_scalar_112 = None + select_113 = torch.ops.aten.select.int(device_put_14, 0, 1) + _local_scalar_dense_113 = torch.ops.aten._local_scalar_dense.default(select_113); select_113 = None + ge_141 = _local_scalar_dense_113 >= 0 + _assert_scalar_113 = torch.ops.aten._assert_scalar.default(ge_141, "Runtime assertion failed for expression u113 >= 0 on node 'ge_113'"); ge_141 = _assert_scalar_113 = None + select_114 = torch.ops.aten.select.int(device_put_14, 0, 2) + _local_scalar_dense_114 = torch.ops.aten._local_scalar_dense.default(select_114); select_114 = None + ge_142 = _local_scalar_dense_114 >= 0 + _assert_scalar_114 = torch.ops.aten._assert_scalar.default(ge_142, "Runtime assertion failed for expression u114 >= 0 on node 'ge_114'"); ge_142 = _assert_scalar_114 = None + select_115 = torch.ops.aten.select.int(device_put_14, 0, 3) + _local_scalar_dense_115 = torch.ops.aten._local_scalar_dense.default(select_115); select_115 = None + ge_143 = _local_scalar_dense_115 >= 0 + _assert_scalar_115 = torch.ops.aten._assert_scalar.default(ge_143, "Runtime assertion failed for expression u115 >= 0 on node 'ge_115'"); ge_143 = _assert_scalar_115 = None + select_116 = torch.ops.aten.select.int(device_put_14, 0, 4) + _local_scalar_dense_116 = torch.ops.aten._local_scalar_dense.default(select_116); select_116 = None + ge_144 = _local_scalar_dense_116 >= 0 + _assert_scalar_116 = torch.ops.aten._assert_scalar.default(ge_144, "Runtime assertion failed for expression u116 >= 0 on node 'ge_116'"); ge_144 = _assert_scalar_116 = None + select_117 = torch.ops.aten.select.int(device_put_14, 0, 5) + _local_scalar_dense_117 = torch.ops.aten._local_scalar_dense.default(select_117); select_117 = None + ge_145 = _local_scalar_dense_117 >= 0 + _assert_scalar_117 = torch.ops.aten._assert_scalar.default(ge_145, "Runtime assertion failed for expression u117 >= 0 on node 'ge_117'"); ge_145 = _assert_scalar_117 = None + select_118 = torch.ops.aten.select.int(device_put_14, 0, 6) + _local_scalar_dense_118 = torch.ops.aten._local_scalar_dense.default(select_118); select_118 = None + ge_146 = _local_scalar_dense_118 >= 0 + _assert_scalar_118 = torch.ops.aten._assert_scalar.default(ge_146, "Runtime assertion failed for expression u118 >= 0 on node 'ge_118'"); ge_146 = _assert_scalar_118 = None + select_119 = torch.ops.aten.select.int(device_put_14, 0, 7); device_put_14 = None + _local_scalar_dense_119 = torch.ops.aten._local_scalar_dense.default(select_119); select_119 = None + ge_147 = _local_scalar_dense_119 >= 0 + _assert_scalar_119 = torch.ops.aten._assert_scalar.default(ge_147, "Runtime assertion failed for expression u119 >= 0 on node 'ge_119'"); ge_147 = _assert_scalar_119 = None + select_120 = torch.ops.aten.select.int(device_put_15, 0, 0) + _local_scalar_dense_120 = torch.ops.aten._local_scalar_dense.default(select_120); select_120 = None + ge_148 = _local_scalar_dense_120 >= 0 + _assert_scalar_120 = torch.ops.aten._assert_scalar.default(ge_148, "Runtime assertion failed for expression u120 >= 0 on node 'ge_120'"); ge_148 = _assert_scalar_120 = None + select_121 = torch.ops.aten.select.int(device_put_15, 0, 1) + _local_scalar_dense_121 = torch.ops.aten._local_scalar_dense.default(select_121); select_121 = None + ge_149 = _local_scalar_dense_121 >= 0 + _assert_scalar_121 = torch.ops.aten._assert_scalar.default(ge_149, "Runtime assertion failed for expression u121 >= 0 on node 'ge_121'"); ge_149 = _assert_scalar_121 = None + select_122 = torch.ops.aten.select.int(device_put_15, 0, 2) + _local_scalar_dense_122 = torch.ops.aten._local_scalar_dense.default(select_122); select_122 = None + ge_150 = _local_scalar_dense_122 >= 0 + _assert_scalar_122 = torch.ops.aten._assert_scalar.default(ge_150, "Runtime assertion failed for expression u122 >= 0 on node 'ge_122'"); ge_150 = _assert_scalar_122 = None + select_123 = torch.ops.aten.select.int(device_put_15, 0, 3) + _local_scalar_dense_123 = torch.ops.aten._local_scalar_dense.default(select_123); select_123 = None + ge_151 = _local_scalar_dense_123 >= 0 + _assert_scalar_123 = torch.ops.aten._assert_scalar.default(ge_151, "Runtime assertion failed for expression u123 >= 0 on node 'ge_123'"); ge_151 = _assert_scalar_123 = None + select_124 = torch.ops.aten.select.int(device_put_15, 0, 4) + _local_scalar_dense_124 = torch.ops.aten._local_scalar_dense.default(select_124); select_124 = None + ge_152 = _local_scalar_dense_124 >= 0 + _assert_scalar_124 = torch.ops.aten._assert_scalar.default(ge_152, "Runtime assertion failed for expression u124 >= 0 on node 'ge_124'"); ge_152 = _assert_scalar_124 = None + select_125 = torch.ops.aten.select.int(device_put_15, 0, 5) + _local_scalar_dense_125 = torch.ops.aten._local_scalar_dense.default(select_125); select_125 = None + ge_153 = _local_scalar_dense_125 >= 0 + _assert_scalar_125 = torch.ops.aten._assert_scalar.default(ge_153, "Runtime assertion failed for expression u125 >= 0 on node 'ge_125'"); ge_153 = _assert_scalar_125 = None + select_126 = torch.ops.aten.select.int(device_put_15, 0, 6) + _local_scalar_dense_126 = torch.ops.aten._local_scalar_dense.default(select_126); select_126 = None + ge_154 = _local_scalar_dense_126 >= 0 + _assert_scalar_126 = torch.ops.aten._assert_scalar.default(ge_154, "Runtime assertion failed for expression u126 >= 0 on node 'ge_126'"); ge_154 = _assert_scalar_126 = None + select_127 = torch.ops.aten.select.int(device_put_15, 0, 7); device_put_15 = None + _local_scalar_dense_127 = torch.ops.aten._local_scalar_dense.default(select_127); select_127 = None + ge_155 = _local_scalar_dense_127 >= 0 + _assert_scalar_127 = torch.ops.aten._assert_scalar.default(ge_155, "Runtime assertion failed for expression u127 >= 0 on node 'ge_127'"); ge_155 = _assert_scalar_127 = None + all_to_all_single_22 = torch.ops._c10d_functional.all_to_all_single.default(index_14, [_local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127], [_local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119], '1033'); index_14 = None + sym_size_int_28 = torch.ops.aten.sym_size.int(all_to_all_single_22, 0) + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_22); all_to_all_single_22 = None + sym_sum_14 = torch.sym_sum((_local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127)) + add_494 = sym_sum_14 + 64; sym_sum_14 = None + add_495 = add_494 + 8; add_494 = None + sub_171 = add_495 - 1; add_495 = None + floordiv_7 = sub_171 // 8; sub_171 = None + mul_365 = floordiv_7 * 8; floordiv_7 = None + cumsum_21 = torch.ops.aten.cumsum.default(wait_tensor_167, 0) + sub_172 = torch.ops.aten.sub.Tensor(cumsum_21, wait_tensor_167); cumsum_21 = None + sum_32 = torch.ops.aten.sum.dim_IntList(view_534, [0]); view_534 = None + clamp_min_7 = torch.ops.aten.clamp_min.default(sum_32, 8); sum_32 = None + add_496 = torch.ops.aten.add.Tensor(clamp_min_7, 8); clamp_min_7 = None + sub_173 = torch.ops.aten.sub.Tensor(add_496, 1); add_496 = None + div_38 = torch.ops.aten.div.Tensor_mode(sub_173, 8, rounding_mode = 'floor'); sub_173 = None + mul_366 = torch.ops.aten.mul.Tensor(div_38, 8); div_38 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(mul_366, torch.int32); mul_366 = None + cumsum_22 = torch.ops.aten.cumsum.default(convert_element_type_446, 0) + sub_174 = torch.ops.aten.sub.Tensor(cumsum_22, convert_element_type_446); cumsum_22 = None + full_111 = torch.ops.aten.full.default([mul_365], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_365 = None + triton_kernel_wrapper_functional_proxy_7 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 7, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_167, 'start_index_values_ptr': sub_172, 'write_offsets_ptr': sub_174, 'output_ptr': full_111}, tensors_to_clone = ['output_ptr']); wait_tensor_167 = sub_172 = sub_174 = full_111 = None + getitem_792 = triton_kernel_wrapper_functional_proxy_7['output_ptr']; triton_kernel_wrapper_functional_proxy_7 = None + cat_67 = torch.ops.aten.cat.default([wait_tensor_168, full_default]); wait_tensor_168 = None + sym_size_int_29 = torch.ops.aten.sym_size.int(cat_67, 0) + sym_sum_15 = torch.sym_sum((1, _local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127)) + index_15 = torch.ops.aten.index.Tensor(cat_67, [getitem_792]); cat_67 = None + convert_element_type_448 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16) + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_448, 16, '1025'); convert_element_type_448 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + split_43 = torch.ops.aten.split.Tensor(wait_tensor_169, 8); wait_tensor_169 = None + getitem_809 = split_43[0] + getitem_810 = split_43[1] + getitem_811 = split_43[2] + getitem_812 = split_43[3] + getitem_813 = split_43[4] + getitem_814 = split_43[5] + getitem_815 = split_43[6] + getitem_816 = split_43[7] + getitem_817 = split_43[8] + getitem_818 = split_43[9] + getitem_819 = split_43[10] + getitem_820 = split_43[11] + getitem_821 = split_43[12] + getitem_822 = split_43[13] + getitem_823 = split_43[14] + getitem_824 = split_43[15]; split_43 = None + cat_69 = torch.ops.aten.cat.default([getitem_809, getitem_810, getitem_811, getitem_812, getitem_813, getitem_814, getitem_815, getitem_816, getitem_817, getitem_818, getitem_819, getitem_820, getitem_821, getitem_822, getitem_823, getitem_824], 1); getitem_809 = getitem_810 = getitem_811 = getitem_812 = getitem_813 = getitem_814 = getitem_815 = getitem_816 = getitem_817 = getitem_818 = getitem_819 = getitem_820 = getitem_821 = getitem_822 = getitem_823 = getitem_824 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16) + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_450, 16, '1025'); convert_element_type_450 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + split_44 = torch.ops.aten.split.Tensor(wait_tensor_171, 8); wait_tensor_171 = None + getitem_825 = split_44[0] + getitem_826 = split_44[1] + getitem_827 = split_44[2] + getitem_828 = split_44[3] + getitem_829 = split_44[4] + getitem_830 = split_44[5] + getitem_831 = split_44[6] + getitem_832 = split_44[7] + getitem_833 = split_44[8] + getitem_834 = split_44[9] + getitem_835 = split_44[10] + getitem_836 = split_44[11] + getitem_837 = split_44[12] + getitem_838 = split_44[13] + getitem_839 = split_44[14] + getitem_840 = split_44[15]; split_44 = None + cat_70 = torch.ops.aten.cat.default([getitem_825, getitem_826, getitem_827, getitem_828, getitem_829, getitem_830, getitem_831, getitem_832, getitem_833, getitem_834, getitem_835, getitem_836, getitem_837, getitem_838, getitem_839, getitem_840], 1); getitem_825 = getitem_826 = getitem_827 = getitem_828 = getitem_829 = getitem_830 = getitem_831 = getitem_832 = getitem_833 = getitem_834 = getitem_835 = getitem_836 = getitem_837 = getitem_838 = getitem_839 = getitem_840 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16) + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 16, '1025'); convert_element_type_451 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + split_45 = torch.ops.aten.split.Tensor(wait_tensor_172, 8); wait_tensor_172 = None + getitem_841 = split_45[0] + getitem_842 = split_45[1] + getitem_843 = split_45[2] + getitem_844 = split_45[3] + getitem_845 = split_45[4] + getitem_846 = split_45[5] + getitem_847 = split_45[6] + getitem_848 = split_45[7] + getitem_849 = split_45[8] + getitem_850 = split_45[9] + getitem_851 = split_45[10] + getitem_852 = split_45[11] + getitem_853 = split_45[12] + getitem_854 = split_45[13] + getitem_855 = split_45[14] + getitem_856 = split_45[15]; split_45 = None + cat_71 = torch.ops.aten.cat.default([getitem_841, getitem_842, getitem_843, getitem_844, getitem_845, getitem_846, getitem_847, getitem_848, getitem_849, getitem_850, getitem_851, getitem_852, getitem_853, getitem_854, getitem_855, getitem_856], 1); getitem_841 = getitem_842 = getitem_843 = getitem_844 = getitem_845 = getitem_846 = getitem_847 = getitem_848 = getitem_849 = getitem_850 = getitem_851 = getitem_852 = getitem_853 = getitem_854 = getitem_855 = getitem_856 = None + cumsum_23 = torch.ops.aten.cumsum.default(convert_element_type_446, 0, dtype = torch.int32); convert_element_type_446 = None + permute_125 = torch.ops.aten.permute.default(cat_69, [0, 2, 1]); cat_69 = None + _grouped_mm_21 = torch.ops.aten._grouped_mm.default(index_15, permute_125, cumsum_23) + convert_element_type_454 = torch.ops.prims.convert_element_type.default(_grouped_mm_21, torch.float32) + neg_15 = torch.ops.aten.neg.default(convert_element_type_454) + exp_23 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_508 = torch.ops.aten.add.Tensor(exp_23, 1); exp_23 = None + div_39 = torch.ops.aten.div.Tensor(convert_element_type_454, add_508); convert_element_type_454 = add_508 = None + convert_element_type_455 = torch.ops.prims.convert_element_type.default(div_39, torch.bfloat16); div_39 = None + permute_126 = torch.ops.aten.permute.default(cat_71, [0, 2, 1]); cat_71 = None + _grouped_mm_22 = torch.ops.aten._grouped_mm.default(index_15, permute_126, cumsum_23) + mul_378 = torch.ops.aten.mul.Tensor(convert_element_type_455, _grouped_mm_22); convert_element_type_455 = None + permute_127 = torch.ops.aten.permute.default(cat_70, [0, 2, 1]); cat_70 = None + _grouped_mm_23 = torch.ops.aten._grouped_mm.default(mul_378, permute_127, cumsum_23) + empty_7 = torch.ops.aten.empty.memory_format([sym_size_int_29, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_14 = torch.ops.aten.index_put.default(empty_7, [getitem_792], _grouped_mm_23); empty_7 = _grouped_mm_23 = None + slice_53 = torch.ops.aten.slice.Tensor(index_put_14, 0, 0, -1); index_put_14 = None + all_to_all_single_23 = torch.ops._c10d_functional.all_to_all_single.default(slice_53, [_local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119], [_local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127], '1033'); slice_53 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_23); all_to_all_single_23 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16) + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_456, 128, '0'); convert_element_type_456 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + mm_68 = torch.ops.aten.mm.default(view_527, permute_128); permute_128 = None + convert_element_type_459 = torch.ops.prims.convert_element_type.default(mm_68, torch.float32) + neg_16 = torch.ops.aten.neg.default(convert_element_type_459) + exp_24 = torch.ops.aten.exp.default(neg_16); neg_16 = None + add_544 = torch.ops.aten.add.Tensor(exp_24, 1); exp_24 = None + div_40 = torch.ops.aten.div.Tensor(convert_element_type_459, add_544); convert_element_type_459 = add_544 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(div_40, torch.bfloat16); div_40 = None + convert_element_type_461 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16) + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_461, 128, '0'); convert_element_type_461 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_177, [1, 0]); wait_tensor_177 = None + mm_69 = torch.ops.aten.mm.default(view_527, permute_129); permute_129 = None + mul_398 = torch.ops.aten.mul.Tensor(convert_element_type_460, mm_69); convert_element_type_460 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16) + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_464, 128, '0'); convert_element_type_464 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + mm_70 = torch.ops.aten.mm.default(mul_398, permute_130); permute_130 = None + index_put_15 = torch.ops.aten.index_put.default(full_default_1, [getitem_791], wait_tensor_175); wait_tensor_175 = None + view_567 = torch.ops.aten.view.default(mul_360, [-1, 1, 6]); mul_360 = None + view_568 = torch.ops.aten.view.default(index_put_15, [-1, 6, 2048]); index_put_15 = None + convert_element_type_467 = torch.ops.prims.convert_element_type.default(view_568, torch.float32); view_568 = None + bmm_7 = torch.ops.aten.bmm.default(view_567, convert_element_type_467) + convert_element_type_468 = torch.ops.prims.convert_element_type.default(bmm_7, torch.bfloat16); bmm_7 = None + squeeze_7 = torch.ops.aten.squeeze.dim(convert_element_type_468, 1); convert_element_type_468 = None + add_548 = torch.ops.aten.add.Tensor(mm_70, squeeze_7); mm_70 = squeeze_7 = None + view_569 = torch.ops.aten.view.default(add_548, [2, 4096, 2048]); add_548 = None + add_549 = torch.ops.aten.add.Tensor(add_484, view_569); view_569 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16) + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 128, '0'); convert_element_type_469 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + convert_element_type_470 = torch.ops.prims.convert_element_type.default(add_549, torch.float32) + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_470, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_550 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_550); add_550 = None + mul_401 = torch.ops.aten.mul.Tensor(convert_element_type_470, rsqrt_27); convert_element_type_470 = None + mul_402 = torch.ops.aten.mul.Tensor(mul_401, wait_tensor_179); mul_401 = wait_tensor_179 = None + convert_element_type_471 = torch.ops.prims.convert_element_type.default(mul_402, torch.bfloat16); mul_402 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16) + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 128, '0'); convert_element_type_472 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + view_572 = torch.ops.aten.view.default(convert_element_type_471, [8192, 2048]); convert_element_type_471 = None + mm_71 = torch.ops.aten.mm.default(view_572, permute_131); permute_131 = None + view_573 = torch.ops.aten.view.default(mm_71, [2, 4096, 3072]); mm_71 = None + view_574 = torch.ops.aten.view.default(view_573, [2, 4096, -1, 192]); view_573 = None + split_with_sizes_27 = torch.ops.aten.split_with_sizes.default(view_574, [128, 64], -1); view_574 = None + getitem_889 = split_with_sizes_27[0] + getitem_890 = split_with_sizes_27[1]; split_with_sizes_27 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(getitem_890, torch.float32); getitem_890 = None + view_575 = torch.ops.aten.view.default(convert_element_type_475, [2, 4096, 16, -1, 2]); convert_element_type_475 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_575); view_575 = None + mul_403 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_7); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_403); mul_403 = None + view_577 = torch.ops.aten.view.default(view_as_real_18, [2, 4096, 16, 64]); view_as_real_18 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_577, torch.bfloat16); view_577 = None + cat_74 = torch.ops.aten.cat.default([getitem_889, convert_element_type_476], -1); getitem_889 = convert_element_type_476 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16) + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_477, 128, '0'); convert_element_type_477 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + slice_55 = torch.ops.aten.slice.Tensor(wait_tensor_181, 0, 0, 576); wait_tensor_181 = None + permute_132 = torch.ops.aten.permute.default(slice_55, [1, 0]); slice_55 = None + mm_72 = torch.ops.aten.mm.default(view_572, permute_132); permute_132 = None + view_580 = torch.ops.aten.view.default(mm_72, [2, 4096, 576]); mm_72 = None + split_with_sizes_28 = torch.ops.aten.split_with_sizes.default(view_580, [512, 64], -1); view_580 = None + getitem_891 = split_with_sizes_28[0] + getitem_892 = split_with_sizes_28[1]; split_with_sizes_28 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(getitem_892, 2); getitem_892 = None + convert_element_type_480 = torch.ops.prims.convert_element_type.default(unsqueeze_17, torch.float32); unsqueeze_17 = None + view_581 = torch.ops.aten.view.default(convert_element_type_480, [2, 4096, 1, -1, 2]); convert_element_type_480 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_581); view_581 = None + mul_404 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_7); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_404); mul_404 = None + view_583 = torch.ops.aten.view.default(view_as_real_19, [2, 4096, 1, 64]); view_as_real_19 = None + convert_element_type_481 = torch.ops.prims.convert_element_type.default(view_583, torch.bfloat16); view_583 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16) + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 128, '0'); convert_element_type_482 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(getitem_891, torch.float32) + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_551 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_551); add_551 = None + mul_405 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_28); convert_element_type_483 = None + mul_406 = torch.ops.aten.mul.Tensor(mul_405, wait_tensor_182); mul_405 = wait_tensor_182 = None + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_406, torch.bfloat16); mul_406 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16) + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 128, '0'); convert_element_type_485 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + view_586 = torch.ops.aten.view.default(convert_element_type_484, [8192, 512]); convert_element_type_484 = None + mm_73 = torch.ops.aten.mm.default(view_586, permute_133); permute_133 = None + view_587 = torch.ops.aten.view.default(mm_73, [2, 4096, 4096]); mm_73 = None + view_588 = torch.ops.aten.view.default(view_587, [2, 4096, -1, 256]); view_587 = None + split_with_sizes_29 = torch.ops.aten.split_with_sizes.default(view_588, [128, 128], -1); view_588 = None + getitem_893 = split_with_sizes_29[0] + getitem_894 = split_with_sizes_29[1]; split_with_sizes_29 = None + expand_9 = torch.ops.aten.expand.default(convert_element_type_481, [-1, -1, 16, -1]); convert_element_type_481 = None + cat_75 = torch.ops.aten.cat.default([getitem_893, expand_9], -1); getitem_893 = expand_9 = None + permute_134 = torch.ops.aten.permute.default(cat_74, [0, 2, 1, 3]); cat_74 = None + permute_135 = torch.ops.aten.permute.default(cat_75, [0, 2, 1, 3]); cat_75 = None + permute_136 = torch.ops.aten.permute.default(getitem_894, [0, 2, 1, 3]); getitem_894 = None + sdpa_score9 = self.sdpa_score9 + sdpa_mask9 = self.sdpa_mask9 + flex_attention_9 = torch.ops.higher_order.flex_attention(permute_134, permute_135, permute_136, sdpa_score9, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask9), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score9 = sdpa_mask9 = None + getitem_895 = flex_attention_9[0] + getitem_896 = flex_attention_9[1]; flex_attention_9 = None + permute_137 = torch.ops.aten.permute.default(getitem_895, [0, 2, 1, 3]) + view_589 = torch.ops.aten.view.default(permute_137, [2, 4096, -1]); permute_137 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16) + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_488, 128, '0'); convert_element_type_488 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_138 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + view_591 = torch.ops.aten.view.default(view_589, [8192, 2048]); view_589 = None + mm_74 = torch.ops.aten.mm.default(view_591, permute_138); view_591 = permute_138 = None + view_592 = torch.ops.aten.view.default(mm_74, [2, 4096, 2048]); mm_74 = None + add_552 = torch.ops.aten.add.Tensor(add_549, view_592); view_592 = None + convert_element_type_491 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16) + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_491, 128, '0'); convert_element_type_491 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + convert_element_type_492 = torch.ops.prims.convert_element_type.default(add_552, torch.float32) + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_492, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_553 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_553); add_553 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_492, rsqrt_29); convert_element_type_492 = None + mul_408 = torch.ops.aten.mul.Tensor(mul_407, wait_tensor_185); mul_407 = wait_tensor_185 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(mul_408, torch.bfloat16); mul_408 = None + view_594 = torch.ops.aten.view.default(convert_element_type_493, [-1, 2048]); convert_element_type_493 = None + convert_element_type_494 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16) + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_494, 128, '0'); convert_element_type_494 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + slice_57 = torch.ops.aten.slice.Tensor(wait_tensor_186, 0, 0, 64); wait_tensor_186 = None + permute_139 = torch.ops.aten.permute.default(slice_57, [1, 0]); slice_57 = None + mm_75 = torch.ops.aten.mm.default(view_594, permute_139); permute_139 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(mm_75, torch.float32) + amax_8 = torch.ops.aten.amax.default(convert_element_type_497, [1], True) + sub_192 = torch.ops.aten.sub.Tensor(convert_element_type_497, amax_8); convert_element_type_497 = None + exp_25 = torch.ops.aten.exp.default(sub_192); sub_192 = None + sum_33 = torch.ops.aten.sum.dim_IntList(exp_25, [1], True) + div_41 = torch.ops.aten.div.Tensor(exp_25, sum_33); exp_25 = None + add_554 = torch.ops.aten.add.Tensor(div_41, primals_158); primals_158 = None + topk_8 = torch.ops.aten.topk.default(add_554, 6, -1, True, False); add_554 = None + getitem_899 = topk_8[1]; topk_8 = None + gather_8 = torch.ops.aten.gather.default(div_41, 1, getitem_899); div_41 = None + mul_409 = torch.ops.aten.mul.Tensor(gather_8, 1.0); gather_8 = None + view_596 = torch.ops.aten.view.default(getitem_899, [-1]) + histc_16 = torch.ops.aten.histc.default(view_596, 64, 0, 64) + add_555 = torch.ops.aten.add.Tensor(primals_160, histc_16) + sort_8 = torch.ops.aten.sort.stable(view_596, stable = True); view_596 = None + getitem_901 = sort_8[1]; sort_8 = None + div_42 = torch.ops.aten.div.Tensor_mode(getitem_901, 6, rounding_mode = 'floor') + index_16 = torch.ops.aten.index.Tensor(view_594, [div_42]) + all_to_all_single_24 = torch.ops._c10d_functional.all_to_all_single.default(histc_16, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_24); all_to_all_single_24 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_187); wait_tensor_187 = None + view_600 = torch.ops.aten.view.default(histc_16, [8, -1]); histc_16 = None + sum_34 = torch.ops.aten.sum.dim_IntList(view_600, [1]); view_600 = None + device_put_16 = torch.ops.prims.device_put.default(sum_34, device(type='cpu'), True); sum_34 = None + view_601 = torch.ops.aten.view.default(wait_tensor_188, [8, -1]) + sum_35 = torch.ops.aten.sum.dim_IntList(view_601, [1]) + device_put_17 = torch.ops.prims.device_put.default(sum_35, device(type='cpu')); sum_35 = None + select_128 = torch.ops.aten.select.int(device_put_16, 0, 0) + _local_scalar_dense_128 = torch.ops.aten._local_scalar_dense.default(select_128); select_128 = None + ge_160 = _local_scalar_dense_128 >= 0 + _assert_scalar_128 = torch.ops.aten._assert_scalar.default(ge_160, "Runtime assertion failed for expression u128 >= 0 on node 'ge_128'"); ge_160 = _assert_scalar_128 = None + select_129 = torch.ops.aten.select.int(device_put_16, 0, 1) + _local_scalar_dense_129 = torch.ops.aten._local_scalar_dense.default(select_129); select_129 = None + ge_161 = _local_scalar_dense_129 >= 0 + _assert_scalar_129 = torch.ops.aten._assert_scalar.default(ge_161, "Runtime assertion failed for expression u129 >= 0 on node 'ge_129'"); ge_161 = _assert_scalar_129 = None + select_130 = torch.ops.aten.select.int(device_put_16, 0, 2) + _local_scalar_dense_130 = torch.ops.aten._local_scalar_dense.default(select_130); select_130 = None + ge_162 = _local_scalar_dense_130 >= 0 + _assert_scalar_130 = torch.ops.aten._assert_scalar.default(ge_162, "Runtime assertion failed for expression u130 >= 0 on node 'ge_130'"); ge_162 = _assert_scalar_130 = None + select_131 = torch.ops.aten.select.int(device_put_16, 0, 3) + _local_scalar_dense_131 = torch.ops.aten._local_scalar_dense.default(select_131); select_131 = None + ge_163 = _local_scalar_dense_131 >= 0 + _assert_scalar_131 = torch.ops.aten._assert_scalar.default(ge_163, "Runtime assertion failed for expression u131 >= 0 on node 'ge_131'"); ge_163 = _assert_scalar_131 = None + select_132 = torch.ops.aten.select.int(device_put_16, 0, 4) + _local_scalar_dense_132 = torch.ops.aten._local_scalar_dense.default(select_132); select_132 = None + ge_164 = _local_scalar_dense_132 >= 0 + _assert_scalar_132 = torch.ops.aten._assert_scalar.default(ge_164, "Runtime assertion failed for expression u132 >= 0 on node 'ge_132'"); ge_164 = _assert_scalar_132 = None + select_133 = torch.ops.aten.select.int(device_put_16, 0, 5) + _local_scalar_dense_133 = torch.ops.aten._local_scalar_dense.default(select_133); select_133 = None + ge_165 = _local_scalar_dense_133 >= 0 + _assert_scalar_133 = torch.ops.aten._assert_scalar.default(ge_165, "Runtime assertion failed for expression u133 >= 0 on node 'ge_133'"); ge_165 = _assert_scalar_133 = None + select_134 = torch.ops.aten.select.int(device_put_16, 0, 6) + _local_scalar_dense_134 = torch.ops.aten._local_scalar_dense.default(select_134); select_134 = None + ge_166 = _local_scalar_dense_134 >= 0 + _assert_scalar_134 = torch.ops.aten._assert_scalar.default(ge_166, "Runtime assertion failed for expression u134 >= 0 on node 'ge_134'"); ge_166 = _assert_scalar_134 = None + select_135 = torch.ops.aten.select.int(device_put_16, 0, 7); device_put_16 = None + _local_scalar_dense_135 = torch.ops.aten._local_scalar_dense.default(select_135); select_135 = None + ge_167 = _local_scalar_dense_135 >= 0 + _assert_scalar_135 = torch.ops.aten._assert_scalar.default(ge_167, "Runtime assertion failed for expression u135 >= 0 on node 'ge_135'"); ge_167 = _assert_scalar_135 = None + select_136 = torch.ops.aten.select.int(device_put_17, 0, 0) + _local_scalar_dense_136 = torch.ops.aten._local_scalar_dense.default(select_136); select_136 = None + ge_168 = _local_scalar_dense_136 >= 0 + _assert_scalar_136 = torch.ops.aten._assert_scalar.default(ge_168, "Runtime assertion failed for expression u136 >= 0 on node 'ge_136'"); ge_168 = _assert_scalar_136 = None + select_137 = torch.ops.aten.select.int(device_put_17, 0, 1) + _local_scalar_dense_137 = torch.ops.aten._local_scalar_dense.default(select_137); select_137 = None + ge_169 = _local_scalar_dense_137 >= 0 + _assert_scalar_137 = torch.ops.aten._assert_scalar.default(ge_169, "Runtime assertion failed for expression u137 >= 0 on node 'ge_137'"); ge_169 = _assert_scalar_137 = None + select_138 = torch.ops.aten.select.int(device_put_17, 0, 2) + _local_scalar_dense_138 = torch.ops.aten._local_scalar_dense.default(select_138); select_138 = None + ge_170 = _local_scalar_dense_138 >= 0 + _assert_scalar_138 = torch.ops.aten._assert_scalar.default(ge_170, "Runtime assertion failed for expression u138 >= 0 on node 'ge_138'"); ge_170 = _assert_scalar_138 = None + select_139 = torch.ops.aten.select.int(device_put_17, 0, 3) + _local_scalar_dense_139 = torch.ops.aten._local_scalar_dense.default(select_139); select_139 = None + ge_171 = _local_scalar_dense_139 >= 0 + _assert_scalar_139 = torch.ops.aten._assert_scalar.default(ge_171, "Runtime assertion failed for expression u139 >= 0 on node 'ge_139'"); ge_171 = _assert_scalar_139 = None + select_140 = torch.ops.aten.select.int(device_put_17, 0, 4) + _local_scalar_dense_140 = torch.ops.aten._local_scalar_dense.default(select_140); select_140 = None + ge_172 = _local_scalar_dense_140 >= 0 + _assert_scalar_140 = torch.ops.aten._assert_scalar.default(ge_172, "Runtime assertion failed for expression u140 >= 0 on node 'ge_140'"); ge_172 = _assert_scalar_140 = None + select_141 = torch.ops.aten.select.int(device_put_17, 0, 5) + _local_scalar_dense_141 = torch.ops.aten._local_scalar_dense.default(select_141); select_141 = None + ge_173 = _local_scalar_dense_141 >= 0 + _assert_scalar_141 = torch.ops.aten._assert_scalar.default(ge_173, "Runtime assertion failed for expression u141 >= 0 on node 'ge_141'"); ge_173 = _assert_scalar_141 = None + select_142 = torch.ops.aten.select.int(device_put_17, 0, 6) + _local_scalar_dense_142 = torch.ops.aten._local_scalar_dense.default(select_142); select_142 = None + ge_174 = _local_scalar_dense_142 >= 0 + _assert_scalar_142 = torch.ops.aten._assert_scalar.default(ge_174, "Runtime assertion failed for expression u142 >= 0 on node 'ge_142'"); ge_174 = _assert_scalar_142 = None + select_143 = torch.ops.aten.select.int(device_put_17, 0, 7); device_put_17 = None + _local_scalar_dense_143 = torch.ops.aten._local_scalar_dense.default(select_143); select_143 = None + ge_175 = _local_scalar_dense_143 >= 0 + _assert_scalar_143 = torch.ops.aten._assert_scalar.default(ge_175, "Runtime assertion failed for expression u143 >= 0 on node 'ge_143'"); ge_175 = _assert_scalar_143 = None + all_to_all_single_25 = torch.ops._c10d_functional.all_to_all_single.default(index_16, [_local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143], [_local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135], '1033'); index_16 = None + sym_size_int_32 = torch.ops.aten.sym_size.int(all_to_all_single_25, 0) + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_25); all_to_all_single_25 = None + sym_sum_16 = torch.sym_sum((_local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143)) + add_562 = sym_sum_16 + 64; sym_sum_16 = None + add_563 = add_562 + 8; add_562 = None + sub_195 = add_563 - 1; add_563 = None + floordiv_8 = sub_195 // 8; sub_195 = None + mul_414 = floordiv_8 * 8; floordiv_8 = None + cumsum_24 = torch.ops.aten.cumsum.default(wait_tensor_188, 0) + sub_196 = torch.ops.aten.sub.Tensor(cumsum_24, wait_tensor_188); cumsum_24 = None + sum_36 = torch.ops.aten.sum.dim_IntList(view_601, [0]); view_601 = None + clamp_min_8 = torch.ops.aten.clamp_min.default(sum_36, 8); sum_36 = None + add_564 = torch.ops.aten.add.Tensor(clamp_min_8, 8); clamp_min_8 = None + sub_197 = torch.ops.aten.sub.Tensor(add_564, 1); add_564 = None + div_43 = torch.ops.aten.div.Tensor_mode(sub_197, 8, rounding_mode = 'floor'); sub_197 = None + mul_415 = torch.ops.aten.mul.Tensor(div_43, 8); div_43 = None + convert_element_type_500 = torch.ops.prims.convert_element_type.default(mul_415, torch.int32); mul_415 = None + cumsum_25 = torch.ops.aten.cumsum.default(convert_element_type_500, 0) + sub_198 = torch.ops.aten.sub.Tensor(cumsum_25, convert_element_type_500); cumsum_25 = None + full_124 = torch.ops.aten.full.default([mul_414], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_414 = None + triton_kernel_wrapper_functional_proxy_8 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 8, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_188, 'start_index_values_ptr': sub_196, 'write_offsets_ptr': sub_198, 'output_ptr': full_124}, tensors_to_clone = ['output_ptr']); wait_tensor_188 = sub_196 = sub_198 = full_124 = None + getitem_902 = triton_kernel_wrapper_functional_proxy_8['output_ptr']; triton_kernel_wrapper_functional_proxy_8 = None + cat_76 = torch.ops.aten.cat.default([wait_tensor_189, full_default]); wait_tensor_189 = None + sym_size_int_33 = torch.ops.aten.sym_size.int(cat_76, 0) + sym_sum_17 = torch.sym_sum((1, _local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143)) + index_17 = torch.ops.aten.index.Tensor(cat_76, [getitem_902]); cat_76 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16) + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 16, '1025'); convert_element_type_502 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + split_49 = torch.ops.aten.split.Tensor(wait_tensor_190, 8); wait_tensor_190 = None + getitem_919 = split_49[0] + getitem_920 = split_49[1] + getitem_921 = split_49[2] + getitem_922 = split_49[3] + getitem_923 = split_49[4] + getitem_924 = split_49[5] + getitem_925 = split_49[6] + getitem_926 = split_49[7] + getitem_927 = split_49[8] + getitem_928 = split_49[9] + getitem_929 = split_49[10] + getitem_930 = split_49[11] + getitem_931 = split_49[12] + getitem_932 = split_49[13] + getitem_933 = split_49[14] + getitem_934 = split_49[15]; split_49 = None + cat_78 = torch.ops.aten.cat.default([getitem_919, getitem_920, getitem_921, getitem_922, getitem_923, getitem_924, getitem_925, getitem_926, getitem_927, getitem_928, getitem_929, getitem_930, getitem_931, getitem_932, getitem_933, getitem_934], 1); getitem_919 = getitem_920 = getitem_921 = getitem_922 = getitem_923 = getitem_924 = getitem_925 = getitem_926 = getitem_927 = getitem_928 = getitem_929 = getitem_930 = getitem_931 = getitem_932 = getitem_933 = getitem_934 = None + convert_element_type_504 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16) + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_504, 16, '1025'); convert_element_type_504 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + split_50 = torch.ops.aten.split.Tensor(wait_tensor_192, 8); wait_tensor_192 = None + getitem_935 = split_50[0] + getitem_936 = split_50[1] + getitem_937 = split_50[2] + getitem_938 = split_50[3] + getitem_939 = split_50[4] + getitem_940 = split_50[5] + getitem_941 = split_50[6] + getitem_942 = split_50[7] + getitem_943 = split_50[8] + getitem_944 = split_50[9] + getitem_945 = split_50[10] + getitem_946 = split_50[11] + getitem_947 = split_50[12] + getitem_948 = split_50[13] + getitem_949 = split_50[14] + getitem_950 = split_50[15]; split_50 = None + cat_79 = torch.ops.aten.cat.default([getitem_935, getitem_936, getitem_937, getitem_938, getitem_939, getitem_940, getitem_941, getitem_942, getitem_943, getitem_944, getitem_945, getitem_946, getitem_947, getitem_948, getitem_949, getitem_950], 1); getitem_935 = getitem_936 = getitem_937 = getitem_938 = getitem_939 = getitem_940 = getitem_941 = getitem_942 = getitem_943 = getitem_944 = getitem_945 = getitem_946 = getitem_947 = getitem_948 = getitem_949 = getitem_950 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16) + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 16, '1025'); convert_element_type_505 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + split_51 = torch.ops.aten.split.Tensor(wait_tensor_193, 8); wait_tensor_193 = None + getitem_951 = split_51[0] + getitem_952 = split_51[1] + getitem_953 = split_51[2] + getitem_954 = split_51[3] + getitem_955 = split_51[4] + getitem_956 = split_51[5] + getitem_957 = split_51[6] + getitem_958 = split_51[7] + getitem_959 = split_51[8] + getitem_960 = split_51[9] + getitem_961 = split_51[10] + getitem_962 = split_51[11] + getitem_963 = split_51[12] + getitem_964 = split_51[13] + getitem_965 = split_51[14] + getitem_966 = split_51[15]; split_51 = None + cat_80 = torch.ops.aten.cat.default([getitem_951, getitem_952, getitem_953, getitem_954, getitem_955, getitem_956, getitem_957, getitem_958, getitem_959, getitem_960, getitem_961, getitem_962, getitem_963, getitem_964, getitem_965, getitem_966], 1); getitem_951 = getitem_952 = getitem_953 = getitem_954 = getitem_955 = getitem_956 = getitem_957 = getitem_958 = getitem_959 = getitem_960 = getitem_961 = getitem_962 = getitem_963 = getitem_964 = getitem_965 = getitem_966 = None + cumsum_26 = torch.ops.aten.cumsum.default(convert_element_type_500, 0, dtype = torch.int32); convert_element_type_500 = None + permute_140 = torch.ops.aten.permute.default(cat_78, [0, 2, 1]); cat_78 = None + _grouped_mm_24 = torch.ops.aten._grouped_mm.default(index_17, permute_140, cumsum_26) + convert_element_type_508 = torch.ops.prims.convert_element_type.default(_grouped_mm_24, torch.float32) + neg_17 = torch.ops.aten.neg.default(convert_element_type_508) + exp_26 = torch.ops.aten.exp.default(neg_17); neg_17 = None + add_576 = torch.ops.aten.add.Tensor(exp_26, 1); exp_26 = None + div_44 = torch.ops.aten.div.Tensor(convert_element_type_508, add_576); convert_element_type_508 = add_576 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(div_44, torch.bfloat16); div_44 = None + permute_141 = torch.ops.aten.permute.default(cat_80, [0, 2, 1]); cat_80 = None + _grouped_mm_25 = torch.ops.aten._grouped_mm.default(index_17, permute_141, cumsum_26) + mul_427 = torch.ops.aten.mul.Tensor(convert_element_type_509, _grouped_mm_25); convert_element_type_509 = None + permute_142 = torch.ops.aten.permute.default(cat_79, [0, 2, 1]); cat_79 = None + _grouped_mm_26 = torch.ops.aten._grouped_mm.default(mul_427, permute_142, cumsum_26) + empty_8 = torch.ops.aten.empty.memory_format([sym_size_int_33, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_16 = torch.ops.aten.index_put.default(empty_8, [getitem_902], _grouped_mm_26); empty_8 = _grouped_mm_26 = None + slice_59 = torch.ops.aten.slice.Tensor(index_put_16, 0, 0, -1); index_put_16 = None + all_to_all_single_26 = torch.ops._c10d_functional.all_to_all_single.default(slice_59, [_local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135], [_local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143], '1033'); slice_59 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_26); all_to_all_single_26 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16) + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_510, 128, '0'); convert_element_type_510 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + mm_76 = torch.ops.aten.mm.default(view_594, permute_143); permute_143 = None + convert_element_type_513 = torch.ops.prims.convert_element_type.default(mm_76, torch.float32) + neg_18 = torch.ops.aten.neg.default(convert_element_type_513) + exp_27 = torch.ops.aten.exp.default(neg_18); neg_18 = None + add_612 = torch.ops.aten.add.Tensor(exp_27, 1); exp_27 = None + div_45 = torch.ops.aten.div.Tensor(convert_element_type_513, add_612); convert_element_type_513 = add_612 = None + convert_element_type_514 = torch.ops.prims.convert_element_type.default(div_45, torch.bfloat16); div_45 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16) + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 128, '0'); convert_element_type_515 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + mm_77 = torch.ops.aten.mm.default(view_594, permute_144); permute_144 = None + mul_447 = torch.ops.aten.mul.Tensor(convert_element_type_514, mm_77); convert_element_type_514 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16) + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 128, '0'); convert_element_type_518 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + mm_78 = torch.ops.aten.mm.default(mul_447, permute_145); permute_145 = None + index_put_17 = torch.ops.aten.index_put.default(full_default_1, [getitem_901], wait_tensor_196); wait_tensor_196 = None + view_634 = torch.ops.aten.view.default(mul_409, [-1, 1, 6]); mul_409 = None + view_635 = torch.ops.aten.view.default(index_put_17, [-1, 6, 2048]); index_put_17 = None + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_635, torch.float32); view_635 = None + bmm_8 = torch.ops.aten.bmm.default(view_634, convert_element_type_521) + convert_element_type_522 = torch.ops.prims.convert_element_type.default(bmm_8, torch.bfloat16); bmm_8 = None + squeeze_8 = torch.ops.aten.squeeze.dim(convert_element_type_522, 1); convert_element_type_522 = None + add_616 = torch.ops.aten.add.Tensor(mm_78, squeeze_8); mm_78 = squeeze_8 = None + view_636 = torch.ops.aten.view.default(add_616, [2, 4096, 2048]); add_616 = None + add_617 = torch.ops.aten.add.Tensor(add_552, view_636); view_636 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16) + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 128, '0'); convert_element_type_523 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + convert_element_type_524 = torch.ops.prims.convert_element_type.default(add_617, torch.float32) + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_524, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_618 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_618); add_618 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_524, rsqrt_30); convert_element_type_524 = None + mul_451 = torch.ops.aten.mul.Tensor(mul_450, wait_tensor_200); mul_450 = wait_tensor_200 = None + convert_element_type_525 = torch.ops.prims.convert_element_type.default(mul_451, torch.bfloat16); mul_451 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16) + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 128, '0'); convert_element_type_526 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_146 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + view_639 = torch.ops.aten.view.default(convert_element_type_525, [8192, 2048]); convert_element_type_525 = None + mm_79 = torch.ops.aten.mm.default(view_639, permute_146); permute_146 = None + view_640 = torch.ops.aten.view.default(mm_79, [2, 4096, 3072]); mm_79 = None + view_641 = torch.ops.aten.view.default(view_640, [2, 4096, -1, 192]); view_640 = None + split_with_sizes_30 = torch.ops.aten.split_with_sizes.default(view_641, [128, 64], -1); view_641 = None + getitem_999 = split_with_sizes_30[0] + getitem_1000 = split_with_sizes_30[1]; split_with_sizes_30 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(getitem_1000, torch.float32); getitem_1000 = None + view_642 = torch.ops.aten.view.default(convert_element_type_529, [2, 4096, 16, -1, 2]); convert_element_type_529 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_642); view_642 = None + mul_452 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_7); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_452); mul_452 = None + view_644 = torch.ops.aten.view.default(view_as_real_20, [2, 4096, 16, 64]); view_as_real_20 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(view_644, torch.bfloat16); view_644 = None + cat_83 = torch.ops.aten.cat.default([getitem_999, convert_element_type_530], -1); getitem_999 = convert_element_type_530 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16) + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_531, 128, '0'); convert_element_type_531 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + slice_61 = torch.ops.aten.slice.Tensor(wait_tensor_202, 0, 0, 576); wait_tensor_202 = None + permute_147 = torch.ops.aten.permute.default(slice_61, [1, 0]); slice_61 = None + mm_80 = torch.ops.aten.mm.default(view_639, permute_147); permute_147 = None + view_647 = torch.ops.aten.view.default(mm_80, [2, 4096, 576]); mm_80 = None + split_with_sizes_31 = torch.ops.aten.split_with_sizes.default(view_647, [512, 64], -1); view_647 = None + getitem_1001 = split_with_sizes_31[0] + getitem_1002 = split_with_sizes_31[1]; split_with_sizes_31 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(getitem_1002, 2); getitem_1002 = None + convert_element_type_534 = torch.ops.prims.convert_element_type.default(unsqueeze_19, torch.float32); unsqueeze_19 = None + view_648 = torch.ops.aten.view.default(convert_element_type_534, [2, 4096, 1, -1, 2]); convert_element_type_534 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_648); view_648 = None + mul_453 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_7); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_453); mul_453 = None + view_650 = torch.ops.aten.view.default(view_as_real_21, [2, 4096, 1, 64]); view_as_real_21 = None + convert_element_type_535 = torch.ops.prims.convert_element_type.default(view_650, torch.bfloat16); view_650 = None + convert_element_type_536 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16) + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_536, 128, '0'); convert_element_type_536 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + convert_element_type_537 = torch.ops.prims.convert_element_type.default(getitem_1001, torch.float32) + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_537, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_619 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_619); add_619 = None + mul_454 = torch.ops.aten.mul.Tensor(convert_element_type_537, rsqrt_31); convert_element_type_537 = None + mul_455 = torch.ops.aten.mul.Tensor(mul_454, wait_tensor_203); mul_454 = wait_tensor_203 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(mul_455, torch.bfloat16); mul_455 = None + convert_element_type_539 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16) + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_539, 128, '0'); convert_element_type_539 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_148 = torch.ops.aten.permute.default(wait_tensor_204, [1, 0]); wait_tensor_204 = None + view_653 = torch.ops.aten.view.default(convert_element_type_538, [8192, 512]); convert_element_type_538 = None + mm_81 = torch.ops.aten.mm.default(view_653, permute_148); permute_148 = None + view_654 = torch.ops.aten.view.default(mm_81, [2, 4096, 4096]); mm_81 = None + view_655 = torch.ops.aten.view.default(view_654, [2, 4096, -1, 256]); view_654 = None + split_with_sizes_32 = torch.ops.aten.split_with_sizes.default(view_655, [128, 128], -1); view_655 = None + getitem_1003 = split_with_sizes_32[0] + getitem_1004 = split_with_sizes_32[1]; split_with_sizes_32 = None + expand_10 = torch.ops.aten.expand.default(convert_element_type_535, [-1, -1, 16, -1]); convert_element_type_535 = None + cat_84 = torch.ops.aten.cat.default([getitem_1003, expand_10], -1); getitem_1003 = expand_10 = None + permute_149 = torch.ops.aten.permute.default(cat_83, [0, 2, 1, 3]); cat_83 = None + permute_150 = torch.ops.aten.permute.default(cat_84, [0, 2, 1, 3]); cat_84 = None + permute_151 = torch.ops.aten.permute.default(getitem_1004, [0, 2, 1, 3]); getitem_1004 = None + sdpa_score10 = self.sdpa_score10 + sdpa_mask10 = self.sdpa_mask10 + flex_attention_10 = torch.ops.higher_order.flex_attention(permute_149, permute_150, permute_151, sdpa_score10, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask10), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score10 = sdpa_mask10 = None + getitem_1005 = flex_attention_10[0] + getitem_1006 = flex_attention_10[1]; flex_attention_10 = None + permute_152 = torch.ops.aten.permute.default(getitem_1005, [0, 2, 1, 3]) + view_656 = torch.ops.aten.view.default(permute_152, [2, 4096, -1]); permute_152 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16) + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_542, 128, '0'); convert_element_type_542 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + view_658 = torch.ops.aten.view.default(view_656, [8192, 2048]); view_656 = None + mm_82 = torch.ops.aten.mm.default(view_658, permute_153); view_658 = permute_153 = None + view_659 = torch.ops.aten.view.default(mm_82, [2, 4096, 2048]); mm_82 = None + add_620 = torch.ops.aten.add.Tensor(add_617, view_659); view_659 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16) + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 128, '0'); convert_element_type_545 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + convert_element_type_546 = torch.ops.prims.convert_element_type.default(add_620, torch.float32) + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_546, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_621 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_621); add_621 = None + mul_456 = torch.ops.aten.mul.Tensor(convert_element_type_546, rsqrt_32); convert_element_type_546 = None + mul_457 = torch.ops.aten.mul.Tensor(mul_456, wait_tensor_206); mul_456 = wait_tensor_206 = None + convert_element_type_547 = torch.ops.prims.convert_element_type.default(mul_457, torch.bfloat16); mul_457 = None + view_661 = torch.ops.aten.view.default(convert_element_type_547, [-1, 2048]); convert_element_type_547 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16) + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 128, '0'); convert_element_type_548 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + slice_63 = torch.ops.aten.slice.Tensor(wait_tensor_207, 0, 0, 64); wait_tensor_207 = None + permute_154 = torch.ops.aten.permute.default(slice_63, [1, 0]); slice_63 = None + mm_83 = torch.ops.aten.mm.default(view_661, permute_154); permute_154 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(mm_83, torch.float32) + amax_9 = torch.ops.aten.amax.default(convert_element_type_551, [1], True) + sub_216 = torch.ops.aten.sub.Tensor(convert_element_type_551, amax_9); convert_element_type_551 = None + exp_28 = torch.ops.aten.exp.default(sub_216); sub_216 = None + sum_37 = torch.ops.aten.sum.dim_IntList(exp_28, [1], True) + div_46 = torch.ops.aten.div.Tensor(exp_28, sum_37); exp_28 = None + add_622 = torch.ops.aten.add.Tensor(div_46, primals_174); primals_174 = None + topk_9 = torch.ops.aten.topk.default(add_622, 6, -1, True, False); add_622 = None + getitem_1009 = topk_9[1]; topk_9 = None + gather_9 = torch.ops.aten.gather.default(div_46, 1, getitem_1009); div_46 = None + mul_458 = torch.ops.aten.mul.Tensor(gather_9, 1.0); gather_9 = None + view_663 = torch.ops.aten.view.default(getitem_1009, [-1]) + histc_18 = torch.ops.aten.histc.default(view_663, 64, 0, 64) + add_623 = torch.ops.aten.add.Tensor(primals_176, histc_18) + sort_9 = torch.ops.aten.sort.stable(view_663, stable = True); view_663 = None + getitem_1011 = sort_9[1]; sort_9 = None + div_47 = torch.ops.aten.div.Tensor_mode(getitem_1011, 6, rounding_mode = 'floor') + index_18 = torch.ops.aten.index.Tensor(view_661, [div_47]) + all_to_all_single_27 = torch.ops._c10d_functional.all_to_all_single.default(histc_18, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_27); all_to_all_single_27 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_208); wait_tensor_208 = None + view_667 = torch.ops.aten.view.default(histc_18, [8, -1]); histc_18 = None + sum_38 = torch.ops.aten.sum.dim_IntList(view_667, [1]); view_667 = None + device_put_18 = torch.ops.prims.device_put.default(sum_38, device(type='cpu'), True); sum_38 = None + view_668 = torch.ops.aten.view.default(wait_tensor_209, [8, -1]) + sum_39 = torch.ops.aten.sum.dim_IntList(view_668, [1]) + device_put_19 = torch.ops.prims.device_put.default(sum_39, device(type='cpu')); sum_39 = None + select_144 = torch.ops.aten.select.int(device_put_18, 0, 0) + _local_scalar_dense_144 = torch.ops.aten._local_scalar_dense.default(select_144); select_144 = None + ge_180 = _local_scalar_dense_144 >= 0 + _assert_scalar_144 = torch.ops.aten._assert_scalar.default(ge_180, "Runtime assertion failed for expression u144 >= 0 on node 'ge_144'"); ge_180 = _assert_scalar_144 = None + select_145 = torch.ops.aten.select.int(device_put_18, 0, 1) + _local_scalar_dense_145 = torch.ops.aten._local_scalar_dense.default(select_145); select_145 = None + ge_181 = _local_scalar_dense_145 >= 0 + _assert_scalar_145 = torch.ops.aten._assert_scalar.default(ge_181, "Runtime assertion failed for expression u145 >= 0 on node 'ge_145'"); ge_181 = _assert_scalar_145 = None + select_146 = torch.ops.aten.select.int(device_put_18, 0, 2) + _local_scalar_dense_146 = torch.ops.aten._local_scalar_dense.default(select_146); select_146 = None + ge_182 = _local_scalar_dense_146 >= 0 + _assert_scalar_146 = torch.ops.aten._assert_scalar.default(ge_182, "Runtime assertion failed for expression u146 >= 0 on node 'ge_146'"); ge_182 = _assert_scalar_146 = None + select_147 = torch.ops.aten.select.int(device_put_18, 0, 3) + _local_scalar_dense_147 = torch.ops.aten._local_scalar_dense.default(select_147); select_147 = None + ge_183 = _local_scalar_dense_147 >= 0 + _assert_scalar_147 = torch.ops.aten._assert_scalar.default(ge_183, "Runtime assertion failed for expression u147 >= 0 on node 'ge_147'"); ge_183 = _assert_scalar_147 = None + select_148 = torch.ops.aten.select.int(device_put_18, 0, 4) + _local_scalar_dense_148 = torch.ops.aten._local_scalar_dense.default(select_148); select_148 = None + ge_184 = _local_scalar_dense_148 >= 0 + _assert_scalar_148 = torch.ops.aten._assert_scalar.default(ge_184, "Runtime assertion failed for expression u148 >= 0 on node 'ge_148'"); ge_184 = _assert_scalar_148 = None + select_149 = torch.ops.aten.select.int(device_put_18, 0, 5) + _local_scalar_dense_149 = torch.ops.aten._local_scalar_dense.default(select_149); select_149 = None + ge_185 = _local_scalar_dense_149 >= 0 + _assert_scalar_149 = torch.ops.aten._assert_scalar.default(ge_185, "Runtime assertion failed for expression u149 >= 0 on node 'ge_149'"); ge_185 = _assert_scalar_149 = None + select_150 = torch.ops.aten.select.int(device_put_18, 0, 6) + _local_scalar_dense_150 = torch.ops.aten._local_scalar_dense.default(select_150); select_150 = None + ge_186 = _local_scalar_dense_150 >= 0 + _assert_scalar_150 = torch.ops.aten._assert_scalar.default(ge_186, "Runtime assertion failed for expression u150 >= 0 on node 'ge_150'"); ge_186 = _assert_scalar_150 = None + select_151 = torch.ops.aten.select.int(device_put_18, 0, 7); device_put_18 = None + _local_scalar_dense_151 = torch.ops.aten._local_scalar_dense.default(select_151); select_151 = None + ge_187 = _local_scalar_dense_151 >= 0 + _assert_scalar_151 = torch.ops.aten._assert_scalar.default(ge_187, "Runtime assertion failed for expression u151 >= 0 on node 'ge_151'"); ge_187 = _assert_scalar_151 = None + select_152 = torch.ops.aten.select.int(device_put_19, 0, 0) + _local_scalar_dense_152 = torch.ops.aten._local_scalar_dense.default(select_152); select_152 = None + ge_188 = _local_scalar_dense_152 >= 0 + _assert_scalar_152 = torch.ops.aten._assert_scalar.default(ge_188, "Runtime assertion failed for expression u152 >= 0 on node 'ge_152'"); ge_188 = _assert_scalar_152 = None + select_153 = torch.ops.aten.select.int(device_put_19, 0, 1) + _local_scalar_dense_153 = torch.ops.aten._local_scalar_dense.default(select_153); select_153 = None + ge_189 = _local_scalar_dense_153 >= 0 + _assert_scalar_153 = torch.ops.aten._assert_scalar.default(ge_189, "Runtime assertion failed for expression u153 >= 0 on node 'ge_153'"); ge_189 = _assert_scalar_153 = None + select_154 = torch.ops.aten.select.int(device_put_19, 0, 2) + _local_scalar_dense_154 = torch.ops.aten._local_scalar_dense.default(select_154); select_154 = None + ge_190 = _local_scalar_dense_154 >= 0 + _assert_scalar_154 = torch.ops.aten._assert_scalar.default(ge_190, "Runtime assertion failed for expression u154 >= 0 on node 'ge_154'"); ge_190 = _assert_scalar_154 = None + select_155 = torch.ops.aten.select.int(device_put_19, 0, 3) + _local_scalar_dense_155 = torch.ops.aten._local_scalar_dense.default(select_155); select_155 = None + ge_191 = _local_scalar_dense_155 >= 0 + _assert_scalar_155 = torch.ops.aten._assert_scalar.default(ge_191, "Runtime assertion failed for expression u155 >= 0 on node 'ge_155'"); ge_191 = _assert_scalar_155 = None + select_156 = torch.ops.aten.select.int(device_put_19, 0, 4) + _local_scalar_dense_156 = torch.ops.aten._local_scalar_dense.default(select_156); select_156 = None + ge_192 = _local_scalar_dense_156 >= 0 + _assert_scalar_156 = torch.ops.aten._assert_scalar.default(ge_192, "Runtime assertion failed for expression u156 >= 0 on node 'ge_156'"); ge_192 = _assert_scalar_156 = None + select_157 = torch.ops.aten.select.int(device_put_19, 0, 5) + _local_scalar_dense_157 = torch.ops.aten._local_scalar_dense.default(select_157); select_157 = None + ge_193 = _local_scalar_dense_157 >= 0 + _assert_scalar_157 = torch.ops.aten._assert_scalar.default(ge_193, "Runtime assertion failed for expression u157 >= 0 on node 'ge_157'"); ge_193 = _assert_scalar_157 = None + select_158 = torch.ops.aten.select.int(device_put_19, 0, 6) + _local_scalar_dense_158 = torch.ops.aten._local_scalar_dense.default(select_158); select_158 = None + ge_194 = _local_scalar_dense_158 >= 0 + _assert_scalar_158 = torch.ops.aten._assert_scalar.default(ge_194, "Runtime assertion failed for expression u158 >= 0 on node 'ge_158'"); ge_194 = _assert_scalar_158 = None + select_159 = torch.ops.aten.select.int(device_put_19, 0, 7); device_put_19 = None + _local_scalar_dense_159 = torch.ops.aten._local_scalar_dense.default(select_159); select_159 = None + ge_195 = _local_scalar_dense_159 >= 0 + _assert_scalar_159 = torch.ops.aten._assert_scalar.default(ge_195, "Runtime assertion failed for expression u159 >= 0 on node 'ge_159'"); ge_195 = _assert_scalar_159 = None + all_to_all_single_28 = torch.ops._c10d_functional.all_to_all_single.default(index_18, [_local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159], [_local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151], '1033'); index_18 = None + sym_size_int_36 = torch.ops.aten.sym_size.int(all_to_all_single_28, 0) + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_28); all_to_all_single_28 = None + sym_sum_18 = torch.sym_sum((_local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159)) + add_630 = sym_sum_18 + 64; sym_sum_18 = None + add_631 = add_630 + 8; add_630 = None + sub_219 = add_631 - 1; add_631 = None + floordiv_9 = sub_219 // 8; sub_219 = None + mul_463 = floordiv_9 * 8; floordiv_9 = None + cumsum_27 = torch.ops.aten.cumsum.default(wait_tensor_209, 0) + sub_220 = torch.ops.aten.sub.Tensor(cumsum_27, wait_tensor_209); cumsum_27 = None + sum_40 = torch.ops.aten.sum.dim_IntList(view_668, [0]); view_668 = None + clamp_min_9 = torch.ops.aten.clamp_min.default(sum_40, 8); sum_40 = None + add_632 = torch.ops.aten.add.Tensor(clamp_min_9, 8); clamp_min_9 = None + sub_221 = torch.ops.aten.sub.Tensor(add_632, 1); add_632 = None + div_48 = torch.ops.aten.div.Tensor_mode(sub_221, 8, rounding_mode = 'floor'); sub_221 = None + mul_464 = torch.ops.aten.mul.Tensor(div_48, 8); div_48 = None + convert_element_type_554 = torch.ops.prims.convert_element_type.default(mul_464, torch.int32); mul_464 = None + cumsum_28 = torch.ops.aten.cumsum.default(convert_element_type_554, 0) + sub_222 = torch.ops.aten.sub.Tensor(cumsum_28, convert_element_type_554); cumsum_28 = None + full_137 = torch.ops.aten.full.default([mul_463], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_463 = None + triton_kernel_wrapper_functional_proxy_9 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 9, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_209, 'start_index_values_ptr': sub_220, 'write_offsets_ptr': sub_222, 'output_ptr': full_137}, tensors_to_clone = ['output_ptr']); wait_tensor_209 = sub_220 = sub_222 = full_137 = None + getitem_1012 = triton_kernel_wrapper_functional_proxy_9['output_ptr']; triton_kernel_wrapper_functional_proxy_9 = None + cat_85 = torch.ops.aten.cat.default([wait_tensor_210, full_default]); wait_tensor_210 = None + sym_size_int_37 = torch.ops.aten.sym_size.int(cat_85, 0) + sym_sum_19 = torch.sym_sum((1, _local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159)) + index_19 = torch.ops.aten.index.Tensor(cat_85, [getitem_1012]); cat_85 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16) + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 16, '1025'); convert_element_type_556 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + split_55 = torch.ops.aten.split.Tensor(wait_tensor_211, 8); wait_tensor_211 = None + getitem_1029 = split_55[0] + getitem_1030 = split_55[1] + getitem_1031 = split_55[2] + getitem_1032 = split_55[3] + getitem_1033 = split_55[4] + getitem_1034 = split_55[5] + getitem_1035 = split_55[6] + getitem_1036 = split_55[7] + getitem_1037 = split_55[8] + getitem_1038 = split_55[9] + getitem_1039 = split_55[10] + getitem_1040 = split_55[11] + getitem_1041 = split_55[12] + getitem_1042 = split_55[13] + getitem_1043 = split_55[14] + getitem_1044 = split_55[15]; split_55 = None + cat_87 = torch.ops.aten.cat.default([getitem_1029, getitem_1030, getitem_1031, getitem_1032, getitem_1033, getitem_1034, getitem_1035, getitem_1036, getitem_1037, getitem_1038, getitem_1039, getitem_1040, getitem_1041, getitem_1042, getitem_1043, getitem_1044], 1); getitem_1029 = getitem_1030 = getitem_1031 = getitem_1032 = getitem_1033 = getitem_1034 = getitem_1035 = getitem_1036 = getitem_1037 = getitem_1038 = getitem_1039 = getitem_1040 = getitem_1041 = getitem_1042 = getitem_1043 = getitem_1044 = None + convert_element_type_558 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16) + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_558, 16, '1025'); convert_element_type_558 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + split_56 = torch.ops.aten.split.Tensor(wait_tensor_213, 8); wait_tensor_213 = None + getitem_1045 = split_56[0] + getitem_1046 = split_56[1] + getitem_1047 = split_56[2] + getitem_1048 = split_56[3] + getitem_1049 = split_56[4] + getitem_1050 = split_56[5] + getitem_1051 = split_56[6] + getitem_1052 = split_56[7] + getitem_1053 = split_56[8] + getitem_1054 = split_56[9] + getitem_1055 = split_56[10] + getitem_1056 = split_56[11] + getitem_1057 = split_56[12] + getitem_1058 = split_56[13] + getitem_1059 = split_56[14] + getitem_1060 = split_56[15]; split_56 = None + cat_88 = torch.ops.aten.cat.default([getitem_1045, getitem_1046, getitem_1047, getitem_1048, getitem_1049, getitem_1050, getitem_1051, getitem_1052, getitem_1053, getitem_1054, getitem_1055, getitem_1056, getitem_1057, getitem_1058, getitem_1059, getitem_1060], 1); getitem_1045 = getitem_1046 = getitem_1047 = getitem_1048 = getitem_1049 = getitem_1050 = getitem_1051 = getitem_1052 = getitem_1053 = getitem_1054 = getitem_1055 = getitem_1056 = getitem_1057 = getitem_1058 = getitem_1059 = getitem_1060 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16) + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 16, '1025'); convert_element_type_559 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + split_57 = torch.ops.aten.split.Tensor(wait_tensor_214, 8); wait_tensor_214 = None + getitem_1061 = split_57[0] + getitem_1062 = split_57[1] + getitem_1063 = split_57[2] + getitem_1064 = split_57[3] + getitem_1065 = split_57[4] + getitem_1066 = split_57[5] + getitem_1067 = split_57[6] + getitem_1068 = split_57[7] + getitem_1069 = split_57[8] + getitem_1070 = split_57[9] + getitem_1071 = split_57[10] + getitem_1072 = split_57[11] + getitem_1073 = split_57[12] + getitem_1074 = split_57[13] + getitem_1075 = split_57[14] + getitem_1076 = split_57[15]; split_57 = None + cat_89 = torch.ops.aten.cat.default([getitem_1061, getitem_1062, getitem_1063, getitem_1064, getitem_1065, getitem_1066, getitem_1067, getitem_1068, getitem_1069, getitem_1070, getitem_1071, getitem_1072, getitem_1073, getitem_1074, getitem_1075, getitem_1076], 1); getitem_1061 = getitem_1062 = getitem_1063 = getitem_1064 = getitem_1065 = getitem_1066 = getitem_1067 = getitem_1068 = getitem_1069 = getitem_1070 = getitem_1071 = getitem_1072 = getitem_1073 = getitem_1074 = getitem_1075 = getitem_1076 = None + cumsum_29 = torch.ops.aten.cumsum.default(convert_element_type_554, 0, dtype = torch.int32); convert_element_type_554 = None + permute_155 = torch.ops.aten.permute.default(cat_87, [0, 2, 1]); cat_87 = None + _grouped_mm_27 = torch.ops.aten._grouped_mm.default(index_19, permute_155, cumsum_29) + convert_element_type_562 = torch.ops.prims.convert_element_type.default(_grouped_mm_27, torch.float32) + neg_19 = torch.ops.aten.neg.default(convert_element_type_562) + exp_29 = torch.ops.aten.exp.default(neg_19); neg_19 = None + add_644 = torch.ops.aten.add.Tensor(exp_29, 1); exp_29 = None + div_49 = torch.ops.aten.div.Tensor(convert_element_type_562, add_644); convert_element_type_562 = add_644 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(div_49, torch.bfloat16); div_49 = None + permute_156 = torch.ops.aten.permute.default(cat_89, [0, 2, 1]); cat_89 = None + _grouped_mm_28 = torch.ops.aten._grouped_mm.default(index_19, permute_156, cumsum_29) + mul_476 = torch.ops.aten.mul.Tensor(convert_element_type_563, _grouped_mm_28); convert_element_type_563 = None + permute_157 = torch.ops.aten.permute.default(cat_88, [0, 2, 1]); cat_88 = None + _grouped_mm_29 = torch.ops.aten._grouped_mm.default(mul_476, permute_157, cumsum_29) + empty_9 = torch.ops.aten.empty.memory_format([sym_size_int_37, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_18 = torch.ops.aten.index_put.default(empty_9, [getitem_1012], _grouped_mm_29); empty_9 = _grouped_mm_29 = None + slice_65 = torch.ops.aten.slice.Tensor(index_put_18, 0, 0, -1); index_put_18 = None + all_to_all_single_29 = torch.ops._c10d_functional.all_to_all_single.default(slice_65, [_local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151], [_local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159], '1033'); slice_65 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_29); all_to_all_single_29 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16) + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_564, 128, '0'); convert_element_type_564 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_158 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + mm_84 = torch.ops.aten.mm.default(view_661, permute_158); permute_158 = None + convert_element_type_567 = torch.ops.prims.convert_element_type.default(mm_84, torch.float32) + neg_20 = torch.ops.aten.neg.default(convert_element_type_567) + exp_30 = torch.ops.aten.exp.default(neg_20); neg_20 = None + add_680 = torch.ops.aten.add.Tensor(exp_30, 1); exp_30 = None + div_50 = torch.ops.aten.div.Tensor(convert_element_type_567, add_680); convert_element_type_567 = add_680 = None + convert_element_type_568 = torch.ops.prims.convert_element_type.default(div_50, torch.bfloat16); div_50 = None + convert_element_type_569 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16) + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_569, 128, '0'); convert_element_type_569 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_159 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + mm_85 = torch.ops.aten.mm.default(view_661, permute_159); permute_159 = None + mul_496 = torch.ops.aten.mul.Tensor(convert_element_type_568, mm_85); convert_element_type_568 = None + convert_element_type_572 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16) + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_572, 128, '0'); convert_element_type_572 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_160 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_86 = torch.ops.aten.mm.default(mul_496, permute_160); permute_160 = None + index_put_19 = torch.ops.aten.index_put.default(full_default_1, [getitem_1011], wait_tensor_217); wait_tensor_217 = None + view_701 = torch.ops.aten.view.default(mul_458, [-1, 1, 6]); mul_458 = None + view_702 = torch.ops.aten.view.default(index_put_19, [-1, 6, 2048]); index_put_19 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_702, torch.float32); view_702 = None + bmm_9 = torch.ops.aten.bmm.default(view_701, convert_element_type_575) + convert_element_type_576 = torch.ops.prims.convert_element_type.default(bmm_9, torch.bfloat16); bmm_9 = None + squeeze_9 = torch.ops.aten.squeeze.dim(convert_element_type_576, 1); convert_element_type_576 = None + add_684 = torch.ops.aten.add.Tensor(mm_86, squeeze_9); mm_86 = squeeze_9 = None + view_703 = torch.ops.aten.view.default(add_684, [2, 4096, 2048]); add_684 = None + add_685 = torch.ops.aten.add.Tensor(add_620, view_703); view_703 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16) + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_577, 128, '0'); convert_element_type_577 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(add_685, torch.float32) + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_578, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_686 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_686); add_686 = None + mul_499 = torch.ops.aten.mul.Tensor(convert_element_type_578, rsqrt_33); convert_element_type_578 = None + mul_500 = torch.ops.aten.mul.Tensor(mul_499, wait_tensor_221); mul_499 = wait_tensor_221 = None + convert_element_type_579 = torch.ops.prims.convert_element_type.default(mul_500, torch.bfloat16); mul_500 = None + convert_element_type_580 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16) + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_580, 128, '0'); convert_element_type_580 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_222, [1, 0]); wait_tensor_222 = None + view_706 = torch.ops.aten.view.default(convert_element_type_579, [8192, 2048]); convert_element_type_579 = None + mm_87 = torch.ops.aten.mm.default(view_706, permute_161); permute_161 = None + view_707 = torch.ops.aten.view.default(mm_87, [2, 4096, 3072]); mm_87 = None + view_708 = torch.ops.aten.view.default(view_707, [2, 4096, -1, 192]); view_707 = None + split_with_sizes_33 = torch.ops.aten.split_with_sizes.default(view_708, [128, 64], -1); view_708 = None + getitem_1109 = split_with_sizes_33[0] + getitem_1110 = split_with_sizes_33[1]; split_with_sizes_33 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(getitem_1110, torch.float32); getitem_1110 = None + view_709 = torch.ops.aten.view.default(convert_element_type_583, [2, 4096, 16, -1, 2]); convert_element_type_583 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_709); view_709 = None + mul_501 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_7); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_501); mul_501 = None + view_711 = torch.ops.aten.view.default(view_as_real_22, [2, 4096, 16, 64]); view_as_real_22 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(view_711, torch.bfloat16); view_711 = None + cat_92 = torch.ops.aten.cat.default([getitem_1109, convert_element_type_584], -1); getitem_1109 = convert_element_type_584 = None + convert_element_type_585 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16) + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_585, 128, '0'); convert_element_type_585 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + slice_67 = torch.ops.aten.slice.Tensor(wait_tensor_223, 0, 0, 576); wait_tensor_223 = None + permute_162 = torch.ops.aten.permute.default(slice_67, [1, 0]); slice_67 = None + mm_88 = torch.ops.aten.mm.default(view_706, permute_162); permute_162 = None + view_714 = torch.ops.aten.view.default(mm_88, [2, 4096, 576]); mm_88 = None + split_with_sizes_34 = torch.ops.aten.split_with_sizes.default(view_714, [512, 64], -1); view_714 = None + getitem_1111 = split_with_sizes_34[0] + getitem_1112 = split_with_sizes_34[1]; split_with_sizes_34 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(getitem_1112, 2); getitem_1112 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(unsqueeze_21, torch.float32); unsqueeze_21 = None + view_715 = torch.ops.aten.view.default(convert_element_type_588, [2, 4096, 1, -1, 2]); convert_element_type_588 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_715); view_715 = None + mul_502 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_7); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_502); mul_502 = None + view_717 = torch.ops.aten.view.default(view_as_real_23, [2, 4096, 1, 64]); view_as_real_23 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(view_717, torch.bfloat16); view_717 = None + convert_element_type_590 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16) + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_590, 128, '0'); convert_element_type_590 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + convert_element_type_591 = torch.ops.prims.convert_element_type.default(getitem_1111, torch.float32) + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_591, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_687 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_687); add_687 = None + mul_503 = torch.ops.aten.mul.Tensor(convert_element_type_591, rsqrt_34); convert_element_type_591 = None + mul_504 = torch.ops.aten.mul.Tensor(mul_503, wait_tensor_224); mul_503 = wait_tensor_224 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(mul_504, torch.bfloat16); mul_504 = None + convert_element_type_593 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16) + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_593, 128, '0'); convert_element_type_593 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + view_720 = torch.ops.aten.view.default(convert_element_type_592, [8192, 512]); convert_element_type_592 = None + mm_89 = torch.ops.aten.mm.default(view_720, permute_163); permute_163 = None + view_721 = torch.ops.aten.view.default(mm_89, [2, 4096, 4096]); mm_89 = None + view_722 = torch.ops.aten.view.default(view_721, [2, 4096, -1, 256]); view_721 = None + split_with_sizes_35 = torch.ops.aten.split_with_sizes.default(view_722, [128, 128], -1); view_722 = None + getitem_1113 = split_with_sizes_35[0] + getitem_1114 = split_with_sizes_35[1]; split_with_sizes_35 = None + expand_11 = torch.ops.aten.expand.default(convert_element_type_589, [-1, -1, 16, -1]); convert_element_type_589 = None + cat_93 = torch.ops.aten.cat.default([getitem_1113, expand_11], -1); getitem_1113 = expand_11 = None + permute_164 = torch.ops.aten.permute.default(cat_92, [0, 2, 1, 3]); cat_92 = None + permute_165 = torch.ops.aten.permute.default(cat_93, [0, 2, 1, 3]); cat_93 = None + permute_166 = torch.ops.aten.permute.default(getitem_1114, [0, 2, 1, 3]); getitem_1114 = None + sdpa_score11 = self.sdpa_score11 + sdpa_mask11 = self.sdpa_mask11 + flex_attention_11 = torch.ops.higher_order.flex_attention(permute_164, permute_165, permute_166, sdpa_score11, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask11), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score11 = sdpa_mask11 = None + getitem_1115 = flex_attention_11[0] + getitem_1116 = flex_attention_11[1]; flex_attention_11 = None + permute_167 = torch.ops.aten.permute.default(getitem_1115, [0, 2, 1, 3]) + view_723 = torch.ops.aten.view.default(permute_167, [2, 4096, -1]); permute_167 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16) + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_596, 128, '0'); convert_element_type_596 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + permute_168 = torch.ops.aten.permute.default(wait_tensor_226, [1, 0]); wait_tensor_226 = None + view_725 = torch.ops.aten.view.default(view_723, [8192, 2048]); view_723 = None + mm_90 = torch.ops.aten.mm.default(view_725, permute_168); view_725 = permute_168 = None + view_726 = torch.ops.aten.view.default(mm_90, [2, 4096, 2048]); mm_90 = None + add_688 = torch.ops.aten.add.Tensor(add_685, view_726); view_726 = None + convert_element_type_599 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16) + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_599, 128, '0'); convert_element_type_599 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + convert_element_type_600 = torch.ops.prims.convert_element_type.default(add_688, torch.float32) + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_600, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_689 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_689); add_689 = None + mul_505 = torch.ops.aten.mul.Tensor(convert_element_type_600, rsqrt_35); convert_element_type_600 = None + mul_506 = torch.ops.aten.mul.Tensor(mul_505, wait_tensor_227); mul_505 = wait_tensor_227 = None + convert_element_type_601 = torch.ops.prims.convert_element_type.default(mul_506, torch.bfloat16); mul_506 = None + view_728 = torch.ops.aten.view.default(convert_element_type_601, [-1, 2048]); convert_element_type_601 = None + convert_element_type_602 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16) + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_602, 128, '0'); convert_element_type_602 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + slice_69 = torch.ops.aten.slice.Tensor(wait_tensor_228, 0, 0, 64); wait_tensor_228 = None + permute_169 = torch.ops.aten.permute.default(slice_69, [1, 0]); slice_69 = None + mm_91 = torch.ops.aten.mm.default(view_728, permute_169); permute_169 = None + convert_element_type_605 = torch.ops.prims.convert_element_type.default(mm_91, torch.float32) + amax_10 = torch.ops.aten.amax.default(convert_element_type_605, [1], True) + sub_240 = torch.ops.aten.sub.Tensor(convert_element_type_605, amax_10); convert_element_type_605 = None + exp_31 = torch.ops.aten.exp.default(sub_240); sub_240 = None + sum_41 = torch.ops.aten.sum.dim_IntList(exp_31, [1], True) + div_51 = torch.ops.aten.div.Tensor(exp_31, sum_41); exp_31 = None + add_690 = torch.ops.aten.add.Tensor(div_51, primals_190); primals_190 = None + topk_10 = torch.ops.aten.topk.default(add_690, 6, -1, True, False); add_690 = None + getitem_1119 = topk_10[1]; topk_10 = None + gather_10 = torch.ops.aten.gather.default(div_51, 1, getitem_1119); div_51 = None + mul_507 = torch.ops.aten.mul.Tensor(gather_10, 1.0); gather_10 = None + view_730 = torch.ops.aten.view.default(getitem_1119, [-1]) + histc_20 = torch.ops.aten.histc.default(view_730, 64, 0, 64) + add_691 = torch.ops.aten.add.Tensor(primals_192, histc_20) + sort_10 = torch.ops.aten.sort.stable(view_730, stable = True); view_730 = None + getitem_1121 = sort_10[1]; sort_10 = None + div_52 = torch.ops.aten.div.Tensor_mode(getitem_1121, 6, rounding_mode = 'floor') + index_20 = torch.ops.aten.index.Tensor(view_728, [div_52]) + all_to_all_single_30 = torch.ops._c10d_functional.all_to_all_single.default(histc_20, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_30); all_to_all_single_30 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_229); wait_tensor_229 = None + view_734 = torch.ops.aten.view.default(histc_20, [8, -1]); histc_20 = None + sum_42 = torch.ops.aten.sum.dim_IntList(view_734, [1]); view_734 = None + device_put_20 = torch.ops.prims.device_put.default(sum_42, device(type='cpu'), True); sum_42 = None + view_735 = torch.ops.aten.view.default(wait_tensor_230, [8, -1]) + sum_43 = torch.ops.aten.sum.dim_IntList(view_735, [1]) + device_put_21 = torch.ops.prims.device_put.default(sum_43, device(type='cpu')); sum_43 = None + select_160 = torch.ops.aten.select.int(device_put_20, 0, 0) + _local_scalar_dense_160 = torch.ops.aten._local_scalar_dense.default(select_160); select_160 = None + ge_200 = _local_scalar_dense_160 >= 0 + _assert_scalar_160 = torch.ops.aten._assert_scalar.default(ge_200, "Runtime assertion failed for expression u160 >= 0 on node 'ge_160'"); ge_200 = _assert_scalar_160 = None + select_161 = torch.ops.aten.select.int(device_put_20, 0, 1) + _local_scalar_dense_161 = torch.ops.aten._local_scalar_dense.default(select_161); select_161 = None + ge_201 = _local_scalar_dense_161 >= 0 + _assert_scalar_161 = torch.ops.aten._assert_scalar.default(ge_201, "Runtime assertion failed for expression u161 >= 0 on node 'ge_161'"); ge_201 = _assert_scalar_161 = None + select_162 = torch.ops.aten.select.int(device_put_20, 0, 2) + _local_scalar_dense_162 = torch.ops.aten._local_scalar_dense.default(select_162); select_162 = None + ge_202 = _local_scalar_dense_162 >= 0 + _assert_scalar_162 = torch.ops.aten._assert_scalar.default(ge_202, "Runtime assertion failed for expression u162 >= 0 on node 'ge_162'"); ge_202 = _assert_scalar_162 = None + select_163 = torch.ops.aten.select.int(device_put_20, 0, 3) + _local_scalar_dense_163 = torch.ops.aten._local_scalar_dense.default(select_163); select_163 = None + ge_203 = _local_scalar_dense_163 >= 0 + _assert_scalar_163 = torch.ops.aten._assert_scalar.default(ge_203, "Runtime assertion failed for expression u163 >= 0 on node 'ge_163'"); ge_203 = _assert_scalar_163 = None + select_164 = torch.ops.aten.select.int(device_put_20, 0, 4) + _local_scalar_dense_164 = torch.ops.aten._local_scalar_dense.default(select_164); select_164 = None + ge_204 = _local_scalar_dense_164 >= 0 + _assert_scalar_164 = torch.ops.aten._assert_scalar.default(ge_204, "Runtime assertion failed for expression u164 >= 0 on node 'ge_164'"); ge_204 = _assert_scalar_164 = None + select_165 = torch.ops.aten.select.int(device_put_20, 0, 5) + _local_scalar_dense_165 = torch.ops.aten._local_scalar_dense.default(select_165); select_165 = None + ge_205 = _local_scalar_dense_165 >= 0 + _assert_scalar_165 = torch.ops.aten._assert_scalar.default(ge_205, "Runtime assertion failed for expression u165 >= 0 on node 'ge_165'"); ge_205 = _assert_scalar_165 = None + select_166 = torch.ops.aten.select.int(device_put_20, 0, 6) + _local_scalar_dense_166 = torch.ops.aten._local_scalar_dense.default(select_166); select_166 = None + ge_206 = _local_scalar_dense_166 >= 0 + _assert_scalar_166 = torch.ops.aten._assert_scalar.default(ge_206, "Runtime assertion failed for expression u166 >= 0 on node 'ge_166'"); ge_206 = _assert_scalar_166 = None + select_167 = torch.ops.aten.select.int(device_put_20, 0, 7); device_put_20 = None + _local_scalar_dense_167 = torch.ops.aten._local_scalar_dense.default(select_167); select_167 = None + ge_207 = _local_scalar_dense_167 >= 0 + _assert_scalar_167 = torch.ops.aten._assert_scalar.default(ge_207, "Runtime assertion failed for expression u167 >= 0 on node 'ge_167'"); ge_207 = _assert_scalar_167 = None + select_168 = torch.ops.aten.select.int(device_put_21, 0, 0) + _local_scalar_dense_168 = torch.ops.aten._local_scalar_dense.default(select_168); select_168 = None + ge_208 = _local_scalar_dense_168 >= 0 + _assert_scalar_168 = torch.ops.aten._assert_scalar.default(ge_208, "Runtime assertion failed for expression u168 >= 0 on node 'ge_168'"); ge_208 = _assert_scalar_168 = None + select_169 = torch.ops.aten.select.int(device_put_21, 0, 1) + _local_scalar_dense_169 = torch.ops.aten._local_scalar_dense.default(select_169); select_169 = None + ge_209 = _local_scalar_dense_169 >= 0 + _assert_scalar_169 = torch.ops.aten._assert_scalar.default(ge_209, "Runtime assertion failed for expression u169 >= 0 on node 'ge_169'"); ge_209 = _assert_scalar_169 = None + select_170 = torch.ops.aten.select.int(device_put_21, 0, 2) + _local_scalar_dense_170 = torch.ops.aten._local_scalar_dense.default(select_170); select_170 = None + ge_210 = _local_scalar_dense_170 >= 0 + _assert_scalar_170 = torch.ops.aten._assert_scalar.default(ge_210, "Runtime assertion failed for expression u170 >= 0 on node 'ge_170'"); ge_210 = _assert_scalar_170 = None + select_171 = torch.ops.aten.select.int(device_put_21, 0, 3) + _local_scalar_dense_171 = torch.ops.aten._local_scalar_dense.default(select_171); select_171 = None + ge_211 = _local_scalar_dense_171 >= 0 + _assert_scalar_171 = torch.ops.aten._assert_scalar.default(ge_211, "Runtime assertion failed for expression u171 >= 0 on node 'ge_171'"); ge_211 = _assert_scalar_171 = None + select_172 = torch.ops.aten.select.int(device_put_21, 0, 4) + _local_scalar_dense_172 = torch.ops.aten._local_scalar_dense.default(select_172); select_172 = None + ge_212 = _local_scalar_dense_172 >= 0 + _assert_scalar_172 = torch.ops.aten._assert_scalar.default(ge_212, "Runtime assertion failed for expression u172 >= 0 on node 'ge_172'"); ge_212 = _assert_scalar_172 = None + select_173 = torch.ops.aten.select.int(device_put_21, 0, 5) + _local_scalar_dense_173 = torch.ops.aten._local_scalar_dense.default(select_173); select_173 = None + ge_213 = _local_scalar_dense_173 >= 0 + _assert_scalar_173 = torch.ops.aten._assert_scalar.default(ge_213, "Runtime assertion failed for expression u173 >= 0 on node 'ge_173'"); ge_213 = _assert_scalar_173 = None + select_174 = torch.ops.aten.select.int(device_put_21, 0, 6) + _local_scalar_dense_174 = torch.ops.aten._local_scalar_dense.default(select_174); select_174 = None + ge_214 = _local_scalar_dense_174 >= 0 + _assert_scalar_174 = torch.ops.aten._assert_scalar.default(ge_214, "Runtime assertion failed for expression u174 >= 0 on node 'ge_174'"); ge_214 = _assert_scalar_174 = None + select_175 = torch.ops.aten.select.int(device_put_21, 0, 7); device_put_21 = None + _local_scalar_dense_175 = torch.ops.aten._local_scalar_dense.default(select_175); select_175 = None + ge_215 = _local_scalar_dense_175 >= 0 + _assert_scalar_175 = torch.ops.aten._assert_scalar.default(ge_215, "Runtime assertion failed for expression u175 >= 0 on node 'ge_175'"); ge_215 = _assert_scalar_175 = None + all_to_all_single_31 = torch.ops._c10d_functional.all_to_all_single.default(index_20, [_local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175], [_local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167], '1033'); index_20 = None + sym_size_int_40 = torch.ops.aten.sym_size.int(all_to_all_single_31, 0) + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_31); all_to_all_single_31 = None + sym_sum_20 = torch.sym_sum((_local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175)) + add_698 = sym_sum_20 + 64; sym_sum_20 = None + add_699 = add_698 + 8; add_698 = None + sub_243 = add_699 - 1; add_699 = None + floordiv_10 = sub_243 // 8; sub_243 = None + mul_512 = floordiv_10 * 8; floordiv_10 = None + cumsum_30 = torch.ops.aten.cumsum.default(wait_tensor_230, 0) + sub_244 = torch.ops.aten.sub.Tensor(cumsum_30, wait_tensor_230); cumsum_30 = None + sum_44 = torch.ops.aten.sum.dim_IntList(view_735, [0]); view_735 = None + clamp_min_10 = torch.ops.aten.clamp_min.default(sum_44, 8); sum_44 = None + add_700 = torch.ops.aten.add.Tensor(clamp_min_10, 8); clamp_min_10 = None + sub_245 = torch.ops.aten.sub.Tensor(add_700, 1); add_700 = None + div_53 = torch.ops.aten.div.Tensor_mode(sub_245, 8, rounding_mode = 'floor'); sub_245 = None + mul_513 = torch.ops.aten.mul.Tensor(div_53, 8); div_53 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(mul_513, torch.int32); mul_513 = None + cumsum_31 = torch.ops.aten.cumsum.default(convert_element_type_608, 0) + sub_246 = torch.ops.aten.sub.Tensor(cumsum_31, convert_element_type_608); cumsum_31 = None + full_150 = torch.ops.aten.full.default([mul_512], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_512 = None + triton_kernel_wrapper_functional_proxy_10 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 10, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_230, 'start_index_values_ptr': sub_244, 'write_offsets_ptr': sub_246, 'output_ptr': full_150}, tensors_to_clone = ['output_ptr']); wait_tensor_230 = sub_244 = sub_246 = full_150 = None + getitem_1122 = triton_kernel_wrapper_functional_proxy_10['output_ptr']; triton_kernel_wrapper_functional_proxy_10 = None + cat_94 = torch.ops.aten.cat.default([wait_tensor_231, full_default]); wait_tensor_231 = None + sym_size_int_41 = torch.ops.aten.sym_size.int(cat_94, 0) + sym_sum_21 = torch.sym_sum((1, _local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175)) + index_21 = torch.ops.aten.index.Tensor(cat_94, [getitem_1122]); cat_94 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16) + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_610, 16, '1025'); convert_element_type_610 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + split_61 = torch.ops.aten.split.Tensor(wait_tensor_232, 8); wait_tensor_232 = None + getitem_1139 = split_61[0] + getitem_1140 = split_61[1] + getitem_1141 = split_61[2] + getitem_1142 = split_61[3] + getitem_1143 = split_61[4] + getitem_1144 = split_61[5] + getitem_1145 = split_61[6] + getitem_1146 = split_61[7] + getitem_1147 = split_61[8] + getitem_1148 = split_61[9] + getitem_1149 = split_61[10] + getitem_1150 = split_61[11] + getitem_1151 = split_61[12] + getitem_1152 = split_61[13] + getitem_1153 = split_61[14] + getitem_1154 = split_61[15]; split_61 = None + cat_96 = torch.ops.aten.cat.default([getitem_1139, getitem_1140, getitem_1141, getitem_1142, getitem_1143, getitem_1144, getitem_1145, getitem_1146, getitem_1147, getitem_1148, getitem_1149, getitem_1150, getitem_1151, getitem_1152, getitem_1153, getitem_1154], 1); getitem_1139 = getitem_1140 = getitem_1141 = getitem_1142 = getitem_1143 = getitem_1144 = getitem_1145 = getitem_1146 = getitem_1147 = getitem_1148 = getitem_1149 = getitem_1150 = getitem_1151 = getitem_1152 = getitem_1153 = getitem_1154 = None + convert_element_type_612 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16) + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_612, 16, '1025'); convert_element_type_612 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + split_62 = torch.ops.aten.split.Tensor(wait_tensor_234, 8); wait_tensor_234 = None + getitem_1155 = split_62[0] + getitem_1156 = split_62[1] + getitem_1157 = split_62[2] + getitem_1158 = split_62[3] + getitem_1159 = split_62[4] + getitem_1160 = split_62[5] + getitem_1161 = split_62[6] + getitem_1162 = split_62[7] + getitem_1163 = split_62[8] + getitem_1164 = split_62[9] + getitem_1165 = split_62[10] + getitem_1166 = split_62[11] + getitem_1167 = split_62[12] + getitem_1168 = split_62[13] + getitem_1169 = split_62[14] + getitem_1170 = split_62[15]; split_62 = None + cat_97 = torch.ops.aten.cat.default([getitem_1155, getitem_1156, getitem_1157, getitem_1158, getitem_1159, getitem_1160, getitem_1161, getitem_1162, getitem_1163, getitem_1164, getitem_1165, getitem_1166, getitem_1167, getitem_1168, getitem_1169, getitem_1170], 1); getitem_1155 = getitem_1156 = getitem_1157 = getitem_1158 = getitem_1159 = getitem_1160 = getitem_1161 = getitem_1162 = getitem_1163 = getitem_1164 = getitem_1165 = getitem_1166 = getitem_1167 = getitem_1168 = getitem_1169 = getitem_1170 = None + convert_element_type_613 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16) + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_613, 16, '1025'); convert_element_type_613 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + split_63 = torch.ops.aten.split.Tensor(wait_tensor_235, 8); wait_tensor_235 = None + getitem_1171 = split_63[0] + getitem_1172 = split_63[1] + getitem_1173 = split_63[2] + getitem_1174 = split_63[3] + getitem_1175 = split_63[4] + getitem_1176 = split_63[5] + getitem_1177 = split_63[6] + getitem_1178 = split_63[7] + getitem_1179 = split_63[8] + getitem_1180 = split_63[9] + getitem_1181 = split_63[10] + getitem_1182 = split_63[11] + getitem_1183 = split_63[12] + getitem_1184 = split_63[13] + getitem_1185 = split_63[14] + getitem_1186 = split_63[15]; split_63 = None + cat_98 = torch.ops.aten.cat.default([getitem_1171, getitem_1172, getitem_1173, getitem_1174, getitem_1175, getitem_1176, getitem_1177, getitem_1178, getitem_1179, getitem_1180, getitem_1181, getitem_1182, getitem_1183, getitem_1184, getitem_1185, getitem_1186], 1); getitem_1171 = getitem_1172 = getitem_1173 = getitem_1174 = getitem_1175 = getitem_1176 = getitem_1177 = getitem_1178 = getitem_1179 = getitem_1180 = getitem_1181 = getitem_1182 = getitem_1183 = getitem_1184 = getitem_1185 = getitem_1186 = None + cumsum_32 = torch.ops.aten.cumsum.default(convert_element_type_608, 0, dtype = torch.int32); convert_element_type_608 = None + permute_170 = torch.ops.aten.permute.default(cat_96, [0, 2, 1]); cat_96 = None + _grouped_mm_30 = torch.ops.aten._grouped_mm.default(index_21, permute_170, cumsum_32) + convert_element_type_616 = torch.ops.prims.convert_element_type.default(_grouped_mm_30, torch.float32) + neg_21 = torch.ops.aten.neg.default(convert_element_type_616) + exp_32 = torch.ops.aten.exp.default(neg_21); neg_21 = None + add_712 = torch.ops.aten.add.Tensor(exp_32, 1); exp_32 = None + div_54 = torch.ops.aten.div.Tensor(convert_element_type_616, add_712); convert_element_type_616 = add_712 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(div_54, torch.bfloat16); div_54 = None + permute_171 = torch.ops.aten.permute.default(cat_98, [0, 2, 1]); cat_98 = None + _grouped_mm_31 = torch.ops.aten._grouped_mm.default(index_21, permute_171, cumsum_32) + mul_525 = torch.ops.aten.mul.Tensor(convert_element_type_617, _grouped_mm_31); convert_element_type_617 = None + permute_172 = torch.ops.aten.permute.default(cat_97, [0, 2, 1]); cat_97 = None + _grouped_mm_32 = torch.ops.aten._grouped_mm.default(mul_525, permute_172, cumsum_32) + empty_10 = torch.ops.aten.empty.memory_format([sym_size_int_41, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_20 = torch.ops.aten.index_put.default(empty_10, [getitem_1122], _grouped_mm_32); empty_10 = _grouped_mm_32 = None + slice_71 = torch.ops.aten.slice.Tensor(index_put_20, 0, 0, -1); index_put_20 = None + all_to_all_single_32 = torch.ops._c10d_functional.all_to_all_single.default(slice_71, [_local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167], [_local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175], '1033'); slice_71 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_32); all_to_all_single_32 = None + convert_element_type_618 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16) + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_618, 128, '0'); convert_element_type_618 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + mm_92 = torch.ops.aten.mm.default(view_728, permute_173); permute_173 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mm_92, torch.float32) + neg_22 = torch.ops.aten.neg.default(convert_element_type_621) + exp_33 = torch.ops.aten.exp.default(neg_22); neg_22 = None + add_748 = torch.ops.aten.add.Tensor(exp_33, 1); exp_33 = None + div_55 = torch.ops.aten.div.Tensor(convert_element_type_621, add_748); convert_element_type_621 = add_748 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(div_55, torch.bfloat16); div_55 = None + convert_element_type_623 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16) + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_623, 128, '0'); convert_element_type_623 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_240, [1, 0]); wait_tensor_240 = None + mm_93 = torch.ops.aten.mm.default(view_728, permute_174); permute_174 = None + mul_545 = torch.ops.aten.mul.Tensor(convert_element_type_622, mm_93); convert_element_type_622 = None + convert_element_type_626 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16) + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_626, 128, '0'); convert_element_type_626 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + mm_94 = torch.ops.aten.mm.default(mul_545, permute_175); permute_175 = None + index_put_21 = torch.ops.aten.index_put.default(full_default_1, [getitem_1121], wait_tensor_238); wait_tensor_238 = None + view_768 = torch.ops.aten.view.default(mul_507, [-1, 1, 6]); mul_507 = None + view_769 = torch.ops.aten.view.default(index_put_21, [-1, 6, 2048]); index_put_21 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(view_769, torch.float32); view_769 = None + bmm_10 = torch.ops.aten.bmm.default(view_768, convert_element_type_629) + convert_element_type_630 = torch.ops.prims.convert_element_type.default(bmm_10, torch.bfloat16); bmm_10 = None + squeeze_10 = torch.ops.aten.squeeze.dim(convert_element_type_630, 1); convert_element_type_630 = None + add_752 = torch.ops.aten.add.Tensor(mm_94, squeeze_10); mm_94 = squeeze_10 = None + view_770 = torch.ops.aten.view.default(add_752, [2, 4096, 2048]); add_752 = None + add_753 = torch.ops.aten.add.Tensor(add_688, view_770); view_770 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16) + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 128, '0'); convert_element_type_631 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + convert_element_type_632 = torch.ops.prims.convert_element_type.default(add_753, torch.float32) + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_632, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_754 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_754); add_754 = None + mul_548 = torch.ops.aten.mul.Tensor(convert_element_type_632, rsqrt_36); convert_element_type_632 = None + mul_549 = torch.ops.aten.mul.Tensor(mul_548, wait_tensor_242); mul_548 = wait_tensor_242 = None + convert_element_type_633 = torch.ops.prims.convert_element_type.default(mul_549, torch.bfloat16); mul_549 = None + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16) + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 128, '0'); convert_element_type_634 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + view_773 = torch.ops.aten.view.default(convert_element_type_633, [8192, 2048]); convert_element_type_633 = None + mm_95 = torch.ops.aten.mm.default(view_773, permute_176); permute_176 = None + view_774 = torch.ops.aten.view.default(mm_95, [2, 4096, 3072]); mm_95 = None + view_775 = torch.ops.aten.view.default(view_774, [2, 4096, -1, 192]); view_774 = None + split_with_sizes_36 = torch.ops.aten.split_with_sizes.default(view_775, [128, 64], -1); view_775 = None + getitem_1219 = split_with_sizes_36[0] + getitem_1220 = split_with_sizes_36[1]; split_with_sizes_36 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(getitem_1220, torch.float32); getitem_1220 = None + view_776 = torch.ops.aten.view.default(convert_element_type_637, [2, 4096, 16, -1, 2]); convert_element_type_637 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_776); view_776 = None + mul_550 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_7); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_550); mul_550 = None + view_778 = torch.ops.aten.view.default(view_as_real_24, [2, 4096, 16, 64]); view_as_real_24 = None + convert_element_type_638 = torch.ops.prims.convert_element_type.default(view_778, torch.bfloat16); view_778 = None + cat_101 = torch.ops.aten.cat.default([getitem_1219, convert_element_type_638], -1); getitem_1219 = convert_element_type_638 = None + convert_element_type_639 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16) + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_639, 128, '0'); convert_element_type_639 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + slice_73 = torch.ops.aten.slice.Tensor(wait_tensor_244, 0, 0, 576); wait_tensor_244 = None + permute_177 = torch.ops.aten.permute.default(slice_73, [1, 0]); slice_73 = None + mm_96 = torch.ops.aten.mm.default(view_773, permute_177); permute_177 = None + view_781 = torch.ops.aten.view.default(mm_96, [2, 4096, 576]); mm_96 = None + split_with_sizes_37 = torch.ops.aten.split_with_sizes.default(view_781, [512, 64], -1); view_781 = None + getitem_1221 = split_with_sizes_37[0] + getitem_1222 = split_with_sizes_37[1]; split_with_sizes_37 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(getitem_1222, 2); getitem_1222 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(unsqueeze_23, torch.float32); unsqueeze_23 = None + view_782 = torch.ops.aten.view.default(convert_element_type_642, [2, 4096, 1, -1, 2]); convert_element_type_642 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_782); view_782 = None + mul_551 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_7); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_551); mul_551 = None + view_784 = torch.ops.aten.view.default(view_as_real_25, [2, 4096, 1, 64]); view_as_real_25 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_784, torch.bfloat16); view_784 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16) + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 128, '0'); convert_element_type_644 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + convert_element_type_645 = torch.ops.prims.convert_element_type.default(getitem_1221, torch.float32) + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_645, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_755 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_755); add_755 = None + mul_552 = torch.ops.aten.mul.Tensor(convert_element_type_645, rsqrt_37); convert_element_type_645 = None + mul_553 = torch.ops.aten.mul.Tensor(mul_552, wait_tensor_245); mul_552 = wait_tensor_245 = None + convert_element_type_646 = torch.ops.prims.convert_element_type.default(mul_553, torch.bfloat16); mul_553 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16) + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 128, '0'); convert_element_type_647 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + view_787 = torch.ops.aten.view.default(convert_element_type_646, [8192, 512]); convert_element_type_646 = None + mm_97 = torch.ops.aten.mm.default(view_787, permute_178); permute_178 = None + view_788 = torch.ops.aten.view.default(mm_97, [2, 4096, 4096]); mm_97 = None + view_789 = torch.ops.aten.view.default(view_788, [2, 4096, -1, 256]); view_788 = None + split_with_sizes_38 = torch.ops.aten.split_with_sizes.default(view_789, [128, 128], -1); view_789 = None + getitem_1223 = split_with_sizes_38[0] + getitem_1224 = split_with_sizes_38[1]; split_with_sizes_38 = None + expand_12 = torch.ops.aten.expand.default(convert_element_type_643, [-1, -1, 16, -1]); convert_element_type_643 = None + cat_102 = torch.ops.aten.cat.default([getitem_1223, expand_12], -1); getitem_1223 = expand_12 = None + permute_179 = torch.ops.aten.permute.default(cat_101, [0, 2, 1, 3]); cat_101 = None + permute_180 = torch.ops.aten.permute.default(cat_102, [0, 2, 1, 3]); cat_102 = None + permute_181 = torch.ops.aten.permute.default(getitem_1224, [0, 2, 1, 3]); getitem_1224 = None + sdpa_score12 = self.sdpa_score12 + sdpa_mask12 = self.sdpa_mask12 + flex_attention_12 = torch.ops.higher_order.flex_attention(permute_179, permute_180, permute_181, sdpa_score12, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask12), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score12 = sdpa_mask12 = None + getitem_1225 = flex_attention_12[0] + getitem_1226 = flex_attention_12[1]; flex_attention_12 = None + permute_182 = torch.ops.aten.permute.default(getitem_1225, [0, 2, 1, 3]) + view_790 = torch.ops.aten.view.default(permute_182, [2, 4096, -1]); permute_182 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16) + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 128, '0'); convert_element_type_650 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + view_792 = torch.ops.aten.view.default(view_790, [8192, 2048]); view_790 = None + mm_98 = torch.ops.aten.mm.default(view_792, permute_183); view_792 = permute_183 = None + view_793 = torch.ops.aten.view.default(mm_98, [2, 4096, 2048]); mm_98 = None + add_756 = torch.ops.aten.add.Tensor(add_753, view_793); view_793 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16) + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_653, 128, '0'); convert_element_type_653 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(add_756, torch.float32) + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_654, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_757 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_757); add_757 = None + mul_554 = torch.ops.aten.mul.Tensor(convert_element_type_654, rsqrt_38); convert_element_type_654 = None + mul_555 = torch.ops.aten.mul.Tensor(mul_554, wait_tensor_248); mul_554 = wait_tensor_248 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(mul_555, torch.bfloat16); mul_555 = None + view_795 = torch.ops.aten.view.default(convert_element_type_655, [-1, 2048]); convert_element_type_655 = None + convert_element_type_656 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16) + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_656, 128, '0'); convert_element_type_656 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + slice_75 = torch.ops.aten.slice.Tensor(wait_tensor_249, 0, 0, 64); wait_tensor_249 = None + permute_184 = torch.ops.aten.permute.default(slice_75, [1, 0]); slice_75 = None + mm_99 = torch.ops.aten.mm.default(view_795, permute_184); permute_184 = None + convert_element_type_659 = torch.ops.prims.convert_element_type.default(mm_99, torch.float32) + amax_11 = torch.ops.aten.amax.default(convert_element_type_659, [1], True) + sub_264 = torch.ops.aten.sub.Tensor(convert_element_type_659, amax_11); convert_element_type_659 = None + exp_34 = torch.ops.aten.exp.default(sub_264); sub_264 = None + sum_45 = torch.ops.aten.sum.dim_IntList(exp_34, [1], True) + div_56 = torch.ops.aten.div.Tensor(exp_34, sum_45); exp_34 = None + add_758 = torch.ops.aten.add.Tensor(div_56, primals_206); primals_206 = None + topk_11 = torch.ops.aten.topk.default(add_758, 6, -1, True, False); add_758 = None + getitem_1229 = topk_11[1]; topk_11 = None + gather_11 = torch.ops.aten.gather.default(div_56, 1, getitem_1229); div_56 = None + mul_556 = torch.ops.aten.mul.Tensor(gather_11, 1.0); gather_11 = None + view_797 = torch.ops.aten.view.default(getitem_1229, [-1]) + histc_22 = torch.ops.aten.histc.default(view_797, 64, 0, 64) + add_759 = torch.ops.aten.add.Tensor(primals_208, histc_22) + sort_11 = torch.ops.aten.sort.stable(view_797, stable = True); view_797 = None + getitem_1231 = sort_11[1]; sort_11 = None + div_57 = torch.ops.aten.div.Tensor_mode(getitem_1231, 6, rounding_mode = 'floor') + index_22 = torch.ops.aten.index.Tensor(view_795, [div_57]) + all_to_all_single_33 = torch.ops._c10d_functional.all_to_all_single.default(histc_22, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_33); all_to_all_single_33 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_250); wait_tensor_250 = None + view_801 = torch.ops.aten.view.default(histc_22, [8, -1]); histc_22 = None + sum_46 = torch.ops.aten.sum.dim_IntList(view_801, [1]); view_801 = None + device_put_22 = torch.ops.prims.device_put.default(sum_46, device(type='cpu'), True); sum_46 = None + view_802 = torch.ops.aten.view.default(wait_tensor_251, [8, -1]) + sum_47 = torch.ops.aten.sum.dim_IntList(view_802, [1]) + device_put_23 = torch.ops.prims.device_put.default(sum_47, device(type='cpu')); sum_47 = None + select_176 = torch.ops.aten.select.int(device_put_22, 0, 0) + _local_scalar_dense_176 = torch.ops.aten._local_scalar_dense.default(select_176); select_176 = None + ge_220 = _local_scalar_dense_176 >= 0 + _assert_scalar_176 = torch.ops.aten._assert_scalar.default(ge_220, "Runtime assertion failed for expression u176 >= 0 on node 'ge_176'"); ge_220 = _assert_scalar_176 = None + select_177 = torch.ops.aten.select.int(device_put_22, 0, 1) + _local_scalar_dense_177 = torch.ops.aten._local_scalar_dense.default(select_177); select_177 = None + ge_221 = _local_scalar_dense_177 >= 0 + _assert_scalar_177 = torch.ops.aten._assert_scalar.default(ge_221, "Runtime assertion failed for expression u177 >= 0 on node 'ge_177'"); ge_221 = _assert_scalar_177 = None + select_178 = torch.ops.aten.select.int(device_put_22, 0, 2) + _local_scalar_dense_178 = torch.ops.aten._local_scalar_dense.default(select_178); select_178 = None + ge_222 = _local_scalar_dense_178 >= 0 + _assert_scalar_178 = torch.ops.aten._assert_scalar.default(ge_222, "Runtime assertion failed for expression u178 >= 0 on node 'ge_178'"); ge_222 = _assert_scalar_178 = None + select_179 = torch.ops.aten.select.int(device_put_22, 0, 3) + _local_scalar_dense_179 = torch.ops.aten._local_scalar_dense.default(select_179); select_179 = None + ge_223 = _local_scalar_dense_179 >= 0 + _assert_scalar_179 = torch.ops.aten._assert_scalar.default(ge_223, "Runtime assertion failed for expression u179 >= 0 on node 'ge_179'"); ge_223 = _assert_scalar_179 = None + select_180 = torch.ops.aten.select.int(device_put_22, 0, 4) + _local_scalar_dense_180 = torch.ops.aten._local_scalar_dense.default(select_180); select_180 = None + ge_224 = _local_scalar_dense_180 >= 0 + _assert_scalar_180 = torch.ops.aten._assert_scalar.default(ge_224, "Runtime assertion failed for expression u180 >= 0 on node 'ge_180'"); ge_224 = _assert_scalar_180 = None + select_181 = torch.ops.aten.select.int(device_put_22, 0, 5) + _local_scalar_dense_181 = torch.ops.aten._local_scalar_dense.default(select_181); select_181 = None + ge_225 = _local_scalar_dense_181 >= 0 + _assert_scalar_181 = torch.ops.aten._assert_scalar.default(ge_225, "Runtime assertion failed for expression u181 >= 0 on node 'ge_181'"); ge_225 = _assert_scalar_181 = None + select_182 = torch.ops.aten.select.int(device_put_22, 0, 6) + _local_scalar_dense_182 = torch.ops.aten._local_scalar_dense.default(select_182); select_182 = None + ge_226 = _local_scalar_dense_182 >= 0 + _assert_scalar_182 = torch.ops.aten._assert_scalar.default(ge_226, "Runtime assertion failed for expression u182 >= 0 on node 'ge_182'"); ge_226 = _assert_scalar_182 = None + select_183 = torch.ops.aten.select.int(device_put_22, 0, 7); device_put_22 = None + _local_scalar_dense_183 = torch.ops.aten._local_scalar_dense.default(select_183); select_183 = None + ge_227 = _local_scalar_dense_183 >= 0 + _assert_scalar_183 = torch.ops.aten._assert_scalar.default(ge_227, "Runtime assertion failed for expression u183 >= 0 on node 'ge_183'"); ge_227 = _assert_scalar_183 = None + select_184 = torch.ops.aten.select.int(device_put_23, 0, 0) + _local_scalar_dense_184 = torch.ops.aten._local_scalar_dense.default(select_184); select_184 = None + ge_228 = _local_scalar_dense_184 >= 0 + _assert_scalar_184 = torch.ops.aten._assert_scalar.default(ge_228, "Runtime assertion failed for expression u184 >= 0 on node 'ge_184'"); ge_228 = _assert_scalar_184 = None + select_185 = torch.ops.aten.select.int(device_put_23, 0, 1) + _local_scalar_dense_185 = torch.ops.aten._local_scalar_dense.default(select_185); select_185 = None + ge_229 = _local_scalar_dense_185 >= 0 + _assert_scalar_185 = torch.ops.aten._assert_scalar.default(ge_229, "Runtime assertion failed for expression u185 >= 0 on node 'ge_185'"); ge_229 = _assert_scalar_185 = None + select_186 = torch.ops.aten.select.int(device_put_23, 0, 2) + _local_scalar_dense_186 = torch.ops.aten._local_scalar_dense.default(select_186); select_186 = None + ge_230 = _local_scalar_dense_186 >= 0 + _assert_scalar_186 = torch.ops.aten._assert_scalar.default(ge_230, "Runtime assertion failed for expression u186 >= 0 on node 'ge_186'"); ge_230 = _assert_scalar_186 = None + select_187 = torch.ops.aten.select.int(device_put_23, 0, 3) + _local_scalar_dense_187 = torch.ops.aten._local_scalar_dense.default(select_187); select_187 = None + ge_231 = _local_scalar_dense_187 >= 0 + _assert_scalar_187 = torch.ops.aten._assert_scalar.default(ge_231, "Runtime assertion failed for expression u187 >= 0 on node 'ge_187'"); ge_231 = _assert_scalar_187 = None + select_188 = torch.ops.aten.select.int(device_put_23, 0, 4) + _local_scalar_dense_188 = torch.ops.aten._local_scalar_dense.default(select_188); select_188 = None + ge_232 = _local_scalar_dense_188 >= 0 + _assert_scalar_188 = torch.ops.aten._assert_scalar.default(ge_232, "Runtime assertion failed for expression u188 >= 0 on node 'ge_188'"); ge_232 = _assert_scalar_188 = None + select_189 = torch.ops.aten.select.int(device_put_23, 0, 5) + _local_scalar_dense_189 = torch.ops.aten._local_scalar_dense.default(select_189); select_189 = None + ge_233 = _local_scalar_dense_189 >= 0 + _assert_scalar_189 = torch.ops.aten._assert_scalar.default(ge_233, "Runtime assertion failed for expression u189 >= 0 on node 'ge_189'"); ge_233 = _assert_scalar_189 = None + select_190 = torch.ops.aten.select.int(device_put_23, 0, 6) + _local_scalar_dense_190 = torch.ops.aten._local_scalar_dense.default(select_190); select_190 = None + ge_234 = _local_scalar_dense_190 >= 0 + _assert_scalar_190 = torch.ops.aten._assert_scalar.default(ge_234, "Runtime assertion failed for expression u190 >= 0 on node 'ge_190'"); ge_234 = _assert_scalar_190 = None + select_191 = torch.ops.aten.select.int(device_put_23, 0, 7); device_put_23 = None + _local_scalar_dense_191 = torch.ops.aten._local_scalar_dense.default(select_191); select_191 = None + ge_235 = _local_scalar_dense_191 >= 0 + _assert_scalar_191 = torch.ops.aten._assert_scalar.default(ge_235, "Runtime assertion failed for expression u191 >= 0 on node 'ge_191'"); ge_235 = _assert_scalar_191 = None + all_to_all_single_34 = torch.ops._c10d_functional.all_to_all_single.default(index_22, [_local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191], [_local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183], '1033'); index_22 = None + sym_size_int_44 = torch.ops.aten.sym_size.int(all_to_all_single_34, 0) + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_34); all_to_all_single_34 = None + sym_sum_22 = torch.sym_sum((_local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191)) + add_766 = sym_sum_22 + 64; sym_sum_22 = None + add_767 = add_766 + 8; add_766 = None + sub_267 = add_767 - 1; add_767 = None + floordiv_11 = sub_267 // 8; sub_267 = None + mul_561 = floordiv_11 * 8; floordiv_11 = None + cumsum_33 = torch.ops.aten.cumsum.default(wait_tensor_251, 0) + sub_268 = torch.ops.aten.sub.Tensor(cumsum_33, wait_tensor_251); cumsum_33 = None + sum_48 = torch.ops.aten.sum.dim_IntList(view_802, [0]); view_802 = None + clamp_min_11 = torch.ops.aten.clamp_min.default(sum_48, 8); sum_48 = None + add_768 = torch.ops.aten.add.Tensor(clamp_min_11, 8); clamp_min_11 = None + sub_269 = torch.ops.aten.sub.Tensor(add_768, 1); add_768 = None + div_58 = torch.ops.aten.div.Tensor_mode(sub_269, 8, rounding_mode = 'floor'); sub_269 = None + mul_562 = torch.ops.aten.mul.Tensor(div_58, 8); div_58 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(mul_562, torch.int32); mul_562 = None + cumsum_34 = torch.ops.aten.cumsum.default(convert_element_type_662, 0) + sub_270 = torch.ops.aten.sub.Tensor(cumsum_34, convert_element_type_662); cumsum_34 = None + full_163 = torch.ops.aten.full.default([mul_561], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_561 = None + triton_kernel_wrapper_functional_proxy_11 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 11, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_251, 'start_index_values_ptr': sub_268, 'write_offsets_ptr': sub_270, 'output_ptr': full_163}, tensors_to_clone = ['output_ptr']); wait_tensor_251 = sub_268 = sub_270 = full_163 = None + getitem_1232 = triton_kernel_wrapper_functional_proxy_11['output_ptr']; triton_kernel_wrapper_functional_proxy_11 = None + cat_103 = torch.ops.aten.cat.default([wait_tensor_252, full_default]); wait_tensor_252 = None + sym_size_int_45 = torch.ops.aten.sym_size.int(cat_103, 0) + sym_sum_23 = torch.sym_sum((1, _local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191)) + index_23 = torch.ops.aten.index.Tensor(cat_103, [getitem_1232]); cat_103 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16) + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 16, '1025'); convert_element_type_664 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + split_67 = torch.ops.aten.split.Tensor(wait_tensor_253, 8); wait_tensor_253 = None + getitem_1249 = split_67[0] + getitem_1250 = split_67[1] + getitem_1251 = split_67[2] + getitem_1252 = split_67[3] + getitem_1253 = split_67[4] + getitem_1254 = split_67[5] + getitem_1255 = split_67[6] + getitem_1256 = split_67[7] + getitem_1257 = split_67[8] + getitem_1258 = split_67[9] + getitem_1259 = split_67[10] + getitem_1260 = split_67[11] + getitem_1261 = split_67[12] + getitem_1262 = split_67[13] + getitem_1263 = split_67[14] + getitem_1264 = split_67[15]; split_67 = None + cat_105 = torch.ops.aten.cat.default([getitem_1249, getitem_1250, getitem_1251, getitem_1252, getitem_1253, getitem_1254, getitem_1255, getitem_1256, getitem_1257, getitem_1258, getitem_1259, getitem_1260, getitem_1261, getitem_1262, getitem_1263, getitem_1264], 1); getitem_1249 = getitem_1250 = getitem_1251 = getitem_1252 = getitem_1253 = getitem_1254 = getitem_1255 = getitem_1256 = getitem_1257 = getitem_1258 = getitem_1259 = getitem_1260 = getitem_1261 = getitem_1262 = getitem_1263 = getitem_1264 = None + convert_element_type_666 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16) + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_666, 16, '1025'); convert_element_type_666 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + split_68 = torch.ops.aten.split.Tensor(wait_tensor_255, 8); wait_tensor_255 = None + getitem_1265 = split_68[0] + getitem_1266 = split_68[1] + getitem_1267 = split_68[2] + getitem_1268 = split_68[3] + getitem_1269 = split_68[4] + getitem_1270 = split_68[5] + getitem_1271 = split_68[6] + getitem_1272 = split_68[7] + getitem_1273 = split_68[8] + getitem_1274 = split_68[9] + getitem_1275 = split_68[10] + getitem_1276 = split_68[11] + getitem_1277 = split_68[12] + getitem_1278 = split_68[13] + getitem_1279 = split_68[14] + getitem_1280 = split_68[15]; split_68 = None + cat_106 = torch.ops.aten.cat.default([getitem_1265, getitem_1266, getitem_1267, getitem_1268, getitem_1269, getitem_1270, getitem_1271, getitem_1272, getitem_1273, getitem_1274, getitem_1275, getitem_1276, getitem_1277, getitem_1278, getitem_1279, getitem_1280], 1); getitem_1265 = getitem_1266 = getitem_1267 = getitem_1268 = getitem_1269 = getitem_1270 = getitem_1271 = getitem_1272 = getitem_1273 = getitem_1274 = getitem_1275 = getitem_1276 = getitem_1277 = getitem_1278 = getitem_1279 = getitem_1280 = None + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16) + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 16, '1025'); convert_element_type_667 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + split_69 = torch.ops.aten.split.Tensor(wait_tensor_256, 8); wait_tensor_256 = None + getitem_1281 = split_69[0] + getitem_1282 = split_69[1] + getitem_1283 = split_69[2] + getitem_1284 = split_69[3] + getitem_1285 = split_69[4] + getitem_1286 = split_69[5] + getitem_1287 = split_69[6] + getitem_1288 = split_69[7] + getitem_1289 = split_69[8] + getitem_1290 = split_69[9] + getitem_1291 = split_69[10] + getitem_1292 = split_69[11] + getitem_1293 = split_69[12] + getitem_1294 = split_69[13] + getitem_1295 = split_69[14] + getitem_1296 = split_69[15]; split_69 = None + cat_107 = torch.ops.aten.cat.default([getitem_1281, getitem_1282, getitem_1283, getitem_1284, getitem_1285, getitem_1286, getitem_1287, getitem_1288, getitem_1289, getitem_1290, getitem_1291, getitem_1292, getitem_1293, getitem_1294, getitem_1295, getitem_1296], 1); getitem_1281 = getitem_1282 = getitem_1283 = getitem_1284 = getitem_1285 = getitem_1286 = getitem_1287 = getitem_1288 = getitem_1289 = getitem_1290 = getitem_1291 = getitem_1292 = getitem_1293 = getitem_1294 = getitem_1295 = getitem_1296 = None + cumsum_35 = torch.ops.aten.cumsum.default(convert_element_type_662, 0, dtype = torch.int32); convert_element_type_662 = None + permute_185 = torch.ops.aten.permute.default(cat_105, [0, 2, 1]); cat_105 = None + _grouped_mm_33 = torch.ops.aten._grouped_mm.default(index_23, permute_185, cumsum_35) + convert_element_type_670 = torch.ops.prims.convert_element_type.default(_grouped_mm_33, torch.float32) + neg_23 = torch.ops.aten.neg.default(convert_element_type_670) + exp_35 = torch.ops.aten.exp.default(neg_23); neg_23 = None + add_780 = torch.ops.aten.add.Tensor(exp_35, 1); exp_35 = None + div_59 = torch.ops.aten.div.Tensor(convert_element_type_670, add_780); convert_element_type_670 = add_780 = None + convert_element_type_671 = torch.ops.prims.convert_element_type.default(div_59, torch.bfloat16); div_59 = None + permute_186 = torch.ops.aten.permute.default(cat_107, [0, 2, 1]); cat_107 = None + _grouped_mm_34 = torch.ops.aten._grouped_mm.default(index_23, permute_186, cumsum_35) + mul_574 = torch.ops.aten.mul.Tensor(convert_element_type_671, _grouped_mm_34); convert_element_type_671 = None + permute_187 = torch.ops.aten.permute.default(cat_106, [0, 2, 1]); cat_106 = None + _grouped_mm_35 = torch.ops.aten._grouped_mm.default(mul_574, permute_187, cumsum_35) + empty_11 = torch.ops.aten.empty.memory_format([sym_size_int_45, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_22 = torch.ops.aten.index_put.default(empty_11, [getitem_1232], _grouped_mm_35); empty_11 = _grouped_mm_35 = None + slice_77 = torch.ops.aten.slice.Tensor(index_put_22, 0, 0, -1); index_put_22 = None + all_to_all_single_35 = torch.ops._c10d_functional.all_to_all_single.default(slice_77, [_local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183], [_local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191], '1033'); slice_77 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_35); all_to_all_single_35 = None + convert_element_type_672 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16) + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_672, 128, '0'); convert_element_type_672 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + mm_100 = torch.ops.aten.mm.default(view_795, permute_188); permute_188 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(mm_100, torch.float32) + neg_24 = torch.ops.aten.neg.default(convert_element_type_675) + exp_36 = torch.ops.aten.exp.default(neg_24); neg_24 = None + add_816 = torch.ops.aten.add.Tensor(exp_36, 1); exp_36 = None + div_60 = torch.ops.aten.div.Tensor(convert_element_type_675, add_816); convert_element_type_675 = add_816 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(div_60, torch.bfloat16); div_60 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16) + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 128, '0'); convert_element_type_677 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + mm_101 = torch.ops.aten.mm.default(view_795, permute_189); permute_189 = None + mul_594 = torch.ops.aten.mul.Tensor(convert_element_type_676, mm_101); convert_element_type_676 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16) + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 128, '0'); convert_element_type_680 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_190 = torch.ops.aten.permute.default(wait_tensor_262, [1, 0]); wait_tensor_262 = None + mm_102 = torch.ops.aten.mm.default(mul_594, permute_190); permute_190 = None + index_put_23 = torch.ops.aten.index_put.default(full_default_1, [getitem_1231], wait_tensor_259); wait_tensor_259 = None + view_835 = torch.ops.aten.view.default(mul_556, [-1, 1, 6]); mul_556 = None + view_836 = torch.ops.aten.view.default(index_put_23, [-1, 6, 2048]); index_put_23 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(view_836, torch.float32); view_836 = None + bmm_11 = torch.ops.aten.bmm.default(view_835, convert_element_type_683) + convert_element_type_684 = torch.ops.prims.convert_element_type.default(bmm_11, torch.bfloat16); bmm_11 = None + squeeze_11 = torch.ops.aten.squeeze.dim(convert_element_type_684, 1); convert_element_type_684 = None + add_820 = torch.ops.aten.add.Tensor(mm_102, squeeze_11); mm_102 = squeeze_11 = None + view_837 = torch.ops.aten.view.default(add_820, [2, 4096, 2048]); add_820 = None + add_821 = torch.ops.aten.add.Tensor(add_756, view_837); view_837 = None + convert_element_type_685 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16) + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_685, 128, '0'); convert_element_type_685 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(add_821, torch.float32) + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_686, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_822 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_822); add_822 = None + mul_597 = torch.ops.aten.mul.Tensor(convert_element_type_686, rsqrt_39); convert_element_type_686 = None + mul_598 = torch.ops.aten.mul.Tensor(mul_597, wait_tensor_263); mul_597 = wait_tensor_263 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_598, torch.bfloat16); mul_598 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16) + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 128, '0'); convert_element_type_688 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_191 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + view_840 = torch.ops.aten.view.default(convert_element_type_687, [8192, 2048]); convert_element_type_687 = None + mm_103 = torch.ops.aten.mm.default(view_840, permute_191); permute_191 = None + view_841 = torch.ops.aten.view.default(mm_103, [2, 4096, 3072]); mm_103 = None + view_842 = torch.ops.aten.view.default(view_841, [2, 4096, -1, 192]); view_841 = None + split_with_sizes_39 = torch.ops.aten.split_with_sizes.default(view_842, [128, 64], -1); view_842 = None + getitem_1329 = split_with_sizes_39[0] + getitem_1330 = split_with_sizes_39[1]; split_with_sizes_39 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(getitem_1330, torch.float32); getitem_1330 = None + view_843 = torch.ops.aten.view.default(convert_element_type_691, [2, 4096, 16, -1, 2]); convert_element_type_691 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_843); view_843 = None + mul_599 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_7); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_599); mul_599 = None + view_845 = torch.ops.aten.view.default(view_as_real_26, [2, 4096, 16, 64]); view_as_real_26 = None + convert_element_type_692 = torch.ops.prims.convert_element_type.default(view_845, torch.bfloat16); view_845 = None + cat_110 = torch.ops.aten.cat.default([getitem_1329, convert_element_type_692], -1); getitem_1329 = convert_element_type_692 = None + convert_element_type_693 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16) + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_693, 128, '0'); convert_element_type_693 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + slice_79 = torch.ops.aten.slice.Tensor(wait_tensor_265, 0, 0, 576); wait_tensor_265 = None + permute_192 = torch.ops.aten.permute.default(slice_79, [1, 0]); slice_79 = None + mm_104 = torch.ops.aten.mm.default(view_840, permute_192); permute_192 = None + view_848 = torch.ops.aten.view.default(mm_104, [2, 4096, 576]); mm_104 = None + split_with_sizes_40 = torch.ops.aten.split_with_sizes.default(view_848, [512, 64], -1); view_848 = None + getitem_1331 = split_with_sizes_40[0] + getitem_1332 = split_with_sizes_40[1]; split_with_sizes_40 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(getitem_1332, 2); getitem_1332 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(unsqueeze_25, torch.float32); unsqueeze_25 = None + view_849 = torch.ops.aten.view.default(convert_element_type_696, [2, 4096, 1, -1, 2]); convert_element_type_696 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_849); view_849 = None + mul_600 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_7); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_600); mul_600 = None + view_851 = torch.ops.aten.view.default(view_as_real_27, [2, 4096, 1, 64]); view_as_real_27 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(view_851, torch.bfloat16); view_851 = None + convert_element_type_698 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16) + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_698, 128, '0'); convert_element_type_698 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + convert_element_type_699 = torch.ops.prims.convert_element_type.default(getitem_1331, torch.float32) + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_699, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_823 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_823); add_823 = None + mul_601 = torch.ops.aten.mul.Tensor(convert_element_type_699, rsqrt_40); convert_element_type_699 = None + mul_602 = torch.ops.aten.mul.Tensor(mul_601, wait_tensor_266); mul_601 = wait_tensor_266 = None + convert_element_type_700 = torch.ops.prims.convert_element_type.default(mul_602, torch.bfloat16); mul_602 = None + convert_element_type_701 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16) + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_701, 128, '0'); convert_element_type_701 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_193 = torch.ops.aten.permute.default(wait_tensor_267, [1, 0]); wait_tensor_267 = None + view_854 = torch.ops.aten.view.default(convert_element_type_700, [8192, 512]); convert_element_type_700 = None + mm_105 = torch.ops.aten.mm.default(view_854, permute_193); permute_193 = None + view_855 = torch.ops.aten.view.default(mm_105, [2, 4096, 4096]); mm_105 = None + view_856 = torch.ops.aten.view.default(view_855, [2, 4096, -1, 256]); view_855 = None + split_with_sizes_41 = torch.ops.aten.split_with_sizes.default(view_856, [128, 128], -1); view_856 = None + getitem_1333 = split_with_sizes_41[0] + getitem_1334 = split_with_sizes_41[1]; split_with_sizes_41 = None + expand_13 = torch.ops.aten.expand.default(convert_element_type_697, [-1, -1, 16, -1]); convert_element_type_697 = None + cat_111 = torch.ops.aten.cat.default([getitem_1333, expand_13], -1); getitem_1333 = expand_13 = None + permute_194 = torch.ops.aten.permute.default(cat_110, [0, 2, 1, 3]); cat_110 = None + permute_195 = torch.ops.aten.permute.default(cat_111, [0, 2, 1, 3]); cat_111 = None + permute_196 = torch.ops.aten.permute.default(getitem_1334, [0, 2, 1, 3]); getitem_1334 = None + sdpa_score13 = self.sdpa_score13 + sdpa_mask13 = self.sdpa_mask13 + flex_attention_13 = torch.ops.higher_order.flex_attention(permute_194, permute_195, permute_196, sdpa_score13, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask13), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score13 = sdpa_mask13 = None + getitem_1335 = flex_attention_13[0] + getitem_1336 = flex_attention_13[1]; flex_attention_13 = None + permute_197 = torch.ops.aten.permute.default(getitem_1335, [0, 2, 1, 3]) + view_857 = torch.ops.aten.view.default(permute_197, [2, 4096, -1]); permute_197 = None + convert_element_type_704 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16) + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_704, 128, '0'); convert_element_type_704 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + view_859 = torch.ops.aten.view.default(view_857, [8192, 2048]); view_857 = None + mm_106 = torch.ops.aten.mm.default(view_859, permute_198); view_859 = permute_198 = None + view_860 = torch.ops.aten.view.default(mm_106, [2, 4096, 2048]); mm_106 = None + add_824 = torch.ops.aten.add.Tensor(add_821, view_860); view_860 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16) + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_707, 128, '0'); convert_element_type_707 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(add_824, torch.float32) + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_708, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_825 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_825); add_825 = None + mul_603 = torch.ops.aten.mul.Tensor(convert_element_type_708, rsqrt_41); convert_element_type_708 = None + mul_604 = torch.ops.aten.mul.Tensor(mul_603, wait_tensor_269); mul_603 = wait_tensor_269 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(mul_604, torch.bfloat16); mul_604 = None + view_862 = torch.ops.aten.view.default(convert_element_type_709, [-1, 2048]); convert_element_type_709 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16) + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 128, '0'); convert_element_type_710 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + slice_81 = torch.ops.aten.slice.Tensor(wait_tensor_270, 0, 0, 64); wait_tensor_270 = None + permute_199 = torch.ops.aten.permute.default(slice_81, [1, 0]); slice_81 = None + mm_107 = torch.ops.aten.mm.default(view_862, permute_199); permute_199 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(mm_107, torch.float32) + amax_12 = torch.ops.aten.amax.default(convert_element_type_713, [1], True) + sub_288 = torch.ops.aten.sub.Tensor(convert_element_type_713, amax_12); convert_element_type_713 = None + exp_37 = torch.ops.aten.exp.default(sub_288); sub_288 = None + sum_49 = torch.ops.aten.sum.dim_IntList(exp_37, [1], True) + div_61 = torch.ops.aten.div.Tensor(exp_37, sum_49); exp_37 = None + add_826 = torch.ops.aten.add.Tensor(div_61, primals_222); primals_222 = None + topk_12 = torch.ops.aten.topk.default(add_826, 6, -1, True, False); add_826 = None + getitem_1339 = topk_12[1]; topk_12 = None + gather_12 = torch.ops.aten.gather.default(div_61, 1, getitem_1339); div_61 = None + mul_605 = torch.ops.aten.mul.Tensor(gather_12, 1.0); gather_12 = None + view_864 = torch.ops.aten.view.default(getitem_1339, [-1]) + histc_24 = torch.ops.aten.histc.default(view_864, 64, 0, 64) + add_827 = torch.ops.aten.add.Tensor(primals_224, histc_24) + sort_12 = torch.ops.aten.sort.stable(view_864, stable = True); view_864 = None + getitem_1341 = sort_12[1]; sort_12 = None + div_62 = torch.ops.aten.div.Tensor_mode(getitem_1341, 6, rounding_mode = 'floor') + index_24 = torch.ops.aten.index.Tensor(view_862, [div_62]) + all_to_all_single_36 = torch.ops._c10d_functional.all_to_all_single.default(histc_24, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_36); all_to_all_single_36 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_271); wait_tensor_271 = None + view_868 = torch.ops.aten.view.default(histc_24, [8, -1]); histc_24 = None + sum_50 = torch.ops.aten.sum.dim_IntList(view_868, [1]); view_868 = None + device_put_24 = torch.ops.prims.device_put.default(sum_50, device(type='cpu'), True); sum_50 = None + view_869 = torch.ops.aten.view.default(wait_tensor_272, [8, -1]) + sum_51 = torch.ops.aten.sum.dim_IntList(view_869, [1]) + device_put_25 = torch.ops.prims.device_put.default(sum_51, device(type='cpu')); sum_51 = None + select_192 = torch.ops.aten.select.int(device_put_24, 0, 0) + _local_scalar_dense_192 = torch.ops.aten._local_scalar_dense.default(select_192); select_192 = None + ge_240 = _local_scalar_dense_192 >= 0 + _assert_scalar_192 = torch.ops.aten._assert_scalar.default(ge_240, "Runtime assertion failed for expression u192 >= 0 on node 'ge_192'"); ge_240 = _assert_scalar_192 = None + select_193 = torch.ops.aten.select.int(device_put_24, 0, 1) + _local_scalar_dense_193 = torch.ops.aten._local_scalar_dense.default(select_193); select_193 = None + ge_241 = _local_scalar_dense_193 >= 0 + _assert_scalar_193 = torch.ops.aten._assert_scalar.default(ge_241, "Runtime assertion failed for expression u193 >= 0 on node 'ge_193'"); ge_241 = _assert_scalar_193 = None + select_194 = torch.ops.aten.select.int(device_put_24, 0, 2) + _local_scalar_dense_194 = torch.ops.aten._local_scalar_dense.default(select_194); select_194 = None + ge_242 = _local_scalar_dense_194 >= 0 + _assert_scalar_194 = torch.ops.aten._assert_scalar.default(ge_242, "Runtime assertion failed for expression u194 >= 0 on node 'ge_194'"); ge_242 = _assert_scalar_194 = None + select_195 = torch.ops.aten.select.int(device_put_24, 0, 3) + _local_scalar_dense_195 = torch.ops.aten._local_scalar_dense.default(select_195); select_195 = None + ge_243 = _local_scalar_dense_195 >= 0 + _assert_scalar_195 = torch.ops.aten._assert_scalar.default(ge_243, "Runtime assertion failed for expression u195 >= 0 on node 'ge_195'"); ge_243 = _assert_scalar_195 = None + select_196 = torch.ops.aten.select.int(device_put_24, 0, 4) + _local_scalar_dense_196 = torch.ops.aten._local_scalar_dense.default(select_196); select_196 = None + ge_244 = _local_scalar_dense_196 >= 0 + _assert_scalar_196 = torch.ops.aten._assert_scalar.default(ge_244, "Runtime assertion failed for expression u196 >= 0 on node 'ge_196'"); ge_244 = _assert_scalar_196 = None + select_197 = torch.ops.aten.select.int(device_put_24, 0, 5) + _local_scalar_dense_197 = torch.ops.aten._local_scalar_dense.default(select_197); select_197 = None + ge_245 = _local_scalar_dense_197 >= 0 + _assert_scalar_197 = torch.ops.aten._assert_scalar.default(ge_245, "Runtime assertion failed for expression u197 >= 0 on node 'ge_197'"); ge_245 = _assert_scalar_197 = None + select_198 = torch.ops.aten.select.int(device_put_24, 0, 6) + _local_scalar_dense_198 = torch.ops.aten._local_scalar_dense.default(select_198); select_198 = None + ge_246 = _local_scalar_dense_198 >= 0 + _assert_scalar_198 = torch.ops.aten._assert_scalar.default(ge_246, "Runtime assertion failed for expression u198 >= 0 on node 'ge_198'"); ge_246 = _assert_scalar_198 = None + select_199 = torch.ops.aten.select.int(device_put_24, 0, 7); device_put_24 = None + _local_scalar_dense_199 = torch.ops.aten._local_scalar_dense.default(select_199); select_199 = None + ge_247 = _local_scalar_dense_199 >= 0 + _assert_scalar_199 = torch.ops.aten._assert_scalar.default(ge_247, "Runtime assertion failed for expression u199 >= 0 on node 'ge_199'"); ge_247 = _assert_scalar_199 = None + select_200 = torch.ops.aten.select.int(device_put_25, 0, 0) + _local_scalar_dense_200 = torch.ops.aten._local_scalar_dense.default(select_200); select_200 = None + ge_248 = _local_scalar_dense_200 >= 0 + _assert_scalar_200 = torch.ops.aten._assert_scalar.default(ge_248, "Runtime assertion failed for expression u200 >= 0 on node 'ge_200'"); ge_248 = _assert_scalar_200 = None + select_201 = torch.ops.aten.select.int(device_put_25, 0, 1) + _local_scalar_dense_201 = torch.ops.aten._local_scalar_dense.default(select_201); select_201 = None + ge_249 = _local_scalar_dense_201 >= 0 + _assert_scalar_201 = torch.ops.aten._assert_scalar.default(ge_249, "Runtime assertion failed for expression u201 >= 0 on node 'ge_201'"); ge_249 = _assert_scalar_201 = None + select_202 = torch.ops.aten.select.int(device_put_25, 0, 2) + _local_scalar_dense_202 = torch.ops.aten._local_scalar_dense.default(select_202); select_202 = None + ge_250 = _local_scalar_dense_202 >= 0 + _assert_scalar_202 = torch.ops.aten._assert_scalar.default(ge_250, "Runtime assertion failed for expression u202 >= 0 on node 'ge_202'"); ge_250 = _assert_scalar_202 = None + select_203 = torch.ops.aten.select.int(device_put_25, 0, 3) + _local_scalar_dense_203 = torch.ops.aten._local_scalar_dense.default(select_203); select_203 = None + ge_251 = _local_scalar_dense_203 >= 0 + _assert_scalar_203 = torch.ops.aten._assert_scalar.default(ge_251, "Runtime assertion failed for expression u203 >= 0 on node 'ge_203'"); ge_251 = _assert_scalar_203 = None + select_204 = torch.ops.aten.select.int(device_put_25, 0, 4) + _local_scalar_dense_204 = torch.ops.aten._local_scalar_dense.default(select_204); select_204 = None + ge_252 = _local_scalar_dense_204 >= 0 + _assert_scalar_204 = torch.ops.aten._assert_scalar.default(ge_252, "Runtime assertion failed for expression u204 >= 0 on node 'ge_204'"); ge_252 = _assert_scalar_204 = None + select_205 = torch.ops.aten.select.int(device_put_25, 0, 5) + _local_scalar_dense_205 = torch.ops.aten._local_scalar_dense.default(select_205); select_205 = None + ge_253 = _local_scalar_dense_205 >= 0 + _assert_scalar_205 = torch.ops.aten._assert_scalar.default(ge_253, "Runtime assertion failed for expression u205 >= 0 on node 'ge_205'"); ge_253 = _assert_scalar_205 = None + select_206 = torch.ops.aten.select.int(device_put_25, 0, 6) + _local_scalar_dense_206 = torch.ops.aten._local_scalar_dense.default(select_206); select_206 = None + ge_254 = _local_scalar_dense_206 >= 0 + _assert_scalar_206 = torch.ops.aten._assert_scalar.default(ge_254, "Runtime assertion failed for expression u206 >= 0 on node 'ge_206'"); ge_254 = _assert_scalar_206 = None + select_207 = torch.ops.aten.select.int(device_put_25, 0, 7); device_put_25 = None + _local_scalar_dense_207 = torch.ops.aten._local_scalar_dense.default(select_207); select_207 = None + ge_255 = _local_scalar_dense_207 >= 0 + _assert_scalar_207 = torch.ops.aten._assert_scalar.default(ge_255, "Runtime assertion failed for expression u207 >= 0 on node 'ge_207'"); ge_255 = _assert_scalar_207 = None + all_to_all_single_37 = torch.ops._c10d_functional.all_to_all_single.default(index_24, [_local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207], [_local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199], '1033'); index_24 = None + sym_size_int_48 = torch.ops.aten.sym_size.int(all_to_all_single_37, 0) + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_37); all_to_all_single_37 = None + sym_sum_24 = torch.sym_sum((_local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207)) + add_834 = sym_sum_24 + 64; sym_sum_24 = None + add_835 = add_834 + 8; add_834 = None + sub_291 = add_835 - 1; add_835 = None + floordiv_12 = sub_291 // 8; sub_291 = None + mul_610 = floordiv_12 * 8; floordiv_12 = None + cumsum_36 = torch.ops.aten.cumsum.default(wait_tensor_272, 0) + sub_292 = torch.ops.aten.sub.Tensor(cumsum_36, wait_tensor_272); cumsum_36 = None + sum_52 = torch.ops.aten.sum.dim_IntList(view_869, [0]); view_869 = None + clamp_min_12 = torch.ops.aten.clamp_min.default(sum_52, 8); sum_52 = None + add_836 = torch.ops.aten.add.Tensor(clamp_min_12, 8); clamp_min_12 = None + sub_293 = torch.ops.aten.sub.Tensor(add_836, 1); add_836 = None + div_63 = torch.ops.aten.div.Tensor_mode(sub_293, 8, rounding_mode = 'floor'); sub_293 = None + mul_611 = torch.ops.aten.mul.Tensor(div_63, 8); div_63 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(mul_611, torch.int32); mul_611 = None + cumsum_37 = torch.ops.aten.cumsum.default(convert_element_type_716, 0) + sub_294 = torch.ops.aten.sub.Tensor(cumsum_37, convert_element_type_716); cumsum_37 = None + full_176 = torch.ops.aten.full.default([mul_610], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_610 = None + triton_kernel_wrapper_functional_proxy_12 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 12, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_272, 'start_index_values_ptr': sub_292, 'write_offsets_ptr': sub_294, 'output_ptr': full_176}, tensors_to_clone = ['output_ptr']); wait_tensor_272 = sub_292 = sub_294 = full_176 = None + getitem_1342 = triton_kernel_wrapper_functional_proxy_12['output_ptr']; triton_kernel_wrapper_functional_proxy_12 = None + cat_112 = torch.ops.aten.cat.default([wait_tensor_273, full_default]); wait_tensor_273 = None + sym_size_int_49 = torch.ops.aten.sym_size.int(cat_112, 0) + sym_sum_25 = torch.sym_sum((1, _local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207)) + index_25 = torch.ops.aten.index.Tensor(cat_112, [getitem_1342]); cat_112 = None + convert_element_type_718 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16) + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_718, 16, '1025'); convert_element_type_718 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + split_73 = torch.ops.aten.split.Tensor(wait_tensor_274, 8); wait_tensor_274 = None + getitem_1359 = split_73[0] + getitem_1360 = split_73[1] + getitem_1361 = split_73[2] + getitem_1362 = split_73[3] + getitem_1363 = split_73[4] + getitem_1364 = split_73[5] + getitem_1365 = split_73[6] + getitem_1366 = split_73[7] + getitem_1367 = split_73[8] + getitem_1368 = split_73[9] + getitem_1369 = split_73[10] + getitem_1370 = split_73[11] + getitem_1371 = split_73[12] + getitem_1372 = split_73[13] + getitem_1373 = split_73[14] + getitem_1374 = split_73[15]; split_73 = None + cat_114 = torch.ops.aten.cat.default([getitem_1359, getitem_1360, getitem_1361, getitem_1362, getitem_1363, getitem_1364, getitem_1365, getitem_1366, getitem_1367, getitem_1368, getitem_1369, getitem_1370, getitem_1371, getitem_1372, getitem_1373, getitem_1374], 1); getitem_1359 = getitem_1360 = getitem_1361 = getitem_1362 = getitem_1363 = getitem_1364 = getitem_1365 = getitem_1366 = getitem_1367 = getitem_1368 = getitem_1369 = getitem_1370 = getitem_1371 = getitem_1372 = getitem_1373 = getitem_1374 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16) + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_720, 16, '1025'); convert_element_type_720 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + split_74 = torch.ops.aten.split.Tensor(wait_tensor_276, 8); wait_tensor_276 = None + getitem_1375 = split_74[0] + getitem_1376 = split_74[1] + getitem_1377 = split_74[2] + getitem_1378 = split_74[3] + getitem_1379 = split_74[4] + getitem_1380 = split_74[5] + getitem_1381 = split_74[6] + getitem_1382 = split_74[7] + getitem_1383 = split_74[8] + getitem_1384 = split_74[9] + getitem_1385 = split_74[10] + getitem_1386 = split_74[11] + getitem_1387 = split_74[12] + getitem_1388 = split_74[13] + getitem_1389 = split_74[14] + getitem_1390 = split_74[15]; split_74 = None + cat_115 = torch.ops.aten.cat.default([getitem_1375, getitem_1376, getitem_1377, getitem_1378, getitem_1379, getitem_1380, getitem_1381, getitem_1382, getitem_1383, getitem_1384, getitem_1385, getitem_1386, getitem_1387, getitem_1388, getitem_1389, getitem_1390], 1); getitem_1375 = getitem_1376 = getitem_1377 = getitem_1378 = getitem_1379 = getitem_1380 = getitem_1381 = getitem_1382 = getitem_1383 = getitem_1384 = getitem_1385 = getitem_1386 = getitem_1387 = getitem_1388 = getitem_1389 = getitem_1390 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16) + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 16, '1025'); convert_element_type_721 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + split_75 = torch.ops.aten.split.Tensor(wait_tensor_277, 8); wait_tensor_277 = None + getitem_1391 = split_75[0] + getitem_1392 = split_75[1] + getitem_1393 = split_75[2] + getitem_1394 = split_75[3] + getitem_1395 = split_75[4] + getitem_1396 = split_75[5] + getitem_1397 = split_75[6] + getitem_1398 = split_75[7] + getitem_1399 = split_75[8] + getitem_1400 = split_75[9] + getitem_1401 = split_75[10] + getitem_1402 = split_75[11] + getitem_1403 = split_75[12] + getitem_1404 = split_75[13] + getitem_1405 = split_75[14] + getitem_1406 = split_75[15]; split_75 = None + cat_116 = torch.ops.aten.cat.default([getitem_1391, getitem_1392, getitem_1393, getitem_1394, getitem_1395, getitem_1396, getitem_1397, getitem_1398, getitem_1399, getitem_1400, getitem_1401, getitem_1402, getitem_1403, getitem_1404, getitem_1405, getitem_1406], 1); getitem_1391 = getitem_1392 = getitem_1393 = getitem_1394 = getitem_1395 = getitem_1396 = getitem_1397 = getitem_1398 = getitem_1399 = getitem_1400 = getitem_1401 = getitem_1402 = getitem_1403 = getitem_1404 = getitem_1405 = getitem_1406 = None + cumsum_38 = torch.ops.aten.cumsum.default(convert_element_type_716, 0, dtype = torch.int32); convert_element_type_716 = None + permute_200 = torch.ops.aten.permute.default(cat_114, [0, 2, 1]); cat_114 = None + _grouped_mm_36 = torch.ops.aten._grouped_mm.default(index_25, permute_200, cumsum_38) + convert_element_type_724 = torch.ops.prims.convert_element_type.default(_grouped_mm_36, torch.float32) + neg_25 = torch.ops.aten.neg.default(convert_element_type_724) + exp_38 = torch.ops.aten.exp.default(neg_25); neg_25 = None + add_848 = torch.ops.aten.add.Tensor(exp_38, 1); exp_38 = None + div_64 = torch.ops.aten.div.Tensor(convert_element_type_724, add_848); convert_element_type_724 = add_848 = None + convert_element_type_725 = torch.ops.prims.convert_element_type.default(div_64, torch.bfloat16); div_64 = None + permute_201 = torch.ops.aten.permute.default(cat_116, [0, 2, 1]); cat_116 = None + _grouped_mm_37 = torch.ops.aten._grouped_mm.default(index_25, permute_201, cumsum_38) + mul_623 = torch.ops.aten.mul.Tensor(convert_element_type_725, _grouped_mm_37); convert_element_type_725 = None + permute_202 = torch.ops.aten.permute.default(cat_115, [0, 2, 1]); cat_115 = None + _grouped_mm_38 = torch.ops.aten._grouped_mm.default(mul_623, permute_202, cumsum_38) + empty_12 = torch.ops.aten.empty.memory_format([sym_size_int_49, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_24 = torch.ops.aten.index_put.default(empty_12, [getitem_1342], _grouped_mm_38); empty_12 = _grouped_mm_38 = None + slice_83 = torch.ops.aten.slice.Tensor(index_put_24, 0, 0, -1); index_put_24 = None + all_to_all_single_38 = torch.ops._c10d_functional.all_to_all_single.default(slice_83, [_local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199], [_local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207], '1033'); slice_83 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_38); all_to_all_single_38 = None + convert_element_type_726 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16) + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_726, 128, '0'); convert_element_type_726 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_203 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + mm_108 = torch.ops.aten.mm.default(view_862, permute_203); permute_203 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mm_108, torch.float32) + neg_26 = torch.ops.aten.neg.default(convert_element_type_729) + exp_39 = torch.ops.aten.exp.default(neg_26); neg_26 = None + add_884 = torch.ops.aten.add.Tensor(exp_39, 1); exp_39 = None + div_65 = torch.ops.aten.div.Tensor(convert_element_type_729, add_884); convert_element_type_729 = add_884 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(div_65, torch.bfloat16); div_65 = None + convert_element_type_731 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16) + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_731, 128, '0'); convert_element_type_731 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_204 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + mm_109 = torch.ops.aten.mm.default(view_862, permute_204); permute_204 = None + mul_643 = torch.ops.aten.mul.Tensor(convert_element_type_730, mm_109); convert_element_type_730 = None + convert_element_type_734 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16) + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_734, 128, '0'); convert_element_type_734 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + mm_110 = torch.ops.aten.mm.default(mul_643, permute_205); permute_205 = None + index_put_25 = torch.ops.aten.index_put.default(full_default_1, [getitem_1341], wait_tensor_280); wait_tensor_280 = None + view_902 = torch.ops.aten.view.default(mul_605, [-1, 1, 6]); mul_605 = None + view_903 = torch.ops.aten.view.default(index_put_25, [-1, 6, 2048]); index_put_25 = None + convert_element_type_737 = torch.ops.prims.convert_element_type.default(view_903, torch.float32); view_903 = None + bmm_12 = torch.ops.aten.bmm.default(view_902, convert_element_type_737) + convert_element_type_738 = torch.ops.prims.convert_element_type.default(bmm_12, torch.bfloat16); bmm_12 = None + squeeze_12 = torch.ops.aten.squeeze.dim(convert_element_type_738, 1); convert_element_type_738 = None + add_888 = torch.ops.aten.add.Tensor(mm_110, squeeze_12); mm_110 = squeeze_12 = None + view_904 = torch.ops.aten.view.default(add_888, [2, 4096, 2048]); add_888 = None + add_889 = torch.ops.aten.add.Tensor(add_824, view_904); view_904 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16) + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_739, 128, '0'); convert_element_type_739 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(add_889, torch.float32) + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_740, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_890 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_890); add_890 = None + mul_646 = torch.ops.aten.mul.Tensor(convert_element_type_740, rsqrt_42); convert_element_type_740 = None + mul_647 = torch.ops.aten.mul.Tensor(mul_646, wait_tensor_284); mul_646 = wait_tensor_284 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(mul_647, torch.bfloat16); mul_647 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16) + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_742, 128, '0'); convert_element_type_742 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_285, [1, 0]); wait_tensor_285 = None + view_907 = torch.ops.aten.view.default(convert_element_type_741, [8192, 2048]); convert_element_type_741 = None + mm_111 = torch.ops.aten.mm.default(view_907, permute_206); permute_206 = None + view_908 = torch.ops.aten.view.default(mm_111, [2, 4096, 3072]); mm_111 = None + view_909 = torch.ops.aten.view.default(view_908, [2, 4096, -1, 192]); view_908 = None + split_with_sizes_42 = torch.ops.aten.split_with_sizes.default(view_909, [128, 64], -1); view_909 = None + getitem_1439 = split_with_sizes_42[0] + getitem_1440 = split_with_sizes_42[1]; split_with_sizes_42 = None + convert_element_type_745 = torch.ops.prims.convert_element_type.default(getitem_1440, torch.float32); getitem_1440 = None + view_910 = torch.ops.aten.view.default(convert_element_type_745, [2, 4096, 16, -1, 2]); convert_element_type_745 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_910); view_910 = None + mul_648 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_7); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_648); mul_648 = None + view_912 = torch.ops.aten.view.default(view_as_real_28, [2, 4096, 16, 64]); view_as_real_28 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(view_912, torch.bfloat16); view_912 = None + cat_119 = torch.ops.aten.cat.default([getitem_1439, convert_element_type_746], -1); getitem_1439 = convert_element_type_746 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16) + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_747, 128, '0'); convert_element_type_747 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + slice_85 = torch.ops.aten.slice.Tensor(wait_tensor_286, 0, 0, 576); wait_tensor_286 = None + permute_207 = torch.ops.aten.permute.default(slice_85, [1, 0]); slice_85 = None + mm_112 = torch.ops.aten.mm.default(view_907, permute_207); permute_207 = None + view_915 = torch.ops.aten.view.default(mm_112, [2, 4096, 576]); mm_112 = None + split_with_sizes_43 = torch.ops.aten.split_with_sizes.default(view_915, [512, 64], -1); view_915 = None + getitem_1441 = split_with_sizes_43[0] + getitem_1442 = split_with_sizes_43[1]; split_with_sizes_43 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(getitem_1442, 2); getitem_1442 = None + convert_element_type_750 = torch.ops.prims.convert_element_type.default(unsqueeze_27, torch.float32); unsqueeze_27 = None + view_916 = torch.ops.aten.view.default(convert_element_type_750, [2, 4096, 1, -1, 2]); convert_element_type_750 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_916); view_916 = None + mul_649 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_7); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_649); mul_649 = None + view_918 = torch.ops.aten.view.default(view_as_real_29, [2, 4096, 1, 64]); view_as_real_29 = None + convert_element_type_751 = torch.ops.prims.convert_element_type.default(view_918, torch.bfloat16); view_918 = None + convert_element_type_752 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16) + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_752, 128, '0'); convert_element_type_752 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(getitem_1441, torch.float32) + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_753, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_891 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_891); add_891 = None + mul_650 = torch.ops.aten.mul.Tensor(convert_element_type_753, rsqrt_43); convert_element_type_753 = None + mul_651 = torch.ops.aten.mul.Tensor(mul_650, wait_tensor_287); mul_650 = wait_tensor_287 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(mul_651, torch.bfloat16); mul_651 = None + convert_element_type_755 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16) + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_755, 128, '0'); convert_element_type_755 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + view_921 = torch.ops.aten.view.default(convert_element_type_754, [8192, 512]); convert_element_type_754 = None + mm_113 = torch.ops.aten.mm.default(view_921, permute_208); permute_208 = None + view_922 = torch.ops.aten.view.default(mm_113, [2, 4096, 4096]); mm_113 = None + view_923 = torch.ops.aten.view.default(view_922, [2, 4096, -1, 256]); view_922 = None + split_with_sizes_44 = torch.ops.aten.split_with_sizes.default(view_923, [128, 128], -1); view_923 = None + getitem_1443 = split_with_sizes_44[0] + getitem_1444 = split_with_sizes_44[1]; split_with_sizes_44 = None + expand_14 = torch.ops.aten.expand.default(convert_element_type_751, [-1, -1, 16, -1]); convert_element_type_751 = None + cat_120 = torch.ops.aten.cat.default([getitem_1443, expand_14], -1); getitem_1443 = expand_14 = None + permute_209 = torch.ops.aten.permute.default(cat_119, [0, 2, 1, 3]); cat_119 = None + permute_210 = torch.ops.aten.permute.default(cat_120, [0, 2, 1, 3]); cat_120 = None + permute_211 = torch.ops.aten.permute.default(getitem_1444, [0, 2, 1, 3]); getitem_1444 = None + sdpa_score14 = self.sdpa_score14 + sdpa_mask14 = self.sdpa_mask14 + flex_attention_14 = torch.ops.higher_order.flex_attention(permute_209, permute_210, permute_211, sdpa_score14, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask14), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score14 = sdpa_mask14 = None + getitem_1445 = flex_attention_14[0] + getitem_1446 = flex_attention_14[1]; flex_attention_14 = None + permute_212 = torch.ops.aten.permute.default(getitem_1445, [0, 2, 1, 3]) + view_924 = torch.ops.aten.view.default(permute_212, [2, 4096, -1]); permute_212 = None + convert_element_type_758 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16) + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_758, 128, '0'); convert_element_type_758 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_213 = torch.ops.aten.permute.default(wait_tensor_289, [1, 0]); wait_tensor_289 = None + view_926 = torch.ops.aten.view.default(view_924, [8192, 2048]); view_924 = None + mm_114 = torch.ops.aten.mm.default(view_926, permute_213); view_926 = permute_213 = None + view_927 = torch.ops.aten.view.default(mm_114, [2, 4096, 2048]); mm_114 = None + add_892 = torch.ops.aten.add.Tensor(add_889, view_927); view_927 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16) + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_761, 128, '0'); convert_element_type_761 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(add_892, torch.float32) + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_762, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_893 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_893); add_893 = None + mul_652 = torch.ops.aten.mul.Tensor(convert_element_type_762, rsqrt_44); convert_element_type_762 = None + mul_653 = torch.ops.aten.mul.Tensor(mul_652, wait_tensor_290); mul_652 = wait_tensor_290 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(mul_653, torch.bfloat16); mul_653 = None + view_929 = torch.ops.aten.view.default(convert_element_type_763, [-1, 2048]); convert_element_type_763 = None + convert_element_type_764 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16) + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_764, 128, '0'); convert_element_type_764 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + slice_87 = torch.ops.aten.slice.Tensor(wait_tensor_291, 0, 0, 64); wait_tensor_291 = None + permute_214 = torch.ops.aten.permute.default(slice_87, [1, 0]); slice_87 = None + mm_115 = torch.ops.aten.mm.default(view_929, permute_214); permute_214 = None + convert_element_type_767 = torch.ops.prims.convert_element_type.default(mm_115, torch.float32) + amax_13 = torch.ops.aten.amax.default(convert_element_type_767, [1], True) + sub_312 = torch.ops.aten.sub.Tensor(convert_element_type_767, amax_13); convert_element_type_767 = None + exp_40 = torch.ops.aten.exp.default(sub_312); sub_312 = None + sum_53 = torch.ops.aten.sum.dim_IntList(exp_40, [1], True) + div_66 = torch.ops.aten.div.Tensor(exp_40, sum_53); exp_40 = None + add_894 = torch.ops.aten.add.Tensor(div_66, primals_238); primals_238 = None + topk_13 = torch.ops.aten.topk.default(add_894, 6, -1, True, False); add_894 = None + getitem_1449 = topk_13[1]; topk_13 = None + gather_13 = torch.ops.aten.gather.default(div_66, 1, getitem_1449); div_66 = None + mul_654 = torch.ops.aten.mul.Tensor(gather_13, 1.0); gather_13 = None + view_931 = torch.ops.aten.view.default(getitem_1449, [-1]) + histc_26 = torch.ops.aten.histc.default(view_931, 64, 0, 64) + add_895 = torch.ops.aten.add.Tensor(primals_240, histc_26) + sort_13 = torch.ops.aten.sort.stable(view_931, stable = True); view_931 = None + getitem_1451 = sort_13[1]; sort_13 = None + div_67 = torch.ops.aten.div.Tensor_mode(getitem_1451, 6, rounding_mode = 'floor') + index_26 = torch.ops.aten.index.Tensor(view_929, [div_67]) + all_to_all_single_39 = torch.ops._c10d_functional.all_to_all_single.default(histc_26, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_39); all_to_all_single_39 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_292); wait_tensor_292 = None + view_935 = torch.ops.aten.view.default(histc_26, [8, -1]); histc_26 = None + sum_54 = torch.ops.aten.sum.dim_IntList(view_935, [1]); view_935 = None + device_put_26 = torch.ops.prims.device_put.default(sum_54, device(type='cpu'), True); sum_54 = None + view_936 = torch.ops.aten.view.default(wait_tensor_293, [8, -1]) + sum_55 = torch.ops.aten.sum.dim_IntList(view_936, [1]) + device_put_27 = torch.ops.prims.device_put.default(sum_55, device(type='cpu')); sum_55 = None + select_208 = torch.ops.aten.select.int(device_put_26, 0, 0) + _local_scalar_dense_208 = torch.ops.aten._local_scalar_dense.default(select_208); select_208 = None + ge_260 = _local_scalar_dense_208 >= 0 + _assert_scalar_208 = torch.ops.aten._assert_scalar.default(ge_260, "Runtime assertion failed for expression u208 >= 0 on node 'ge_208'"); ge_260 = _assert_scalar_208 = None + select_209 = torch.ops.aten.select.int(device_put_26, 0, 1) + _local_scalar_dense_209 = torch.ops.aten._local_scalar_dense.default(select_209); select_209 = None + ge_261 = _local_scalar_dense_209 >= 0 + _assert_scalar_209 = torch.ops.aten._assert_scalar.default(ge_261, "Runtime assertion failed for expression u209 >= 0 on node 'ge_209'"); ge_261 = _assert_scalar_209 = None + select_210 = torch.ops.aten.select.int(device_put_26, 0, 2) + _local_scalar_dense_210 = torch.ops.aten._local_scalar_dense.default(select_210); select_210 = None + ge_262 = _local_scalar_dense_210 >= 0 + _assert_scalar_210 = torch.ops.aten._assert_scalar.default(ge_262, "Runtime assertion failed for expression u210 >= 0 on node 'ge_210'"); ge_262 = _assert_scalar_210 = None + select_211 = torch.ops.aten.select.int(device_put_26, 0, 3) + _local_scalar_dense_211 = torch.ops.aten._local_scalar_dense.default(select_211); select_211 = None + ge_263 = _local_scalar_dense_211 >= 0 + _assert_scalar_211 = torch.ops.aten._assert_scalar.default(ge_263, "Runtime assertion failed for expression u211 >= 0 on node 'ge_211'"); ge_263 = _assert_scalar_211 = None + select_212 = torch.ops.aten.select.int(device_put_26, 0, 4) + _local_scalar_dense_212 = torch.ops.aten._local_scalar_dense.default(select_212); select_212 = None + ge_264 = _local_scalar_dense_212 >= 0 + _assert_scalar_212 = torch.ops.aten._assert_scalar.default(ge_264, "Runtime assertion failed for expression u212 >= 0 on node 'ge_212'"); ge_264 = _assert_scalar_212 = None + select_213 = torch.ops.aten.select.int(device_put_26, 0, 5) + _local_scalar_dense_213 = torch.ops.aten._local_scalar_dense.default(select_213); select_213 = None + ge_265 = _local_scalar_dense_213 >= 0 + _assert_scalar_213 = torch.ops.aten._assert_scalar.default(ge_265, "Runtime assertion failed for expression u213 >= 0 on node 'ge_213'"); ge_265 = _assert_scalar_213 = None + select_214 = torch.ops.aten.select.int(device_put_26, 0, 6) + _local_scalar_dense_214 = torch.ops.aten._local_scalar_dense.default(select_214); select_214 = None + ge_266 = _local_scalar_dense_214 >= 0 + _assert_scalar_214 = torch.ops.aten._assert_scalar.default(ge_266, "Runtime assertion failed for expression u214 >= 0 on node 'ge_214'"); ge_266 = _assert_scalar_214 = None + select_215 = torch.ops.aten.select.int(device_put_26, 0, 7); device_put_26 = None + _local_scalar_dense_215 = torch.ops.aten._local_scalar_dense.default(select_215); select_215 = None + ge_267 = _local_scalar_dense_215 >= 0 + _assert_scalar_215 = torch.ops.aten._assert_scalar.default(ge_267, "Runtime assertion failed for expression u215 >= 0 on node 'ge_215'"); ge_267 = _assert_scalar_215 = None + select_216 = torch.ops.aten.select.int(device_put_27, 0, 0) + _local_scalar_dense_216 = torch.ops.aten._local_scalar_dense.default(select_216); select_216 = None + ge_268 = _local_scalar_dense_216 >= 0 + _assert_scalar_216 = torch.ops.aten._assert_scalar.default(ge_268, "Runtime assertion failed for expression u216 >= 0 on node 'ge_216'"); ge_268 = _assert_scalar_216 = None + select_217 = torch.ops.aten.select.int(device_put_27, 0, 1) + _local_scalar_dense_217 = torch.ops.aten._local_scalar_dense.default(select_217); select_217 = None + ge_269 = _local_scalar_dense_217 >= 0 + _assert_scalar_217 = torch.ops.aten._assert_scalar.default(ge_269, "Runtime assertion failed for expression u217 >= 0 on node 'ge_217'"); ge_269 = _assert_scalar_217 = None + select_218 = torch.ops.aten.select.int(device_put_27, 0, 2) + _local_scalar_dense_218 = torch.ops.aten._local_scalar_dense.default(select_218); select_218 = None + ge_270 = _local_scalar_dense_218 >= 0 + _assert_scalar_218 = torch.ops.aten._assert_scalar.default(ge_270, "Runtime assertion failed for expression u218 >= 0 on node 'ge_218'"); ge_270 = _assert_scalar_218 = None + select_219 = torch.ops.aten.select.int(device_put_27, 0, 3) + _local_scalar_dense_219 = torch.ops.aten._local_scalar_dense.default(select_219); select_219 = None + ge_271 = _local_scalar_dense_219 >= 0 + _assert_scalar_219 = torch.ops.aten._assert_scalar.default(ge_271, "Runtime assertion failed for expression u219 >= 0 on node 'ge_219'"); ge_271 = _assert_scalar_219 = None + select_220 = torch.ops.aten.select.int(device_put_27, 0, 4) + _local_scalar_dense_220 = torch.ops.aten._local_scalar_dense.default(select_220); select_220 = None + ge_272 = _local_scalar_dense_220 >= 0 + _assert_scalar_220 = torch.ops.aten._assert_scalar.default(ge_272, "Runtime assertion failed for expression u220 >= 0 on node 'ge_220'"); ge_272 = _assert_scalar_220 = None + select_221 = torch.ops.aten.select.int(device_put_27, 0, 5) + _local_scalar_dense_221 = torch.ops.aten._local_scalar_dense.default(select_221); select_221 = None + ge_273 = _local_scalar_dense_221 >= 0 + _assert_scalar_221 = torch.ops.aten._assert_scalar.default(ge_273, "Runtime assertion failed for expression u221 >= 0 on node 'ge_221'"); ge_273 = _assert_scalar_221 = None + select_222 = torch.ops.aten.select.int(device_put_27, 0, 6) + _local_scalar_dense_222 = torch.ops.aten._local_scalar_dense.default(select_222); select_222 = None + ge_274 = _local_scalar_dense_222 >= 0 + _assert_scalar_222 = torch.ops.aten._assert_scalar.default(ge_274, "Runtime assertion failed for expression u222 >= 0 on node 'ge_222'"); ge_274 = _assert_scalar_222 = None + select_223 = torch.ops.aten.select.int(device_put_27, 0, 7); device_put_27 = None + _local_scalar_dense_223 = torch.ops.aten._local_scalar_dense.default(select_223); select_223 = None + ge_275 = _local_scalar_dense_223 >= 0 + _assert_scalar_223 = torch.ops.aten._assert_scalar.default(ge_275, "Runtime assertion failed for expression u223 >= 0 on node 'ge_223'"); ge_275 = _assert_scalar_223 = None + all_to_all_single_40 = torch.ops._c10d_functional.all_to_all_single.default(index_26, [_local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223], [_local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215], '1033'); index_26 = None + sym_size_int_52 = torch.ops.aten.sym_size.int(all_to_all_single_40, 0) + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_40); all_to_all_single_40 = None + sym_sum_26 = torch.sym_sum((_local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223)) + add_902 = sym_sum_26 + 64; sym_sum_26 = None + add_903 = add_902 + 8; add_902 = None + sub_315 = add_903 - 1; add_903 = None + floordiv_13 = sub_315 // 8; sub_315 = None + mul_659 = floordiv_13 * 8; floordiv_13 = None + cumsum_39 = torch.ops.aten.cumsum.default(wait_tensor_293, 0) + sub_316 = torch.ops.aten.sub.Tensor(cumsum_39, wait_tensor_293); cumsum_39 = None + sum_56 = torch.ops.aten.sum.dim_IntList(view_936, [0]); view_936 = None + clamp_min_13 = torch.ops.aten.clamp_min.default(sum_56, 8); sum_56 = None + add_904 = torch.ops.aten.add.Tensor(clamp_min_13, 8); clamp_min_13 = None + sub_317 = torch.ops.aten.sub.Tensor(add_904, 1); add_904 = None + div_68 = torch.ops.aten.div.Tensor_mode(sub_317, 8, rounding_mode = 'floor'); sub_317 = None + mul_660 = torch.ops.aten.mul.Tensor(div_68, 8); div_68 = None + convert_element_type_770 = torch.ops.prims.convert_element_type.default(mul_660, torch.int32); mul_660 = None + cumsum_40 = torch.ops.aten.cumsum.default(convert_element_type_770, 0) + sub_318 = torch.ops.aten.sub.Tensor(cumsum_40, convert_element_type_770); cumsum_40 = None + full_189 = torch.ops.aten.full.default([mul_659], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_659 = None + triton_kernel_wrapper_functional_proxy_13 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 13, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_293, 'start_index_values_ptr': sub_316, 'write_offsets_ptr': sub_318, 'output_ptr': full_189}, tensors_to_clone = ['output_ptr']); wait_tensor_293 = sub_316 = sub_318 = full_189 = None + getitem_1452 = triton_kernel_wrapper_functional_proxy_13['output_ptr']; triton_kernel_wrapper_functional_proxy_13 = None + cat_121 = torch.ops.aten.cat.default([wait_tensor_294, full_default]); wait_tensor_294 = None + sym_size_int_53 = torch.ops.aten.sym_size.int(cat_121, 0) + sym_sum_27 = torch.sym_sum((1, _local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223)) + index_27 = torch.ops.aten.index.Tensor(cat_121, [getitem_1452]); cat_121 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16) + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_772, 16, '1025'); convert_element_type_772 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + split_79 = torch.ops.aten.split.Tensor(wait_tensor_295, 8); wait_tensor_295 = None + getitem_1469 = split_79[0] + getitem_1470 = split_79[1] + getitem_1471 = split_79[2] + getitem_1472 = split_79[3] + getitem_1473 = split_79[4] + getitem_1474 = split_79[5] + getitem_1475 = split_79[6] + getitem_1476 = split_79[7] + getitem_1477 = split_79[8] + getitem_1478 = split_79[9] + getitem_1479 = split_79[10] + getitem_1480 = split_79[11] + getitem_1481 = split_79[12] + getitem_1482 = split_79[13] + getitem_1483 = split_79[14] + getitem_1484 = split_79[15]; split_79 = None + cat_123 = torch.ops.aten.cat.default([getitem_1469, getitem_1470, getitem_1471, getitem_1472, getitem_1473, getitem_1474, getitem_1475, getitem_1476, getitem_1477, getitem_1478, getitem_1479, getitem_1480, getitem_1481, getitem_1482, getitem_1483, getitem_1484], 1); getitem_1469 = getitem_1470 = getitem_1471 = getitem_1472 = getitem_1473 = getitem_1474 = getitem_1475 = getitem_1476 = getitem_1477 = getitem_1478 = getitem_1479 = getitem_1480 = getitem_1481 = getitem_1482 = getitem_1483 = getitem_1484 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16) + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_774, 16, '1025'); convert_element_type_774 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + split_80 = torch.ops.aten.split.Tensor(wait_tensor_297, 8); wait_tensor_297 = None + getitem_1485 = split_80[0] + getitem_1486 = split_80[1] + getitem_1487 = split_80[2] + getitem_1488 = split_80[3] + getitem_1489 = split_80[4] + getitem_1490 = split_80[5] + getitem_1491 = split_80[6] + getitem_1492 = split_80[7] + getitem_1493 = split_80[8] + getitem_1494 = split_80[9] + getitem_1495 = split_80[10] + getitem_1496 = split_80[11] + getitem_1497 = split_80[12] + getitem_1498 = split_80[13] + getitem_1499 = split_80[14] + getitem_1500 = split_80[15]; split_80 = None + cat_124 = torch.ops.aten.cat.default([getitem_1485, getitem_1486, getitem_1487, getitem_1488, getitem_1489, getitem_1490, getitem_1491, getitem_1492, getitem_1493, getitem_1494, getitem_1495, getitem_1496, getitem_1497, getitem_1498, getitem_1499, getitem_1500], 1); getitem_1485 = getitem_1486 = getitem_1487 = getitem_1488 = getitem_1489 = getitem_1490 = getitem_1491 = getitem_1492 = getitem_1493 = getitem_1494 = getitem_1495 = getitem_1496 = getitem_1497 = getitem_1498 = getitem_1499 = getitem_1500 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16) + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_775, 16, '1025'); convert_element_type_775 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + split_81 = torch.ops.aten.split.Tensor(wait_tensor_298, 8); wait_tensor_298 = None + getitem_1501 = split_81[0] + getitem_1502 = split_81[1] + getitem_1503 = split_81[2] + getitem_1504 = split_81[3] + getitem_1505 = split_81[4] + getitem_1506 = split_81[5] + getitem_1507 = split_81[6] + getitem_1508 = split_81[7] + getitem_1509 = split_81[8] + getitem_1510 = split_81[9] + getitem_1511 = split_81[10] + getitem_1512 = split_81[11] + getitem_1513 = split_81[12] + getitem_1514 = split_81[13] + getitem_1515 = split_81[14] + getitem_1516 = split_81[15]; split_81 = None + cat_125 = torch.ops.aten.cat.default([getitem_1501, getitem_1502, getitem_1503, getitem_1504, getitem_1505, getitem_1506, getitem_1507, getitem_1508, getitem_1509, getitem_1510, getitem_1511, getitem_1512, getitem_1513, getitem_1514, getitem_1515, getitem_1516], 1); getitem_1501 = getitem_1502 = getitem_1503 = getitem_1504 = getitem_1505 = getitem_1506 = getitem_1507 = getitem_1508 = getitem_1509 = getitem_1510 = getitem_1511 = getitem_1512 = getitem_1513 = getitem_1514 = getitem_1515 = getitem_1516 = None + cumsum_41 = torch.ops.aten.cumsum.default(convert_element_type_770, 0, dtype = torch.int32); convert_element_type_770 = None + permute_215 = torch.ops.aten.permute.default(cat_123, [0, 2, 1]); cat_123 = None + _grouped_mm_39 = torch.ops.aten._grouped_mm.default(index_27, permute_215, cumsum_41) + convert_element_type_778 = torch.ops.prims.convert_element_type.default(_grouped_mm_39, torch.float32) + neg_27 = torch.ops.aten.neg.default(convert_element_type_778) + exp_41 = torch.ops.aten.exp.default(neg_27); neg_27 = None + add_916 = torch.ops.aten.add.Tensor(exp_41, 1); exp_41 = None + div_69 = torch.ops.aten.div.Tensor(convert_element_type_778, add_916); convert_element_type_778 = add_916 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(div_69, torch.bfloat16); div_69 = None + permute_216 = torch.ops.aten.permute.default(cat_125, [0, 2, 1]); cat_125 = None + _grouped_mm_40 = torch.ops.aten._grouped_mm.default(index_27, permute_216, cumsum_41) + mul_672 = torch.ops.aten.mul.Tensor(convert_element_type_779, _grouped_mm_40); convert_element_type_779 = None + permute_217 = torch.ops.aten.permute.default(cat_124, [0, 2, 1]); cat_124 = None + _grouped_mm_41 = torch.ops.aten._grouped_mm.default(mul_672, permute_217, cumsum_41) + empty_13 = torch.ops.aten.empty.memory_format([sym_size_int_53, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_26 = torch.ops.aten.index_put.default(empty_13, [getitem_1452], _grouped_mm_41); empty_13 = _grouped_mm_41 = None + slice_89 = torch.ops.aten.slice.Tensor(index_put_26, 0, 0, -1); index_put_26 = None + all_to_all_single_41 = torch.ops._c10d_functional.all_to_all_single.default(slice_89, [_local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215], [_local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223], '1033'); slice_89 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_41); all_to_all_single_41 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16) + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_780, 128, '0'); convert_element_type_780 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_302, [1, 0]); wait_tensor_302 = None + mm_116 = torch.ops.aten.mm.default(view_929, permute_218); permute_218 = None + convert_element_type_783 = torch.ops.prims.convert_element_type.default(mm_116, torch.float32) + neg_28 = torch.ops.aten.neg.default(convert_element_type_783) + exp_42 = torch.ops.aten.exp.default(neg_28); neg_28 = None + add_952 = torch.ops.aten.add.Tensor(exp_42, 1); exp_42 = None + div_70 = torch.ops.aten.div.Tensor(convert_element_type_783, add_952); convert_element_type_783 = add_952 = None + convert_element_type_784 = torch.ops.prims.convert_element_type.default(div_70, torch.bfloat16); div_70 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16) + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_785, 128, '0'); convert_element_type_785 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_303, [1, 0]); wait_tensor_303 = None + mm_117 = torch.ops.aten.mm.default(view_929, permute_219); permute_219 = None + mul_692 = torch.ops.aten.mul.Tensor(convert_element_type_784, mm_117); convert_element_type_784 = None + convert_element_type_788 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16) + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_788, 128, '0'); convert_element_type_788 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_304, [1, 0]); wait_tensor_304 = None + mm_118 = torch.ops.aten.mm.default(mul_692, permute_220); permute_220 = None + index_put_27 = torch.ops.aten.index_put.default(full_default_1, [getitem_1451], wait_tensor_301); wait_tensor_301 = None + view_969 = torch.ops.aten.view.default(mul_654, [-1, 1, 6]); mul_654 = None + view_970 = torch.ops.aten.view.default(index_put_27, [-1, 6, 2048]); index_put_27 = None + convert_element_type_791 = torch.ops.prims.convert_element_type.default(view_970, torch.float32); view_970 = None + bmm_13 = torch.ops.aten.bmm.default(view_969, convert_element_type_791) + convert_element_type_792 = torch.ops.prims.convert_element_type.default(bmm_13, torch.bfloat16); bmm_13 = None + squeeze_13 = torch.ops.aten.squeeze.dim(convert_element_type_792, 1); convert_element_type_792 = None + add_956 = torch.ops.aten.add.Tensor(mm_118, squeeze_13); mm_118 = squeeze_13 = None + view_971 = torch.ops.aten.view.default(add_956, [2, 4096, 2048]); add_956 = None + add_957 = torch.ops.aten.add.Tensor(add_892, view_971); view_971 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16) + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 128, '0'); convert_element_type_793 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_957, torch.float32) + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_958 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_958); add_958 = None + mul_695 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_45); convert_element_type_794 = None + mul_696 = torch.ops.aten.mul.Tensor(mul_695, wait_tensor_305); mul_695 = wait_tensor_305 = None + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_696, torch.bfloat16); mul_696 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16) + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 128, '0'); convert_element_type_796 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_306, [1, 0]); wait_tensor_306 = None + view_974 = torch.ops.aten.view.default(convert_element_type_795, [8192, 2048]); convert_element_type_795 = None + mm_119 = torch.ops.aten.mm.default(view_974, permute_221); permute_221 = None + view_975 = torch.ops.aten.view.default(mm_119, [2, 4096, 3072]); mm_119 = None + view_976 = torch.ops.aten.view.default(view_975, [2, 4096, -1, 192]); view_975 = None + split_with_sizes_45 = torch.ops.aten.split_with_sizes.default(view_976, [128, 64], -1); view_976 = None + getitem_1549 = split_with_sizes_45[0] + getitem_1550 = split_with_sizes_45[1]; split_with_sizes_45 = None + convert_element_type_799 = torch.ops.prims.convert_element_type.default(getitem_1550, torch.float32); getitem_1550 = None + view_977 = torch.ops.aten.view.default(convert_element_type_799, [2, 4096, 16, -1, 2]); convert_element_type_799 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_977); view_977 = None + mul_697 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_7); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_697); mul_697 = None + view_979 = torch.ops.aten.view.default(view_as_real_30, [2, 4096, 16, 64]); view_as_real_30 = None + convert_element_type_800 = torch.ops.prims.convert_element_type.default(view_979, torch.bfloat16); view_979 = None + cat_128 = torch.ops.aten.cat.default([getitem_1549, convert_element_type_800], -1); getitem_1549 = convert_element_type_800 = None + convert_element_type_801 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16) + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_801, 128, '0'); convert_element_type_801 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + slice_91 = torch.ops.aten.slice.Tensor(wait_tensor_307, 0, 0, 576); wait_tensor_307 = None + permute_222 = torch.ops.aten.permute.default(slice_91, [1, 0]); slice_91 = None + mm_120 = torch.ops.aten.mm.default(view_974, permute_222); permute_222 = None + view_982 = torch.ops.aten.view.default(mm_120, [2, 4096, 576]); mm_120 = None + split_with_sizes_46 = torch.ops.aten.split_with_sizes.default(view_982, [512, 64], -1); view_982 = None + getitem_1551 = split_with_sizes_46[0] + getitem_1552 = split_with_sizes_46[1]; split_with_sizes_46 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(getitem_1552, 2); getitem_1552 = None + convert_element_type_804 = torch.ops.prims.convert_element_type.default(unsqueeze_29, torch.float32); unsqueeze_29 = None + view_983 = torch.ops.aten.view.default(convert_element_type_804, [2, 4096, 1, -1, 2]); convert_element_type_804 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_983); view_983 = None + mul_698 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_7); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_698); mul_698 = None + view_985 = torch.ops.aten.view.default(view_as_real_31, [2, 4096, 1, 64]); view_as_real_31 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_985, torch.bfloat16); view_985 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16) + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_806, 128, '0'); convert_element_type_806 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(getitem_1551, torch.float32) + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_807, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_959 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_959); add_959 = None + mul_699 = torch.ops.aten.mul.Tensor(convert_element_type_807, rsqrt_46); convert_element_type_807 = None + mul_700 = torch.ops.aten.mul.Tensor(mul_699, wait_tensor_308); mul_699 = wait_tensor_308 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(mul_700, torch.bfloat16); mul_700 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16) + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 128, '0'); convert_element_type_809 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + permute_223 = torch.ops.aten.permute.default(wait_tensor_309, [1, 0]); wait_tensor_309 = None + view_988 = torch.ops.aten.view.default(convert_element_type_808, [8192, 512]); convert_element_type_808 = None + mm_121 = torch.ops.aten.mm.default(view_988, permute_223); permute_223 = None + view_989 = torch.ops.aten.view.default(mm_121, [2, 4096, 4096]); mm_121 = None + view_990 = torch.ops.aten.view.default(view_989, [2, 4096, -1, 256]); view_989 = None + split_with_sizes_47 = torch.ops.aten.split_with_sizes.default(view_990, [128, 128], -1); view_990 = None + getitem_1553 = split_with_sizes_47[0] + getitem_1554 = split_with_sizes_47[1]; split_with_sizes_47 = None + expand_15 = torch.ops.aten.expand.default(convert_element_type_805, [-1, -1, 16, -1]); convert_element_type_805 = None + cat_129 = torch.ops.aten.cat.default([getitem_1553, expand_15], -1); getitem_1553 = expand_15 = None + permute_224 = torch.ops.aten.permute.default(cat_128, [0, 2, 1, 3]); cat_128 = None + permute_225 = torch.ops.aten.permute.default(cat_129, [0, 2, 1, 3]); cat_129 = None + permute_226 = torch.ops.aten.permute.default(getitem_1554, [0, 2, 1, 3]); getitem_1554 = None + sdpa_score15 = self.sdpa_score15 + sdpa_mask15 = self.sdpa_mask15 + flex_attention_15 = torch.ops.higher_order.flex_attention(permute_224, permute_225, permute_226, sdpa_score15, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask15), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score15 = sdpa_mask15 = None + getitem_1555 = flex_attention_15[0] + getitem_1556 = flex_attention_15[1]; flex_attention_15 = None + permute_227 = torch.ops.aten.permute.default(getitem_1555, [0, 2, 1, 3]) + view_991 = torch.ops.aten.view.default(permute_227, [2, 4096, -1]); permute_227 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16) + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 128, '0'); convert_element_type_812 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_310, [1, 0]); wait_tensor_310 = None + view_993 = torch.ops.aten.view.default(view_991, [8192, 2048]); view_991 = None + mm_122 = torch.ops.aten.mm.default(view_993, permute_228); view_993 = permute_228 = None + view_994 = torch.ops.aten.view.default(mm_122, [2, 4096, 2048]); mm_122 = None + add_960 = torch.ops.aten.add.Tensor(add_957, view_994); view_994 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16) + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 128, '0'); convert_element_type_815 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + convert_element_type_816 = torch.ops.prims.convert_element_type.default(add_960, torch.float32) + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_816, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_961 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_961); add_961 = None + mul_701 = torch.ops.aten.mul.Tensor(convert_element_type_816, rsqrt_47); convert_element_type_816 = None + mul_702 = torch.ops.aten.mul.Tensor(mul_701, wait_tensor_311); mul_701 = wait_tensor_311 = None + convert_element_type_817 = torch.ops.prims.convert_element_type.default(mul_702, torch.bfloat16); mul_702 = None + view_996 = torch.ops.aten.view.default(convert_element_type_817, [-1, 2048]); convert_element_type_817 = None + convert_element_type_818 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16) + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_818, 128, '0'); convert_element_type_818 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + slice_93 = torch.ops.aten.slice.Tensor(wait_tensor_312, 0, 0, 64); wait_tensor_312 = None + permute_229 = torch.ops.aten.permute.default(slice_93, [1, 0]); slice_93 = None + mm_123 = torch.ops.aten.mm.default(view_996, permute_229); permute_229 = None + convert_element_type_821 = torch.ops.prims.convert_element_type.default(mm_123, torch.float32) + amax_14 = torch.ops.aten.amax.default(convert_element_type_821, [1], True) + sub_336 = torch.ops.aten.sub.Tensor(convert_element_type_821, amax_14); convert_element_type_821 = None + exp_43 = torch.ops.aten.exp.default(sub_336); sub_336 = None + sum_57 = torch.ops.aten.sum.dim_IntList(exp_43, [1], True) + div_71 = torch.ops.aten.div.Tensor(exp_43, sum_57); exp_43 = None + add_962 = torch.ops.aten.add.Tensor(div_71, primals_254); primals_254 = None + topk_14 = torch.ops.aten.topk.default(add_962, 6, -1, True, False); add_962 = None + getitem_1559 = topk_14[1]; topk_14 = None + gather_14 = torch.ops.aten.gather.default(div_71, 1, getitem_1559); div_71 = None + mul_703 = torch.ops.aten.mul.Tensor(gather_14, 1.0); gather_14 = None + view_998 = torch.ops.aten.view.default(getitem_1559, [-1]) + histc_28 = torch.ops.aten.histc.default(view_998, 64, 0, 64) + add_963 = torch.ops.aten.add.Tensor(primals_256, histc_28) + sort_14 = torch.ops.aten.sort.stable(view_998, stable = True); view_998 = None + getitem_1561 = sort_14[1]; sort_14 = None + div_72 = torch.ops.aten.div.Tensor_mode(getitem_1561, 6, rounding_mode = 'floor') + index_28 = torch.ops.aten.index.Tensor(view_996, [div_72]) + all_to_all_single_42 = torch.ops._c10d_functional.all_to_all_single.default(histc_28, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_42); all_to_all_single_42 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_313); wait_tensor_313 = None + view_1002 = torch.ops.aten.view.default(histc_28, [8, -1]); histc_28 = None + sum_58 = torch.ops.aten.sum.dim_IntList(view_1002, [1]); view_1002 = None + device_put_28 = torch.ops.prims.device_put.default(sum_58, device(type='cpu'), True); sum_58 = None + view_1003 = torch.ops.aten.view.default(wait_tensor_314, [8, -1]) + sum_59 = torch.ops.aten.sum.dim_IntList(view_1003, [1]) + device_put_29 = torch.ops.prims.device_put.default(sum_59, device(type='cpu')); sum_59 = None + select_224 = torch.ops.aten.select.int(device_put_28, 0, 0) + _local_scalar_dense_224 = torch.ops.aten._local_scalar_dense.default(select_224); select_224 = None + ge_280 = _local_scalar_dense_224 >= 0 + _assert_scalar_224 = torch.ops.aten._assert_scalar.default(ge_280, "Runtime assertion failed for expression u224 >= 0 on node 'ge_224'"); ge_280 = _assert_scalar_224 = None + select_225 = torch.ops.aten.select.int(device_put_28, 0, 1) + _local_scalar_dense_225 = torch.ops.aten._local_scalar_dense.default(select_225); select_225 = None + ge_281 = _local_scalar_dense_225 >= 0 + _assert_scalar_225 = torch.ops.aten._assert_scalar.default(ge_281, "Runtime assertion failed for expression u225 >= 0 on node 'ge_225'"); ge_281 = _assert_scalar_225 = None + select_226 = torch.ops.aten.select.int(device_put_28, 0, 2) + _local_scalar_dense_226 = torch.ops.aten._local_scalar_dense.default(select_226); select_226 = None + ge_282 = _local_scalar_dense_226 >= 0 + _assert_scalar_226 = torch.ops.aten._assert_scalar.default(ge_282, "Runtime assertion failed for expression u226 >= 0 on node 'ge_226'"); ge_282 = _assert_scalar_226 = None + select_227 = torch.ops.aten.select.int(device_put_28, 0, 3) + _local_scalar_dense_227 = torch.ops.aten._local_scalar_dense.default(select_227); select_227 = None + ge_283 = _local_scalar_dense_227 >= 0 + _assert_scalar_227 = torch.ops.aten._assert_scalar.default(ge_283, "Runtime assertion failed for expression u227 >= 0 on node 'ge_227'"); ge_283 = _assert_scalar_227 = None + select_228 = torch.ops.aten.select.int(device_put_28, 0, 4) + _local_scalar_dense_228 = torch.ops.aten._local_scalar_dense.default(select_228); select_228 = None + ge_284 = _local_scalar_dense_228 >= 0 + _assert_scalar_228 = torch.ops.aten._assert_scalar.default(ge_284, "Runtime assertion failed for expression u228 >= 0 on node 'ge_228'"); ge_284 = _assert_scalar_228 = None + select_229 = torch.ops.aten.select.int(device_put_28, 0, 5) + _local_scalar_dense_229 = torch.ops.aten._local_scalar_dense.default(select_229); select_229 = None + ge_285 = _local_scalar_dense_229 >= 0 + _assert_scalar_229 = torch.ops.aten._assert_scalar.default(ge_285, "Runtime assertion failed for expression u229 >= 0 on node 'ge_229'"); ge_285 = _assert_scalar_229 = None + select_230 = torch.ops.aten.select.int(device_put_28, 0, 6) + _local_scalar_dense_230 = torch.ops.aten._local_scalar_dense.default(select_230); select_230 = None + ge_286 = _local_scalar_dense_230 >= 0 + _assert_scalar_230 = torch.ops.aten._assert_scalar.default(ge_286, "Runtime assertion failed for expression u230 >= 0 on node 'ge_230'"); ge_286 = _assert_scalar_230 = None + select_231 = torch.ops.aten.select.int(device_put_28, 0, 7); device_put_28 = None + _local_scalar_dense_231 = torch.ops.aten._local_scalar_dense.default(select_231); select_231 = None + ge_287 = _local_scalar_dense_231 >= 0 + _assert_scalar_231 = torch.ops.aten._assert_scalar.default(ge_287, "Runtime assertion failed for expression u231 >= 0 on node 'ge_231'"); ge_287 = _assert_scalar_231 = None + select_232 = torch.ops.aten.select.int(device_put_29, 0, 0) + _local_scalar_dense_232 = torch.ops.aten._local_scalar_dense.default(select_232); select_232 = None + ge_288 = _local_scalar_dense_232 >= 0 + _assert_scalar_232 = torch.ops.aten._assert_scalar.default(ge_288, "Runtime assertion failed for expression u232 >= 0 on node 'ge_232'"); ge_288 = _assert_scalar_232 = None + select_233 = torch.ops.aten.select.int(device_put_29, 0, 1) + _local_scalar_dense_233 = torch.ops.aten._local_scalar_dense.default(select_233); select_233 = None + ge_289 = _local_scalar_dense_233 >= 0 + _assert_scalar_233 = torch.ops.aten._assert_scalar.default(ge_289, "Runtime assertion failed for expression u233 >= 0 on node 'ge_233'"); ge_289 = _assert_scalar_233 = None + select_234 = torch.ops.aten.select.int(device_put_29, 0, 2) + _local_scalar_dense_234 = torch.ops.aten._local_scalar_dense.default(select_234); select_234 = None + ge_290 = _local_scalar_dense_234 >= 0 + _assert_scalar_234 = torch.ops.aten._assert_scalar.default(ge_290, "Runtime assertion failed for expression u234 >= 0 on node 'ge_234'"); ge_290 = _assert_scalar_234 = None + select_235 = torch.ops.aten.select.int(device_put_29, 0, 3) + _local_scalar_dense_235 = torch.ops.aten._local_scalar_dense.default(select_235); select_235 = None + ge_291 = _local_scalar_dense_235 >= 0 + _assert_scalar_235 = torch.ops.aten._assert_scalar.default(ge_291, "Runtime assertion failed for expression u235 >= 0 on node 'ge_235'"); ge_291 = _assert_scalar_235 = None + select_236 = torch.ops.aten.select.int(device_put_29, 0, 4) + _local_scalar_dense_236 = torch.ops.aten._local_scalar_dense.default(select_236); select_236 = None + ge_292 = _local_scalar_dense_236 >= 0 + _assert_scalar_236 = torch.ops.aten._assert_scalar.default(ge_292, "Runtime assertion failed for expression u236 >= 0 on node 'ge_236'"); ge_292 = _assert_scalar_236 = None + select_237 = torch.ops.aten.select.int(device_put_29, 0, 5) + _local_scalar_dense_237 = torch.ops.aten._local_scalar_dense.default(select_237); select_237 = None + ge_293 = _local_scalar_dense_237 >= 0 + _assert_scalar_237 = torch.ops.aten._assert_scalar.default(ge_293, "Runtime assertion failed for expression u237 >= 0 on node 'ge_237'"); ge_293 = _assert_scalar_237 = None + select_238 = torch.ops.aten.select.int(device_put_29, 0, 6) + _local_scalar_dense_238 = torch.ops.aten._local_scalar_dense.default(select_238); select_238 = None + ge_294 = _local_scalar_dense_238 >= 0 + _assert_scalar_238 = torch.ops.aten._assert_scalar.default(ge_294, "Runtime assertion failed for expression u238 >= 0 on node 'ge_238'"); ge_294 = _assert_scalar_238 = None + select_239 = torch.ops.aten.select.int(device_put_29, 0, 7); device_put_29 = None + _local_scalar_dense_239 = torch.ops.aten._local_scalar_dense.default(select_239); select_239 = None + ge_295 = _local_scalar_dense_239 >= 0 + _assert_scalar_239 = torch.ops.aten._assert_scalar.default(ge_295, "Runtime assertion failed for expression u239 >= 0 on node 'ge_239'"); ge_295 = _assert_scalar_239 = None + all_to_all_single_43 = torch.ops._c10d_functional.all_to_all_single.default(index_28, [_local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239], [_local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231], '1033'); index_28 = None + sym_size_int_56 = torch.ops.aten.sym_size.int(all_to_all_single_43, 0) + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_43); all_to_all_single_43 = None + sym_sum_28 = torch.sym_sum((_local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239)) + add_970 = sym_sum_28 + 64; sym_sum_28 = None + add_971 = add_970 + 8; add_970 = None + sub_339 = add_971 - 1; add_971 = None + floordiv_14 = sub_339 // 8; sub_339 = None + mul_708 = floordiv_14 * 8; floordiv_14 = None + cumsum_42 = torch.ops.aten.cumsum.default(wait_tensor_314, 0) + sub_340 = torch.ops.aten.sub.Tensor(cumsum_42, wait_tensor_314); cumsum_42 = None + sum_60 = torch.ops.aten.sum.dim_IntList(view_1003, [0]); view_1003 = None + clamp_min_14 = torch.ops.aten.clamp_min.default(sum_60, 8); sum_60 = None + add_972 = torch.ops.aten.add.Tensor(clamp_min_14, 8); clamp_min_14 = None + sub_341 = torch.ops.aten.sub.Tensor(add_972, 1); add_972 = None + div_73 = torch.ops.aten.div.Tensor_mode(sub_341, 8, rounding_mode = 'floor'); sub_341 = None + mul_709 = torch.ops.aten.mul.Tensor(div_73, 8); div_73 = None + convert_element_type_824 = torch.ops.prims.convert_element_type.default(mul_709, torch.int32); mul_709 = None + cumsum_43 = torch.ops.aten.cumsum.default(convert_element_type_824, 0) + sub_342 = torch.ops.aten.sub.Tensor(cumsum_43, convert_element_type_824); cumsum_43 = None + full_202 = torch.ops.aten.full.default([mul_708], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_708 = None + triton_kernel_wrapper_functional_proxy_14 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 14, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_314, 'start_index_values_ptr': sub_340, 'write_offsets_ptr': sub_342, 'output_ptr': full_202}, tensors_to_clone = ['output_ptr']); wait_tensor_314 = sub_340 = sub_342 = full_202 = None + getitem_1562 = triton_kernel_wrapper_functional_proxy_14['output_ptr']; triton_kernel_wrapper_functional_proxy_14 = None + cat_130 = torch.ops.aten.cat.default([wait_tensor_315, full_default]); wait_tensor_315 = None + sym_size_int_57 = torch.ops.aten.sym_size.int(cat_130, 0) + sym_sum_29 = torch.sym_sum((1, _local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239)) + index_29 = torch.ops.aten.index.Tensor(cat_130, [getitem_1562]); cat_130 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16) + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 16, '1025'); convert_element_type_826 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + split_85 = torch.ops.aten.split.Tensor(wait_tensor_316, 8); wait_tensor_316 = None + getitem_1579 = split_85[0] + getitem_1580 = split_85[1] + getitem_1581 = split_85[2] + getitem_1582 = split_85[3] + getitem_1583 = split_85[4] + getitem_1584 = split_85[5] + getitem_1585 = split_85[6] + getitem_1586 = split_85[7] + getitem_1587 = split_85[8] + getitem_1588 = split_85[9] + getitem_1589 = split_85[10] + getitem_1590 = split_85[11] + getitem_1591 = split_85[12] + getitem_1592 = split_85[13] + getitem_1593 = split_85[14] + getitem_1594 = split_85[15]; split_85 = None + cat_132 = torch.ops.aten.cat.default([getitem_1579, getitem_1580, getitem_1581, getitem_1582, getitem_1583, getitem_1584, getitem_1585, getitem_1586, getitem_1587, getitem_1588, getitem_1589, getitem_1590, getitem_1591, getitem_1592, getitem_1593, getitem_1594], 1); getitem_1579 = getitem_1580 = getitem_1581 = getitem_1582 = getitem_1583 = getitem_1584 = getitem_1585 = getitem_1586 = getitem_1587 = getitem_1588 = getitem_1589 = getitem_1590 = getitem_1591 = getitem_1592 = getitem_1593 = getitem_1594 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16) + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_828, 16, '1025'); convert_element_type_828 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + split_86 = torch.ops.aten.split.Tensor(wait_tensor_318, 8); wait_tensor_318 = None + getitem_1595 = split_86[0] + getitem_1596 = split_86[1] + getitem_1597 = split_86[2] + getitem_1598 = split_86[3] + getitem_1599 = split_86[4] + getitem_1600 = split_86[5] + getitem_1601 = split_86[6] + getitem_1602 = split_86[7] + getitem_1603 = split_86[8] + getitem_1604 = split_86[9] + getitem_1605 = split_86[10] + getitem_1606 = split_86[11] + getitem_1607 = split_86[12] + getitem_1608 = split_86[13] + getitem_1609 = split_86[14] + getitem_1610 = split_86[15]; split_86 = None + cat_133 = torch.ops.aten.cat.default([getitem_1595, getitem_1596, getitem_1597, getitem_1598, getitem_1599, getitem_1600, getitem_1601, getitem_1602, getitem_1603, getitem_1604, getitem_1605, getitem_1606, getitem_1607, getitem_1608, getitem_1609, getitem_1610], 1); getitem_1595 = getitem_1596 = getitem_1597 = getitem_1598 = getitem_1599 = getitem_1600 = getitem_1601 = getitem_1602 = getitem_1603 = getitem_1604 = getitem_1605 = getitem_1606 = getitem_1607 = getitem_1608 = getitem_1609 = getitem_1610 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16) + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 16, '1025'); convert_element_type_829 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + split_87 = torch.ops.aten.split.Tensor(wait_tensor_319, 8); wait_tensor_319 = None + getitem_1611 = split_87[0] + getitem_1612 = split_87[1] + getitem_1613 = split_87[2] + getitem_1614 = split_87[3] + getitem_1615 = split_87[4] + getitem_1616 = split_87[5] + getitem_1617 = split_87[6] + getitem_1618 = split_87[7] + getitem_1619 = split_87[8] + getitem_1620 = split_87[9] + getitem_1621 = split_87[10] + getitem_1622 = split_87[11] + getitem_1623 = split_87[12] + getitem_1624 = split_87[13] + getitem_1625 = split_87[14] + getitem_1626 = split_87[15]; split_87 = None + cat_134 = torch.ops.aten.cat.default([getitem_1611, getitem_1612, getitem_1613, getitem_1614, getitem_1615, getitem_1616, getitem_1617, getitem_1618, getitem_1619, getitem_1620, getitem_1621, getitem_1622, getitem_1623, getitem_1624, getitem_1625, getitem_1626], 1); getitem_1611 = getitem_1612 = getitem_1613 = getitem_1614 = getitem_1615 = getitem_1616 = getitem_1617 = getitem_1618 = getitem_1619 = getitem_1620 = getitem_1621 = getitem_1622 = getitem_1623 = getitem_1624 = getitem_1625 = getitem_1626 = None + cumsum_44 = torch.ops.aten.cumsum.default(convert_element_type_824, 0, dtype = torch.int32); convert_element_type_824 = None + permute_230 = torch.ops.aten.permute.default(cat_132, [0, 2, 1]); cat_132 = None + _grouped_mm_42 = torch.ops.aten._grouped_mm.default(index_29, permute_230, cumsum_44) + convert_element_type_832 = torch.ops.prims.convert_element_type.default(_grouped_mm_42, torch.float32) + neg_29 = torch.ops.aten.neg.default(convert_element_type_832) + exp_44 = torch.ops.aten.exp.default(neg_29); neg_29 = None + add_984 = torch.ops.aten.add.Tensor(exp_44, 1); exp_44 = None + div_74 = torch.ops.aten.div.Tensor(convert_element_type_832, add_984); convert_element_type_832 = add_984 = None + convert_element_type_833 = torch.ops.prims.convert_element_type.default(div_74, torch.bfloat16); div_74 = None + permute_231 = torch.ops.aten.permute.default(cat_134, [0, 2, 1]); cat_134 = None + _grouped_mm_43 = torch.ops.aten._grouped_mm.default(index_29, permute_231, cumsum_44) + mul_721 = torch.ops.aten.mul.Tensor(convert_element_type_833, _grouped_mm_43); convert_element_type_833 = None + permute_232 = torch.ops.aten.permute.default(cat_133, [0, 2, 1]); cat_133 = None + _grouped_mm_44 = torch.ops.aten._grouped_mm.default(mul_721, permute_232, cumsum_44) + empty_14 = torch.ops.aten.empty.memory_format([sym_size_int_57, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_28 = torch.ops.aten.index_put.default(empty_14, [getitem_1562], _grouped_mm_44); empty_14 = _grouped_mm_44 = None + slice_95 = torch.ops.aten.slice.Tensor(index_put_28, 0, 0, -1); index_put_28 = None + all_to_all_single_44 = torch.ops._c10d_functional.all_to_all_single.default(slice_95, [_local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231], [_local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239], '1033'); slice_95 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_44); all_to_all_single_44 = None + convert_element_type_834 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16) + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_834, 128, '0'); convert_element_type_834 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_323, [1, 0]); wait_tensor_323 = None + mm_124 = torch.ops.aten.mm.default(view_996, permute_233); permute_233 = None + convert_element_type_837 = torch.ops.prims.convert_element_type.default(mm_124, torch.float32) + neg_30 = torch.ops.aten.neg.default(convert_element_type_837) + exp_45 = torch.ops.aten.exp.default(neg_30); neg_30 = None + add_1020 = torch.ops.aten.add.Tensor(exp_45, 1); exp_45 = None + div_75 = torch.ops.aten.div.Tensor(convert_element_type_837, add_1020); convert_element_type_837 = add_1020 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(div_75, torch.bfloat16); div_75 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16) + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_839, 128, '0'); convert_element_type_839 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_234 = torch.ops.aten.permute.default(wait_tensor_324, [1, 0]); wait_tensor_324 = None + mm_125 = torch.ops.aten.mm.default(view_996, permute_234); permute_234 = None + mul_741 = torch.ops.aten.mul.Tensor(convert_element_type_838, mm_125); convert_element_type_838 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16) + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 128, '0'); convert_element_type_842 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_235 = torch.ops.aten.permute.default(wait_tensor_325, [1, 0]); wait_tensor_325 = None + mm_126 = torch.ops.aten.mm.default(mul_741, permute_235); permute_235 = None + index_put_29 = torch.ops.aten.index_put.default(full_default_1, [getitem_1561], wait_tensor_322); wait_tensor_322 = None + view_1036 = torch.ops.aten.view.default(mul_703, [-1, 1, 6]); mul_703 = None + view_1037 = torch.ops.aten.view.default(index_put_29, [-1, 6, 2048]); index_put_29 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(view_1037, torch.float32); view_1037 = None + bmm_14 = torch.ops.aten.bmm.default(view_1036, convert_element_type_845) + convert_element_type_846 = torch.ops.prims.convert_element_type.default(bmm_14, torch.bfloat16); bmm_14 = None + squeeze_14 = torch.ops.aten.squeeze.dim(convert_element_type_846, 1); convert_element_type_846 = None + add_1024 = torch.ops.aten.add.Tensor(mm_126, squeeze_14); mm_126 = squeeze_14 = None + view_1038 = torch.ops.aten.view.default(add_1024, [2, 4096, 2048]); add_1024 = None + add_1025 = torch.ops.aten.add.Tensor(add_960, view_1038); view_1038 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16) + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_847, 128, '0'); convert_element_type_847 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(add_1025, torch.float32) + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_848, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_1026 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_1026); add_1026 = None + mul_744 = torch.ops.aten.mul.Tensor(convert_element_type_848, rsqrt_48); convert_element_type_848 = None + mul_745 = torch.ops.aten.mul.Tensor(mul_744, wait_tensor_326); mul_744 = wait_tensor_326 = None + convert_element_type_849 = torch.ops.prims.convert_element_type.default(mul_745, torch.bfloat16); mul_745 = None + convert_element_type_850 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16) + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_850, 128, '0'); convert_element_type_850 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + permute_236 = torch.ops.aten.permute.default(wait_tensor_327, [1, 0]); wait_tensor_327 = None + view_1041 = torch.ops.aten.view.default(convert_element_type_849, [8192, 2048]); convert_element_type_849 = None + mm_127 = torch.ops.aten.mm.default(view_1041, permute_236); permute_236 = None + view_1042 = torch.ops.aten.view.default(mm_127, [2, 4096, 3072]); mm_127 = None + view_1043 = torch.ops.aten.view.default(view_1042, [2, 4096, -1, 192]); view_1042 = None + split_with_sizes_48 = torch.ops.aten.split_with_sizes.default(view_1043, [128, 64], -1); view_1043 = None + getitem_1659 = split_with_sizes_48[0] + getitem_1660 = split_with_sizes_48[1]; split_with_sizes_48 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(getitem_1660, torch.float32); getitem_1660 = None + view_1044 = torch.ops.aten.view.default(convert_element_type_853, [2, 4096, 16, -1, 2]); convert_element_type_853 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_1044); view_1044 = None + mul_746 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_7); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_746); mul_746 = None + view_1046 = torch.ops.aten.view.default(view_as_real_32, [2, 4096, 16, 64]); view_as_real_32 = None + convert_element_type_854 = torch.ops.prims.convert_element_type.default(view_1046, torch.bfloat16); view_1046 = None + cat_137 = torch.ops.aten.cat.default([getitem_1659, convert_element_type_854], -1); getitem_1659 = convert_element_type_854 = None + convert_element_type_855 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16) + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_855, 128, '0'); convert_element_type_855 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + slice_97 = torch.ops.aten.slice.Tensor(wait_tensor_328, 0, 0, 576); wait_tensor_328 = None + permute_237 = torch.ops.aten.permute.default(slice_97, [1, 0]); slice_97 = None + mm_128 = torch.ops.aten.mm.default(view_1041, permute_237); permute_237 = None + view_1049 = torch.ops.aten.view.default(mm_128, [2, 4096, 576]); mm_128 = None + split_with_sizes_49 = torch.ops.aten.split_with_sizes.default(view_1049, [512, 64], -1); view_1049 = None + getitem_1661 = split_with_sizes_49[0] + getitem_1662 = split_with_sizes_49[1]; split_with_sizes_49 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(getitem_1662, 2); getitem_1662 = None + convert_element_type_858 = torch.ops.prims.convert_element_type.default(unsqueeze_31, torch.float32); unsqueeze_31 = None + view_1050 = torch.ops.aten.view.default(convert_element_type_858, [2, 4096, 1, -1, 2]); convert_element_type_858 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_1050); view_1050 = None + mul_747 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_7); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_747); mul_747 = None + view_1052 = torch.ops.aten.view.default(view_as_real_33, [2, 4096, 1, 64]); view_as_real_33 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(view_1052, torch.bfloat16); view_1052 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16) + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_860, 128, '0'); convert_element_type_860 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(getitem_1661, torch.float32) + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_861, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_1027 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_1027); add_1027 = None + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_861, rsqrt_49); convert_element_type_861 = None + mul_749 = torch.ops.aten.mul.Tensor(mul_748, wait_tensor_329); mul_748 = wait_tensor_329 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(mul_749, torch.bfloat16); mul_749 = None + convert_element_type_863 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16) + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_863, 128, '0'); convert_element_type_863 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_330, [1, 0]); wait_tensor_330 = None + view_1055 = torch.ops.aten.view.default(convert_element_type_862, [8192, 512]); convert_element_type_862 = None + mm_129 = torch.ops.aten.mm.default(view_1055, permute_238); permute_238 = None + view_1056 = torch.ops.aten.view.default(mm_129, [2, 4096, 4096]); mm_129 = None + view_1057 = torch.ops.aten.view.default(view_1056, [2, 4096, -1, 256]); view_1056 = None + split_with_sizes_50 = torch.ops.aten.split_with_sizes.default(view_1057, [128, 128], -1); view_1057 = None + getitem_1663 = split_with_sizes_50[0] + getitem_1664 = split_with_sizes_50[1]; split_with_sizes_50 = None + expand_16 = torch.ops.aten.expand.default(convert_element_type_859, [-1, -1, 16, -1]); convert_element_type_859 = None + cat_138 = torch.ops.aten.cat.default([getitem_1663, expand_16], -1); getitem_1663 = expand_16 = None + permute_239 = torch.ops.aten.permute.default(cat_137, [0, 2, 1, 3]); cat_137 = None + permute_240 = torch.ops.aten.permute.default(cat_138, [0, 2, 1, 3]); cat_138 = None + permute_241 = torch.ops.aten.permute.default(getitem_1664, [0, 2, 1, 3]); getitem_1664 = None + sdpa_score16 = self.sdpa_score16 + sdpa_mask16 = self.sdpa_mask16 + flex_attention_16 = torch.ops.higher_order.flex_attention(permute_239, permute_240, permute_241, sdpa_score16, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask16), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score16 = sdpa_mask16 = None + getitem_1665 = flex_attention_16[0] + getitem_1666 = flex_attention_16[1]; flex_attention_16 = None + permute_242 = torch.ops.aten.permute.default(getitem_1665, [0, 2, 1, 3]) + view_1058 = torch.ops.aten.view.default(permute_242, [2, 4096, -1]); permute_242 = None + convert_element_type_866 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16) + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_866, 128, '0'); convert_element_type_866 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_331, [1, 0]); wait_tensor_331 = None + view_1060 = torch.ops.aten.view.default(view_1058, [8192, 2048]); view_1058 = None + mm_130 = torch.ops.aten.mm.default(view_1060, permute_243); view_1060 = permute_243 = None + view_1061 = torch.ops.aten.view.default(mm_130, [2, 4096, 2048]); mm_130 = None + add_1028 = torch.ops.aten.add.Tensor(add_1025, view_1061); view_1061 = None + convert_element_type_869 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16) + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_869, 128, '0'); convert_element_type_869 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + convert_element_type_870 = torch.ops.prims.convert_element_type.default(add_1028, torch.float32) + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_870, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_1029 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_1029); add_1029 = None + mul_750 = torch.ops.aten.mul.Tensor(convert_element_type_870, rsqrt_50); convert_element_type_870 = None + mul_751 = torch.ops.aten.mul.Tensor(mul_750, wait_tensor_332); mul_750 = wait_tensor_332 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(mul_751, torch.bfloat16); mul_751 = None + view_1063 = torch.ops.aten.view.default(convert_element_type_871, [-1, 2048]); convert_element_type_871 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16) + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_872, 128, '0'); convert_element_type_872 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + slice_99 = torch.ops.aten.slice.Tensor(wait_tensor_333, 0, 0, 64); wait_tensor_333 = None + permute_244 = torch.ops.aten.permute.default(slice_99, [1, 0]); slice_99 = None + mm_131 = torch.ops.aten.mm.default(view_1063, permute_244); permute_244 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(mm_131, torch.float32) + amax_15 = torch.ops.aten.amax.default(convert_element_type_875, [1], True) + sub_360 = torch.ops.aten.sub.Tensor(convert_element_type_875, amax_15); convert_element_type_875 = None + exp_46 = torch.ops.aten.exp.default(sub_360); sub_360 = None + sum_61 = torch.ops.aten.sum.dim_IntList(exp_46, [1], True) + div_76 = torch.ops.aten.div.Tensor(exp_46, sum_61); exp_46 = None + add_1030 = torch.ops.aten.add.Tensor(div_76, primals_270); primals_270 = None + topk_15 = torch.ops.aten.topk.default(add_1030, 6, -1, True, False); add_1030 = None + getitem_1669 = topk_15[1]; topk_15 = None + gather_15 = torch.ops.aten.gather.default(div_76, 1, getitem_1669); div_76 = None + mul_752 = torch.ops.aten.mul.Tensor(gather_15, 1.0); gather_15 = None + view_1065 = torch.ops.aten.view.default(getitem_1669, [-1]) + histc_30 = torch.ops.aten.histc.default(view_1065, 64, 0, 64) + add_1031 = torch.ops.aten.add.Tensor(primals_272, histc_30) + sort_15 = torch.ops.aten.sort.stable(view_1065, stable = True); view_1065 = None + getitem_1671 = sort_15[1]; sort_15 = None + div_77 = torch.ops.aten.div.Tensor_mode(getitem_1671, 6, rounding_mode = 'floor') + index_30 = torch.ops.aten.index.Tensor(view_1063, [div_77]) + all_to_all_single_45 = torch.ops._c10d_functional.all_to_all_single.default(histc_30, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_45); all_to_all_single_45 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_334); wait_tensor_334 = None + view_1069 = torch.ops.aten.view.default(histc_30, [8, -1]); histc_30 = None + sum_62 = torch.ops.aten.sum.dim_IntList(view_1069, [1]); view_1069 = None + device_put_30 = torch.ops.prims.device_put.default(sum_62, device(type='cpu'), True); sum_62 = None + view_1070 = torch.ops.aten.view.default(wait_tensor_335, [8, -1]) + sum_63 = torch.ops.aten.sum.dim_IntList(view_1070, [1]) + device_put_31 = torch.ops.prims.device_put.default(sum_63, device(type='cpu')); sum_63 = None + select_240 = torch.ops.aten.select.int(device_put_30, 0, 0) + _local_scalar_dense_240 = torch.ops.aten._local_scalar_dense.default(select_240); select_240 = None + ge_300 = _local_scalar_dense_240 >= 0 + _assert_scalar_240 = torch.ops.aten._assert_scalar.default(ge_300, "Runtime assertion failed for expression u240 >= 0 on node 'ge_240'"); ge_300 = _assert_scalar_240 = None + select_241 = torch.ops.aten.select.int(device_put_30, 0, 1) + _local_scalar_dense_241 = torch.ops.aten._local_scalar_dense.default(select_241); select_241 = None + ge_301 = _local_scalar_dense_241 >= 0 + _assert_scalar_241 = torch.ops.aten._assert_scalar.default(ge_301, "Runtime assertion failed for expression u241 >= 0 on node 'ge_241'"); ge_301 = _assert_scalar_241 = None + select_242 = torch.ops.aten.select.int(device_put_30, 0, 2) + _local_scalar_dense_242 = torch.ops.aten._local_scalar_dense.default(select_242); select_242 = None + ge_302 = _local_scalar_dense_242 >= 0 + _assert_scalar_242 = torch.ops.aten._assert_scalar.default(ge_302, "Runtime assertion failed for expression u242 >= 0 on node 'ge_242'"); ge_302 = _assert_scalar_242 = None + select_243 = torch.ops.aten.select.int(device_put_30, 0, 3) + _local_scalar_dense_243 = torch.ops.aten._local_scalar_dense.default(select_243); select_243 = None + ge_303 = _local_scalar_dense_243 >= 0 + _assert_scalar_243 = torch.ops.aten._assert_scalar.default(ge_303, "Runtime assertion failed for expression u243 >= 0 on node 'ge_243'"); ge_303 = _assert_scalar_243 = None + select_244 = torch.ops.aten.select.int(device_put_30, 0, 4) + _local_scalar_dense_244 = torch.ops.aten._local_scalar_dense.default(select_244); select_244 = None + ge_304 = _local_scalar_dense_244 >= 0 + _assert_scalar_244 = torch.ops.aten._assert_scalar.default(ge_304, "Runtime assertion failed for expression u244 >= 0 on node 'ge_244'"); ge_304 = _assert_scalar_244 = None + select_245 = torch.ops.aten.select.int(device_put_30, 0, 5) + _local_scalar_dense_245 = torch.ops.aten._local_scalar_dense.default(select_245); select_245 = None + ge_305 = _local_scalar_dense_245 >= 0 + _assert_scalar_245 = torch.ops.aten._assert_scalar.default(ge_305, "Runtime assertion failed for expression u245 >= 0 on node 'ge_245'"); ge_305 = _assert_scalar_245 = None + select_246 = torch.ops.aten.select.int(device_put_30, 0, 6) + _local_scalar_dense_246 = torch.ops.aten._local_scalar_dense.default(select_246); select_246 = None + ge_306 = _local_scalar_dense_246 >= 0 + _assert_scalar_246 = torch.ops.aten._assert_scalar.default(ge_306, "Runtime assertion failed for expression u246 >= 0 on node 'ge_246'"); ge_306 = _assert_scalar_246 = None + select_247 = torch.ops.aten.select.int(device_put_30, 0, 7); device_put_30 = None + _local_scalar_dense_247 = torch.ops.aten._local_scalar_dense.default(select_247); select_247 = None + ge_307 = _local_scalar_dense_247 >= 0 + _assert_scalar_247 = torch.ops.aten._assert_scalar.default(ge_307, "Runtime assertion failed for expression u247 >= 0 on node 'ge_247'"); ge_307 = _assert_scalar_247 = None + select_248 = torch.ops.aten.select.int(device_put_31, 0, 0) + _local_scalar_dense_248 = torch.ops.aten._local_scalar_dense.default(select_248); select_248 = None + ge_308 = _local_scalar_dense_248 >= 0 + _assert_scalar_248 = torch.ops.aten._assert_scalar.default(ge_308, "Runtime assertion failed for expression u248 >= 0 on node 'ge_248'"); ge_308 = _assert_scalar_248 = None + select_249 = torch.ops.aten.select.int(device_put_31, 0, 1) + _local_scalar_dense_249 = torch.ops.aten._local_scalar_dense.default(select_249); select_249 = None + ge_309 = _local_scalar_dense_249 >= 0 + _assert_scalar_249 = torch.ops.aten._assert_scalar.default(ge_309, "Runtime assertion failed for expression u249 >= 0 on node 'ge_249'"); ge_309 = _assert_scalar_249 = None + select_250 = torch.ops.aten.select.int(device_put_31, 0, 2) + _local_scalar_dense_250 = torch.ops.aten._local_scalar_dense.default(select_250); select_250 = None + ge_310 = _local_scalar_dense_250 >= 0 + _assert_scalar_250 = torch.ops.aten._assert_scalar.default(ge_310, "Runtime assertion failed for expression u250 >= 0 on node 'ge_250'"); ge_310 = _assert_scalar_250 = None + select_251 = torch.ops.aten.select.int(device_put_31, 0, 3) + _local_scalar_dense_251 = torch.ops.aten._local_scalar_dense.default(select_251); select_251 = None + ge_311 = _local_scalar_dense_251 >= 0 + _assert_scalar_251 = torch.ops.aten._assert_scalar.default(ge_311, "Runtime assertion failed for expression u251 >= 0 on node 'ge_251'"); ge_311 = _assert_scalar_251 = None + select_252 = torch.ops.aten.select.int(device_put_31, 0, 4) + _local_scalar_dense_252 = torch.ops.aten._local_scalar_dense.default(select_252); select_252 = None + ge_312 = _local_scalar_dense_252 >= 0 + _assert_scalar_252 = torch.ops.aten._assert_scalar.default(ge_312, "Runtime assertion failed for expression u252 >= 0 on node 'ge_252'"); ge_312 = _assert_scalar_252 = None + select_253 = torch.ops.aten.select.int(device_put_31, 0, 5) + _local_scalar_dense_253 = torch.ops.aten._local_scalar_dense.default(select_253); select_253 = None + ge_313 = _local_scalar_dense_253 >= 0 + _assert_scalar_253 = torch.ops.aten._assert_scalar.default(ge_313, "Runtime assertion failed for expression u253 >= 0 on node 'ge_253'"); ge_313 = _assert_scalar_253 = None + select_254 = torch.ops.aten.select.int(device_put_31, 0, 6) + _local_scalar_dense_254 = torch.ops.aten._local_scalar_dense.default(select_254); select_254 = None + ge_314 = _local_scalar_dense_254 >= 0 + _assert_scalar_254 = torch.ops.aten._assert_scalar.default(ge_314, "Runtime assertion failed for expression u254 >= 0 on node 'ge_254'"); ge_314 = _assert_scalar_254 = None + select_255 = torch.ops.aten.select.int(device_put_31, 0, 7); device_put_31 = None + _local_scalar_dense_255 = torch.ops.aten._local_scalar_dense.default(select_255); select_255 = None + ge_315 = _local_scalar_dense_255 >= 0 + _assert_scalar_255 = torch.ops.aten._assert_scalar.default(ge_315, "Runtime assertion failed for expression u255 >= 0 on node 'ge_255'"); ge_315 = _assert_scalar_255 = None + all_to_all_single_46 = torch.ops._c10d_functional.all_to_all_single.default(index_30, [_local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255], [_local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247], '1033'); index_30 = None + sym_size_int_60 = torch.ops.aten.sym_size.int(all_to_all_single_46, 0) + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_46); all_to_all_single_46 = None + sym_sum_30 = torch.sym_sum((_local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255)) + add_1038 = sym_sum_30 + 64; sym_sum_30 = None + add_1039 = add_1038 + 8; add_1038 = None + sub_363 = add_1039 - 1; add_1039 = None + floordiv_15 = sub_363 // 8; sub_363 = None + mul_757 = floordiv_15 * 8; floordiv_15 = None + cumsum_45 = torch.ops.aten.cumsum.default(wait_tensor_335, 0) + sub_364 = torch.ops.aten.sub.Tensor(cumsum_45, wait_tensor_335); cumsum_45 = None + sum_64 = torch.ops.aten.sum.dim_IntList(view_1070, [0]); view_1070 = None + clamp_min_15 = torch.ops.aten.clamp_min.default(sum_64, 8); sum_64 = None + add_1040 = torch.ops.aten.add.Tensor(clamp_min_15, 8); clamp_min_15 = None + sub_365 = torch.ops.aten.sub.Tensor(add_1040, 1); add_1040 = None + div_78 = torch.ops.aten.div.Tensor_mode(sub_365, 8, rounding_mode = 'floor'); sub_365 = None + mul_758 = torch.ops.aten.mul.Tensor(div_78, 8); div_78 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(mul_758, torch.int32); mul_758 = None + cumsum_46 = torch.ops.aten.cumsum.default(convert_element_type_878, 0) + sub_366 = torch.ops.aten.sub.Tensor(cumsum_46, convert_element_type_878); cumsum_46 = None + full_215 = torch.ops.aten.full.default([mul_757], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_757 = None + triton_kernel_wrapper_functional_proxy_15 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 15, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_335, 'start_index_values_ptr': sub_364, 'write_offsets_ptr': sub_366, 'output_ptr': full_215}, tensors_to_clone = ['output_ptr']); wait_tensor_335 = sub_364 = sub_366 = full_215 = None + getitem_1672 = triton_kernel_wrapper_functional_proxy_15['output_ptr']; triton_kernel_wrapper_functional_proxy_15 = None + cat_139 = torch.ops.aten.cat.default([wait_tensor_336, full_default]); wait_tensor_336 = None + sym_size_int_61 = torch.ops.aten.sym_size.int(cat_139, 0) + sym_sum_31 = torch.sym_sum((1, _local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255)) + index_31 = torch.ops.aten.index.Tensor(cat_139, [getitem_1672]); cat_139 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16) + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_880, 16, '1025'); convert_element_type_880 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + split_91 = torch.ops.aten.split.Tensor(wait_tensor_337, 8); wait_tensor_337 = None + getitem_1689 = split_91[0] + getitem_1690 = split_91[1] + getitem_1691 = split_91[2] + getitem_1692 = split_91[3] + getitem_1693 = split_91[4] + getitem_1694 = split_91[5] + getitem_1695 = split_91[6] + getitem_1696 = split_91[7] + getitem_1697 = split_91[8] + getitem_1698 = split_91[9] + getitem_1699 = split_91[10] + getitem_1700 = split_91[11] + getitem_1701 = split_91[12] + getitem_1702 = split_91[13] + getitem_1703 = split_91[14] + getitem_1704 = split_91[15]; split_91 = None + cat_141 = torch.ops.aten.cat.default([getitem_1689, getitem_1690, getitem_1691, getitem_1692, getitem_1693, getitem_1694, getitem_1695, getitem_1696, getitem_1697, getitem_1698, getitem_1699, getitem_1700, getitem_1701, getitem_1702, getitem_1703, getitem_1704], 1); getitem_1689 = getitem_1690 = getitem_1691 = getitem_1692 = getitem_1693 = getitem_1694 = getitem_1695 = getitem_1696 = getitem_1697 = getitem_1698 = getitem_1699 = getitem_1700 = getitem_1701 = getitem_1702 = getitem_1703 = getitem_1704 = None + convert_element_type_882 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16) + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_882, 16, '1025'); convert_element_type_882 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + split_92 = torch.ops.aten.split.Tensor(wait_tensor_339, 8); wait_tensor_339 = None + getitem_1705 = split_92[0] + getitem_1706 = split_92[1] + getitem_1707 = split_92[2] + getitem_1708 = split_92[3] + getitem_1709 = split_92[4] + getitem_1710 = split_92[5] + getitem_1711 = split_92[6] + getitem_1712 = split_92[7] + getitem_1713 = split_92[8] + getitem_1714 = split_92[9] + getitem_1715 = split_92[10] + getitem_1716 = split_92[11] + getitem_1717 = split_92[12] + getitem_1718 = split_92[13] + getitem_1719 = split_92[14] + getitem_1720 = split_92[15]; split_92 = None + cat_142 = torch.ops.aten.cat.default([getitem_1705, getitem_1706, getitem_1707, getitem_1708, getitem_1709, getitem_1710, getitem_1711, getitem_1712, getitem_1713, getitem_1714, getitem_1715, getitem_1716, getitem_1717, getitem_1718, getitem_1719, getitem_1720], 1); getitem_1705 = getitem_1706 = getitem_1707 = getitem_1708 = getitem_1709 = getitem_1710 = getitem_1711 = getitem_1712 = getitem_1713 = getitem_1714 = getitem_1715 = getitem_1716 = getitem_1717 = getitem_1718 = getitem_1719 = getitem_1720 = None + convert_element_type_883 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16) + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_883, 16, '1025'); convert_element_type_883 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + split_93 = torch.ops.aten.split.Tensor(wait_tensor_340, 8); wait_tensor_340 = None + getitem_1721 = split_93[0] + getitem_1722 = split_93[1] + getitem_1723 = split_93[2] + getitem_1724 = split_93[3] + getitem_1725 = split_93[4] + getitem_1726 = split_93[5] + getitem_1727 = split_93[6] + getitem_1728 = split_93[7] + getitem_1729 = split_93[8] + getitem_1730 = split_93[9] + getitem_1731 = split_93[10] + getitem_1732 = split_93[11] + getitem_1733 = split_93[12] + getitem_1734 = split_93[13] + getitem_1735 = split_93[14] + getitem_1736 = split_93[15]; split_93 = None + cat_143 = torch.ops.aten.cat.default([getitem_1721, getitem_1722, getitem_1723, getitem_1724, getitem_1725, getitem_1726, getitem_1727, getitem_1728, getitem_1729, getitem_1730, getitem_1731, getitem_1732, getitem_1733, getitem_1734, getitem_1735, getitem_1736], 1); getitem_1721 = getitem_1722 = getitem_1723 = getitem_1724 = getitem_1725 = getitem_1726 = getitem_1727 = getitem_1728 = getitem_1729 = getitem_1730 = getitem_1731 = getitem_1732 = getitem_1733 = getitem_1734 = getitem_1735 = getitem_1736 = None + cumsum_47 = torch.ops.aten.cumsum.default(convert_element_type_878, 0, dtype = torch.int32); convert_element_type_878 = None + permute_245 = torch.ops.aten.permute.default(cat_141, [0, 2, 1]); cat_141 = None + _grouped_mm_45 = torch.ops.aten._grouped_mm.default(index_31, permute_245, cumsum_47) + convert_element_type_886 = torch.ops.prims.convert_element_type.default(_grouped_mm_45, torch.float32) + neg_31 = torch.ops.aten.neg.default(convert_element_type_886) + exp_47 = torch.ops.aten.exp.default(neg_31); neg_31 = None + add_1052 = torch.ops.aten.add.Tensor(exp_47, 1); exp_47 = None + div_79 = torch.ops.aten.div.Tensor(convert_element_type_886, add_1052); convert_element_type_886 = add_1052 = None + convert_element_type_887 = torch.ops.prims.convert_element_type.default(div_79, torch.bfloat16); div_79 = None + permute_246 = torch.ops.aten.permute.default(cat_143, [0, 2, 1]); cat_143 = None + _grouped_mm_46 = torch.ops.aten._grouped_mm.default(index_31, permute_246, cumsum_47) + mul_770 = torch.ops.aten.mul.Tensor(convert_element_type_887, _grouped_mm_46); convert_element_type_887 = None + permute_247 = torch.ops.aten.permute.default(cat_142, [0, 2, 1]); cat_142 = None + _grouped_mm_47 = torch.ops.aten._grouped_mm.default(mul_770, permute_247, cumsum_47) + empty_15 = torch.ops.aten.empty.memory_format([sym_size_int_61, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_30 = torch.ops.aten.index_put.default(empty_15, [getitem_1672], _grouped_mm_47); empty_15 = _grouped_mm_47 = None + slice_101 = torch.ops.aten.slice.Tensor(index_put_30, 0, 0, -1); index_put_30 = None + all_to_all_single_47 = torch.ops._c10d_functional.all_to_all_single.default(slice_101, [_local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247], [_local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255], '1033'); slice_101 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_47); all_to_all_single_47 = None + convert_element_type_888 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16) + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_888, 128, '0'); convert_element_type_888 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + permute_248 = torch.ops.aten.permute.default(wait_tensor_344, [1, 0]); wait_tensor_344 = None + mm_132 = torch.ops.aten.mm.default(view_1063, permute_248); permute_248 = None + convert_element_type_891 = torch.ops.prims.convert_element_type.default(mm_132, torch.float32) + neg_32 = torch.ops.aten.neg.default(convert_element_type_891) + exp_48 = torch.ops.aten.exp.default(neg_32); neg_32 = None + add_1088 = torch.ops.aten.add.Tensor(exp_48, 1); exp_48 = None + div_80 = torch.ops.aten.div.Tensor(convert_element_type_891, add_1088); convert_element_type_891 = add_1088 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(div_80, torch.bfloat16); div_80 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16) + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_893, 128, '0'); convert_element_type_893 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_345, [1, 0]); wait_tensor_345 = None + mm_133 = torch.ops.aten.mm.default(view_1063, permute_249); permute_249 = None + mul_790 = torch.ops.aten.mul.Tensor(convert_element_type_892, mm_133); convert_element_type_892 = None + convert_element_type_896 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16) + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_896, 128, '0'); convert_element_type_896 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_346, [1, 0]); wait_tensor_346 = None + mm_134 = torch.ops.aten.mm.default(mul_790, permute_250); permute_250 = None + index_put_31 = torch.ops.aten.index_put.default(full_default_1, [getitem_1671], wait_tensor_343); wait_tensor_343 = None + view_1103 = torch.ops.aten.view.default(mul_752, [-1, 1, 6]); mul_752 = None + view_1104 = torch.ops.aten.view.default(index_put_31, [-1, 6, 2048]); index_put_31 = None + convert_element_type_899 = torch.ops.prims.convert_element_type.default(view_1104, torch.float32); view_1104 = None + bmm_15 = torch.ops.aten.bmm.default(view_1103, convert_element_type_899) + convert_element_type_900 = torch.ops.prims.convert_element_type.default(bmm_15, torch.bfloat16); bmm_15 = None + squeeze_15 = torch.ops.aten.squeeze.dim(convert_element_type_900, 1); convert_element_type_900 = None + add_1092 = torch.ops.aten.add.Tensor(mm_134, squeeze_15); mm_134 = squeeze_15 = None + view_1105 = torch.ops.aten.view.default(add_1092, [2, 4096, 2048]); add_1092 = None + add_1093 = torch.ops.aten.add.Tensor(add_1028, view_1105); view_1105 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16) + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 128, '0'); convert_element_type_901 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + convert_element_type_902 = torch.ops.prims.convert_element_type.default(add_1093, torch.float32) + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_902, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_1094 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_1094); add_1094 = None + mul_793 = torch.ops.aten.mul.Tensor(convert_element_type_902, rsqrt_51); convert_element_type_902 = None + mul_794 = torch.ops.aten.mul.Tensor(mul_793, wait_tensor_347); mul_793 = wait_tensor_347 = None + convert_element_type_903 = torch.ops.prims.convert_element_type.default(mul_794, torch.bfloat16); mul_794 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16) + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_904, 128, '0'); convert_element_type_904 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_348, [1, 0]); wait_tensor_348 = None + view_1108 = torch.ops.aten.view.default(convert_element_type_903, [8192, 2048]); convert_element_type_903 = None + mm_135 = torch.ops.aten.mm.default(view_1108, permute_251); permute_251 = None + view_1109 = torch.ops.aten.view.default(mm_135, [2, 4096, 3072]); mm_135 = None + view_1110 = torch.ops.aten.view.default(view_1109, [2, 4096, -1, 192]); view_1109 = None + split_with_sizes_51 = torch.ops.aten.split_with_sizes.default(view_1110, [128, 64], -1); view_1110 = None + getitem_1769 = split_with_sizes_51[0] + getitem_1770 = split_with_sizes_51[1]; split_with_sizes_51 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(getitem_1770, torch.float32); getitem_1770 = None + view_1111 = torch.ops.aten.view.default(convert_element_type_907, [2, 4096, 16, -1, 2]); convert_element_type_907 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_1111); view_1111 = None + mul_795 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_7); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_795); mul_795 = None + view_1113 = torch.ops.aten.view.default(view_as_real_34, [2, 4096, 16, 64]); view_as_real_34 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(view_1113, torch.bfloat16); view_1113 = None + cat_146 = torch.ops.aten.cat.default([getitem_1769, convert_element_type_908], -1); getitem_1769 = convert_element_type_908 = None + convert_element_type_909 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16) + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_909, 128, '0'); convert_element_type_909 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + slice_103 = torch.ops.aten.slice.Tensor(wait_tensor_349, 0, 0, 576); wait_tensor_349 = None + permute_252 = torch.ops.aten.permute.default(slice_103, [1, 0]); slice_103 = None + mm_136 = torch.ops.aten.mm.default(view_1108, permute_252); permute_252 = None + view_1116 = torch.ops.aten.view.default(mm_136, [2, 4096, 576]); mm_136 = None + split_with_sizes_52 = torch.ops.aten.split_with_sizes.default(view_1116, [512, 64], -1); view_1116 = None + getitem_1771 = split_with_sizes_52[0] + getitem_1772 = split_with_sizes_52[1]; split_with_sizes_52 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(getitem_1772, 2); getitem_1772 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(unsqueeze_33, torch.float32); unsqueeze_33 = None + view_1117 = torch.ops.aten.view.default(convert_element_type_912, [2, 4096, 1, -1, 2]); convert_element_type_912 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_1117); view_1117 = None + mul_796 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_7); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_796); mul_796 = None + view_1119 = torch.ops.aten.view.default(view_as_real_35, [2, 4096, 1, 64]); view_as_real_35 = None + convert_element_type_913 = torch.ops.prims.convert_element_type.default(view_1119, torch.bfloat16); view_1119 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16) + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 128, '0'); convert_element_type_914 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + convert_element_type_915 = torch.ops.prims.convert_element_type.default(getitem_1771, torch.float32) + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_915, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_1095 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_1095); add_1095 = None + mul_797 = torch.ops.aten.mul.Tensor(convert_element_type_915, rsqrt_52); convert_element_type_915 = None + mul_798 = torch.ops.aten.mul.Tensor(mul_797, wait_tensor_350); mul_797 = wait_tensor_350 = None + convert_element_type_916 = torch.ops.prims.convert_element_type.default(mul_798, torch.bfloat16); mul_798 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16) + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_917, 128, '0'); convert_element_type_917 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_351, [1, 0]); wait_tensor_351 = None + view_1122 = torch.ops.aten.view.default(convert_element_type_916, [8192, 512]); convert_element_type_916 = None + mm_137 = torch.ops.aten.mm.default(view_1122, permute_253); permute_253 = None + view_1123 = torch.ops.aten.view.default(mm_137, [2, 4096, 4096]); mm_137 = None + view_1124 = torch.ops.aten.view.default(view_1123, [2, 4096, -1, 256]); view_1123 = None + split_with_sizes_53 = torch.ops.aten.split_with_sizes.default(view_1124, [128, 128], -1); view_1124 = None + getitem_1773 = split_with_sizes_53[0] + getitem_1774 = split_with_sizes_53[1]; split_with_sizes_53 = None + expand_17 = torch.ops.aten.expand.default(convert_element_type_913, [-1, -1, 16, -1]); convert_element_type_913 = None + cat_147 = torch.ops.aten.cat.default([getitem_1773, expand_17], -1); getitem_1773 = expand_17 = None + permute_254 = torch.ops.aten.permute.default(cat_146, [0, 2, 1, 3]); cat_146 = None + permute_255 = torch.ops.aten.permute.default(cat_147, [0, 2, 1, 3]); cat_147 = None + permute_256 = torch.ops.aten.permute.default(getitem_1774, [0, 2, 1, 3]); getitem_1774 = None + sdpa_score17 = self.sdpa_score17 + sdpa_mask17 = self.sdpa_mask17 + flex_attention_17 = torch.ops.higher_order.flex_attention(permute_254, permute_255, permute_256, sdpa_score17, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask17), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score17 = sdpa_mask17 = None + getitem_1775 = flex_attention_17[0] + getitem_1776 = flex_attention_17[1]; flex_attention_17 = None + permute_257 = torch.ops.aten.permute.default(getitem_1775, [0, 2, 1, 3]) + view_1125 = torch.ops.aten.view.default(permute_257, [2, 4096, -1]); permute_257 = None + convert_element_type_920 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16) + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_920, 128, '0'); convert_element_type_920 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_258 = torch.ops.aten.permute.default(wait_tensor_352, [1, 0]); wait_tensor_352 = None + view_1127 = torch.ops.aten.view.default(view_1125, [8192, 2048]); view_1125 = None + mm_138 = torch.ops.aten.mm.default(view_1127, permute_258); view_1127 = permute_258 = None + view_1128 = torch.ops.aten.view.default(mm_138, [2, 4096, 2048]); mm_138 = None + add_1096 = torch.ops.aten.add.Tensor(add_1093, view_1128); view_1128 = None + convert_element_type_923 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16) + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_923, 128, '0'); convert_element_type_923 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_924 = torch.ops.prims.convert_element_type.default(add_1096, torch.float32) + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_924, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_1097 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_1097); add_1097 = None + mul_799 = torch.ops.aten.mul.Tensor(convert_element_type_924, rsqrt_53); convert_element_type_924 = None + mul_800 = torch.ops.aten.mul.Tensor(mul_799, wait_tensor_353); mul_799 = wait_tensor_353 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(mul_800, torch.bfloat16); mul_800 = None + view_1130 = torch.ops.aten.view.default(convert_element_type_925, [-1, 2048]); convert_element_type_925 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16) + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_926, 128, '0'); convert_element_type_926 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + slice_105 = torch.ops.aten.slice.Tensor(wait_tensor_354, 0, 0, 64); wait_tensor_354 = None + permute_259 = torch.ops.aten.permute.default(slice_105, [1, 0]); slice_105 = None + mm_139 = torch.ops.aten.mm.default(view_1130, permute_259); permute_259 = None + convert_element_type_929 = torch.ops.prims.convert_element_type.default(mm_139, torch.float32) + amax_16 = torch.ops.aten.amax.default(convert_element_type_929, [1], True) + sub_384 = torch.ops.aten.sub.Tensor(convert_element_type_929, amax_16); convert_element_type_929 = None + exp_49 = torch.ops.aten.exp.default(sub_384); sub_384 = None + sum_65 = torch.ops.aten.sum.dim_IntList(exp_49, [1], True) + div_81 = torch.ops.aten.div.Tensor(exp_49, sum_65); exp_49 = None + add_1098 = torch.ops.aten.add.Tensor(div_81, primals_286); primals_286 = None + topk_16 = torch.ops.aten.topk.default(add_1098, 6, -1, True, False); add_1098 = None + getitem_1779 = topk_16[1]; topk_16 = None + gather_16 = torch.ops.aten.gather.default(div_81, 1, getitem_1779); div_81 = None + mul_801 = torch.ops.aten.mul.Tensor(gather_16, 1.0); gather_16 = None + view_1132 = torch.ops.aten.view.default(getitem_1779, [-1]) + histc_32 = torch.ops.aten.histc.default(view_1132, 64, 0, 64) + add_1099 = torch.ops.aten.add.Tensor(primals_288, histc_32) + sort_16 = torch.ops.aten.sort.stable(view_1132, stable = True); view_1132 = None + getitem_1781 = sort_16[1]; sort_16 = None + div_82 = torch.ops.aten.div.Tensor_mode(getitem_1781, 6, rounding_mode = 'floor') + index_32 = torch.ops.aten.index.Tensor(view_1130, [div_82]) + all_to_all_single_48 = torch.ops._c10d_functional.all_to_all_single.default(histc_32, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_48); all_to_all_single_48 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_355); wait_tensor_355 = None + view_1136 = torch.ops.aten.view.default(histc_32, [8, -1]); histc_32 = None + sum_66 = torch.ops.aten.sum.dim_IntList(view_1136, [1]); view_1136 = None + device_put_32 = torch.ops.prims.device_put.default(sum_66, device(type='cpu'), True); sum_66 = None + view_1137 = torch.ops.aten.view.default(wait_tensor_356, [8, -1]) + sum_67 = torch.ops.aten.sum.dim_IntList(view_1137, [1]) + device_put_33 = torch.ops.prims.device_put.default(sum_67, device(type='cpu')); sum_67 = None + select_256 = torch.ops.aten.select.int(device_put_32, 0, 0) + _local_scalar_dense_256 = torch.ops.aten._local_scalar_dense.default(select_256); select_256 = None + ge_320 = _local_scalar_dense_256 >= 0 + _assert_scalar_256 = torch.ops.aten._assert_scalar.default(ge_320, "Runtime assertion failed for expression u256 >= 0 on node 'ge_256'"); ge_320 = _assert_scalar_256 = None + select_257 = torch.ops.aten.select.int(device_put_32, 0, 1) + _local_scalar_dense_257 = torch.ops.aten._local_scalar_dense.default(select_257); select_257 = None + ge_321 = _local_scalar_dense_257 >= 0 + _assert_scalar_257 = torch.ops.aten._assert_scalar.default(ge_321, "Runtime assertion failed for expression u257 >= 0 on node 'ge_257'"); ge_321 = _assert_scalar_257 = None + select_258 = torch.ops.aten.select.int(device_put_32, 0, 2) + _local_scalar_dense_258 = torch.ops.aten._local_scalar_dense.default(select_258); select_258 = None + ge_322 = _local_scalar_dense_258 >= 0 + _assert_scalar_258 = torch.ops.aten._assert_scalar.default(ge_322, "Runtime assertion failed for expression u258 >= 0 on node 'ge_258'"); ge_322 = _assert_scalar_258 = None + select_259 = torch.ops.aten.select.int(device_put_32, 0, 3) + _local_scalar_dense_259 = torch.ops.aten._local_scalar_dense.default(select_259); select_259 = None + ge_323 = _local_scalar_dense_259 >= 0 + _assert_scalar_259 = torch.ops.aten._assert_scalar.default(ge_323, "Runtime assertion failed for expression u259 >= 0 on node 'ge_259'"); ge_323 = _assert_scalar_259 = None + select_260 = torch.ops.aten.select.int(device_put_32, 0, 4) + _local_scalar_dense_260 = torch.ops.aten._local_scalar_dense.default(select_260); select_260 = None + ge_324 = _local_scalar_dense_260 >= 0 + _assert_scalar_260 = torch.ops.aten._assert_scalar.default(ge_324, "Runtime assertion failed for expression u260 >= 0 on node 'ge_260'"); ge_324 = _assert_scalar_260 = None + select_261 = torch.ops.aten.select.int(device_put_32, 0, 5) + _local_scalar_dense_261 = torch.ops.aten._local_scalar_dense.default(select_261); select_261 = None + ge_325 = _local_scalar_dense_261 >= 0 + _assert_scalar_261 = torch.ops.aten._assert_scalar.default(ge_325, "Runtime assertion failed for expression u261 >= 0 on node 'ge_261'"); ge_325 = _assert_scalar_261 = None + select_262 = torch.ops.aten.select.int(device_put_32, 0, 6) + _local_scalar_dense_262 = torch.ops.aten._local_scalar_dense.default(select_262); select_262 = None + ge_326 = _local_scalar_dense_262 >= 0 + _assert_scalar_262 = torch.ops.aten._assert_scalar.default(ge_326, "Runtime assertion failed for expression u262 >= 0 on node 'ge_262'"); ge_326 = _assert_scalar_262 = None + select_263 = torch.ops.aten.select.int(device_put_32, 0, 7); device_put_32 = None + _local_scalar_dense_263 = torch.ops.aten._local_scalar_dense.default(select_263); select_263 = None + ge_327 = _local_scalar_dense_263 >= 0 + _assert_scalar_263 = torch.ops.aten._assert_scalar.default(ge_327, "Runtime assertion failed for expression u263 >= 0 on node 'ge_263'"); ge_327 = _assert_scalar_263 = None + select_264 = torch.ops.aten.select.int(device_put_33, 0, 0) + _local_scalar_dense_264 = torch.ops.aten._local_scalar_dense.default(select_264); select_264 = None + ge_328 = _local_scalar_dense_264 >= 0 + _assert_scalar_264 = torch.ops.aten._assert_scalar.default(ge_328, "Runtime assertion failed for expression u264 >= 0 on node 'ge_264'"); ge_328 = _assert_scalar_264 = None + select_265 = torch.ops.aten.select.int(device_put_33, 0, 1) + _local_scalar_dense_265 = torch.ops.aten._local_scalar_dense.default(select_265); select_265 = None + ge_329 = _local_scalar_dense_265 >= 0 + _assert_scalar_265 = torch.ops.aten._assert_scalar.default(ge_329, "Runtime assertion failed for expression u265 >= 0 on node 'ge_265'"); ge_329 = _assert_scalar_265 = None + select_266 = torch.ops.aten.select.int(device_put_33, 0, 2) + _local_scalar_dense_266 = torch.ops.aten._local_scalar_dense.default(select_266); select_266 = None + ge_330 = _local_scalar_dense_266 >= 0 + _assert_scalar_266 = torch.ops.aten._assert_scalar.default(ge_330, "Runtime assertion failed for expression u266 >= 0 on node 'ge_266'"); ge_330 = _assert_scalar_266 = None + select_267 = torch.ops.aten.select.int(device_put_33, 0, 3) + _local_scalar_dense_267 = torch.ops.aten._local_scalar_dense.default(select_267); select_267 = None + ge_331 = _local_scalar_dense_267 >= 0 + _assert_scalar_267 = torch.ops.aten._assert_scalar.default(ge_331, "Runtime assertion failed for expression u267 >= 0 on node 'ge_267'"); ge_331 = _assert_scalar_267 = None + select_268 = torch.ops.aten.select.int(device_put_33, 0, 4) + _local_scalar_dense_268 = torch.ops.aten._local_scalar_dense.default(select_268); select_268 = None + ge_332 = _local_scalar_dense_268 >= 0 + _assert_scalar_268 = torch.ops.aten._assert_scalar.default(ge_332, "Runtime assertion failed for expression u268 >= 0 on node 'ge_268'"); ge_332 = _assert_scalar_268 = None + select_269 = torch.ops.aten.select.int(device_put_33, 0, 5) + _local_scalar_dense_269 = torch.ops.aten._local_scalar_dense.default(select_269); select_269 = None + ge_333 = _local_scalar_dense_269 >= 0 + _assert_scalar_269 = torch.ops.aten._assert_scalar.default(ge_333, "Runtime assertion failed for expression u269 >= 0 on node 'ge_269'"); ge_333 = _assert_scalar_269 = None + select_270 = torch.ops.aten.select.int(device_put_33, 0, 6) + _local_scalar_dense_270 = torch.ops.aten._local_scalar_dense.default(select_270); select_270 = None + ge_334 = _local_scalar_dense_270 >= 0 + _assert_scalar_270 = torch.ops.aten._assert_scalar.default(ge_334, "Runtime assertion failed for expression u270 >= 0 on node 'ge_270'"); ge_334 = _assert_scalar_270 = None + select_271 = torch.ops.aten.select.int(device_put_33, 0, 7); device_put_33 = None + _local_scalar_dense_271 = torch.ops.aten._local_scalar_dense.default(select_271); select_271 = None + ge_335 = _local_scalar_dense_271 >= 0 + _assert_scalar_271 = torch.ops.aten._assert_scalar.default(ge_335, "Runtime assertion failed for expression u271 >= 0 on node 'ge_271'"); ge_335 = _assert_scalar_271 = None + all_to_all_single_49 = torch.ops._c10d_functional.all_to_all_single.default(index_32, [_local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271], [_local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263], '1033'); index_32 = None + sym_size_int_64 = torch.ops.aten.sym_size.int(all_to_all_single_49, 0) + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_49); all_to_all_single_49 = None + sym_sum_32 = torch.sym_sum((_local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271)) + add_1106 = sym_sum_32 + 64; sym_sum_32 = None + add_1107 = add_1106 + 8; add_1106 = None + sub_387 = add_1107 - 1; add_1107 = None + floordiv_16 = sub_387 // 8; sub_387 = None + mul_806 = floordiv_16 * 8; floordiv_16 = None + cumsum_48 = torch.ops.aten.cumsum.default(wait_tensor_356, 0) + sub_388 = torch.ops.aten.sub.Tensor(cumsum_48, wait_tensor_356); cumsum_48 = None + sum_68 = torch.ops.aten.sum.dim_IntList(view_1137, [0]); view_1137 = None + clamp_min_16 = torch.ops.aten.clamp_min.default(sum_68, 8); sum_68 = None + add_1108 = torch.ops.aten.add.Tensor(clamp_min_16, 8); clamp_min_16 = None + sub_389 = torch.ops.aten.sub.Tensor(add_1108, 1); add_1108 = None + div_83 = torch.ops.aten.div.Tensor_mode(sub_389, 8, rounding_mode = 'floor'); sub_389 = None + mul_807 = torch.ops.aten.mul.Tensor(div_83, 8); div_83 = None + convert_element_type_932 = torch.ops.prims.convert_element_type.default(mul_807, torch.int32); mul_807 = None + cumsum_49 = torch.ops.aten.cumsum.default(convert_element_type_932, 0) + sub_390 = torch.ops.aten.sub.Tensor(cumsum_49, convert_element_type_932); cumsum_49 = None + full_228 = torch.ops.aten.full.default([mul_806], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_806 = None + triton_kernel_wrapper_functional_proxy_16 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 16, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_356, 'start_index_values_ptr': sub_388, 'write_offsets_ptr': sub_390, 'output_ptr': full_228}, tensors_to_clone = ['output_ptr']); wait_tensor_356 = sub_388 = sub_390 = full_228 = None + getitem_1782 = triton_kernel_wrapper_functional_proxy_16['output_ptr']; triton_kernel_wrapper_functional_proxy_16 = None + cat_148 = torch.ops.aten.cat.default([wait_tensor_357, full_default]); wait_tensor_357 = None + sym_size_int_65 = torch.ops.aten.sym_size.int(cat_148, 0) + sym_sum_33 = torch.sym_sum((1, _local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271)) + index_33 = torch.ops.aten.index.Tensor(cat_148, [getitem_1782]); cat_148 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16) + all_gather_into_tensor_291 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 16, '1025'); convert_element_type_934 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_291); all_gather_into_tensor_291 = None + split_97 = torch.ops.aten.split.Tensor(wait_tensor_358, 8); wait_tensor_358 = None + getitem_1799 = split_97[0] + getitem_1800 = split_97[1] + getitem_1801 = split_97[2] + getitem_1802 = split_97[3] + getitem_1803 = split_97[4] + getitem_1804 = split_97[5] + getitem_1805 = split_97[6] + getitem_1806 = split_97[7] + getitem_1807 = split_97[8] + getitem_1808 = split_97[9] + getitem_1809 = split_97[10] + getitem_1810 = split_97[11] + getitem_1811 = split_97[12] + getitem_1812 = split_97[13] + getitem_1813 = split_97[14] + getitem_1814 = split_97[15]; split_97 = None + cat_150 = torch.ops.aten.cat.default([getitem_1799, getitem_1800, getitem_1801, getitem_1802, getitem_1803, getitem_1804, getitem_1805, getitem_1806, getitem_1807, getitem_1808, getitem_1809, getitem_1810, getitem_1811, getitem_1812, getitem_1813, getitem_1814], 1); getitem_1799 = getitem_1800 = getitem_1801 = getitem_1802 = getitem_1803 = getitem_1804 = getitem_1805 = getitem_1806 = getitem_1807 = getitem_1808 = getitem_1809 = getitem_1810 = getitem_1811 = getitem_1812 = getitem_1813 = getitem_1814 = None + convert_element_type_936 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16) + all_gather_into_tensor_293 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_936, 16, '1025'); convert_element_type_936 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_293); all_gather_into_tensor_293 = None + split_98 = torch.ops.aten.split.Tensor(wait_tensor_360, 8); wait_tensor_360 = None + getitem_1815 = split_98[0] + getitem_1816 = split_98[1] + getitem_1817 = split_98[2] + getitem_1818 = split_98[3] + getitem_1819 = split_98[4] + getitem_1820 = split_98[5] + getitem_1821 = split_98[6] + getitem_1822 = split_98[7] + getitem_1823 = split_98[8] + getitem_1824 = split_98[9] + getitem_1825 = split_98[10] + getitem_1826 = split_98[11] + getitem_1827 = split_98[12] + getitem_1828 = split_98[13] + getitem_1829 = split_98[14] + getitem_1830 = split_98[15]; split_98 = None + cat_151 = torch.ops.aten.cat.default([getitem_1815, getitem_1816, getitem_1817, getitem_1818, getitem_1819, getitem_1820, getitem_1821, getitem_1822, getitem_1823, getitem_1824, getitem_1825, getitem_1826, getitem_1827, getitem_1828, getitem_1829, getitem_1830], 1); getitem_1815 = getitem_1816 = getitem_1817 = getitem_1818 = getitem_1819 = getitem_1820 = getitem_1821 = getitem_1822 = getitem_1823 = getitem_1824 = getitem_1825 = getitem_1826 = getitem_1827 = getitem_1828 = getitem_1829 = getitem_1830 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16) + all_gather_into_tensor_294 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_937, 16, '1025'); convert_element_type_937 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_294); all_gather_into_tensor_294 = None + split_99 = torch.ops.aten.split.Tensor(wait_tensor_361, 8); wait_tensor_361 = None + getitem_1831 = split_99[0] + getitem_1832 = split_99[1] + getitem_1833 = split_99[2] + getitem_1834 = split_99[3] + getitem_1835 = split_99[4] + getitem_1836 = split_99[5] + getitem_1837 = split_99[6] + getitem_1838 = split_99[7] + getitem_1839 = split_99[8] + getitem_1840 = split_99[9] + getitem_1841 = split_99[10] + getitem_1842 = split_99[11] + getitem_1843 = split_99[12] + getitem_1844 = split_99[13] + getitem_1845 = split_99[14] + getitem_1846 = split_99[15]; split_99 = None + cat_152 = torch.ops.aten.cat.default([getitem_1831, getitem_1832, getitem_1833, getitem_1834, getitem_1835, getitem_1836, getitem_1837, getitem_1838, getitem_1839, getitem_1840, getitem_1841, getitem_1842, getitem_1843, getitem_1844, getitem_1845, getitem_1846], 1); getitem_1831 = getitem_1832 = getitem_1833 = getitem_1834 = getitem_1835 = getitem_1836 = getitem_1837 = getitem_1838 = getitem_1839 = getitem_1840 = getitem_1841 = getitem_1842 = getitem_1843 = getitem_1844 = getitem_1845 = getitem_1846 = None + cumsum_50 = torch.ops.aten.cumsum.default(convert_element_type_932, 0, dtype = torch.int32); convert_element_type_932 = None + permute_260 = torch.ops.aten.permute.default(cat_150, [0, 2, 1]); cat_150 = None + _grouped_mm_48 = torch.ops.aten._grouped_mm.default(index_33, permute_260, cumsum_50) + convert_element_type_940 = torch.ops.prims.convert_element_type.default(_grouped_mm_48, torch.float32) + neg_33 = torch.ops.aten.neg.default(convert_element_type_940) + exp_50 = torch.ops.aten.exp.default(neg_33); neg_33 = None + add_1120 = torch.ops.aten.add.Tensor(exp_50, 1); exp_50 = None + div_84 = torch.ops.aten.div.Tensor(convert_element_type_940, add_1120); convert_element_type_940 = add_1120 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(div_84, torch.bfloat16); div_84 = None + permute_261 = torch.ops.aten.permute.default(cat_152, [0, 2, 1]); cat_152 = None + _grouped_mm_49 = torch.ops.aten._grouped_mm.default(index_33, permute_261, cumsum_50) + mul_819 = torch.ops.aten.mul.Tensor(convert_element_type_941, _grouped_mm_49); convert_element_type_941 = None + permute_262 = torch.ops.aten.permute.default(cat_151, [0, 2, 1]); cat_151 = None + _grouped_mm_50 = torch.ops.aten._grouped_mm.default(mul_819, permute_262, cumsum_50) + empty_16 = torch.ops.aten.empty.memory_format([sym_size_int_65, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_32 = torch.ops.aten.index_put.default(empty_16, [getitem_1782], _grouped_mm_50); empty_16 = _grouped_mm_50 = None + slice_107 = torch.ops.aten.slice.Tensor(index_put_32, 0, 0, -1); index_put_32 = None + all_to_all_single_50 = torch.ops._c10d_functional.all_to_all_single.default(slice_107, [_local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263], [_local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271], '1033'); slice_107 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_50); all_to_all_single_50 = None + convert_element_type_942 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16) + all_gather_into_tensor_297 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_942, 128, '0'); convert_element_type_942 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_297); all_gather_into_tensor_297 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_365, [1, 0]); wait_tensor_365 = None + mm_140 = torch.ops.aten.mm.default(view_1130, permute_263); permute_263 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(mm_140, torch.float32) + neg_34 = torch.ops.aten.neg.default(convert_element_type_945) + exp_51 = torch.ops.aten.exp.default(neg_34); neg_34 = None + add_1156 = torch.ops.aten.add.Tensor(exp_51, 1); exp_51 = None + div_85 = torch.ops.aten.div.Tensor(convert_element_type_945, add_1156); convert_element_type_945 = add_1156 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(div_85, torch.bfloat16); div_85 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16) + all_gather_into_tensor_298 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 128, '0'); convert_element_type_947 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_298); all_gather_into_tensor_298 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_366, [1, 0]); wait_tensor_366 = None + mm_141 = torch.ops.aten.mm.default(view_1130, permute_264); permute_264 = None + mul_839 = torch.ops.aten.mul.Tensor(convert_element_type_946, mm_141); convert_element_type_946 = None + convert_element_type_950 = torch.ops.prims.convert_element_type.default(primals_294, torch.bfloat16) + all_gather_into_tensor_299 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_950, 128, '0'); convert_element_type_950 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_299); all_gather_into_tensor_299 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_367, [1, 0]); wait_tensor_367 = None + mm_142 = torch.ops.aten.mm.default(mul_839, permute_265); permute_265 = None + index_put_33 = torch.ops.aten.index_put.default(full_default_1, [getitem_1781], wait_tensor_364); wait_tensor_364 = None + view_1170 = torch.ops.aten.view.default(mul_801, [-1, 1, 6]); mul_801 = None + view_1171 = torch.ops.aten.view.default(index_put_33, [-1, 6, 2048]); index_put_33 = None + convert_element_type_953 = torch.ops.prims.convert_element_type.default(view_1171, torch.float32); view_1171 = None + bmm_16 = torch.ops.aten.bmm.default(view_1170, convert_element_type_953) + convert_element_type_954 = torch.ops.prims.convert_element_type.default(bmm_16, torch.bfloat16); bmm_16 = None + squeeze_16 = torch.ops.aten.squeeze.dim(convert_element_type_954, 1); convert_element_type_954 = None + add_1160 = torch.ops.aten.add.Tensor(mm_142, squeeze_16); mm_142 = squeeze_16 = None + view_1172 = torch.ops.aten.view.default(add_1160, [2, 4096, 2048]); add_1160 = None + add_1161 = torch.ops.aten.add.Tensor(add_1096, view_1172); view_1172 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_295, torch.bfloat16) + all_gather_into_tensor_300 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 128, '0'); convert_element_type_955 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_300); all_gather_into_tensor_300 = None + convert_element_type_956 = torch.ops.prims.convert_element_type.default(add_1161, torch.float32) + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_956, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_1162 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_1162); add_1162 = None + mul_842 = torch.ops.aten.mul.Tensor(convert_element_type_956, rsqrt_54); convert_element_type_956 = None + mul_843 = torch.ops.aten.mul.Tensor(mul_842, wait_tensor_368); mul_842 = wait_tensor_368 = None + convert_element_type_957 = torch.ops.prims.convert_element_type.default(mul_843, torch.bfloat16); mul_843 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_296, torch.bfloat16) + all_gather_into_tensor_301 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 128, '0'); convert_element_type_958 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_301); all_gather_into_tensor_301 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_369, [1, 0]); wait_tensor_369 = None + view_1175 = torch.ops.aten.view.default(convert_element_type_957, [8192, 2048]); convert_element_type_957 = None + mm_143 = torch.ops.aten.mm.default(view_1175, permute_266); permute_266 = None + view_1176 = torch.ops.aten.view.default(mm_143, [2, 4096, 3072]); mm_143 = None + view_1177 = torch.ops.aten.view.default(view_1176, [2, 4096, -1, 192]); view_1176 = None + split_with_sizes_54 = torch.ops.aten.split_with_sizes.default(view_1177, [128, 64], -1); view_1177 = None + getitem_1879 = split_with_sizes_54[0] + getitem_1880 = split_with_sizes_54[1]; split_with_sizes_54 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(getitem_1880, torch.float32); getitem_1880 = None + view_1178 = torch.ops.aten.view.default(convert_element_type_961, [2, 4096, 16, -1, 2]); convert_element_type_961 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_1178); view_1178 = None + mul_844 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_7); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_844); mul_844 = None + view_1180 = torch.ops.aten.view.default(view_as_real_36, [2, 4096, 16, 64]); view_as_real_36 = None + convert_element_type_962 = torch.ops.prims.convert_element_type.default(view_1180, torch.bfloat16); view_1180 = None + cat_155 = torch.ops.aten.cat.default([getitem_1879, convert_element_type_962], -1); getitem_1879 = convert_element_type_962 = None + convert_element_type_963 = torch.ops.prims.convert_element_type.default(primals_297, torch.bfloat16) + all_gather_into_tensor_302 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_963, 128, '0'); convert_element_type_963 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_302); all_gather_into_tensor_302 = None + slice_109 = torch.ops.aten.slice.Tensor(wait_tensor_370, 0, 0, 576); wait_tensor_370 = None + permute_267 = torch.ops.aten.permute.default(slice_109, [1, 0]); slice_109 = None + mm_144 = torch.ops.aten.mm.default(view_1175, permute_267); permute_267 = None + view_1183 = torch.ops.aten.view.default(mm_144, [2, 4096, 576]); mm_144 = None + split_with_sizes_55 = torch.ops.aten.split_with_sizes.default(view_1183, [512, 64], -1); view_1183 = None + getitem_1881 = split_with_sizes_55[0] + getitem_1882 = split_with_sizes_55[1]; split_with_sizes_55 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(getitem_1882, 2); getitem_1882 = None + convert_element_type_966 = torch.ops.prims.convert_element_type.default(unsqueeze_35, torch.float32); unsqueeze_35 = None + view_1184 = torch.ops.aten.view.default(convert_element_type_966, [2, 4096, 1, -1, 2]); convert_element_type_966 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_1184); view_1184 = None + mul_845 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_7); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_845); mul_845 = None + view_1186 = torch.ops.aten.view.default(view_as_real_37, [2, 4096, 1, 64]); view_as_real_37 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(view_1186, torch.bfloat16); view_1186 = None + convert_element_type_968 = torch.ops.prims.convert_element_type.default(primals_298, torch.bfloat16) + all_gather_into_tensor_303 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_968, 128, '0'); convert_element_type_968 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_303); all_gather_into_tensor_303 = None + convert_element_type_969 = torch.ops.prims.convert_element_type.default(getitem_1881, torch.float32) + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_969, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_1163 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_1163); add_1163 = None + mul_846 = torch.ops.aten.mul.Tensor(convert_element_type_969, rsqrt_55); convert_element_type_969 = None + mul_847 = torch.ops.aten.mul.Tensor(mul_846, wait_tensor_371); mul_846 = wait_tensor_371 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(mul_847, torch.bfloat16); mul_847 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(primals_299, torch.bfloat16) + all_gather_into_tensor_304 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_971, 128, '0'); convert_element_type_971 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_304); all_gather_into_tensor_304 = None + permute_268 = torch.ops.aten.permute.default(wait_tensor_372, [1, 0]); wait_tensor_372 = None + view_1189 = torch.ops.aten.view.default(convert_element_type_970, [8192, 512]); convert_element_type_970 = None + mm_145 = torch.ops.aten.mm.default(view_1189, permute_268); permute_268 = None + view_1190 = torch.ops.aten.view.default(mm_145, [2, 4096, 4096]); mm_145 = None + view_1191 = torch.ops.aten.view.default(view_1190, [2, 4096, -1, 256]); view_1190 = None + split_with_sizes_56 = torch.ops.aten.split_with_sizes.default(view_1191, [128, 128], -1); view_1191 = None + getitem_1883 = split_with_sizes_56[0] + getitem_1884 = split_with_sizes_56[1]; split_with_sizes_56 = None + expand_18 = torch.ops.aten.expand.default(convert_element_type_967, [-1, -1, 16, -1]); convert_element_type_967 = None + cat_156 = torch.ops.aten.cat.default([getitem_1883, expand_18], -1); getitem_1883 = expand_18 = None + permute_269 = torch.ops.aten.permute.default(cat_155, [0, 2, 1, 3]); cat_155 = None + permute_270 = torch.ops.aten.permute.default(cat_156, [0, 2, 1, 3]); cat_156 = None + permute_271 = torch.ops.aten.permute.default(getitem_1884, [0, 2, 1, 3]); getitem_1884 = None + sdpa_score18 = self.sdpa_score18 + sdpa_mask18 = self.sdpa_mask18 + flex_attention_18 = torch.ops.higher_order.flex_attention(permute_269, permute_270, permute_271, sdpa_score18, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask18), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score18 = sdpa_mask18 = None + getitem_1885 = flex_attention_18[0] + getitem_1886 = flex_attention_18[1]; flex_attention_18 = None + permute_272 = torch.ops.aten.permute.default(getitem_1885, [0, 2, 1, 3]) + view_1192 = torch.ops.aten.view.default(permute_272, [2, 4096, -1]); permute_272 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_300, torch.bfloat16) + all_gather_into_tensor_305 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 128, '0'); convert_element_type_974 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_305); all_gather_into_tensor_305 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_373, [1, 0]); wait_tensor_373 = None + view_1194 = torch.ops.aten.view.default(view_1192, [8192, 2048]); view_1192 = None + mm_146 = torch.ops.aten.mm.default(view_1194, permute_273); view_1194 = permute_273 = None + view_1195 = torch.ops.aten.view.default(mm_146, [2, 4096, 2048]); mm_146 = None + add_1164 = torch.ops.aten.add.Tensor(add_1161, view_1195); view_1195 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_301, torch.bfloat16) + all_gather_into_tensor_306 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 128, '0'); convert_element_type_977 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_306); all_gather_into_tensor_306 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_1164, torch.float32) + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_1165 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_1165); add_1165 = None + mul_848 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_56); convert_element_type_978 = None + mul_849 = torch.ops.aten.mul.Tensor(mul_848, wait_tensor_374); mul_848 = wait_tensor_374 = None + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_849, torch.bfloat16); mul_849 = None + view_1197 = torch.ops.aten.view.default(convert_element_type_979, [-1, 2048]); convert_element_type_979 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_303, torch.bfloat16) + all_gather_into_tensor_307 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 128, '0'); convert_element_type_980 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_307); all_gather_into_tensor_307 = None + slice_111 = torch.ops.aten.slice.Tensor(wait_tensor_375, 0, 0, 64); wait_tensor_375 = None + permute_274 = torch.ops.aten.permute.default(slice_111, [1, 0]); slice_111 = None + mm_147 = torch.ops.aten.mm.default(view_1197, permute_274); permute_274 = None + convert_element_type_983 = torch.ops.prims.convert_element_type.default(mm_147, torch.float32) + amax_17 = torch.ops.aten.amax.default(convert_element_type_983, [1], True) + sub_408 = torch.ops.aten.sub.Tensor(convert_element_type_983, amax_17); convert_element_type_983 = None + exp_52 = torch.ops.aten.exp.default(sub_408); sub_408 = None + sum_69 = torch.ops.aten.sum.dim_IntList(exp_52, [1], True) + div_86 = torch.ops.aten.div.Tensor(exp_52, sum_69); exp_52 = None + add_1166 = torch.ops.aten.add.Tensor(div_86, primals_302); primals_302 = None + topk_17 = torch.ops.aten.topk.default(add_1166, 6, -1, True, False); add_1166 = None + getitem_1889 = topk_17[1]; topk_17 = None + gather_17 = torch.ops.aten.gather.default(div_86, 1, getitem_1889); div_86 = None + mul_850 = torch.ops.aten.mul.Tensor(gather_17, 1.0); gather_17 = None + view_1199 = torch.ops.aten.view.default(getitem_1889, [-1]) + histc_34 = torch.ops.aten.histc.default(view_1199, 64, 0, 64) + add_1167 = torch.ops.aten.add.Tensor(primals_304, histc_34) + sort_17 = torch.ops.aten.sort.stable(view_1199, stable = True); view_1199 = None + getitem_1891 = sort_17[1]; sort_17 = None + div_87 = torch.ops.aten.div.Tensor_mode(getitem_1891, 6, rounding_mode = 'floor') + index_34 = torch.ops.aten.index.Tensor(view_1197, [div_87]) + all_to_all_single_51 = torch.ops._c10d_functional.all_to_all_single.default(histc_34, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_51); all_to_all_single_51 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_376); wait_tensor_376 = None + view_1203 = torch.ops.aten.view.default(histc_34, [8, -1]); histc_34 = None + sum_70 = torch.ops.aten.sum.dim_IntList(view_1203, [1]); view_1203 = None + device_put_34 = torch.ops.prims.device_put.default(sum_70, device(type='cpu'), True); sum_70 = None + view_1204 = torch.ops.aten.view.default(wait_tensor_377, [8, -1]) + sum_71 = torch.ops.aten.sum.dim_IntList(view_1204, [1]) + device_put_35 = torch.ops.prims.device_put.default(sum_71, device(type='cpu')); sum_71 = None + select_272 = torch.ops.aten.select.int(device_put_34, 0, 0) + _local_scalar_dense_272 = torch.ops.aten._local_scalar_dense.default(select_272); select_272 = None + ge_340 = _local_scalar_dense_272 >= 0 + _assert_scalar_272 = torch.ops.aten._assert_scalar.default(ge_340, "Runtime assertion failed for expression u272 >= 0 on node 'ge_272'"); ge_340 = _assert_scalar_272 = None + select_273 = torch.ops.aten.select.int(device_put_34, 0, 1) + _local_scalar_dense_273 = torch.ops.aten._local_scalar_dense.default(select_273); select_273 = None + ge_341 = _local_scalar_dense_273 >= 0 + _assert_scalar_273 = torch.ops.aten._assert_scalar.default(ge_341, "Runtime assertion failed for expression u273 >= 0 on node 'ge_273'"); ge_341 = _assert_scalar_273 = None + select_274 = torch.ops.aten.select.int(device_put_34, 0, 2) + _local_scalar_dense_274 = torch.ops.aten._local_scalar_dense.default(select_274); select_274 = None + ge_342 = _local_scalar_dense_274 >= 0 + _assert_scalar_274 = torch.ops.aten._assert_scalar.default(ge_342, "Runtime assertion failed for expression u274 >= 0 on node 'ge_274'"); ge_342 = _assert_scalar_274 = None + select_275 = torch.ops.aten.select.int(device_put_34, 0, 3) + _local_scalar_dense_275 = torch.ops.aten._local_scalar_dense.default(select_275); select_275 = None + ge_343 = _local_scalar_dense_275 >= 0 + _assert_scalar_275 = torch.ops.aten._assert_scalar.default(ge_343, "Runtime assertion failed for expression u275 >= 0 on node 'ge_275'"); ge_343 = _assert_scalar_275 = None + select_276 = torch.ops.aten.select.int(device_put_34, 0, 4) + _local_scalar_dense_276 = torch.ops.aten._local_scalar_dense.default(select_276); select_276 = None + ge_344 = _local_scalar_dense_276 >= 0 + _assert_scalar_276 = torch.ops.aten._assert_scalar.default(ge_344, "Runtime assertion failed for expression u276 >= 0 on node 'ge_276'"); ge_344 = _assert_scalar_276 = None + select_277 = torch.ops.aten.select.int(device_put_34, 0, 5) + _local_scalar_dense_277 = torch.ops.aten._local_scalar_dense.default(select_277); select_277 = None + ge_345 = _local_scalar_dense_277 >= 0 + _assert_scalar_277 = torch.ops.aten._assert_scalar.default(ge_345, "Runtime assertion failed for expression u277 >= 0 on node 'ge_277'"); ge_345 = _assert_scalar_277 = None + select_278 = torch.ops.aten.select.int(device_put_34, 0, 6) + _local_scalar_dense_278 = torch.ops.aten._local_scalar_dense.default(select_278); select_278 = None + ge_346 = _local_scalar_dense_278 >= 0 + _assert_scalar_278 = torch.ops.aten._assert_scalar.default(ge_346, "Runtime assertion failed for expression u278 >= 0 on node 'ge_278'"); ge_346 = _assert_scalar_278 = None + select_279 = torch.ops.aten.select.int(device_put_34, 0, 7); device_put_34 = None + _local_scalar_dense_279 = torch.ops.aten._local_scalar_dense.default(select_279); select_279 = None + ge_347 = _local_scalar_dense_279 >= 0 + _assert_scalar_279 = torch.ops.aten._assert_scalar.default(ge_347, "Runtime assertion failed for expression u279 >= 0 on node 'ge_279'"); ge_347 = _assert_scalar_279 = None + select_280 = torch.ops.aten.select.int(device_put_35, 0, 0) + _local_scalar_dense_280 = torch.ops.aten._local_scalar_dense.default(select_280); select_280 = None + ge_348 = _local_scalar_dense_280 >= 0 + _assert_scalar_280 = torch.ops.aten._assert_scalar.default(ge_348, "Runtime assertion failed for expression u280 >= 0 on node 'ge_280'"); ge_348 = _assert_scalar_280 = None + select_281 = torch.ops.aten.select.int(device_put_35, 0, 1) + _local_scalar_dense_281 = torch.ops.aten._local_scalar_dense.default(select_281); select_281 = None + ge_349 = _local_scalar_dense_281 >= 0 + _assert_scalar_281 = torch.ops.aten._assert_scalar.default(ge_349, "Runtime assertion failed for expression u281 >= 0 on node 'ge_281'"); ge_349 = _assert_scalar_281 = None + select_282 = torch.ops.aten.select.int(device_put_35, 0, 2) + _local_scalar_dense_282 = torch.ops.aten._local_scalar_dense.default(select_282); select_282 = None + ge_350 = _local_scalar_dense_282 >= 0 + _assert_scalar_282 = torch.ops.aten._assert_scalar.default(ge_350, "Runtime assertion failed for expression u282 >= 0 on node 'ge_282'"); ge_350 = _assert_scalar_282 = None + select_283 = torch.ops.aten.select.int(device_put_35, 0, 3) + _local_scalar_dense_283 = torch.ops.aten._local_scalar_dense.default(select_283); select_283 = None + ge_351 = _local_scalar_dense_283 >= 0 + _assert_scalar_283 = torch.ops.aten._assert_scalar.default(ge_351, "Runtime assertion failed for expression u283 >= 0 on node 'ge_283'"); ge_351 = _assert_scalar_283 = None + select_284 = torch.ops.aten.select.int(device_put_35, 0, 4) + _local_scalar_dense_284 = torch.ops.aten._local_scalar_dense.default(select_284); select_284 = None + ge_352 = _local_scalar_dense_284 >= 0 + _assert_scalar_284 = torch.ops.aten._assert_scalar.default(ge_352, "Runtime assertion failed for expression u284 >= 0 on node 'ge_284'"); ge_352 = _assert_scalar_284 = None + select_285 = torch.ops.aten.select.int(device_put_35, 0, 5) + _local_scalar_dense_285 = torch.ops.aten._local_scalar_dense.default(select_285); select_285 = None + ge_353 = _local_scalar_dense_285 >= 0 + _assert_scalar_285 = torch.ops.aten._assert_scalar.default(ge_353, "Runtime assertion failed for expression u285 >= 0 on node 'ge_285'"); ge_353 = _assert_scalar_285 = None + select_286 = torch.ops.aten.select.int(device_put_35, 0, 6) + _local_scalar_dense_286 = torch.ops.aten._local_scalar_dense.default(select_286); select_286 = None + ge_354 = _local_scalar_dense_286 >= 0 + _assert_scalar_286 = torch.ops.aten._assert_scalar.default(ge_354, "Runtime assertion failed for expression u286 >= 0 on node 'ge_286'"); ge_354 = _assert_scalar_286 = None + select_287 = torch.ops.aten.select.int(device_put_35, 0, 7); device_put_35 = None + _local_scalar_dense_287 = torch.ops.aten._local_scalar_dense.default(select_287); select_287 = None + ge_355 = _local_scalar_dense_287 >= 0 + _assert_scalar_287 = torch.ops.aten._assert_scalar.default(ge_355, "Runtime assertion failed for expression u287 >= 0 on node 'ge_287'"); ge_355 = _assert_scalar_287 = None + all_to_all_single_52 = torch.ops._c10d_functional.all_to_all_single.default(index_34, [_local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287], [_local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279], '1033'); index_34 = None + sym_size_int_68 = torch.ops.aten.sym_size.int(all_to_all_single_52, 0) + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_52); all_to_all_single_52 = None + sym_sum_34 = torch.sym_sum((_local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287)) + add_1174 = sym_sum_34 + 64; sym_sum_34 = None + add_1175 = add_1174 + 8; add_1174 = None + sub_411 = add_1175 - 1; add_1175 = None + floordiv_17 = sub_411 // 8; sub_411 = None + mul_855 = floordiv_17 * 8; floordiv_17 = None + cumsum_51 = torch.ops.aten.cumsum.default(wait_tensor_377, 0) + sub_412 = torch.ops.aten.sub.Tensor(cumsum_51, wait_tensor_377); cumsum_51 = None + sum_72 = torch.ops.aten.sum.dim_IntList(view_1204, [0]); view_1204 = None + clamp_min_17 = torch.ops.aten.clamp_min.default(sum_72, 8); sum_72 = None + add_1176 = torch.ops.aten.add.Tensor(clamp_min_17, 8); clamp_min_17 = None + sub_413 = torch.ops.aten.sub.Tensor(add_1176, 1); add_1176 = None + div_88 = torch.ops.aten.div.Tensor_mode(sub_413, 8, rounding_mode = 'floor'); sub_413 = None + mul_856 = torch.ops.aten.mul.Tensor(div_88, 8); div_88 = None + convert_element_type_986 = torch.ops.prims.convert_element_type.default(mul_856, torch.int32); mul_856 = None + cumsum_52 = torch.ops.aten.cumsum.default(convert_element_type_986, 0) + sub_414 = torch.ops.aten.sub.Tensor(cumsum_52, convert_element_type_986); cumsum_52 = None + full_241 = torch.ops.aten.full.default([mul_855], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_855 = None + triton_kernel_wrapper_functional_proxy_17 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 17, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_377, 'start_index_values_ptr': sub_412, 'write_offsets_ptr': sub_414, 'output_ptr': full_241}, tensors_to_clone = ['output_ptr']); wait_tensor_377 = sub_412 = sub_414 = full_241 = None + getitem_1892 = triton_kernel_wrapper_functional_proxy_17['output_ptr']; triton_kernel_wrapper_functional_proxy_17 = None + cat_157 = torch.ops.aten.cat.default([wait_tensor_378, full_default]); wait_tensor_378 = None + sym_size_int_69 = torch.ops.aten.sym_size.int(cat_157, 0) + sym_sum_35 = torch.sym_sum((1, _local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287)) + index_35 = torch.ops.aten.index.Tensor(cat_157, [getitem_1892]); cat_157 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_305, torch.bfloat16) + all_gather_into_tensor_308 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 16, '1025'); convert_element_type_988 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_308); all_gather_into_tensor_308 = None + split_103 = torch.ops.aten.split.Tensor(wait_tensor_379, 8); wait_tensor_379 = None + getitem_1909 = split_103[0] + getitem_1910 = split_103[1] + getitem_1911 = split_103[2] + getitem_1912 = split_103[3] + getitem_1913 = split_103[4] + getitem_1914 = split_103[5] + getitem_1915 = split_103[6] + getitem_1916 = split_103[7] + getitem_1917 = split_103[8] + getitem_1918 = split_103[9] + getitem_1919 = split_103[10] + getitem_1920 = split_103[11] + getitem_1921 = split_103[12] + getitem_1922 = split_103[13] + getitem_1923 = split_103[14] + getitem_1924 = split_103[15]; split_103 = None + cat_159 = torch.ops.aten.cat.default([getitem_1909, getitem_1910, getitem_1911, getitem_1912, getitem_1913, getitem_1914, getitem_1915, getitem_1916, getitem_1917, getitem_1918, getitem_1919, getitem_1920, getitem_1921, getitem_1922, getitem_1923, getitem_1924], 1); getitem_1909 = getitem_1910 = getitem_1911 = getitem_1912 = getitem_1913 = getitem_1914 = getitem_1915 = getitem_1916 = getitem_1917 = getitem_1918 = getitem_1919 = getitem_1920 = getitem_1921 = getitem_1922 = getitem_1923 = getitem_1924 = None + convert_element_type_990 = torch.ops.prims.convert_element_type.default(primals_306, torch.bfloat16) + all_gather_into_tensor_310 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_990, 16, '1025'); convert_element_type_990 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_310); all_gather_into_tensor_310 = None + split_104 = torch.ops.aten.split.Tensor(wait_tensor_381, 8); wait_tensor_381 = None + getitem_1925 = split_104[0] + getitem_1926 = split_104[1] + getitem_1927 = split_104[2] + getitem_1928 = split_104[3] + getitem_1929 = split_104[4] + getitem_1930 = split_104[5] + getitem_1931 = split_104[6] + getitem_1932 = split_104[7] + getitem_1933 = split_104[8] + getitem_1934 = split_104[9] + getitem_1935 = split_104[10] + getitem_1936 = split_104[11] + getitem_1937 = split_104[12] + getitem_1938 = split_104[13] + getitem_1939 = split_104[14] + getitem_1940 = split_104[15]; split_104 = None + cat_160 = torch.ops.aten.cat.default([getitem_1925, getitem_1926, getitem_1927, getitem_1928, getitem_1929, getitem_1930, getitem_1931, getitem_1932, getitem_1933, getitem_1934, getitem_1935, getitem_1936, getitem_1937, getitem_1938, getitem_1939, getitem_1940], 1); getitem_1925 = getitem_1926 = getitem_1927 = getitem_1928 = getitem_1929 = getitem_1930 = getitem_1931 = getitem_1932 = getitem_1933 = getitem_1934 = getitem_1935 = getitem_1936 = getitem_1937 = getitem_1938 = getitem_1939 = getitem_1940 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_307, torch.bfloat16) + all_gather_into_tensor_311 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 16, '1025'); convert_element_type_991 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_311); all_gather_into_tensor_311 = None + split_105 = torch.ops.aten.split.Tensor(wait_tensor_382, 8); wait_tensor_382 = None + getitem_1941 = split_105[0] + getitem_1942 = split_105[1] + getitem_1943 = split_105[2] + getitem_1944 = split_105[3] + getitem_1945 = split_105[4] + getitem_1946 = split_105[5] + getitem_1947 = split_105[6] + getitem_1948 = split_105[7] + getitem_1949 = split_105[8] + getitem_1950 = split_105[9] + getitem_1951 = split_105[10] + getitem_1952 = split_105[11] + getitem_1953 = split_105[12] + getitem_1954 = split_105[13] + getitem_1955 = split_105[14] + getitem_1956 = split_105[15]; split_105 = None + cat_161 = torch.ops.aten.cat.default([getitem_1941, getitem_1942, getitem_1943, getitem_1944, getitem_1945, getitem_1946, getitem_1947, getitem_1948, getitem_1949, getitem_1950, getitem_1951, getitem_1952, getitem_1953, getitem_1954, getitem_1955, getitem_1956], 1); getitem_1941 = getitem_1942 = getitem_1943 = getitem_1944 = getitem_1945 = getitem_1946 = getitem_1947 = getitem_1948 = getitem_1949 = getitem_1950 = getitem_1951 = getitem_1952 = getitem_1953 = getitem_1954 = getitem_1955 = getitem_1956 = None + cumsum_53 = torch.ops.aten.cumsum.default(convert_element_type_986, 0, dtype = torch.int32); convert_element_type_986 = None + permute_275 = torch.ops.aten.permute.default(cat_159, [0, 2, 1]); cat_159 = None + _grouped_mm_51 = torch.ops.aten._grouped_mm.default(index_35, permute_275, cumsum_53) + convert_element_type_994 = torch.ops.prims.convert_element_type.default(_grouped_mm_51, torch.float32) + neg_35 = torch.ops.aten.neg.default(convert_element_type_994) + exp_53 = torch.ops.aten.exp.default(neg_35); neg_35 = None + add_1188 = torch.ops.aten.add.Tensor(exp_53, 1); exp_53 = None + div_89 = torch.ops.aten.div.Tensor(convert_element_type_994, add_1188); convert_element_type_994 = add_1188 = None + convert_element_type_995 = torch.ops.prims.convert_element_type.default(div_89, torch.bfloat16); div_89 = None + permute_276 = torch.ops.aten.permute.default(cat_161, [0, 2, 1]); cat_161 = None + _grouped_mm_52 = torch.ops.aten._grouped_mm.default(index_35, permute_276, cumsum_53) + mul_868 = torch.ops.aten.mul.Tensor(convert_element_type_995, _grouped_mm_52); convert_element_type_995 = None + permute_277 = torch.ops.aten.permute.default(cat_160, [0, 2, 1]); cat_160 = None + _grouped_mm_53 = torch.ops.aten._grouped_mm.default(mul_868, permute_277, cumsum_53) + empty_17 = torch.ops.aten.empty.memory_format([sym_size_int_69, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_34 = torch.ops.aten.index_put.default(empty_17, [getitem_1892], _grouped_mm_53); empty_17 = _grouped_mm_53 = None + slice_113 = torch.ops.aten.slice.Tensor(index_put_34, 0, 0, -1); index_put_34 = None + all_to_all_single_53 = torch.ops._c10d_functional.all_to_all_single.default(slice_113, [_local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279], [_local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287], '1033'); slice_113 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_53); all_to_all_single_53 = None + convert_element_type_996 = torch.ops.prims.convert_element_type.default(primals_308, torch.bfloat16) + all_gather_into_tensor_314 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_996, 128, '0'); convert_element_type_996 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_314); all_gather_into_tensor_314 = None + permute_278 = torch.ops.aten.permute.default(wait_tensor_386, [1, 0]); wait_tensor_386 = None + mm_148 = torch.ops.aten.mm.default(view_1197, permute_278); permute_278 = None + convert_element_type_999 = torch.ops.prims.convert_element_type.default(mm_148, torch.float32) + neg_36 = torch.ops.aten.neg.default(convert_element_type_999) + exp_54 = torch.ops.aten.exp.default(neg_36); neg_36 = None + add_1224 = torch.ops.aten.add.Tensor(exp_54, 1); exp_54 = None + div_90 = torch.ops.aten.div.Tensor(convert_element_type_999, add_1224); convert_element_type_999 = add_1224 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(div_90, torch.bfloat16); div_90 = None + convert_element_type_1001 = torch.ops.prims.convert_element_type.default(primals_309, torch.bfloat16) + all_gather_into_tensor_315 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1001, 128, '0'); convert_element_type_1001 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_315); all_gather_into_tensor_315 = None + permute_279 = torch.ops.aten.permute.default(wait_tensor_387, [1, 0]); wait_tensor_387 = None + mm_149 = torch.ops.aten.mm.default(view_1197, permute_279); permute_279 = None + mul_888 = torch.ops.aten.mul.Tensor(convert_element_type_1000, mm_149); convert_element_type_1000 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(primals_310, torch.bfloat16) + all_gather_into_tensor_316 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1004, 128, '0'); convert_element_type_1004 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_316); all_gather_into_tensor_316 = None + permute_280 = torch.ops.aten.permute.default(wait_tensor_388, [1, 0]); wait_tensor_388 = None + mm_150 = torch.ops.aten.mm.default(mul_888, permute_280); permute_280 = None + index_put_35 = torch.ops.aten.index_put.default(full_default_1, [getitem_1891], wait_tensor_385); wait_tensor_385 = None + view_1237 = torch.ops.aten.view.default(mul_850, [-1, 1, 6]); mul_850 = None + view_1238 = torch.ops.aten.view.default(index_put_35, [-1, 6, 2048]); index_put_35 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(view_1238, torch.float32); view_1238 = None + bmm_17 = torch.ops.aten.bmm.default(view_1237, convert_element_type_1007) + convert_element_type_1008 = torch.ops.prims.convert_element_type.default(bmm_17, torch.bfloat16); bmm_17 = None + squeeze_17 = torch.ops.aten.squeeze.dim(convert_element_type_1008, 1); convert_element_type_1008 = None + add_1228 = torch.ops.aten.add.Tensor(mm_150, squeeze_17); mm_150 = squeeze_17 = None + view_1239 = torch.ops.aten.view.default(add_1228, [2, 4096, 2048]); add_1228 = None + add_1229 = torch.ops.aten.add.Tensor(add_1164, view_1239); view_1239 = None + convert_element_type_1009 = torch.ops.prims.convert_element_type.default(primals_311, torch.bfloat16) + all_gather_into_tensor_317 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1009, 128, '0'); convert_element_type_1009 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_317); all_gather_into_tensor_317 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(add_1229, torch.float32) + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1010, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_1230 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_1230); add_1230 = None + mul_891 = torch.ops.aten.mul.Tensor(convert_element_type_1010, rsqrt_57); convert_element_type_1010 = None + mul_892 = torch.ops.aten.mul.Tensor(mul_891, wait_tensor_389); mul_891 = wait_tensor_389 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(mul_892, torch.bfloat16); mul_892 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(primals_312, torch.bfloat16) + all_gather_into_tensor_318 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1012, 128, '0'); convert_element_type_1012 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_318); all_gather_into_tensor_318 = None + permute_281 = torch.ops.aten.permute.default(wait_tensor_390, [1, 0]); wait_tensor_390 = None + view_1242 = torch.ops.aten.view.default(convert_element_type_1011, [8192, 2048]); convert_element_type_1011 = None + mm_151 = torch.ops.aten.mm.default(view_1242, permute_281); permute_281 = None + view_1243 = torch.ops.aten.view.default(mm_151, [2, 4096, 3072]); mm_151 = None + view_1244 = torch.ops.aten.view.default(view_1243, [2, 4096, -1, 192]); view_1243 = None + split_with_sizes_57 = torch.ops.aten.split_with_sizes.default(view_1244, [128, 64], -1); view_1244 = None + getitem_1989 = split_with_sizes_57[0] + getitem_1990 = split_with_sizes_57[1]; split_with_sizes_57 = None + convert_element_type_1015 = torch.ops.prims.convert_element_type.default(getitem_1990, torch.float32); getitem_1990 = None + view_1245 = torch.ops.aten.view.default(convert_element_type_1015, [2, 4096, 16, -1, 2]); convert_element_type_1015 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_1245); view_1245 = None + mul_893 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_7); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_893); mul_893 = None + view_1247 = torch.ops.aten.view.default(view_as_real_38, [2, 4096, 16, 64]); view_as_real_38 = None + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_1247, torch.bfloat16); view_1247 = None + cat_164 = torch.ops.aten.cat.default([getitem_1989, convert_element_type_1016], -1); getitem_1989 = convert_element_type_1016 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(primals_313, torch.bfloat16) + all_gather_into_tensor_319 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1017, 128, '0'); convert_element_type_1017 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_319); all_gather_into_tensor_319 = None + slice_115 = torch.ops.aten.slice.Tensor(wait_tensor_391, 0, 0, 576); wait_tensor_391 = None + permute_282 = torch.ops.aten.permute.default(slice_115, [1, 0]); slice_115 = None + mm_152 = torch.ops.aten.mm.default(view_1242, permute_282); permute_282 = None + view_1250 = torch.ops.aten.view.default(mm_152, [2, 4096, 576]); mm_152 = None + split_with_sizes_58 = torch.ops.aten.split_with_sizes.default(view_1250, [512, 64], -1); view_1250 = None + getitem_1991 = split_with_sizes_58[0] + getitem_1992 = split_with_sizes_58[1]; split_with_sizes_58 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(getitem_1992, 2); getitem_1992 = None + convert_element_type_1020 = torch.ops.prims.convert_element_type.default(unsqueeze_37, torch.float32); unsqueeze_37 = None + view_1251 = torch.ops.aten.view.default(convert_element_type_1020, [2, 4096, 1, -1, 2]); convert_element_type_1020 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_1251); view_1251 = None + mul_894 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_7); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_894); mul_894 = None + view_1253 = torch.ops.aten.view.default(view_as_real_39, [2, 4096, 1, 64]); view_as_real_39 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(view_1253, torch.bfloat16); view_1253 = None + convert_element_type_1022 = torch.ops.prims.convert_element_type.default(primals_314, torch.bfloat16) + all_gather_into_tensor_320 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1022, 128, '0'); convert_element_type_1022 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_320); all_gather_into_tensor_320 = None + convert_element_type_1023 = torch.ops.prims.convert_element_type.default(getitem_1991, torch.float32) + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1023, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_1231 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_1231); add_1231 = None + mul_895 = torch.ops.aten.mul.Tensor(convert_element_type_1023, rsqrt_58); convert_element_type_1023 = None + mul_896 = torch.ops.aten.mul.Tensor(mul_895, wait_tensor_392); mul_895 = wait_tensor_392 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(mul_896, torch.bfloat16); mul_896 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(primals_315, torch.bfloat16) + all_gather_into_tensor_321 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1025, 128, '0'); convert_element_type_1025 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_321); all_gather_into_tensor_321 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_393, [1, 0]); wait_tensor_393 = None + view_1256 = torch.ops.aten.view.default(convert_element_type_1024, [8192, 512]); convert_element_type_1024 = None + mm_153 = torch.ops.aten.mm.default(view_1256, permute_283); permute_283 = None + view_1257 = torch.ops.aten.view.default(mm_153, [2, 4096, 4096]); mm_153 = None + view_1258 = torch.ops.aten.view.default(view_1257, [2, 4096, -1, 256]); view_1257 = None + split_with_sizes_59 = torch.ops.aten.split_with_sizes.default(view_1258, [128, 128], -1); view_1258 = None + getitem_1993 = split_with_sizes_59[0] + getitem_1994 = split_with_sizes_59[1]; split_with_sizes_59 = None + expand_19 = torch.ops.aten.expand.default(convert_element_type_1021, [-1, -1, 16, -1]); convert_element_type_1021 = None + cat_165 = torch.ops.aten.cat.default([getitem_1993, expand_19], -1); getitem_1993 = expand_19 = None + permute_284 = torch.ops.aten.permute.default(cat_164, [0, 2, 1, 3]); cat_164 = None + permute_285 = torch.ops.aten.permute.default(cat_165, [0, 2, 1, 3]); cat_165 = None + permute_286 = torch.ops.aten.permute.default(getitem_1994, [0, 2, 1, 3]); getitem_1994 = None + sdpa_score19 = self.sdpa_score19 + sdpa_mask19 = self.sdpa_mask19 + flex_attention_19 = torch.ops.higher_order.flex_attention(permute_284, permute_285, permute_286, sdpa_score19, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask19), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score19 = sdpa_mask19 = None + getitem_1995 = flex_attention_19[0] + getitem_1996 = flex_attention_19[1]; flex_attention_19 = None + permute_287 = torch.ops.aten.permute.default(getitem_1995, [0, 2, 1, 3]) + view_1259 = torch.ops.aten.view.default(permute_287, [2, 4096, -1]); permute_287 = None + convert_element_type_1028 = torch.ops.prims.convert_element_type.default(primals_316, torch.bfloat16) + all_gather_into_tensor_322 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1028, 128, '0'); convert_element_type_1028 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_322); all_gather_into_tensor_322 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_394, [1, 0]); wait_tensor_394 = None + view_1261 = torch.ops.aten.view.default(view_1259, [8192, 2048]); view_1259 = None + mm_154 = torch.ops.aten.mm.default(view_1261, permute_288); view_1261 = permute_288 = None + view_1262 = torch.ops.aten.view.default(mm_154, [2, 4096, 2048]); mm_154 = None + add_1232 = torch.ops.aten.add.Tensor(add_1229, view_1262); view_1262 = None + convert_element_type_1031 = torch.ops.prims.convert_element_type.default(primals_317, torch.bfloat16) + all_gather_into_tensor_323 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1031, 128, '0'); convert_element_type_1031 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_323); all_gather_into_tensor_323 = None + convert_element_type_1032 = torch.ops.prims.convert_element_type.default(add_1232, torch.float32) + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1032, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_1233 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_1233); add_1233 = None + mul_897 = torch.ops.aten.mul.Tensor(convert_element_type_1032, rsqrt_59); convert_element_type_1032 = None + mul_898 = torch.ops.aten.mul.Tensor(mul_897, wait_tensor_395); mul_897 = wait_tensor_395 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(mul_898, torch.bfloat16); mul_898 = None + view_1264 = torch.ops.aten.view.default(convert_element_type_1033, [-1, 2048]); convert_element_type_1033 = None + convert_element_type_1034 = torch.ops.prims.convert_element_type.default(primals_319, torch.bfloat16) + all_gather_into_tensor_324 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1034, 128, '0'); convert_element_type_1034 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_324); all_gather_into_tensor_324 = None + slice_117 = torch.ops.aten.slice.Tensor(wait_tensor_396, 0, 0, 64); wait_tensor_396 = None + permute_289 = torch.ops.aten.permute.default(slice_117, [1, 0]); slice_117 = None + mm_155 = torch.ops.aten.mm.default(view_1264, permute_289); permute_289 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(mm_155, torch.float32) + amax_18 = torch.ops.aten.amax.default(convert_element_type_1037, [1], True) + sub_432 = torch.ops.aten.sub.Tensor(convert_element_type_1037, amax_18); convert_element_type_1037 = None + exp_55 = torch.ops.aten.exp.default(sub_432); sub_432 = None + sum_73 = torch.ops.aten.sum.dim_IntList(exp_55, [1], True) + div_91 = torch.ops.aten.div.Tensor(exp_55, sum_73); exp_55 = None + add_1234 = torch.ops.aten.add.Tensor(div_91, primals_318); primals_318 = None + topk_18 = torch.ops.aten.topk.default(add_1234, 6, -1, True, False); add_1234 = None + getitem_1999 = topk_18[1]; topk_18 = None + gather_18 = torch.ops.aten.gather.default(div_91, 1, getitem_1999); div_91 = None + mul_899 = torch.ops.aten.mul.Tensor(gather_18, 1.0); gather_18 = None + view_1266 = torch.ops.aten.view.default(getitem_1999, [-1]) + histc_36 = torch.ops.aten.histc.default(view_1266, 64, 0, 64) + add_1235 = torch.ops.aten.add.Tensor(primals_320, histc_36) + sort_18 = torch.ops.aten.sort.stable(view_1266, stable = True); view_1266 = None + getitem_2001 = sort_18[1]; sort_18 = None + div_92 = torch.ops.aten.div.Tensor_mode(getitem_2001, 6, rounding_mode = 'floor') + index_36 = torch.ops.aten.index.Tensor(view_1264, [div_92]) + all_to_all_single_54 = torch.ops._c10d_functional.all_to_all_single.default(histc_36, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_54); all_to_all_single_54 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_397); wait_tensor_397 = None + view_1270 = torch.ops.aten.view.default(histc_36, [8, -1]); histc_36 = None + sum_74 = torch.ops.aten.sum.dim_IntList(view_1270, [1]); view_1270 = None + device_put_36 = torch.ops.prims.device_put.default(sum_74, device(type='cpu'), True); sum_74 = None + view_1271 = torch.ops.aten.view.default(wait_tensor_398, [8, -1]) + sum_75 = torch.ops.aten.sum.dim_IntList(view_1271, [1]) + device_put_37 = torch.ops.prims.device_put.default(sum_75, device(type='cpu')); sum_75 = None + select_288 = torch.ops.aten.select.int(device_put_36, 0, 0) + _local_scalar_dense_288 = torch.ops.aten._local_scalar_dense.default(select_288); select_288 = None + ge_360 = _local_scalar_dense_288 >= 0 + _assert_scalar_288 = torch.ops.aten._assert_scalar.default(ge_360, "Runtime assertion failed for expression u288 >= 0 on node 'ge_288'"); ge_360 = _assert_scalar_288 = None + select_289 = torch.ops.aten.select.int(device_put_36, 0, 1) + _local_scalar_dense_289 = torch.ops.aten._local_scalar_dense.default(select_289); select_289 = None + ge_361 = _local_scalar_dense_289 >= 0 + _assert_scalar_289 = torch.ops.aten._assert_scalar.default(ge_361, "Runtime assertion failed for expression u289 >= 0 on node 'ge_289'"); ge_361 = _assert_scalar_289 = None + select_290 = torch.ops.aten.select.int(device_put_36, 0, 2) + _local_scalar_dense_290 = torch.ops.aten._local_scalar_dense.default(select_290); select_290 = None + ge_362 = _local_scalar_dense_290 >= 0 + _assert_scalar_290 = torch.ops.aten._assert_scalar.default(ge_362, "Runtime assertion failed for expression u290 >= 0 on node 'ge_290'"); ge_362 = _assert_scalar_290 = None + select_291 = torch.ops.aten.select.int(device_put_36, 0, 3) + _local_scalar_dense_291 = torch.ops.aten._local_scalar_dense.default(select_291); select_291 = None + ge_363 = _local_scalar_dense_291 >= 0 + _assert_scalar_291 = torch.ops.aten._assert_scalar.default(ge_363, "Runtime assertion failed for expression u291 >= 0 on node 'ge_291'"); ge_363 = _assert_scalar_291 = None + select_292 = torch.ops.aten.select.int(device_put_36, 0, 4) + _local_scalar_dense_292 = torch.ops.aten._local_scalar_dense.default(select_292); select_292 = None + ge_364 = _local_scalar_dense_292 >= 0 + _assert_scalar_292 = torch.ops.aten._assert_scalar.default(ge_364, "Runtime assertion failed for expression u292 >= 0 on node 'ge_292'"); ge_364 = _assert_scalar_292 = None + select_293 = torch.ops.aten.select.int(device_put_36, 0, 5) + _local_scalar_dense_293 = torch.ops.aten._local_scalar_dense.default(select_293); select_293 = None + ge_365 = _local_scalar_dense_293 >= 0 + _assert_scalar_293 = torch.ops.aten._assert_scalar.default(ge_365, "Runtime assertion failed for expression u293 >= 0 on node 'ge_293'"); ge_365 = _assert_scalar_293 = None + select_294 = torch.ops.aten.select.int(device_put_36, 0, 6) + _local_scalar_dense_294 = torch.ops.aten._local_scalar_dense.default(select_294); select_294 = None + ge_366 = _local_scalar_dense_294 >= 0 + _assert_scalar_294 = torch.ops.aten._assert_scalar.default(ge_366, "Runtime assertion failed for expression u294 >= 0 on node 'ge_294'"); ge_366 = _assert_scalar_294 = None + select_295 = torch.ops.aten.select.int(device_put_36, 0, 7); device_put_36 = None + _local_scalar_dense_295 = torch.ops.aten._local_scalar_dense.default(select_295); select_295 = None + ge_367 = _local_scalar_dense_295 >= 0 + _assert_scalar_295 = torch.ops.aten._assert_scalar.default(ge_367, "Runtime assertion failed for expression u295 >= 0 on node 'ge_295'"); ge_367 = _assert_scalar_295 = None + select_296 = torch.ops.aten.select.int(device_put_37, 0, 0) + _local_scalar_dense_296 = torch.ops.aten._local_scalar_dense.default(select_296); select_296 = None + ge_368 = _local_scalar_dense_296 >= 0 + _assert_scalar_296 = torch.ops.aten._assert_scalar.default(ge_368, "Runtime assertion failed for expression u296 >= 0 on node 'ge_296'"); ge_368 = _assert_scalar_296 = None + select_297 = torch.ops.aten.select.int(device_put_37, 0, 1) + _local_scalar_dense_297 = torch.ops.aten._local_scalar_dense.default(select_297); select_297 = None + ge_369 = _local_scalar_dense_297 >= 0 + _assert_scalar_297 = torch.ops.aten._assert_scalar.default(ge_369, "Runtime assertion failed for expression u297 >= 0 on node 'ge_297'"); ge_369 = _assert_scalar_297 = None + select_298 = torch.ops.aten.select.int(device_put_37, 0, 2) + _local_scalar_dense_298 = torch.ops.aten._local_scalar_dense.default(select_298); select_298 = None + ge_370 = _local_scalar_dense_298 >= 0 + _assert_scalar_298 = torch.ops.aten._assert_scalar.default(ge_370, "Runtime assertion failed for expression u298 >= 0 on node 'ge_298'"); ge_370 = _assert_scalar_298 = None + select_299 = torch.ops.aten.select.int(device_put_37, 0, 3) + _local_scalar_dense_299 = torch.ops.aten._local_scalar_dense.default(select_299); select_299 = None + ge_371 = _local_scalar_dense_299 >= 0 + _assert_scalar_299 = torch.ops.aten._assert_scalar.default(ge_371, "Runtime assertion failed for expression u299 >= 0 on node 'ge_299'"); ge_371 = _assert_scalar_299 = None + select_300 = torch.ops.aten.select.int(device_put_37, 0, 4) + _local_scalar_dense_300 = torch.ops.aten._local_scalar_dense.default(select_300); select_300 = None + ge_372 = _local_scalar_dense_300 >= 0 + _assert_scalar_300 = torch.ops.aten._assert_scalar.default(ge_372, "Runtime assertion failed for expression u300 >= 0 on node 'ge_300'"); ge_372 = _assert_scalar_300 = None + select_301 = torch.ops.aten.select.int(device_put_37, 0, 5) + _local_scalar_dense_301 = torch.ops.aten._local_scalar_dense.default(select_301); select_301 = None + ge_373 = _local_scalar_dense_301 >= 0 + _assert_scalar_301 = torch.ops.aten._assert_scalar.default(ge_373, "Runtime assertion failed for expression u301 >= 0 on node 'ge_301'"); ge_373 = _assert_scalar_301 = None + select_302 = torch.ops.aten.select.int(device_put_37, 0, 6) + _local_scalar_dense_302 = torch.ops.aten._local_scalar_dense.default(select_302); select_302 = None + ge_374 = _local_scalar_dense_302 >= 0 + _assert_scalar_302 = torch.ops.aten._assert_scalar.default(ge_374, "Runtime assertion failed for expression u302 >= 0 on node 'ge_302'"); ge_374 = _assert_scalar_302 = None + select_303 = torch.ops.aten.select.int(device_put_37, 0, 7); device_put_37 = None + _local_scalar_dense_303 = torch.ops.aten._local_scalar_dense.default(select_303); select_303 = None + ge_375 = _local_scalar_dense_303 >= 0 + _assert_scalar_303 = torch.ops.aten._assert_scalar.default(ge_375, "Runtime assertion failed for expression u303 >= 0 on node 'ge_303'"); ge_375 = _assert_scalar_303 = None + all_to_all_single_55 = torch.ops._c10d_functional.all_to_all_single.default(index_36, [_local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303], [_local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295], '1033'); index_36 = None + sym_size_int_72 = torch.ops.aten.sym_size.int(all_to_all_single_55, 0) + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_55); all_to_all_single_55 = None + sym_sum_36 = torch.sym_sum((_local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303)) + add_1242 = sym_sum_36 + 64; sym_sum_36 = None + add_1243 = add_1242 + 8; add_1242 = None + sub_435 = add_1243 - 1; add_1243 = None + floordiv_18 = sub_435 // 8; sub_435 = None + mul_904 = floordiv_18 * 8; floordiv_18 = None + cumsum_54 = torch.ops.aten.cumsum.default(wait_tensor_398, 0) + sub_436 = torch.ops.aten.sub.Tensor(cumsum_54, wait_tensor_398); cumsum_54 = None + sum_76 = torch.ops.aten.sum.dim_IntList(view_1271, [0]); view_1271 = None + clamp_min_18 = torch.ops.aten.clamp_min.default(sum_76, 8); sum_76 = None + add_1244 = torch.ops.aten.add.Tensor(clamp_min_18, 8); clamp_min_18 = None + sub_437 = torch.ops.aten.sub.Tensor(add_1244, 1); add_1244 = None + div_93 = torch.ops.aten.div.Tensor_mode(sub_437, 8, rounding_mode = 'floor'); sub_437 = None + mul_905 = torch.ops.aten.mul.Tensor(div_93, 8); div_93 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(mul_905, torch.int32); mul_905 = None + cumsum_55 = torch.ops.aten.cumsum.default(convert_element_type_1040, 0) + sub_438 = torch.ops.aten.sub.Tensor(cumsum_55, convert_element_type_1040); cumsum_55 = None + full_254 = torch.ops.aten.full.default([mul_904], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_904 = None + triton_kernel_wrapper_functional_proxy_18 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 18, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_398, 'start_index_values_ptr': sub_436, 'write_offsets_ptr': sub_438, 'output_ptr': full_254}, tensors_to_clone = ['output_ptr']); wait_tensor_398 = sub_436 = sub_438 = full_254 = None + getitem_2002 = triton_kernel_wrapper_functional_proxy_18['output_ptr']; triton_kernel_wrapper_functional_proxy_18 = None + cat_166 = torch.ops.aten.cat.default([wait_tensor_399, full_default]); wait_tensor_399 = None + sym_size_int_73 = torch.ops.aten.sym_size.int(cat_166, 0) + sym_sum_37 = torch.sym_sum((1, _local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303)) + index_37 = torch.ops.aten.index.Tensor(cat_166, [getitem_2002]); cat_166 = None + convert_element_type_1042 = torch.ops.prims.convert_element_type.default(primals_321, torch.bfloat16) + all_gather_into_tensor_325 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1042, 16, '1025'); convert_element_type_1042 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_325); all_gather_into_tensor_325 = None + split_109 = torch.ops.aten.split.Tensor(wait_tensor_400, 8); wait_tensor_400 = None + getitem_2019 = split_109[0] + getitem_2020 = split_109[1] + getitem_2021 = split_109[2] + getitem_2022 = split_109[3] + getitem_2023 = split_109[4] + getitem_2024 = split_109[5] + getitem_2025 = split_109[6] + getitem_2026 = split_109[7] + getitem_2027 = split_109[8] + getitem_2028 = split_109[9] + getitem_2029 = split_109[10] + getitem_2030 = split_109[11] + getitem_2031 = split_109[12] + getitem_2032 = split_109[13] + getitem_2033 = split_109[14] + getitem_2034 = split_109[15]; split_109 = None + cat_168 = torch.ops.aten.cat.default([getitem_2019, getitem_2020, getitem_2021, getitem_2022, getitem_2023, getitem_2024, getitem_2025, getitem_2026, getitem_2027, getitem_2028, getitem_2029, getitem_2030, getitem_2031, getitem_2032, getitem_2033, getitem_2034], 1); getitem_2019 = getitem_2020 = getitem_2021 = getitem_2022 = getitem_2023 = getitem_2024 = getitem_2025 = getitem_2026 = getitem_2027 = getitem_2028 = getitem_2029 = getitem_2030 = getitem_2031 = getitem_2032 = getitem_2033 = getitem_2034 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(primals_322, torch.bfloat16) + all_gather_into_tensor_327 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1044, 16, '1025'); convert_element_type_1044 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_327); all_gather_into_tensor_327 = None + split_110 = torch.ops.aten.split.Tensor(wait_tensor_402, 8); wait_tensor_402 = None + getitem_2035 = split_110[0] + getitem_2036 = split_110[1] + getitem_2037 = split_110[2] + getitem_2038 = split_110[3] + getitem_2039 = split_110[4] + getitem_2040 = split_110[5] + getitem_2041 = split_110[6] + getitem_2042 = split_110[7] + getitem_2043 = split_110[8] + getitem_2044 = split_110[9] + getitem_2045 = split_110[10] + getitem_2046 = split_110[11] + getitem_2047 = split_110[12] + getitem_2048 = split_110[13] + getitem_2049 = split_110[14] + getitem_2050 = split_110[15]; split_110 = None + cat_169 = torch.ops.aten.cat.default([getitem_2035, getitem_2036, getitem_2037, getitem_2038, getitem_2039, getitem_2040, getitem_2041, getitem_2042, getitem_2043, getitem_2044, getitem_2045, getitem_2046, getitem_2047, getitem_2048, getitem_2049, getitem_2050], 1); getitem_2035 = getitem_2036 = getitem_2037 = getitem_2038 = getitem_2039 = getitem_2040 = getitem_2041 = getitem_2042 = getitem_2043 = getitem_2044 = getitem_2045 = getitem_2046 = getitem_2047 = getitem_2048 = getitem_2049 = getitem_2050 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(primals_323, torch.bfloat16) + all_gather_into_tensor_328 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1045, 16, '1025'); convert_element_type_1045 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_328); all_gather_into_tensor_328 = None + split_111 = torch.ops.aten.split.Tensor(wait_tensor_403, 8); wait_tensor_403 = None + getitem_2051 = split_111[0] + getitem_2052 = split_111[1] + getitem_2053 = split_111[2] + getitem_2054 = split_111[3] + getitem_2055 = split_111[4] + getitem_2056 = split_111[5] + getitem_2057 = split_111[6] + getitem_2058 = split_111[7] + getitem_2059 = split_111[8] + getitem_2060 = split_111[9] + getitem_2061 = split_111[10] + getitem_2062 = split_111[11] + getitem_2063 = split_111[12] + getitem_2064 = split_111[13] + getitem_2065 = split_111[14] + getitem_2066 = split_111[15]; split_111 = None + cat_170 = torch.ops.aten.cat.default([getitem_2051, getitem_2052, getitem_2053, getitem_2054, getitem_2055, getitem_2056, getitem_2057, getitem_2058, getitem_2059, getitem_2060, getitem_2061, getitem_2062, getitem_2063, getitem_2064, getitem_2065, getitem_2066], 1); getitem_2051 = getitem_2052 = getitem_2053 = getitem_2054 = getitem_2055 = getitem_2056 = getitem_2057 = getitem_2058 = getitem_2059 = getitem_2060 = getitem_2061 = getitem_2062 = getitem_2063 = getitem_2064 = getitem_2065 = getitem_2066 = None + cumsum_56 = torch.ops.aten.cumsum.default(convert_element_type_1040, 0, dtype = torch.int32); convert_element_type_1040 = None + permute_290 = torch.ops.aten.permute.default(cat_168, [0, 2, 1]); cat_168 = None + _grouped_mm_54 = torch.ops.aten._grouped_mm.default(index_37, permute_290, cumsum_56) + convert_element_type_1048 = torch.ops.prims.convert_element_type.default(_grouped_mm_54, torch.float32) + neg_37 = torch.ops.aten.neg.default(convert_element_type_1048) + exp_56 = torch.ops.aten.exp.default(neg_37); neg_37 = None + add_1256 = torch.ops.aten.add.Tensor(exp_56, 1); exp_56 = None + div_94 = torch.ops.aten.div.Tensor(convert_element_type_1048, add_1256); convert_element_type_1048 = add_1256 = None + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(div_94, torch.bfloat16); div_94 = None + permute_291 = torch.ops.aten.permute.default(cat_170, [0, 2, 1]); cat_170 = None + _grouped_mm_55 = torch.ops.aten._grouped_mm.default(index_37, permute_291, cumsum_56) + mul_917 = torch.ops.aten.mul.Tensor(convert_element_type_1049, _grouped_mm_55); convert_element_type_1049 = None + permute_292 = torch.ops.aten.permute.default(cat_169, [0, 2, 1]); cat_169 = None + _grouped_mm_56 = torch.ops.aten._grouped_mm.default(mul_917, permute_292, cumsum_56) + empty_18 = torch.ops.aten.empty.memory_format([sym_size_int_73, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_36 = torch.ops.aten.index_put.default(empty_18, [getitem_2002], _grouped_mm_56); empty_18 = _grouped_mm_56 = None + slice_119 = torch.ops.aten.slice.Tensor(index_put_36, 0, 0, -1); index_put_36 = None + all_to_all_single_56 = torch.ops._c10d_functional.all_to_all_single.default(slice_119, [_local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295], [_local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303], '1033'); slice_119 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_56); all_to_all_single_56 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(primals_324, torch.bfloat16) + all_gather_into_tensor_331 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1050, 128, '0'); convert_element_type_1050 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_331); all_gather_into_tensor_331 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_407, [1, 0]); wait_tensor_407 = None + mm_156 = torch.ops.aten.mm.default(view_1264, permute_293); permute_293 = None + convert_element_type_1053 = torch.ops.prims.convert_element_type.default(mm_156, torch.float32) + neg_38 = torch.ops.aten.neg.default(convert_element_type_1053) + exp_57 = torch.ops.aten.exp.default(neg_38); neg_38 = None + add_1292 = torch.ops.aten.add.Tensor(exp_57, 1); exp_57 = None + div_95 = torch.ops.aten.div.Tensor(convert_element_type_1053, add_1292); convert_element_type_1053 = add_1292 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(div_95, torch.bfloat16); div_95 = None + convert_element_type_1055 = torch.ops.prims.convert_element_type.default(primals_325, torch.bfloat16) + all_gather_into_tensor_332 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1055, 128, '0'); convert_element_type_1055 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_332); all_gather_into_tensor_332 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_408, [1, 0]); wait_tensor_408 = None + mm_157 = torch.ops.aten.mm.default(view_1264, permute_294); permute_294 = None + mul_937 = torch.ops.aten.mul.Tensor(convert_element_type_1054, mm_157); convert_element_type_1054 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(primals_326, torch.bfloat16) + all_gather_into_tensor_333 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1058, 128, '0'); convert_element_type_1058 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_333); all_gather_into_tensor_333 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_409, [1, 0]); wait_tensor_409 = None + mm_158 = torch.ops.aten.mm.default(mul_937, permute_295); permute_295 = None + index_put_37 = torch.ops.aten.index_put.default(full_default_1, [getitem_2001], wait_tensor_406); wait_tensor_406 = None + view_1304 = torch.ops.aten.view.default(mul_899, [-1, 1, 6]); mul_899 = None + view_1305 = torch.ops.aten.view.default(index_put_37, [-1, 6, 2048]); index_put_37 = None + convert_element_type_1061 = torch.ops.prims.convert_element_type.default(view_1305, torch.float32); view_1305 = None + bmm_18 = torch.ops.aten.bmm.default(view_1304, convert_element_type_1061) + convert_element_type_1062 = torch.ops.prims.convert_element_type.default(bmm_18, torch.bfloat16); bmm_18 = None + squeeze_18 = torch.ops.aten.squeeze.dim(convert_element_type_1062, 1); convert_element_type_1062 = None + add_1296 = torch.ops.aten.add.Tensor(mm_158, squeeze_18); mm_158 = squeeze_18 = None + view_1306 = torch.ops.aten.view.default(add_1296, [2, 4096, 2048]); add_1296 = None + add_1297 = torch.ops.aten.add.Tensor(add_1232, view_1306); view_1306 = None + convert_element_type_1063 = torch.ops.prims.convert_element_type.default(primals_327, torch.bfloat16) + all_gather_into_tensor_334 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1063, 128, '0'); convert_element_type_1063 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_334); all_gather_into_tensor_334 = None + convert_element_type_1064 = torch.ops.prims.convert_element_type.default(add_1297, torch.float32) + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1064, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_1298 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_1298); add_1298 = None + mul_940 = torch.ops.aten.mul.Tensor(convert_element_type_1064, rsqrt_60); convert_element_type_1064 = None + mul_941 = torch.ops.aten.mul.Tensor(mul_940, wait_tensor_410); mul_940 = wait_tensor_410 = None + convert_element_type_1065 = torch.ops.prims.convert_element_type.default(mul_941, torch.bfloat16); mul_941 = None + convert_element_type_1066 = torch.ops.prims.convert_element_type.default(primals_328, torch.bfloat16) + all_gather_into_tensor_335 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1066, 128, '0'); convert_element_type_1066 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_335); all_gather_into_tensor_335 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_411, [1, 0]); wait_tensor_411 = None + view_1309 = torch.ops.aten.view.default(convert_element_type_1065, [8192, 2048]); convert_element_type_1065 = None + mm_159 = torch.ops.aten.mm.default(view_1309, permute_296); permute_296 = None + view_1310 = torch.ops.aten.view.default(mm_159, [2, 4096, 3072]); mm_159 = None + view_1311 = torch.ops.aten.view.default(view_1310, [2, 4096, -1, 192]); view_1310 = None + split_with_sizes_60 = torch.ops.aten.split_with_sizes.default(view_1311, [128, 64], -1); view_1311 = None + getitem_2099 = split_with_sizes_60[0] + getitem_2100 = split_with_sizes_60[1]; split_with_sizes_60 = None + convert_element_type_1069 = torch.ops.prims.convert_element_type.default(getitem_2100, torch.float32); getitem_2100 = None + view_1312 = torch.ops.aten.view.default(convert_element_type_1069, [2, 4096, 16, -1, 2]); convert_element_type_1069 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_1312); view_1312 = None + mul_942 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_7); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_942); mul_942 = None + view_1314 = torch.ops.aten.view.default(view_as_real_40, [2, 4096, 16, 64]); view_as_real_40 = None + convert_element_type_1070 = torch.ops.prims.convert_element_type.default(view_1314, torch.bfloat16); view_1314 = None + cat_173 = torch.ops.aten.cat.default([getitem_2099, convert_element_type_1070], -1); getitem_2099 = convert_element_type_1070 = None + convert_element_type_1071 = torch.ops.prims.convert_element_type.default(primals_329, torch.bfloat16) + all_gather_into_tensor_336 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1071, 128, '0'); convert_element_type_1071 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_336); all_gather_into_tensor_336 = None + slice_121 = torch.ops.aten.slice.Tensor(wait_tensor_412, 0, 0, 576); wait_tensor_412 = None + permute_297 = torch.ops.aten.permute.default(slice_121, [1, 0]); slice_121 = None + mm_160 = torch.ops.aten.mm.default(view_1309, permute_297); permute_297 = None + view_1317 = torch.ops.aten.view.default(mm_160, [2, 4096, 576]); mm_160 = None + split_with_sizes_61 = torch.ops.aten.split_with_sizes.default(view_1317, [512, 64], -1); view_1317 = None + getitem_2101 = split_with_sizes_61[0] + getitem_2102 = split_with_sizes_61[1]; split_with_sizes_61 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(getitem_2102, 2); getitem_2102 = None + convert_element_type_1074 = torch.ops.prims.convert_element_type.default(unsqueeze_39, torch.float32); unsqueeze_39 = None + view_1318 = torch.ops.aten.view.default(convert_element_type_1074, [2, 4096, 1, -1, 2]); convert_element_type_1074 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_1318); view_1318 = None + mul_943 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_7); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_943); mul_943 = None + view_1320 = torch.ops.aten.view.default(view_as_real_41, [2, 4096, 1, 64]); view_as_real_41 = None + convert_element_type_1075 = torch.ops.prims.convert_element_type.default(view_1320, torch.bfloat16); view_1320 = None + convert_element_type_1076 = torch.ops.prims.convert_element_type.default(primals_330, torch.bfloat16) + all_gather_into_tensor_337 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1076, 128, '0'); convert_element_type_1076 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_337); all_gather_into_tensor_337 = None + convert_element_type_1077 = torch.ops.prims.convert_element_type.default(getitem_2101, torch.float32) + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1077, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_1299 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_1299); add_1299 = None + mul_944 = torch.ops.aten.mul.Tensor(convert_element_type_1077, rsqrt_61); convert_element_type_1077 = None + mul_945 = torch.ops.aten.mul.Tensor(mul_944, wait_tensor_413); mul_944 = wait_tensor_413 = None + convert_element_type_1078 = torch.ops.prims.convert_element_type.default(mul_945, torch.bfloat16); mul_945 = None + convert_element_type_1079 = torch.ops.prims.convert_element_type.default(primals_331, torch.bfloat16) + all_gather_into_tensor_338 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1079, 128, '0'); convert_element_type_1079 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_338); all_gather_into_tensor_338 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_414, [1, 0]); wait_tensor_414 = None + view_1323 = torch.ops.aten.view.default(convert_element_type_1078, [8192, 512]); convert_element_type_1078 = None + mm_161 = torch.ops.aten.mm.default(view_1323, permute_298); permute_298 = None + view_1324 = torch.ops.aten.view.default(mm_161, [2, 4096, 4096]); mm_161 = None + view_1325 = torch.ops.aten.view.default(view_1324, [2, 4096, -1, 256]); view_1324 = None + split_with_sizes_62 = torch.ops.aten.split_with_sizes.default(view_1325, [128, 128], -1); view_1325 = None + getitem_2103 = split_with_sizes_62[0] + getitem_2104 = split_with_sizes_62[1]; split_with_sizes_62 = None + expand_20 = torch.ops.aten.expand.default(convert_element_type_1075, [-1, -1, 16, -1]); convert_element_type_1075 = None + cat_174 = torch.ops.aten.cat.default([getitem_2103, expand_20], -1); getitem_2103 = expand_20 = None + permute_299 = torch.ops.aten.permute.default(cat_173, [0, 2, 1, 3]); cat_173 = None + permute_300 = torch.ops.aten.permute.default(cat_174, [0, 2, 1, 3]); cat_174 = None + permute_301 = torch.ops.aten.permute.default(getitem_2104, [0, 2, 1, 3]); getitem_2104 = None + sdpa_score20 = self.sdpa_score20 + sdpa_mask20 = self.sdpa_mask20 + flex_attention_20 = torch.ops.higher_order.flex_attention(permute_299, permute_300, permute_301, sdpa_score20, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask20), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score20 = sdpa_mask20 = None + getitem_2105 = flex_attention_20[0] + getitem_2106 = flex_attention_20[1]; flex_attention_20 = None + permute_302 = torch.ops.aten.permute.default(getitem_2105, [0, 2, 1, 3]) + view_1326 = torch.ops.aten.view.default(permute_302, [2, 4096, -1]); permute_302 = None + convert_element_type_1082 = torch.ops.prims.convert_element_type.default(primals_332, torch.bfloat16) + all_gather_into_tensor_339 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1082, 128, '0'); convert_element_type_1082 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_339); all_gather_into_tensor_339 = None + permute_303 = torch.ops.aten.permute.default(wait_tensor_415, [1, 0]); wait_tensor_415 = None + view_1328 = torch.ops.aten.view.default(view_1326, [8192, 2048]); view_1326 = None + mm_162 = torch.ops.aten.mm.default(view_1328, permute_303); view_1328 = permute_303 = None + view_1329 = torch.ops.aten.view.default(mm_162, [2, 4096, 2048]); mm_162 = None + add_1300 = torch.ops.aten.add.Tensor(add_1297, view_1329); view_1329 = None + convert_element_type_1085 = torch.ops.prims.convert_element_type.default(primals_333, torch.bfloat16) + all_gather_into_tensor_340 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1085, 128, '0'); convert_element_type_1085 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_340); all_gather_into_tensor_340 = None + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(add_1300, torch.float32) + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1086, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_1301 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_1301); add_1301 = None + mul_946 = torch.ops.aten.mul.Tensor(convert_element_type_1086, rsqrt_62); convert_element_type_1086 = None + mul_947 = torch.ops.aten.mul.Tensor(mul_946, wait_tensor_416); mul_946 = wait_tensor_416 = None + convert_element_type_1087 = torch.ops.prims.convert_element_type.default(mul_947, torch.bfloat16); mul_947 = None + view_1331 = torch.ops.aten.view.default(convert_element_type_1087, [-1, 2048]); convert_element_type_1087 = None + convert_element_type_1088 = torch.ops.prims.convert_element_type.default(primals_335, torch.bfloat16) + all_gather_into_tensor_341 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1088, 128, '0'); convert_element_type_1088 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_341); all_gather_into_tensor_341 = None + slice_123 = torch.ops.aten.slice.Tensor(wait_tensor_417, 0, 0, 64); wait_tensor_417 = None + permute_304 = torch.ops.aten.permute.default(slice_123, [1, 0]); slice_123 = None + mm_163 = torch.ops.aten.mm.default(view_1331, permute_304); permute_304 = None + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(mm_163, torch.float32) + amax_19 = torch.ops.aten.amax.default(convert_element_type_1091, [1], True) + sub_456 = torch.ops.aten.sub.Tensor(convert_element_type_1091, amax_19); convert_element_type_1091 = None + exp_58 = torch.ops.aten.exp.default(sub_456); sub_456 = None + sum_77 = torch.ops.aten.sum.dim_IntList(exp_58, [1], True) + div_96 = torch.ops.aten.div.Tensor(exp_58, sum_77); exp_58 = None + add_1302 = torch.ops.aten.add.Tensor(div_96, primals_334); primals_334 = None + topk_19 = torch.ops.aten.topk.default(add_1302, 6, -1, True, False); add_1302 = None + getitem_2109 = topk_19[1]; topk_19 = None + gather_19 = torch.ops.aten.gather.default(div_96, 1, getitem_2109); div_96 = None + mul_948 = torch.ops.aten.mul.Tensor(gather_19, 1.0); gather_19 = None + view_1333 = torch.ops.aten.view.default(getitem_2109, [-1]) + histc_38 = torch.ops.aten.histc.default(view_1333, 64, 0, 64) + add_1303 = torch.ops.aten.add.Tensor(primals_336, histc_38) + sort_19 = torch.ops.aten.sort.stable(view_1333, stable = True); view_1333 = None + getitem_2111 = sort_19[1]; sort_19 = None + div_97 = torch.ops.aten.div.Tensor_mode(getitem_2111, 6, rounding_mode = 'floor') + index_38 = torch.ops.aten.index.Tensor(view_1331, [div_97]) + all_to_all_single_57 = torch.ops._c10d_functional.all_to_all_single.default(histc_38, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_57); all_to_all_single_57 = None + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_418); wait_tensor_418 = None + view_1337 = torch.ops.aten.view.default(histc_38, [8, -1]); histc_38 = None + sum_78 = torch.ops.aten.sum.dim_IntList(view_1337, [1]); view_1337 = None + device_put_38 = torch.ops.prims.device_put.default(sum_78, device(type='cpu'), True); sum_78 = None + view_1338 = torch.ops.aten.view.default(wait_tensor_419, [8, -1]) + sum_79 = torch.ops.aten.sum.dim_IntList(view_1338, [1]) + device_put_39 = torch.ops.prims.device_put.default(sum_79, device(type='cpu')); sum_79 = None + select_304 = torch.ops.aten.select.int(device_put_38, 0, 0) + _local_scalar_dense_304 = torch.ops.aten._local_scalar_dense.default(select_304); select_304 = None + ge_380 = _local_scalar_dense_304 >= 0 + _assert_scalar_304 = torch.ops.aten._assert_scalar.default(ge_380, "Runtime assertion failed for expression u304 >= 0 on node 'ge_304'"); ge_380 = _assert_scalar_304 = None + select_305 = torch.ops.aten.select.int(device_put_38, 0, 1) + _local_scalar_dense_305 = torch.ops.aten._local_scalar_dense.default(select_305); select_305 = None + ge_381 = _local_scalar_dense_305 >= 0 + _assert_scalar_305 = torch.ops.aten._assert_scalar.default(ge_381, "Runtime assertion failed for expression u305 >= 0 on node 'ge_305'"); ge_381 = _assert_scalar_305 = None + select_306 = torch.ops.aten.select.int(device_put_38, 0, 2) + _local_scalar_dense_306 = torch.ops.aten._local_scalar_dense.default(select_306); select_306 = None + ge_382 = _local_scalar_dense_306 >= 0 + _assert_scalar_306 = torch.ops.aten._assert_scalar.default(ge_382, "Runtime assertion failed for expression u306 >= 0 on node 'ge_306'"); ge_382 = _assert_scalar_306 = None + select_307 = torch.ops.aten.select.int(device_put_38, 0, 3) + _local_scalar_dense_307 = torch.ops.aten._local_scalar_dense.default(select_307); select_307 = None + ge_383 = _local_scalar_dense_307 >= 0 + _assert_scalar_307 = torch.ops.aten._assert_scalar.default(ge_383, "Runtime assertion failed for expression u307 >= 0 on node 'ge_307'"); ge_383 = _assert_scalar_307 = None + select_308 = torch.ops.aten.select.int(device_put_38, 0, 4) + _local_scalar_dense_308 = torch.ops.aten._local_scalar_dense.default(select_308); select_308 = None + ge_384 = _local_scalar_dense_308 >= 0 + _assert_scalar_308 = torch.ops.aten._assert_scalar.default(ge_384, "Runtime assertion failed for expression u308 >= 0 on node 'ge_308'"); ge_384 = _assert_scalar_308 = None + select_309 = torch.ops.aten.select.int(device_put_38, 0, 5) + _local_scalar_dense_309 = torch.ops.aten._local_scalar_dense.default(select_309); select_309 = None + ge_385 = _local_scalar_dense_309 >= 0 + _assert_scalar_309 = torch.ops.aten._assert_scalar.default(ge_385, "Runtime assertion failed for expression u309 >= 0 on node 'ge_309'"); ge_385 = _assert_scalar_309 = None + select_310 = torch.ops.aten.select.int(device_put_38, 0, 6) + _local_scalar_dense_310 = torch.ops.aten._local_scalar_dense.default(select_310); select_310 = None + ge_386 = _local_scalar_dense_310 >= 0 + _assert_scalar_310 = torch.ops.aten._assert_scalar.default(ge_386, "Runtime assertion failed for expression u310 >= 0 on node 'ge_310'"); ge_386 = _assert_scalar_310 = None + select_311 = torch.ops.aten.select.int(device_put_38, 0, 7); device_put_38 = None + _local_scalar_dense_311 = torch.ops.aten._local_scalar_dense.default(select_311); select_311 = None + ge_387 = _local_scalar_dense_311 >= 0 + _assert_scalar_311 = torch.ops.aten._assert_scalar.default(ge_387, "Runtime assertion failed for expression u311 >= 0 on node 'ge_311'"); ge_387 = _assert_scalar_311 = None + select_312 = torch.ops.aten.select.int(device_put_39, 0, 0) + _local_scalar_dense_312 = torch.ops.aten._local_scalar_dense.default(select_312); select_312 = None + ge_388 = _local_scalar_dense_312 >= 0 + _assert_scalar_312 = torch.ops.aten._assert_scalar.default(ge_388, "Runtime assertion failed for expression u312 >= 0 on node 'ge_312'"); ge_388 = _assert_scalar_312 = None + select_313 = torch.ops.aten.select.int(device_put_39, 0, 1) + _local_scalar_dense_313 = torch.ops.aten._local_scalar_dense.default(select_313); select_313 = None + ge_389 = _local_scalar_dense_313 >= 0 + _assert_scalar_313 = torch.ops.aten._assert_scalar.default(ge_389, "Runtime assertion failed for expression u313 >= 0 on node 'ge_313'"); ge_389 = _assert_scalar_313 = None + select_314 = torch.ops.aten.select.int(device_put_39, 0, 2) + _local_scalar_dense_314 = torch.ops.aten._local_scalar_dense.default(select_314); select_314 = None + ge_390 = _local_scalar_dense_314 >= 0 + _assert_scalar_314 = torch.ops.aten._assert_scalar.default(ge_390, "Runtime assertion failed for expression u314 >= 0 on node 'ge_314'"); ge_390 = _assert_scalar_314 = None + select_315 = torch.ops.aten.select.int(device_put_39, 0, 3) + _local_scalar_dense_315 = torch.ops.aten._local_scalar_dense.default(select_315); select_315 = None + ge_391 = _local_scalar_dense_315 >= 0 + _assert_scalar_315 = torch.ops.aten._assert_scalar.default(ge_391, "Runtime assertion failed for expression u315 >= 0 on node 'ge_315'"); ge_391 = _assert_scalar_315 = None + select_316 = torch.ops.aten.select.int(device_put_39, 0, 4) + _local_scalar_dense_316 = torch.ops.aten._local_scalar_dense.default(select_316); select_316 = None + ge_392 = _local_scalar_dense_316 >= 0 + _assert_scalar_316 = torch.ops.aten._assert_scalar.default(ge_392, "Runtime assertion failed for expression u316 >= 0 on node 'ge_316'"); ge_392 = _assert_scalar_316 = None + select_317 = torch.ops.aten.select.int(device_put_39, 0, 5) + _local_scalar_dense_317 = torch.ops.aten._local_scalar_dense.default(select_317); select_317 = None + ge_393 = _local_scalar_dense_317 >= 0 + _assert_scalar_317 = torch.ops.aten._assert_scalar.default(ge_393, "Runtime assertion failed for expression u317 >= 0 on node 'ge_317'"); ge_393 = _assert_scalar_317 = None + select_318 = torch.ops.aten.select.int(device_put_39, 0, 6) + _local_scalar_dense_318 = torch.ops.aten._local_scalar_dense.default(select_318); select_318 = None + ge_394 = _local_scalar_dense_318 >= 0 + _assert_scalar_318 = torch.ops.aten._assert_scalar.default(ge_394, "Runtime assertion failed for expression u318 >= 0 on node 'ge_318'"); ge_394 = _assert_scalar_318 = None + select_319 = torch.ops.aten.select.int(device_put_39, 0, 7); device_put_39 = None + _local_scalar_dense_319 = torch.ops.aten._local_scalar_dense.default(select_319); select_319 = None + ge_395 = _local_scalar_dense_319 >= 0 + _assert_scalar_319 = torch.ops.aten._assert_scalar.default(ge_395, "Runtime assertion failed for expression u319 >= 0 on node 'ge_319'"); ge_395 = _assert_scalar_319 = None + all_to_all_single_58 = torch.ops._c10d_functional.all_to_all_single.default(index_38, [_local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319], [_local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311], '1033'); index_38 = None + sym_size_int_76 = torch.ops.aten.sym_size.int(all_to_all_single_58, 0) + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_58); all_to_all_single_58 = None + sym_sum_38 = torch.sym_sum((_local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319)) + add_1310 = sym_sum_38 + 64; sym_sum_38 = None + add_1311 = add_1310 + 8; add_1310 = None + sub_459 = add_1311 - 1; add_1311 = None + floordiv_19 = sub_459 // 8; sub_459 = None + mul_953 = floordiv_19 * 8; floordiv_19 = None + cumsum_57 = torch.ops.aten.cumsum.default(wait_tensor_419, 0) + sub_460 = torch.ops.aten.sub.Tensor(cumsum_57, wait_tensor_419); cumsum_57 = None + sum_80 = torch.ops.aten.sum.dim_IntList(view_1338, [0]); view_1338 = None + clamp_min_19 = torch.ops.aten.clamp_min.default(sum_80, 8); sum_80 = None + add_1312 = torch.ops.aten.add.Tensor(clamp_min_19, 8); clamp_min_19 = None + sub_461 = torch.ops.aten.sub.Tensor(add_1312, 1); add_1312 = None + div_98 = torch.ops.aten.div.Tensor_mode(sub_461, 8, rounding_mode = 'floor'); sub_461 = None + mul_954 = torch.ops.aten.mul.Tensor(div_98, 8); div_98 = None + convert_element_type_1094 = torch.ops.prims.convert_element_type.default(mul_954, torch.int32); mul_954 = None + cumsum_58 = torch.ops.aten.cumsum.default(convert_element_type_1094, 0) + sub_462 = torch.ops.aten.sub.Tensor(cumsum_58, convert_element_type_1094); cumsum_58 = None + full_267 = torch.ops.aten.full.default([mul_953], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_953 = None + triton_kernel_wrapper_functional_proxy_19 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 19, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_419, 'start_index_values_ptr': sub_460, 'write_offsets_ptr': sub_462, 'output_ptr': full_267}, tensors_to_clone = ['output_ptr']); wait_tensor_419 = sub_460 = sub_462 = full_267 = None + getitem_2112 = triton_kernel_wrapper_functional_proxy_19['output_ptr']; triton_kernel_wrapper_functional_proxy_19 = None + cat_175 = torch.ops.aten.cat.default([wait_tensor_420, full_default]); wait_tensor_420 = None + sym_size_int_77 = torch.ops.aten.sym_size.int(cat_175, 0) + sym_sum_39 = torch.sym_sum((1, _local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319)) + index_39 = torch.ops.aten.index.Tensor(cat_175, [getitem_2112]); cat_175 = None + convert_element_type_1096 = torch.ops.prims.convert_element_type.default(primals_337, torch.bfloat16) + all_gather_into_tensor_342 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1096, 16, '1025'); convert_element_type_1096 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_342); all_gather_into_tensor_342 = None + split_115 = torch.ops.aten.split.Tensor(wait_tensor_421, 8); wait_tensor_421 = None + getitem_2129 = split_115[0] + getitem_2130 = split_115[1] + getitem_2131 = split_115[2] + getitem_2132 = split_115[3] + getitem_2133 = split_115[4] + getitem_2134 = split_115[5] + getitem_2135 = split_115[6] + getitem_2136 = split_115[7] + getitem_2137 = split_115[8] + getitem_2138 = split_115[9] + getitem_2139 = split_115[10] + getitem_2140 = split_115[11] + getitem_2141 = split_115[12] + getitem_2142 = split_115[13] + getitem_2143 = split_115[14] + getitem_2144 = split_115[15]; split_115 = None + cat_177 = torch.ops.aten.cat.default([getitem_2129, getitem_2130, getitem_2131, getitem_2132, getitem_2133, getitem_2134, getitem_2135, getitem_2136, getitem_2137, getitem_2138, getitem_2139, getitem_2140, getitem_2141, getitem_2142, getitem_2143, getitem_2144], 1); getitem_2129 = getitem_2130 = getitem_2131 = getitem_2132 = getitem_2133 = getitem_2134 = getitem_2135 = getitem_2136 = getitem_2137 = getitem_2138 = getitem_2139 = getitem_2140 = getitem_2141 = getitem_2142 = getitem_2143 = getitem_2144 = None + convert_element_type_1098 = torch.ops.prims.convert_element_type.default(primals_338, torch.bfloat16) + all_gather_into_tensor_344 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1098, 16, '1025'); convert_element_type_1098 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_344); all_gather_into_tensor_344 = None + split_116 = torch.ops.aten.split.Tensor(wait_tensor_423, 8); wait_tensor_423 = None + getitem_2145 = split_116[0] + getitem_2146 = split_116[1] + getitem_2147 = split_116[2] + getitem_2148 = split_116[3] + getitem_2149 = split_116[4] + getitem_2150 = split_116[5] + getitem_2151 = split_116[6] + getitem_2152 = split_116[7] + getitem_2153 = split_116[8] + getitem_2154 = split_116[9] + getitem_2155 = split_116[10] + getitem_2156 = split_116[11] + getitem_2157 = split_116[12] + getitem_2158 = split_116[13] + getitem_2159 = split_116[14] + getitem_2160 = split_116[15]; split_116 = None + cat_178 = torch.ops.aten.cat.default([getitem_2145, getitem_2146, getitem_2147, getitem_2148, getitem_2149, getitem_2150, getitem_2151, getitem_2152, getitem_2153, getitem_2154, getitem_2155, getitem_2156, getitem_2157, getitem_2158, getitem_2159, getitem_2160], 1); getitem_2145 = getitem_2146 = getitem_2147 = getitem_2148 = getitem_2149 = getitem_2150 = getitem_2151 = getitem_2152 = getitem_2153 = getitem_2154 = getitem_2155 = getitem_2156 = getitem_2157 = getitem_2158 = getitem_2159 = getitem_2160 = None + convert_element_type_1099 = torch.ops.prims.convert_element_type.default(primals_339, torch.bfloat16) + all_gather_into_tensor_345 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1099, 16, '1025'); convert_element_type_1099 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_345); all_gather_into_tensor_345 = None + split_117 = torch.ops.aten.split.Tensor(wait_tensor_424, 8); wait_tensor_424 = None + getitem_2161 = split_117[0] + getitem_2162 = split_117[1] + getitem_2163 = split_117[2] + getitem_2164 = split_117[3] + getitem_2165 = split_117[4] + getitem_2166 = split_117[5] + getitem_2167 = split_117[6] + getitem_2168 = split_117[7] + getitem_2169 = split_117[8] + getitem_2170 = split_117[9] + getitem_2171 = split_117[10] + getitem_2172 = split_117[11] + getitem_2173 = split_117[12] + getitem_2174 = split_117[13] + getitem_2175 = split_117[14] + getitem_2176 = split_117[15]; split_117 = None + cat_179 = torch.ops.aten.cat.default([getitem_2161, getitem_2162, getitem_2163, getitem_2164, getitem_2165, getitem_2166, getitem_2167, getitem_2168, getitem_2169, getitem_2170, getitem_2171, getitem_2172, getitem_2173, getitem_2174, getitem_2175, getitem_2176], 1); getitem_2161 = getitem_2162 = getitem_2163 = getitem_2164 = getitem_2165 = getitem_2166 = getitem_2167 = getitem_2168 = getitem_2169 = getitem_2170 = getitem_2171 = getitem_2172 = getitem_2173 = getitem_2174 = getitem_2175 = getitem_2176 = None + cumsum_59 = torch.ops.aten.cumsum.default(convert_element_type_1094, 0, dtype = torch.int32); convert_element_type_1094 = None + permute_305 = torch.ops.aten.permute.default(cat_177, [0, 2, 1]); cat_177 = None + _grouped_mm_57 = torch.ops.aten._grouped_mm.default(index_39, permute_305, cumsum_59) + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(_grouped_mm_57, torch.float32) + neg_39 = torch.ops.aten.neg.default(convert_element_type_1102) + exp_59 = torch.ops.aten.exp.default(neg_39); neg_39 = None + add_1324 = torch.ops.aten.add.Tensor(exp_59, 1); exp_59 = None + div_99 = torch.ops.aten.div.Tensor(convert_element_type_1102, add_1324); convert_element_type_1102 = add_1324 = None + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(div_99, torch.bfloat16); div_99 = None + permute_306 = torch.ops.aten.permute.default(cat_179, [0, 2, 1]); cat_179 = None + _grouped_mm_58 = torch.ops.aten._grouped_mm.default(index_39, permute_306, cumsum_59) + mul_966 = torch.ops.aten.mul.Tensor(convert_element_type_1103, _grouped_mm_58); convert_element_type_1103 = None + permute_307 = torch.ops.aten.permute.default(cat_178, [0, 2, 1]); cat_178 = None + _grouped_mm_59 = torch.ops.aten._grouped_mm.default(mul_966, permute_307, cumsum_59) + empty_19 = torch.ops.aten.empty.memory_format([sym_size_int_77, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_38 = torch.ops.aten.index_put.default(empty_19, [getitem_2112], _grouped_mm_59); empty_19 = _grouped_mm_59 = None + slice_125 = torch.ops.aten.slice.Tensor(index_put_38, 0, 0, -1); index_put_38 = None + all_to_all_single_59 = torch.ops._c10d_functional.all_to_all_single.default(slice_125, [_local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311], [_local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319], '1033'); slice_125 = None + wait_tensor_427 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_59); all_to_all_single_59 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(primals_340, torch.bfloat16) + all_gather_into_tensor_348 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1104, 128, '0'); convert_element_type_1104 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_348); all_gather_into_tensor_348 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_428, [1, 0]); wait_tensor_428 = None + mm_164 = torch.ops.aten.mm.default(view_1331, permute_308); permute_308 = None + convert_element_type_1107 = torch.ops.prims.convert_element_type.default(mm_164, torch.float32) + neg_40 = torch.ops.aten.neg.default(convert_element_type_1107) + exp_60 = torch.ops.aten.exp.default(neg_40); neg_40 = None + add_1360 = torch.ops.aten.add.Tensor(exp_60, 1); exp_60 = None + div_100 = torch.ops.aten.div.Tensor(convert_element_type_1107, add_1360); convert_element_type_1107 = add_1360 = None + convert_element_type_1108 = torch.ops.prims.convert_element_type.default(div_100, torch.bfloat16); div_100 = None + convert_element_type_1109 = torch.ops.prims.convert_element_type.default(primals_341, torch.bfloat16) + all_gather_into_tensor_349 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1109, 128, '0'); convert_element_type_1109 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_349); all_gather_into_tensor_349 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_429, [1, 0]); wait_tensor_429 = None + mm_165 = torch.ops.aten.mm.default(view_1331, permute_309); permute_309 = None + mul_986 = torch.ops.aten.mul.Tensor(convert_element_type_1108, mm_165); convert_element_type_1108 = None + convert_element_type_1112 = torch.ops.prims.convert_element_type.default(primals_342, torch.bfloat16) + all_gather_into_tensor_350 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1112, 128, '0'); convert_element_type_1112 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_350); all_gather_into_tensor_350 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_430, [1, 0]); wait_tensor_430 = None + mm_166 = torch.ops.aten.mm.default(mul_986, permute_310); permute_310 = None + index_put_39 = torch.ops.aten.index_put.default(full_default_1, [getitem_2111], wait_tensor_427); wait_tensor_427 = None + view_1371 = torch.ops.aten.view.default(mul_948, [-1, 1, 6]); mul_948 = None + view_1372 = torch.ops.aten.view.default(index_put_39, [-1, 6, 2048]); index_put_39 = None + convert_element_type_1115 = torch.ops.prims.convert_element_type.default(view_1372, torch.float32); view_1372 = None + bmm_19 = torch.ops.aten.bmm.default(view_1371, convert_element_type_1115) + convert_element_type_1116 = torch.ops.prims.convert_element_type.default(bmm_19, torch.bfloat16); bmm_19 = None + squeeze_19 = torch.ops.aten.squeeze.dim(convert_element_type_1116, 1); convert_element_type_1116 = None + add_1364 = torch.ops.aten.add.Tensor(mm_166, squeeze_19); mm_166 = squeeze_19 = None + view_1373 = torch.ops.aten.view.default(add_1364, [2, 4096, 2048]); add_1364 = None + add_1365 = torch.ops.aten.add.Tensor(add_1300, view_1373); view_1373 = None + convert_element_type_1117 = torch.ops.prims.convert_element_type.default(primals_343, torch.bfloat16) + all_gather_into_tensor_351 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1117, 128, '0'); convert_element_type_1117 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_351); all_gather_into_tensor_351 = None + convert_element_type_1118 = torch.ops.prims.convert_element_type.default(add_1365, torch.float32) + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1118, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_1366 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_1366); add_1366 = None + mul_989 = torch.ops.aten.mul.Tensor(convert_element_type_1118, rsqrt_63); convert_element_type_1118 = None + mul_990 = torch.ops.aten.mul.Tensor(mul_989, wait_tensor_431); mul_989 = wait_tensor_431 = None + convert_element_type_1119 = torch.ops.prims.convert_element_type.default(mul_990, torch.bfloat16); mul_990 = None + convert_element_type_1120 = torch.ops.prims.convert_element_type.default(primals_344, torch.bfloat16) + all_gather_into_tensor_352 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1120, 128, '0'); convert_element_type_1120 = None + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_352); all_gather_into_tensor_352 = None + permute_311 = torch.ops.aten.permute.default(wait_tensor_432, [1, 0]); wait_tensor_432 = None + view_1376 = torch.ops.aten.view.default(convert_element_type_1119, [8192, 2048]); convert_element_type_1119 = None + mm_167 = torch.ops.aten.mm.default(view_1376, permute_311); permute_311 = None + view_1377 = torch.ops.aten.view.default(mm_167, [2, 4096, 3072]); mm_167 = None + view_1378 = torch.ops.aten.view.default(view_1377, [2, 4096, -1, 192]); view_1377 = None + split_with_sizes_63 = torch.ops.aten.split_with_sizes.default(view_1378, [128, 64], -1); view_1378 = None + getitem_2209 = split_with_sizes_63[0] + getitem_2210 = split_with_sizes_63[1]; split_with_sizes_63 = None + convert_element_type_1123 = torch.ops.prims.convert_element_type.default(getitem_2210, torch.float32); getitem_2210 = None + view_1379 = torch.ops.aten.view.default(convert_element_type_1123, [2, 4096, 16, -1, 2]); convert_element_type_1123 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_1379); view_1379 = None + mul_991 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_7); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_991); mul_991 = None + view_1381 = torch.ops.aten.view.default(view_as_real_42, [2, 4096, 16, 64]); view_as_real_42 = None + convert_element_type_1124 = torch.ops.prims.convert_element_type.default(view_1381, torch.bfloat16); view_1381 = None + cat_182 = torch.ops.aten.cat.default([getitem_2209, convert_element_type_1124], -1); getitem_2209 = convert_element_type_1124 = None + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(primals_345, torch.bfloat16) + all_gather_into_tensor_353 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1125, 128, '0'); convert_element_type_1125 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_353); all_gather_into_tensor_353 = None + slice_127 = torch.ops.aten.slice.Tensor(wait_tensor_433, 0, 0, 576); wait_tensor_433 = None + permute_312 = torch.ops.aten.permute.default(slice_127, [1, 0]); slice_127 = None + mm_168 = torch.ops.aten.mm.default(view_1376, permute_312); permute_312 = None + view_1384 = torch.ops.aten.view.default(mm_168, [2, 4096, 576]); mm_168 = None + split_with_sizes_64 = torch.ops.aten.split_with_sizes.default(view_1384, [512, 64], -1); view_1384 = None + getitem_2211 = split_with_sizes_64[0] + getitem_2212 = split_with_sizes_64[1]; split_with_sizes_64 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(getitem_2212, 2); getitem_2212 = None + convert_element_type_1128 = torch.ops.prims.convert_element_type.default(unsqueeze_41, torch.float32); unsqueeze_41 = None + view_1385 = torch.ops.aten.view.default(convert_element_type_1128, [2, 4096, 1, -1, 2]); convert_element_type_1128 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_1385); view_1385 = None + mul_992 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_7); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_992); mul_992 = None + view_1387 = torch.ops.aten.view.default(view_as_real_43, [2, 4096, 1, 64]); view_as_real_43 = None + convert_element_type_1129 = torch.ops.prims.convert_element_type.default(view_1387, torch.bfloat16); view_1387 = None + convert_element_type_1130 = torch.ops.prims.convert_element_type.default(primals_346, torch.bfloat16) + all_gather_into_tensor_354 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1130, 128, '0'); convert_element_type_1130 = None + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_354); all_gather_into_tensor_354 = None + convert_element_type_1131 = torch.ops.prims.convert_element_type.default(getitem_2211, torch.float32) + pow_65 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1131, 2) + mean_64 = torch.ops.aten.mean.dim(pow_65, [2], True); pow_65 = None + add_1367 = torch.ops.aten.add.Scalar(mean_64, 1e-05); mean_64 = None + rsqrt_64 = torch.ops.aten.rsqrt.default(add_1367); add_1367 = None + mul_993 = torch.ops.aten.mul.Tensor(convert_element_type_1131, rsqrt_64); convert_element_type_1131 = None + mul_994 = torch.ops.aten.mul.Tensor(mul_993, wait_tensor_434); mul_993 = wait_tensor_434 = None + convert_element_type_1132 = torch.ops.prims.convert_element_type.default(mul_994, torch.bfloat16); mul_994 = None + convert_element_type_1133 = torch.ops.prims.convert_element_type.default(primals_347, torch.bfloat16) + all_gather_into_tensor_355 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1133, 128, '0'); convert_element_type_1133 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_355); all_gather_into_tensor_355 = None + permute_313 = torch.ops.aten.permute.default(wait_tensor_435, [1, 0]); wait_tensor_435 = None + view_1390 = torch.ops.aten.view.default(convert_element_type_1132, [8192, 512]); convert_element_type_1132 = None + mm_169 = torch.ops.aten.mm.default(view_1390, permute_313); permute_313 = None + view_1391 = torch.ops.aten.view.default(mm_169, [2, 4096, 4096]); mm_169 = None + view_1392 = torch.ops.aten.view.default(view_1391, [2, 4096, -1, 256]); view_1391 = None + split_with_sizes_65 = torch.ops.aten.split_with_sizes.default(view_1392, [128, 128], -1); view_1392 = None + getitem_2213 = split_with_sizes_65[0] + getitem_2214 = split_with_sizes_65[1]; split_with_sizes_65 = None + expand_21 = torch.ops.aten.expand.default(convert_element_type_1129, [-1, -1, 16, -1]); convert_element_type_1129 = None + cat_183 = torch.ops.aten.cat.default([getitem_2213, expand_21], -1); getitem_2213 = expand_21 = None + permute_314 = torch.ops.aten.permute.default(cat_182, [0, 2, 1, 3]); cat_182 = None + permute_315 = torch.ops.aten.permute.default(cat_183, [0, 2, 1, 3]); cat_183 = None + permute_316 = torch.ops.aten.permute.default(getitem_2214, [0, 2, 1, 3]); getitem_2214 = None + sdpa_score21 = self.sdpa_score21 + sdpa_mask21 = self.sdpa_mask21 + flex_attention_21 = torch.ops.higher_order.flex_attention(permute_314, permute_315, permute_316, sdpa_score21, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask21), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score21 = sdpa_mask21 = None + getitem_2215 = flex_attention_21[0] + getitem_2216 = flex_attention_21[1]; flex_attention_21 = None + permute_317 = torch.ops.aten.permute.default(getitem_2215, [0, 2, 1, 3]) + view_1393 = torch.ops.aten.view.default(permute_317, [2, 4096, -1]); permute_317 = None + convert_element_type_1136 = torch.ops.prims.convert_element_type.default(primals_348, torch.bfloat16) + all_gather_into_tensor_356 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1136, 128, '0'); convert_element_type_1136 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_356); all_gather_into_tensor_356 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_436, [1, 0]); wait_tensor_436 = None + view_1395 = torch.ops.aten.view.default(view_1393, [8192, 2048]); view_1393 = None + mm_170 = torch.ops.aten.mm.default(view_1395, permute_318); view_1395 = permute_318 = None + view_1396 = torch.ops.aten.view.default(mm_170, [2, 4096, 2048]); mm_170 = None + add_1368 = torch.ops.aten.add.Tensor(add_1365, view_1396); view_1396 = None + convert_element_type_1139 = torch.ops.prims.convert_element_type.default(primals_349, torch.bfloat16) + all_gather_into_tensor_357 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1139, 128, '0'); convert_element_type_1139 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_357); all_gather_into_tensor_357 = None + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(add_1368, torch.float32) + pow_66 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1140, 2) + mean_65 = torch.ops.aten.mean.dim(pow_66, [2], True); pow_66 = None + add_1369 = torch.ops.aten.add.Scalar(mean_65, 1e-05); mean_65 = None + rsqrt_65 = torch.ops.aten.rsqrt.default(add_1369); add_1369 = None + mul_995 = torch.ops.aten.mul.Tensor(convert_element_type_1140, rsqrt_65); convert_element_type_1140 = None + mul_996 = torch.ops.aten.mul.Tensor(mul_995, wait_tensor_437); mul_995 = wait_tensor_437 = None + convert_element_type_1141 = torch.ops.prims.convert_element_type.default(mul_996, torch.bfloat16); mul_996 = None + view_1398 = torch.ops.aten.view.default(convert_element_type_1141, [-1, 2048]); convert_element_type_1141 = None + convert_element_type_1142 = torch.ops.prims.convert_element_type.default(primals_351, torch.bfloat16) + all_gather_into_tensor_358 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1142, 128, '0'); convert_element_type_1142 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_358); all_gather_into_tensor_358 = None + slice_129 = torch.ops.aten.slice.Tensor(wait_tensor_438, 0, 0, 64); wait_tensor_438 = None + permute_319 = torch.ops.aten.permute.default(slice_129, [1, 0]); slice_129 = None + mm_171 = torch.ops.aten.mm.default(view_1398, permute_319); permute_319 = None + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(mm_171, torch.float32) + amax_20 = torch.ops.aten.amax.default(convert_element_type_1145, [1], True) + sub_480 = torch.ops.aten.sub.Tensor(convert_element_type_1145, amax_20); convert_element_type_1145 = None + exp_61 = torch.ops.aten.exp.default(sub_480); sub_480 = None + sum_81 = torch.ops.aten.sum.dim_IntList(exp_61, [1], True) + div_101 = torch.ops.aten.div.Tensor(exp_61, sum_81); exp_61 = None + add_1370 = torch.ops.aten.add.Tensor(div_101, primals_350); primals_350 = None + topk_20 = torch.ops.aten.topk.default(add_1370, 6, -1, True, False); add_1370 = None + getitem_2219 = topk_20[1]; topk_20 = None + gather_20 = torch.ops.aten.gather.default(div_101, 1, getitem_2219); div_101 = None + mul_997 = torch.ops.aten.mul.Tensor(gather_20, 1.0); gather_20 = None + view_1400 = torch.ops.aten.view.default(getitem_2219, [-1]) + histc_40 = torch.ops.aten.histc.default(view_1400, 64, 0, 64) + add_1371 = torch.ops.aten.add.Tensor(primals_352, histc_40) + sort_20 = torch.ops.aten.sort.stable(view_1400, stable = True); view_1400 = None + getitem_2221 = sort_20[1]; sort_20 = None + div_102 = torch.ops.aten.div.Tensor_mode(getitem_2221, 6, rounding_mode = 'floor') + index_40 = torch.ops.aten.index.Tensor(view_1398, [div_102]) + all_to_all_single_60 = torch.ops._c10d_functional.all_to_all_single.default(histc_40, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_439 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_60); all_to_all_single_60 = None + wait_tensor_440 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_439); wait_tensor_439 = None + view_1404 = torch.ops.aten.view.default(histc_40, [8, -1]); histc_40 = None + sum_82 = torch.ops.aten.sum.dim_IntList(view_1404, [1]); view_1404 = None + device_put_40 = torch.ops.prims.device_put.default(sum_82, device(type='cpu'), True); sum_82 = None + view_1405 = torch.ops.aten.view.default(wait_tensor_440, [8, -1]) + sum_83 = torch.ops.aten.sum.dim_IntList(view_1405, [1]) + device_put_41 = torch.ops.prims.device_put.default(sum_83, device(type='cpu')); sum_83 = None + select_320 = torch.ops.aten.select.int(device_put_40, 0, 0) + _local_scalar_dense_320 = torch.ops.aten._local_scalar_dense.default(select_320); select_320 = None + ge_400 = _local_scalar_dense_320 >= 0 + _assert_scalar_320 = torch.ops.aten._assert_scalar.default(ge_400, "Runtime assertion failed for expression u320 >= 0 on node 'ge_320'"); ge_400 = _assert_scalar_320 = None + select_321 = torch.ops.aten.select.int(device_put_40, 0, 1) + _local_scalar_dense_321 = torch.ops.aten._local_scalar_dense.default(select_321); select_321 = None + ge_401 = _local_scalar_dense_321 >= 0 + _assert_scalar_321 = torch.ops.aten._assert_scalar.default(ge_401, "Runtime assertion failed for expression u321 >= 0 on node 'ge_321'"); ge_401 = _assert_scalar_321 = None + select_322 = torch.ops.aten.select.int(device_put_40, 0, 2) + _local_scalar_dense_322 = torch.ops.aten._local_scalar_dense.default(select_322); select_322 = None + ge_402 = _local_scalar_dense_322 >= 0 + _assert_scalar_322 = torch.ops.aten._assert_scalar.default(ge_402, "Runtime assertion failed for expression u322 >= 0 on node 'ge_322'"); ge_402 = _assert_scalar_322 = None + select_323 = torch.ops.aten.select.int(device_put_40, 0, 3) + _local_scalar_dense_323 = torch.ops.aten._local_scalar_dense.default(select_323); select_323 = None + ge_403 = _local_scalar_dense_323 >= 0 + _assert_scalar_323 = torch.ops.aten._assert_scalar.default(ge_403, "Runtime assertion failed for expression u323 >= 0 on node 'ge_323'"); ge_403 = _assert_scalar_323 = None + select_324 = torch.ops.aten.select.int(device_put_40, 0, 4) + _local_scalar_dense_324 = torch.ops.aten._local_scalar_dense.default(select_324); select_324 = None + ge_404 = _local_scalar_dense_324 >= 0 + _assert_scalar_324 = torch.ops.aten._assert_scalar.default(ge_404, "Runtime assertion failed for expression u324 >= 0 on node 'ge_324'"); ge_404 = _assert_scalar_324 = None + select_325 = torch.ops.aten.select.int(device_put_40, 0, 5) + _local_scalar_dense_325 = torch.ops.aten._local_scalar_dense.default(select_325); select_325 = None + ge_405 = _local_scalar_dense_325 >= 0 + _assert_scalar_325 = torch.ops.aten._assert_scalar.default(ge_405, "Runtime assertion failed for expression u325 >= 0 on node 'ge_325'"); ge_405 = _assert_scalar_325 = None + select_326 = torch.ops.aten.select.int(device_put_40, 0, 6) + _local_scalar_dense_326 = torch.ops.aten._local_scalar_dense.default(select_326); select_326 = None + ge_406 = _local_scalar_dense_326 >= 0 + _assert_scalar_326 = torch.ops.aten._assert_scalar.default(ge_406, "Runtime assertion failed for expression u326 >= 0 on node 'ge_326'"); ge_406 = _assert_scalar_326 = None + select_327 = torch.ops.aten.select.int(device_put_40, 0, 7); device_put_40 = None + _local_scalar_dense_327 = torch.ops.aten._local_scalar_dense.default(select_327); select_327 = None + ge_407 = _local_scalar_dense_327 >= 0 + _assert_scalar_327 = torch.ops.aten._assert_scalar.default(ge_407, "Runtime assertion failed for expression u327 >= 0 on node 'ge_327'"); ge_407 = _assert_scalar_327 = None + select_328 = torch.ops.aten.select.int(device_put_41, 0, 0) + _local_scalar_dense_328 = torch.ops.aten._local_scalar_dense.default(select_328); select_328 = None + ge_408 = _local_scalar_dense_328 >= 0 + _assert_scalar_328 = torch.ops.aten._assert_scalar.default(ge_408, "Runtime assertion failed for expression u328 >= 0 on node 'ge_328'"); ge_408 = _assert_scalar_328 = None + select_329 = torch.ops.aten.select.int(device_put_41, 0, 1) + _local_scalar_dense_329 = torch.ops.aten._local_scalar_dense.default(select_329); select_329 = None + ge_409 = _local_scalar_dense_329 >= 0 + _assert_scalar_329 = torch.ops.aten._assert_scalar.default(ge_409, "Runtime assertion failed for expression u329 >= 0 on node 'ge_329'"); ge_409 = _assert_scalar_329 = None + select_330 = torch.ops.aten.select.int(device_put_41, 0, 2) + _local_scalar_dense_330 = torch.ops.aten._local_scalar_dense.default(select_330); select_330 = None + ge_410 = _local_scalar_dense_330 >= 0 + _assert_scalar_330 = torch.ops.aten._assert_scalar.default(ge_410, "Runtime assertion failed for expression u330 >= 0 on node 'ge_330'"); ge_410 = _assert_scalar_330 = None + select_331 = torch.ops.aten.select.int(device_put_41, 0, 3) + _local_scalar_dense_331 = torch.ops.aten._local_scalar_dense.default(select_331); select_331 = None + ge_411 = _local_scalar_dense_331 >= 0 + _assert_scalar_331 = torch.ops.aten._assert_scalar.default(ge_411, "Runtime assertion failed for expression u331 >= 0 on node 'ge_331'"); ge_411 = _assert_scalar_331 = None + select_332 = torch.ops.aten.select.int(device_put_41, 0, 4) + _local_scalar_dense_332 = torch.ops.aten._local_scalar_dense.default(select_332); select_332 = None + ge_412 = _local_scalar_dense_332 >= 0 + _assert_scalar_332 = torch.ops.aten._assert_scalar.default(ge_412, "Runtime assertion failed for expression u332 >= 0 on node 'ge_332'"); ge_412 = _assert_scalar_332 = None + select_333 = torch.ops.aten.select.int(device_put_41, 0, 5) + _local_scalar_dense_333 = torch.ops.aten._local_scalar_dense.default(select_333); select_333 = None + ge_413 = _local_scalar_dense_333 >= 0 + _assert_scalar_333 = torch.ops.aten._assert_scalar.default(ge_413, "Runtime assertion failed for expression u333 >= 0 on node 'ge_333'"); ge_413 = _assert_scalar_333 = None + select_334 = torch.ops.aten.select.int(device_put_41, 0, 6) + _local_scalar_dense_334 = torch.ops.aten._local_scalar_dense.default(select_334); select_334 = None + ge_414 = _local_scalar_dense_334 >= 0 + _assert_scalar_334 = torch.ops.aten._assert_scalar.default(ge_414, "Runtime assertion failed for expression u334 >= 0 on node 'ge_334'"); ge_414 = _assert_scalar_334 = None + select_335 = torch.ops.aten.select.int(device_put_41, 0, 7); device_put_41 = None + _local_scalar_dense_335 = torch.ops.aten._local_scalar_dense.default(select_335); select_335 = None + ge_415 = _local_scalar_dense_335 >= 0 + _assert_scalar_335 = torch.ops.aten._assert_scalar.default(ge_415, "Runtime assertion failed for expression u335 >= 0 on node 'ge_335'"); ge_415 = _assert_scalar_335 = None + all_to_all_single_61 = torch.ops._c10d_functional.all_to_all_single.default(index_40, [_local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335], [_local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327], '1033'); index_40 = None + sym_size_int_80 = torch.ops.aten.sym_size.int(all_to_all_single_61, 0) + wait_tensor_441 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_61); all_to_all_single_61 = None + sym_sum_40 = torch.sym_sum((_local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335)) + add_1378 = sym_sum_40 + 64; sym_sum_40 = None + add_1379 = add_1378 + 8; add_1378 = None + sub_483 = add_1379 - 1; add_1379 = None + floordiv_20 = sub_483 // 8; sub_483 = None + mul_1002 = floordiv_20 * 8; floordiv_20 = None + cumsum_60 = torch.ops.aten.cumsum.default(wait_tensor_440, 0) + sub_484 = torch.ops.aten.sub.Tensor(cumsum_60, wait_tensor_440); cumsum_60 = None + sum_84 = torch.ops.aten.sum.dim_IntList(view_1405, [0]); view_1405 = None + clamp_min_20 = torch.ops.aten.clamp_min.default(sum_84, 8); sum_84 = None + add_1380 = torch.ops.aten.add.Tensor(clamp_min_20, 8); clamp_min_20 = None + sub_485 = torch.ops.aten.sub.Tensor(add_1380, 1); add_1380 = None + div_103 = torch.ops.aten.div.Tensor_mode(sub_485, 8, rounding_mode = 'floor'); sub_485 = None + mul_1003 = torch.ops.aten.mul.Tensor(div_103, 8); div_103 = None + convert_element_type_1148 = torch.ops.prims.convert_element_type.default(mul_1003, torch.int32); mul_1003 = None + cumsum_61 = torch.ops.aten.cumsum.default(convert_element_type_1148, 0) + sub_486 = torch.ops.aten.sub.Tensor(cumsum_61, convert_element_type_1148); cumsum_61 = None + full_280 = torch.ops.aten.full.default([mul_1002], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1002 = None + triton_kernel_wrapper_functional_proxy_20 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 20, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_440, 'start_index_values_ptr': sub_484, 'write_offsets_ptr': sub_486, 'output_ptr': full_280}, tensors_to_clone = ['output_ptr']); wait_tensor_440 = sub_484 = sub_486 = full_280 = None + getitem_2222 = triton_kernel_wrapper_functional_proxy_20['output_ptr']; triton_kernel_wrapper_functional_proxy_20 = None + cat_184 = torch.ops.aten.cat.default([wait_tensor_441, full_default]); wait_tensor_441 = None + sym_size_int_81 = torch.ops.aten.sym_size.int(cat_184, 0) + sym_sum_41 = torch.sym_sum((1, _local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335)) + index_41 = torch.ops.aten.index.Tensor(cat_184, [getitem_2222]); cat_184 = None + convert_element_type_1150 = torch.ops.prims.convert_element_type.default(primals_353, torch.bfloat16) + all_gather_into_tensor_359 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1150, 16, '1025'); convert_element_type_1150 = None + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_359); all_gather_into_tensor_359 = None + split_121 = torch.ops.aten.split.Tensor(wait_tensor_442, 8); wait_tensor_442 = None + getitem_2239 = split_121[0] + getitem_2240 = split_121[1] + getitem_2241 = split_121[2] + getitem_2242 = split_121[3] + getitem_2243 = split_121[4] + getitem_2244 = split_121[5] + getitem_2245 = split_121[6] + getitem_2246 = split_121[7] + getitem_2247 = split_121[8] + getitem_2248 = split_121[9] + getitem_2249 = split_121[10] + getitem_2250 = split_121[11] + getitem_2251 = split_121[12] + getitem_2252 = split_121[13] + getitem_2253 = split_121[14] + getitem_2254 = split_121[15]; split_121 = None + cat_186 = torch.ops.aten.cat.default([getitem_2239, getitem_2240, getitem_2241, getitem_2242, getitem_2243, getitem_2244, getitem_2245, getitem_2246, getitem_2247, getitem_2248, getitem_2249, getitem_2250, getitem_2251, getitem_2252, getitem_2253, getitem_2254], 1); getitem_2239 = getitem_2240 = getitem_2241 = getitem_2242 = getitem_2243 = getitem_2244 = getitem_2245 = getitem_2246 = getitem_2247 = getitem_2248 = getitem_2249 = getitem_2250 = getitem_2251 = getitem_2252 = getitem_2253 = getitem_2254 = None + convert_element_type_1152 = torch.ops.prims.convert_element_type.default(primals_354, torch.bfloat16) + all_gather_into_tensor_361 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1152, 16, '1025'); convert_element_type_1152 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_361); all_gather_into_tensor_361 = None + split_122 = torch.ops.aten.split.Tensor(wait_tensor_444, 8); wait_tensor_444 = None + getitem_2255 = split_122[0] + getitem_2256 = split_122[1] + getitem_2257 = split_122[2] + getitem_2258 = split_122[3] + getitem_2259 = split_122[4] + getitem_2260 = split_122[5] + getitem_2261 = split_122[6] + getitem_2262 = split_122[7] + getitem_2263 = split_122[8] + getitem_2264 = split_122[9] + getitem_2265 = split_122[10] + getitem_2266 = split_122[11] + getitem_2267 = split_122[12] + getitem_2268 = split_122[13] + getitem_2269 = split_122[14] + getitem_2270 = split_122[15]; split_122 = None + cat_187 = torch.ops.aten.cat.default([getitem_2255, getitem_2256, getitem_2257, getitem_2258, getitem_2259, getitem_2260, getitem_2261, getitem_2262, getitem_2263, getitem_2264, getitem_2265, getitem_2266, getitem_2267, getitem_2268, getitem_2269, getitem_2270], 1); getitem_2255 = getitem_2256 = getitem_2257 = getitem_2258 = getitem_2259 = getitem_2260 = getitem_2261 = getitem_2262 = getitem_2263 = getitem_2264 = getitem_2265 = getitem_2266 = getitem_2267 = getitem_2268 = getitem_2269 = getitem_2270 = None + convert_element_type_1153 = torch.ops.prims.convert_element_type.default(primals_355, torch.bfloat16) + all_gather_into_tensor_362 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1153, 16, '1025'); convert_element_type_1153 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_362); all_gather_into_tensor_362 = None + split_123 = torch.ops.aten.split.Tensor(wait_tensor_445, 8); wait_tensor_445 = None + getitem_2271 = split_123[0] + getitem_2272 = split_123[1] + getitem_2273 = split_123[2] + getitem_2274 = split_123[3] + getitem_2275 = split_123[4] + getitem_2276 = split_123[5] + getitem_2277 = split_123[6] + getitem_2278 = split_123[7] + getitem_2279 = split_123[8] + getitem_2280 = split_123[9] + getitem_2281 = split_123[10] + getitem_2282 = split_123[11] + getitem_2283 = split_123[12] + getitem_2284 = split_123[13] + getitem_2285 = split_123[14] + getitem_2286 = split_123[15]; split_123 = None + cat_188 = torch.ops.aten.cat.default([getitem_2271, getitem_2272, getitem_2273, getitem_2274, getitem_2275, getitem_2276, getitem_2277, getitem_2278, getitem_2279, getitem_2280, getitem_2281, getitem_2282, getitem_2283, getitem_2284, getitem_2285, getitem_2286], 1); getitem_2271 = getitem_2272 = getitem_2273 = getitem_2274 = getitem_2275 = getitem_2276 = getitem_2277 = getitem_2278 = getitem_2279 = getitem_2280 = getitem_2281 = getitem_2282 = getitem_2283 = getitem_2284 = getitem_2285 = getitem_2286 = None + cumsum_62 = torch.ops.aten.cumsum.default(convert_element_type_1148, 0, dtype = torch.int32); convert_element_type_1148 = None + permute_320 = torch.ops.aten.permute.default(cat_186, [0, 2, 1]); cat_186 = None + _grouped_mm_60 = torch.ops.aten._grouped_mm.default(index_41, permute_320, cumsum_62) + convert_element_type_1156 = torch.ops.prims.convert_element_type.default(_grouped_mm_60, torch.float32) + neg_41 = torch.ops.aten.neg.default(convert_element_type_1156) + exp_62 = torch.ops.aten.exp.default(neg_41); neg_41 = None + add_1392 = torch.ops.aten.add.Tensor(exp_62, 1); exp_62 = None + div_104 = torch.ops.aten.div.Tensor(convert_element_type_1156, add_1392); convert_element_type_1156 = add_1392 = None + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(div_104, torch.bfloat16); div_104 = None + permute_321 = torch.ops.aten.permute.default(cat_188, [0, 2, 1]); cat_188 = None + _grouped_mm_61 = torch.ops.aten._grouped_mm.default(index_41, permute_321, cumsum_62) + mul_1015 = torch.ops.aten.mul.Tensor(convert_element_type_1157, _grouped_mm_61); convert_element_type_1157 = None + permute_322 = torch.ops.aten.permute.default(cat_187, [0, 2, 1]); cat_187 = None + _grouped_mm_62 = torch.ops.aten._grouped_mm.default(mul_1015, permute_322, cumsum_62) + empty_20 = torch.ops.aten.empty.memory_format([sym_size_int_81, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_40 = torch.ops.aten.index_put.default(empty_20, [getitem_2222], _grouped_mm_62); empty_20 = _grouped_mm_62 = None + slice_131 = torch.ops.aten.slice.Tensor(index_put_40, 0, 0, -1); index_put_40 = None + all_to_all_single_62 = torch.ops._c10d_functional.all_to_all_single.default(slice_131, [_local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327], [_local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335], '1033'); slice_131 = None + wait_tensor_448 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_62); all_to_all_single_62 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(primals_356, torch.bfloat16) + all_gather_into_tensor_365 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1158, 128, '0'); convert_element_type_1158 = None + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_365); all_gather_into_tensor_365 = None + permute_323 = torch.ops.aten.permute.default(wait_tensor_449, [1, 0]); wait_tensor_449 = None + mm_172 = torch.ops.aten.mm.default(view_1398, permute_323); permute_323 = None + convert_element_type_1161 = torch.ops.prims.convert_element_type.default(mm_172, torch.float32) + neg_42 = torch.ops.aten.neg.default(convert_element_type_1161) + exp_63 = torch.ops.aten.exp.default(neg_42); neg_42 = None + add_1428 = torch.ops.aten.add.Tensor(exp_63, 1); exp_63 = None + div_105 = torch.ops.aten.div.Tensor(convert_element_type_1161, add_1428); convert_element_type_1161 = add_1428 = None + convert_element_type_1162 = torch.ops.prims.convert_element_type.default(div_105, torch.bfloat16); div_105 = None + convert_element_type_1163 = torch.ops.prims.convert_element_type.default(primals_357, torch.bfloat16) + all_gather_into_tensor_366 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1163, 128, '0'); convert_element_type_1163 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_366); all_gather_into_tensor_366 = None + permute_324 = torch.ops.aten.permute.default(wait_tensor_450, [1, 0]); wait_tensor_450 = None + mm_173 = torch.ops.aten.mm.default(view_1398, permute_324); permute_324 = None + mul_1035 = torch.ops.aten.mul.Tensor(convert_element_type_1162, mm_173); convert_element_type_1162 = None + convert_element_type_1166 = torch.ops.prims.convert_element_type.default(primals_358, torch.bfloat16) + all_gather_into_tensor_367 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1166, 128, '0'); convert_element_type_1166 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_367); all_gather_into_tensor_367 = None + permute_325 = torch.ops.aten.permute.default(wait_tensor_451, [1, 0]); wait_tensor_451 = None + mm_174 = torch.ops.aten.mm.default(mul_1035, permute_325); permute_325 = None + index_put_41 = torch.ops.aten.index_put.default(full_default_1, [getitem_2221], wait_tensor_448); wait_tensor_448 = None + view_1438 = torch.ops.aten.view.default(mul_997, [-1, 1, 6]); mul_997 = None + view_1439 = torch.ops.aten.view.default(index_put_41, [-1, 6, 2048]); index_put_41 = None + convert_element_type_1169 = torch.ops.prims.convert_element_type.default(view_1439, torch.float32); view_1439 = None + bmm_20 = torch.ops.aten.bmm.default(view_1438, convert_element_type_1169) + convert_element_type_1170 = torch.ops.prims.convert_element_type.default(bmm_20, torch.bfloat16); bmm_20 = None + squeeze_20 = torch.ops.aten.squeeze.dim(convert_element_type_1170, 1); convert_element_type_1170 = None + add_1432 = torch.ops.aten.add.Tensor(mm_174, squeeze_20); mm_174 = squeeze_20 = None + view_1440 = torch.ops.aten.view.default(add_1432, [2, 4096, 2048]); add_1432 = None + add_1433 = torch.ops.aten.add.Tensor(add_1368, view_1440); view_1440 = None + convert_element_type_1171 = torch.ops.prims.convert_element_type.default(primals_359, torch.bfloat16) + all_gather_into_tensor_368 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1171, 128, '0'); convert_element_type_1171 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_368); all_gather_into_tensor_368 = None + convert_element_type_1172 = torch.ops.prims.convert_element_type.default(add_1433, torch.float32) + pow_67 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1172, 2) + mean_66 = torch.ops.aten.mean.dim(pow_67, [2], True); pow_67 = None + add_1434 = torch.ops.aten.add.Scalar(mean_66, 1e-05); mean_66 = None + rsqrt_66 = torch.ops.aten.rsqrt.default(add_1434); add_1434 = None + mul_1038 = torch.ops.aten.mul.Tensor(convert_element_type_1172, rsqrt_66); convert_element_type_1172 = None + mul_1039 = torch.ops.aten.mul.Tensor(mul_1038, wait_tensor_452); mul_1038 = wait_tensor_452 = None + convert_element_type_1173 = torch.ops.prims.convert_element_type.default(mul_1039, torch.bfloat16); mul_1039 = None + convert_element_type_1174 = torch.ops.prims.convert_element_type.default(primals_360, torch.bfloat16) + all_gather_into_tensor_369 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1174, 128, '0'); convert_element_type_1174 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_369); all_gather_into_tensor_369 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_453, [1, 0]); wait_tensor_453 = None + view_1443 = torch.ops.aten.view.default(convert_element_type_1173, [8192, 2048]); convert_element_type_1173 = None + mm_175 = torch.ops.aten.mm.default(view_1443, permute_326); permute_326 = None + view_1444 = torch.ops.aten.view.default(mm_175, [2, 4096, 3072]); mm_175 = None + view_1445 = torch.ops.aten.view.default(view_1444, [2, 4096, -1, 192]); view_1444 = None + split_with_sizes_66 = torch.ops.aten.split_with_sizes.default(view_1445, [128, 64], -1); view_1445 = None + getitem_2319 = split_with_sizes_66[0] + getitem_2320 = split_with_sizes_66[1]; split_with_sizes_66 = None + convert_element_type_1177 = torch.ops.prims.convert_element_type.default(getitem_2320, torch.float32); getitem_2320 = None + view_1446 = torch.ops.aten.view.default(convert_element_type_1177, [2, 4096, 16, -1, 2]); convert_element_type_1177 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_1446); view_1446 = None + mul_1040 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_7); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_1040); mul_1040 = None + view_1448 = torch.ops.aten.view.default(view_as_real_44, [2, 4096, 16, 64]); view_as_real_44 = None + convert_element_type_1178 = torch.ops.prims.convert_element_type.default(view_1448, torch.bfloat16); view_1448 = None + cat_191 = torch.ops.aten.cat.default([getitem_2319, convert_element_type_1178], -1); getitem_2319 = convert_element_type_1178 = None + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(primals_361, torch.bfloat16) + all_gather_into_tensor_370 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1179, 128, '0'); convert_element_type_1179 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_370); all_gather_into_tensor_370 = None + slice_133 = torch.ops.aten.slice.Tensor(wait_tensor_454, 0, 0, 576); wait_tensor_454 = None + permute_327 = torch.ops.aten.permute.default(slice_133, [1, 0]); slice_133 = None + mm_176 = torch.ops.aten.mm.default(view_1443, permute_327); permute_327 = None + view_1451 = torch.ops.aten.view.default(mm_176, [2, 4096, 576]); mm_176 = None + split_with_sizes_67 = torch.ops.aten.split_with_sizes.default(view_1451, [512, 64], -1); view_1451 = None + getitem_2321 = split_with_sizes_67[0] + getitem_2322 = split_with_sizes_67[1]; split_with_sizes_67 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(getitem_2322, 2); getitem_2322 = None + convert_element_type_1182 = torch.ops.prims.convert_element_type.default(unsqueeze_43, torch.float32); unsqueeze_43 = None + view_1452 = torch.ops.aten.view.default(convert_element_type_1182, [2, 4096, 1, -1, 2]); convert_element_type_1182 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_1452); view_1452 = None + mul_1041 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_7); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_1041); mul_1041 = None + view_1454 = torch.ops.aten.view.default(view_as_real_45, [2, 4096, 1, 64]); view_as_real_45 = None + convert_element_type_1183 = torch.ops.prims.convert_element_type.default(view_1454, torch.bfloat16); view_1454 = None + convert_element_type_1184 = torch.ops.prims.convert_element_type.default(primals_362, torch.bfloat16) + all_gather_into_tensor_371 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1184, 128, '0'); convert_element_type_1184 = None + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_371); all_gather_into_tensor_371 = None + convert_element_type_1185 = torch.ops.prims.convert_element_type.default(getitem_2321, torch.float32) + pow_68 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1185, 2) + mean_67 = torch.ops.aten.mean.dim(pow_68, [2], True); pow_68 = None + add_1435 = torch.ops.aten.add.Scalar(mean_67, 1e-05); mean_67 = None + rsqrt_67 = torch.ops.aten.rsqrt.default(add_1435); add_1435 = None + mul_1042 = torch.ops.aten.mul.Tensor(convert_element_type_1185, rsqrt_67); convert_element_type_1185 = None + mul_1043 = torch.ops.aten.mul.Tensor(mul_1042, wait_tensor_455); mul_1042 = wait_tensor_455 = None + convert_element_type_1186 = torch.ops.prims.convert_element_type.default(mul_1043, torch.bfloat16); mul_1043 = None + convert_element_type_1187 = torch.ops.prims.convert_element_type.default(primals_363, torch.bfloat16) + all_gather_into_tensor_372 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1187, 128, '0'); convert_element_type_1187 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_372); all_gather_into_tensor_372 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_456, [1, 0]); wait_tensor_456 = None + view_1457 = torch.ops.aten.view.default(convert_element_type_1186, [8192, 512]); convert_element_type_1186 = None + mm_177 = torch.ops.aten.mm.default(view_1457, permute_328); permute_328 = None + view_1458 = torch.ops.aten.view.default(mm_177, [2, 4096, 4096]); mm_177 = None + view_1459 = torch.ops.aten.view.default(view_1458, [2, 4096, -1, 256]); view_1458 = None + split_with_sizes_68 = torch.ops.aten.split_with_sizes.default(view_1459, [128, 128], -1); view_1459 = None + getitem_2323 = split_with_sizes_68[0] + getitem_2324 = split_with_sizes_68[1]; split_with_sizes_68 = None + expand_22 = torch.ops.aten.expand.default(convert_element_type_1183, [-1, -1, 16, -1]); convert_element_type_1183 = None + cat_192 = torch.ops.aten.cat.default([getitem_2323, expand_22], -1); getitem_2323 = expand_22 = None + permute_329 = torch.ops.aten.permute.default(cat_191, [0, 2, 1, 3]); cat_191 = None + permute_330 = torch.ops.aten.permute.default(cat_192, [0, 2, 1, 3]); cat_192 = None + permute_331 = torch.ops.aten.permute.default(getitem_2324, [0, 2, 1, 3]); getitem_2324 = None + sdpa_score22 = self.sdpa_score22 + sdpa_mask22 = self.sdpa_mask22 + flex_attention_22 = torch.ops.higher_order.flex_attention(permute_329, permute_330, permute_331, sdpa_score22, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask22), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score22 = sdpa_mask22 = None + getitem_2325 = flex_attention_22[0] + getitem_2326 = flex_attention_22[1]; flex_attention_22 = None + permute_332 = torch.ops.aten.permute.default(getitem_2325, [0, 2, 1, 3]) + view_1460 = torch.ops.aten.view.default(permute_332, [2, 4096, -1]); permute_332 = None + convert_element_type_1190 = torch.ops.prims.convert_element_type.default(primals_364, torch.bfloat16) + all_gather_into_tensor_373 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1190, 128, '0'); convert_element_type_1190 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_373); all_gather_into_tensor_373 = None + permute_333 = torch.ops.aten.permute.default(wait_tensor_457, [1, 0]); wait_tensor_457 = None + view_1462 = torch.ops.aten.view.default(view_1460, [8192, 2048]); view_1460 = None + mm_178 = torch.ops.aten.mm.default(view_1462, permute_333); view_1462 = permute_333 = None + view_1463 = torch.ops.aten.view.default(mm_178, [2, 4096, 2048]); mm_178 = None + add_1436 = torch.ops.aten.add.Tensor(add_1433, view_1463); view_1463 = None + convert_element_type_1193 = torch.ops.prims.convert_element_type.default(primals_365, torch.bfloat16) + all_gather_into_tensor_374 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1193, 128, '0'); convert_element_type_1193 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_374); all_gather_into_tensor_374 = None + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(add_1436, torch.float32) + pow_69 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1194, 2) + mean_68 = torch.ops.aten.mean.dim(pow_69, [2], True); pow_69 = None + add_1437 = torch.ops.aten.add.Scalar(mean_68, 1e-05); mean_68 = None + rsqrt_68 = torch.ops.aten.rsqrt.default(add_1437); add_1437 = None + mul_1044 = torch.ops.aten.mul.Tensor(convert_element_type_1194, rsqrt_68); convert_element_type_1194 = None + mul_1045 = torch.ops.aten.mul.Tensor(mul_1044, wait_tensor_458); mul_1044 = wait_tensor_458 = None + convert_element_type_1195 = torch.ops.prims.convert_element_type.default(mul_1045, torch.bfloat16); mul_1045 = None + view_1465 = torch.ops.aten.view.default(convert_element_type_1195, [-1, 2048]); convert_element_type_1195 = None + convert_element_type_1196 = torch.ops.prims.convert_element_type.default(primals_367, torch.bfloat16) + all_gather_into_tensor_375 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1196, 128, '0'); convert_element_type_1196 = None + wait_tensor_459 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_375); all_gather_into_tensor_375 = None + slice_135 = torch.ops.aten.slice.Tensor(wait_tensor_459, 0, 0, 64); wait_tensor_459 = None + permute_334 = torch.ops.aten.permute.default(slice_135, [1, 0]); slice_135 = None + mm_179 = torch.ops.aten.mm.default(view_1465, permute_334); permute_334 = None + convert_element_type_1199 = torch.ops.prims.convert_element_type.default(mm_179, torch.float32) + amax_21 = torch.ops.aten.amax.default(convert_element_type_1199, [1], True) + sub_504 = torch.ops.aten.sub.Tensor(convert_element_type_1199, amax_21); convert_element_type_1199 = None + exp_64 = torch.ops.aten.exp.default(sub_504); sub_504 = None + sum_85 = torch.ops.aten.sum.dim_IntList(exp_64, [1], True) + div_106 = torch.ops.aten.div.Tensor(exp_64, sum_85); exp_64 = None + add_1438 = torch.ops.aten.add.Tensor(div_106, primals_366); primals_366 = None + topk_21 = torch.ops.aten.topk.default(add_1438, 6, -1, True, False); add_1438 = None + getitem_2329 = topk_21[1]; topk_21 = None + gather_21 = torch.ops.aten.gather.default(div_106, 1, getitem_2329); div_106 = None + mul_1046 = torch.ops.aten.mul.Tensor(gather_21, 1.0); gather_21 = None + view_1467 = torch.ops.aten.view.default(getitem_2329, [-1]) + histc_42 = torch.ops.aten.histc.default(view_1467, 64, 0, 64) + add_1439 = torch.ops.aten.add.Tensor(primals_368, histc_42) + sort_21 = torch.ops.aten.sort.stable(view_1467, stable = True); view_1467 = None + getitem_2331 = sort_21[1]; sort_21 = None + div_107 = torch.ops.aten.div.Tensor_mode(getitem_2331, 6, rounding_mode = 'floor') + index_42 = torch.ops.aten.index.Tensor(view_1465, [div_107]) + all_to_all_single_63 = torch.ops._c10d_functional.all_to_all_single.default(histc_42, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_460 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_63); all_to_all_single_63 = None + wait_tensor_461 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_460); wait_tensor_460 = None + view_1471 = torch.ops.aten.view.default(histc_42, [8, -1]); histc_42 = None + sum_86 = torch.ops.aten.sum.dim_IntList(view_1471, [1]); view_1471 = None + device_put_42 = torch.ops.prims.device_put.default(sum_86, device(type='cpu'), True); sum_86 = None + view_1472 = torch.ops.aten.view.default(wait_tensor_461, [8, -1]) + sum_87 = torch.ops.aten.sum.dim_IntList(view_1472, [1]) + device_put_43 = torch.ops.prims.device_put.default(sum_87, device(type='cpu')); sum_87 = None + select_336 = torch.ops.aten.select.int(device_put_42, 0, 0) + _local_scalar_dense_336 = torch.ops.aten._local_scalar_dense.default(select_336); select_336 = None + ge_420 = _local_scalar_dense_336 >= 0 + _assert_scalar_336 = torch.ops.aten._assert_scalar.default(ge_420, "Runtime assertion failed for expression u336 >= 0 on node 'ge_336'"); ge_420 = _assert_scalar_336 = None + select_337 = torch.ops.aten.select.int(device_put_42, 0, 1) + _local_scalar_dense_337 = torch.ops.aten._local_scalar_dense.default(select_337); select_337 = None + ge_421 = _local_scalar_dense_337 >= 0 + _assert_scalar_337 = torch.ops.aten._assert_scalar.default(ge_421, "Runtime assertion failed for expression u337 >= 0 on node 'ge_337'"); ge_421 = _assert_scalar_337 = None + select_338 = torch.ops.aten.select.int(device_put_42, 0, 2) + _local_scalar_dense_338 = torch.ops.aten._local_scalar_dense.default(select_338); select_338 = None + ge_422 = _local_scalar_dense_338 >= 0 + _assert_scalar_338 = torch.ops.aten._assert_scalar.default(ge_422, "Runtime assertion failed for expression u338 >= 0 on node 'ge_338'"); ge_422 = _assert_scalar_338 = None + select_339 = torch.ops.aten.select.int(device_put_42, 0, 3) + _local_scalar_dense_339 = torch.ops.aten._local_scalar_dense.default(select_339); select_339 = None + ge_423 = _local_scalar_dense_339 >= 0 + _assert_scalar_339 = torch.ops.aten._assert_scalar.default(ge_423, "Runtime assertion failed for expression u339 >= 0 on node 'ge_339'"); ge_423 = _assert_scalar_339 = None + select_340 = torch.ops.aten.select.int(device_put_42, 0, 4) + _local_scalar_dense_340 = torch.ops.aten._local_scalar_dense.default(select_340); select_340 = None + ge_424 = _local_scalar_dense_340 >= 0 + _assert_scalar_340 = torch.ops.aten._assert_scalar.default(ge_424, "Runtime assertion failed for expression u340 >= 0 on node 'ge_340'"); ge_424 = _assert_scalar_340 = None + select_341 = torch.ops.aten.select.int(device_put_42, 0, 5) + _local_scalar_dense_341 = torch.ops.aten._local_scalar_dense.default(select_341); select_341 = None + ge_425 = _local_scalar_dense_341 >= 0 + _assert_scalar_341 = torch.ops.aten._assert_scalar.default(ge_425, "Runtime assertion failed for expression u341 >= 0 on node 'ge_341'"); ge_425 = _assert_scalar_341 = None + select_342 = torch.ops.aten.select.int(device_put_42, 0, 6) + _local_scalar_dense_342 = torch.ops.aten._local_scalar_dense.default(select_342); select_342 = None + ge_426 = _local_scalar_dense_342 >= 0 + _assert_scalar_342 = torch.ops.aten._assert_scalar.default(ge_426, "Runtime assertion failed for expression u342 >= 0 on node 'ge_342'"); ge_426 = _assert_scalar_342 = None + select_343 = torch.ops.aten.select.int(device_put_42, 0, 7); device_put_42 = None + _local_scalar_dense_343 = torch.ops.aten._local_scalar_dense.default(select_343); select_343 = None + ge_427 = _local_scalar_dense_343 >= 0 + _assert_scalar_343 = torch.ops.aten._assert_scalar.default(ge_427, "Runtime assertion failed for expression u343 >= 0 on node 'ge_343'"); ge_427 = _assert_scalar_343 = None + select_344 = torch.ops.aten.select.int(device_put_43, 0, 0) + _local_scalar_dense_344 = torch.ops.aten._local_scalar_dense.default(select_344); select_344 = None + ge_428 = _local_scalar_dense_344 >= 0 + _assert_scalar_344 = torch.ops.aten._assert_scalar.default(ge_428, "Runtime assertion failed for expression u344 >= 0 on node 'ge_344'"); ge_428 = _assert_scalar_344 = None + select_345 = torch.ops.aten.select.int(device_put_43, 0, 1) + _local_scalar_dense_345 = torch.ops.aten._local_scalar_dense.default(select_345); select_345 = None + ge_429 = _local_scalar_dense_345 >= 0 + _assert_scalar_345 = torch.ops.aten._assert_scalar.default(ge_429, "Runtime assertion failed for expression u345 >= 0 on node 'ge_345'"); ge_429 = _assert_scalar_345 = None + select_346 = torch.ops.aten.select.int(device_put_43, 0, 2) + _local_scalar_dense_346 = torch.ops.aten._local_scalar_dense.default(select_346); select_346 = None + ge_430 = _local_scalar_dense_346 >= 0 + _assert_scalar_346 = torch.ops.aten._assert_scalar.default(ge_430, "Runtime assertion failed for expression u346 >= 0 on node 'ge_346'"); ge_430 = _assert_scalar_346 = None + select_347 = torch.ops.aten.select.int(device_put_43, 0, 3) + _local_scalar_dense_347 = torch.ops.aten._local_scalar_dense.default(select_347); select_347 = None + ge_431 = _local_scalar_dense_347 >= 0 + _assert_scalar_347 = torch.ops.aten._assert_scalar.default(ge_431, "Runtime assertion failed for expression u347 >= 0 on node 'ge_347'"); ge_431 = _assert_scalar_347 = None + select_348 = torch.ops.aten.select.int(device_put_43, 0, 4) + _local_scalar_dense_348 = torch.ops.aten._local_scalar_dense.default(select_348); select_348 = None + ge_432 = _local_scalar_dense_348 >= 0 + _assert_scalar_348 = torch.ops.aten._assert_scalar.default(ge_432, "Runtime assertion failed for expression u348 >= 0 on node 'ge_348'"); ge_432 = _assert_scalar_348 = None + select_349 = torch.ops.aten.select.int(device_put_43, 0, 5) + _local_scalar_dense_349 = torch.ops.aten._local_scalar_dense.default(select_349); select_349 = None + ge_433 = _local_scalar_dense_349 >= 0 + _assert_scalar_349 = torch.ops.aten._assert_scalar.default(ge_433, "Runtime assertion failed for expression u349 >= 0 on node 'ge_349'"); ge_433 = _assert_scalar_349 = None + select_350 = torch.ops.aten.select.int(device_put_43, 0, 6) + _local_scalar_dense_350 = torch.ops.aten._local_scalar_dense.default(select_350); select_350 = None + ge_434 = _local_scalar_dense_350 >= 0 + _assert_scalar_350 = torch.ops.aten._assert_scalar.default(ge_434, "Runtime assertion failed for expression u350 >= 0 on node 'ge_350'"); ge_434 = _assert_scalar_350 = None + select_351 = torch.ops.aten.select.int(device_put_43, 0, 7); device_put_43 = None + _local_scalar_dense_351 = torch.ops.aten._local_scalar_dense.default(select_351); select_351 = None + ge_435 = _local_scalar_dense_351 >= 0 + _assert_scalar_351 = torch.ops.aten._assert_scalar.default(ge_435, "Runtime assertion failed for expression u351 >= 0 on node 'ge_351'"); ge_435 = _assert_scalar_351 = None + all_to_all_single_64 = torch.ops._c10d_functional.all_to_all_single.default(index_42, [_local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351], [_local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343], '1033'); index_42 = None + sym_size_int_84 = torch.ops.aten.sym_size.int(all_to_all_single_64, 0) + wait_tensor_462 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_64); all_to_all_single_64 = None + sym_sum_42 = torch.sym_sum((_local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351)) + add_1446 = sym_sum_42 + 64; sym_sum_42 = None + add_1447 = add_1446 + 8; add_1446 = None + sub_507 = add_1447 - 1; add_1447 = None + floordiv_21 = sub_507 // 8; sub_507 = None + mul_1051 = floordiv_21 * 8; floordiv_21 = None + cumsum_63 = torch.ops.aten.cumsum.default(wait_tensor_461, 0) + sub_508 = torch.ops.aten.sub.Tensor(cumsum_63, wait_tensor_461); cumsum_63 = None + sum_88 = torch.ops.aten.sum.dim_IntList(view_1472, [0]); view_1472 = None + clamp_min_21 = torch.ops.aten.clamp_min.default(sum_88, 8); sum_88 = None + add_1448 = torch.ops.aten.add.Tensor(clamp_min_21, 8); clamp_min_21 = None + sub_509 = torch.ops.aten.sub.Tensor(add_1448, 1); add_1448 = None + div_108 = torch.ops.aten.div.Tensor_mode(sub_509, 8, rounding_mode = 'floor'); sub_509 = None + mul_1052 = torch.ops.aten.mul.Tensor(div_108, 8); div_108 = None + convert_element_type_1202 = torch.ops.prims.convert_element_type.default(mul_1052, torch.int32); mul_1052 = None + cumsum_64 = torch.ops.aten.cumsum.default(convert_element_type_1202, 0) + sub_510 = torch.ops.aten.sub.Tensor(cumsum_64, convert_element_type_1202); cumsum_64 = None + full_293 = torch.ops.aten.full.default([mul_1051], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1051 = None + triton_kernel_wrapper_functional_proxy_21 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 21, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_461, 'start_index_values_ptr': sub_508, 'write_offsets_ptr': sub_510, 'output_ptr': full_293}, tensors_to_clone = ['output_ptr']); wait_tensor_461 = sub_508 = sub_510 = full_293 = None + getitem_2332 = triton_kernel_wrapper_functional_proxy_21['output_ptr']; triton_kernel_wrapper_functional_proxy_21 = None + cat_193 = torch.ops.aten.cat.default([wait_tensor_462, full_default]); wait_tensor_462 = None + sym_size_int_85 = torch.ops.aten.sym_size.int(cat_193, 0) + sym_sum_43 = torch.sym_sum((1, _local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351)) + index_43 = torch.ops.aten.index.Tensor(cat_193, [getitem_2332]); cat_193 = None + convert_element_type_1204 = torch.ops.prims.convert_element_type.default(primals_369, torch.bfloat16) + all_gather_into_tensor_376 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1204, 16, '1025'); convert_element_type_1204 = None + wait_tensor_463 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_376); all_gather_into_tensor_376 = None + split_127 = torch.ops.aten.split.Tensor(wait_tensor_463, 8); wait_tensor_463 = None + getitem_2349 = split_127[0] + getitem_2350 = split_127[1] + getitem_2351 = split_127[2] + getitem_2352 = split_127[3] + getitem_2353 = split_127[4] + getitem_2354 = split_127[5] + getitem_2355 = split_127[6] + getitem_2356 = split_127[7] + getitem_2357 = split_127[8] + getitem_2358 = split_127[9] + getitem_2359 = split_127[10] + getitem_2360 = split_127[11] + getitem_2361 = split_127[12] + getitem_2362 = split_127[13] + getitem_2363 = split_127[14] + getitem_2364 = split_127[15]; split_127 = None + cat_195 = torch.ops.aten.cat.default([getitem_2349, getitem_2350, getitem_2351, getitem_2352, getitem_2353, getitem_2354, getitem_2355, getitem_2356, getitem_2357, getitem_2358, getitem_2359, getitem_2360, getitem_2361, getitem_2362, getitem_2363, getitem_2364], 1); getitem_2349 = getitem_2350 = getitem_2351 = getitem_2352 = getitem_2353 = getitem_2354 = getitem_2355 = getitem_2356 = getitem_2357 = getitem_2358 = getitem_2359 = getitem_2360 = getitem_2361 = getitem_2362 = getitem_2363 = getitem_2364 = None + convert_element_type_1206 = torch.ops.prims.convert_element_type.default(primals_370, torch.bfloat16) + all_gather_into_tensor_378 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1206, 16, '1025'); convert_element_type_1206 = None + wait_tensor_465 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_378); all_gather_into_tensor_378 = None + split_128 = torch.ops.aten.split.Tensor(wait_tensor_465, 8); wait_tensor_465 = None + getitem_2365 = split_128[0] + getitem_2366 = split_128[1] + getitem_2367 = split_128[2] + getitem_2368 = split_128[3] + getitem_2369 = split_128[4] + getitem_2370 = split_128[5] + getitem_2371 = split_128[6] + getitem_2372 = split_128[7] + getitem_2373 = split_128[8] + getitem_2374 = split_128[9] + getitem_2375 = split_128[10] + getitem_2376 = split_128[11] + getitem_2377 = split_128[12] + getitem_2378 = split_128[13] + getitem_2379 = split_128[14] + getitem_2380 = split_128[15]; split_128 = None + cat_196 = torch.ops.aten.cat.default([getitem_2365, getitem_2366, getitem_2367, getitem_2368, getitem_2369, getitem_2370, getitem_2371, getitem_2372, getitem_2373, getitem_2374, getitem_2375, getitem_2376, getitem_2377, getitem_2378, getitem_2379, getitem_2380], 1); getitem_2365 = getitem_2366 = getitem_2367 = getitem_2368 = getitem_2369 = getitem_2370 = getitem_2371 = getitem_2372 = getitem_2373 = getitem_2374 = getitem_2375 = getitem_2376 = getitem_2377 = getitem_2378 = getitem_2379 = getitem_2380 = None + convert_element_type_1207 = torch.ops.prims.convert_element_type.default(primals_371, torch.bfloat16) + all_gather_into_tensor_379 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1207, 16, '1025'); convert_element_type_1207 = None + wait_tensor_466 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_379); all_gather_into_tensor_379 = None + split_129 = torch.ops.aten.split.Tensor(wait_tensor_466, 8); wait_tensor_466 = None + getitem_2381 = split_129[0] + getitem_2382 = split_129[1] + getitem_2383 = split_129[2] + getitem_2384 = split_129[3] + getitem_2385 = split_129[4] + getitem_2386 = split_129[5] + getitem_2387 = split_129[6] + getitem_2388 = split_129[7] + getitem_2389 = split_129[8] + getitem_2390 = split_129[9] + getitem_2391 = split_129[10] + getitem_2392 = split_129[11] + getitem_2393 = split_129[12] + getitem_2394 = split_129[13] + getitem_2395 = split_129[14] + getitem_2396 = split_129[15]; split_129 = None + cat_197 = torch.ops.aten.cat.default([getitem_2381, getitem_2382, getitem_2383, getitem_2384, getitem_2385, getitem_2386, getitem_2387, getitem_2388, getitem_2389, getitem_2390, getitem_2391, getitem_2392, getitem_2393, getitem_2394, getitem_2395, getitem_2396], 1); getitem_2381 = getitem_2382 = getitem_2383 = getitem_2384 = getitem_2385 = getitem_2386 = getitem_2387 = getitem_2388 = getitem_2389 = getitem_2390 = getitem_2391 = getitem_2392 = getitem_2393 = getitem_2394 = getitem_2395 = getitem_2396 = None + cumsum_65 = torch.ops.aten.cumsum.default(convert_element_type_1202, 0, dtype = torch.int32); convert_element_type_1202 = None + permute_335 = torch.ops.aten.permute.default(cat_195, [0, 2, 1]); cat_195 = None + _grouped_mm_63 = torch.ops.aten._grouped_mm.default(index_43, permute_335, cumsum_65) + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(_grouped_mm_63, torch.float32) + neg_43 = torch.ops.aten.neg.default(convert_element_type_1210) + exp_65 = torch.ops.aten.exp.default(neg_43); neg_43 = None + add_1460 = torch.ops.aten.add.Tensor(exp_65, 1); exp_65 = None + div_109 = torch.ops.aten.div.Tensor(convert_element_type_1210, add_1460); convert_element_type_1210 = add_1460 = None + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(div_109, torch.bfloat16); div_109 = None + permute_336 = torch.ops.aten.permute.default(cat_197, [0, 2, 1]); cat_197 = None + _grouped_mm_64 = torch.ops.aten._grouped_mm.default(index_43, permute_336, cumsum_65) + mul_1064 = torch.ops.aten.mul.Tensor(convert_element_type_1211, _grouped_mm_64); convert_element_type_1211 = None + permute_337 = torch.ops.aten.permute.default(cat_196, [0, 2, 1]); cat_196 = None + _grouped_mm_65 = torch.ops.aten._grouped_mm.default(mul_1064, permute_337, cumsum_65) + empty_21 = torch.ops.aten.empty.memory_format([sym_size_int_85, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_42 = torch.ops.aten.index_put.default(empty_21, [getitem_2332], _grouped_mm_65); empty_21 = _grouped_mm_65 = None + slice_137 = torch.ops.aten.slice.Tensor(index_put_42, 0, 0, -1); index_put_42 = None + all_to_all_single_65 = torch.ops._c10d_functional.all_to_all_single.default(slice_137, [_local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343], [_local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351], '1033'); slice_137 = None + wait_tensor_469 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_65); all_to_all_single_65 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(primals_372, torch.bfloat16) + all_gather_into_tensor_382 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1212, 128, '0'); convert_element_type_1212 = None + wait_tensor_470 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_382); all_gather_into_tensor_382 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_470, [1, 0]); wait_tensor_470 = None + mm_180 = torch.ops.aten.mm.default(view_1465, permute_338); permute_338 = None + convert_element_type_1215 = torch.ops.prims.convert_element_type.default(mm_180, torch.float32) + neg_44 = torch.ops.aten.neg.default(convert_element_type_1215) + exp_66 = torch.ops.aten.exp.default(neg_44); neg_44 = None + add_1496 = torch.ops.aten.add.Tensor(exp_66, 1); exp_66 = None + div_110 = torch.ops.aten.div.Tensor(convert_element_type_1215, add_1496); convert_element_type_1215 = add_1496 = None + convert_element_type_1216 = torch.ops.prims.convert_element_type.default(div_110, torch.bfloat16); div_110 = None + convert_element_type_1217 = torch.ops.prims.convert_element_type.default(primals_373, torch.bfloat16) + all_gather_into_tensor_383 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1217, 128, '0'); convert_element_type_1217 = None + wait_tensor_471 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_383); all_gather_into_tensor_383 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_471, [1, 0]); wait_tensor_471 = None + mm_181 = torch.ops.aten.mm.default(view_1465, permute_339); permute_339 = None + mul_1084 = torch.ops.aten.mul.Tensor(convert_element_type_1216, mm_181); convert_element_type_1216 = None + convert_element_type_1220 = torch.ops.prims.convert_element_type.default(primals_374, torch.bfloat16) + all_gather_into_tensor_384 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1220, 128, '0'); convert_element_type_1220 = None + wait_tensor_472 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_384); all_gather_into_tensor_384 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_472, [1, 0]); wait_tensor_472 = None + mm_182 = torch.ops.aten.mm.default(mul_1084, permute_340); permute_340 = None + index_put_43 = torch.ops.aten.index_put.default(full_default_1, [getitem_2331], wait_tensor_469); wait_tensor_469 = None + view_1505 = torch.ops.aten.view.default(mul_1046, [-1, 1, 6]); mul_1046 = None + view_1506 = torch.ops.aten.view.default(index_put_43, [-1, 6, 2048]); index_put_43 = None + convert_element_type_1223 = torch.ops.prims.convert_element_type.default(view_1506, torch.float32); view_1506 = None + bmm_21 = torch.ops.aten.bmm.default(view_1505, convert_element_type_1223) + convert_element_type_1224 = torch.ops.prims.convert_element_type.default(bmm_21, torch.bfloat16); bmm_21 = None + squeeze_21 = torch.ops.aten.squeeze.dim(convert_element_type_1224, 1); convert_element_type_1224 = None + add_1500 = torch.ops.aten.add.Tensor(mm_182, squeeze_21); mm_182 = squeeze_21 = None + view_1507 = torch.ops.aten.view.default(add_1500, [2, 4096, 2048]); add_1500 = None + add_1501 = torch.ops.aten.add.Tensor(add_1436, view_1507); view_1507 = None + convert_element_type_1225 = torch.ops.prims.convert_element_type.default(primals_375, torch.bfloat16) + all_gather_into_tensor_385 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1225, 128, '0'); convert_element_type_1225 = None + wait_tensor_473 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_385); all_gather_into_tensor_385 = None + convert_element_type_1226 = torch.ops.prims.convert_element_type.default(add_1501, torch.float32) + pow_70 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1226, 2) + mean_69 = torch.ops.aten.mean.dim(pow_70, [2], True); pow_70 = None + add_1502 = torch.ops.aten.add.Scalar(mean_69, 1e-05); mean_69 = None + rsqrt_69 = torch.ops.aten.rsqrt.default(add_1502); add_1502 = None + mul_1087 = torch.ops.aten.mul.Tensor(convert_element_type_1226, rsqrt_69); convert_element_type_1226 = None + mul_1088 = torch.ops.aten.mul.Tensor(mul_1087, wait_tensor_473); mul_1087 = wait_tensor_473 = None + convert_element_type_1227 = torch.ops.prims.convert_element_type.default(mul_1088, torch.bfloat16); mul_1088 = None + convert_element_type_1228 = torch.ops.prims.convert_element_type.default(primals_376, torch.bfloat16) + all_gather_into_tensor_386 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1228, 128, '0'); convert_element_type_1228 = None + wait_tensor_474 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_386); all_gather_into_tensor_386 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_474, [1, 0]); wait_tensor_474 = None + view_1510 = torch.ops.aten.view.default(convert_element_type_1227, [8192, 2048]); convert_element_type_1227 = None + mm_183 = torch.ops.aten.mm.default(view_1510, permute_341); permute_341 = None + view_1511 = torch.ops.aten.view.default(mm_183, [2, 4096, 3072]); mm_183 = None + view_1512 = torch.ops.aten.view.default(view_1511, [2, 4096, -1, 192]); view_1511 = None + split_with_sizes_69 = torch.ops.aten.split_with_sizes.default(view_1512, [128, 64], -1); view_1512 = None + getitem_2429 = split_with_sizes_69[0] + getitem_2430 = split_with_sizes_69[1]; split_with_sizes_69 = None + convert_element_type_1231 = torch.ops.prims.convert_element_type.default(getitem_2430, torch.float32); getitem_2430 = None + view_1513 = torch.ops.aten.view.default(convert_element_type_1231, [2, 4096, 16, -1, 2]); convert_element_type_1231 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_1513); view_1513 = None + mul_1089 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_7); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_1089); mul_1089 = None + view_1515 = torch.ops.aten.view.default(view_as_real_46, [2, 4096, 16, 64]); view_as_real_46 = None + convert_element_type_1232 = torch.ops.prims.convert_element_type.default(view_1515, torch.bfloat16); view_1515 = None + cat_200 = torch.ops.aten.cat.default([getitem_2429, convert_element_type_1232], -1); getitem_2429 = convert_element_type_1232 = None + convert_element_type_1233 = torch.ops.prims.convert_element_type.default(primals_377, torch.bfloat16) + all_gather_into_tensor_387 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1233, 128, '0'); convert_element_type_1233 = None + wait_tensor_475 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_387); all_gather_into_tensor_387 = None + slice_139 = torch.ops.aten.slice.Tensor(wait_tensor_475, 0, 0, 576); wait_tensor_475 = None + permute_342 = torch.ops.aten.permute.default(slice_139, [1, 0]); slice_139 = None + mm_184 = torch.ops.aten.mm.default(view_1510, permute_342); permute_342 = None + view_1518 = torch.ops.aten.view.default(mm_184, [2, 4096, 576]); mm_184 = None + split_with_sizes_70 = torch.ops.aten.split_with_sizes.default(view_1518, [512, 64], -1); view_1518 = None + getitem_2431 = split_with_sizes_70[0] + getitem_2432 = split_with_sizes_70[1]; split_with_sizes_70 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(getitem_2432, 2); getitem_2432 = None + convert_element_type_1236 = torch.ops.prims.convert_element_type.default(unsqueeze_45, torch.float32); unsqueeze_45 = None + view_1519 = torch.ops.aten.view.default(convert_element_type_1236, [2, 4096, 1, -1, 2]); convert_element_type_1236 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_1519); view_1519 = None + mul_1090 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_7); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_1090); mul_1090 = None + view_1521 = torch.ops.aten.view.default(view_as_real_47, [2, 4096, 1, 64]); view_as_real_47 = None + convert_element_type_1237 = torch.ops.prims.convert_element_type.default(view_1521, torch.bfloat16); view_1521 = None + convert_element_type_1238 = torch.ops.prims.convert_element_type.default(primals_378, torch.bfloat16) + all_gather_into_tensor_388 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1238, 128, '0'); convert_element_type_1238 = None + wait_tensor_476 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_388); all_gather_into_tensor_388 = None + convert_element_type_1239 = torch.ops.prims.convert_element_type.default(getitem_2431, torch.float32) + pow_71 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1239, 2) + mean_70 = torch.ops.aten.mean.dim(pow_71, [2], True); pow_71 = None + add_1503 = torch.ops.aten.add.Scalar(mean_70, 1e-05); mean_70 = None + rsqrt_70 = torch.ops.aten.rsqrt.default(add_1503); add_1503 = None + mul_1091 = torch.ops.aten.mul.Tensor(convert_element_type_1239, rsqrt_70); convert_element_type_1239 = None + mul_1092 = torch.ops.aten.mul.Tensor(mul_1091, wait_tensor_476); mul_1091 = wait_tensor_476 = None + convert_element_type_1240 = torch.ops.prims.convert_element_type.default(mul_1092, torch.bfloat16); mul_1092 = None + convert_element_type_1241 = torch.ops.prims.convert_element_type.default(primals_379, torch.bfloat16) + all_gather_into_tensor_389 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1241, 128, '0'); convert_element_type_1241 = None + wait_tensor_477 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_389); all_gather_into_tensor_389 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_477, [1, 0]); wait_tensor_477 = None + view_1524 = torch.ops.aten.view.default(convert_element_type_1240, [8192, 512]); convert_element_type_1240 = None + mm_185 = torch.ops.aten.mm.default(view_1524, permute_343); permute_343 = None + view_1525 = torch.ops.aten.view.default(mm_185, [2, 4096, 4096]); mm_185 = None + view_1526 = torch.ops.aten.view.default(view_1525, [2, 4096, -1, 256]); view_1525 = None + split_with_sizes_71 = torch.ops.aten.split_with_sizes.default(view_1526, [128, 128], -1); view_1526 = None + getitem_2433 = split_with_sizes_71[0] + getitem_2434 = split_with_sizes_71[1]; split_with_sizes_71 = None + expand_23 = torch.ops.aten.expand.default(convert_element_type_1237, [-1, -1, 16, -1]); convert_element_type_1237 = None + cat_201 = torch.ops.aten.cat.default([getitem_2433, expand_23], -1); getitem_2433 = expand_23 = None + permute_344 = torch.ops.aten.permute.default(cat_200, [0, 2, 1, 3]); cat_200 = None + permute_345 = torch.ops.aten.permute.default(cat_201, [0, 2, 1, 3]); cat_201 = None + permute_346 = torch.ops.aten.permute.default(getitem_2434, [0, 2, 1, 3]); getitem_2434 = None + sdpa_score23 = self.sdpa_score23 + sdpa_mask23 = self.sdpa_mask23 + flex_attention_23 = torch.ops.higher_order.flex_attention(permute_344, permute_345, permute_346, sdpa_score23, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask23), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score23 = sdpa_mask23 = None + getitem_2435 = flex_attention_23[0] + getitem_2436 = flex_attention_23[1]; flex_attention_23 = None + permute_347 = torch.ops.aten.permute.default(getitem_2435, [0, 2, 1, 3]) + view_1527 = torch.ops.aten.view.default(permute_347, [2, 4096, -1]); permute_347 = None + convert_element_type_1244 = torch.ops.prims.convert_element_type.default(primals_380, torch.bfloat16) + all_gather_into_tensor_390 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1244, 128, '0'); convert_element_type_1244 = None + wait_tensor_478 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_390); all_gather_into_tensor_390 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_478, [1, 0]); wait_tensor_478 = None + view_1529 = torch.ops.aten.view.default(view_1527, [8192, 2048]); view_1527 = None + mm_186 = torch.ops.aten.mm.default(view_1529, permute_348); view_1529 = permute_348 = None + view_1530 = torch.ops.aten.view.default(mm_186, [2, 4096, 2048]); mm_186 = None + add_1504 = torch.ops.aten.add.Tensor(add_1501, view_1530); view_1530 = None + convert_element_type_1247 = torch.ops.prims.convert_element_type.default(primals_381, torch.bfloat16) + all_gather_into_tensor_391 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1247, 128, '0'); convert_element_type_1247 = None + wait_tensor_479 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_391); all_gather_into_tensor_391 = None + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(add_1504, torch.float32) + pow_72 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1248, 2) + mean_71 = torch.ops.aten.mean.dim(pow_72, [2], True); pow_72 = None + add_1505 = torch.ops.aten.add.Scalar(mean_71, 1e-05); mean_71 = None + rsqrt_71 = torch.ops.aten.rsqrt.default(add_1505); add_1505 = None + mul_1093 = torch.ops.aten.mul.Tensor(convert_element_type_1248, rsqrt_71); convert_element_type_1248 = None + mul_1094 = torch.ops.aten.mul.Tensor(mul_1093, wait_tensor_479); mul_1093 = wait_tensor_479 = None + convert_element_type_1249 = torch.ops.prims.convert_element_type.default(mul_1094, torch.bfloat16); mul_1094 = None + view_1532 = torch.ops.aten.view.default(convert_element_type_1249, [-1, 2048]); convert_element_type_1249 = None + convert_element_type_1250 = torch.ops.prims.convert_element_type.default(primals_383, torch.bfloat16) + all_gather_into_tensor_392 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1250, 128, '0'); convert_element_type_1250 = None + wait_tensor_480 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_392); all_gather_into_tensor_392 = None + slice_141 = torch.ops.aten.slice.Tensor(wait_tensor_480, 0, 0, 64); wait_tensor_480 = None + permute_349 = torch.ops.aten.permute.default(slice_141, [1, 0]); slice_141 = None + mm_187 = torch.ops.aten.mm.default(view_1532, permute_349); permute_349 = None + convert_element_type_1253 = torch.ops.prims.convert_element_type.default(mm_187, torch.float32) + amax_22 = torch.ops.aten.amax.default(convert_element_type_1253, [1], True) + sub_528 = torch.ops.aten.sub.Tensor(convert_element_type_1253, amax_22); convert_element_type_1253 = None + exp_67 = torch.ops.aten.exp.default(sub_528); sub_528 = None + sum_89 = torch.ops.aten.sum.dim_IntList(exp_67, [1], True) + div_111 = torch.ops.aten.div.Tensor(exp_67, sum_89); exp_67 = None + add_1506 = torch.ops.aten.add.Tensor(div_111, primals_382); primals_382 = None + topk_22 = torch.ops.aten.topk.default(add_1506, 6, -1, True, False); add_1506 = None + getitem_2439 = topk_22[1]; topk_22 = None + gather_22 = torch.ops.aten.gather.default(div_111, 1, getitem_2439); div_111 = None + mul_1095 = torch.ops.aten.mul.Tensor(gather_22, 1.0); gather_22 = None + view_1534 = torch.ops.aten.view.default(getitem_2439, [-1]) + histc_44 = torch.ops.aten.histc.default(view_1534, 64, 0, 64) + add_1507 = torch.ops.aten.add.Tensor(primals_384, histc_44) + sort_22 = torch.ops.aten.sort.stable(view_1534, stable = True); view_1534 = None + getitem_2441 = sort_22[1]; sort_22 = None + div_112 = torch.ops.aten.div.Tensor_mode(getitem_2441, 6, rounding_mode = 'floor') + index_44 = torch.ops.aten.index.Tensor(view_1532, [div_112]) + all_to_all_single_66 = torch.ops._c10d_functional.all_to_all_single.default(histc_44, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_481 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_66); all_to_all_single_66 = None + wait_tensor_482 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_481); wait_tensor_481 = None + view_1538 = torch.ops.aten.view.default(histc_44, [8, -1]); histc_44 = None + sum_90 = torch.ops.aten.sum.dim_IntList(view_1538, [1]); view_1538 = None + device_put_44 = torch.ops.prims.device_put.default(sum_90, device(type='cpu'), True); sum_90 = None + view_1539 = torch.ops.aten.view.default(wait_tensor_482, [8, -1]) + sum_91 = torch.ops.aten.sum.dim_IntList(view_1539, [1]) + device_put_45 = torch.ops.prims.device_put.default(sum_91, device(type='cpu')); sum_91 = None + select_352 = torch.ops.aten.select.int(device_put_44, 0, 0) + _local_scalar_dense_352 = torch.ops.aten._local_scalar_dense.default(select_352); select_352 = None + ge_440 = _local_scalar_dense_352 >= 0 + _assert_scalar_352 = torch.ops.aten._assert_scalar.default(ge_440, "Runtime assertion failed for expression u352 >= 0 on node 'ge_352'"); ge_440 = _assert_scalar_352 = None + select_353 = torch.ops.aten.select.int(device_put_44, 0, 1) + _local_scalar_dense_353 = torch.ops.aten._local_scalar_dense.default(select_353); select_353 = None + ge_441 = _local_scalar_dense_353 >= 0 + _assert_scalar_353 = torch.ops.aten._assert_scalar.default(ge_441, "Runtime assertion failed for expression u353 >= 0 on node 'ge_353'"); ge_441 = _assert_scalar_353 = None + select_354 = torch.ops.aten.select.int(device_put_44, 0, 2) + _local_scalar_dense_354 = torch.ops.aten._local_scalar_dense.default(select_354); select_354 = None + ge_442 = _local_scalar_dense_354 >= 0 + _assert_scalar_354 = torch.ops.aten._assert_scalar.default(ge_442, "Runtime assertion failed for expression u354 >= 0 on node 'ge_354'"); ge_442 = _assert_scalar_354 = None + select_355 = torch.ops.aten.select.int(device_put_44, 0, 3) + _local_scalar_dense_355 = torch.ops.aten._local_scalar_dense.default(select_355); select_355 = None + ge_443 = _local_scalar_dense_355 >= 0 + _assert_scalar_355 = torch.ops.aten._assert_scalar.default(ge_443, "Runtime assertion failed for expression u355 >= 0 on node 'ge_355'"); ge_443 = _assert_scalar_355 = None + select_356 = torch.ops.aten.select.int(device_put_44, 0, 4) + _local_scalar_dense_356 = torch.ops.aten._local_scalar_dense.default(select_356); select_356 = None + ge_444 = _local_scalar_dense_356 >= 0 + _assert_scalar_356 = torch.ops.aten._assert_scalar.default(ge_444, "Runtime assertion failed for expression u356 >= 0 on node 'ge_356'"); ge_444 = _assert_scalar_356 = None + select_357 = torch.ops.aten.select.int(device_put_44, 0, 5) + _local_scalar_dense_357 = torch.ops.aten._local_scalar_dense.default(select_357); select_357 = None + ge_445 = _local_scalar_dense_357 >= 0 + _assert_scalar_357 = torch.ops.aten._assert_scalar.default(ge_445, "Runtime assertion failed for expression u357 >= 0 on node 'ge_357'"); ge_445 = _assert_scalar_357 = None + select_358 = torch.ops.aten.select.int(device_put_44, 0, 6) + _local_scalar_dense_358 = torch.ops.aten._local_scalar_dense.default(select_358); select_358 = None + ge_446 = _local_scalar_dense_358 >= 0 + _assert_scalar_358 = torch.ops.aten._assert_scalar.default(ge_446, "Runtime assertion failed for expression u358 >= 0 on node 'ge_358'"); ge_446 = _assert_scalar_358 = None + select_359 = torch.ops.aten.select.int(device_put_44, 0, 7); device_put_44 = None + _local_scalar_dense_359 = torch.ops.aten._local_scalar_dense.default(select_359); select_359 = None + ge_447 = _local_scalar_dense_359 >= 0 + _assert_scalar_359 = torch.ops.aten._assert_scalar.default(ge_447, "Runtime assertion failed for expression u359 >= 0 on node 'ge_359'"); ge_447 = _assert_scalar_359 = None + select_360 = torch.ops.aten.select.int(device_put_45, 0, 0) + _local_scalar_dense_360 = torch.ops.aten._local_scalar_dense.default(select_360); select_360 = None + ge_448 = _local_scalar_dense_360 >= 0 + _assert_scalar_360 = torch.ops.aten._assert_scalar.default(ge_448, "Runtime assertion failed for expression u360 >= 0 on node 'ge_360'"); ge_448 = _assert_scalar_360 = None + select_361 = torch.ops.aten.select.int(device_put_45, 0, 1) + _local_scalar_dense_361 = torch.ops.aten._local_scalar_dense.default(select_361); select_361 = None + ge_449 = _local_scalar_dense_361 >= 0 + _assert_scalar_361 = torch.ops.aten._assert_scalar.default(ge_449, "Runtime assertion failed for expression u361 >= 0 on node 'ge_361'"); ge_449 = _assert_scalar_361 = None + select_362 = torch.ops.aten.select.int(device_put_45, 0, 2) + _local_scalar_dense_362 = torch.ops.aten._local_scalar_dense.default(select_362); select_362 = None + ge_450 = _local_scalar_dense_362 >= 0 + _assert_scalar_362 = torch.ops.aten._assert_scalar.default(ge_450, "Runtime assertion failed for expression u362 >= 0 on node 'ge_362'"); ge_450 = _assert_scalar_362 = None + select_363 = torch.ops.aten.select.int(device_put_45, 0, 3) + _local_scalar_dense_363 = torch.ops.aten._local_scalar_dense.default(select_363); select_363 = None + ge_451 = _local_scalar_dense_363 >= 0 + _assert_scalar_363 = torch.ops.aten._assert_scalar.default(ge_451, "Runtime assertion failed for expression u363 >= 0 on node 'ge_363'"); ge_451 = _assert_scalar_363 = None + select_364 = torch.ops.aten.select.int(device_put_45, 0, 4) + _local_scalar_dense_364 = torch.ops.aten._local_scalar_dense.default(select_364); select_364 = None + ge_452 = _local_scalar_dense_364 >= 0 + _assert_scalar_364 = torch.ops.aten._assert_scalar.default(ge_452, "Runtime assertion failed for expression u364 >= 0 on node 'ge_364'"); ge_452 = _assert_scalar_364 = None + select_365 = torch.ops.aten.select.int(device_put_45, 0, 5) + _local_scalar_dense_365 = torch.ops.aten._local_scalar_dense.default(select_365); select_365 = None + ge_453 = _local_scalar_dense_365 >= 0 + _assert_scalar_365 = torch.ops.aten._assert_scalar.default(ge_453, "Runtime assertion failed for expression u365 >= 0 on node 'ge_365'"); ge_453 = _assert_scalar_365 = None + select_366 = torch.ops.aten.select.int(device_put_45, 0, 6) + _local_scalar_dense_366 = torch.ops.aten._local_scalar_dense.default(select_366); select_366 = None + ge_454 = _local_scalar_dense_366 >= 0 + _assert_scalar_366 = torch.ops.aten._assert_scalar.default(ge_454, "Runtime assertion failed for expression u366 >= 0 on node 'ge_366'"); ge_454 = _assert_scalar_366 = None + select_367 = torch.ops.aten.select.int(device_put_45, 0, 7); device_put_45 = None + _local_scalar_dense_367 = torch.ops.aten._local_scalar_dense.default(select_367); select_367 = None + ge_455 = _local_scalar_dense_367 >= 0 + _assert_scalar_367 = torch.ops.aten._assert_scalar.default(ge_455, "Runtime assertion failed for expression u367 >= 0 on node 'ge_367'"); ge_455 = _assert_scalar_367 = None + all_to_all_single_67 = torch.ops._c10d_functional.all_to_all_single.default(index_44, [_local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367], [_local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359], '1033'); index_44 = None + sym_size_int_88 = torch.ops.aten.sym_size.int(all_to_all_single_67, 0) + wait_tensor_483 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_67); all_to_all_single_67 = None + sym_sum_44 = torch.sym_sum((_local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367)) + add_1514 = sym_sum_44 + 64; sym_sum_44 = None + add_1515 = add_1514 + 8; add_1514 = None + sub_531 = add_1515 - 1; add_1515 = None + floordiv_22 = sub_531 // 8; sub_531 = None + mul_1100 = floordiv_22 * 8; floordiv_22 = None + cumsum_66 = torch.ops.aten.cumsum.default(wait_tensor_482, 0) + sub_532 = torch.ops.aten.sub.Tensor(cumsum_66, wait_tensor_482); cumsum_66 = None + sum_92 = torch.ops.aten.sum.dim_IntList(view_1539, [0]); view_1539 = None + clamp_min_22 = torch.ops.aten.clamp_min.default(sum_92, 8); sum_92 = None + add_1516 = torch.ops.aten.add.Tensor(clamp_min_22, 8); clamp_min_22 = None + sub_533 = torch.ops.aten.sub.Tensor(add_1516, 1); add_1516 = None + div_113 = torch.ops.aten.div.Tensor_mode(sub_533, 8, rounding_mode = 'floor'); sub_533 = None + mul_1101 = torch.ops.aten.mul.Tensor(div_113, 8); div_113 = None + convert_element_type_1256 = torch.ops.prims.convert_element_type.default(mul_1101, torch.int32); mul_1101 = None + cumsum_67 = torch.ops.aten.cumsum.default(convert_element_type_1256, 0) + sub_534 = torch.ops.aten.sub.Tensor(cumsum_67, convert_element_type_1256); cumsum_67 = None + full_306 = torch.ops.aten.full.default([mul_1100], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1100 = None + triton_kernel_wrapper_functional_proxy_22 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 22, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_482, 'start_index_values_ptr': sub_532, 'write_offsets_ptr': sub_534, 'output_ptr': full_306}, tensors_to_clone = ['output_ptr']); wait_tensor_482 = sub_532 = sub_534 = full_306 = None + getitem_2442 = triton_kernel_wrapper_functional_proxy_22['output_ptr']; triton_kernel_wrapper_functional_proxy_22 = None + cat_202 = torch.ops.aten.cat.default([wait_tensor_483, full_default]); wait_tensor_483 = None + sym_size_int_89 = torch.ops.aten.sym_size.int(cat_202, 0) + sym_sum_45 = torch.sym_sum((1, _local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367)) + index_45 = torch.ops.aten.index.Tensor(cat_202, [getitem_2442]); cat_202 = None + convert_element_type_1258 = torch.ops.prims.convert_element_type.default(primals_385, torch.bfloat16) + all_gather_into_tensor_393 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1258, 16, '1025'); convert_element_type_1258 = None + wait_tensor_484 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_393); all_gather_into_tensor_393 = None + split_133 = torch.ops.aten.split.Tensor(wait_tensor_484, 8); wait_tensor_484 = None + getitem_2459 = split_133[0] + getitem_2460 = split_133[1] + getitem_2461 = split_133[2] + getitem_2462 = split_133[3] + getitem_2463 = split_133[4] + getitem_2464 = split_133[5] + getitem_2465 = split_133[6] + getitem_2466 = split_133[7] + getitem_2467 = split_133[8] + getitem_2468 = split_133[9] + getitem_2469 = split_133[10] + getitem_2470 = split_133[11] + getitem_2471 = split_133[12] + getitem_2472 = split_133[13] + getitem_2473 = split_133[14] + getitem_2474 = split_133[15]; split_133 = None + cat_204 = torch.ops.aten.cat.default([getitem_2459, getitem_2460, getitem_2461, getitem_2462, getitem_2463, getitem_2464, getitem_2465, getitem_2466, getitem_2467, getitem_2468, getitem_2469, getitem_2470, getitem_2471, getitem_2472, getitem_2473, getitem_2474], 1); getitem_2459 = getitem_2460 = getitem_2461 = getitem_2462 = getitem_2463 = getitem_2464 = getitem_2465 = getitem_2466 = getitem_2467 = getitem_2468 = getitem_2469 = getitem_2470 = getitem_2471 = getitem_2472 = getitem_2473 = getitem_2474 = None + convert_element_type_1260 = torch.ops.prims.convert_element_type.default(primals_386, torch.bfloat16) + all_gather_into_tensor_395 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1260, 16, '1025'); convert_element_type_1260 = None + wait_tensor_486 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_395); all_gather_into_tensor_395 = None + split_134 = torch.ops.aten.split.Tensor(wait_tensor_486, 8); wait_tensor_486 = None + getitem_2475 = split_134[0] + getitem_2476 = split_134[1] + getitem_2477 = split_134[2] + getitem_2478 = split_134[3] + getitem_2479 = split_134[4] + getitem_2480 = split_134[5] + getitem_2481 = split_134[6] + getitem_2482 = split_134[7] + getitem_2483 = split_134[8] + getitem_2484 = split_134[9] + getitem_2485 = split_134[10] + getitem_2486 = split_134[11] + getitem_2487 = split_134[12] + getitem_2488 = split_134[13] + getitem_2489 = split_134[14] + getitem_2490 = split_134[15]; split_134 = None + cat_205 = torch.ops.aten.cat.default([getitem_2475, getitem_2476, getitem_2477, getitem_2478, getitem_2479, getitem_2480, getitem_2481, getitem_2482, getitem_2483, getitem_2484, getitem_2485, getitem_2486, getitem_2487, getitem_2488, getitem_2489, getitem_2490], 1); getitem_2475 = getitem_2476 = getitem_2477 = getitem_2478 = getitem_2479 = getitem_2480 = getitem_2481 = getitem_2482 = getitem_2483 = getitem_2484 = getitem_2485 = getitem_2486 = getitem_2487 = getitem_2488 = getitem_2489 = getitem_2490 = None + convert_element_type_1261 = torch.ops.prims.convert_element_type.default(primals_387, torch.bfloat16) + all_gather_into_tensor_396 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1261, 16, '1025'); convert_element_type_1261 = None + wait_tensor_487 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_396); all_gather_into_tensor_396 = None + split_135 = torch.ops.aten.split.Tensor(wait_tensor_487, 8); wait_tensor_487 = None + getitem_2491 = split_135[0] + getitem_2492 = split_135[1] + getitem_2493 = split_135[2] + getitem_2494 = split_135[3] + getitem_2495 = split_135[4] + getitem_2496 = split_135[5] + getitem_2497 = split_135[6] + getitem_2498 = split_135[7] + getitem_2499 = split_135[8] + getitem_2500 = split_135[9] + getitem_2501 = split_135[10] + getitem_2502 = split_135[11] + getitem_2503 = split_135[12] + getitem_2504 = split_135[13] + getitem_2505 = split_135[14] + getitem_2506 = split_135[15]; split_135 = None + cat_206 = torch.ops.aten.cat.default([getitem_2491, getitem_2492, getitem_2493, getitem_2494, getitem_2495, getitem_2496, getitem_2497, getitem_2498, getitem_2499, getitem_2500, getitem_2501, getitem_2502, getitem_2503, getitem_2504, getitem_2505, getitem_2506], 1); getitem_2491 = getitem_2492 = getitem_2493 = getitem_2494 = getitem_2495 = getitem_2496 = getitem_2497 = getitem_2498 = getitem_2499 = getitem_2500 = getitem_2501 = getitem_2502 = getitem_2503 = getitem_2504 = getitem_2505 = getitem_2506 = None + cumsum_68 = torch.ops.aten.cumsum.default(convert_element_type_1256, 0, dtype = torch.int32); convert_element_type_1256 = None + permute_350 = torch.ops.aten.permute.default(cat_204, [0, 2, 1]); cat_204 = None + _grouped_mm_66 = torch.ops.aten._grouped_mm.default(index_45, permute_350, cumsum_68) + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(_grouped_mm_66, torch.float32) + neg_45 = torch.ops.aten.neg.default(convert_element_type_1264) + exp_68 = torch.ops.aten.exp.default(neg_45); neg_45 = None + add_1528 = torch.ops.aten.add.Tensor(exp_68, 1); exp_68 = None + div_114 = torch.ops.aten.div.Tensor(convert_element_type_1264, add_1528); convert_element_type_1264 = add_1528 = None + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(div_114, torch.bfloat16); div_114 = None + permute_351 = torch.ops.aten.permute.default(cat_206, [0, 2, 1]); cat_206 = None + _grouped_mm_67 = torch.ops.aten._grouped_mm.default(index_45, permute_351, cumsum_68) + mul_1113 = torch.ops.aten.mul.Tensor(convert_element_type_1265, _grouped_mm_67); convert_element_type_1265 = None + permute_352 = torch.ops.aten.permute.default(cat_205, [0, 2, 1]); cat_205 = None + _grouped_mm_68 = torch.ops.aten._grouped_mm.default(mul_1113, permute_352, cumsum_68) + empty_22 = torch.ops.aten.empty.memory_format([sym_size_int_89, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_44 = torch.ops.aten.index_put.default(empty_22, [getitem_2442], _grouped_mm_68); empty_22 = _grouped_mm_68 = None + slice_143 = torch.ops.aten.slice.Tensor(index_put_44, 0, 0, -1); index_put_44 = None + all_to_all_single_68 = torch.ops._c10d_functional.all_to_all_single.default(slice_143, [_local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359], [_local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367], '1033'); slice_143 = None + wait_tensor_490 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_68); all_to_all_single_68 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(primals_388, torch.bfloat16) + all_gather_into_tensor_399 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1266, 128, '0'); convert_element_type_1266 = None + wait_tensor_491 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_399); all_gather_into_tensor_399 = None + permute_353 = torch.ops.aten.permute.default(wait_tensor_491, [1, 0]); wait_tensor_491 = None + mm_188 = torch.ops.aten.mm.default(view_1532, permute_353); permute_353 = None + convert_element_type_1269 = torch.ops.prims.convert_element_type.default(mm_188, torch.float32) + neg_46 = torch.ops.aten.neg.default(convert_element_type_1269) + exp_69 = torch.ops.aten.exp.default(neg_46); neg_46 = None + add_1564 = torch.ops.aten.add.Tensor(exp_69, 1); exp_69 = None + div_115 = torch.ops.aten.div.Tensor(convert_element_type_1269, add_1564); convert_element_type_1269 = add_1564 = None + convert_element_type_1270 = torch.ops.prims.convert_element_type.default(div_115, torch.bfloat16); div_115 = None + convert_element_type_1271 = torch.ops.prims.convert_element_type.default(primals_389, torch.bfloat16) + all_gather_into_tensor_400 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1271, 128, '0'); convert_element_type_1271 = None + wait_tensor_492 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_400); all_gather_into_tensor_400 = None + permute_354 = torch.ops.aten.permute.default(wait_tensor_492, [1, 0]); wait_tensor_492 = None + mm_189 = torch.ops.aten.mm.default(view_1532, permute_354); permute_354 = None + mul_1133 = torch.ops.aten.mul.Tensor(convert_element_type_1270, mm_189); convert_element_type_1270 = None + convert_element_type_1274 = torch.ops.prims.convert_element_type.default(primals_390, torch.bfloat16) + all_gather_into_tensor_401 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1274, 128, '0'); convert_element_type_1274 = None + wait_tensor_493 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_401); all_gather_into_tensor_401 = None + permute_355 = torch.ops.aten.permute.default(wait_tensor_493, [1, 0]); wait_tensor_493 = None + mm_190 = torch.ops.aten.mm.default(mul_1133, permute_355); permute_355 = None + index_put_45 = torch.ops.aten.index_put.default(full_default_1, [getitem_2441], wait_tensor_490); wait_tensor_490 = None + view_1572 = torch.ops.aten.view.default(mul_1095, [-1, 1, 6]); mul_1095 = None + view_1573 = torch.ops.aten.view.default(index_put_45, [-1, 6, 2048]); index_put_45 = None + convert_element_type_1277 = torch.ops.prims.convert_element_type.default(view_1573, torch.float32); view_1573 = None + bmm_22 = torch.ops.aten.bmm.default(view_1572, convert_element_type_1277) + convert_element_type_1278 = torch.ops.prims.convert_element_type.default(bmm_22, torch.bfloat16); bmm_22 = None + squeeze_22 = torch.ops.aten.squeeze.dim(convert_element_type_1278, 1); convert_element_type_1278 = None + add_1568 = torch.ops.aten.add.Tensor(mm_190, squeeze_22); mm_190 = squeeze_22 = None + view_1574 = torch.ops.aten.view.default(add_1568, [2, 4096, 2048]); add_1568 = None + add_1569 = torch.ops.aten.add.Tensor(add_1504, view_1574); view_1574 = None + convert_element_type_1279 = torch.ops.prims.convert_element_type.default(primals_391, torch.bfloat16) + all_gather_into_tensor_402 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1279, 128, '0'); convert_element_type_1279 = None + wait_tensor_494 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_402); all_gather_into_tensor_402 = None + convert_element_type_1280 = torch.ops.prims.convert_element_type.default(add_1569, torch.float32) + pow_73 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1280, 2) + mean_72 = torch.ops.aten.mean.dim(pow_73, [2], True); pow_73 = None + add_1570 = torch.ops.aten.add.Scalar(mean_72, 1e-05); mean_72 = None + rsqrt_72 = torch.ops.aten.rsqrt.default(add_1570); add_1570 = None + mul_1136 = torch.ops.aten.mul.Tensor(convert_element_type_1280, rsqrt_72); convert_element_type_1280 = None + mul_1137 = torch.ops.aten.mul.Tensor(mul_1136, wait_tensor_494); mul_1136 = wait_tensor_494 = None + convert_element_type_1281 = torch.ops.prims.convert_element_type.default(mul_1137, torch.bfloat16); mul_1137 = None + convert_element_type_1282 = torch.ops.prims.convert_element_type.default(primals_392, torch.bfloat16) + all_gather_into_tensor_403 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1282, 128, '0'); convert_element_type_1282 = None + wait_tensor_495 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_403); all_gather_into_tensor_403 = None + permute_356 = torch.ops.aten.permute.default(wait_tensor_495, [1, 0]); wait_tensor_495 = None + view_1577 = torch.ops.aten.view.default(convert_element_type_1281, [8192, 2048]); convert_element_type_1281 = None + mm_191 = torch.ops.aten.mm.default(view_1577, permute_356); permute_356 = None + view_1578 = torch.ops.aten.view.default(mm_191, [2, 4096, 3072]); mm_191 = None + view_1579 = torch.ops.aten.view.default(view_1578, [2, 4096, -1, 192]); view_1578 = None + split_with_sizes_72 = torch.ops.aten.split_with_sizes.default(view_1579, [128, 64], -1); view_1579 = None + getitem_2539 = split_with_sizes_72[0] + getitem_2540 = split_with_sizes_72[1]; split_with_sizes_72 = None + convert_element_type_1285 = torch.ops.prims.convert_element_type.default(getitem_2540, torch.float32); getitem_2540 = None + view_1580 = torch.ops.aten.view.default(convert_element_type_1285, [2, 4096, 16, -1, 2]); convert_element_type_1285 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_1580); view_1580 = None + mul_1138 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_7); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_1138); mul_1138 = None + view_1582 = torch.ops.aten.view.default(view_as_real_48, [2, 4096, 16, 64]); view_as_real_48 = None + convert_element_type_1286 = torch.ops.prims.convert_element_type.default(view_1582, torch.bfloat16); view_1582 = None + cat_209 = torch.ops.aten.cat.default([getitem_2539, convert_element_type_1286], -1); getitem_2539 = convert_element_type_1286 = None + convert_element_type_1287 = torch.ops.prims.convert_element_type.default(primals_393, torch.bfloat16) + all_gather_into_tensor_404 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1287, 128, '0'); convert_element_type_1287 = None + wait_tensor_496 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_404); all_gather_into_tensor_404 = None + slice_145 = torch.ops.aten.slice.Tensor(wait_tensor_496, 0, 0, 576); wait_tensor_496 = None + permute_357 = torch.ops.aten.permute.default(slice_145, [1, 0]); slice_145 = None + mm_192 = torch.ops.aten.mm.default(view_1577, permute_357); permute_357 = None + view_1585 = torch.ops.aten.view.default(mm_192, [2, 4096, 576]); mm_192 = None + split_with_sizes_73 = torch.ops.aten.split_with_sizes.default(view_1585, [512, 64], -1); view_1585 = None + getitem_2541 = split_with_sizes_73[0] + getitem_2542 = split_with_sizes_73[1]; split_with_sizes_73 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(getitem_2542, 2); getitem_2542 = None + convert_element_type_1290 = torch.ops.prims.convert_element_type.default(unsqueeze_47, torch.float32); unsqueeze_47 = None + view_1586 = torch.ops.aten.view.default(convert_element_type_1290, [2, 4096, 1, -1, 2]); convert_element_type_1290 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_1586); view_1586 = None + mul_1139 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_7); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_1139); mul_1139 = None + view_1588 = torch.ops.aten.view.default(view_as_real_49, [2, 4096, 1, 64]); view_as_real_49 = None + convert_element_type_1291 = torch.ops.prims.convert_element_type.default(view_1588, torch.bfloat16); view_1588 = None + convert_element_type_1292 = torch.ops.prims.convert_element_type.default(primals_394, torch.bfloat16) + all_gather_into_tensor_405 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1292, 128, '0'); convert_element_type_1292 = None + wait_tensor_497 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_405); all_gather_into_tensor_405 = None + convert_element_type_1293 = torch.ops.prims.convert_element_type.default(getitem_2541, torch.float32) + pow_74 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1293, 2) + mean_73 = torch.ops.aten.mean.dim(pow_74, [2], True); pow_74 = None + add_1571 = torch.ops.aten.add.Scalar(mean_73, 1e-05); mean_73 = None + rsqrt_73 = torch.ops.aten.rsqrt.default(add_1571); add_1571 = None + mul_1140 = torch.ops.aten.mul.Tensor(convert_element_type_1293, rsqrt_73); convert_element_type_1293 = None + mul_1141 = torch.ops.aten.mul.Tensor(mul_1140, wait_tensor_497); mul_1140 = wait_tensor_497 = None + convert_element_type_1294 = torch.ops.prims.convert_element_type.default(mul_1141, torch.bfloat16); mul_1141 = None + convert_element_type_1295 = torch.ops.prims.convert_element_type.default(primals_395, torch.bfloat16) + all_gather_into_tensor_406 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1295, 128, '0'); convert_element_type_1295 = None + wait_tensor_498 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_406); all_gather_into_tensor_406 = None + permute_358 = torch.ops.aten.permute.default(wait_tensor_498, [1, 0]); wait_tensor_498 = None + view_1591 = torch.ops.aten.view.default(convert_element_type_1294, [8192, 512]); convert_element_type_1294 = None + mm_193 = torch.ops.aten.mm.default(view_1591, permute_358); permute_358 = None + view_1592 = torch.ops.aten.view.default(mm_193, [2, 4096, 4096]); mm_193 = None + view_1593 = torch.ops.aten.view.default(view_1592, [2, 4096, -1, 256]); view_1592 = None + split_with_sizes_74 = torch.ops.aten.split_with_sizes.default(view_1593, [128, 128], -1); view_1593 = None + getitem_2543 = split_with_sizes_74[0] + getitem_2544 = split_with_sizes_74[1]; split_with_sizes_74 = None + expand_24 = torch.ops.aten.expand.default(convert_element_type_1291, [-1, -1, 16, -1]); convert_element_type_1291 = None + cat_210 = torch.ops.aten.cat.default([getitem_2543, expand_24], -1); getitem_2543 = expand_24 = None + permute_359 = torch.ops.aten.permute.default(cat_209, [0, 2, 1, 3]); cat_209 = None + permute_360 = torch.ops.aten.permute.default(cat_210, [0, 2, 1, 3]); cat_210 = None + permute_361 = torch.ops.aten.permute.default(getitem_2544, [0, 2, 1, 3]); getitem_2544 = None + sdpa_score24 = self.sdpa_score24 + sdpa_mask24 = self.sdpa_mask24 + flex_attention_24 = torch.ops.higher_order.flex_attention(permute_359, permute_360, permute_361, sdpa_score24, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask24), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score24 = sdpa_mask24 = None + getitem_2545 = flex_attention_24[0] + getitem_2546 = flex_attention_24[1]; flex_attention_24 = None + permute_362 = torch.ops.aten.permute.default(getitem_2545, [0, 2, 1, 3]) + view_1594 = torch.ops.aten.view.default(permute_362, [2, 4096, -1]); permute_362 = None + convert_element_type_1298 = torch.ops.prims.convert_element_type.default(primals_396, torch.bfloat16) + all_gather_into_tensor_407 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1298, 128, '0'); convert_element_type_1298 = None + wait_tensor_499 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_407); all_gather_into_tensor_407 = None + permute_363 = torch.ops.aten.permute.default(wait_tensor_499, [1, 0]); wait_tensor_499 = None + view_1596 = torch.ops.aten.view.default(view_1594, [8192, 2048]); view_1594 = None + mm_194 = torch.ops.aten.mm.default(view_1596, permute_363); view_1596 = permute_363 = None + view_1597 = torch.ops.aten.view.default(mm_194, [2, 4096, 2048]); mm_194 = None + add_1572 = torch.ops.aten.add.Tensor(add_1569, view_1597); view_1597 = None + convert_element_type_1301 = torch.ops.prims.convert_element_type.default(primals_397, torch.bfloat16) + all_gather_into_tensor_408 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1301, 128, '0'); convert_element_type_1301 = None + wait_tensor_500 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_408); all_gather_into_tensor_408 = None + convert_element_type_1302 = torch.ops.prims.convert_element_type.default(add_1572, torch.float32) + pow_75 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1302, 2) + mean_74 = torch.ops.aten.mean.dim(pow_75, [2], True); pow_75 = None + add_1573 = torch.ops.aten.add.Scalar(mean_74, 1e-05); mean_74 = None + rsqrt_74 = torch.ops.aten.rsqrt.default(add_1573); add_1573 = None + mul_1142 = torch.ops.aten.mul.Tensor(convert_element_type_1302, rsqrt_74); convert_element_type_1302 = None + mul_1143 = torch.ops.aten.mul.Tensor(mul_1142, wait_tensor_500); mul_1142 = wait_tensor_500 = None + convert_element_type_1303 = torch.ops.prims.convert_element_type.default(mul_1143, torch.bfloat16); mul_1143 = None + view_1599 = torch.ops.aten.view.default(convert_element_type_1303, [-1, 2048]); convert_element_type_1303 = None + convert_element_type_1304 = torch.ops.prims.convert_element_type.default(primals_399, torch.bfloat16) + all_gather_into_tensor_409 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1304, 128, '0'); convert_element_type_1304 = None + wait_tensor_501 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_409); all_gather_into_tensor_409 = None + slice_147 = torch.ops.aten.slice.Tensor(wait_tensor_501, 0, 0, 64); wait_tensor_501 = None + permute_364 = torch.ops.aten.permute.default(slice_147, [1, 0]); slice_147 = None + mm_195 = torch.ops.aten.mm.default(view_1599, permute_364); permute_364 = None + convert_element_type_1307 = torch.ops.prims.convert_element_type.default(mm_195, torch.float32) + amax_23 = torch.ops.aten.amax.default(convert_element_type_1307, [1], True) + sub_552 = torch.ops.aten.sub.Tensor(convert_element_type_1307, amax_23); convert_element_type_1307 = None + exp_70 = torch.ops.aten.exp.default(sub_552); sub_552 = None + sum_93 = torch.ops.aten.sum.dim_IntList(exp_70, [1], True) + div_116 = torch.ops.aten.div.Tensor(exp_70, sum_93); exp_70 = None + add_1574 = torch.ops.aten.add.Tensor(div_116, primals_398); primals_398 = None + topk_23 = torch.ops.aten.topk.default(add_1574, 6, -1, True, False); add_1574 = None + getitem_2549 = topk_23[1]; topk_23 = None + gather_23 = torch.ops.aten.gather.default(div_116, 1, getitem_2549); div_116 = None + mul_1144 = torch.ops.aten.mul.Tensor(gather_23, 1.0); gather_23 = None + view_1601 = torch.ops.aten.view.default(getitem_2549, [-1]) + histc_46 = torch.ops.aten.histc.default(view_1601, 64, 0, 64) + add_1575 = torch.ops.aten.add.Tensor(primals_400, histc_46) + sort_23 = torch.ops.aten.sort.stable(view_1601, stable = True); view_1601 = None + getitem_2551 = sort_23[1]; sort_23 = None + div_117 = torch.ops.aten.div.Tensor_mode(getitem_2551, 6, rounding_mode = 'floor') + index_46 = torch.ops.aten.index.Tensor(view_1599, [div_117]) + all_to_all_single_69 = torch.ops._c10d_functional.all_to_all_single.default(histc_46, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_502 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_69); all_to_all_single_69 = None + wait_tensor_503 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_502); wait_tensor_502 = None + view_1605 = torch.ops.aten.view.default(histc_46, [8, -1]); histc_46 = None + sum_94 = torch.ops.aten.sum.dim_IntList(view_1605, [1]); view_1605 = None + device_put_46 = torch.ops.prims.device_put.default(sum_94, device(type='cpu'), True); sum_94 = None + view_1606 = torch.ops.aten.view.default(wait_tensor_503, [8, -1]) + sum_95 = torch.ops.aten.sum.dim_IntList(view_1606, [1]) + device_put_47 = torch.ops.prims.device_put.default(sum_95, device(type='cpu')); sum_95 = None + select_368 = torch.ops.aten.select.int(device_put_46, 0, 0) + _local_scalar_dense_368 = torch.ops.aten._local_scalar_dense.default(select_368); select_368 = None + ge_460 = _local_scalar_dense_368 >= 0 + _assert_scalar_368 = torch.ops.aten._assert_scalar.default(ge_460, "Runtime assertion failed for expression u368 >= 0 on node 'ge_368'"); ge_460 = _assert_scalar_368 = None + select_369 = torch.ops.aten.select.int(device_put_46, 0, 1) + _local_scalar_dense_369 = torch.ops.aten._local_scalar_dense.default(select_369); select_369 = None + ge_461 = _local_scalar_dense_369 >= 0 + _assert_scalar_369 = torch.ops.aten._assert_scalar.default(ge_461, "Runtime assertion failed for expression u369 >= 0 on node 'ge_369'"); ge_461 = _assert_scalar_369 = None + select_370 = torch.ops.aten.select.int(device_put_46, 0, 2) + _local_scalar_dense_370 = torch.ops.aten._local_scalar_dense.default(select_370); select_370 = None + ge_462 = _local_scalar_dense_370 >= 0 + _assert_scalar_370 = torch.ops.aten._assert_scalar.default(ge_462, "Runtime assertion failed for expression u370 >= 0 on node 'ge_370'"); ge_462 = _assert_scalar_370 = None + select_371 = torch.ops.aten.select.int(device_put_46, 0, 3) + _local_scalar_dense_371 = torch.ops.aten._local_scalar_dense.default(select_371); select_371 = None + ge_463 = _local_scalar_dense_371 >= 0 + _assert_scalar_371 = torch.ops.aten._assert_scalar.default(ge_463, "Runtime assertion failed for expression u371 >= 0 on node 'ge_371'"); ge_463 = _assert_scalar_371 = None + select_372 = torch.ops.aten.select.int(device_put_46, 0, 4) + _local_scalar_dense_372 = torch.ops.aten._local_scalar_dense.default(select_372); select_372 = None + ge_464 = _local_scalar_dense_372 >= 0 + _assert_scalar_372 = torch.ops.aten._assert_scalar.default(ge_464, "Runtime assertion failed for expression u372 >= 0 on node 'ge_372'"); ge_464 = _assert_scalar_372 = None + select_373 = torch.ops.aten.select.int(device_put_46, 0, 5) + _local_scalar_dense_373 = torch.ops.aten._local_scalar_dense.default(select_373); select_373 = None + ge_465 = _local_scalar_dense_373 >= 0 + _assert_scalar_373 = torch.ops.aten._assert_scalar.default(ge_465, "Runtime assertion failed for expression u373 >= 0 on node 'ge_373'"); ge_465 = _assert_scalar_373 = None + select_374 = torch.ops.aten.select.int(device_put_46, 0, 6) + _local_scalar_dense_374 = torch.ops.aten._local_scalar_dense.default(select_374); select_374 = None + ge_466 = _local_scalar_dense_374 >= 0 + _assert_scalar_374 = torch.ops.aten._assert_scalar.default(ge_466, "Runtime assertion failed for expression u374 >= 0 on node 'ge_374'"); ge_466 = _assert_scalar_374 = None + select_375 = torch.ops.aten.select.int(device_put_46, 0, 7); device_put_46 = None + _local_scalar_dense_375 = torch.ops.aten._local_scalar_dense.default(select_375); select_375 = None + ge_467 = _local_scalar_dense_375 >= 0 + _assert_scalar_375 = torch.ops.aten._assert_scalar.default(ge_467, "Runtime assertion failed for expression u375 >= 0 on node 'ge_375'"); ge_467 = _assert_scalar_375 = None + select_376 = torch.ops.aten.select.int(device_put_47, 0, 0) + _local_scalar_dense_376 = torch.ops.aten._local_scalar_dense.default(select_376); select_376 = None + ge_468 = _local_scalar_dense_376 >= 0 + _assert_scalar_376 = torch.ops.aten._assert_scalar.default(ge_468, "Runtime assertion failed for expression u376 >= 0 on node 'ge_376'"); ge_468 = _assert_scalar_376 = None + select_377 = torch.ops.aten.select.int(device_put_47, 0, 1) + _local_scalar_dense_377 = torch.ops.aten._local_scalar_dense.default(select_377); select_377 = None + ge_469 = _local_scalar_dense_377 >= 0 + _assert_scalar_377 = torch.ops.aten._assert_scalar.default(ge_469, "Runtime assertion failed for expression u377 >= 0 on node 'ge_377'"); ge_469 = _assert_scalar_377 = None + select_378 = torch.ops.aten.select.int(device_put_47, 0, 2) + _local_scalar_dense_378 = torch.ops.aten._local_scalar_dense.default(select_378); select_378 = None + ge_470 = _local_scalar_dense_378 >= 0 + _assert_scalar_378 = torch.ops.aten._assert_scalar.default(ge_470, "Runtime assertion failed for expression u378 >= 0 on node 'ge_378'"); ge_470 = _assert_scalar_378 = None + select_379 = torch.ops.aten.select.int(device_put_47, 0, 3) + _local_scalar_dense_379 = torch.ops.aten._local_scalar_dense.default(select_379); select_379 = None + ge_471 = _local_scalar_dense_379 >= 0 + _assert_scalar_379 = torch.ops.aten._assert_scalar.default(ge_471, "Runtime assertion failed for expression u379 >= 0 on node 'ge_379'"); ge_471 = _assert_scalar_379 = None + select_380 = torch.ops.aten.select.int(device_put_47, 0, 4) + _local_scalar_dense_380 = torch.ops.aten._local_scalar_dense.default(select_380); select_380 = None + ge_472 = _local_scalar_dense_380 >= 0 + _assert_scalar_380 = torch.ops.aten._assert_scalar.default(ge_472, "Runtime assertion failed for expression u380 >= 0 on node 'ge_380'"); ge_472 = _assert_scalar_380 = None + select_381 = torch.ops.aten.select.int(device_put_47, 0, 5) + _local_scalar_dense_381 = torch.ops.aten._local_scalar_dense.default(select_381); select_381 = None + ge_473 = _local_scalar_dense_381 >= 0 + _assert_scalar_381 = torch.ops.aten._assert_scalar.default(ge_473, "Runtime assertion failed for expression u381 >= 0 on node 'ge_381'"); ge_473 = _assert_scalar_381 = None + select_382 = torch.ops.aten.select.int(device_put_47, 0, 6) + _local_scalar_dense_382 = torch.ops.aten._local_scalar_dense.default(select_382); select_382 = None + ge_474 = _local_scalar_dense_382 >= 0 + _assert_scalar_382 = torch.ops.aten._assert_scalar.default(ge_474, "Runtime assertion failed for expression u382 >= 0 on node 'ge_382'"); ge_474 = _assert_scalar_382 = None + select_383 = torch.ops.aten.select.int(device_put_47, 0, 7); device_put_47 = None + _local_scalar_dense_383 = torch.ops.aten._local_scalar_dense.default(select_383); select_383 = None + ge_475 = _local_scalar_dense_383 >= 0 + _assert_scalar_383 = torch.ops.aten._assert_scalar.default(ge_475, "Runtime assertion failed for expression u383 >= 0 on node 'ge_383'"); ge_475 = _assert_scalar_383 = None + all_to_all_single_70 = torch.ops._c10d_functional.all_to_all_single.default(index_46, [_local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383], [_local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375], '1033'); index_46 = None + sym_size_int_92 = torch.ops.aten.sym_size.int(all_to_all_single_70, 0) + wait_tensor_504 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_70); all_to_all_single_70 = None + sym_sum_46 = torch.sym_sum((_local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383)) + add_1582 = sym_sum_46 + 64; sym_sum_46 = None + add_1583 = add_1582 + 8; add_1582 = None + sub_555 = add_1583 - 1; add_1583 = None + floordiv_23 = sub_555 // 8; sub_555 = None + mul_1149 = floordiv_23 * 8; floordiv_23 = None + cumsum_69 = torch.ops.aten.cumsum.default(wait_tensor_503, 0) + sub_556 = torch.ops.aten.sub.Tensor(cumsum_69, wait_tensor_503); cumsum_69 = None + sum_96 = torch.ops.aten.sum.dim_IntList(view_1606, [0]); view_1606 = None + clamp_min_23 = torch.ops.aten.clamp_min.default(sum_96, 8); sum_96 = None + add_1584 = torch.ops.aten.add.Tensor(clamp_min_23, 8); clamp_min_23 = None + sub_557 = torch.ops.aten.sub.Tensor(add_1584, 1); add_1584 = None + div_118 = torch.ops.aten.div.Tensor_mode(sub_557, 8, rounding_mode = 'floor'); sub_557 = None + mul_1150 = torch.ops.aten.mul.Tensor(div_118, 8); div_118 = None + convert_element_type_1310 = torch.ops.prims.convert_element_type.default(mul_1150, torch.int32); mul_1150 = None + cumsum_70 = torch.ops.aten.cumsum.default(convert_element_type_1310, 0) + sub_558 = torch.ops.aten.sub.Tensor(cumsum_70, convert_element_type_1310); cumsum_70 = None + full_319 = torch.ops.aten.full.default([mul_1149], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1149 = None + triton_kernel_wrapper_functional_proxy_23 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 23, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_503, 'start_index_values_ptr': sub_556, 'write_offsets_ptr': sub_558, 'output_ptr': full_319}, tensors_to_clone = ['output_ptr']); wait_tensor_503 = sub_556 = sub_558 = full_319 = None + getitem_2552 = triton_kernel_wrapper_functional_proxy_23['output_ptr']; triton_kernel_wrapper_functional_proxy_23 = None + cat_211 = torch.ops.aten.cat.default([wait_tensor_504, full_default]); wait_tensor_504 = None + sym_size_int_93 = torch.ops.aten.sym_size.int(cat_211, 0) + sym_sum_47 = torch.sym_sum((1, _local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383)) + index_47 = torch.ops.aten.index.Tensor(cat_211, [getitem_2552]); cat_211 = None + convert_element_type_1312 = torch.ops.prims.convert_element_type.default(primals_401, torch.bfloat16) + all_gather_into_tensor_410 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1312, 16, '1025'); convert_element_type_1312 = None + wait_tensor_505 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_410); all_gather_into_tensor_410 = None + split_139 = torch.ops.aten.split.Tensor(wait_tensor_505, 8); wait_tensor_505 = None + getitem_2569 = split_139[0] + getitem_2570 = split_139[1] + getitem_2571 = split_139[2] + getitem_2572 = split_139[3] + getitem_2573 = split_139[4] + getitem_2574 = split_139[5] + getitem_2575 = split_139[6] + getitem_2576 = split_139[7] + getitem_2577 = split_139[8] + getitem_2578 = split_139[9] + getitem_2579 = split_139[10] + getitem_2580 = split_139[11] + getitem_2581 = split_139[12] + getitem_2582 = split_139[13] + getitem_2583 = split_139[14] + getitem_2584 = split_139[15]; split_139 = None + cat_213 = torch.ops.aten.cat.default([getitem_2569, getitem_2570, getitem_2571, getitem_2572, getitem_2573, getitem_2574, getitem_2575, getitem_2576, getitem_2577, getitem_2578, getitem_2579, getitem_2580, getitem_2581, getitem_2582, getitem_2583, getitem_2584], 1); getitem_2569 = getitem_2570 = getitem_2571 = getitem_2572 = getitem_2573 = getitem_2574 = getitem_2575 = getitem_2576 = getitem_2577 = getitem_2578 = getitem_2579 = getitem_2580 = getitem_2581 = getitem_2582 = getitem_2583 = getitem_2584 = None + convert_element_type_1314 = torch.ops.prims.convert_element_type.default(primals_402, torch.bfloat16) + all_gather_into_tensor_412 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1314, 16, '1025'); convert_element_type_1314 = None + wait_tensor_507 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_412); all_gather_into_tensor_412 = None + split_140 = torch.ops.aten.split.Tensor(wait_tensor_507, 8); wait_tensor_507 = None + getitem_2585 = split_140[0] + getitem_2586 = split_140[1] + getitem_2587 = split_140[2] + getitem_2588 = split_140[3] + getitem_2589 = split_140[4] + getitem_2590 = split_140[5] + getitem_2591 = split_140[6] + getitem_2592 = split_140[7] + getitem_2593 = split_140[8] + getitem_2594 = split_140[9] + getitem_2595 = split_140[10] + getitem_2596 = split_140[11] + getitem_2597 = split_140[12] + getitem_2598 = split_140[13] + getitem_2599 = split_140[14] + getitem_2600 = split_140[15]; split_140 = None + cat_214 = torch.ops.aten.cat.default([getitem_2585, getitem_2586, getitem_2587, getitem_2588, getitem_2589, getitem_2590, getitem_2591, getitem_2592, getitem_2593, getitem_2594, getitem_2595, getitem_2596, getitem_2597, getitem_2598, getitem_2599, getitem_2600], 1); getitem_2585 = getitem_2586 = getitem_2587 = getitem_2588 = getitem_2589 = getitem_2590 = getitem_2591 = getitem_2592 = getitem_2593 = getitem_2594 = getitem_2595 = getitem_2596 = getitem_2597 = getitem_2598 = getitem_2599 = getitem_2600 = None + convert_element_type_1315 = torch.ops.prims.convert_element_type.default(primals_403, torch.bfloat16) + all_gather_into_tensor_413 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1315, 16, '1025'); convert_element_type_1315 = None + wait_tensor_508 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_413); all_gather_into_tensor_413 = None + split_141 = torch.ops.aten.split.Tensor(wait_tensor_508, 8); wait_tensor_508 = None + getitem_2601 = split_141[0] + getitem_2602 = split_141[1] + getitem_2603 = split_141[2] + getitem_2604 = split_141[3] + getitem_2605 = split_141[4] + getitem_2606 = split_141[5] + getitem_2607 = split_141[6] + getitem_2608 = split_141[7] + getitem_2609 = split_141[8] + getitem_2610 = split_141[9] + getitem_2611 = split_141[10] + getitem_2612 = split_141[11] + getitem_2613 = split_141[12] + getitem_2614 = split_141[13] + getitem_2615 = split_141[14] + getitem_2616 = split_141[15]; split_141 = None + cat_215 = torch.ops.aten.cat.default([getitem_2601, getitem_2602, getitem_2603, getitem_2604, getitem_2605, getitem_2606, getitem_2607, getitem_2608, getitem_2609, getitem_2610, getitem_2611, getitem_2612, getitem_2613, getitem_2614, getitem_2615, getitem_2616], 1); getitem_2601 = getitem_2602 = getitem_2603 = getitem_2604 = getitem_2605 = getitem_2606 = getitem_2607 = getitem_2608 = getitem_2609 = getitem_2610 = getitem_2611 = getitem_2612 = getitem_2613 = getitem_2614 = getitem_2615 = getitem_2616 = None + cumsum_71 = torch.ops.aten.cumsum.default(convert_element_type_1310, 0, dtype = torch.int32); convert_element_type_1310 = None + permute_365 = torch.ops.aten.permute.default(cat_213, [0, 2, 1]); cat_213 = None + _grouped_mm_69 = torch.ops.aten._grouped_mm.default(index_47, permute_365, cumsum_71) + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(_grouped_mm_69, torch.float32) + neg_47 = torch.ops.aten.neg.default(convert_element_type_1318) + exp_71 = torch.ops.aten.exp.default(neg_47); neg_47 = None + add_1596 = torch.ops.aten.add.Tensor(exp_71, 1); exp_71 = None + div_119 = torch.ops.aten.div.Tensor(convert_element_type_1318, add_1596); convert_element_type_1318 = add_1596 = None + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(div_119, torch.bfloat16); div_119 = None + permute_366 = torch.ops.aten.permute.default(cat_215, [0, 2, 1]); cat_215 = None + _grouped_mm_70 = torch.ops.aten._grouped_mm.default(index_47, permute_366, cumsum_71) + mul_1162 = torch.ops.aten.mul.Tensor(convert_element_type_1319, _grouped_mm_70); convert_element_type_1319 = None + permute_367 = torch.ops.aten.permute.default(cat_214, [0, 2, 1]); cat_214 = None + _grouped_mm_71 = torch.ops.aten._grouped_mm.default(mul_1162, permute_367, cumsum_71) + empty_23 = torch.ops.aten.empty.memory_format([sym_size_int_93, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_46 = torch.ops.aten.index_put.default(empty_23, [getitem_2552], _grouped_mm_71); empty_23 = _grouped_mm_71 = None + slice_149 = torch.ops.aten.slice.Tensor(index_put_46, 0, 0, -1); index_put_46 = None + all_to_all_single_71 = torch.ops._c10d_functional.all_to_all_single.default(slice_149, [_local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375], [_local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383], '1033'); slice_149 = None + wait_tensor_511 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_71); all_to_all_single_71 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(primals_404, torch.bfloat16) + all_gather_into_tensor_416 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1320, 128, '0'); convert_element_type_1320 = None + wait_tensor_512 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_416); all_gather_into_tensor_416 = None + permute_368 = torch.ops.aten.permute.default(wait_tensor_512, [1, 0]); wait_tensor_512 = None + mm_196 = torch.ops.aten.mm.default(view_1599, permute_368); permute_368 = None + convert_element_type_1323 = torch.ops.prims.convert_element_type.default(mm_196, torch.float32) + neg_48 = torch.ops.aten.neg.default(convert_element_type_1323) + exp_72 = torch.ops.aten.exp.default(neg_48); neg_48 = None + add_1632 = torch.ops.aten.add.Tensor(exp_72, 1); exp_72 = None + div_120 = torch.ops.aten.div.Tensor(convert_element_type_1323, add_1632); convert_element_type_1323 = add_1632 = None + convert_element_type_1324 = torch.ops.prims.convert_element_type.default(div_120, torch.bfloat16); div_120 = None + convert_element_type_1325 = torch.ops.prims.convert_element_type.default(primals_405, torch.bfloat16) + all_gather_into_tensor_417 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1325, 128, '0'); convert_element_type_1325 = None + wait_tensor_513 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_417); all_gather_into_tensor_417 = None + permute_369 = torch.ops.aten.permute.default(wait_tensor_513, [1, 0]); wait_tensor_513 = None + mm_197 = torch.ops.aten.mm.default(view_1599, permute_369); permute_369 = None + mul_1182 = torch.ops.aten.mul.Tensor(convert_element_type_1324, mm_197); convert_element_type_1324 = None + convert_element_type_1328 = torch.ops.prims.convert_element_type.default(primals_406, torch.bfloat16) + all_gather_into_tensor_418 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1328, 128, '0'); convert_element_type_1328 = None + wait_tensor_514 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_418); all_gather_into_tensor_418 = None + permute_370 = torch.ops.aten.permute.default(wait_tensor_514, [1, 0]); wait_tensor_514 = None + mm_198 = torch.ops.aten.mm.default(mul_1182, permute_370); permute_370 = None + index_put_47 = torch.ops.aten.index_put.default(full_default_1, [getitem_2551], wait_tensor_511); wait_tensor_511 = None + view_1639 = torch.ops.aten.view.default(mul_1144, [-1, 1, 6]); mul_1144 = None + view_1640 = torch.ops.aten.view.default(index_put_47, [-1, 6, 2048]); index_put_47 = None + convert_element_type_1331 = torch.ops.prims.convert_element_type.default(view_1640, torch.float32); view_1640 = None + bmm_23 = torch.ops.aten.bmm.default(view_1639, convert_element_type_1331) + convert_element_type_1332 = torch.ops.prims.convert_element_type.default(bmm_23, torch.bfloat16); bmm_23 = None + squeeze_23 = torch.ops.aten.squeeze.dim(convert_element_type_1332, 1); convert_element_type_1332 = None + add_1636 = torch.ops.aten.add.Tensor(mm_198, squeeze_23); mm_198 = squeeze_23 = None + view_1641 = torch.ops.aten.view.default(add_1636, [2, 4096, 2048]); add_1636 = None + add_1637 = torch.ops.aten.add.Tensor(add_1572, view_1641); view_1641 = None + convert_element_type_1333 = torch.ops.prims.convert_element_type.default(primals_407, torch.bfloat16) + all_gather_into_tensor_419 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1333, 128, '0'); convert_element_type_1333 = None + wait_tensor_515 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_419); all_gather_into_tensor_419 = None + convert_element_type_1334 = torch.ops.prims.convert_element_type.default(add_1637, torch.float32) + pow_76 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1334, 2) + mean_75 = torch.ops.aten.mean.dim(pow_76, [2], True); pow_76 = None + add_1638 = torch.ops.aten.add.Scalar(mean_75, 1e-05); mean_75 = None + rsqrt_75 = torch.ops.aten.rsqrt.default(add_1638); add_1638 = None + mul_1185 = torch.ops.aten.mul.Tensor(convert_element_type_1334, rsqrt_75); convert_element_type_1334 = None + mul_1186 = torch.ops.aten.mul.Tensor(mul_1185, wait_tensor_515); mul_1185 = wait_tensor_515 = None + convert_element_type_1335 = torch.ops.prims.convert_element_type.default(mul_1186, torch.bfloat16); mul_1186 = None + convert_element_type_1336 = torch.ops.prims.convert_element_type.default(primals_408, torch.bfloat16) + all_gather_into_tensor_420 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1336, 128, '0'); convert_element_type_1336 = None + wait_tensor_516 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_420); all_gather_into_tensor_420 = None + permute_371 = torch.ops.aten.permute.default(wait_tensor_516, [1, 0]); wait_tensor_516 = None + view_1644 = torch.ops.aten.view.default(convert_element_type_1335, [8192, 2048]); convert_element_type_1335 = None + mm_199 = torch.ops.aten.mm.default(view_1644, permute_371); permute_371 = None + view_1645 = torch.ops.aten.view.default(mm_199, [2, 4096, 3072]); mm_199 = None + view_1646 = torch.ops.aten.view.default(view_1645, [2, 4096, -1, 192]); view_1645 = None + split_with_sizes_75 = torch.ops.aten.split_with_sizes.default(view_1646, [128, 64], -1); view_1646 = None + getitem_2649 = split_with_sizes_75[0] + getitem_2650 = split_with_sizes_75[1]; split_with_sizes_75 = None + convert_element_type_1339 = torch.ops.prims.convert_element_type.default(getitem_2650, torch.float32); getitem_2650 = None + view_1647 = torch.ops.aten.view.default(convert_element_type_1339, [2, 4096, 16, -1, 2]); convert_element_type_1339 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_1647); view_1647 = None + mul_1187 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_7); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_1187); mul_1187 = None + view_1649 = torch.ops.aten.view.default(view_as_real_50, [2, 4096, 16, 64]); view_as_real_50 = None + convert_element_type_1340 = torch.ops.prims.convert_element_type.default(view_1649, torch.bfloat16); view_1649 = None + cat_218 = torch.ops.aten.cat.default([getitem_2649, convert_element_type_1340], -1); getitem_2649 = convert_element_type_1340 = None + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(primals_409, torch.bfloat16) + all_gather_into_tensor_421 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1341, 128, '0'); convert_element_type_1341 = None + wait_tensor_517 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_421); all_gather_into_tensor_421 = None + slice_151 = torch.ops.aten.slice.Tensor(wait_tensor_517, 0, 0, 576); wait_tensor_517 = None + permute_372 = torch.ops.aten.permute.default(slice_151, [1, 0]); slice_151 = None + mm_200 = torch.ops.aten.mm.default(view_1644, permute_372); permute_372 = None + view_1652 = torch.ops.aten.view.default(mm_200, [2, 4096, 576]); mm_200 = None + split_with_sizes_76 = torch.ops.aten.split_with_sizes.default(view_1652, [512, 64], -1); view_1652 = None + getitem_2651 = split_with_sizes_76[0] + getitem_2652 = split_with_sizes_76[1]; split_with_sizes_76 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(getitem_2652, 2); getitem_2652 = None + convert_element_type_1344 = torch.ops.prims.convert_element_type.default(unsqueeze_49, torch.float32); unsqueeze_49 = None + view_1653 = torch.ops.aten.view.default(convert_element_type_1344, [2, 4096, 1, -1, 2]); convert_element_type_1344 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_1653); view_1653 = None + mul_1188 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_7); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_1188); mul_1188 = None + view_1655 = torch.ops.aten.view.default(view_as_real_51, [2, 4096, 1, 64]); view_as_real_51 = None + convert_element_type_1345 = torch.ops.prims.convert_element_type.default(view_1655, torch.bfloat16); view_1655 = None + convert_element_type_1346 = torch.ops.prims.convert_element_type.default(primals_410, torch.bfloat16) + all_gather_into_tensor_422 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1346, 128, '0'); convert_element_type_1346 = None + wait_tensor_518 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_422); all_gather_into_tensor_422 = None + convert_element_type_1347 = torch.ops.prims.convert_element_type.default(getitem_2651, torch.float32) + pow_77 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1347, 2) + mean_76 = torch.ops.aten.mean.dim(pow_77, [2], True); pow_77 = None + add_1639 = torch.ops.aten.add.Scalar(mean_76, 1e-05); mean_76 = None + rsqrt_76 = torch.ops.aten.rsqrt.default(add_1639); add_1639 = None + mul_1189 = torch.ops.aten.mul.Tensor(convert_element_type_1347, rsqrt_76); convert_element_type_1347 = None + mul_1190 = torch.ops.aten.mul.Tensor(mul_1189, wait_tensor_518); mul_1189 = wait_tensor_518 = None + convert_element_type_1348 = torch.ops.prims.convert_element_type.default(mul_1190, torch.bfloat16); mul_1190 = None + convert_element_type_1349 = torch.ops.prims.convert_element_type.default(primals_411, torch.bfloat16) + all_gather_into_tensor_423 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1349, 128, '0'); convert_element_type_1349 = None + wait_tensor_519 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_423); all_gather_into_tensor_423 = None + permute_373 = torch.ops.aten.permute.default(wait_tensor_519, [1, 0]); wait_tensor_519 = None + view_1658 = torch.ops.aten.view.default(convert_element_type_1348, [8192, 512]); convert_element_type_1348 = None + mm_201 = torch.ops.aten.mm.default(view_1658, permute_373); permute_373 = None + view_1659 = torch.ops.aten.view.default(mm_201, [2, 4096, 4096]); mm_201 = None + view_1660 = torch.ops.aten.view.default(view_1659, [2, 4096, -1, 256]); view_1659 = None + split_with_sizes_77 = torch.ops.aten.split_with_sizes.default(view_1660, [128, 128], -1); view_1660 = None + getitem_2653 = split_with_sizes_77[0] + getitem_2654 = split_with_sizes_77[1]; split_with_sizes_77 = None + expand_25 = torch.ops.aten.expand.default(convert_element_type_1345, [-1, -1, 16, -1]); convert_element_type_1345 = None + cat_219 = torch.ops.aten.cat.default([getitem_2653, expand_25], -1); getitem_2653 = expand_25 = None + permute_374 = torch.ops.aten.permute.default(cat_218, [0, 2, 1, 3]); cat_218 = None + permute_375 = torch.ops.aten.permute.default(cat_219, [0, 2, 1, 3]); cat_219 = None + permute_376 = torch.ops.aten.permute.default(getitem_2654, [0, 2, 1, 3]); getitem_2654 = None + sdpa_score25 = self.sdpa_score25 + sdpa_mask25 = self.sdpa_mask25 + flex_attention_25 = torch.ops.higher_order.flex_attention(permute_374, permute_375, permute_376, sdpa_score25, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask25), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score25 = sdpa_mask25 = None + getitem_2655 = flex_attention_25[0] + getitem_2656 = flex_attention_25[1]; flex_attention_25 = None + permute_377 = torch.ops.aten.permute.default(getitem_2655, [0, 2, 1, 3]) + view_1661 = torch.ops.aten.view.default(permute_377, [2, 4096, -1]); permute_377 = None + convert_element_type_1352 = torch.ops.prims.convert_element_type.default(primals_412, torch.bfloat16) + all_gather_into_tensor_424 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1352, 128, '0'); convert_element_type_1352 = None + wait_tensor_520 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_424); all_gather_into_tensor_424 = None + permute_378 = torch.ops.aten.permute.default(wait_tensor_520, [1, 0]); wait_tensor_520 = None + view_1663 = torch.ops.aten.view.default(view_1661, [8192, 2048]); view_1661 = None + mm_202 = torch.ops.aten.mm.default(view_1663, permute_378); view_1663 = permute_378 = None + view_1664 = torch.ops.aten.view.default(mm_202, [2, 4096, 2048]); mm_202 = None + add_1640 = torch.ops.aten.add.Tensor(add_1637, view_1664); view_1664 = None + convert_element_type_1355 = torch.ops.prims.convert_element_type.default(primals_413, torch.bfloat16) + all_gather_into_tensor_425 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1355, 128, '0'); convert_element_type_1355 = None + wait_tensor_521 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_425); all_gather_into_tensor_425 = None + convert_element_type_1356 = torch.ops.prims.convert_element_type.default(add_1640, torch.float32) + pow_78 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1356, 2) + mean_77 = torch.ops.aten.mean.dim(pow_78, [2], True); pow_78 = None + add_1641 = torch.ops.aten.add.Scalar(mean_77, 1e-05); mean_77 = None + rsqrt_77 = torch.ops.aten.rsqrt.default(add_1641); add_1641 = None + mul_1191 = torch.ops.aten.mul.Tensor(convert_element_type_1356, rsqrt_77); convert_element_type_1356 = None + mul_1192 = torch.ops.aten.mul.Tensor(mul_1191, wait_tensor_521); mul_1191 = wait_tensor_521 = None + convert_element_type_1357 = torch.ops.prims.convert_element_type.default(mul_1192, torch.bfloat16); mul_1192 = None + view_1666 = torch.ops.aten.view.default(convert_element_type_1357, [-1, 2048]); convert_element_type_1357 = None + convert_element_type_1358 = torch.ops.prims.convert_element_type.default(primals_415, torch.bfloat16) + all_gather_into_tensor_426 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1358, 128, '0'); convert_element_type_1358 = None + wait_tensor_522 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_426); all_gather_into_tensor_426 = None + slice_153 = torch.ops.aten.slice.Tensor(wait_tensor_522, 0, 0, 64); wait_tensor_522 = None + permute_379 = torch.ops.aten.permute.default(slice_153, [1, 0]); slice_153 = None + mm_203 = torch.ops.aten.mm.default(view_1666, permute_379); permute_379 = None + convert_element_type_1361 = torch.ops.prims.convert_element_type.default(mm_203, torch.float32) + amax_24 = torch.ops.aten.amax.default(convert_element_type_1361, [1], True) + sub_576 = torch.ops.aten.sub.Tensor(convert_element_type_1361, amax_24); convert_element_type_1361 = None + exp_73 = torch.ops.aten.exp.default(sub_576); sub_576 = None + sum_97 = torch.ops.aten.sum.dim_IntList(exp_73, [1], True) + div_121 = torch.ops.aten.div.Tensor(exp_73, sum_97); exp_73 = None + add_1642 = torch.ops.aten.add.Tensor(div_121, primals_414); primals_414 = None + topk_24 = torch.ops.aten.topk.default(add_1642, 6, -1, True, False); add_1642 = None + getitem_2659 = topk_24[1]; topk_24 = None + gather_24 = torch.ops.aten.gather.default(div_121, 1, getitem_2659); div_121 = None + mul_1193 = torch.ops.aten.mul.Tensor(gather_24, 1.0); gather_24 = None + view_1668 = torch.ops.aten.view.default(getitem_2659, [-1]) + histc_48 = torch.ops.aten.histc.default(view_1668, 64, 0, 64) + add_1643 = torch.ops.aten.add.Tensor(primals_416, histc_48) + sort_24 = torch.ops.aten.sort.stable(view_1668, stable = True); view_1668 = None + getitem_2661 = sort_24[1]; sort_24 = None + div_122 = torch.ops.aten.div.Tensor_mode(getitem_2661, 6, rounding_mode = 'floor') + index_48 = torch.ops.aten.index.Tensor(view_1666, [div_122]) + all_to_all_single_72 = torch.ops._c10d_functional.all_to_all_single.default(histc_48, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_523 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_72); all_to_all_single_72 = None + wait_tensor_524 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_523); wait_tensor_523 = None + view_1672 = torch.ops.aten.view.default(histc_48, [8, -1]); histc_48 = None + sum_98 = torch.ops.aten.sum.dim_IntList(view_1672, [1]); view_1672 = None + device_put_48 = torch.ops.prims.device_put.default(sum_98, device(type='cpu'), True); sum_98 = None + view_1673 = torch.ops.aten.view.default(wait_tensor_524, [8, -1]) + sum_99 = torch.ops.aten.sum.dim_IntList(view_1673, [1]) + device_put_49 = torch.ops.prims.device_put.default(sum_99, device(type='cpu')); sum_99 = None + select_384 = torch.ops.aten.select.int(device_put_48, 0, 0) + _local_scalar_dense_384 = torch.ops.aten._local_scalar_dense.default(select_384); select_384 = None + ge_480 = _local_scalar_dense_384 >= 0 + _assert_scalar_384 = torch.ops.aten._assert_scalar.default(ge_480, "Runtime assertion failed for expression u384 >= 0 on node 'ge_384'"); ge_480 = _assert_scalar_384 = None + select_385 = torch.ops.aten.select.int(device_put_48, 0, 1) + _local_scalar_dense_385 = torch.ops.aten._local_scalar_dense.default(select_385); select_385 = None + ge_481 = _local_scalar_dense_385 >= 0 + _assert_scalar_385 = torch.ops.aten._assert_scalar.default(ge_481, "Runtime assertion failed for expression u385 >= 0 on node 'ge_385'"); ge_481 = _assert_scalar_385 = None + select_386 = torch.ops.aten.select.int(device_put_48, 0, 2) + _local_scalar_dense_386 = torch.ops.aten._local_scalar_dense.default(select_386); select_386 = None + ge_482 = _local_scalar_dense_386 >= 0 + _assert_scalar_386 = torch.ops.aten._assert_scalar.default(ge_482, "Runtime assertion failed for expression u386 >= 0 on node 'ge_386'"); ge_482 = _assert_scalar_386 = None + select_387 = torch.ops.aten.select.int(device_put_48, 0, 3) + _local_scalar_dense_387 = torch.ops.aten._local_scalar_dense.default(select_387); select_387 = None + ge_483 = _local_scalar_dense_387 >= 0 + _assert_scalar_387 = torch.ops.aten._assert_scalar.default(ge_483, "Runtime assertion failed for expression u387 >= 0 on node 'ge_387'"); ge_483 = _assert_scalar_387 = None + select_388 = torch.ops.aten.select.int(device_put_48, 0, 4) + _local_scalar_dense_388 = torch.ops.aten._local_scalar_dense.default(select_388); select_388 = None + ge_484 = _local_scalar_dense_388 >= 0 + _assert_scalar_388 = torch.ops.aten._assert_scalar.default(ge_484, "Runtime assertion failed for expression u388 >= 0 on node 'ge_388'"); ge_484 = _assert_scalar_388 = None + select_389 = torch.ops.aten.select.int(device_put_48, 0, 5) + _local_scalar_dense_389 = torch.ops.aten._local_scalar_dense.default(select_389); select_389 = None + ge_485 = _local_scalar_dense_389 >= 0 + _assert_scalar_389 = torch.ops.aten._assert_scalar.default(ge_485, "Runtime assertion failed for expression u389 >= 0 on node 'ge_389'"); ge_485 = _assert_scalar_389 = None + select_390 = torch.ops.aten.select.int(device_put_48, 0, 6) + _local_scalar_dense_390 = torch.ops.aten._local_scalar_dense.default(select_390); select_390 = None + ge_486 = _local_scalar_dense_390 >= 0 + _assert_scalar_390 = torch.ops.aten._assert_scalar.default(ge_486, "Runtime assertion failed for expression u390 >= 0 on node 'ge_390'"); ge_486 = _assert_scalar_390 = None + select_391 = torch.ops.aten.select.int(device_put_48, 0, 7); device_put_48 = None + _local_scalar_dense_391 = torch.ops.aten._local_scalar_dense.default(select_391); select_391 = None + ge_487 = _local_scalar_dense_391 >= 0 + _assert_scalar_391 = torch.ops.aten._assert_scalar.default(ge_487, "Runtime assertion failed for expression u391 >= 0 on node 'ge_391'"); ge_487 = _assert_scalar_391 = None + select_392 = torch.ops.aten.select.int(device_put_49, 0, 0) + _local_scalar_dense_392 = torch.ops.aten._local_scalar_dense.default(select_392); select_392 = None + ge_488 = _local_scalar_dense_392 >= 0 + _assert_scalar_392 = torch.ops.aten._assert_scalar.default(ge_488, "Runtime assertion failed for expression u392 >= 0 on node 'ge_392'"); ge_488 = _assert_scalar_392 = None + select_393 = torch.ops.aten.select.int(device_put_49, 0, 1) + _local_scalar_dense_393 = torch.ops.aten._local_scalar_dense.default(select_393); select_393 = None + ge_489 = _local_scalar_dense_393 >= 0 + _assert_scalar_393 = torch.ops.aten._assert_scalar.default(ge_489, "Runtime assertion failed for expression u393 >= 0 on node 'ge_393'"); ge_489 = _assert_scalar_393 = None + select_394 = torch.ops.aten.select.int(device_put_49, 0, 2) + _local_scalar_dense_394 = torch.ops.aten._local_scalar_dense.default(select_394); select_394 = None + ge_490 = _local_scalar_dense_394 >= 0 + _assert_scalar_394 = torch.ops.aten._assert_scalar.default(ge_490, "Runtime assertion failed for expression u394 >= 0 on node 'ge_394'"); ge_490 = _assert_scalar_394 = None + select_395 = torch.ops.aten.select.int(device_put_49, 0, 3) + _local_scalar_dense_395 = torch.ops.aten._local_scalar_dense.default(select_395); select_395 = None + ge_491 = _local_scalar_dense_395 >= 0 + _assert_scalar_395 = torch.ops.aten._assert_scalar.default(ge_491, "Runtime assertion failed for expression u395 >= 0 on node 'ge_395'"); ge_491 = _assert_scalar_395 = None + select_396 = torch.ops.aten.select.int(device_put_49, 0, 4) + _local_scalar_dense_396 = torch.ops.aten._local_scalar_dense.default(select_396); select_396 = None + ge_492 = _local_scalar_dense_396 >= 0 + _assert_scalar_396 = torch.ops.aten._assert_scalar.default(ge_492, "Runtime assertion failed for expression u396 >= 0 on node 'ge_396'"); ge_492 = _assert_scalar_396 = None + select_397 = torch.ops.aten.select.int(device_put_49, 0, 5) + _local_scalar_dense_397 = torch.ops.aten._local_scalar_dense.default(select_397); select_397 = None + ge_493 = _local_scalar_dense_397 >= 0 + _assert_scalar_397 = torch.ops.aten._assert_scalar.default(ge_493, "Runtime assertion failed for expression u397 >= 0 on node 'ge_397'"); ge_493 = _assert_scalar_397 = None + select_398 = torch.ops.aten.select.int(device_put_49, 0, 6) + _local_scalar_dense_398 = torch.ops.aten._local_scalar_dense.default(select_398); select_398 = None + ge_494 = _local_scalar_dense_398 >= 0 + _assert_scalar_398 = torch.ops.aten._assert_scalar.default(ge_494, "Runtime assertion failed for expression u398 >= 0 on node 'ge_398'"); ge_494 = _assert_scalar_398 = None + select_399 = torch.ops.aten.select.int(device_put_49, 0, 7); device_put_49 = None + _local_scalar_dense_399 = torch.ops.aten._local_scalar_dense.default(select_399); select_399 = None + ge_495 = _local_scalar_dense_399 >= 0 + _assert_scalar_399 = torch.ops.aten._assert_scalar.default(ge_495, "Runtime assertion failed for expression u399 >= 0 on node 'ge_399'"); ge_495 = _assert_scalar_399 = None + all_to_all_single_73 = torch.ops._c10d_functional.all_to_all_single.default(index_48, [_local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399], [_local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391], '1033'); index_48 = None + sym_size_int_96 = torch.ops.aten.sym_size.int(all_to_all_single_73, 0) + wait_tensor_525 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_73); all_to_all_single_73 = None + sym_sum_48 = torch.sym_sum((_local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399)) + add_1650 = sym_sum_48 + 64; sym_sum_48 = None + add_1651 = add_1650 + 8; add_1650 = None + sub_579 = add_1651 - 1; add_1651 = None + floordiv_24 = sub_579 // 8; sub_579 = None + mul_1198 = floordiv_24 * 8; floordiv_24 = None + cumsum_72 = torch.ops.aten.cumsum.default(wait_tensor_524, 0) + sub_580 = torch.ops.aten.sub.Tensor(cumsum_72, wait_tensor_524); cumsum_72 = None + sum_100 = torch.ops.aten.sum.dim_IntList(view_1673, [0]); view_1673 = None + clamp_min_24 = torch.ops.aten.clamp_min.default(sum_100, 8); sum_100 = None + add_1652 = torch.ops.aten.add.Tensor(clamp_min_24, 8); clamp_min_24 = None + sub_581 = torch.ops.aten.sub.Tensor(add_1652, 1); add_1652 = None + div_123 = torch.ops.aten.div.Tensor_mode(sub_581, 8, rounding_mode = 'floor'); sub_581 = None + mul_1199 = torch.ops.aten.mul.Tensor(div_123, 8); div_123 = None + convert_element_type_1364 = torch.ops.prims.convert_element_type.default(mul_1199, torch.int32); mul_1199 = None + cumsum_73 = torch.ops.aten.cumsum.default(convert_element_type_1364, 0) + sub_582 = torch.ops.aten.sub.Tensor(cumsum_73, convert_element_type_1364); cumsum_73 = None + full_332 = torch.ops.aten.full.default([mul_1198], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1198 = None + triton_kernel_wrapper_functional_proxy_24 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 24, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_524, 'start_index_values_ptr': sub_580, 'write_offsets_ptr': sub_582, 'output_ptr': full_332}, tensors_to_clone = ['output_ptr']); wait_tensor_524 = sub_580 = sub_582 = full_332 = None + getitem_2662 = triton_kernel_wrapper_functional_proxy_24['output_ptr']; triton_kernel_wrapper_functional_proxy_24 = None + cat_220 = torch.ops.aten.cat.default([wait_tensor_525, full_default]); wait_tensor_525 = None + sym_size_int_97 = torch.ops.aten.sym_size.int(cat_220, 0) + sym_sum_49 = torch.sym_sum((1, _local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399)) + index_49 = torch.ops.aten.index.Tensor(cat_220, [getitem_2662]); cat_220 = None + convert_element_type_1366 = torch.ops.prims.convert_element_type.default(primals_417, torch.bfloat16) + all_gather_into_tensor_427 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1366, 16, '1025'); convert_element_type_1366 = None + wait_tensor_526 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_427); all_gather_into_tensor_427 = None + split_145 = torch.ops.aten.split.Tensor(wait_tensor_526, 8); wait_tensor_526 = None + getitem_2679 = split_145[0] + getitem_2680 = split_145[1] + getitem_2681 = split_145[2] + getitem_2682 = split_145[3] + getitem_2683 = split_145[4] + getitem_2684 = split_145[5] + getitem_2685 = split_145[6] + getitem_2686 = split_145[7] + getitem_2687 = split_145[8] + getitem_2688 = split_145[9] + getitem_2689 = split_145[10] + getitem_2690 = split_145[11] + getitem_2691 = split_145[12] + getitem_2692 = split_145[13] + getitem_2693 = split_145[14] + getitem_2694 = split_145[15]; split_145 = None + cat_222 = torch.ops.aten.cat.default([getitem_2679, getitem_2680, getitem_2681, getitem_2682, getitem_2683, getitem_2684, getitem_2685, getitem_2686, getitem_2687, getitem_2688, getitem_2689, getitem_2690, getitem_2691, getitem_2692, getitem_2693, getitem_2694], 1); getitem_2679 = getitem_2680 = getitem_2681 = getitem_2682 = getitem_2683 = getitem_2684 = getitem_2685 = getitem_2686 = getitem_2687 = getitem_2688 = getitem_2689 = getitem_2690 = getitem_2691 = getitem_2692 = getitem_2693 = getitem_2694 = None + convert_element_type_1368 = torch.ops.prims.convert_element_type.default(primals_418, torch.bfloat16) + all_gather_into_tensor_429 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1368, 16, '1025'); convert_element_type_1368 = None + wait_tensor_528 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_429); all_gather_into_tensor_429 = None + split_146 = torch.ops.aten.split.Tensor(wait_tensor_528, 8); wait_tensor_528 = None + getitem_2695 = split_146[0] + getitem_2696 = split_146[1] + getitem_2697 = split_146[2] + getitem_2698 = split_146[3] + getitem_2699 = split_146[4] + getitem_2700 = split_146[5] + getitem_2701 = split_146[6] + getitem_2702 = split_146[7] + getitem_2703 = split_146[8] + getitem_2704 = split_146[9] + getitem_2705 = split_146[10] + getitem_2706 = split_146[11] + getitem_2707 = split_146[12] + getitem_2708 = split_146[13] + getitem_2709 = split_146[14] + getitem_2710 = split_146[15]; split_146 = None + cat_223 = torch.ops.aten.cat.default([getitem_2695, getitem_2696, getitem_2697, getitem_2698, getitem_2699, getitem_2700, getitem_2701, getitem_2702, getitem_2703, getitem_2704, getitem_2705, getitem_2706, getitem_2707, getitem_2708, getitem_2709, getitem_2710], 1); getitem_2695 = getitem_2696 = getitem_2697 = getitem_2698 = getitem_2699 = getitem_2700 = getitem_2701 = getitem_2702 = getitem_2703 = getitem_2704 = getitem_2705 = getitem_2706 = getitem_2707 = getitem_2708 = getitem_2709 = getitem_2710 = None + convert_element_type_1369 = torch.ops.prims.convert_element_type.default(primals_419, torch.bfloat16) + all_gather_into_tensor_430 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1369, 16, '1025'); convert_element_type_1369 = None + wait_tensor_529 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_430); all_gather_into_tensor_430 = None + split_147 = torch.ops.aten.split.Tensor(wait_tensor_529, 8); wait_tensor_529 = None + getitem_2711 = split_147[0] + getitem_2712 = split_147[1] + getitem_2713 = split_147[2] + getitem_2714 = split_147[3] + getitem_2715 = split_147[4] + getitem_2716 = split_147[5] + getitem_2717 = split_147[6] + getitem_2718 = split_147[7] + getitem_2719 = split_147[8] + getitem_2720 = split_147[9] + getitem_2721 = split_147[10] + getitem_2722 = split_147[11] + getitem_2723 = split_147[12] + getitem_2724 = split_147[13] + getitem_2725 = split_147[14] + getitem_2726 = split_147[15]; split_147 = None + cat_224 = torch.ops.aten.cat.default([getitem_2711, getitem_2712, getitem_2713, getitem_2714, getitem_2715, getitem_2716, getitem_2717, getitem_2718, getitem_2719, getitem_2720, getitem_2721, getitem_2722, getitem_2723, getitem_2724, getitem_2725, getitem_2726], 1); getitem_2711 = getitem_2712 = getitem_2713 = getitem_2714 = getitem_2715 = getitem_2716 = getitem_2717 = getitem_2718 = getitem_2719 = getitem_2720 = getitem_2721 = getitem_2722 = getitem_2723 = getitem_2724 = getitem_2725 = getitem_2726 = None + cumsum_74 = torch.ops.aten.cumsum.default(convert_element_type_1364, 0, dtype = torch.int32); convert_element_type_1364 = None + permute_380 = torch.ops.aten.permute.default(cat_222, [0, 2, 1]); cat_222 = None + _grouped_mm_72 = torch.ops.aten._grouped_mm.default(index_49, permute_380, cumsum_74) + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(_grouped_mm_72, torch.float32) + neg_49 = torch.ops.aten.neg.default(convert_element_type_1372) + exp_74 = torch.ops.aten.exp.default(neg_49); neg_49 = None + add_1664 = torch.ops.aten.add.Tensor(exp_74, 1); exp_74 = None + div_124 = torch.ops.aten.div.Tensor(convert_element_type_1372, add_1664); convert_element_type_1372 = add_1664 = None + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(div_124, torch.bfloat16); div_124 = None + permute_381 = torch.ops.aten.permute.default(cat_224, [0, 2, 1]); cat_224 = None + _grouped_mm_73 = torch.ops.aten._grouped_mm.default(index_49, permute_381, cumsum_74) + mul_1211 = torch.ops.aten.mul.Tensor(convert_element_type_1373, _grouped_mm_73); convert_element_type_1373 = None + permute_382 = torch.ops.aten.permute.default(cat_223, [0, 2, 1]); cat_223 = None + _grouped_mm_74 = torch.ops.aten._grouped_mm.default(mul_1211, permute_382, cumsum_74) + empty_24 = torch.ops.aten.empty.memory_format([sym_size_int_97, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_48 = torch.ops.aten.index_put.default(empty_24, [getitem_2662], _grouped_mm_74); empty_24 = _grouped_mm_74 = None + slice_155 = torch.ops.aten.slice.Tensor(index_put_48, 0, 0, -1); index_put_48 = None + all_to_all_single_74 = torch.ops._c10d_functional.all_to_all_single.default(slice_155, [_local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391], [_local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399], '1033'); slice_155 = None + wait_tensor_532 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_74); all_to_all_single_74 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(primals_420, torch.bfloat16) + all_gather_into_tensor_433 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1374, 128, '0'); convert_element_type_1374 = None + wait_tensor_533 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_433); all_gather_into_tensor_433 = None + permute_383 = torch.ops.aten.permute.default(wait_tensor_533, [1, 0]); wait_tensor_533 = None + mm_204 = torch.ops.aten.mm.default(view_1666, permute_383); permute_383 = None + convert_element_type_1377 = torch.ops.prims.convert_element_type.default(mm_204, torch.float32) + neg_50 = torch.ops.aten.neg.default(convert_element_type_1377) + exp_75 = torch.ops.aten.exp.default(neg_50); neg_50 = None + add_1700 = torch.ops.aten.add.Tensor(exp_75, 1); exp_75 = None + div_125 = torch.ops.aten.div.Tensor(convert_element_type_1377, add_1700); convert_element_type_1377 = add_1700 = None + convert_element_type_1378 = torch.ops.prims.convert_element_type.default(div_125, torch.bfloat16); div_125 = None + convert_element_type_1379 = torch.ops.prims.convert_element_type.default(primals_421, torch.bfloat16) + all_gather_into_tensor_434 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1379, 128, '0'); convert_element_type_1379 = None + wait_tensor_534 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_434); all_gather_into_tensor_434 = None + permute_384 = torch.ops.aten.permute.default(wait_tensor_534, [1, 0]); wait_tensor_534 = None + mm_205 = torch.ops.aten.mm.default(view_1666, permute_384); permute_384 = None + mul_1231 = torch.ops.aten.mul.Tensor(convert_element_type_1378, mm_205); convert_element_type_1378 = None + convert_element_type_1382 = torch.ops.prims.convert_element_type.default(primals_422, torch.bfloat16) + all_gather_into_tensor_435 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1382, 128, '0'); convert_element_type_1382 = None + wait_tensor_535 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_435); all_gather_into_tensor_435 = None + permute_385 = torch.ops.aten.permute.default(wait_tensor_535, [1, 0]); wait_tensor_535 = None + mm_206 = torch.ops.aten.mm.default(mul_1231, permute_385); permute_385 = None + index_put_49 = torch.ops.aten.index_put.default(full_default_1, [getitem_2661], wait_tensor_532); wait_tensor_532 = None + view_1706 = torch.ops.aten.view.default(mul_1193, [-1, 1, 6]); mul_1193 = None + view_1707 = torch.ops.aten.view.default(index_put_49, [-1, 6, 2048]); index_put_49 = None + convert_element_type_1385 = torch.ops.prims.convert_element_type.default(view_1707, torch.float32); view_1707 = None + bmm_24 = torch.ops.aten.bmm.default(view_1706, convert_element_type_1385) + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(bmm_24, torch.bfloat16); bmm_24 = None + squeeze_24 = torch.ops.aten.squeeze.dim(convert_element_type_1386, 1); convert_element_type_1386 = None + add_1704 = torch.ops.aten.add.Tensor(mm_206, squeeze_24); mm_206 = squeeze_24 = None + view_1708 = torch.ops.aten.view.default(add_1704, [2, 4096, 2048]); add_1704 = None + add_1705 = torch.ops.aten.add.Tensor(add_1640, view_1708); view_1708 = None + convert_element_type_1387 = torch.ops.prims.convert_element_type.default(primals_423, torch.bfloat16) + all_gather_into_tensor_436 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1387, 128, '0'); convert_element_type_1387 = None + wait_tensor_536 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_436); all_gather_into_tensor_436 = None + convert_element_type_1388 = torch.ops.prims.convert_element_type.default(add_1705, torch.float32) + pow_79 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1388, 2) + mean_78 = torch.ops.aten.mean.dim(pow_79, [2], True); pow_79 = None + add_1706 = torch.ops.aten.add.Scalar(mean_78, 1e-05); mean_78 = None + rsqrt_78 = torch.ops.aten.rsqrt.default(add_1706); add_1706 = None + mul_1234 = torch.ops.aten.mul.Tensor(convert_element_type_1388, rsqrt_78); convert_element_type_1388 = None + mul_1235 = torch.ops.aten.mul.Tensor(mul_1234, wait_tensor_536); mul_1234 = wait_tensor_536 = None + convert_element_type_1389 = torch.ops.prims.convert_element_type.default(mul_1235, torch.bfloat16); mul_1235 = None + convert_element_type_1390 = torch.ops.prims.convert_element_type.default(primals_424, torch.bfloat16) + all_gather_into_tensor_437 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1390, 128, '0'); convert_element_type_1390 = None + wait_tensor_537 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_437); all_gather_into_tensor_437 = None + permute_386 = torch.ops.aten.permute.default(wait_tensor_537, [1, 0]); wait_tensor_537 = None + view_1711 = torch.ops.aten.view.default(convert_element_type_1389, [8192, 2048]); convert_element_type_1389 = None + mm_207 = torch.ops.aten.mm.default(view_1711, permute_386); permute_386 = None + view_1712 = torch.ops.aten.view.default(mm_207, [2, 4096, 3072]); mm_207 = None + view_1713 = torch.ops.aten.view.default(view_1712, [2, 4096, -1, 192]); view_1712 = None + split_with_sizes_78 = torch.ops.aten.split_with_sizes.default(view_1713, [128, 64], -1); view_1713 = None + getitem_2759 = split_with_sizes_78[0] + getitem_2760 = split_with_sizes_78[1]; split_with_sizes_78 = None + convert_element_type_1393 = torch.ops.prims.convert_element_type.default(getitem_2760, torch.float32); getitem_2760 = None + view_1714 = torch.ops.aten.view.default(convert_element_type_1393, [2, 4096, 16, -1, 2]); convert_element_type_1393 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_1714); view_1714 = None + mul_1236 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_7); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_1236); mul_1236 = None + view_1716 = torch.ops.aten.view.default(view_as_real_52, [2, 4096, 16, 64]); view_as_real_52 = None + convert_element_type_1394 = torch.ops.prims.convert_element_type.default(view_1716, torch.bfloat16); view_1716 = None + cat_227 = torch.ops.aten.cat.default([getitem_2759, convert_element_type_1394], -1); getitem_2759 = convert_element_type_1394 = None + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(primals_425, torch.bfloat16) + all_gather_into_tensor_438 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1395, 128, '0'); convert_element_type_1395 = None + wait_tensor_538 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_438); all_gather_into_tensor_438 = None + slice_157 = torch.ops.aten.slice.Tensor(wait_tensor_538, 0, 0, 576); wait_tensor_538 = None + permute_387 = torch.ops.aten.permute.default(slice_157, [1, 0]); slice_157 = None + mm_208 = torch.ops.aten.mm.default(view_1711, permute_387); permute_387 = None + view_1719 = torch.ops.aten.view.default(mm_208, [2, 4096, 576]); mm_208 = None + split_with_sizes_79 = torch.ops.aten.split_with_sizes.default(view_1719, [512, 64], -1); view_1719 = None + getitem_2761 = split_with_sizes_79[0] + getitem_2762 = split_with_sizes_79[1]; split_with_sizes_79 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(getitem_2762, 2); getitem_2762 = None + convert_element_type_1398 = torch.ops.prims.convert_element_type.default(unsqueeze_51, torch.float32); unsqueeze_51 = None + view_1720 = torch.ops.aten.view.default(convert_element_type_1398, [2, 4096, 1, -1, 2]); convert_element_type_1398 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_1720); view_1720 = None + mul_1237 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_7); view_as_complex_53 = view_7 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_1237); mul_1237 = None + view_1722 = torch.ops.aten.view.default(view_as_real_53, [2, 4096, 1, 64]); view_as_real_53 = None + convert_element_type_1399 = torch.ops.prims.convert_element_type.default(view_1722, torch.bfloat16); view_1722 = None + convert_element_type_1400 = torch.ops.prims.convert_element_type.default(primals_426, torch.bfloat16) + all_gather_into_tensor_439 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1400, 128, '0'); convert_element_type_1400 = None + wait_tensor_539 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_439); all_gather_into_tensor_439 = None + convert_element_type_1401 = torch.ops.prims.convert_element_type.default(getitem_2761, torch.float32) + pow_80 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1401, 2) + mean_79 = torch.ops.aten.mean.dim(pow_80, [2], True); pow_80 = None + add_1707 = torch.ops.aten.add.Scalar(mean_79, 1e-05); mean_79 = None + rsqrt_79 = torch.ops.aten.rsqrt.default(add_1707); add_1707 = None + mul_1238 = torch.ops.aten.mul.Tensor(convert_element_type_1401, rsqrt_79); convert_element_type_1401 = None + mul_1239 = torch.ops.aten.mul.Tensor(mul_1238, wait_tensor_539); mul_1238 = wait_tensor_539 = None + convert_element_type_1402 = torch.ops.prims.convert_element_type.default(mul_1239, torch.bfloat16); mul_1239 = None + convert_element_type_1403 = torch.ops.prims.convert_element_type.default(primals_427, torch.bfloat16) + all_gather_into_tensor_440 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1403, 128, '0'); convert_element_type_1403 = None + wait_tensor_540 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_440); all_gather_into_tensor_440 = None + permute_388 = torch.ops.aten.permute.default(wait_tensor_540, [1, 0]); wait_tensor_540 = None + view_1725 = torch.ops.aten.view.default(convert_element_type_1402, [8192, 512]); convert_element_type_1402 = None + mm_209 = torch.ops.aten.mm.default(view_1725, permute_388); permute_388 = None + view_1726 = torch.ops.aten.view.default(mm_209, [2, 4096, 4096]); mm_209 = None + view_1727 = torch.ops.aten.view.default(view_1726, [2, 4096, -1, 256]); view_1726 = None + split_with_sizes_80 = torch.ops.aten.split_with_sizes.default(view_1727, [128, 128], -1); view_1727 = None + getitem_2763 = split_with_sizes_80[0] + getitem_2764 = split_with_sizes_80[1]; split_with_sizes_80 = None + expand_26 = torch.ops.aten.expand.default(convert_element_type_1399, [-1, -1, 16, -1]); convert_element_type_1399 = None + cat_228 = torch.ops.aten.cat.default([getitem_2763, expand_26], -1); getitem_2763 = expand_26 = None + permute_389 = torch.ops.aten.permute.default(cat_227, [0, 2, 1, 3]); cat_227 = None + permute_390 = torch.ops.aten.permute.default(cat_228, [0, 2, 1, 3]); cat_228 = None + permute_391 = torch.ops.aten.permute.default(getitem_2764, [0, 2, 1, 3]); getitem_2764 = None + sdpa_score26 = self.sdpa_score26 + sdpa_mask26 = self.sdpa_mask26 + flex_attention_26 = torch.ops.higher_order.flex_attention(permute_389, permute_390, permute_391, sdpa_score26, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask26), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score26 = sdpa_mask26 = None + getitem_2765 = flex_attention_26[0] + getitem_2766 = flex_attention_26[1]; flex_attention_26 = None + permute_392 = torch.ops.aten.permute.default(getitem_2765, [0, 2, 1, 3]) + view_1728 = torch.ops.aten.view.default(permute_392, [2, 4096, -1]); permute_392 = None + convert_element_type_1406 = torch.ops.prims.convert_element_type.default(primals_428, torch.bfloat16) + all_gather_into_tensor_441 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1406, 128, '0'); convert_element_type_1406 = None + wait_tensor_541 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_441); all_gather_into_tensor_441 = None + permute_393 = torch.ops.aten.permute.default(wait_tensor_541, [1, 0]); wait_tensor_541 = None + view_1730 = torch.ops.aten.view.default(view_1728, [8192, 2048]); view_1728 = None + mm_210 = torch.ops.aten.mm.default(view_1730, permute_393); view_1730 = permute_393 = None + view_1731 = torch.ops.aten.view.default(mm_210, [2, 4096, 2048]); mm_210 = None + add_1708 = torch.ops.aten.add.Tensor(add_1705, view_1731); view_1731 = None + convert_element_type_1409 = torch.ops.prims.convert_element_type.default(primals_429, torch.bfloat16) + all_gather_into_tensor_442 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1409, 128, '0'); convert_element_type_1409 = None + wait_tensor_542 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_442); all_gather_into_tensor_442 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(add_1708, torch.float32) + pow_81 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1410, 2) + mean_80 = torch.ops.aten.mean.dim(pow_81, [2], True); pow_81 = None + add_1709 = torch.ops.aten.add.Scalar(mean_80, 1e-05); mean_80 = None + rsqrt_80 = torch.ops.aten.rsqrt.default(add_1709); add_1709 = None + mul_1240 = torch.ops.aten.mul.Tensor(convert_element_type_1410, rsqrt_80); convert_element_type_1410 = None + mul_1241 = torch.ops.aten.mul.Tensor(mul_1240, wait_tensor_542); mul_1240 = wait_tensor_542 = None + convert_element_type_1411 = torch.ops.prims.convert_element_type.default(mul_1241, torch.bfloat16); mul_1241 = None + view_1733 = torch.ops.aten.view.default(convert_element_type_1411, [-1, 2048]); convert_element_type_1411 = None + convert_element_type_1412 = torch.ops.prims.convert_element_type.default(primals_431, torch.bfloat16) + all_gather_into_tensor_443 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1412, 128, '0'); convert_element_type_1412 = None + wait_tensor_543 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_443); all_gather_into_tensor_443 = None + slice_159 = torch.ops.aten.slice.Tensor(wait_tensor_543, 0, 0, 64); wait_tensor_543 = None + permute_394 = torch.ops.aten.permute.default(slice_159, [1, 0]); slice_159 = None + mm_211 = torch.ops.aten.mm.default(view_1733, permute_394); permute_394 = None + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(mm_211, torch.float32) + amax_25 = torch.ops.aten.amax.default(convert_element_type_1415, [1], True) + sub_600 = torch.ops.aten.sub.Tensor(convert_element_type_1415, amax_25); convert_element_type_1415 = None + exp_76 = torch.ops.aten.exp.default(sub_600); sub_600 = None + sum_101 = torch.ops.aten.sum.dim_IntList(exp_76, [1], True) + div_126 = torch.ops.aten.div.Tensor(exp_76, sum_101); exp_76 = None + add_1710 = torch.ops.aten.add.Tensor(div_126, primals_430); primals_430 = None + topk_25 = torch.ops.aten.topk.default(add_1710, 6, -1, True, False); add_1710 = None + getitem_2769 = topk_25[1]; topk_25 = None + gather_25 = torch.ops.aten.gather.default(div_126, 1, getitem_2769); div_126 = None + mul_1242 = torch.ops.aten.mul.Tensor(gather_25, 1.0); gather_25 = None + view_1735 = torch.ops.aten.view.default(getitem_2769, [-1]) + histc_50 = torch.ops.aten.histc.default(view_1735, 64, 0, 64) + add_1711 = torch.ops.aten.add.Tensor(primals_432, histc_50) + sort_25 = torch.ops.aten.sort.stable(view_1735, stable = True); view_1735 = None + getitem_2771 = sort_25[1]; sort_25 = None + div_127 = torch.ops.aten.div.Tensor_mode(getitem_2771, 6, rounding_mode = 'floor') + index_50 = torch.ops.aten.index.Tensor(view_1733, [div_127]) + all_to_all_single_75 = torch.ops._c10d_functional.all_to_all_single.default(histc_50, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '1033') + wait_tensor_544 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_75); all_to_all_single_75 = None + wait_tensor_545 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_544); wait_tensor_544 = None + view_1739 = torch.ops.aten.view.default(histc_50, [8, -1]); histc_50 = None + sum_102 = torch.ops.aten.sum.dim_IntList(view_1739, [1]); view_1739 = None + device_put_50 = torch.ops.prims.device_put.default(sum_102, device(type='cpu'), True); sum_102 = None + view_1740 = torch.ops.aten.view.default(wait_tensor_545, [8, -1]) + sum_103 = torch.ops.aten.sum.dim_IntList(view_1740, [1]) + device_put_51 = torch.ops.prims.device_put.default(sum_103, device(type='cpu')); sum_103 = None + select_400 = torch.ops.aten.select.int(device_put_50, 0, 0) + _local_scalar_dense_400 = torch.ops.aten._local_scalar_dense.default(select_400); select_400 = None + ge_500 = _local_scalar_dense_400 >= 0 + _assert_scalar_400 = torch.ops.aten._assert_scalar.default(ge_500, "Runtime assertion failed for expression u400 >= 0 on node 'ge_400'"); ge_500 = _assert_scalar_400 = None + select_401 = torch.ops.aten.select.int(device_put_50, 0, 1) + _local_scalar_dense_401 = torch.ops.aten._local_scalar_dense.default(select_401); select_401 = None + ge_501 = _local_scalar_dense_401 >= 0 + _assert_scalar_401 = torch.ops.aten._assert_scalar.default(ge_501, "Runtime assertion failed for expression u401 >= 0 on node 'ge_401'"); ge_501 = _assert_scalar_401 = None + select_402 = torch.ops.aten.select.int(device_put_50, 0, 2) + _local_scalar_dense_402 = torch.ops.aten._local_scalar_dense.default(select_402); select_402 = None + ge_502 = _local_scalar_dense_402 >= 0 + _assert_scalar_402 = torch.ops.aten._assert_scalar.default(ge_502, "Runtime assertion failed for expression u402 >= 0 on node 'ge_402'"); ge_502 = _assert_scalar_402 = None + select_403 = torch.ops.aten.select.int(device_put_50, 0, 3) + _local_scalar_dense_403 = torch.ops.aten._local_scalar_dense.default(select_403); select_403 = None + ge_503 = _local_scalar_dense_403 >= 0 + _assert_scalar_403 = torch.ops.aten._assert_scalar.default(ge_503, "Runtime assertion failed for expression u403 >= 0 on node 'ge_403'"); ge_503 = _assert_scalar_403 = None + select_404 = torch.ops.aten.select.int(device_put_50, 0, 4) + _local_scalar_dense_404 = torch.ops.aten._local_scalar_dense.default(select_404); select_404 = None + ge_504 = _local_scalar_dense_404 >= 0 + _assert_scalar_404 = torch.ops.aten._assert_scalar.default(ge_504, "Runtime assertion failed for expression u404 >= 0 on node 'ge_404'"); ge_504 = _assert_scalar_404 = None + select_405 = torch.ops.aten.select.int(device_put_50, 0, 5) + _local_scalar_dense_405 = torch.ops.aten._local_scalar_dense.default(select_405); select_405 = None + ge_505 = _local_scalar_dense_405 >= 0 + _assert_scalar_405 = torch.ops.aten._assert_scalar.default(ge_505, "Runtime assertion failed for expression u405 >= 0 on node 'ge_405'"); ge_505 = _assert_scalar_405 = None + select_406 = torch.ops.aten.select.int(device_put_50, 0, 6) + _local_scalar_dense_406 = torch.ops.aten._local_scalar_dense.default(select_406); select_406 = None + ge_506 = _local_scalar_dense_406 >= 0 + _assert_scalar_406 = torch.ops.aten._assert_scalar.default(ge_506, "Runtime assertion failed for expression u406 >= 0 on node 'ge_406'"); ge_506 = _assert_scalar_406 = None + select_407 = torch.ops.aten.select.int(device_put_50, 0, 7); device_put_50 = None + _local_scalar_dense_407 = torch.ops.aten._local_scalar_dense.default(select_407); select_407 = None + ge_507 = _local_scalar_dense_407 >= 0 + _assert_scalar_407 = torch.ops.aten._assert_scalar.default(ge_507, "Runtime assertion failed for expression u407 >= 0 on node 'ge_407'"); ge_507 = _assert_scalar_407 = None + select_408 = torch.ops.aten.select.int(device_put_51, 0, 0) + _local_scalar_dense_408 = torch.ops.aten._local_scalar_dense.default(select_408); select_408 = None + ge_508 = _local_scalar_dense_408 >= 0 + _assert_scalar_408 = torch.ops.aten._assert_scalar.default(ge_508, "Runtime assertion failed for expression u408 >= 0 on node 'ge_408'"); ge_508 = _assert_scalar_408 = None + select_409 = torch.ops.aten.select.int(device_put_51, 0, 1) + _local_scalar_dense_409 = torch.ops.aten._local_scalar_dense.default(select_409); select_409 = None + ge_509 = _local_scalar_dense_409 >= 0 + _assert_scalar_409 = torch.ops.aten._assert_scalar.default(ge_509, "Runtime assertion failed for expression u409 >= 0 on node 'ge_409'"); ge_509 = _assert_scalar_409 = None + select_410 = torch.ops.aten.select.int(device_put_51, 0, 2) + _local_scalar_dense_410 = torch.ops.aten._local_scalar_dense.default(select_410); select_410 = None + ge_510 = _local_scalar_dense_410 >= 0 + _assert_scalar_410 = torch.ops.aten._assert_scalar.default(ge_510, "Runtime assertion failed for expression u410 >= 0 on node 'ge_410'"); ge_510 = _assert_scalar_410 = None + select_411 = torch.ops.aten.select.int(device_put_51, 0, 3) + _local_scalar_dense_411 = torch.ops.aten._local_scalar_dense.default(select_411); select_411 = None + ge_511 = _local_scalar_dense_411 >= 0 + _assert_scalar_411 = torch.ops.aten._assert_scalar.default(ge_511, "Runtime assertion failed for expression u411 >= 0 on node 'ge_411'"); ge_511 = _assert_scalar_411 = None + select_412 = torch.ops.aten.select.int(device_put_51, 0, 4) + _local_scalar_dense_412 = torch.ops.aten._local_scalar_dense.default(select_412); select_412 = None + ge_512 = _local_scalar_dense_412 >= 0 + _assert_scalar_412 = torch.ops.aten._assert_scalar.default(ge_512, "Runtime assertion failed for expression u412 >= 0 on node 'ge_412'"); ge_512 = _assert_scalar_412 = None + select_413 = torch.ops.aten.select.int(device_put_51, 0, 5) + _local_scalar_dense_413 = torch.ops.aten._local_scalar_dense.default(select_413); select_413 = None + ge_513 = _local_scalar_dense_413 >= 0 + _assert_scalar_413 = torch.ops.aten._assert_scalar.default(ge_513, "Runtime assertion failed for expression u413 >= 0 on node 'ge_413'"); ge_513 = _assert_scalar_413 = None + select_414 = torch.ops.aten.select.int(device_put_51, 0, 6) + _local_scalar_dense_414 = torch.ops.aten._local_scalar_dense.default(select_414); select_414 = None + ge_514 = _local_scalar_dense_414 >= 0 + _assert_scalar_414 = torch.ops.aten._assert_scalar.default(ge_514, "Runtime assertion failed for expression u414 >= 0 on node 'ge_414'"); ge_514 = _assert_scalar_414 = None + select_415 = torch.ops.aten.select.int(device_put_51, 0, 7); device_put_51 = None + _local_scalar_dense_415 = torch.ops.aten._local_scalar_dense.default(select_415); select_415 = None + ge_515 = _local_scalar_dense_415 >= 0 + _assert_scalar_415 = torch.ops.aten._assert_scalar.default(ge_515, "Runtime assertion failed for expression u415 >= 0 on node 'ge_415'"); ge_515 = _assert_scalar_415 = None + all_to_all_single_76 = torch.ops._c10d_functional.all_to_all_single.default(index_50, [_local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415], [_local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407], '1033'); index_50 = None + sym_size_int_100 = torch.ops.aten.sym_size.int(all_to_all_single_76, 0) + wait_tensor_546 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_76); all_to_all_single_76 = None + sym_sum_50 = torch.sym_sum((_local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415)) + add_1718 = sym_sum_50 + 64; sym_sum_50 = None + add_1719 = add_1718 + 8; add_1718 = None + sub_603 = add_1719 - 1; add_1719 = None + floordiv_25 = sub_603 // 8; sub_603 = None + mul_1247 = floordiv_25 * 8; floordiv_25 = None + cumsum_75 = torch.ops.aten.cumsum.default(wait_tensor_545, 0) + sub_604 = torch.ops.aten.sub.Tensor(cumsum_75, wait_tensor_545); cumsum_75 = None + sum_104 = torch.ops.aten.sum.dim_IntList(view_1740, [0]); view_1740 = None + clamp_min_25 = torch.ops.aten.clamp_min.default(sum_104, 8); sum_104 = None + add_1720 = torch.ops.aten.add.Tensor(clamp_min_25, 8); clamp_min_25 = None + sub_605 = torch.ops.aten.sub.Tensor(add_1720, 1); add_1720 = None + div_128 = torch.ops.aten.div.Tensor_mode(sub_605, 8, rounding_mode = 'floor'); sub_605 = None + mul_1248 = torch.ops.aten.mul.Tensor(div_128, 8); div_128 = None + convert_element_type_1418 = torch.ops.prims.convert_element_type.default(mul_1248, torch.int32); mul_1248 = None + cumsum_76 = torch.ops.aten.cumsum.default(convert_element_type_1418, 0) + sub_606 = torch.ops.aten.sub.Tensor(cumsum_76, convert_element_type_1418); cumsum_76 = None + full_345 = torch.ops.aten.full.default([mul_1247], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1247 = None + triton_kernel_wrapper_functional_proxy_25 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 25, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_545, 'start_index_values_ptr': sub_604, 'write_offsets_ptr': sub_606, 'output_ptr': full_345}, tensors_to_clone = ['output_ptr']); wait_tensor_545 = sub_604 = sub_606 = full_345 = None + getitem_2772 = triton_kernel_wrapper_functional_proxy_25['output_ptr']; triton_kernel_wrapper_functional_proxy_25 = None + cat_229 = torch.ops.aten.cat.default([wait_tensor_546, full_default]); wait_tensor_546 = full_default = None + sym_size_int_101 = torch.ops.aten.sym_size.int(cat_229, 0) + sym_sum_51 = torch.sym_sum((1, _local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415)) + index_51 = torch.ops.aten.index.Tensor(cat_229, [getitem_2772]); cat_229 = None + convert_element_type_1420 = torch.ops.prims.convert_element_type.default(primals_433, torch.bfloat16) + all_gather_into_tensor_444 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1420, 16, '1025'); convert_element_type_1420 = None + wait_tensor_547 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_444); all_gather_into_tensor_444 = None + split_151 = torch.ops.aten.split.Tensor(wait_tensor_547, 8); wait_tensor_547 = None + getitem_2789 = split_151[0] + getitem_2790 = split_151[1] + getitem_2791 = split_151[2] + getitem_2792 = split_151[3] + getitem_2793 = split_151[4] + getitem_2794 = split_151[5] + getitem_2795 = split_151[6] + getitem_2796 = split_151[7] + getitem_2797 = split_151[8] + getitem_2798 = split_151[9] + getitem_2799 = split_151[10] + getitem_2800 = split_151[11] + getitem_2801 = split_151[12] + getitem_2802 = split_151[13] + getitem_2803 = split_151[14] + getitem_2804 = split_151[15]; split_151 = None + cat_231 = torch.ops.aten.cat.default([getitem_2789, getitem_2790, getitem_2791, getitem_2792, getitem_2793, getitem_2794, getitem_2795, getitem_2796, getitem_2797, getitem_2798, getitem_2799, getitem_2800, getitem_2801, getitem_2802, getitem_2803, getitem_2804], 1); getitem_2789 = getitem_2790 = getitem_2791 = getitem_2792 = getitem_2793 = getitem_2794 = getitem_2795 = getitem_2796 = getitem_2797 = getitem_2798 = getitem_2799 = getitem_2800 = getitem_2801 = getitem_2802 = getitem_2803 = getitem_2804 = None + convert_element_type_1422 = torch.ops.prims.convert_element_type.default(primals_434, torch.bfloat16) + all_gather_into_tensor_446 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1422, 16, '1025'); convert_element_type_1422 = None + wait_tensor_549 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_446); all_gather_into_tensor_446 = None + split_152 = torch.ops.aten.split.Tensor(wait_tensor_549, 8); wait_tensor_549 = None + getitem_2805 = split_152[0] + getitem_2806 = split_152[1] + getitem_2807 = split_152[2] + getitem_2808 = split_152[3] + getitem_2809 = split_152[4] + getitem_2810 = split_152[5] + getitem_2811 = split_152[6] + getitem_2812 = split_152[7] + getitem_2813 = split_152[8] + getitem_2814 = split_152[9] + getitem_2815 = split_152[10] + getitem_2816 = split_152[11] + getitem_2817 = split_152[12] + getitem_2818 = split_152[13] + getitem_2819 = split_152[14] + getitem_2820 = split_152[15]; split_152 = None + cat_232 = torch.ops.aten.cat.default([getitem_2805, getitem_2806, getitem_2807, getitem_2808, getitem_2809, getitem_2810, getitem_2811, getitem_2812, getitem_2813, getitem_2814, getitem_2815, getitem_2816, getitem_2817, getitem_2818, getitem_2819, getitem_2820], 1); getitem_2805 = getitem_2806 = getitem_2807 = getitem_2808 = getitem_2809 = getitem_2810 = getitem_2811 = getitem_2812 = getitem_2813 = getitem_2814 = getitem_2815 = getitem_2816 = getitem_2817 = getitem_2818 = getitem_2819 = getitem_2820 = None + convert_element_type_1423 = torch.ops.prims.convert_element_type.default(primals_435, torch.bfloat16) + all_gather_into_tensor_447 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1423, 16, '1025'); convert_element_type_1423 = None + wait_tensor_550 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_447); all_gather_into_tensor_447 = None + split_153 = torch.ops.aten.split.Tensor(wait_tensor_550, 8); wait_tensor_550 = None + getitem_2821 = split_153[0] + getitem_2822 = split_153[1] + getitem_2823 = split_153[2] + getitem_2824 = split_153[3] + getitem_2825 = split_153[4] + getitem_2826 = split_153[5] + getitem_2827 = split_153[6] + getitem_2828 = split_153[7] + getitem_2829 = split_153[8] + getitem_2830 = split_153[9] + getitem_2831 = split_153[10] + getitem_2832 = split_153[11] + getitem_2833 = split_153[12] + getitem_2834 = split_153[13] + getitem_2835 = split_153[14] + getitem_2836 = split_153[15]; split_153 = None + cat_233 = torch.ops.aten.cat.default([getitem_2821, getitem_2822, getitem_2823, getitem_2824, getitem_2825, getitem_2826, getitem_2827, getitem_2828, getitem_2829, getitem_2830, getitem_2831, getitem_2832, getitem_2833, getitem_2834, getitem_2835, getitem_2836], 1); getitem_2821 = getitem_2822 = getitem_2823 = getitem_2824 = getitem_2825 = getitem_2826 = getitem_2827 = getitem_2828 = getitem_2829 = getitem_2830 = getitem_2831 = getitem_2832 = getitem_2833 = getitem_2834 = getitem_2835 = getitem_2836 = None + cumsum_77 = torch.ops.aten.cumsum.default(convert_element_type_1418, 0, dtype = torch.int32); convert_element_type_1418 = None + permute_395 = torch.ops.aten.permute.default(cat_231, [0, 2, 1]); cat_231 = None + _grouped_mm_75 = torch.ops.aten._grouped_mm.default(index_51, permute_395, cumsum_77) + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(_grouped_mm_75, torch.float32) + neg_51 = torch.ops.aten.neg.default(convert_element_type_1426) + exp_77 = torch.ops.aten.exp.default(neg_51); neg_51 = None + add_1732 = torch.ops.aten.add.Tensor(exp_77, 1); exp_77 = None + div_129 = torch.ops.aten.div.Tensor(convert_element_type_1426, add_1732); convert_element_type_1426 = add_1732 = None + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(div_129, torch.bfloat16); div_129 = None + permute_396 = torch.ops.aten.permute.default(cat_233, [0, 2, 1]); cat_233 = None + _grouped_mm_76 = torch.ops.aten._grouped_mm.default(index_51, permute_396, cumsum_77) + mul_1260 = torch.ops.aten.mul.Tensor(convert_element_type_1427, _grouped_mm_76); convert_element_type_1427 = None + permute_397 = torch.ops.aten.permute.default(cat_232, [0, 2, 1]); cat_232 = None + _grouped_mm_77 = torch.ops.aten._grouped_mm.default(mul_1260, permute_397, cumsum_77) + empty_25 = torch.ops.aten.empty.memory_format([sym_size_int_101, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_50 = torch.ops.aten.index_put.default(empty_25, [getitem_2772], _grouped_mm_77); empty_25 = _grouped_mm_77 = None + slice_161 = torch.ops.aten.slice.Tensor(index_put_50, 0, 0, -1); index_put_50 = None + all_to_all_single_77 = torch.ops._c10d_functional.all_to_all_single.default(slice_161, [_local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407], [_local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415], '1033'); slice_161 = None + wait_tensor_553 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_77); all_to_all_single_77 = None + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(primals_436, torch.bfloat16) + all_gather_into_tensor_450 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1428, 128, '0'); convert_element_type_1428 = None + wait_tensor_554 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_450); all_gather_into_tensor_450 = None + permute_398 = torch.ops.aten.permute.default(wait_tensor_554, [1, 0]); wait_tensor_554 = None + mm_212 = torch.ops.aten.mm.default(view_1733, permute_398); permute_398 = None + convert_element_type_1431 = torch.ops.prims.convert_element_type.default(mm_212, torch.float32) + neg_52 = torch.ops.aten.neg.default(convert_element_type_1431) + exp_78 = torch.ops.aten.exp.default(neg_52); neg_52 = None + add_1768 = torch.ops.aten.add.Tensor(exp_78, 1); exp_78 = None + div_130 = torch.ops.aten.div.Tensor(convert_element_type_1431, add_1768); convert_element_type_1431 = add_1768 = None + convert_element_type_1432 = torch.ops.prims.convert_element_type.default(div_130, torch.bfloat16); div_130 = None + convert_element_type_1433 = torch.ops.prims.convert_element_type.default(primals_437, torch.bfloat16) + all_gather_into_tensor_451 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1433, 128, '0'); convert_element_type_1433 = None + wait_tensor_555 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_451); all_gather_into_tensor_451 = None + permute_399 = torch.ops.aten.permute.default(wait_tensor_555, [1, 0]); wait_tensor_555 = None + mm_213 = torch.ops.aten.mm.default(view_1733, permute_399); permute_399 = None + mul_1280 = torch.ops.aten.mul.Tensor(convert_element_type_1432, mm_213); convert_element_type_1432 = None + convert_element_type_1436 = torch.ops.prims.convert_element_type.default(primals_438, torch.bfloat16) + all_gather_into_tensor_452 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1436, 128, '0'); convert_element_type_1436 = None + wait_tensor_556 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_452); all_gather_into_tensor_452 = None + permute_400 = torch.ops.aten.permute.default(wait_tensor_556, [1, 0]); wait_tensor_556 = None + mm_214 = torch.ops.aten.mm.default(mul_1280, permute_400); permute_400 = None + index_put_51 = torch.ops.aten.index_put.default(full_default_1, [getitem_2771], wait_tensor_553); full_default_1 = wait_tensor_553 = None + view_1773 = torch.ops.aten.view.default(mul_1242, [-1, 1, 6]); mul_1242 = None + view_1774 = torch.ops.aten.view.default(index_put_51, [-1, 6, 2048]); index_put_51 = None + convert_element_type_1439 = torch.ops.prims.convert_element_type.default(view_1774, torch.float32); view_1774 = None + bmm_25 = torch.ops.aten.bmm.default(view_1773, convert_element_type_1439) + convert_element_type_1440 = torch.ops.prims.convert_element_type.default(bmm_25, torch.bfloat16); bmm_25 = None + squeeze_25 = torch.ops.aten.squeeze.dim(convert_element_type_1440, 1); convert_element_type_1440 = None + add_1772 = torch.ops.aten.add.Tensor(mm_214, squeeze_25); mm_214 = squeeze_25 = None + view_1775 = torch.ops.aten.view.default(add_1772, [2, 4096, 2048]); add_1772 = None + add_1773 = torch.ops.aten.add.Tensor(add_1708, view_1775); view_1775 = None + convert_element_type_1441 = torch.ops.prims.convert_element_type.default(primals_439, torch.bfloat16) + all_gather_into_tensor_453 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1441, 128, '0'); convert_element_type_1441 = None + wait_tensor_557 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_453); all_gather_into_tensor_453 = None + convert_element_type_1442 = torch.ops.prims.convert_element_type.default(add_1773, torch.float32) + pow_82 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1442, 2) + mean_81 = torch.ops.aten.mean.dim(pow_82, [2], True); pow_82 = None + add_1774 = torch.ops.aten.add.Scalar(mean_81, 1.1920928955078125e-07); mean_81 = None + rsqrt_81 = torch.ops.aten.rsqrt.default(add_1774); add_1774 = None + mul_1283 = torch.ops.aten.mul.Tensor(convert_element_type_1442, rsqrt_81); convert_element_type_1442 = None + mul_1284 = torch.ops.aten.mul.Tensor(mul_1283, wait_tensor_557); mul_1283 = wait_tensor_557 = None + convert_element_type_1443 = torch.ops.prims.convert_element_type.default(mul_1284, torch.bfloat16); mul_1284 = None + convert_element_type_1444 = torch.ops.prims.convert_element_type.default(primals_440, torch.bfloat16) + all_gather_into_tensor_454 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1444, 128, '0'); convert_element_type_1444 = None + wait_tensor_558 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_454); all_gather_into_tensor_454 = None + permute_401 = torch.ops.aten.permute.default(wait_tensor_558, [1, 0]); wait_tensor_558 = None + view_1778 = torch.ops.aten.view.default(convert_element_type_1443, [8192, 2048]); convert_element_type_1443 = None + mm_215 = torch.ops.aten.mm.default(view_1778, permute_401); permute_401 = None + view_1779 = torch.ops.aten.view.default(mm_215, [2, 4096, 102400]); mm_215 = None + permute_406 = torch.ops.aten.permute.default(view_1773, [0, 2, 1]); view_1773 = None + permute_407 = torch.ops.aten.permute.default(convert_element_type_1439, [0, 2, 1]); convert_element_type_1439 = None + permute_422 = torch.ops.aten.permute.default(permute_397, [0, 2, 1]); permute_397 = None + permute_426 = torch.ops.aten.permute.default(permute_396, [0, 2, 1]); permute_396 = None + permute_430 = torch.ops.aten.permute.default(permute_395, [0, 2, 1]); permute_395 = None + add_1781 = 0 + sym_size_int_100; sym_size_int_100 = None + full_default_54 = torch.ops.aten.full.default([0, 2048], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + permute_456 = torch.ops.aten.permute.default(view_1706, [0, 2, 1]); view_1706 = None + permute_457 = torch.ops.aten.permute.default(convert_element_type_1385, [0, 2, 1]); convert_element_type_1385 = None + permute_472 = torch.ops.aten.permute.default(permute_382, [0, 2, 1]); permute_382 = None + permute_476 = torch.ops.aten.permute.default(permute_381, [0, 2, 1]); permute_381 = None + permute_480 = torch.ops.aten.permute.default(permute_380, [0, 2, 1]); permute_380 = None + add_1796 = 0 + sym_size_int_96; sym_size_int_96 = None + permute_506 = torch.ops.aten.permute.default(view_1639, [0, 2, 1]); view_1639 = None + permute_507 = torch.ops.aten.permute.default(convert_element_type_1331, [0, 2, 1]); convert_element_type_1331 = None + permute_522 = torch.ops.aten.permute.default(permute_367, [0, 2, 1]); permute_367 = None + permute_526 = torch.ops.aten.permute.default(permute_366, [0, 2, 1]); permute_366 = None + permute_530 = torch.ops.aten.permute.default(permute_365, [0, 2, 1]); permute_365 = None + add_1811 = 0 + sym_size_int_92; sym_size_int_92 = None + permute_556 = torch.ops.aten.permute.default(view_1572, [0, 2, 1]); view_1572 = None + permute_557 = torch.ops.aten.permute.default(convert_element_type_1277, [0, 2, 1]); convert_element_type_1277 = None + permute_572 = torch.ops.aten.permute.default(permute_352, [0, 2, 1]); permute_352 = None + permute_576 = torch.ops.aten.permute.default(permute_351, [0, 2, 1]); permute_351 = None + permute_580 = torch.ops.aten.permute.default(permute_350, [0, 2, 1]); permute_350 = None + add_1826 = 0 + sym_size_int_88; sym_size_int_88 = None + permute_606 = torch.ops.aten.permute.default(view_1505, [0, 2, 1]); view_1505 = None + permute_607 = torch.ops.aten.permute.default(convert_element_type_1223, [0, 2, 1]); convert_element_type_1223 = None + permute_622 = torch.ops.aten.permute.default(permute_337, [0, 2, 1]); permute_337 = None + permute_626 = torch.ops.aten.permute.default(permute_336, [0, 2, 1]); permute_336 = None + permute_630 = torch.ops.aten.permute.default(permute_335, [0, 2, 1]); permute_335 = None + add_1841 = 0 + sym_size_int_84; sym_size_int_84 = None + permute_656 = torch.ops.aten.permute.default(view_1438, [0, 2, 1]); view_1438 = None + permute_657 = torch.ops.aten.permute.default(convert_element_type_1169, [0, 2, 1]); convert_element_type_1169 = None + permute_672 = torch.ops.aten.permute.default(permute_322, [0, 2, 1]); permute_322 = None + permute_676 = torch.ops.aten.permute.default(permute_321, [0, 2, 1]); permute_321 = None + permute_680 = torch.ops.aten.permute.default(permute_320, [0, 2, 1]); permute_320 = None + add_1856 = 0 + sym_size_int_80; sym_size_int_80 = None + permute_706 = torch.ops.aten.permute.default(view_1371, [0, 2, 1]); view_1371 = None + permute_707 = torch.ops.aten.permute.default(convert_element_type_1115, [0, 2, 1]); convert_element_type_1115 = None + permute_722 = torch.ops.aten.permute.default(permute_307, [0, 2, 1]); permute_307 = None + permute_726 = torch.ops.aten.permute.default(permute_306, [0, 2, 1]); permute_306 = None + permute_730 = torch.ops.aten.permute.default(permute_305, [0, 2, 1]); permute_305 = None + add_1871 = 0 + sym_size_int_76; sym_size_int_76 = None + permute_756 = torch.ops.aten.permute.default(view_1304, [0, 2, 1]); view_1304 = None + permute_757 = torch.ops.aten.permute.default(convert_element_type_1061, [0, 2, 1]); convert_element_type_1061 = None + permute_772 = torch.ops.aten.permute.default(permute_292, [0, 2, 1]); permute_292 = None + permute_776 = torch.ops.aten.permute.default(permute_291, [0, 2, 1]); permute_291 = None + permute_780 = torch.ops.aten.permute.default(permute_290, [0, 2, 1]); permute_290 = None + add_1886 = 0 + sym_size_int_72; sym_size_int_72 = None + permute_806 = torch.ops.aten.permute.default(view_1237, [0, 2, 1]); view_1237 = None + permute_807 = torch.ops.aten.permute.default(convert_element_type_1007, [0, 2, 1]); convert_element_type_1007 = None + permute_822 = torch.ops.aten.permute.default(permute_277, [0, 2, 1]); permute_277 = None + permute_826 = torch.ops.aten.permute.default(permute_276, [0, 2, 1]); permute_276 = None + permute_830 = torch.ops.aten.permute.default(permute_275, [0, 2, 1]); permute_275 = None + add_1901 = 0 + sym_size_int_68; sym_size_int_68 = None + permute_856 = torch.ops.aten.permute.default(view_1170, [0, 2, 1]); view_1170 = None + permute_857 = torch.ops.aten.permute.default(convert_element_type_953, [0, 2, 1]); convert_element_type_953 = None + permute_872 = torch.ops.aten.permute.default(permute_262, [0, 2, 1]); permute_262 = None + permute_876 = torch.ops.aten.permute.default(permute_261, [0, 2, 1]); permute_261 = None + permute_880 = torch.ops.aten.permute.default(permute_260, [0, 2, 1]); permute_260 = None + add_1916 = 0 + sym_size_int_64; sym_size_int_64 = None + permute_906 = torch.ops.aten.permute.default(view_1103, [0, 2, 1]); view_1103 = None + permute_907 = torch.ops.aten.permute.default(convert_element_type_899, [0, 2, 1]); convert_element_type_899 = None + permute_922 = torch.ops.aten.permute.default(permute_247, [0, 2, 1]); permute_247 = None + permute_926 = torch.ops.aten.permute.default(permute_246, [0, 2, 1]); permute_246 = None + permute_930 = torch.ops.aten.permute.default(permute_245, [0, 2, 1]); permute_245 = None + add_1931 = 0 + sym_size_int_60; sym_size_int_60 = None + permute_956 = torch.ops.aten.permute.default(view_1036, [0, 2, 1]); view_1036 = None + permute_957 = torch.ops.aten.permute.default(convert_element_type_845, [0, 2, 1]); convert_element_type_845 = None + permute_972 = torch.ops.aten.permute.default(permute_232, [0, 2, 1]); permute_232 = None + permute_976 = torch.ops.aten.permute.default(permute_231, [0, 2, 1]); permute_231 = None + permute_980 = torch.ops.aten.permute.default(permute_230, [0, 2, 1]); permute_230 = None + add_1946 = 0 + sym_size_int_56; sym_size_int_56 = None + permute_1006 = torch.ops.aten.permute.default(view_969, [0, 2, 1]); view_969 = None + permute_1007 = torch.ops.aten.permute.default(convert_element_type_791, [0, 2, 1]); convert_element_type_791 = None + permute_1022 = torch.ops.aten.permute.default(permute_217, [0, 2, 1]); permute_217 = None + permute_1026 = torch.ops.aten.permute.default(permute_216, [0, 2, 1]); permute_216 = None + permute_1030 = torch.ops.aten.permute.default(permute_215, [0, 2, 1]); permute_215 = None + add_1961 = 0 + sym_size_int_52; sym_size_int_52 = None + permute_1056 = torch.ops.aten.permute.default(view_902, [0, 2, 1]); view_902 = None + permute_1057 = torch.ops.aten.permute.default(convert_element_type_737, [0, 2, 1]); convert_element_type_737 = None + permute_1072 = torch.ops.aten.permute.default(permute_202, [0, 2, 1]); permute_202 = None + permute_1076 = torch.ops.aten.permute.default(permute_201, [0, 2, 1]); permute_201 = None + permute_1080 = torch.ops.aten.permute.default(permute_200, [0, 2, 1]); permute_200 = None + add_1976 = 0 + sym_size_int_48; sym_size_int_48 = None + permute_1106 = torch.ops.aten.permute.default(view_835, [0, 2, 1]); view_835 = None + permute_1107 = torch.ops.aten.permute.default(convert_element_type_683, [0, 2, 1]); convert_element_type_683 = None + permute_1122 = torch.ops.aten.permute.default(permute_187, [0, 2, 1]); permute_187 = None + permute_1126 = torch.ops.aten.permute.default(permute_186, [0, 2, 1]); permute_186 = None + permute_1130 = torch.ops.aten.permute.default(permute_185, [0, 2, 1]); permute_185 = None + add_1991 = 0 + sym_size_int_44; sym_size_int_44 = None + permute_1156 = torch.ops.aten.permute.default(view_768, [0, 2, 1]); view_768 = None + permute_1157 = torch.ops.aten.permute.default(convert_element_type_629, [0, 2, 1]); convert_element_type_629 = None + permute_1172 = torch.ops.aten.permute.default(permute_172, [0, 2, 1]); permute_172 = None + permute_1176 = torch.ops.aten.permute.default(permute_171, [0, 2, 1]); permute_171 = None + permute_1180 = torch.ops.aten.permute.default(permute_170, [0, 2, 1]); permute_170 = None + add_2006 = 0 + sym_size_int_40; sym_size_int_40 = None + permute_1206 = torch.ops.aten.permute.default(view_701, [0, 2, 1]); view_701 = None + permute_1207 = torch.ops.aten.permute.default(convert_element_type_575, [0, 2, 1]); convert_element_type_575 = None + permute_1222 = torch.ops.aten.permute.default(permute_157, [0, 2, 1]); permute_157 = None + permute_1226 = torch.ops.aten.permute.default(permute_156, [0, 2, 1]); permute_156 = None + permute_1230 = torch.ops.aten.permute.default(permute_155, [0, 2, 1]); permute_155 = None + add_2021 = 0 + sym_size_int_36; sym_size_int_36 = None + permute_1256 = torch.ops.aten.permute.default(view_634, [0, 2, 1]); view_634 = None + permute_1257 = torch.ops.aten.permute.default(convert_element_type_521, [0, 2, 1]); convert_element_type_521 = None + permute_1272 = torch.ops.aten.permute.default(permute_142, [0, 2, 1]); permute_142 = None + permute_1276 = torch.ops.aten.permute.default(permute_141, [0, 2, 1]); permute_141 = None + permute_1280 = torch.ops.aten.permute.default(permute_140, [0, 2, 1]); permute_140 = None + add_2036 = 0 + sym_size_int_32; sym_size_int_32 = None + permute_1306 = torch.ops.aten.permute.default(view_567, [0, 2, 1]); view_567 = None + permute_1307 = torch.ops.aten.permute.default(convert_element_type_467, [0, 2, 1]); convert_element_type_467 = None + permute_1322 = torch.ops.aten.permute.default(permute_127, [0, 2, 1]); permute_127 = None + permute_1326 = torch.ops.aten.permute.default(permute_126, [0, 2, 1]); permute_126 = None + permute_1330 = torch.ops.aten.permute.default(permute_125, [0, 2, 1]); permute_125 = None + add_2051 = 0 + sym_size_int_28; sym_size_int_28 = None + permute_1356 = torch.ops.aten.permute.default(view_500, [0, 2, 1]); view_500 = None + permute_1357 = torch.ops.aten.permute.default(convert_element_type_413, [0, 2, 1]); convert_element_type_413 = None + permute_1372 = torch.ops.aten.permute.default(permute_112, [0, 2, 1]); permute_112 = None + permute_1376 = torch.ops.aten.permute.default(permute_111, [0, 2, 1]); permute_111 = None + permute_1380 = torch.ops.aten.permute.default(permute_110, [0, 2, 1]); permute_110 = None + add_2066 = 0 + sym_size_int_24; sym_size_int_24 = None + permute_1406 = torch.ops.aten.permute.default(view_433, [0, 2, 1]); view_433 = None + permute_1407 = torch.ops.aten.permute.default(convert_element_type_359, [0, 2, 1]); convert_element_type_359 = None + permute_1422 = torch.ops.aten.permute.default(permute_97, [0, 2, 1]); permute_97 = None + permute_1426 = torch.ops.aten.permute.default(permute_96, [0, 2, 1]); permute_96 = None + permute_1430 = torch.ops.aten.permute.default(permute_95, [0, 2, 1]); permute_95 = None + add_2081 = 0 + sym_size_int_20; sym_size_int_20 = None + permute_1456 = torch.ops.aten.permute.default(view_366, [0, 2, 1]); view_366 = None + permute_1457 = torch.ops.aten.permute.default(convert_element_type_305, [0, 2, 1]); convert_element_type_305 = None + permute_1472 = torch.ops.aten.permute.default(permute_82, [0, 2, 1]); permute_82 = None + permute_1476 = torch.ops.aten.permute.default(permute_81, [0, 2, 1]); permute_81 = None + permute_1480 = torch.ops.aten.permute.default(permute_80, [0, 2, 1]); permute_80 = None + add_2096 = 0 + sym_size_int_16; sym_size_int_16 = None + permute_1506 = torch.ops.aten.permute.default(view_299, [0, 2, 1]); view_299 = None + permute_1507 = torch.ops.aten.permute.default(convert_element_type_251, [0, 2, 1]); convert_element_type_251 = None + permute_1522 = torch.ops.aten.permute.default(permute_67, [0, 2, 1]); permute_67 = None + permute_1526 = torch.ops.aten.permute.default(permute_66, [0, 2, 1]); permute_66 = None + permute_1530 = torch.ops.aten.permute.default(permute_65, [0, 2, 1]); permute_65 = None + add_2111 = 0 + sym_size_int_12; sym_size_int_12 = None + permute_1556 = torch.ops.aten.permute.default(view_232, [0, 2, 1]); view_232 = None + permute_1557 = torch.ops.aten.permute.default(convert_element_type_197, [0, 2, 1]); convert_element_type_197 = None + permute_1572 = torch.ops.aten.permute.default(permute_52, [0, 2, 1]); permute_52 = None + permute_1576 = torch.ops.aten.permute.default(permute_51, [0, 2, 1]); permute_51 = None + permute_1580 = torch.ops.aten.permute.default(permute_50, [0, 2, 1]); permute_50 = None + add_2126 = 0 + sym_size_int_8; sym_size_int_8 = None + permute_1606 = torch.ops.aten.permute.default(view_165, [0, 2, 1]); view_165 = None + permute_1607 = torch.ops.aten.permute.default(convert_element_type_143, [0, 2, 1]); convert_element_type_143 = None + permute_1622 = torch.ops.aten.permute.default(permute_37, [0, 2, 1]); permute_37 = None + permute_1626 = torch.ops.aten.permute.default(permute_36, [0, 2, 1]); permute_36 = None + permute_1630 = torch.ops.aten.permute.default(permute_35, [0, 2, 1]); permute_35 = None + add_2141 = 0 + sym_size_int_4; sym_size_int_4 = None + permute_1656 = torch.ops.aten.permute.default(view_98, [0, 2, 1]); view_98 = None + permute_1657 = torch.ops.aten.permute.default(convert_element_type_89, [0, 2, 1]); convert_element_type_89 = None + permute_1672 = torch.ops.aten.permute.default(permute_22, [0, 2, 1]); permute_22 = None + permute_1676 = torch.ops.aten.permute.default(permute_21, [0, 2, 1]); permute_21 = None + permute_1680 = torch.ops.aten.permute.default(permute_20, [0, 2, 1]); permute_20 = None + add_2156 = 0 + sym_size_int; sym_size_int = None + copy_ = torch.ops.aten.copy_.default(primals_32, add_11); primals_32 = add_11 = copy_ = None + copy__1 = torch.ops.aten.copy_.default(primals_48, add_79); primals_48 = add_79 = copy__1 = None + copy__2 = torch.ops.aten.copy_.default(primals_64, add_147); primals_64 = add_147 = copy__2 = None + copy__3 = torch.ops.aten.copy_.default(primals_80, add_215); primals_80 = add_215 = copy__3 = None + copy__4 = torch.ops.aten.copy_.default(primals_96, add_283); primals_96 = add_283 = copy__4 = None + copy__5 = torch.ops.aten.copy_.default(primals_112, add_351); primals_112 = add_351 = copy__5 = None + copy__6 = torch.ops.aten.copy_.default(primals_128, add_419); primals_128 = add_419 = copy__6 = None + copy__7 = torch.ops.aten.copy_.default(primals_144, add_487); primals_144 = add_487 = copy__7 = None + copy__8 = torch.ops.aten.copy_.default(primals_160, add_555); primals_160 = add_555 = copy__8 = None + copy__9 = torch.ops.aten.copy_.default(primals_176, add_623); primals_176 = add_623 = copy__9 = None + copy__10 = torch.ops.aten.copy_.default(primals_192, add_691); primals_192 = add_691 = copy__10 = None + copy__11 = torch.ops.aten.copy_.default(primals_208, add_759); primals_208 = add_759 = copy__11 = None + copy__12 = torch.ops.aten.copy_.default(primals_224, add_827); primals_224 = add_827 = copy__12 = None + copy__13 = torch.ops.aten.copy_.default(primals_240, add_895); primals_240 = add_895 = copy__13 = None + copy__14 = torch.ops.aten.copy_.default(primals_256, add_963); primals_256 = add_963 = copy__14 = None + copy__15 = torch.ops.aten.copy_.default(primals_272, add_1031); primals_272 = add_1031 = copy__15 = None + copy__16 = torch.ops.aten.copy_.default(primals_288, add_1099); primals_288 = add_1099 = copy__16 = None + copy__17 = torch.ops.aten.copy_.default(primals_304, add_1167); primals_304 = add_1167 = copy__17 = None + copy__18 = torch.ops.aten.copy_.default(primals_320, add_1235); primals_320 = add_1235 = copy__18 = None + copy__19 = torch.ops.aten.copy_.default(primals_336, add_1303); primals_336 = add_1303 = copy__19 = None + copy__20 = torch.ops.aten.copy_.default(primals_352, add_1371); primals_352 = add_1371 = copy__20 = None + copy__21 = torch.ops.aten.copy_.default(primals_368, add_1439); primals_368 = add_1439 = copy__21 = None + copy__22 = torch.ops.aten.copy_.default(primals_384, add_1507); primals_384 = add_1507 = copy__22 = None + copy__23 = torch.ops.aten.copy_.default(primals_400, add_1575); primals_400 = add_1575 = copy__23 = None + copy__24 = torch.ops.aten.copy_.default(primals_416, add_1643); primals_416 = add_1643 = copy__24 = None + copy__25 = torch.ops.aten.copy_.default(primals_432, add_1711); primals_432 = add_1711 = copy__25 = None + return (view_1779, getitem_22, sym_sum_1, _local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15, _local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7, getitem_132, sym_sum_3, _local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31, _local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23, getitem_242, sym_sum_5, _local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47, _local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39, getitem_352, sym_sum_7, _local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63, _local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55, getitem_462, sym_sum_9, _local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79, _local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71, getitem_572, sym_sum_11, _local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95, _local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87, getitem_682, sym_sum_13, _local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111, _local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103, getitem_792, sym_sum_15, _local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127, _local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119, getitem_902, sym_sum_17, _local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143, _local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135, getitem_1012, sym_sum_19, _local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159, _local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151, getitem_1122, sym_sum_21, _local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175, _local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167, getitem_1232, sym_sum_23, _local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191, _local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183, getitem_1342, sym_sum_25, _local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207, _local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199, getitem_1452, sym_sum_27, _local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223, _local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215, getitem_1562, sym_sum_29, _local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239, _local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231, getitem_1672, sym_sum_31, _local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255, _local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247, getitem_1782, sym_sum_33, _local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271, _local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263, getitem_1892, sym_sum_35, _local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287, _local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279, getitem_2002, sym_sum_37, _local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303, _local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295, getitem_2112, sym_sum_39, _local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319, _local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311, getitem_2222, sym_sum_41, _local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335, _local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327, getitem_2332, sym_sum_43, _local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351, _local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343, getitem_2442, sym_sum_45, _local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367, _local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359, getitem_2552, sym_sum_47, _local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383, _local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375, getitem_2662, sym_sum_49, _local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399, _local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391, getitem_2772, sym_sum_51, _local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415, _local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_31, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_47, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_63, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_79, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_95, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_111, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_127, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_143, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_159, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_175, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_191, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_207, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_223, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_239, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_255, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_271, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_287, primals_289, primals_290, primals_291, primals_292, primals_293, primals_294, primals_295, primals_296, primals_297, primals_298, primals_299, primals_300, primals_301, primals_303, primals_305, primals_306, primals_307, primals_308, primals_309, primals_310, primals_311, primals_312, primals_313, primals_314, primals_315, primals_316, primals_317, primals_319, primals_321, primals_322, primals_323, primals_324, primals_325, primals_326, primals_327, primals_328, primals_329, primals_330, primals_331, primals_332, primals_333, primals_335, primals_337, primals_338, primals_339, primals_340, primals_341, primals_342, primals_343, primals_344, primals_345, primals_346, primals_347, primals_348, primals_349, primals_351, primals_353, primals_354, primals_355, primals_356, primals_357, primals_358, primals_359, primals_360, primals_361, primals_362, primals_363, primals_364, primals_365, primals_367, primals_369, primals_370, primals_371, primals_372, primals_373, primals_374, primals_375, primals_376, primals_377, primals_378, primals_379, primals_380, primals_381, primals_383, primals_385, primals_386, primals_387, primals_388, primals_389, primals_390, primals_391, primals_392, primals_393, primals_394, primals_395, primals_396, primals_397, primals_399, primals_401, primals_402, primals_403, primals_404, primals_405, primals_406, primals_407, primals_408, primals_409, primals_410, primals_411, primals_412, primals_413, primals_415, primals_417, primals_418, primals_419, primals_420, primals_421, primals_422, primals_423, primals_424, primals_425, primals_426, primals_427, primals_428, primals_429, primals_431, primals_433, primals_434, primals_435, primals_436, primals_437, primals_438, primals_439, primals_440, embedding, rsqrt, view_3, getitem_2, rsqrt_1, view_17, permute_3, permute_4, permute_5, getitem_6, getitem_7, mm_3, rsqrt_2, view_26, mm_4, mm_5, view_32, add_5, rsqrt_3, view_36, getitem_11, rsqrt_4, view_50, permute_14, permute_15, permute_16, getitem_15, getitem_16, add_8, rsqrt_5, view_58, mm_11, amax, sum_1, getitem_19, getitem_21, div_2, getitem_22, index_1, cumsum_2, _grouped_mm, _grouped_mm_1, mul_35, mm_12, mm_13, mul_55, add_73, rsqrt_6, view_103, getitem_121, rsqrt_7, view_117, permute_29, permute_30, permute_31, getitem_125, getitem_126, add_76, rsqrt_8, view_125, mm_19, amax_1, sum_5, getitem_129, getitem_131, div_7, getitem_132, index_3, cumsum_5, _grouped_mm_3, _grouped_mm_4, mul_84, mm_20, mm_21, mul_104, add_141, rsqrt_9, view_170, getitem_231, rsqrt_10, view_184, permute_44, permute_45, permute_46, getitem_235, getitem_236, add_144, rsqrt_11, view_192, mm_27, amax_2, sum_9, getitem_239, getitem_241, div_12, getitem_242, index_5, cumsum_8, _grouped_mm_6, _grouped_mm_7, mul_133, mm_28, mm_29, mul_153, add_209, rsqrt_12, view_237, getitem_341, rsqrt_13, view_251, permute_59, permute_60, permute_61, getitem_345, getitem_346, add_212, rsqrt_14, view_259, mm_35, amax_3, sum_13, getitem_349, getitem_351, div_17, getitem_352, index_7, cumsum_11, _grouped_mm_9, _grouped_mm_10, mul_182, mm_36, mm_37, mul_202, add_277, rsqrt_15, view_304, getitem_451, rsqrt_16, view_318, permute_74, permute_75, permute_76, getitem_455, getitem_456, add_280, rsqrt_17, view_326, mm_43, amax_4, sum_17, getitem_459, getitem_461, div_22, getitem_462, index_9, cumsum_14, _grouped_mm_12, _grouped_mm_13, mul_231, mm_44, mm_45, mul_251, add_345, rsqrt_18, view_371, getitem_561, rsqrt_19, view_385, permute_89, permute_90, permute_91, getitem_565, getitem_566, add_348, rsqrt_20, view_393, mm_51, amax_5, sum_21, getitem_569, getitem_571, div_27, getitem_572, index_11, cumsum_17, _grouped_mm_15, _grouped_mm_16, mul_280, mm_52, mm_53, mul_300, add_413, rsqrt_21, view_438, getitem_671, rsqrt_22, view_452, permute_104, permute_105, permute_106, getitem_675, getitem_676, add_416, rsqrt_23, view_460, mm_59, amax_6, sum_25, getitem_679, getitem_681, div_32, getitem_682, index_13, cumsum_20, _grouped_mm_18, _grouped_mm_19, mul_329, mm_60, mm_61, mul_349, add_481, rsqrt_24, view_505, getitem_781, rsqrt_25, view_519, permute_119, permute_120, permute_121, getitem_785, getitem_786, add_484, rsqrt_26, view_527, mm_67, amax_7, sum_29, getitem_789, getitem_791, div_37, getitem_792, index_15, cumsum_23, _grouped_mm_21, _grouped_mm_22, mul_378, mm_68, mm_69, mul_398, add_549, rsqrt_27, view_572, getitem_891, rsqrt_28, view_586, permute_134, permute_135, permute_136, getitem_895, getitem_896, add_552, rsqrt_29, view_594, mm_75, amax_8, sum_33, getitem_899, getitem_901, div_42, getitem_902, index_17, cumsum_26, _grouped_mm_24, _grouped_mm_25, mul_427, mm_76, mm_77, mul_447, add_617, rsqrt_30, view_639, getitem_1001, rsqrt_31, view_653, permute_149, permute_150, permute_151, getitem_1005, getitem_1006, add_620, rsqrt_32, view_661, mm_83, amax_9, sum_37, getitem_1009, getitem_1011, div_47, getitem_1012, index_19, cumsum_29, _grouped_mm_27, _grouped_mm_28, mul_476, mm_84, mm_85, mul_496, add_685, rsqrt_33, view_706, getitem_1111, rsqrt_34, view_720, permute_164, permute_165, permute_166, getitem_1115, getitem_1116, add_688, rsqrt_35, view_728, mm_91, amax_10, sum_41, getitem_1119, getitem_1121, div_52, getitem_1122, index_21, cumsum_32, _grouped_mm_30, _grouped_mm_31, mul_525, mm_92, mm_93, mul_545, add_753, rsqrt_36, view_773, getitem_1221, rsqrt_37, view_787, permute_179, permute_180, permute_181, getitem_1225, getitem_1226, add_756, rsqrt_38, view_795, mm_99, amax_11, sum_45, getitem_1229, getitem_1231, div_57, getitem_1232, index_23, cumsum_35, _grouped_mm_33, _grouped_mm_34, mul_574, mm_100, mm_101, mul_594, add_821, rsqrt_39, view_840, getitem_1331, rsqrt_40, view_854, permute_194, permute_195, permute_196, getitem_1335, getitem_1336, add_824, rsqrt_41, view_862, mm_107, amax_12, sum_49, getitem_1339, getitem_1341, div_62, getitem_1342, index_25, cumsum_38, _grouped_mm_36, _grouped_mm_37, mul_623, mm_108, mm_109, mul_643, add_889, rsqrt_42, view_907, getitem_1441, rsqrt_43, view_921, permute_209, permute_210, permute_211, getitem_1445, getitem_1446, add_892, rsqrt_44, view_929, mm_115, amax_13, sum_53, getitem_1449, getitem_1451, div_67, getitem_1452, index_27, cumsum_41, _grouped_mm_39, _grouped_mm_40, mul_672, mm_116, mm_117, mul_692, add_957, rsqrt_45, view_974, getitem_1551, rsqrt_46, view_988, permute_224, permute_225, permute_226, getitem_1555, getitem_1556, add_960, rsqrt_47, view_996, mm_123, amax_14, sum_57, getitem_1559, getitem_1561, div_72, getitem_1562, index_29, cumsum_44, _grouped_mm_42, _grouped_mm_43, mul_721, mm_124, mm_125, mul_741, add_1025, rsqrt_48, view_1041, getitem_1661, rsqrt_49, view_1055, permute_239, permute_240, permute_241, getitem_1665, getitem_1666, add_1028, rsqrt_50, view_1063, mm_131, amax_15, sum_61, getitem_1669, getitem_1671, div_77, getitem_1672, index_31, cumsum_47, _grouped_mm_45, _grouped_mm_46, mul_770, mm_132, mm_133, mul_790, add_1093, rsqrt_51, view_1108, getitem_1771, rsqrt_52, view_1122, permute_254, permute_255, permute_256, getitem_1775, getitem_1776, add_1096, rsqrt_53, view_1130, mm_139, amax_16, sum_65, getitem_1779, getitem_1781, div_82, getitem_1782, index_33, cumsum_50, _grouped_mm_48, _grouped_mm_49, mul_819, mm_140, mm_141, mul_839, add_1161, rsqrt_54, view_1175, getitem_1881, rsqrt_55, view_1189, permute_269, permute_270, permute_271, getitem_1885, getitem_1886, add_1164, rsqrt_56, view_1197, mm_147, amax_17, sum_69, getitem_1889, getitem_1891, div_87, getitem_1892, index_35, cumsum_53, _grouped_mm_51, _grouped_mm_52, mul_868, mm_148, mm_149, mul_888, add_1229, rsqrt_57, view_1242, getitem_1991, rsqrt_58, view_1256, permute_284, permute_285, permute_286, getitem_1995, getitem_1996, add_1232, rsqrt_59, view_1264, mm_155, amax_18, sum_73, getitem_1999, getitem_2001, div_92, getitem_2002, index_37, cumsum_56, _grouped_mm_54, _grouped_mm_55, mul_917, mm_156, mm_157, mul_937, add_1297, rsqrt_60, view_1309, getitem_2101, rsqrt_61, view_1323, permute_299, permute_300, permute_301, getitem_2105, getitem_2106, add_1300, rsqrt_62, view_1331, mm_163, amax_19, sum_77, getitem_2109, getitem_2111, div_97, getitem_2112, index_39, cumsum_59, _grouped_mm_57, _grouped_mm_58, mul_966, mm_164, mm_165, mul_986, add_1365, rsqrt_63, view_1376, getitem_2211, rsqrt_64, view_1390, permute_314, permute_315, permute_316, getitem_2215, getitem_2216, add_1368, rsqrt_65, view_1398, mm_171, amax_20, sum_81, getitem_2219, getitem_2221, div_102, getitem_2222, index_41, cumsum_62, _grouped_mm_60, _grouped_mm_61, mul_1015, mm_172, mm_173, mul_1035, add_1433, rsqrt_66, view_1443, getitem_2321, rsqrt_67, view_1457, permute_329, permute_330, permute_331, getitem_2325, getitem_2326, add_1436, rsqrt_68, view_1465, mm_179, amax_21, sum_85, getitem_2329, getitem_2331, div_107, getitem_2332, index_43, cumsum_65, _grouped_mm_63, _grouped_mm_64, mul_1064, mm_180, mm_181, mul_1084, add_1501, rsqrt_69, view_1510, getitem_2431, rsqrt_70, view_1524, permute_344, permute_345, permute_346, getitem_2435, getitem_2436, add_1504, rsqrt_71, view_1532, mm_187, amax_22, sum_89, getitem_2439, getitem_2441, div_112, getitem_2442, index_45, cumsum_68, _grouped_mm_66, _grouped_mm_67, mul_1113, mm_188, mm_189, mul_1133, add_1569, rsqrt_72, view_1577, getitem_2541, rsqrt_73, view_1591, permute_359, permute_360, permute_361, getitem_2545, getitem_2546, add_1572, rsqrt_74, view_1599, mm_195, amax_23, sum_93, getitem_2549, getitem_2551, div_117, getitem_2552, index_47, cumsum_71, _grouped_mm_69, _grouped_mm_70, mul_1162, mm_196, mm_197, mul_1182, add_1637, rsqrt_75, view_1644, getitem_2651, rsqrt_76, view_1658, permute_374, permute_375, permute_376, getitem_2655, getitem_2656, add_1640, rsqrt_77, view_1666, mm_203, amax_24, sum_97, getitem_2659, getitem_2661, div_122, getitem_2662, index_49, cumsum_74, _grouped_mm_72, _grouped_mm_73, mul_1211, mm_204, mm_205, mul_1231, add_1705, rsqrt_78, view_1711, getitem_2761, rsqrt_79, view_1725, permute_389, permute_390, permute_391, getitem_2765, getitem_2766, add_1708, rsqrt_80, view_1733, mm_211, amax_25, sum_101, getitem_2769, getitem_2771, div_127, getitem_2772, index_51, cumsum_77, _grouped_mm_75, _grouped_mm_76, mul_1260, mm_212, mm_213, mul_1280, add_1773, rsqrt_81, view_1778, permute_406, permute_407, permute_422, permute_426, permute_430, full_default_54, permute_456, permute_457, permute_472, permute_476, permute_480, permute_506, permute_507, permute_522, permute_526, permute_530, permute_556, permute_557, permute_572, permute_576, permute_580, permute_606, permute_607, permute_622, permute_626, permute_630, permute_656, permute_657, permute_672, permute_676, permute_680, permute_706, permute_707, permute_722, permute_726, permute_730, permute_756, permute_757, permute_772, permute_776, permute_780, permute_806, permute_807, permute_822, permute_826, permute_830, permute_856, permute_857, permute_872, permute_876, permute_880, permute_906, permute_907, permute_922, permute_926, permute_930, permute_956, permute_957, permute_972, permute_976, permute_980, permute_1006, permute_1007, permute_1022, permute_1026, permute_1030, permute_1056, permute_1057, permute_1072, permute_1076, permute_1080, permute_1106, permute_1107, permute_1122, permute_1126, permute_1130, permute_1156, permute_1157, permute_1172, permute_1176, permute_1180, permute_1206, permute_1207, permute_1222, permute_1226, permute_1230, permute_1256, permute_1257, permute_1272, permute_1276, permute_1280, permute_1306, permute_1307, permute_1322, permute_1326, permute_1330, permute_1356, permute_1357, permute_1372, permute_1376, permute_1380, permute_1406, permute_1407, permute_1422, permute_1426, permute_1430, permute_1456, permute_1457, permute_1472, permute_1476, permute_1480, permute_1506, permute_1507, permute_1522, permute_1526, permute_1530, permute_1556, permute_1557, permute_1572, permute_1576, permute_1580, permute_1606, permute_1607, permute_1622, permute_1626, permute_1630, permute_1656, permute_1657, permute_1672, permute_1676, permute_1680, _local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7, _local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15, _local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23, _local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31, _local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39, _local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47, _local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55, _local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63, _local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71, _local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79, _local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87, _local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95, _local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103, _local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111, _local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119, _local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127, _local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135, _local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143, _local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151, _local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159, _local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167, _local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175, _local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183, _local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191, _local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199, _local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207, _local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215, _local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223, _local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231, _local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239, _local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247, _local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255, _local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263, _local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271, _local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279, _local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287, _local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295, _local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303, _local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311, _local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319, _local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327, _local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335, _local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343, _local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351, _local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359, _local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367, _local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375, _local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383, _local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391, _local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399, _local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407, _local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415, sym_size_int_1, sym_size_int_5, sym_size_int_9, sym_size_int_13, sym_size_int_17, sym_size_int_21, sym_size_int_25, sym_size_int_29, sym_size_int_33, sym_size_int_37, sym_size_int_41, sym_size_int_45, sym_size_int_49, sym_size_int_53, sym_size_int_57, sym_size_int_61, sym_size_int_65, sym_size_int_69, sym_size_int_73, sym_size_int_77, sym_size_int_81, sym_size_int_85, sym_size_int_89, sym_size_int_93, sym_size_int_97, sym_size_int_101, add_1781, add_1796, add_1811, add_1826, add_1841, add_1856, add_1871, add_1886, add_1901, add_1916, add_1931, add_1946, add_1961, add_1976, add_1991, add_2006, add_2021, add_2036, add_2051, add_2066, add_2081, add_2096, add_2111, add_2126, add_2141, add_2156) + +def load_args(reader): + buf0 = reader.storage(None, 6553600, device=device(type='cuda', index=0)) + reader.tensor(buf0, (800, 2048), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 65536, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 4096), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (4096, 32), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf3, (16,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf4, (24, 2048), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf5, (5, 2048), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf6, (4,), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf7, (32, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf8, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_9 + buf9 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf9, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_10 + buf10 = reader.storage(None, 32768, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf10, (2, 4096), dtype=torch.int32, is_leaf=True) # primals_11 + buf11 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf11, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_12 + buf12 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf12, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_13 + buf13 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf13, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_14 + buf14 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf14, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_15 + buf15 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf15, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_16 + buf16 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf16, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_17 + buf17 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf17, (16, 2048), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf18, (16,), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 704512, device=device(type='cuda', index=0)) + reader.tensor(buf19, (86, 2048), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 704512, device=device(type='cuda', index=0)) + reader.tensor(buf20, (86, 2048), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 700416, device=device(type='cuda', index=0)) + reader.tensor(buf21, (16, 10944), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf22, (16,), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf23, (24, 2048), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf24, (5, 2048), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf25, (4,), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf26, (32, 512), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf27, (16, 2048), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf28, (16,), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf29, (64,), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf30, (1, 2048), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf31, (64,), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf32, (8, 88, 2048), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf33, (8, 128, 1408), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf34, (8, 88, 2048), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf35, (22, 2048), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf36, (22, 2048), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf37, (16, 2816), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf38, (16,), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf39, (24, 2048), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf40, (5, 2048), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf41, (4,), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf42, (32, 512), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf43, (16, 2048), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf44, (16,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf45, (64,), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf46, (1, 2048), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf47, (64,), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf48, (8, 88, 2048), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf49, (8, 128, 1408), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf50, (8, 88, 2048), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf51, (22, 2048), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf52, (22, 2048), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf53, (16, 2816), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf54, (16,), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf55, (24, 2048), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf56, (5, 2048), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf57, (4,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf58, (32, 512), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf59, (16, 2048), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf60, (16,), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf61, (64,), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf62, (1, 2048), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf63, (64,), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf64, (8, 88, 2048), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf65, (8, 128, 1408), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf66, (8, 88, 2048), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf67, (22, 2048), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf68, (22, 2048), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf69, (16, 2816), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf70, (16,), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf71, (24, 2048), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf72, (5, 2048), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf73, (4,), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf74, (32, 512), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf75, (16, 2048), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf76, (16,), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf77, (64,), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf78, (1, 2048), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf79, (64,), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf80, (8, 88, 2048), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf81, (8, 128, 1408), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf82, (8, 88, 2048), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf83, (22, 2048), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf84, (22, 2048), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf85, (16, 2816), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf86, (16,), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf87, (24, 2048), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf88, (5, 2048), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf89, (4,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf90, (32, 512), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf91, (16, 2048), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf92, (16,), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf93, (64,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf94, (1, 2048), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf95, (64,), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf96, (8, 88, 2048), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf97, (8, 128, 1408), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf98, (8, 88, 2048), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf99, (22, 2048), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf100, (22, 2048), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf101, (16, 2816), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf102, (16,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf103, (24, 2048), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf104, (5, 2048), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf105, (4,), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf106, (32, 512), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf107, (16, 2048), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf108, (16,), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf109, (64,), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf110, (1, 2048), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf111, (64,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf112, (8, 88, 2048), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf113, (8, 128, 1408), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf114, (8, 88, 2048), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf115, (22, 2048), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf116, (22, 2048), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf117, (16, 2816), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf118, (16,), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf119, (24, 2048), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf120, (5, 2048), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf121, (4,), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf122, (32, 512), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf123, (16, 2048), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf124, (16,), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf125, (64,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf126, (1, 2048), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf127, (64,), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf128, (8, 88, 2048), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf129, (8, 128, 1408), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf130, (8, 88, 2048), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf131, (22, 2048), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf132, (22, 2048), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf133, (16, 2816), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf134, (16,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf135, (24, 2048), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf136, (5, 2048), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf137, (4,), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf138, (32, 512), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf139, (16, 2048), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf140, (16,), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf141, (64,), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf142, (1, 2048), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf143, (64,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf144, (8, 88, 2048), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf145, (8, 128, 1408), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf146, (8, 88, 2048), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf147, (22, 2048), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf148, (22, 2048), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf149, (16, 2816), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf150, (16,), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf151, (24, 2048), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf152, (5, 2048), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf153, (4,), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf154, (32, 512), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf155, (16, 2048), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf156, (16,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf157, (64,), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf158, (1, 2048), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf159, (64,), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf160, (8, 88, 2048), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf161, (8, 128, 1408), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf162, (8, 88, 2048), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf163, (22, 2048), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf164, (22, 2048), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf165, (16, 2816), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf166, (16,), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf167, (24, 2048), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf168, (5, 2048), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf169, (4,), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf170, (32, 512), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf171, (16, 2048), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf172, (16,), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf173, (64,), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf174, (1, 2048), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf175, (64,), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf176, (8, 88, 2048), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf177, (8, 128, 1408), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf178, (8, 88, 2048), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf179, (22, 2048), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf180, (22, 2048), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf181, (16, 2816), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf182, (16,), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf183, (24, 2048), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf184, (5, 2048), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf185, (4,), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf186, (32, 512), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf187, (16, 2048), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf188, (16,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf189, (64,), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf190, (1, 2048), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf191, (64,), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf192, (8, 88, 2048), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf193, (8, 128, 1408), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf194, (8, 88, 2048), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf195, (22, 2048), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf196, (22, 2048), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf197, (16, 2816), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf198, (16,), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf199, (24, 2048), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf200, (5, 2048), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf201, (4,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf202, (32, 512), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf203, (16, 2048), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf204, (16,), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf205, (64,), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf206, (1, 2048), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf207, (64,), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf208, (8, 88, 2048), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf209, (8, 128, 1408), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf210, (8, 88, 2048), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf211, (22, 2048), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf212, (22, 2048), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf213, (16, 2816), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf214, (16,), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf215, (24, 2048), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf216, (5, 2048), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf217, (4,), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf218, (32, 512), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf219, (16, 2048), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf220, (16,), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf221, (64,), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf222, (1, 2048), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf223, (64,), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf224, (8, 88, 2048), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf225, (8, 128, 1408), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf226, (8, 88, 2048), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf227, (22, 2048), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf228, (22, 2048), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf229, (16, 2816), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf230, (16,), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf231, (24, 2048), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf232, (5, 2048), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf233, (4,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf234, (32, 512), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf235, (16, 2048), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf236, (16,), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf237, (64,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf238, (1, 2048), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf239, (64,), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf240, (8, 88, 2048), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf241, (8, 128, 1408), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf242, (8, 88, 2048), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf243, (22, 2048), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf244, (22, 2048), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf245, (16, 2816), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf246, (16,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf247, (24, 2048), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf248, (5, 2048), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf249, (4,), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf250, (32, 512), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf251, (16, 2048), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf252, (16,), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf253, (64,), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf254, (1, 2048), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf255, (64,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf256, (8, 88, 2048), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf257, (8, 128, 1408), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf258, (8, 88, 2048), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf259, (22, 2048), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf260, (22, 2048), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf261, (16, 2816), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf262, (16,), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf263, (24, 2048), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf264, (5, 2048), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf265, (4,), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf266, (32, 512), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf267, (16, 2048), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf268, (16,), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf269, (64,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf270, (1, 2048), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf271, (64,), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf272, (8, 88, 2048), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf273, (8, 128, 1408), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf274, (8, 88, 2048), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf275, (22, 2048), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf276, (22, 2048), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf277, (16, 2816), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf278, (16,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf279, (24, 2048), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf280, (5, 2048), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf281, (4,), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf282, (32, 512), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf283, (16, 2048), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf284, (16,), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf285, (64,), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf286, (1, 2048), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf287, (64,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf288, (8, 88, 2048), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf289, (8, 128, 1408), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf290, (8, 88, 2048), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf291, (22, 2048), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf292, (22, 2048), is_leaf=True) # primals_293 + buf293 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf293, (16, 2816), is_leaf=True) # primals_294 + buf294 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf294, (16,), is_leaf=True) # primals_295 + buf295 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf295, (24, 2048), is_leaf=True) # primals_296 + buf296 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf296, (5, 2048), is_leaf=True) # primals_297 + buf297 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf297, (4,), is_leaf=True) # primals_298 + buf298 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf298, (32, 512), is_leaf=True) # primals_299 + buf299 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf299, (16, 2048), is_leaf=True) # primals_300 + buf300 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf300, (16,), is_leaf=True) # primals_301 + buf301 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf301, (64,), is_leaf=True) # primals_302 + buf302 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf302, (1, 2048), is_leaf=True) # primals_303 + buf303 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf303, (64,), is_leaf=True) # primals_304 + buf304 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf304, (8, 88, 2048), is_leaf=True) # primals_305 + buf305 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf305, (8, 128, 1408), is_leaf=True) # primals_306 + buf306 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf306, (8, 88, 2048), is_leaf=True) # primals_307 + buf307 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf307, (22, 2048), is_leaf=True) # primals_308 + buf308 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf308, (22, 2048), is_leaf=True) # primals_309 + buf309 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf309, (16, 2816), is_leaf=True) # primals_310 + buf310 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf310, (16,), is_leaf=True) # primals_311 + buf311 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf311, (24, 2048), is_leaf=True) # primals_312 + buf312 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf312, (5, 2048), is_leaf=True) # primals_313 + buf313 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf313, (4,), is_leaf=True) # primals_314 + buf314 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf314, (32, 512), is_leaf=True) # primals_315 + buf315 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf315, (16, 2048), is_leaf=True) # primals_316 + buf316 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf316, (16,), is_leaf=True) # primals_317 + buf317 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf317, (64,), is_leaf=True) # primals_318 + buf318 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf318, (1, 2048), is_leaf=True) # primals_319 + buf319 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf319, (64,), is_leaf=True) # primals_320 + buf320 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf320, (8, 88, 2048), is_leaf=True) # primals_321 + buf321 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf321, (8, 128, 1408), is_leaf=True) # primals_322 + buf322 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf322, (8, 88, 2048), is_leaf=True) # primals_323 + buf323 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf323, (22, 2048), is_leaf=True) # primals_324 + buf324 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf324, (22, 2048), is_leaf=True) # primals_325 + buf325 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf325, (16, 2816), is_leaf=True) # primals_326 + buf326 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf326, (16,), is_leaf=True) # primals_327 + buf327 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf327, (24, 2048), is_leaf=True) # primals_328 + buf328 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf328, (5, 2048), is_leaf=True) # primals_329 + buf329 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf329, (4,), is_leaf=True) # primals_330 + buf330 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf330, (32, 512), is_leaf=True) # primals_331 + buf331 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf331, (16, 2048), is_leaf=True) # primals_332 + buf332 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf332, (16,), is_leaf=True) # primals_333 + buf333 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf333, (64,), is_leaf=True) # primals_334 + buf334 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf334, (1, 2048), is_leaf=True) # primals_335 + buf335 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf335, (64,), is_leaf=True) # primals_336 + buf336 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf336, (8, 88, 2048), is_leaf=True) # primals_337 + buf337 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf337, (8, 128, 1408), is_leaf=True) # primals_338 + buf338 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf338, (8, 88, 2048), is_leaf=True) # primals_339 + buf339 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf339, (22, 2048), is_leaf=True) # primals_340 + buf340 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf340, (22, 2048), is_leaf=True) # primals_341 + buf341 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf341, (16, 2816), is_leaf=True) # primals_342 + buf342 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf342, (16,), is_leaf=True) # primals_343 + buf343 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf343, (24, 2048), is_leaf=True) # primals_344 + buf344 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf344, (5, 2048), is_leaf=True) # primals_345 + buf345 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf345, (4,), is_leaf=True) # primals_346 + buf346 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf346, (32, 512), is_leaf=True) # primals_347 + buf347 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf347, (16, 2048), is_leaf=True) # primals_348 + buf348 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf348, (16,), is_leaf=True) # primals_349 + buf349 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf349, (64,), is_leaf=True) # primals_350 + buf350 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf350, (1, 2048), is_leaf=True) # primals_351 + buf351 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf351, (64,), is_leaf=True) # primals_352 + buf352 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf352, (8, 88, 2048), is_leaf=True) # primals_353 + buf353 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf353, (8, 128, 1408), is_leaf=True) # primals_354 + buf354 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf354, (8, 88, 2048), is_leaf=True) # primals_355 + buf355 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf355, (22, 2048), is_leaf=True) # primals_356 + buf356 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf356, (22, 2048), is_leaf=True) # primals_357 + buf357 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf357, (16, 2816), is_leaf=True) # primals_358 + buf358 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf358, (16,), is_leaf=True) # primals_359 + buf359 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf359, (24, 2048), is_leaf=True) # primals_360 + buf360 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf360, (5, 2048), is_leaf=True) # primals_361 + buf361 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf361, (4,), is_leaf=True) # primals_362 + buf362 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf362, (32, 512), is_leaf=True) # primals_363 + buf363 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf363, (16, 2048), is_leaf=True) # primals_364 + buf364 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf364, (16,), is_leaf=True) # primals_365 + buf365 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf365, (64,), is_leaf=True) # primals_366 + buf366 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf366, (1, 2048), is_leaf=True) # primals_367 + buf367 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf367, (64,), is_leaf=True) # primals_368 + buf368 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf368, (8, 88, 2048), is_leaf=True) # primals_369 + buf369 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf369, (8, 128, 1408), is_leaf=True) # primals_370 + buf370 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf370, (8, 88, 2048), is_leaf=True) # primals_371 + buf371 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf371, (22, 2048), is_leaf=True) # primals_372 + buf372 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf372, (22, 2048), is_leaf=True) # primals_373 + buf373 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf373, (16, 2816), is_leaf=True) # primals_374 + buf374 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf374, (16,), is_leaf=True) # primals_375 + buf375 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf375, (24, 2048), is_leaf=True) # primals_376 + buf376 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf376, (5, 2048), is_leaf=True) # primals_377 + buf377 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf377, (4,), is_leaf=True) # primals_378 + buf378 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf378, (32, 512), is_leaf=True) # primals_379 + buf379 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf379, (16, 2048), is_leaf=True) # primals_380 + buf380 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf380, (16,), is_leaf=True) # primals_381 + buf381 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf381, (64,), is_leaf=True) # primals_382 + buf382 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf382, (1, 2048), is_leaf=True) # primals_383 + buf383 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf383, (64,), is_leaf=True) # primals_384 + buf384 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf384, (8, 88, 2048), is_leaf=True) # primals_385 + buf385 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf385, (8, 128, 1408), is_leaf=True) # primals_386 + buf386 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf386, (8, 88, 2048), is_leaf=True) # primals_387 + buf387 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf387, (22, 2048), is_leaf=True) # primals_388 + buf388 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf388, (22, 2048), is_leaf=True) # primals_389 + buf389 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf389, (16, 2816), is_leaf=True) # primals_390 + buf390 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf390, (16,), is_leaf=True) # primals_391 + buf391 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf391, (24, 2048), is_leaf=True) # primals_392 + buf392 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf392, (5, 2048), is_leaf=True) # primals_393 + buf393 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf393, (4,), is_leaf=True) # primals_394 + buf394 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf394, (32, 512), is_leaf=True) # primals_395 + buf395 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf395, (16, 2048), is_leaf=True) # primals_396 + buf396 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf396, (16,), is_leaf=True) # primals_397 + buf397 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf397, (64,), is_leaf=True) # primals_398 + buf398 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf398, (1, 2048), is_leaf=True) # primals_399 + buf399 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf399, (64,), is_leaf=True) # primals_400 + buf400 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf400, (8, 88, 2048), is_leaf=True) # primals_401 + buf401 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf401, (8, 128, 1408), is_leaf=True) # primals_402 + buf402 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf402, (8, 88, 2048), is_leaf=True) # primals_403 + buf403 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf403, (22, 2048), is_leaf=True) # primals_404 + buf404 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf404, (22, 2048), is_leaf=True) # primals_405 + buf405 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf405, (16, 2816), is_leaf=True) # primals_406 + buf406 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf406, (16,), is_leaf=True) # primals_407 + buf407 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf407, (24, 2048), is_leaf=True) # primals_408 + buf408 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf408, (5, 2048), is_leaf=True) # primals_409 + buf409 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf409, (4,), is_leaf=True) # primals_410 + buf410 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf410, (32, 512), is_leaf=True) # primals_411 + buf411 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf411, (16, 2048), is_leaf=True) # primals_412 + buf412 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf412, (16,), is_leaf=True) # primals_413 + buf413 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf413, (64,), is_leaf=True) # primals_414 + buf414 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf414, (1, 2048), is_leaf=True) # primals_415 + buf415 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf415, (64,), is_leaf=True) # primals_416 + buf416 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf416, (8, 88, 2048), is_leaf=True) # primals_417 + buf417 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf417, (8, 128, 1408), is_leaf=True) # primals_418 + buf418 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf418, (8, 88, 2048), is_leaf=True) # primals_419 + buf419 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf419, (22, 2048), is_leaf=True) # primals_420 + buf420 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf420, (22, 2048), is_leaf=True) # primals_421 + buf421 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf421, (16, 2816), is_leaf=True) # primals_422 + buf422 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf422, (16,), is_leaf=True) # primals_423 + buf423 = reader.storage(None, 196608, device=device(type='cuda', index=0)) + reader.tensor(buf423, (24, 2048), is_leaf=True) # primals_424 + buf424 = reader.storage(None, 40960, device=device(type='cuda', index=0)) + reader.tensor(buf424, (5, 2048), is_leaf=True) # primals_425 + buf425 = reader.storage(None, 16, device=device(type='cuda', index=0)) + reader.tensor(buf425, (4,), is_leaf=True) # primals_426 + buf426 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf426, (32, 512), is_leaf=True) # primals_427 + buf427 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf427, (16, 2048), is_leaf=True) # primals_428 + buf428 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf428, (16,), is_leaf=True) # primals_429 + buf429 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf429, (64,), is_leaf=True) # primals_430 + buf430 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf430, (1, 2048), is_leaf=True) # primals_431 + buf431 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf431, (64,), is_leaf=True) # primals_432 + buf432 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf432, (8, 88, 2048), is_leaf=True) # primals_433 + buf433 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf433, (8, 128, 1408), is_leaf=True) # primals_434 + buf434 = reader.storage(None, 5767168, device=device(type='cuda', index=0)) + reader.tensor(buf434, (8, 88, 2048), is_leaf=True) # primals_435 + buf435 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf435, (22, 2048), is_leaf=True) # primals_436 + buf436 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf436, (22, 2048), is_leaf=True) # primals_437 + buf437 = reader.storage(None, 180224, device=device(type='cuda', index=0)) + reader.tensor(buf437, (16, 2816), is_leaf=True) # primals_438 + buf438 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf438, (16,), is_leaf=True) # primals_439 + buf439 = reader.storage(None, 6553600, device=device(type='cuda', index=0)) + reader.tensor(buf439, (800, 2048), is_leaf=True) # primals_440 +load_args._version = 0 +mod = Repro() +if __name__ == '__main__': + from torch._dynamo.repro.after_aot import run_repro + from torch._dynamo.repro.after_aot import setup_fake_process_groups + setup_fake_process_groups({'0': {'size': 128, 'rank': 0}, '1033': {'size': 8, 'rank': 0}, '1025': {'size': 16, 'rank': 0}}) + with torch.no_grad(): + run_repro(mod, load_args, accuracy=False, command='run', save_dir=None, tracing_mode='real', check_str=None) + # To run it separately, do + # mod, args = run_repro(mod, load_args, accuracy=False, command='get_args', save_dir=None, tracing_mode='real', check_str=None) + # mod(*args) + dist.destroy_process_group() + +# Helper functions for overlap simulator +def get_pg_config(): + """DSv3 128 GPUs: FSDP=128, TP=1, EP=8.""" + return {'0': {'size': 128, 'rank': 0}, '1025': {'size': 16, 'rank': 0}, '1033': {'size': 8, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls16_8.table" + +def get_colls_group_mapping(): + # FSDP "0" → internode (table group "0"), all other groups → intranode (table group "1") + return {'0': '0', '1025': '1', '1033': '1'} diff --git a/autoparallel/tools/overlap_simulator/repro_dsv3_fw_64.py b/autoparallel/tools/overlap_simulator/repro_dsv3_fw_64.py new file mode 100644 index 00000000..a70e98ea --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_dsv3_fw_64.py @@ -0,0 +1,8752 @@ +# fmt: off +# flake8: noqa +# isort: skip_file + +import os +os.environ['PYTORCH_KERNEL_CACHE_PATH'] = '/mnt/mffuse/.cache/torch/kernels' +os.environ['TORCH_DISABLE_ADDR2LINE'] = '1' +os.environ['TORCH_TRACE'] = '/mnt/mffuse/outputs/sfsdp-dsv3-16b--tp1-bs2-inductor-64-ivankobzarev-gw1hcmpr/torch_trace/' +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' +os.environ['TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE'] = '[${role_name}${rank}|${local_rank}]:' +os.environ['TORCHELASTIC_MAX_RESTARTS'] = '0' +os.environ['TORCHX_INTERNAL_SESSION_ID'] = 'a7cb45e8-8435-4d98-8768-5273c1f06ab2' +os.environ['TORCHX_RUN_PYTHONPATH'] = '' +os.environ['TORCHELASTIC_ERROR_FILE'] = '/tmp/torchelastic_rm4e8tdn/sfsdp-dsv3-16b--tp1-bs2-inductor-64-ivankobzarev-gw1hcmpr_eyl0q0pn/attempt_0/0/error.json' +os.environ['TORCH_ADDR2LINE_BINARY'] = '/packages/folly.symbolizer/folly-addr2line' +os.environ['TORCHX_JOB_ID'] = 'mast_conda://torchx/sfsdp-dsv3-16b--tp1-bs2-inductor-64-ivankobzarev-gw1hcmpr' +os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '3' +os.environ['TORCHELASTIC_SIGNALS_TO_HANDLE'] = 'SIGTERM,SIGINT,SIGHUP,SIGQUIT' +os.environ['TORCHELASTIC_RUN_ID'] = 'sfsdp-dsv3-16b--tp1-bs2-inductor-64-ivankobzarev-gw1hcmpr' +os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1' +os.environ['TORCHELASTIC_RESTART_COUNT'] = '0' +os.environ['TORCHELASTIC_USE_AGENT_STORE'] = 'False' +os.environ['PYTORCH_DDP_USE_SIDE_STREAM'] = '0' +os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/tmp/torchinductor_root' +os.environ['TORCH_FR_BUFFER_SIZE'] = '20000' +os.environ['TORCH_NCCL_DUMP_ON_TIMEOUT'] = '1' +os.environ['TORCH_FR_DUMP_TEMP_FILE'] = '/mnt/mffuse_nccl_trace/nccl_trace/sfsdp-dsv3-16b--tp1-bs2-inductor-64-ivankobzarev-gw1hcmpr/v_0/attempt_0/nccl_trace_rank_' +os.environ['TRITON_CACHE_DIR'] = '/tmp/torchinductor_root/triton/0' + +import torch +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +import torch._inductor.inductor_prims +import torch.distributed as dist +from torch.testing._internal.distributed.fake_pg import FakeStore +import triton +import triton.language as tl + +import torch._dynamo.config +import torch._inductor.config +import torch._functorch.config +import torch.fx.experimental._config +torch._dynamo.config.capture_scalar_outputs = True +torch._inductor.config.allow_buffer_reuse = False +torch._inductor.config.reorder_for_compute_comm_overlap = False +torch._inductor.config.reorder_for_peak_memory = False +torch._inductor.config.max_autotune = False +torch._inductor.config.coordinate_descent_tuning = False +torch._inductor.config.deterministic = False +torch._inductor.config.aten_distributed_optimizations.collective_bucketing = True +torch._inductor.config.aten_distributed_optimizations.insert_overlap_deps = True +torch._inductor.config.wrap_inductor_compiled_regions = False +torch._inductor.config.triton.cudagraphs = False +torch._inductor.config.triton.store_cubin = False +torch._inductor.config.test_configs.runtime_triton_dtype_assert = False +torch._functorch.config.functionalize_rng_ops = False +torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = True +torch._functorch.config.unlift_effect_tokens = True +torch._functorch.config.selective_decompose = False + + + +isolate_fails_code_str = None + + + + + +if "__compile_source__" in globals(): + import inspect as __after_aot_inspect + import linecache as __after_aot_linecache + __after_aot_filename = __after_aot_inspect.currentframe().f_code.co_filename + __after_aot_linecache.cache[__after_aot_filename] = ( + len(__compile_source__), + None, + __compile_source__.splitlines(True), + __after_aot_filename, + ) +# torch version: 2.11.0a0+git5ac4d4b +# torch cuda version: 12.4 +# torch git version: 5ac4d4bf3f85e15fdd6676f46b090568ea91e47e + + +# CUDA Info: +# nvcc not found +# GPU Hardware Info: +# NVIDIA H100 80GB HBM3 : 8 + +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.reset_table() + +@triton.jit +def _fill_indices_kernel_0( + tokens_per_expert_group_ptr, + start_index_values_ptr, + write_offsets_ptr, + output_ptr, + experts_per_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # Number of threads per block +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # map programs (blocks) to the experts and loop (grid stride) if needed + for expert_id in range(pid, experts_per_rank, num_programs): + # read this experts write offset + write_offset = tl.load(write_offsets_ptr + expert_id) + + for r in range(num_ranks): + # index into tokens_per_expert_group array + i = r * experts_per_rank + expert_id + + # load start index and number of tokens for this expert-rank pair + start_index = tl.load(start_index_values_ptr + i) + length = tl.load(tokens_per_expert_group_ptr + i) + + # each thread in block processes tokens in parallel + offsets = tl.arange(0, BLOCK_SIZE) + + # tokens are processed in chunks of BLOCK_SIZE + for chunk_start in range(0, length, BLOCK_SIZE): + chunk_offsets = chunk_start + offsets + + # mask valid indices + mask = chunk_offsets < length + + values = start_index + chunk_offsets + + # destination + dest_indices = write_offset + chunk_offsets + + # store + tl.store(output_ptr + dest_indices, values, mask=mask) + + # update write offset for next rank + write_offset += length + +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.add_kernel(_fill_indices_kernel_0) +torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.constant_args={0: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 1: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 2: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 3: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 4: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 5: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 6: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 7: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 8: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 9: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 10: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 11: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 12: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 13: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 14: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 15: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 16: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 17: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 18: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 19: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 20: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 21: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 22: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 23: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 24: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}, 25: {'experts_per_rank': 8, 'num_ranks': 8, 'BLOCK_SIZE': 128}} + +from torch.nn import * +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sdpa_score0 = lambda score, b, h, m, n, *args: score + self.sdpa_mask0 = lambda b, h, m, n, *args: True + self.sdpa_score1 = lambda score, b, h, m, n, *args: score + self.sdpa_mask1 = lambda b, h, m, n, *args: True + self.sdpa_score2 = lambda score, b, h, m, n, *args: score + self.sdpa_mask2 = lambda b, h, m, n, *args: True + self.sdpa_score3 = lambda score, b, h, m, n, *args: score + self.sdpa_mask3 = lambda b, h, m, n, *args: True + self.sdpa_score4 = lambda score, b, h, m, n, *args: score + self.sdpa_mask4 = lambda b, h, m, n, *args: True + self.sdpa_score5 = lambda score, b, h, m, n, *args: score + self.sdpa_mask5 = lambda b, h, m, n, *args: True + self.sdpa_score6 = lambda score, b, h, m, n, *args: score + self.sdpa_mask6 = lambda b, h, m, n, *args: True + self.sdpa_score7 = lambda score, b, h, m, n, *args: score + self.sdpa_mask7 = lambda b, h, m, n, *args: True + self.sdpa_score8 = lambda score, b, h, m, n, *args: score + self.sdpa_mask8 = lambda b, h, m, n, *args: True + self.sdpa_score9 = lambda score, b, h, m, n, *args: score + self.sdpa_mask9 = lambda b, h, m, n, *args: True + self.sdpa_score10 = lambda score, b, h, m, n, *args: score + self.sdpa_mask10 = lambda b, h, m, n, *args: True + self.sdpa_score11 = lambda score, b, h, m, n, *args: score + self.sdpa_mask11 = lambda b, h, m, n, *args: True + self.sdpa_score12 = lambda score, b, h, m, n, *args: score + self.sdpa_mask12 = lambda b, h, m, n, *args: True + self.sdpa_score13 = lambda score, b, h, m, n, *args: score + self.sdpa_mask13 = lambda b, h, m, n, *args: True + self.sdpa_score14 = lambda score, b, h, m, n, *args: score + self.sdpa_mask14 = lambda b, h, m, n, *args: True + self.sdpa_score15 = lambda score, b, h, m, n, *args: score + self.sdpa_mask15 = lambda b, h, m, n, *args: True + self.sdpa_score16 = lambda score, b, h, m, n, *args: score + self.sdpa_mask16 = lambda b, h, m, n, *args: True + self.sdpa_score17 = lambda score, b, h, m, n, *args: score + self.sdpa_mask17 = lambda b, h, m, n, *args: True + self.sdpa_score18 = lambda score, b, h, m, n, *args: score + self.sdpa_mask18 = lambda b, h, m, n, *args: True + self.sdpa_score19 = lambda score, b, h, m, n, *args: score + self.sdpa_mask19 = lambda b, h, m, n, *args: True + self.sdpa_score20 = lambda score, b, h, m, n, *args: score + self.sdpa_mask20 = lambda b, h, m, n, *args: True + self.sdpa_score21 = lambda score, b, h, m, n, *args: score + self.sdpa_mask21 = lambda b, h, m, n, *args: True + self.sdpa_score22 = lambda score, b, h, m, n, *args: score + self.sdpa_mask22 = lambda b, h, m, n, *args: True + self.sdpa_score23 = lambda score, b, h, m, n, *args: score + self.sdpa_mask23 = lambda b, h, m, n, *args: True + self.sdpa_score24 = lambda score, b, h, m, n, *args: score + self.sdpa_mask24 = lambda b, h, m, n, *args: True + self.sdpa_score25 = lambda score, b, h, m, n, *args: score + self.sdpa_mask25 = lambda b, h, m, n, *args: True + self.sdpa_score26 = lambda score, b, h, m, n, *args: score + self.sdpa_mask26 = lambda b, h, m, n, *args: True + + + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, primals_294, primals_295, primals_296, primals_297, primals_298, primals_299, primals_300, primals_301, primals_302, primals_303, primals_304, primals_305, primals_306, primals_307, primals_308, primals_309, primals_310, primals_311, primals_312, primals_313, primals_314, primals_315, primals_316, primals_317, primals_318, primals_319, primals_320, primals_321, primals_322, primals_323, primals_324, primals_325, primals_326, primals_327, primals_328, primals_329, primals_330, primals_331, primals_332, primals_333, primals_334, primals_335, primals_336, primals_337, primals_338, primals_339, primals_340, primals_341, primals_342, primals_343, primals_344, primals_345, primals_346, primals_347, primals_348, primals_349, primals_350, primals_351, primals_352, primals_353, primals_354, primals_355, primals_356, primals_357, primals_358, primals_359, primals_360, primals_361, primals_362, primals_363, primals_364, primals_365, primals_366, primals_367, primals_368, primals_369, primals_370, primals_371, primals_372, primals_373, primals_374, primals_375, primals_376, primals_377, primals_378, primals_379, primals_380, primals_381, primals_382, primals_383, primals_384, primals_385, primals_386, primals_387, primals_388, primals_389, primals_390, primals_391, primals_392, primals_393, primals_394, primals_395, primals_396, primals_397, primals_398, primals_399, primals_400, primals_401, primals_402, primals_403, primals_404, primals_405, primals_406, primals_407, primals_408, primals_409, primals_410, primals_411, primals_412, primals_413, primals_414, primals_415, primals_416, primals_417, primals_418, primals_419, primals_420, primals_421, primals_422, primals_423, primals_424, primals_425, primals_426, primals_427, primals_428, primals_429, primals_430, primals_431, primals_432, primals_433, primals_434, primals_435, primals_436, primals_437, primals_438, primals_439, primals_440): + convert_element_type = torch.ops.prims.convert_element_type.default(primals_1, torch.bfloat16) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type, 64, '0'); convert_element_type = None + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None + embedding = torch.ops.aten.embedding.default(wait_tensor, primals_2); wait_tensor = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16) + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 64, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32) + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_1); mul = wait_tensor_1 = None + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16) + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 64, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + view_3 = torch.ops.aten.view.default(convert_element_type_3, [8192, 2048]); convert_element_type_3 = None + mm = torch.ops.aten.mm.default(view_3, permute); permute = None + view_4 = torch.ops.aten.view.default(mm, [2, 4096, 3072]); mm = None + view_5 = torch.ops.aten.view.default(view_4, [2, 4096, -1, 192]); view_4 = None + split_with_sizes = torch.ops.aten.split_with_sizes.default(view_5, [128, 64], -1); view_5 = None + getitem = split_with_sizes[0] + getitem_1 = split_with_sizes[1]; split_with_sizes = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default(getitem_1, torch.float32); getitem_1 = None + view_6 = torch.ops.aten.view.default(convert_element_type_7, [2, 4096, 16, -1, 2]); convert_element_type_7 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_6); view_6 = None + view_7 = torch.ops.aten.view.default(primals_3, [1, 4096, 1, 32]) + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_7); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_8 = torch.ops.aten.view.default(view_as_real, [2, 4096, 16, 64]); view_as_real = None + convert_element_type_8 = torch.ops.prims.convert_element_type.default(view_8, torch.bfloat16); view_8 = None + cat = torch.ops.aten.cat.default([getitem, convert_element_type_8], -1); getitem = convert_element_type_8 = None + convert_element_type_9 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16) + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_9, 64, '0'); convert_element_type_9 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_3, [1, 0]); wait_tensor_3 = None + mm_1 = torch.ops.aten.mm.default(view_3, permute_1); permute_1 = None + view_11 = torch.ops.aten.view.default(mm_1, [2, 4096, 576]); mm_1 = None + split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(view_11, [512, 64], -1); view_11 = None + getitem_2 = split_with_sizes_1[0] + getitem_3 = split_with_sizes_1[1]; split_with_sizes_1 = None + unsqueeze = torch.ops.aten.unsqueeze.default(getitem_3, 2); getitem_3 = None + convert_element_type_12 = torch.ops.prims.convert_element_type.default(unsqueeze, torch.float32); unsqueeze = None + view_12 = torch.ops.aten.view.default(convert_element_type_12, [2, 4096, 1, -1, 2]); convert_element_type_12 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_12); view_12 = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_7); view_as_complex_1 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_14 = torch.ops.aten.view.default(view_as_real_1, [2, 4096, 1, 64]); view_as_real_1 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_14, torch.bfloat16); view_14 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16) + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_14, 64, '0'); convert_element_type_14 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(getitem_2, torch.float32) + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_15, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_1 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_1); add_1 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_15, rsqrt_1); convert_element_type_15 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_4); mul_4 = wait_tensor_4 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16) + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 64, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + view_17 = torch.ops.aten.view.default(convert_element_type_16, [8192, 512]); convert_element_type_16 = None + mm_2 = torch.ops.aten.mm.default(view_17, permute_2); permute_2 = None + view_18 = torch.ops.aten.view.default(mm_2, [2, 4096, 4096]); mm_2 = None + view_19 = torch.ops.aten.view.default(view_18, [2, 4096, -1, 256]); view_18 = None + split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(view_19, [128, 128], -1); view_19 = None + getitem_4 = split_with_sizes_2[0] + getitem_5 = split_with_sizes_2[1]; split_with_sizes_2 = None + expand = torch.ops.aten.expand.default(convert_element_type_13, [-1, -1, 16, -1]); convert_element_type_13 = None + cat_1 = torch.ops.aten.cat.default([getitem_4, expand], -1); getitem_4 = expand = None + permute_3 = torch.ops.aten.permute.default(cat, [0, 2, 1, 3]); cat = None + permute_4 = torch.ops.aten.permute.default(cat_1, [0, 2, 1, 3]); cat_1 = None + permute_5 = torch.ops.aten.permute.default(getitem_5, [0, 2, 1, 3]); getitem_5 = None + sdpa_score0 = self.sdpa_score0 + sdpa_mask0 = self.sdpa_mask0 + flex_attention = torch.ops.higher_order.flex_attention(permute_3, permute_4, permute_5, sdpa_score0, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask0), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score0 = sdpa_mask0 = None + getitem_6 = flex_attention[0] + getitem_7 = flex_attention[1]; flex_attention = None + permute_6 = torch.ops.aten.permute.default(getitem_6, [0, 2, 1, 3]) + view_20 = torch.ops.aten.view.default(permute_6, [2, 4096, -1]); permute_6 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16) + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 64, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + view_22 = torch.ops.aten.view.default(view_20, [8192, 2048]); view_20 = None + mm_3 = torch.ops.aten.mm.default(view_22, permute_7); view_22 = permute_7 = None + view_23 = torch.ops.aten.view.default(mm_3, [2, 4096, 2048]) + add_2 = torch.ops.aten.add.Tensor(embedding, view_23); view_23 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16) + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 64, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_24 = torch.ops.prims.convert_element_type.default(add_2, torch.float32) + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_24, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_3 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_3); add_3 = None + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_24, rsqrt_2); convert_element_type_24 = None + mul_7 = torch.ops.aten.mul.Tensor(mul_6, wait_tensor_7); mul_6 = wait_tensor_7 = None + convert_element_type_25 = torch.ops.prims.convert_element_type.default(mul_7, torch.bfloat16); mul_7 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16) + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_26, 64, '0'); convert_element_type_26 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_8, [1, 0]); wait_tensor_8 = None + view_26 = torch.ops.aten.view.default(convert_element_type_25, [8192, 2048]); convert_element_type_25 = None + mm_4 = torch.ops.aten.mm.default(view_26, permute_8); permute_8 = None + view_27 = torch.ops.aten.view.default(mm_4, [2, 4096, 10944]) + convert_element_type_29 = torch.ops.prims.convert_element_type.default(view_27, torch.float32); view_27 = None + neg = torch.ops.aten.neg.default(convert_element_type_29) + exp = torch.ops.aten.exp.default(neg); neg = None + add_4 = torch.ops.aten.add.Tensor(exp, 1); exp = None + div = torch.ops.aten.div.Tensor(convert_element_type_29, add_4); convert_element_type_29 = add_4 = None + convert_element_type_30 = torch.ops.prims.convert_element_type.default(div, torch.bfloat16); div = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16) + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 64, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_9, [1, 0]); wait_tensor_9 = None + mm_5 = torch.ops.aten.mm.default(view_26, permute_9); permute_9 = None + view_30 = torch.ops.aten.view.default(mm_5, [2, 4096, 10944]) + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_30, view_30); convert_element_type_30 = view_30 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16) + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 64, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_10, [1, 0]); wait_tensor_10 = None + view_32 = torch.ops.aten.view.default(mul_8, [8192, 10944]); mul_8 = None + mm_6 = torch.ops.aten.mm.default(view_32, permute_10); permute_10 = None + view_33 = torch.ops.aten.view.default(mm_6, [2, 4096, 2048]); mm_6 = None + add_5 = torch.ops.aten.add.Tensor(add_2, view_33); add_2 = view_33 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16) + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 64, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + convert_element_type_38 = torch.ops.prims.convert_element_type.default(add_5, torch.float32) + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_38, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_9 = torch.ops.aten.mul.Tensor(convert_element_type_38, rsqrt_3); convert_element_type_38 = None + mul_10 = torch.ops.aten.mul.Tensor(mul_9, wait_tensor_11); mul_9 = wait_tensor_11 = None + convert_element_type_39 = torch.ops.prims.convert_element_type.default(mul_10, torch.bfloat16); mul_10 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16) + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 64, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + view_36 = torch.ops.aten.view.default(convert_element_type_39, [8192, 2048]); convert_element_type_39 = None + mm_7 = torch.ops.aten.mm.default(view_36, permute_11); permute_11 = None + view_37 = torch.ops.aten.view.default(mm_7, [2, 4096, 3072]); mm_7 = None + view_38 = torch.ops.aten.view.default(view_37, [2, 4096, -1, 192]); view_37 = None + split_with_sizes_3 = torch.ops.aten.split_with_sizes.default(view_38, [128, 64], -1); view_38 = None + getitem_9 = split_with_sizes_3[0] + getitem_10 = split_with_sizes_3[1]; split_with_sizes_3 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(getitem_10, torch.float32); getitem_10 = None + view_39 = torch.ops.aten.view.default(convert_element_type_43, [2, 4096, 16, -1, 2]); convert_element_type_43 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_39); view_39 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_7); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_41 = torch.ops.aten.view.default(view_as_real_2, [2, 4096, 16, 64]); view_as_real_2 = None + convert_element_type_44 = torch.ops.prims.convert_element_type.default(view_41, torch.bfloat16); view_41 = None + cat_2 = torch.ops.aten.cat.default([getitem_9, convert_element_type_44], -1); getitem_9 = convert_element_type_44 = None + convert_element_type_45 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16) + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_45, 64, '0'); convert_element_type_45 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + mm_8 = torch.ops.aten.mm.default(view_36, permute_12); permute_12 = None + view_44 = torch.ops.aten.view.default(mm_8, [2, 4096, 576]); mm_8 = None + split_with_sizes_4 = torch.ops.aten.split_with_sizes.default(view_44, [512, 64], -1); view_44 = None + getitem_11 = split_with_sizes_4[0] + getitem_12 = split_with_sizes_4[1]; split_with_sizes_4 = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(getitem_12, 2); getitem_12 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(unsqueeze_1, torch.float32); unsqueeze_1 = None + view_45 = torch.ops.aten.view.default(convert_element_type_48, [2, 4096, 1, -1, 2]); convert_element_type_48 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_45); view_45 = None + mul_12 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_7); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_12); mul_12 = None + view_47 = torch.ops.aten.view.default(view_as_real_3, [2, 4096, 1, 64]); view_as_real_3 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_47, torch.bfloat16); view_47 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16) + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 64, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + convert_element_type_51 = torch.ops.prims.convert_element_type.default(getitem_11, torch.float32) + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_51, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_7 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_7); add_7 = None + mul_13 = torch.ops.aten.mul.Tensor(convert_element_type_51, rsqrt_4); convert_element_type_51 = None + mul_14 = torch.ops.aten.mul.Tensor(mul_13, wait_tensor_14); mul_13 = wait_tensor_14 = None + convert_element_type_52 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16) + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 64, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_15, [1, 0]); wait_tensor_15 = None + view_50 = torch.ops.aten.view.default(convert_element_type_52, [8192, 512]); convert_element_type_52 = None + mm_9 = torch.ops.aten.mm.default(view_50, permute_13); permute_13 = None + view_51 = torch.ops.aten.view.default(mm_9, [2, 4096, 4096]); mm_9 = None + view_52 = torch.ops.aten.view.default(view_51, [2, 4096, -1, 256]); view_51 = None + split_with_sizes_5 = torch.ops.aten.split_with_sizes.default(view_52, [128, 128], -1); view_52 = None + getitem_13 = split_with_sizes_5[0] + getitem_14 = split_with_sizes_5[1]; split_with_sizes_5 = None + expand_1 = torch.ops.aten.expand.default(convert_element_type_49, [-1, -1, 16, -1]); convert_element_type_49 = None + cat_3 = torch.ops.aten.cat.default([getitem_13, expand_1], -1); getitem_13 = expand_1 = None + permute_14 = torch.ops.aten.permute.default(cat_2, [0, 2, 1, 3]); cat_2 = None + permute_15 = torch.ops.aten.permute.default(cat_3, [0, 2, 1, 3]); cat_3 = None + permute_16 = torch.ops.aten.permute.default(getitem_14, [0, 2, 1, 3]); getitem_14 = None + sdpa_score1 = self.sdpa_score1 + sdpa_mask1 = self.sdpa_mask1 + flex_attention_1 = torch.ops.higher_order.flex_attention(permute_14, permute_15, permute_16, sdpa_score1, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask1), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score1 = sdpa_mask1 = None + getitem_15 = flex_attention_1[0] + getitem_16 = flex_attention_1[1]; flex_attention_1 = None + permute_17 = torch.ops.aten.permute.default(getitem_15, [0, 2, 1, 3]) + view_53 = torch.ops.aten.view.default(permute_17, [2, 4096, -1]); permute_17 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16) + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 64, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + view_55 = torch.ops.aten.view.default(view_53, [8192, 2048]); view_53 = None + mm_10 = torch.ops.aten.mm.default(view_55, permute_18); view_55 = permute_18 = None + view_56 = torch.ops.aten.view.default(mm_10, [2, 4096, 2048]); mm_10 = None + add_8 = torch.ops.aten.add.Tensor(add_5, view_56); view_56 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16) + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_59, 64, '0'); convert_element_type_59 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(add_8, torch.float32) + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_60, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_9 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_9); add_9 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, rsqrt_5); convert_element_type_60 = None + mul_16 = torch.ops.aten.mul.Tensor(mul_15, wait_tensor_17); mul_15 = wait_tensor_17 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(mul_16, torch.bfloat16); mul_16 = None + view_58 = torch.ops.aten.view.default(convert_element_type_61, [-1, 2048]); convert_element_type_61 = None + convert_element_type_62 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16) + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_62, 64, '0'); convert_element_type_62 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + mm_11 = torch.ops.aten.mm.default(view_58, permute_19); permute_19 = None + convert_element_type_65 = torch.ops.prims.convert_element_type.default(mm_11, torch.float32) + amax = torch.ops.aten.amax.default(convert_element_type_65, [1], True) + sub = torch.ops.aten.sub.Tensor(convert_element_type_65, amax); convert_element_type_65 = None + exp_1 = torch.ops.aten.exp.default(sub); sub = None + sum_1 = torch.ops.aten.sum.dim_IntList(exp_1, [1], True) + div_1 = torch.ops.aten.div.Tensor(exp_1, sum_1); exp_1 = None + add_10 = torch.ops.aten.add.Tensor(div_1, primals_30); primals_30 = None + topk = torch.ops.aten.topk.default(add_10, 6, -1, True, False); add_10 = None + getitem_19 = topk[1]; topk = None + gather = torch.ops.aten.gather.default(div_1, 1, getitem_19); div_1 = None + mul_17 = torch.ops.aten.mul.Tensor(gather, 1.0); gather = None + view_60 = torch.ops.aten.view.default(getitem_19, [-1]) + histc = torch.ops.aten.histc.default(view_60, 64, 0, 64) + add_11 = torch.ops.aten.add.Tensor(primals_32, histc) + sort = torch.ops.aten.sort.stable(view_60, stable = True); view_60 = None + getitem_21 = sort[1]; sort = None + div_2 = torch.ops.aten.div.Tensor_mode(getitem_21, 6, rounding_mode = 'floor') + index = torch.ops.aten.index.Tensor(view_58, [div_2]) + all_to_all_single = torch.ops._c10d_functional.all_to_all_single.default(histc, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single); all_to_all_single = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_19); wait_tensor_19 = None + view_64 = torch.ops.aten.view.default(histc, [8, -1]); histc = None + sum_2 = torch.ops.aten.sum.dim_IntList(view_64, [1]); view_64 = None + device_put = torch.ops.prims.device_put.default(sum_2, device(type='cpu'), True); sum_2 = None + view_65 = torch.ops.aten.view.default(wait_tensor_20, [8, -1]) + sum_3 = torch.ops.aten.sum.dim_IntList(view_65, [1]) + device_put_1 = torch.ops.prims.device_put.default(sum_3, device(type='cpu')); sum_3 = None + select = torch.ops.aten.select.int(device_put, 0, 0) + _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None + ge = _local_scalar_dense >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None + select_1 = torch.ops.aten.select.int(device_put, 0, 1) + _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None + ge_1 = _local_scalar_dense_1 >= 0 + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None + select_2 = torch.ops.aten.select.int(device_put, 0, 2) + _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None + ge_2 = _local_scalar_dense_2 >= 0 + _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None + select_3 = torch.ops.aten.select.int(device_put, 0, 3) + _local_scalar_dense_3 = torch.ops.aten._local_scalar_dense.default(select_3); select_3 = None + ge_3 = _local_scalar_dense_3 >= 0 + _assert_scalar_3 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_3'"); ge_3 = _assert_scalar_3 = None + select_4 = torch.ops.aten.select.int(device_put, 0, 4) + _local_scalar_dense_4 = torch.ops.aten._local_scalar_dense.default(select_4); select_4 = None + ge_4 = _local_scalar_dense_4 >= 0 + _assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_4, "Runtime assertion failed for expression u4 >= 0 on node 'ge_4'"); ge_4 = _assert_scalar_4 = None + select_5 = torch.ops.aten.select.int(device_put, 0, 5) + _local_scalar_dense_5 = torch.ops.aten._local_scalar_dense.default(select_5); select_5 = None + ge_5 = _local_scalar_dense_5 >= 0 + _assert_scalar_5 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u5 >= 0 on node 'ge_5'"); ge_5 = _assert_scalar_5 = None + select_6 = torch.ops.aten.select.int(device_put, 0, 6) + _local_scalar_dense_6 = torch.ops.aten._local_scalar_dense.default(select_6); select_6 = None + ge_6 = _local_scalar_dense_6 >= 0 + _assert_scalar_6 = torch.ops.aten._assert_scalar.default(ge_6, "Runtime assertion failed for expression u6 >= 0 on node 'ge_6'"); ge_6 = _assert_scalar_6 = None + select_7 = torch.ops.aten.select.int(device_put, 0, 7); device_put = None + _local_scalar_dense_7 = torch.ops.aten._local_scalar_dense.default(select_7); select_7 = None + ge_7 = _local_scalar_dense_7 >= 0 + _assert_scalar_7 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u7 >= 0 on node 'ge_7'"); ge_7 = _assert_scalar_7 = None + select_8 = torch.ops.aten.select.int(device_put_1, 0, 0) + _local_scalar_dense_8 = torch.ops.aten._local_scalar_dense.default(select_8); select_8 = None + ge_8 = _local_scalar_dense_8 >= 0 + _assert_scalar_8 = torch.ops.aten._assert_scalar.default(ge_8, "Runtime assertion failed for expression u8 >= 0 on node 'ge_8'"); ge_8 = _assert_scalar_8 = None + select_9 = torch.ops.aten.select.int(device_put_1, 0, 1) + _local_scalar_dense_9 = torch.ops.aten._local_scalar_dense.default(select_9); select_9 = None + ge_9 = _local_scalar_dense_9 >= 0 + _assert_scalar_9 = torch.ops.aten._assert_scalar.default(ge_9, "Runtime assertion failed for expression u9 >= 0 on node 'ge_9'"); ge_9 = _assert_scalar_9 = None + select_10 = torch.ops.aten.select.int(device_put_1, 0, 2) + _local_scalar_dense_10 = torch.ops.aten._local_scalar_dense.default(select_10); select_10 = None + ge_10 = _local_scalar_dense_10 >= 0 + _assert_scalar_10 = torch.ops.aten._assert_scalar.default(ge_10, "Runtime assertion failed for expression u10 >= 0 on node 'ge_10'"); ge_10 = _assert_scalar_10 = None + select_11 = torch.ops.aten.select.int(device_put_1, 0, 3) + _local_scalar_dense_11 = torch.ops.aten._local_scalar_dense.default(select_11); select_11 = None + ge_11 = _local_scalar_dense_11 >= 0 + _assert_scalar_11 = torch.ops.aten._assert_scalar.default(ge_11, "Runtime assertion failed for expression u11 >= 0 on node 'ge_11'"); ge_11 = _assert_scalar_11 = None + select_12 = torch.ops.aten.select.int(device_put_1, 0, 4) + _local_scalar_dense_12 = torch.ops.aten._local_scalar_dense.default(select_12); select_12 = None + ge_12 = _local_scalar_dense_12 >= 0 + _assert_scalar_12 = torch.ops.aten._assert_scalar.default(ge_12, "Runtime assertion failed for expression u12 >= 0 on node 'ge_12'"); ge_12 = _assert_scalar_12 = None + select_13 = torch.ops.aten.select.int(device_put_1, 0, 5) + _local_scalar_dense_13 = torch.ops.aten._local_scalar_dense.default(select_13); select_13 = None + ge_13 = _local_scalar_dense_13 >= 0 + _assert_scalar_13 = torch.ops.aten._assert_scalar.default(ge_13, "Runtime assertion failed for expression u13 >= 0 on node 'ge_13'"); ge_13 = _assert_scalar_13 = None + select_14 = torch.ops.aten.select.int(device_put_1, 0, 6) + _local_scalar_dense_14 = torch.ops.aten._local_scalar_dense.default(select_14); select_14 = None + ge_14 = _local_scalar_dense_14 >= 0 + _assert_scalar_14 = torch.ops.aten._assert_scalar.default(ge_14, "Runtime assertion failed for expression u14 >= 0 on node 'ge_14'"); ge_14 = _assert_scalar_14 = None + select_15 = torch.ops.aten.select.int(device_put_1, 0, 7); device_put_1 = None + _local_scalar_dense_15 = torch.ops.aten._local_scalar_dense.default(select_15); select_15 = None + ge_15 = _local_scalar_dense_15 >= 0 + _assert_scalar_15 = torch.ops.aten._assert_scalar.default(ge_15, "Runtime assertion failed for expression u15 >= 0 on node 'ge_15'"); ge_15 = _assert_scalar_15 = None + all_to_all_single_1 = torch.ops._c10d_functional.all_to_all_single.default(index, [_local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15], [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7], '521'); index = None + sym_size_int = torch.ops.aten.sym_size.int(all_to_all_single_1, 0) + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_1); all_to_all_single_1 = None + sym_sum = torch.sym_sum((_local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15, _local_scalar_dense_8, _local_scalar_dense_9)) + add_18 = sym_sum + 64; sym_sum = None + add_19 = add_18 + 8; add_18 = None + sub_3 = add_19 - 1; add_19 = None + floordiv = sub_3 // 8; sub_3 = None + mul_22 = floordiv * 8; floordiv = None + cumsum = torch.ops.aten.cumsum.default(wait_tensor_20, 0) + sub_4 = torch.ops.aten.sub.Tensor(cumsum, wait_tensor_20); cumsum = None + sum_4 = torch.ops.aten.sum.dim_IntList(view_65, [0]); view_65 = None + clamp_min = torch.ops.aten.clamp_min.default(sum_4, 8); sum_4 = None + add_20 = torch.ops.aten.add.Tensor(clamp_min, 8); clamp_min = None + sub_5 = torch.ops.aten.sub.Tensor(add_20, 1); add_20 = None + div_3 = torch.ops.aten.div.Tensor_mode(sub_5, 8, rounding_mode = 'floor'); sub_5 = None + mul_23 = torch.ops.aten.mul.Tensor(div_3, 8); div_3 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(mul_23, torch.int32); mul_23 = None + cumsum_1 = torch.ops.aten.cumsum.default(convert_element_type_68, 0) + sub_6 = torch.ops.aten.sub.Tensor(cumsum_1, convert_element_type_68); cumsum_1 = None + full_20 = torch.ops.aten.full.default([mul_22], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_22 = None + triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_20, 'start_index_values_ptr': sub_4, 'write_offsets_ptr': sub_6, 'output_ptr': full_20}, tensors_to_clone = ['output_ptr']); wait_tensor_20 = sub_4 = sub_6 = full_20 = None + getitem_22 = triton_kernel_wrapper_functional_proxy['output_ptr']; triton_kernel_wrapper_functional_proxy = None + full_default = torch.ops.aten.full.default([1, 2048], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + cat_4 = torch.ops.aten.cat.default([wait_tensor_21, full_default]); wait_tensor_21 = None + sym_size_int_1 = torch.ops.aten.sym_size.int(cat_4, 0) + sym_sum_1 = torch.sym_sum((1, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15, _local_scalar_dense_8, _local_scalar_dense_9)) + index_1 = torch.ops.aten.index.Tensor(cat_4, [getitem_22]); cat_4 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16) + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 8, '513'); convert_element_type_70 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + convert_element_type_72 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16) + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_72, 8, '513'); convert_element_type_72 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16) + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 8, '513'); convert_element_type_73 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + cumsum_2 = torch.ops.aten.cumsum.default(convert_element_type_68, 0, dtype = torch.int32); convert_element_type_68 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_22, [0, 2, 1]); wait_tensor_22 = None + _grouped_mm = torch.ops.aten._grouped_mm.default(index_1, permute_20, cumsum_2); permute_20 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(_grouped_mm, torch.float32) + neg_1 = torch.ops.aten.neg.default(convert_element_type_76) + exp_2 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_32 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + div_4 = torch.ops.aten.div.Tensor(convert_element_type_76, add_32); convert_element_type_76 = add_32 = None + convert_element_type_77 = torch.ops.prims.convert_element_type.default(div_4, torch.bfloat16); div_4 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_25, [0, 2, 1]); wait_tensor_25 = None + _grouped_mm_1 = torch.ops.aten._grouped_mm.default(index_1, permute_21, cumsum_2); permute_21 = None + mul_35 = torch.ops.aten.mul.Tensor(convert_element_type_77, _grouped_mm_1); convert_element_type_77 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_24, [0, 2, 1]); wait_tensor_24 = None + _grouped_mm_2 = torch.ops.aten._grouped_mm.default(mul_35, permute_22, cumsum_2); permute_22 = None + empty = torch.ops.aten.empty.memory_format([sym_size_int_1, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put = torch.ops.aten.index_put.default(empty, [getitem_22], _grouped_mm_2); empty = _grouped_mm_2 = None + slice_6 = torch.ops.aten.slice.Tensor(index_put, 0, 0, -1); index_put = None + all_to_all_single_2 = torch.ops._c10d_functional.all_to_all_single.default(slice_6, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7], [_local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15], '521'); slice_6 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_2); all_to_all_single_2 = None + convert_element_type_78 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16) + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_78, 64, '0'); convert_element_type_78 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + mm_12 = torch.ops.aten.mm.default(view_58, permute_23); permute_23 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(mm_12, torch.float32) + neg_2 = torch.ops.aten.neg.default(convert_element_type_81) + exp_3 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_68 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + div_5 = torch.ops.aten.div.Tensor(convert_element_type_81, add_68); convert_element_type_81 = add_68 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(div_5, torch.bfloat16); div_5 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16) + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 64, '0'); convert_element_type_83 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + mm_13 = torch.ops.aten.mm.default(view_58, permute_24); permute_24 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_82, mm_13); convert_element_type_82 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16) + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 64, '0'); convert_element_type_86 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_25 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_14 = torch.ops.aten.mm.default(mul_55, permute_25); permute_25 = None + full_default_1 = torch.ops.aten.full.default([49152, 2048], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_1 = torch.ops.aten.index_put.default(full_default_1, [getitem_21], wait_tensor_28); wait_tensor_28 = None + view_98 = torch.ops.aten.view.default(mul_17, [-1, 1, 6]); mul_17 = None + view_99 = torch.ops.aten.view.default(index_put_1, [-1, 6, 2048]); index_put_1 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(view_99, torch.float32); view_99 = None + bmm = torch.ops.aten.bmm.default(view_98, convert_element_type_89) + convert_element_type_90 = torch.ops.prims.convert_element_type.default(bmm, torch.bfloat16); bmm = None + squeeze = torch.ops.aten.squeeze.dim(convert_element_type_90, 1); convert_element_type_90 = None + add_72 = torch.ops.aten.add.Tensor(mm_14, squeeze); mm_14 = squeeze = None + view_100 = torch.ops.aten.view.default(add_72, [2, 4096, 2048]); add_72 = None + add_73 = torch.ops.aten.add.Tensor(add_8, view_100); view_100 = None + convert_element_type_91 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16) + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_91, 64, '0'); convert_element_type_91 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(add_73, torch.float32) + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_92, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_74 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_58 = torch.ops.aten.mul.Tensor(convert_element_type_92, rsqrt_6); convert_element_type_92 = None + mul_59 = torch.ops.aten.mul.Tensor(mul_58, wait_tensor_32); mul_58 = wait_tensor_32 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_59, torch.bfloat16); mul_59 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16) + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 64, '0'); convert_element_type_94 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_26 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + view_103 = torch.ops.aten.view.default(convert_element_type_93, [8192, 2048]); convert_element_type_93 = None + mm_15 = torch.ops.aten.mm.default(view_103, permute_26); permute_26 = None + view_104 = torch.ops.aten.view.default(mm_15, [2, 4096, 3072]); mm_15 = None + view_105 = torch.ops.aten.view.default(view_104, [2, 4096, -1, 192]); view_104 = None + split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_105, [128, 64], -1); view_105 = None + getitem_23 = split_with_sizes_6[0] + getitem_24 = split_with_sizes_6[1]; split_with_sizes_6 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(getitem_24, torch.float32); getitem_24 = None + view_106 = torch.ops.aten.view.default(convert_element_type_97, [2, 4096, 16, -1, 2]); convert_element_type_97 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_106); view_106 = None + mul_60 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_7); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_60); mul_60 = None + view_108 = torch.ops.aten.view.default(view_as_real_4, [2, 4096, 16, 64]); view_as_real_4 = None + convert_element_type_98 = torch.ops.prims.convert_element_type.default(view_108, torch.bfloat16); view_108 = None + cat_5 = torch.ops.aten.cat.default([getitem_23, convert_element_type_98], -1); getitem_23 = convert_element_type_98 = None + convert_element_type_99 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16) + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_99, 64, '0'); convert_element_type_99 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + permute_27 = torch.ops.aten.permute.default(wait_tensor_34, [1, 0]); wait_tensor_34 = None + mm_16 = torch.ops.aten.mm.default(view_103, permute_27); permute_27 = None + view_111 = torch.ops.aten.view.default(mm_16, [2, 4096, 576]); mm_16 = None + split_with_sizes_7 = torch.ops.aten.split_with_sizes.default(view_111, [512, 64], -1); view_111 = None + getitem_25 = split_with_sizes_7[0] + getitem_26 = split_with_sizes_7[1]; split_with_sizes_7 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(getitem_26, 2); getitem_26 = None + convert_element_type_102 = torch.ops.prims.convert_element_type.default(unsqueeze_3, torch.float32); unsqueeze_3 = None + view_112 = torch.ops.aten.view.default(convert_element_type_102, [2, 4096, 1, -1, 2]); convert_element_type_102 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_112); view_112 = None + mul_61 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_7); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_61); mul_61 = None + view_114 = torch.ops.aten.view.default(view_as_real_5, [2, 4096, 1, 64]); view_as_real_5 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(view_114, torch.bfloat16); view_114 = None + convert_element_type_104 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16) + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_104, 64, '0'); convert_element_type_104 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + convert_element_type_105 = torch.ops.prims.convert_element_type.default(getitem_25, torch.float32) + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_105, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_75 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_75); add_75 = None + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_105, rsqrt_7); convert_element_type_105 = None + mul_63 = torch.ops.aten.mul.Tensor(mul_62, wait_tensor_35); mul_62 = wait_tensor_35 = None + convert_element_type_106 = torch.ops.prims.convert_element_type.default(mul_63, torch.bfloat16); mul_63 = None + convert_element_type_107 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16) + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_107, 64, '0'); convert_element_type_107 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_28 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + view_117 = torch.ops.aten.view.default(convert_element_type_106, [8192, 512]); convert_element_type_106 = None + mm_17 = torch.ops.aten.mm.default(view_117, permute_28); permute_28 = None + view_118 = torch.ops.aten.view.default(mm_17, [2, 4096, 4096]); mm_17 = None + view_119 = torch.ops.aten.view.default(view_118, [2, 4096, -1, 256]); view_118 = None + split_with_sizes_8 = torch.ops.aten.split_with_sizes.default(view_119, [128, 128], -1); view_119 = None + getitem_27 = split_with_sizes_8[0] + getitem_28 = split_with_sizes_8[1]; split_with_sizes_8 = None + expand_2 = torch.ops.aten.expand.default(convert_element_type_103, [-1, -1, 16, -1]); convert_element_type_103 = None + cat_6 = torch.ops.aten.cat.default([getitem_27, expand_2], -1); getitem_27 = expand_2 = None + permute_29 = torch.ops.aten.permute.default(cat_5, [0, 2, 1, 3]); cat_5 = None + permute_30 = torch.ops.aten.permute.default(cat_6, [0, 2, 1, 3]); cat_6 = None + permute_31 = torch.ops.aten.permute.default(getitem_28, [0, 2, 1, 3]); getitem_28 = None + sdpa_score2 = self.sdpa_score2 + sdpa_mask2 = self.sdpa_mask2 + flex_attention_2 = torch.ops.higher_order.flex_attention(permute_29, permute_30, permute_31, sdpa_score2, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask2), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score2 = sdpa_mask2 = None + getitem_29 = flex_attention_2[0] + getitem_30 = flex_attention_2[1]; flex_attention_2 = None + permute_32 = torch.ops.aten.permute.default(getitem_29, [0, 2, 1, 3]) + view_120 = torch.ops.aten.view.default(permute_32, [2, 4096, -1]); permute_32 = None + convert_element_type_110 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16) + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_110, 64, '0'); convert_element_type_110 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + view_122 = torch.ops.aten.view.default(view_120, [8192, 2048]); view_120 = None + mm_18 = torch.ops.aten.mm.default(view_122, permute_33); view_122 = permute_33 = None + view_123 = torch.ops.aten.view.default(mm_18, [2, 4096, 2048]); mm_18 = None + add_76 = torch.ops.aten.add.Tensor(add_73, view_123); view_123 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16) + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_113, 64, '0'); convert_element_type_113 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(add_76, torch.float32) + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_114, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_77 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_77); add_77 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_114, rsqrt_8); convert_element_type_114 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_38); mul_64 = wait_tensor_38 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + view_125 = torch.ops.aten.view.default(convert_element_type_115, [-1, 2048]); convert_element_type_115 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16) + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 64, '0'); convert_element_type_116 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + mm_19 = torch.ops.aten.mm.default(view_125, permute_34); permute_34 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(mm_19, torch.float32) + amax_1 = torch.ops.aten.amax.default(convert_element_type_119, [1], True) + sub_24 = torch.ops.aten.sub.Tensor(convert_element_type_119, amax_1); convert_element_type_119 = None + exp_4 = torch.ops.aten.exp.default(sub_24); sub_24 = None + sum_5 = torch.ops.aten.sum.dim_IntList(exp_4, [1], True) + div_6 = torch.ops.aten.div.Tensor(exp_4, sum_5); exp_4 = None + add_78 = torch.ops.aten.add.Tensor(div_6, primals_46); primals_46 = None + topk_1 = torch.ops.aten.topk.default(add_78, 6, -1, True, False); add_78 = None + getitem_33 = topk_1[1]; topk_1 = None + gather_1 = torch.ops.aten.gather.default(div_6, 1, getitem_33); div_6 = None + mul_66 = torch.ops.aten.mul.Tensor(gather_1, 1.0); gather_1 = None + view_127 = torch.ops.aten.view.default(getitem_33, [-1]) + histc_2 = torch.ops.aten.histc.default(view_127, 64, 0, 64) + add_79 = torch.ops.aten.add.Tensor(primals_48, histc_2) + sort_1 = torch.ops.aten.sort.stable(view_127, stable = True); view_127 = None + getitem_35 = sort_1[1]; sort_1 = None + div_7 = torch.ops.aten.div.Tensor_mode(getitem_35, 6, rounding_mode = 'floor') + index_2 = torch.ops.aten.index.Tensor(view_125, [div_7]) + all_to_all_single_3 = torch.ops._c10d_functional.all_to_all_single.default(histc_2, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_3); all_to_all_single_3 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_40); wait_tensor_40 = None + view_131 = torch.ops.aten.view.default(histc_2, [8, -1]); histc_2 = None + sum_6 = torch.ops.aten.sum.dim_IntList(view_131, [1]); view_131 = None + device_put_2 = torch.ops.prims.device_put.default(sum_6, device(type='cpu'), True); sum_6 = None + view_132 = torch.ops.aten.view.default(wait_tensor_41, [8, -1]) + sum_7 = torch.ops.aten.sum.dim_IntList(view_132, [1]) + device_put_3 = torch.ops.prims.device_put.default(sum_7, device(type='cpu')); sum_7 = None + select_16 = torch.ops.aten.select.int(device_put_2, 0, 0) + _local_scalar_dense_16 = torch.ops.aten._local_scalar_dense.default(select_16); select_16 = None + ge_20 = _local_scalar_dense_16 >= 0 + _assert_scalar_16 = torch.ops.aten._assert_scalar.default(ge_20, "Runtime assertion failed for expression u16 >= 0 on node 'ge_16'"); ge_20 = _assert_scalar_16 = None + select_17 = torch.ops.aten.select.int(device_put_2, 0, 1) + _local_scalar_dense_17 = torch.ops.aten._local_scalar_dense.default(select_17); select_17 = None + ge_21 = _local_scalar_dense_17 >= 0 + _assert_scalar_17 = torch.ops.aten._assert_scalar.default(ge_21, "Runtime assertion failed for expression u17 >= 0 on node 'ge_17'"); ge_21 = _assert_scalar_17 = None + select_18 = torch.ops.aten.select.int(device_put_2, 0, 2) + _local_scalar_dense_18 = torch.ops.aten._local_scalar_dense.default(select_18); select_18 = None + ge_22 = _local_scalar_dense_18 >= 0 + _assert_scalar_18 = torch.ops.aten._assert_scalar.default(ge_22, "Runtime assertion failed for expression u18 >= 0 on node 'ge_18'"); ge_22 = _assert_scalar_18 = None + select_19 = torch.ops.aten.select.int(device_put_2, 0, 3) + _local_scalar_dense_19 = torch.ops.aten._local_scalar_dense.default(select_19); select_19 = None + ge_23 = _local_scalar_dense_19 >= 0 + _assert_scalar_19 = torch.ops.aten._assert_scalar.default(ge_23, "Runtime assertion failed for expression u19 >= 0 on node 'ge_19'"); ge_23 = _assert_scalar_19 = None + select_20 = torch.ops.aten.select.int(device_put_2, 0, 4) + _local_scalar_dense_20 = torch.ops.aten._local_scalar_dense.default(select_20); select_20 = None + ge_24 = _local_scalar_dense_20 >= 0 + _assert_scalar_20 = torch.ops.aten._assert_scalar.default(ge_24, "Runtime assertion failed for expression u20 >= 0 on node 'ge_20'"); ge_24 = _assert_scalar_20 = None + select_21 = torch.ops.aten.select.int(device_put_2, 0, 5) + _local_scalar_dense_21 = torch.ops.aten._local_scalar_dense.default(select_21); select_21 = None + ge_25 = _local_scalar_dense_21 >= 0 + _assert_scalar_21 = torch.ops.aten._assert_scalar.default(ge_25, "Runtime assertion failed for expression u21 >= 0 on node 'ge_21'"); ge_25 = _assert_scalar_21 = None + select_22 = torch.ops.aten.select.int(device_put_2, 0, 6) + _local_scalar_dense_22 = torch.ops.aten._local_scalar_dense.default(select_22); select_22 = None + ge_26 = _local_scalar_dense_22 >= 0 + _assert_scalar_22 = torch.ops.aten._assert_scalar.default(ge_26, "Runtime assertion failed for expression u22 >= 0 on node 'ge_22'"); ge_26 = _assert_scalar_22 = None + select_23 = torch.ops.aten.select.int(device_put_2, 0, 7); device_put_2 = None + _local_scalar_dense_23 = torch.ops.aten._local_scalar_dense.default(select_23); select_23 = None + ge_27 = _local_scalar_dense_23 >= 0 + _assert_scalar_23 = torch.ops.aten._assert_scalar.default(ge_27, "Runtime assertion failed for expression u23 >= 0 on node 'ge_23'"); ge_27 = _assert_scalar_23 = None + select_24 = torch.ops.aten.select.int(device_put_3, 0, 0) + _local_scalar_dense_24 = torch.ops.aten._local_scalar_dense.default(select_24); select_24 = None + ge_28 = _local_scalar_dense_24 >= 0 + _assert_scalar_24 = torch.ops.aten._assert_scalar.default(ge_28, "Runtime assertion failed for expression u24 >= 0 on node 'ge_24'"); ge_28 = _assert_scalar_24 = None + select_25 = torch.ops.aten.select.int(device_put_3, 0, 1) + _local_scalar_dense_25 = torch.ops.aten._local_scalar_dense.default(select_25); select_25 = None + ge_29 = _local_scalar_dense_25 >= 0 + _assert_scalar_25 = torch.ops.aten._assert_scalar.default(ge_29, "Runtime assertion failed for expression u25 >= 0 on node 'ge_25'"); ge_29 = _assert_scalar_25 = None + select_26 = torch.ops.aten.select.int(device_put_3, 0, 2) + _local_scalar_dense_26 = torch.ops.aten._local_scalar_dense.default(select_26); select_26 = None + ge_30 = _local_scalar_dense_26 >= 0 + _assert_scalar_26 = torch.ops.aten._assert_scalar.default(ge_30, "Runtime assertion failed for expression u26 >= 0 on node 'ge_26'"); ge_30 = _assert_scalar_26 = None + select_27 = torch.ops.aten.select.int(device_put_3, 0, 3) + _local_scalar_dense_27 = torch.ops.aten._local_scalar_dense.default(select_27); select_27 = None + ge_31 = _local_scalar_dense_27 >= 0 + _assert_scalar_27 = torch.ops.aten._assert_scalar.default(ge_31, "Runtime assertion failed for expression u27 >= 0 on node 'ge_27'"); ge_31 = _assert_scalar_27 = None + select_28 = torch.ops.aten.select.int(device_put_3, 0, 4) + _local_scalar_dense_28 = torch.ops.aten._local_scalar_dense.default(select_28); select_28 = None + ge_32 = _local_scalar_dense_28 >= 0 + _assert_scalar_28 = torch.ops.aten._assert_scalar.default(ge_32, "Runtime assertion failed for expression u28 >= 0 on node 'ge_28'"); ge_32 = _assert_scalar_28 = None + select_29 = torch.ops.aten.select.int(device_put_3, 0, 5) + _local_scalar_dense_29 = torch.ops.aten._local_scalar_dense.default(select_29); select_29 = None + ge_33 = _local_scalar_dense_29 >= 0 + _assert_scalar_29 = torch.ops.aten._assert_scalar.default(ge_33, "Runtime assertion failed for expression u29 >= 0 on node 'ge_29'"); ge_33 = _assert_scalar_29 = None + select_30 = torch.ops.aten.select.int(device_put_3, 0, 6) + _local_scalar_dense_30 = torch.ops.aten._local_scalar_dense.default(select_30); select_30 = None + ge_34 = _local_scalar_dense_30 >= 0 + _assert_scalar_30 = torch.ops.aten._assert_scalar.default(ge_34, "Runtime assertion failed for expression u30 >= 0 on node 'ge_30'"); ge_34 = _assert_scalar_30 = None + select_31 = torch.ops.aten.select.int(device_put_3, 0, 7); device_put_3 = None + _local_scalar_dense_31 = torch.ops.aten._local_scalar_dense.default(select_31); select_31 = None + ge_35 = _local_scalar_dense_31 >= 0 + _assert_scalar_31 = torch.ops.aten._assert_scalar.default(ge_35, "Runtime assertion failed for expression u31 >= 0 on node 'ge_31'"); ge_35 = _assert_scalar_31 = None + all_to_all_single_4 = torch.ops._c10d_functional.all_to_all_single.default(index_2, [_local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31], [_local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23], '521'); index_2 = None + sym_size_int_4 = torch.ops.aten.sym_size.int(all_to_all_single_4, 0) + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_4); all_to_all_single_4 = None + sym_sum_2 = torch.sym_sum((_local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31)) + add_86 = sym_sum_2 + 64; sym_sum_2 = None + add_87 = add_86 + 8; add_86 = None + sub_27 = add_87 - 1; add_87 = None + floordiv_1 = sub_27 // 8; sub_27 = None + mul_71 = floordiv_1 * 8; floordiv_1 = None + cumsum_3 = torch.ops.aten.cumsum.default(wait_tensor_41, 0) + sub_28 = torch.ops.aten.sub.Tensor(cumsum_3, wait_tensor_41); cumsum_3 = None + sum_8 = torch.ops.aten.sum.dim_IntList(view_132, [0]); view_132 = None + clamp_min_1 = torch.ops.aten.clamp_min.default(sum_8, 8); sum_8 = None + add_88 = torch.ops.aten.add.Tensor(clamp_min_1, 8); clamp_min_1 = None + sub_29 = torch.ops.aten.sub.Tensor(add_88, 1); add_88 = None + div_8 = torch.ops.aten.div.Tensor_mode(sub_29, 8, rounding_mode = 'floor'); sub_29 = None + mul_72 = torch.ops.aten.mul.Tensor(div_8, 8); div_8 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(mul_72, torch.int32); mul_72 = None + cumsum_4 = torch.ops.aten.cumsum.default(convert_element_type_122, 0) + sub_30 = torch.ops.aten.sub.Tensor(cumsum_4, convert_element_type_122); cumsum_4 = None + full_33 = torch.ops.aten.full.default([mul_71], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_71 = None + triton_kernel_wrapper_functional_proxy_1 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 1, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_41, 'start_index_values_ptr': sub_28, 'write_offsets_ptr': sub_30, 'output_ptr': full_33}, tensors_to_clone = ['output_ptr']); wait_tensor_41 = sub_28 = sub_30 = full_33 = None + getitem_36 = triton_kernel_wrapper_functional_proxy_1['output_ptr']; triton_kernel_wrapper_functional_proxy_1 = None + cat_7 = torch.ops.aten.cat.default([wait_tensor_42, full_default]); wait_tensor_42 = None + sym_size_int_5 = torch.ops.aten.sym_size.int(cat_7, 0) + sym_sum_3 = torch.sym_sum((1, _local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31)) + index_3 = torch.ops.aten.index.Tensor(cat_7, [getitem_36]); cat_7 = None + convert_element_type_124 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16) + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_124, 8, '513'); convert_element_type_124 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16) + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_126, 8, '513'); convert_element_type_126 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16) + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 8, '513'); convert_element_type_127 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + cumsum_5 = torch.ops.aten.cumsum.default(convert_element_type_122, 0, dtype = torch.int32); convert_element_type_122 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_43, [0, 2, 1]); wait_tensor_43 = None + _grouped_mm_3 = torch.ops.aten._grouped_mm.default(index_3, permute_35, cumsum_5); permute_35 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(_grouped_mm_3, torch.float32) + neg_3 = torch.ops.aten.neg.default(convert_element_type_130) + exp_5 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_100 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + div_9 = torch.ops.aten.div.Tensor(convert_element_type_130, add_100); convert_element_type_130 = add_100 = None + convert_element_type_131 = torch.ops.prims.convert_element_type.default(div_9, torch.bfloat16); div_9 = None + permute_36 = torch.ops.aten.permute.default(wait_tensor_46, [0, 2, 1]); wait_tensor_46 = None + _grouped_mm_4 = torch.ops.aten._grouped_mm.default(index_3, permute_36, cumsum_5); permute_36 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_131, _grouped_mm_4); convert_element_type_131 = None + permute_37 = torch.ops.aten.permute.default(wait_tensor_45, [0, 2, 1]); wait_tensor_45 = None + _grouped_mm_5 = torch.ops.aten._grouped_mm.default(mul_84, permute_37, cumsum_5); permute_37 = None + empty_1 = torch.ops.aten.empty.memory_format([sym_size_int_5, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_2 = torch.ops.aten.index_put.default(empty_1, [getitem_36], _grouped_mm_5); empty_1 = _grouped_mm_5 = None + slice_10 = torch.ops.aten.slice.Tensor(index_put_2, 0, 0, -1); index_put_2 = None + all_to_all_single_5 = torch.ops._c10d_functional.all_to_all_single.default(slice_10, [_local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23], [_local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31], '521'); slice_10 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_5); all_to_all_single_5 = None + convert_element_type_132 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16) + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_132, 64, '0'); convert_element_type_132 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_38 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + mm_20 = torch.ops.aten.mm.default(view_125, permute_38); permute_38 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mm_20, torch.float32) + neg_4 = torch.ops.aten.neg.default(convert_element_type_135) + exp_6 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_136 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + div_10 = torch.ops.aten.div.Tensor(convert_element_type_135, add_136); convert_element_type_135 = add_136 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(div_10, torch.bfloat16); div_10 = None + convert_element_type_137 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16) + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_137, 64, '0'); convert_element_type_137 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_39 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + mm_21 = torch.ops.aten.mm.default(view_125, permute_39); permute_39 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_136, mm_21); convert_element_type_136 = None + convert_element_type_140 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16) + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_140, 64, '0'); convert_element_type_140 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + mm_22 = torch.ops.aten.mm.default(mul_104, permute_40); permute_40 = None + index_put_3 = torch.ops.aten.index_put.default(full_default_1, [getitem_35], wait_tensor_49); wait_tensor_49 = None + view_165 = torch.ops.aten.view.default(mul_66, [-1, 1, 6]); mul_66 = None + view_166 = torch.ops.aten.view.default(index_put_3, [-1, 6, 2048]); index_put_3 = None + convert_element_type_143 = torch.ops.prims.convert_element_type.default(view_166, torch.float32); view_166 = None + bmm_1 = torch.ops.aten.bmm.default(view_165, convert_element_type_143) + convert_element_type_144 = torch.ops.prims.convert_element_type.default(bmm_1, torch.bfloat16); bmm_1 = None + squeeze_1 = torch.ops.aten.squeeze.dim(convert_element_type_144, 1); convert_element_type_144 = None + add_140 = torch.ops.aten.add.Tensor(mm_22, squeeze_1); mm_22 = squeeze_1 = None + view_167 = torch.ops.aten.view.default(add_140, [2, 4096, 2048]); add_140 = None + add_141 = torch.ops.aten.add.Tensor(add_76, view_167); view_167 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16) + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_145, 64, '0'); convert_element_type_145 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(add_141, torch.float32) + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_146, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_142 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_142); add_142 = None + mul_107 = torch.ops.aten.mul.Tensor(convert_element_type_146, rsqrt_9); convert_element_type_146 = None + mul_108 = torch.ops.aten.mul.Tensor(mul_107, wait_tensor_53); mul_107 = wait_tensor_53 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(mul_108, torch.bfloat16); mul_108 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16) + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_148, 64, '0'); convert_element_type_148 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + view_170 = torch.ops.aten.view.default(convert_element_type_147, [8192, 2048]); convert_element_type_147 = None + mm_23 = torch.ops.aten.mm.default(view_170, permute_41); permute_41 = None + view_171 = torch.ops.aten.view.default(mm_23, [2, 4096, 3072]); mm_23 = None + view_172 = torch.ops.aten.view.default(view_171, [2, 4096, -1, 192]); view_171 = None + split_with_sizes_9 = torch.ops.aten.split_with_sizes.default(view_172, [128, 64], -1); view_172 = None + getitem_37 = split_with_sizes_9[0] + getitem_38 = split_with_sizes_9[1]; split_with_sizes_9 = None + convert_element_type_151 = torch.ops.prims.convert_element_type.default(getitem_38, torch.float32); getitem_38 = None + view_173 = torch.ops.aten.view.default(convert_element_type_151, [2, 4096, 16, -1, 2]); convert_element_type_151 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_173); view_173 = None + mul_109 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_7); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_109); mul_109 = None + view_175 = torch.ops.aten.view.default(view_as_real_6, [2, 4096, 16, 64]); view_as_real_6 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(view_175, torch.bfloat16); view_175 = None + cat_8 = torch.ops.aten.cat.default([getitem_37, convert_element_type_152], -1); getitem_37 = convert_element_type_152 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16) + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_153, 64, '0'); convert_element_type_153 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_55, [1, 0]); wait_tensor_55 = None + mm_24 = torch.ops.aten.mm.default(view_170, permute_42); permute_42 = None + view_178 = torch.ops.aten.view.default(mm_24, [2, 4096, 576]); mm_24 = None + split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_178, [512, 64], -1); view_178 = None + getitem_39 = split_with_sizes_10[0] + getitem_40 = split_with_sizes_10[1]; split_with_sizes_10 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(getitem_40, 2); getitem_40 = None + convert_element_type_156 = torch.ops.prims.convert_element_type.default(unsqueeze_5, torch.float32); unsqueeze_5 = None + view_179 = torch.ops.aten.view.default(convert_element_type_156, [2, 4096, 1, -1, 2]); convert_element_type_156 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_179); view_179 = None + mul_110 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_7); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_110); mul_110 = None + view_181 = torch.ops.aten.view.default(view_as_real_7, [2, 4096, 1, 64]); view_as_real_7 = None + convert_element_type_157 = torch.ops.prims.convert_element_type.default(view_181, torch.bfloat16); view_181 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16) + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_158, 64, '0'); convert_element_type_158 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(getitem_39, torch.float32) + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_159, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_143 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_143); add_143 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_159, rsqrt_10); convert_element_type_159 = None + mul_112 = torch.ops.aten.mul.Tensor(mul_111, wait_tensor_56); mul_111 = wait_tensor_56 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(mul_112, torch.bfloat16); mul_112 = None + convert_element_type_161 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16) + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_161, 64, '0'); convert_element_type_161 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + view_184 = torch.ops.aten.view.default(convert_element_type_160, [8192, 512]); convert_element_type_160 = None + mm_25 = torch.ops.aten.mm.default(view_184, permute_43); permute_43 = None + view_185 = torch.ops.aten.view.default(mm_25, [2, 4096, 4096]); mm_25 = None + view_186 = torch.ops.aten.view.default(view_185, [2, 4096, -1, 256]); view_185 = None + split_with_sizes_11 = torch.ops.aten.split_with_sizes.default(view_186, [128, 128], -1); view_186 = None + getitem_41 = split_with_sizes_11[0] + getitem_42 = split_with_sizes_11[1]; split_with_sizes_11 = None + expand_3 = torch.ops.aten.expand.default(convert_element_type_157, [-1, -1, 16, -1]); convert_element_type_157 = None + cat_9 = torch.ops.aten.cat.default([getitem_41, expand_3], -1); getitem_41 = expand_3 = None + permute_44 = torch.ops.aten.permute.default(cat_8, [0, 2, 1, 3]); cat_8 = None + permute_45 = torch.ops.aten.permute.default(cat_9, [0, 2, 1, 3]); cat_9 = None + permute_46 = torch.ops.aten.permute.default(getitem_42, [0, 2, 1, 3]); getitem_42 = None + sdpa_score3 = self.sdpa_score3 + sdpa_mask3 = self.sdpa_mask3 + flex_attention_3 = torch.ops.higher_order.flex_attention(permute_44, permute_45, permute_46, sdpa_score3, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask3), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score3 = sdpa_mask3 = None + getitem_43 = flex_attention_3[0] + getitem_44 = flex_attention_3[1]; flex_attention_3 = None + permute_47 = torch.ops.aten.permute.default(getitem_43, [0, 2, 1, 3]) + view_187 = torch.ops.aten.view.default(permute_47, [2, 4096, -1]); permute_47 = None + convert_element_type_164 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16) + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_164, 64, '0'); convert_element_type_164 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_48 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + view_189 = torch.ops.aten.view.default(view_187, [8192, 2048]); view_187 = None + mm_26 = torch.ops.aten.mm.default(view_189, permute_48); view_189 = permute_48 = None + view_190 = torch.ops.aten.view.default(mm_26, [2, 4096, 2048]); mm_26 = None + add_144 = torch.ops.aten.add.Tensor(add_141, view_190); view_190 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16) + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_167, 64, '0'); convert_element_type_167 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(add_144, torch.float32) + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_168, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_145 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_145); add_145 = None + mul_113 = torch.ops.aten.mul.Tensor(convert_element_type_168, rsqrt_11); convert_element_type_168 = None + mul_114 = torch.ops.aten.mul.Tensor(mul_113, wait_tensor_59); mul_113 = wait_tensor_59 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(mul_114, torch.bfloat16); mul_114 = None + view_192 = torch.ops.aten.view.default(convert_element_type_169, [-1, 2048]); convert_element_type_169 = None + convert_element_type_170 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16) + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_170, 64, '0'); convert_element_type_170 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + permute_49 = torch.ops.aten.permute.default(wait_tensor_60, [1, 0]); wait_tensor_60 = None + mm_27 = torch.ops.aten.mm.default(view_192, permute_49); permute_49 = None + convert_element_type_173 = torch.ops.prims.convert_element_type.default(mm_27, torch.float32) + amax_2 = torch.ops.aten.amax.default(convert_element_type_173, [1], True) + sub_48 = torch.ops.aten.sub.Tensor(convert_element_type_173, amax_2); convert_element_type_173 = None + exp_7 = torch.ops.aten.exp.default(sub_48); sub_48 = None + sum_9 = torch.ops.aten.sum.dim_IntList(exp_7, [1], True) + div_11 = torch.ops.aten.div.Tensor(exp_7, sum_9); exp_7 = None + add_146 = torch.ops.aten.add.Tensor(div_11, primals_62); primals_62 = None + topk_2 = torch.ops.aten.topk.default(add_146, 6, -1, True, False); add_146 = None + getitem_47 = topk_2[1]; topk_2 = None + gather_2 = torch.ops.aten.gather.default(div_11, 1, getitem_47); div_11 = None + mul_115 = torch.ops.aten.mul.Tensor(gather_2, 1.0); gather_2 = None + view_194 = torch.ops.aten.view.default(getitem_47, [-1]) + histc_4 = torch.ops.aten.histc.default(view_194, 64, 0, 64) + add_147 = torch.ops.aten.add.Tensor(primals_64, histc_4) + sort_2 = torch.ops.aten.sort.stable(view_194, stable = True); view_194 = None + getitem_49 = sort_2[1]; sort_2 = None + div_12 = torch.ops.aten.div.Tensor_mode(getitem_49, 6, rounding_mode = 'floor') + index_4 = torch.ops.aten.index.Tensor(view_192, [div_12]) + all_to_all_single_6 = torch.ops._c10d_functional.all_to_all_single.default(histc_4, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_6); all_to_all_single_6 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_61); wait_tensor_61 = None + view_198 = torch.ops.aten.view.default(histc_4, [8, -1]); histc_4 = None + sum_10 = torch.ops.aten.sum.dim_IntList(view_198, [1]); view_198 = None + device_put_4 = torch.ops.prims.device_put.default(sum_10, device(type='cpu'), True); sum_10 = None + view_199 = torch.ops.aten.view.default(wait_tensor_62, [8, -1]) + sum_11 = torch.ops.aten.sum.dim_IntList(view_199, [1]) + device_put_5 = torch.ops.prims.device_put.default(sum_11, device(type='cpu')); sum_11 = None + select_32 = torch.ops.aten.select.int(device_put_4, 0, 0) + _local_scalar_dense_32 = torch.ops.aten._local_scalar_dense.default(select_32); select_32 = None + ge_40 = _local_scalar_dense_32 >= 0 + _assert_scalar_32 = torch.ops.aten._assert_scalar.default(ge_40, "Runtime assertion failed for expression u32 >= 0 on node 'ge_32'"); ge_40 = _assert_scalar_32 = None + select_33 = torch.ops.aten.select.int(device_put_4, 0, 1) + _local_scalar_dense_33 = torch.ops.aten._local_scalar_dense.default(select_33); select_33 = None + ge_41 = _local_scalar_dense_33 >= 0 + _assert_scalar_33 = torch.ops.aten._assert_scalar.default(ge_41, "Runtime assertion failed for expression u33 >= 0 on node 'ge_33'"); ge_41 = _assert_scalar_33 = None + select_34 = torch.ops.aten.select.int(device_put_4, 0, 2) + _local_scalar_dense_34 = torch.ops.aten._local_scalar_dense.default(select_34); select_34 = None + ge_42 = _local_scalar_dense_34 >= 0 + _assert_scalar_34 = torch.ops.aten._assert_scalar.default(ge_42, "Runtime assertion failed for expression u34 >= 0 on node 'ge_34'"); ge_42 = _assert_scalar_34 = None + select_35 = torch.ops.aten.select.int(device_put_4, 0, 3) + _local_scalar_dense_35 = torch.ops.aten._local_scalar_dense.default(select_35); select_35 = None + ge_43 = _local_scalar_dense_35 >= 0 + _assert_scalar_35 = torch.ops.aten._assert_scalar.default(ge_43, "Runtime assertion failed for expression u35 >= 0 on node 'ge_35'"); ge_43 = _assert_scalar_35 = None + select_36 = torch.ops.aten.select.int(device_put_4, 0, 4) + _local_scalar_dense_36 = torch.ops.aten._local_scalar_dense.default(select_36); select_36 = None + ge_44 = _local_scalar_dense_36 >= 0 + _assert_scalar_36 = torch.ops.aten._assert_scalar.default(ge_44, "Runtime assertion failed for expression u36 >= 0 on node 'ge_36'"); ge_44 = _assert_scalar_36 = None + select_37 = torch.ops.aten.select.int(device_put_4, 0, 5) + _local_scalar_dense_37 = torch.ops.aten._local_scalar_dense.default(select_37); select_37 = None + ge_45 = _local_scalar_dense_37 >= 0 + _assert_scalar_37 = torch.ops.aten._assert_scalar.default(ge_45, "Runtime assertion failed for expression u37 >= 0 on node 'ge_37'"); ge_45 = _assert_scalar_37 = None + select_38 = torch.ops.aten.select.int(device_put_4, 0, 6) + _local_scalar_dense_38 = torch.ops.aten._local_scalar_dense.default(select_38); select_38 = None + ge_46 = _local_scalar_dense_38 >= 0 + _assert_scalar_38 = torch.ops.aten._assert_scalar.default(ge_46, "Runtime assertion failed for expression u38 >= 0 on node 'ge_38'"); ge_46 = _assert_scalar_38 = None + select_39 = torch.ops.aten.select.int(device_put_4, 0, 7); device_put_4 = None + _local_scalar_dense_39 = torch.ops.aten._local_scalar_dense.default(select_39); select_39 = None + ge_47 = _local_scalar_dense_39 >= 0 + _assert_scalar_39 = torch.ops.aten._assert_scalar.default(ge_47, "Runtime assertion failed for expression u39 >= 0 on node 'ge_39'"); ge_47 = _assert_scalar_39 = None + select_40 = torch.ops.aten.select.int(device_put_5, 0, 0) + _local_scalar_dense_40 = torch.ops.aten._local_scalar_dense.default(select_40); select_40 = None + ge_48 = _local_scalar_dense_40 >= 0 + _assert_scalar_40 = torch.ops.aten._assert_scalar.default(ge_48, "Runtime assertion failed for expression u40 >= 0 on node 'ge_40'"); ge_48 = _assert_scalar_40 = None + select_41 = torch.ops.aten.select.int(device_put_5, 0, 1) + _local_scalar_dense_41 = torch.ops.aten._local_scalar_dense.default(select_41); select_41 = None + ge_49 = _local_scalar_dense_41 >= 0 + _assert_scalar_41 = torch.ops.aten._assert_scalar.default(ge_49, "Runtime assertion failed for expression u41 >= 0 on node 'ge_41'"); ge_49 = _assert_scalar_41 = None + select_42 = torch.ops.aten.select.int(device_put_5, 0, 2) + _local_scalar_dense_42 = torch.ops.aten._local_scalar_dense.default(select_42); select_42 = None + ge_50 = _local_scalar_dense_42 >= 0 + _assert_scalar_42 = torch.ops.aten._assert_scalar.default(ge_50, "Runtime assertion failed for expression u42 >= 0 on node 'ge_42'"); ge_50 = _assert_scalar_42 = None + select_43 = torch.ops.aten.select.int(device_put_5, 0, 3) + _local_scalar_dense_43 = torch.ops.aten._local_scalar_dense.default(select_43); select_43 = None + ge_51 = _local_scalar_dense_43 >= 0 + _assert_scalar_43 = torch.ops.aten._assert_scalar.default(ge_51, "Runtime assertion failed for expression u43 >= 0 on node 'ge_43'"); ge_51 = _assert_scalar_43 = None + select_44 = torch.ops.aten.select.int(device_put_5, 0, 4) + _local_scalar_dense_44 = torch.ops.aten._local_scalar_dense.default(select_44); select_44 = None + ge_52 = _local_scalar_dense_44 >= 0 + _assert_scalar_44 = torch.ops.aten._assert_scalar.default(ge_52, "Runtime assertion failed for expression u44 >= 0 on node 'ge_44'"); ge_52 = _assert_scalar_44 = None + select_45 = torch.ops.aten.select.int(device_put_5, 0, 5) + _local_scalar_dense_45 = torch.ops.aten._local_scalar_dense.default(select_45); select_45 = None + ge_53 = _local_scalar_dense_45 >= 0 + _assert_scalar_45 = torch.ops.aten._assert_scalar.default(ge_53, "Runtime assertion failed for expression u45 >= 0 on node 'ge_45'"); ge_53 = _assert_scalar_45 = None + select_46 = torch.ops.aten.select.int(device_put_5, 0, 6) + _local_scalar_dense_46 = torch.ops.aten._local_scalar_dense.default(select_46); select_46 = None + ge_54 = _local_scalar_dense_46 >= 0 + _assert_scalar_46 = torch.ops.aten._assert_scalar.default(ge_54, "Runtime assertion failed for expression u46 >= 0 on node 'ge_46'"); ge_54 = _assert_scalar_46 = None + select_47 = torch.ops.aten.select.int(device_put_5, 0, 7); device_put_5 = None + _local_scalar_dense_47 = torch.ops.aten._local_scalar_dense.default(select_47); select_47 = None + ge_55 = _local_scalar_dense_47 >= 0 + _assert_scalar_47 = torch.ops.aten._assert_scalar.default(ge_55, "Runtime assertion failed for expression u47 >= 0 on node 'ge_47'"); ge_55 = _assert_scalar_47 = None + all_to_all_single_7 = torch.ops._c10d_functional.all_to_all_single.default(index_4, [_local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47], [_local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39], '521'); index_4 = None + sym_size_int_8 = torch.ops.aten.sym_size.int(all_to_all_single_7, 0) + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_7); all_to_all_single_7 = None + sym_sum_4 = torch.sym_sum((_local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47)) + add_154 = sym_sum_4 + 64; sym_sum_4 = None + add_155 = add_154 + 8; add_154 = None + sub_51 = add_155 - 1; add_155 = None + floordiv_2 = sub_51 // 8; sub_51 = None + mul_120 = floordiv_2 * 8; floordiv_2 = None + cumsum_6 = torch.ops.aten.cumsum.default(wait_tensor_62, 0) + sub_52 = torch.ops.aten.sub.Tensor(cumsum_6, wait_tensor_62); cumsum_6 = None + sum_12 = torch.ops.aten.sum.dim_IntList(view_199, [0]); view_199 = None + clamp_min_2 = torch.ops.aten.clamp_min.default(sum_12, 8); sum_12 = None + add_156 = torch.ops.aten.add.Tensor(clamp_min_2, 8); clamp_min_2 = None + sub_53 = torch.ops.aten.sub.Tensor(add_156, 1); add_156 = None + div_13 = torch.ops.aten.div.Tensor_mode(sub_53, 8, rounding_mode = 'floor'); sub_53 = None + mul_121 = torch.ops.aten.mul.Tensor(div_13, 8); div_13 = None + convert_element_type_176 = torch.ops.prims.convert_element_type.default(mul_121, torch.int32); mul_121 = None + cumsum_7 = torch.ops.aten.cumsum.default(convert_element_type_176, 0) + sub_54 = torch.ops.aten.sub.Tensor(cumsum_7, convert_element_type_176); cumsum_7 = None + full_46 = torch.ops.aten.full.default([mul_120], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_120 = None + triton_kernel_wrapper_functional_proxy_2 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 2, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_62, 'start_index_values_ptr': sub_52, 'write_offsets_ptr': sub_54, 'output_ptr': full_46}, tensors_to_clone = ['output_ptr']); wait_tensor_62 = sub_52 = sub_54 = full_46 = None + getitem_50 = triton_kernel_wrapper_functional_proxy_2['output_ptr']; triton_kernel_wrapper_functional_proxy_2 = None + cat_10 = torch.ops.aten.cat.default([wait_tensor_63, full_default]); wait_tensor_63 = None + sym_size_int_9 = torch.ops.aten.sym_size.int(cat_10, 0) + sym_sum_5 = torch.sym_sum((1, _local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47)) + index_5 = torch.ops.aten.index.Tensor(cat_10, [getitem_50]); cat_10 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16) + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_178, 8, '513'); convert_element_type_178 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16) + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_180, 8, '513'); convert_element_type_180 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16) + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_181, 8, '513'); convert_element_type_181 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + cumsum_8 = torch.ops.aten.cumsum.default(convert_element_type_176, 0, dtype = torch.int32); convert_element_type_176 = None + permute_50 = torch.ops.aten.permute.default(wait_tensor_64, [0, 2, 1]); wait_tensor_64 = None + _grouped_mm_6 = torch.ops.aten._grouped_mm.default(index_5, permute_50, cumsum_8); permute_50 = None + convert_element_type_184 = torch.ops.prims.convert_element_type.default(_grouped_mm_6, torch.float32) + neg_5 = torch.ops.aten.neg.default(convert_element_type_184) + exp_8 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_168 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + div_14 = torch.ops.aten.div.Tensor(convert_element_type_184, add_168); convert_element_type_184 = add_168 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(div_14, torch.bfloat16); div_14 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_67, [0, 2, 1]); wait_tensor_67 = None + _grouped_mm_7 = torch.ops.aten._grouped_mm.default(index_5, permute_51, cumsum_8); permute_51 = None + mul_133 = torch.ops.aten.mul.Tensor(convert_element_type_185, _grouped_mm_7); convert_element_type_185 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_66, [0, 2, 1]); wait_tensor_66 = None + _grouped_mm_8 = torch.ops.aten._grouped_mm.default(mul_133, permute_52, cumsum_8); permute_52 = None + empty_2 = torch.ops.aten.empty.memory_format([sym_size_int_9, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_4 = torch.ops.aten.index_put.default(empty_2, [getitem_50], _grouped_mm_8); empty_2 = _grouped_mm_8 = None + slice_14 = torch.ops.aten.slice.Tensor(index_put_4, 0, 0, -1); index_put_4 = None + all_to_all_single_8 = torch.ops._c10d_functional.all_to_all_single.default(slice_14, [_local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39], [_local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47], '521'); slice_14 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_8); all_to_all_single_8 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16) + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_186, 64, '0'); convert_element_type_186 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_28 = torch.ops.aten.mm.default(view_192, permute_53); permute_53 = None + convert_element_type_189 = torch.ops.prims.convert_element_type.default(mm_28, torch.float32) + neg_6 = torch.ops.aten.neg.default(convert_element_type_189) + exp_9 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_204 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + div_15 = torch.ops.aten.div.Tensor(convert_element_type_189, add_204); convert_element_type_189 = add_204 = None + convert_element_type_190 = torch.ops.prims.convert_element_type.default(div_15, torch.bfloat16); div_15 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16) + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_191, 64, '0'); convert_element_type_191 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + mm_29 = torch.ops.aten.mm.default(view_192, permute_54); permute_54 = None + mul_153 = torch.ops.aten.mul.Tensor(convert_element_type_190, mm_29); convert_element_type_190 = None + convert_element_type_194 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16) + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_194, 64, '0'); convert_element_type_194 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_73, [1, 0]); wait_tensor_73 = None + mm_30 = torch.ops.aten.mm.default(mul_153, permute_55); permute_55 = None + index_put_5 = torch.ops.aten.index_put.default(full_default_1, [getitem_49], wait_tensor_70); wait_tensor_70 = None + view_232 = torch.ops.aten.view.default(mul_115, [-1, 1, 6]); mul_115 = None + view_233 = torch.ops.aten.view.default(index_put_5, [-1, 6, 2048]); index_put_5 = None + convert_element_type_197 = torch.ops.prims.convert_element_type.default(view_233, torch.float32); view_233 = None + bmm_2 = torch.ops.aten.bmm.default(view_232, convert_element_type_197) + convert_element_type_198 = torch.ops.prims.convert_element_type.default(bmm_2, torch.bfloat16); bmm_2 = None + squeeze_2 = torch.ops.aten.squeeze.dim(convert_element_type_198, 1); convert_element_type_198 = None + add_208 = torch.ops.aten.add.Tensor(mm_30, squeeze_2); mm_30 = squeeze_2 = None + view_234 = torch.ops.aten.view.default(add_208, [2, 4096, 2048]); add_208 = None + add_209 = torch.ops.aten.add.Tensor(add_144, view_234); view_234 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16) + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 64, '0'); convert_element_type_199 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_209, torch.float32) + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_210 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_210); add_210 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_74); mul_156 = wait_tensor_74 = None + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16) + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 64, '0'); convert_element_type_202 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + view_237 = torch.ops.aten.view.default(convert_element_type_201, [8192, 2048]); convert_element_type_201 = None + mm_31 = torch.ops.aten.mm.default(view_237, permute_56); permute_56 = None + view_238 = torch.ops.aten.view.default(mm_31, [2, 4096, 3072]); mm_31 = None + view_239 = torch.ops.aten.view.default(view_238, [2, 4096, -1, 192]); view_238 = None + split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(view_239, [128, 64], -1); view_239 = None + getitem_51 = split_with_sizes_12[0] + getitem_52 = split_with_sizes_12[1]; split_with_sizes_12 = None + convert_element_type_205 = torch.ops.prims.convert_element_type.default(getitem_52, torch.float32); getitem_52 = None + view_240 = torch.ops.aten.view.default(convert_element_type_205, [2, 4096, 16, -1, 2]); convert_element_type_205 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_240); view_240 = None + mul_158 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_7); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_158); mul_158 = None + view_242 = torch.ops.aten.view.default(view_as_real_8, [2, 4096, 16, 64]); view_as_real_8 = None + convert_element_type_206 = torch.ops.prims.convert_element_type.default(view_242, torch.bfloat16); view_242 = None + cat_11 = torch.ops.aten.cat.default([getitem_51, convert_element_type_206], -1); getitem_51 = convert_element_type_206 = None + convert_element_type_207 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16) + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_207, 64, '0'); convert_element_type_207 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + mm_32 = torch.ops.aten.mm.default(view_237, permute_57); permute_57 = None + view_245 = torch.ops.aten.view.default(mm_32, [2, 4096, 576]); mm_32 = None + split_with_sizes_13 = torch.ops.aten.split_with_sizes.default(view_245, [512, 64], -1); view_245 = None + getitem_53 = split_with_sizes_13[0] + getitem_54 = split_with_sizes_13[1]; split_with_sizes_13 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(getitem_54, 2); getitem_54 = None + convert_element_type_210 = torch.ops.prims.convert_element_type.default(unsqueeze_7, torch.float32); unsqueeze_7 = None + view_246 = torch.ops.aten.view.default(convert_element_type_210, [2, 4096, 1, -1, 2]); convert_element_type_210 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_246); view_246 = None + mul_159 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_7); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_159); mul_159 = None + view_248 = torch.ops.aten.view.default(view_as_real_9, [2, 4096, 1, 64]); view_as_real_9 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_248, torch.bfloat16); view_248 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16) + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_212, 64, '0'); convert_element_type_212 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(getitem_53, torch.float32) + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_213, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_211 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_211); add_211 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_213, rsqrt_13); convert_element_type_213 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_77); mul_160 = wait_tensor_77 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16) + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 64, '0'); convert_element_type_215 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_58 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + view_251 = torch.ops.aten.view.default(convert_element_type_214, [8192, 512]); convert_element_type_214 = None + mm_33 = torch.ops.aten.mm.default(view_251, permute_58); permute_58 = None + view_252 = torch.ops.aten.view.default(mm_33, [2, 4096, 4096]); mm_33 = None + view_253 = torch.ops.aten.view.default(view_252, [2, 4096, -1, 256]); view_252 = None + split_with_sizes_14 = torch.ops.aten.split_with_sizes.default(view_253, [128, 128], -1); view_253 = None + getitem_55 = split_with_sizes_14[0] + getitem_56 = split_with_sizes_14[1]; split_with_sizes_14 = None + expand_4 = torch.ops.aten.expand.default(convert_element_type_211, [-1, -1, 16, -1]); convert_element_type_211 = None + cat_12 = torch.ops.aten.cat.default([getitem_55, expand_4], -1); getitem_55 = expand_4 = None + permute_59 = torch.ops.aten.permute.default(cat_11, [0, 2, 1, 3]); cat_11 = None + permute_60 = torch.ops.aten.permute.default(cat_12, [0, 2, 1, 3]); cat_12 = None + permute_61 = torch.ops.aten.permute.default(getitem_56, [0, 2, 1, 3]); getitem_56 = None + sdpa_score4 = self.sdpa_score4 + sdpa_mask4 = self.sdpa_mask4 + flex_attention_4 = torch.ops.higher_order.flex_attention(permute_59, permute_60, permute_61, sdpa_score4, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask4), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score4 = sdpa_mask4 = None + getitem_57 = flex_attention_4[0] + getitem_58 = flex_attention_4[1]; flex_attention_4 = None + permute_62 = torch.ops.aten.permute.default(getitem_57, [0, 2, 1, 3]) + view_254 = torch.ops.aten.view.default(permute_62, [2, 4096, -1]); permute_62 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16) + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 64, '0'); convert_element_type_218 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + view_256 = torch.ops.aten.view.default(view_254, [8192, 2048]); view_254 = None + mm_34 = torch.ops.aten.mm.default(view_256, permute_63); view_256 = permute_63 = None + view_257 = torch.ops.aten.view.default(mm_34, [2, 4096, 2048]); mm_34 = None + add_212 = torch.ops.aten.add.Tensor(add_209, view_257); view_257 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16) + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 64, '0'); convert_element_type_221 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + convert_element_type_222 = torch.ops.prims.convert_element_type.default(add_212, torch.float32) + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_222, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_213 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_213); add_213 = None + mul_162 = torch.ops.aten.mul.Tensor(convert_element_type_222, rsqrt_14); convert_element_type_222 = None + mul_163 = torch.ops.aten.mul.Tensor(mul_162, wait_tensor_80); mul_162 = wait_tensor_80 = None + convert_element_type_223 = torch.ops.prims.convert_element_type.default(mul_163, torch.bfloat16); mul_163 = None + view_259 = torch.ops.aten.view.default(convert_element_type_223, [-1, 2048]); convert_element_type_223 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16) + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_224, 64, '0'); convert_element_type_224 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_81, [1, 0]); wait_tensor_81 = None + mm_35 = torch.ops.aten.mm.default(view_259, permute_64); permute_64 = None + convert_element_type_227 = torch.ops.prims.convert_element_type.default(mm_35, torch.float32) + amax_3 = torch.ops.aten.amax.default(convert_element_type_227, [1], True) + sub_72 = torch.ops.aten.sub.Tensor(convert_element_type_227, amax_3); convert_element_type_227 = None + exp_10 = torch.ops.aten.exp.default(sub_72); sub_72 = None + sum_13 = torch.ops.aten.sum.dim_IntList(exp_10, [1], True) + div_16 = torch.ops.aten.div.Tensor(exp_10, sum_13); exp_10 = None + add_214 = torch.ops.aten.add.Tensor(div_16, primals_78); primals_78 = None + topk_3 = torch.ops.aten.topk.default(add_214, 6, -1, True, False); add_214 = None + getitem_61 = topk_3[1]; topk_3 = None + gather_3 = torch.ops.aten.gather.default(div_16, 1, getitem_61); div_16 = None + mul_164 = torch.ops.aten.mul.Tensor(gather_3, 1.0); gather_3 = None + view_261 = torch.ops.aten.view.default(getitem_61, [-1]) + histc_6 = torch.ops.aten.histc.default(view_261, 64, 0, 64) + add_215 = torch.ops.aten.add.Tensor(primals_80, histc_6) + sort_3 = torch.ops.aten.sort.stable(view_261, stable = True); view_261 = None + getitem_63 = sort_3[1]; sort_3 = None + div_17 = torch.ops.aten.div.Tensor_mode(getitem_63, 6, rounding_mode = 'floor') + index_6 = torch.ops.aten.index.Tensor(view_259, [div_17]) + all_to_all_single_9 = torch.ops._c10d_functional.all_to_all_single.default(histc_6, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_9); all_to_all_single_9 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_82); wait_tensor_82 = None + view_265 = torch.ops.aten.view.default(histc_6, [8, -1]); histc_6 = None + sum_14 = torch.ops.aten.sum.dim_IntList(view_265, [1]); view_265 = None + device_put_6 = torch.ops.prims.device_put.default(sum_14, device(type='cpu'), True); sum_14 = None + view_266 = torch.ops.aten.view.default(wait_tensor_83, [8, -1]) + sum_15 = torch.ops.aten.sum.dim_IntList(view_266, [1]) + device_put_7 = torch.ops.prims.device_put.default(sum_15, device(type='cpu')); sum_15 = None + select_48 = torch.ops.aten.select.int(device_put_6, 0, 0) + _local_scalar_dense_48 = torch.ops.aten._local_scalar_dense.default(select_48); select_48 = None + ge_60 = _local_scalar_dense_48 >= 0 + _assert_scalar_48 = torch.ops.aten._assert_scalar.default(ge_60, "Runtime assertion failed for expression u48 >= 0 on node 'ge_48'"); ge_60 = _assert_scalar_48 = None + select_49 = torch.ops.aten.select.int(device_put_6, 0, 1) + _local_scalar_dense_49 = torch.ops.aten._local_scalar_dense.default(select_49); select_49 = None + ge_61 = _local_scalar_dense_49 >= 0 + _assert_scalar_49 = torch.ops.aten._assert_scalar.default(ge_61, "Runtime assertion failed for expression u49 >= 0 on node 'ge_49'"); ge_61 = _assert_scalar_49 = None + select_50 = torch.ops.aten.select.int(device_put_6, 0, 2) + _local_scalar_dense_50 = torch.ops.aten._local_scalar_dense.default(select_50); select_50 = None + ge_62 = _local_scalar_dense_50 >= 0 + _assert_scalar_50 = torch.ops.aten._assert_scalar.default(ge_62, "Runtime assertion failed for expression u50 >= 0 on node 'ge_50'"); ge_62 = _assert_scalar_50 = None + select_51 = torch.ops.aten.select.int(device_put_6, 0, 3) + _local_scalar_dense_51 = torch.ops.aten._local_scalar_dense.default(select_51); select_51 = None + ge_63 = _local_scalar_dense_51 >= 0 + _assert_scalar_51 = torch.ops.aten._assert_scalar.default(ge_63, "Runtime assertion failed for expression u51 >= 0 on node 'ge_51'"); ge_63 = _assert_scalar_51 = None + select_52 = torch.ops.aten.select.int(device_put_6, 0, 4) + _local_scalar_dense_52 = torch.ops.aten._local_scalar_dense.default(select_52); select_52 = None + ge_64 = _local_scalar_dense_52 >= 0 + _assert_scalar_52 = torch.ops.aten._assert_scalar.default(ge_64, "Runtime assertion failed for expression u52 >= 0 on node 'ge_52'"); ge_64 = _assert_scalar_52 = None + select_53 = torch.ops.aten.select.int(device_put_6, 0, 5) + _local_scalar_dense_53 = torch.ops.aten._local_scalar_dense.default(select_53); select_53 = None + ge_65 = _local_scalar_dense_53 >= 0 + _assert_scalar_53 = torch.ops.aten._assert_scalar.default(ge_65, "Runtime assertion failed for expression u53 >= 0 on node 'ge_53'"); ge_65 = _assert_scalar_53 = None + select_54 = torch.ops.aten.select.int(device_put_6, 0, 6) + _local_scalar_dense_54 = torch.ops.aten._local_scalar_dense.default(select_54); select_54 = None + ge_66 = _local_scalar_dense_54 >= 0 + _assert_scalar_54 = torch.ops.aten._assert_scalar.default(ge_66, "Runtime assertion failed for expression u54 >= 0 on node 'ge_54'"); ge_66 = _assert_scalar_54 = None + select_55 = torch.ops.aten.select.int(device_put_6, 0, 7); device_put_6 = None + _local_scalar_dense_55 = torch.ops.aten._local_scalar_dense.default(select_55); select_55 = None + ge_67 = _local_scalar_dense_55 >= 0 + _assert_scalar_55 = torch.ops.aten._assert_scalar.default(ge_67, "Runtime assertion failed for expression u55 >= 0 on node 'ge_55'"); ge_67 = _assert_scalar_55 = None + select_56 = torch.ops.aten.select.int(device_put_7, 0, 0) + _local_scalar_dense_56 = torch.ops.aten._local_scalar_dense.default(select_56); select_56 = None + ge_68 = _local_scalar_dense_56 >= 0 + _assert_scalar_56 = torch.ops.aten._assert_scalar.default(ge_68, "Runtime assertion failed for expression u56 >= 0 on node 'ge_56'"); ge_68 = _assert_scalar_56 = None + select_57 = torch.ops.aten.select.int(device_put_7, 0, 1) + _local_scalar_dense_57 = torch.ops.aten._local_scalar_dense.default(select_57); select_57 = None + ge_69 = _local_scalar_dense_57 >= 0 + _assert_scalar_57 = torch.ops.aten._assert_scalar.default(ge_69, "Runtime assertion failed for expression u57 >= 0 on node 'ge_57'"); ge_69 = _assert_scalar_57 = None + select_58 = torch.ops.aten.select.int(device_put_7, 0, 2) + _local_scalar_dense_58 = torch.ops.aten._local_scalar_dense.default(select_58); select_58 = None + ge_70 = _local_scalar_dense_58 >= 0 + _assert_scalar_58 = torch.ops.aten._assert_scalar.default(ge_70, "Runtime assertion failed for expression u58 >= 0 on node 'ge_58'"); ge_70 = _assert_scalar_58 = None + select_59 = torch.ops.aten.select.int(device_put_7, 0, 3) + _local_scalar_dense_59 = torch.ops.aten._local_scalar_dense.default(select_59); select_59 = None + ge_71 = _local_scalar_dense_59 >= 0 + _assert_scalar_59 = torch.ops.aten._assert_scalar.default(ge_71, "Runtime assertion failed for expression u59 >= 0 on node 'ge_59'"); ge_71 = _assert_scalar_59 = None + select_60 = torch.ops.aten.select.int(device_put_7, 0, 4) + _local_scalar_dense_60 = torch.ops.aten._local_scalar_dense.default(select_60); select_60 = None + ge_72 = _local_scalar_dense_60 >= 0 + _assert_scalar_60 = torch.ops.aten._assert_scalar.default(ge_72, "Runtime assertion failed for expression u60 >= 0 on node 'ge_60'"); ge_72 = _assert_scalar_60 = None + select_61 = torch.ops.aten.select.int(device_put_7, 0, 5) + _local_scalar_dense_61 = torch.ops.aten._local_scalar_dense.default(select_61); select_61 = None + ge_73 = _local_scalar_dense_61 >= 0 + _assert_scalar_61 = torch.ops.aten._assert_scalar.default(ge_73, "Runtime assertion failed for expression u61 >= 0 on node 'ge_61'"); ge_73 = _assert_scalar_61 = None + select_62 = torch.ops.aten.select.int(device_put_7, 0, 6) + _local_scalar_dense_62 = torch.ops.aten._local_scalar_dense.default(select_62); select_62 = None + ge_74 = _local_scalar_dense_62 >= 0 + _assert_scalar_62 = torch.ops.aten._assert_scalar.default(ge_74, "Runtime assertion failed for expression u62 >= 0 on node 'ge_62'"); ge_74 = _assert_scalar_62 = None + select_63 = torch.ops.aten.select.int(device_put_7, 0, 7); device_put_7 = None + _local_scalar_dense_63 = torch.ops.aten._local_scalar_dense.default(select_63); select_63 = None + ge_75 = _local_scalar_dense_63 >= 0 + _assert_scalar_63 = torch.ops.aten._assert_scalar.default(ge_75, "Runtime assertion failed for expression u63 >= 0 on node 'ge_63'"); ge_75 = _assert_scalar_63 = None + all_to_all_single_10 = torch.ops._c10d_functional.all_to_all_single.default(index_6, [_local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63], [_local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55], '521'); index_6 = None + sym_size_int_12 = torch.ops.aten.sym_size.int(all_to_all_single_10, 0) + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_10); all_to_all_single_10 = None + sym_sum_6 = torch.sym_sum((_local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63)) + add_222 = sym_sum_6 + 64; sym_sum_6 = None + add_223 = add_222 + 8; add_222 = None + sub_75 = add_223 - 1; add_223 = None + floordiv_3 = sub_75 // 8; sub_75 = None + mul_169 = floordiv_3 * 8; floordiv_3 = None + cumsum_9 = torch.ops.aten.cumsum.default(wait_tensor_83, 0) + sub_76 = torch.ops.aten.sub.Tensor(cumsum_9, wait_tensor_83); cumsum_9 = None + sum_16 = torch.ops.aten.sum.dim_IntList(view_266, [0]); view_266 = None + clamp_min_3 = torch.ops.aten.clamp_min.default(sum_16, 8); sum_16 = None + add_224 = torch.ops.aten.add.Tensor(clamp_min_3, 8); clamp_min_3 = None + sub_77 = torch.ops.aten.sub.Tensor(add_224, 1); add_224 = None + div_18 = torch.ops.aten.div.Tensor_mode(sub_77, 8, rounding_mode = 'floor'); sub_77 = None + mul_170 = torch.ops.aten.mul.Tensor(div_18, 8); div_18 = None + convert_element_type_230 = torch.ops.prims.convert_element_type.default(mul_170, torch.int32); mul_170 = None + cumsum_10 = torch.ops.aten.cumsum.default(convert_element_type_230, 0) + sub_78 = torch.ops.aten.sub.Tensor(cumsum_10, convert_element_type_230); cumsum_10 = None + full_59 = torch.ops.aten.full.default([mul_169], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_169 = None + triton_kernel_wrapper_functional_proxy_3 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_83, 'start_index_values_ptr': sub_76, 'write_offsets_ptr': sub_78, 'output_ptr': full_59}, tensors_to_clone = ['output_ptr']); wait_tensor_83 = sub_76 = sub_78 = full_59 = None + getitem_64 = triton_kernel_wrapper_functional_proxy_3['output_ptr']; triton_kernel_wrapper_functional_proxy_3 = None + cat_13 = torch.ops.aten.cat.default([wait_tensor_84, full_default]); wait_tensor_84 = None + sym_size_int_13 = torch.ops.aten.sym_size.int(cat_13, 0) + sym_sum_7 = torch.sym_sum((1, _local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63)) + index_7 = torch.ops.aten.index.Tensor(cat_13, [getitem_64]); cat_13 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16) + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 8, '513'); convert_element_type_232 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16) + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 8, '513'); convert_element_type_234 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16) + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 8, '513'); convert_element_type_235 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + cumsum_11 = torch.ops.aten.cumsum.default(convert_element_type_230, 0, dtype = torch.int32); convert_element_type_230 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_85, [0, 2, 1]); wait_tensor_85 = None + _grouped_mm_9 = torch.ops.aten._grouped_mm.default(index_7, permute_65, cumsum_11); permute_65 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(_grouped_mm_9, torch.float32) + neg_7 = torch.ops.aten.neg.default(convert_element_type_238) + exp_11 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_236 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + div_19 = torch.ops.aten.div.Tensor(convert_element_type_238, add_236); convert_element_type_238 = add_236 = None + convert_element_type_239 = torch.ops.prims.convert_element_type.default(div_19, torch.bfloat16); div_19 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_88, [0, 2, 1]); wait_tensor_88 = None + _grouped_mm_10 = torch.ops.aten._grouped_mm.default(index_7, permute_66, cumsum_11); permute_66 = None + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_239, _grouped_mm_10); convert_element_type_239 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_87, [0, 2, 1]); wait_tensor_87 = None + _grouped_mm_11 = torch.ops.aten._grouped_mm.default(mul_182, permute_67, cumsum_11); permute_67 = None + empty_3 = torch.ops.aten.empty.memory_format([sym_size_int_13, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_6 = torch.ops.aten.index_put.default(empty_3, [getitem_64], _grouped_mm_11); empty_3 = _grouped_mm_11 = None + slice_18 = torch.ops.aten.slice.Tensor(index_put_6, 0, 0, -1); index_put_6 = None + all_to_all_single_11 = torch.ops._c10d_functional.all_to_all_single.default(slice_18, [_local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55], [_local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63], '521'); slice_18 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_11); all_to_all_single_11 = None + convert_element_type_240 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16) + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_240, 64, '0'); convert_element_type_240 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + mm_36 = torch.ops.aten.mm.default(view_259, permute_68); permute_68 = None + convert_element_type_243 = torch.ops.prims.convert_element_type.default(mm_36, torch.float32) + neg_8 = torch.ops.aten.neg.default(convert_element_type_243) + exp_12 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_272 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + div_20 = torch.ops.aten.div.Tensor(convert_element_type_243, add_272); convert_element_type_243 = add_272 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(div_20, torch.bfloat16); div_20 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16) + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_245, 64, '0'); convert_element_type_245 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_69 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + mm_37 = torch.ops.aten.mm.default(view_259, permute_69); permute_69 = None + mul_202 = torch.ops.aten.mul.Tensor(convert_element_type_244, mm_37); convert_element_type_244 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16) + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 64, '0'); convert_element_type_248 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + permute_70 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + mm_38 = torch.ops.aten.mm.default(mul_202, permute_70); permute_70 = None + index_put_7 = torch.ops.aten.index_put.default(full_default_1, [getitem_63], wait_tensor_91); wait_tensor_91 = None + view_299 = torch.ops.aten.view.default(mul_164, [-1, 1, 6]); mul_164 = None + view_300 = torch.ops.aten.view.default(index_put_7, [-1, 6, 2048]); index_put_7 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None + bmm_3 = torch.ops.aten.bmm.default(view_299, convert_element_type_251) + convert_element_type_252 = torch.ops.prims.convert_element_type.default(bmm_3, torch.bfloat16); bmm_3 = None + squeeze_3 = torch.ops.aten.squeeze.dim(convert_element_type_252, 1); convert_element_type_252 = None + add_276 = torch.ops.aten.add.Tensor(mm_38, squeeze_3); mm_38 = squeeze_3 = None + view_301 = torch.ops.aten.view.default(add_276, [2, 4096, 2048]); add_276 = None + add_277 = torch.ops.aten.add.Tensor(add_212, view_301); view_301 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16) + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 64, '0'); convert_element_type_253 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(add_277, torch.float32) + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_254, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_278 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_278); add_278 = None + mul_205 = torch.ops.aten.mul.Tensor(convert_element_type_254, rsqrt_15); convert_element_type_254 = None + mul_206 = torch.ops.aten.mul.Tensor(mul_205, wait_tensor_95); mul_205 = wait_tensor_95 = None + convert_element_type_255 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_256 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16) + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_256, 64, '0'); convert_element_type_256 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_71 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + view_304 = torch.ops.aten.view.default(convert_element_type_255, [8192, 2048]); convert_element_type_255 = None + mm_39 = torch.ops.aten.mm.default(view_304, permute_71); permute_71 = None + view_305 = torch.ops.aten.view.default(mm_39, [2, 4096, 3072]); mm_39 = None + view_306 = torch.ops.aten.view.default(view_305, [2, 4096, -1, 192]); view_305 = None + split_with_sizes_15 = torch.ops.aten.split_with_sizes.default(view_306, [128, 64], -1); view_306 = None + getitem_65 = split_with_sizes_15[0] + getitem_66 = split_with_sizes_15[1]; split_with_sizes_15 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(getitem_66, torch.float32); getitem_66 = None + view_307 = torch.ops.aten.view.default(convert_element_type_259, [2, 4096, 16, -1, 2]); convert_element_type_259 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_307); view_307 = None + mul_207 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_7); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_207); mul_207 = None + view_309 = torch.ops.aten.view.default(view_as_real_10, [2, 4096, 16, 64]); view_as_real_10 = None + convert_element_type_260 = torch.ops.prims.convert_element_type.default(view_309, torch.bfloat16); view_309 = None + cat_14 = torch.ops.aten.cat.default([getitem_65, convert_element_type_260], -1); getitem_65 = convert_element_type_260 = None + convert_element_type_261 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16) + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_261, 64, '0'); convert_element_type_261 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_72 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + mm_40 = torch.ops.aten.mm.default(view_304, permute_72); permute_72 = None + view_312 = torch.ops.aten.view.default(mm_40, [2, 4096, 576]); mm_40 = None + split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_312, [512, 64], -1); view_312 = None + getitem_67 = split_with_sizes_16[0] + getitem_68 = split_with_sizes_16[1]; split_with_sizes_16 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(getitem_68, 2); getitem_68 = None + convert_element_type_264 = torch.ops.prims.convert_element_type.default(unsqueeze_9, torch.float32); unsqueeze_9 = None + view_313 = torch.ops.aten.view.default(convert_element_type_264, [2, 4096, 1, -1, 2]); convert_element_type_264 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_313); view_313 = None + mul_208 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_7); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_208); mul_208 = None + view_315 = torch.ops.aten.view.default(view_as_real_11, [2, 4096, 1, 64]); view_as_real_11 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(view_315, torch.bfloat16); view_315 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16) + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_266, 64, '0'); convert_element_type_266 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(getitem_67, torch.float32) + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_267, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_279 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_279); add_279 = None + mul_209 = torch.ops.aten.mul.Tensor(convert_element_type_267, rsqrt_16); convert_element_type_267 = None + mul_210 = torch.ops.aten.mul.Tensor(mul_209, wait_tensor_98); mul_209 = wait_tensor_98 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(mul_210, torch.bfloat16); mul_210 = None + convert_element_type_269 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16) + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_269, 64, '0'); convert_element_type_269 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + view_318 = torch.ops.aten.view.default(convert_element_type_268, [8192, 512]); convert_element_type_268 = None + mm_41 = torch.ops.aten.mm.default(view_318, permute_73); permute_73 = None + view_319 = torch.ops.aten.view.default(mm_41, [2, 4096, 4096]); mm_41 = None + view_320 = torch.ops.aten.view.default(view_319, [2, 4096, -1, 256]); view_319 = None + split_with_sizes_17 = torch.ops.aten.split_with_sizes.default(view_320, [128, 128], -1); view_320 = None + getitem_69 = split_with_sizes_17[0] + getitem_70 = split_with_sizes_17[1]; split_with_sizes_17 = None + expand_5 = torch.ops.aten.expand.default(convert_element_type_265, [-1, -1, 16, -1]); convert_element_type_265 = None + cat_15 = torch.ops.aten.cat.default([getitem_69, expand_5], -1); getitem_69 = expand_5 = None + permute_74 = torch.ops.aten.permute.default(cat_14, [0, 2, 1, 3]); cat_14 = None + permute_75 = torch.ops.aten.permute.default(cat_15, [0, 2, 1, 3]); cat_15 = None + permute_76 = torch.ops.aten.permute.default(getitem_70, [0, 2, 1, 3]); getitem_70 = None + sdpa_score5 = self.sdpa_score5 + sdpa_mask5 = self.sdpa_mask5 + flex_attention_5 = torch.ops.higher_order.flex_attention(permute_74, permute_75, permute_76, sdpa_score5, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask5), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score5 = sdpa_mask5 = None + getitem_71 = flex_attention_5[0] + getitem_72 = flex_attention_5[1]; flex_attention_5 = None + permute_77 = torch.ops.aten.permute.default(getitem_71, [0, 2, 1, 3]) + view_321 = torch.ops.aten.view.default(permute_77, [2, 4096, -1]); permute_77 = None + convert_element_type_272 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16) + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_272, 64, '0'); convert_element_type_272 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_100, [1, 0]); wait_tensor_100 = None + view_323 = torch.ops.aten.view.default(view_321, [8192, 2048]); view_321 = None + mm_42 = torch.ops.aten.mm.default(view_323, permute_78); view_323 = permute_78 = None + view_324 = torch.ops.aten.view.default(mm_42, [2, 4096, 2048]); mm_42 = None + add_280 = torch.ops.aten.add.Tensor(add_277, view_324); view_324 = None + convert_element_type_275 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16) + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_275, 64, '0'); convert_element_type_275 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + convert_element_type_276 = torch.ops.prims.convert_element_type.default(add_280, torch.float32) + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_276, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_281 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_281); add_281 = None + mul_211 = torch.ops.aten.mul.Tensor(convert_element_type_276, rsqrt_17); convert_element_type_276 = None + mul_212 = torch.ops.aten.mul.Tensor(mul_211, wait_tensor_101); mul_211 = wait_tensor_101 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(mul_212, torch.bfloat16); mul_212 = None + view_326 = torch.ops.aten.view.default(convert_element_type_277, [-1, 2048]); convert_element_type_277 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16) + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_278, 64, '0'); convert_element_type_278 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + mm_43 = torch.ops.aten.mm.default(view_326, permute_79); permute_79 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(mm_43, torch.float32) + amax_4 = torch.ops.aten.amax.default(convert_element_type_281, [1], True) + sub_96 = torch.ops.aten.sub.Tensor(convert_element_type_281, amax_4); convert_element_type_281 = None + exp_13 = torch.ops.aten.exp.default(sub_96); sub_96 = None + sum_17 = torch.ops.aten.sum.dim_IntList(exp_13, [1], True) + div_21 = torch.ops.aten.div.Tensor(exp_13, sum_17); exp_13 = None + add_282 = torch.ops.aten.add.Tensor(div_21, primals_94); primals_94 = None + topk_4 = torch.ops.aten.topk.default(add_282, 6, -1, True, False); add_282 = None + getitem_75 = topk_4[1]; topk_4 = None + gather_4 = torch.ops.aten.gather.default(div_21, 1, getitem_75); div_21 = None + mul_213 = torch.ops.aten.mul.Tensor(gather_4, 1.0); gather_4 = None + view_328 = torch.ops.aten.view.default(getitem_75, [-1]) + histc_8 = torch.ops.aten.histc.default(view_328, 64, 0, 64) + add_283 = torch.ops.aten.add.Tensor(primals_96, histc_8) + sort_4 = torch.ops.aten.sort.stable(view_328, stable = True); view_328 = None + getitem_77 = sort_4[1]; sort_4 = None + div_22 = torch.ops.aten.div.Tensor_mode(getitem_77, 6, rounding_mode = 'floor') + index_8 = torch.ops.aten.index.Tensor(view_326, [div_22]) + all_to_all_single_12 = torch.ops._c10d_functional.all_to_all_single.default(histc_8, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_12); all_to_all_single_12 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_103); wait_tensor_103 = None + view_332 = torch.ops.aten.view.default(histc_8, [8, -1]); histc_8 = None + sum_18 = torch.ops.aten.sum.dim_IntList(view_332, [1]); view_332 = None + device_put_8 = torch.ops.prims.device_put.default(sum_18, device(type='cpu'), True); sum_18 = None + view_333 = torch.ops.aten.view.default(wait_tensor_104, [8, -1]) + sum_19 = torch.ops.aten.sum.dim_IntList(view_333, [1]) + device_put_9 = torch.ops.prims.device_put.default(sum_19, device(type='cpu')); sum_19 = None + select_64 = torch.ops.aten.select.int(device_put_8, 0, 0) + _local_scalar_dense_64 = torch.ops.aten._local_scalar_dense.default(select_64); select_64 = None + ge_80 = _local_scalar_dense_64 >= 0 + _assert_scalar_64 = torch.ops.aten._assert_scalar.default(ge_80, "Runtime assertion failed for expression u64 >= 0 on node 'ge_64'"); ge_80 = _assert_scalar_64 = None + select_65 = torch.ops.aten.select.int(device_put_8, 0, 1) + _local_scalar_dense_65 = torch.ops.aten._local_scalar_dense.default(select_65); select_65 = None + ge_81 = _local_scalar_dense_65 >= 0 + _assert_scalar_65 = torch.ops.aten._assert_scalar.default(ge_81, "Runtime assertion failed for expression u65 >= 0 on node 'ge_65'"); ge_81 = _assert_scalar_65 = None + select_66 = torch.ops.aten.select.int(device_put_8, 0, 2) + _local_scalar_dense_66 = torch.ops.aten._local_scalar_dense.default(select_66); select_66 = None + ge_82 = _local_scalar_dense_66 >= 0 + _assert_scalar_66 = torch.ops.aten._assert_scalar.default(ge_82, "Runtime assertion failed for expression u66 >= 0 on node 'ge_66'"); ge_82 = _assert_scalar_66 = None + select_67 = torch.ops.aten.select.int(device_put_8, 0, 3) + _local_scalar_dense_67 = torch.ops.aten._local_scalar_dense.default(select_67); select_67 = None + ge_83 = _local_scalar_dense_67 >= 0 + _assert_scalar_67 = torch.ops.aten._assert_scalar.default(ge_83, "Runtime assertion failed for expression u67 >= 0 on node 'ge_67'"); ge_83 = _assert_scalar_67 = None + select_68 = torch.ops.aten.select.int(device_put_8, 0, 4) + _local_scalar_dense_68 = torch.ops.aten._local_scalar_dense.default(select_68); select_68 = None + ge_84 = _local_scalar_dense_68 >= 0 + _assert_scalar_68 = torch.ops.aten._assert_scalar.default(ge_84, "Runtime assertion failed for expression u68 >= 0 on node 'ge_68'"); ge_84 = _assert_scalar_68 = None + select_69 = torch.ops.aten.select.int(device_put_8, 0, 5) + _local_scalar_dense_69 = torch.ops.aten._local_scalar_dense.default(select_69); select_69 = None + ge_85 = _local_scalar_dense_69 >= 0 + _assert_scalar_69 = torch.ops.aten._assert_scalar.default(ge_85, "Runtime assertion failed for expression u69 >= 0 on node 'ge_69'"); ge_85 = _assert_scalar_69 = None + select_70 = torch.ops.aten.select.int(device_put_8, 0, 6) + _local_scalar_dense_70 = torch.ops.aten._local_scalar_dense.default(select_70); select_70 = None + ge_86 = _local_scalar_dense_70 >= 0 + _assert_scalar_70 = torch.ops.aten._assert_scalar.default(ge_86, "Runtime assertion failed for expression u70 >= 0 on node 'ge_70'"); ge_86 = _assert_scalar_70 = None + select_71 = torch.ops.aten.select.int(device_put_8, 0, 7); device_put_8 = None + _local_scalar_dense_71 = torch.ops.aten._local_scalar_dense.default(select_71); select_71 = None + ge_87 = _local_scalar_dense_71 >= 0 + _assert_scalar_71 = torch.ops.aten._assert_scalar.default(ge_87, "Runtime assertion failed for expression u71 >= 0 on node 'ge_71'"); ge_87 = _assert_scalar_71 = None + select_72 = torch.ops.aten.select.int(device_put_9, 0, 0) + _local_scalar_dense_72 = torch.ops.aten._local_scalar_dense.default(select_72); select_72 = None + ge_88 = _local_scalar_dense_72 >= 0 + _assert_scalar_72 = torch.ops.aten._assert_scalar.default(ge_88, "Runtime assertion failed for expression u72 >= 0 on node 'ge_72'"); ge_88 = _assert_scalar_72 = None + select_73 = torch.ops.aten.select.int(device_put_9, 0, 1) + _local_scalar_dense_73 = torch.ops.aten._local_scalar_dense.default(select_73); select_73 = None + ge_89 = _local_scalar_dense_73 >= 0 + _assert_scalar_73 = torch.ops.aten._assert_scalar.default(ge_89, "Runtime assertion failed for expression u73 >= 0 on node 'ge_73'"); ge_89 = _assert_scalar_73 = None + select_74 = torch.ops.aten.select.int(device_put_9, 0, 2) + _local_scalar_dense_74 = torch.ops.aten._local_scalar_dense.default(select_74); select_74 = None + ge_90 = _local_scalar_dense_74 >= 0 + _assert_scalar_74 = torch.ops.aten._assert_scalar.default(ge_90, "Runtime assertion failed for expression u74 >= 0 on node 'ge_74'"); ge_90 = _assert_scalar_74 = None + select_75 = torch.ops.aten.select.int(device_put_9, 0, 3) + _local_scalar_dense_75 = torch.ops.aten._local_scalar_dense.default(select_75); select_75 = None + ge_91 = _local_scalar_dense_75 >= 0 + _assert_scalar_75 = torch.ops.aten._assert_scalar.default(ge_91, "Runtime assertion failed for expression u75 >= 0 on node 'ge_75'"); ge_91 = _assert_scalar_75 = None + select_76 = torch.ops.aten.select.int(device_put_9, 0, 4) + _local_scalar_dense_76 = torch.ops.aten._local_scalar_dense.default(select_76); select_76 = None + ge_92 = _local_scalar_dense_76 >= 0 + _assert_scalar_76 = torch.ops.aten._assert_scalar.default(ge_92, "Runtime assertion failed for expression u76 >= 0 on node 'ge_76'"); ge_92 = _assert_scalar_76 = None + select_77 = torch.ops.aten.select.int(device_put_9, 0, 5) + _local_scalar_dense_77 = torch.ops.aten._local_scalar_dense.default(select_77); select_77 = None + ge_93 = _local_scalar_dense_77 >= 0 + _assert_scalar_77 = torch.ops.aten._assert_scalar.default(ge_93, "Runtime assertion failed for expression u77 >= 0 on node 'ge_77'"); ge_93 = _assert_scalar_77 = None + select_78 = torch.ops.aten.select.int(device_put_9, 0, 6) + _local_scalar_dense_78 = torch.ops.aten._local_scalar_dense.default(select_78); select_78 = None + ge_94 = _local_scalar_dense_78 >= 0 + _assert_scalar_78 = torch.ops.aten._assert_scalar.default(ge_94, "Runtime assertion failed for expression u78 >= 0 on node 'ge_78'"); ge_94 = _assert_scalar_78 = None + select_79 = torch.ops.aten.select.int(device_put_9, 0, 7); device_put_9 = None + _local_scalar_dense_79 = torch.ops.aten._local_scalar_dense.default(select_79); select_79 = None + ge_95 = _local_scalar_dense_79 >= 0 + _assert_scalar_79 = torch.ops.aten._assert_scalar.default(ge_95, "Runtime assertion failed for expression u79 >= 0 on node 'ge_79'"); ge_95 = _assert_scalar_79 = None + all_to_all_single_13 = torch.ops._c10d_functional.all_to_all_single.default(index_8, [_local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79], [_local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71], '521'); index_8 = None + sym_size_int_16 = torch.ops.aten.sym_size.int(all_to_all_single_13, 0) + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_13); all_to_all_single_13 = None + sym_sum_8 = torch.sym_sum((_local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79)) + add_290 = sym_sum_8 + 64; sym_sum_8 = None + add_291 = add_290 + 8; add_290 = None + sub_99 = add_291 - 1; add_291 = None + floordiv_4 = sub_99 // 8; sub_99 = None + mul_218 = floordiv_4 * 8; floordiv_4 = None + cumsum_12 = torch.ops.aten.cumsum.default(wait_tensor_104, 0) + sub_100 = torch.ops.aten.sub.Tensor(cumsum_12, wait_tensor_104); cumsum_12 = None + sum_20 = torch.ops.aten.sum.dim_IntList(view_333, [0]); view_333 = None + clamp_min_4 = torch.ops.aten.clamp_min.default(sum_20, 8); sum_20 = None + add_292 = torch.ops.aten.add.Tensor(clamp_min_4, 8); clamp_min_4 = None + sub_101 = torch.ops.aten.sub.Tensor(add_292, 1); add_292 = None + div_23 = torch.ops.aten.div.Tensor_mode(sub_101, 8, rounding_mode = 'floor'); sub_101 = None + mul_219 = torch.ops.aten.mul.Tensor(div_23, 8); div_23 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(mul_219, torch.int32); mul_219 = None + cumsum_13 = torch.ops.aten.cumsum.default(convert_element_type_284, 0) + sub_102 = torch.ops.aten.sub.Tensor(cumsum_13, convert_element_type_284); cumsum_13 = None + full_72 = torch.ops.aten.full.default([mul_218], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_218 = None + triton_kernel_wrapper_functional_proxy_4 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 4, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_104, 'start_index_values_ptr': sub_100, 'write_offsets_ptr': sub_102, 'output_ptr': full_72}, tensors_to_clone = ['output_ptr']); wait_tensor_104 = sub_100 = sub_102 = full_72 = None + getitem_78 = triton_kernel_wrapper_functional_proxy_4['output_ptr']; triton_kernel_wrapper_functional_proxy_4 = None + cat_16 = torch.ops.aten.cat.default([wait_tensor_105, full_default]); wait_tensor_105 = None + sym_size_int_17 = torch.ops.aten.sym_size.int(cat_16, 0) + sym_sum_9 = torch.sym_sum((1, _local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79)) + index_9 = torch.ops.aten.index.Tensor(cat_16, [getitem_78]); cat_16 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16) + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 8, '513'); convert_element_type_286 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + convert_element_type_288 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16) + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_288, 8, '513'); convert_element_type_288 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + convert_element_type_289 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16) + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_289, 8, '513'); convert_element_type_289 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + cumsum_14 = torch.ops.aten.cumsum.default(convert_element_type_284, 0, dtype = torch.int32); convert_element_type_284 = None + permute_80 = torch.ops.aten.permute.default(wait_tensor_106, [0, 2, 1]); wait_tensor_106 = None + _grouped_mm_12 = torch.ops.aten._grouped_mm.default(index_9, permute_80, cumsum_14); permute_80 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(_grouped_mm_12, torch.float32) + neg_9 = torch.ops.aten.neg.default(convert_element_type_292) + exp_14 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_304 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + div_24 = torch.ops.aten.div.Tensor(convert_element_type_292, add_304); convert_element_type_292 = add_304 = None + convert_element_type_293 = torch.ops.prims.convert_element_type.default(div_24, torch.bfloat16); div_24 = None + permute_81 = torch.ops.aten.permute.default(wait_tensor_109, [0, 2, 1]); wait_tensor_109 = None + _grouped_mm_13 = torch.ops.aten._grouped_mm.default(index_9, permute_81, cumsum_14); permute_81 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_293, _grouped_mm_13); convert_element_type_293 = None + permute_82 = torch.ops.aten.permute.default(wait_tensor_108, [0, 2, 1]); wait_tensor_108 = None + _grouped_mm_14 = torch.ops.aten._grouped_mm.default(mul_231, permute_82, cumsum_14); permute_82 = None + empty_4 = torch.ops.aten.empty.memory_format([sym_size_int_17, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_8 = torch.ops.aten.index_put.default(empty_4, [getitem_78], _grouped_mm_14); empty_4 = _grouped_mm_14 = None + slice_22 = torch.ops.aten.slice.Tensor(index_put_8, 0, 0, -1); index_put_8 = None + all_to_all_single_14 = torch.ops._c10d_functional.all_to_all_single.default(slice_22, [_local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71], [_local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79], '521'); slice_22 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_14); all_to_all_single_14 = None + convert_element_type_294 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16) + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_294, 64, '0'); convert_element_type_294 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_83 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + mm_44 = torch.ops.aten.mm.default(view_326, permute_83); permute_83 = None + convert_element_type_297 = torch.ops.prims.convert_element_type.default(mm_44, torch.float32) + neg_10 = torch.ops.aten.neg.default(convert_element_type_297) + exp_15 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_340 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + div_25 = torch.ops.aten.div.Tensor(convert_element_type_297, add_340); convert_element_type_297 = add_340 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(div_25, torch.bfloat16); div_25 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16) + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_299, 64, '0'); convert_element_type_299 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_114, [1, 0]); wait_tensor_114 = None + mm_45 = torch.ops.aten.mm.default(view_326, permute_84); permute_84 = None + mul_251 = torch.ops.aten.mul.Tensor(convert_element_type_298, mm_45); convert_element_type_298 = None + convert_element_type_302 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16) + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_302, 64, '0'); convert_element_type_302 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + mm_46 = torch.ops.aten.mm.default(mul_251, permute_85); permute_85 = None + index_put_9 = torch.ops.aten.index_put.default(full_default_1, [getitem_77], wait_tensor_112); wait_tensor_112 = None + view_366 = torch.ops.aten.view.default(mul_213, [-1, 1, 6]); mul_213 = None + view_367 = torch.ops.aten.view.default(index_put_9, [-1, 6, 2048]); index_put_9 = None + convert_element_type_305 = torch.ops.prims.convert_element_type.default(view_367, torch.float32); view_367 = None + bmm_4 = torch.ops.aten.bmm.default(view_366, convert_element_type_305) + convert_element_type_306 = torch.ops.prims.convert_element_type.default(bmm_4, torch.bfloat16); bmm_4 = None + squeeze_4 = torch.ops.aten.squeeze.dim(convert_element_type_306, 1); convert_element_type_306 = None + add_344 = torch.ops.aten.add.Tensor(mm_46, squeeze_4); mm_46 = squeeze_4 = None + view_368 = torch.ops.aten.view.default(add_344, [2, 4096, 2048]); add_344 = None + add_345 = torch.ops.aten.add.Tensor(add_280, view_368); view_368 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16) + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 64, '0'); convert_element_type_307 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_308 = torch.ops.prims.convert_element_type.default(add_345, torch.float32) + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_308, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_346 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_346); add_346 = None + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_308, rsqrt_18); convert_element_type_308 = None + mul_255 = torch.ops.aten.mul.Tensor(mul_254, wait_tensor_116); mul_254 = wait_tensor_116 = None + convert_element_type_309 = torch.ops.prims.convert_element_type.default(mul_255, torch.bfloat16); mul_255 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16) + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_310, 64, '0'); convert_element_type_310 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + view_371 = torch.ops.aten.view.default(convert_element_type_309, [8192, 2048]); convert_element_type_309 = None + mm_47 = torch.ops.aten.mm.default(view_371, permute_86); permute_86 = None + view_372 = torch.ops.aten.view.default(mm_47, [2, 4096, 3072]); mm_47 = None + view_373 = torch.ops.aten.view.default(view_372, [2, 4096, -1, 192]); view_372 = None + split_with_sizes_18 = torch.ops.aten.split_with_sizes.default(view_373, [128, 64], -1); view_373 = None + getitem_79 = split_with_sizes_18[0] + getitem_80 = split_with_sizes_18[1]; split_with_sizes_18 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(getitem_80, torch.float32); getitem_80 = None + view_374 = torch.ops.aten.view.default(convert_element_type_313, [2, 4096, 16, -1, 2]); convert_element_type_313 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_374); view_374 = None + mul_256 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_7); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_256); mul_256 = None + view_376 = torch.ops.aten.view.default(view_as_real_12, [2, 4096, 16, 64]); view_as_real_12 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(view_376, torch.bfloat16); view_376 = None + cat_17 = torch.ops.aten.cat.default([getitem_79, convert_element_type_314], -1); getitem_79 = convert_element_type_314 = None + convert_element_type_315 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16) + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_315, 64, '0'); convert_element_type_315 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_118, [1, 0]); wait_tensor_118 = None + mm_48 = torch.ops.aten.mm.default(view_371, permute_87); permute_87 = None + view_379 = torch.ops.aten.view.default(mm_48, [2, 4096, 576]); mm_48 = None + split_with_sizes_19 = torch.ops.aten.split_with_sizes.default(view_379, [512, 64], -1); view_379 = None + getitem_81 = split_with_sizes_19[0] + getitem_82 = split_with_sizes_19[1]; split_with_sizes_19 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(getitem_82, 2); getitem_82 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(unsqueeze_11, torch.float32); unsqueeze_11 = None + view_380 = torch.ops.aten.view.default(convert_element_type_318, [2, 4096, 1, -1, 2]); convert_element_type_318 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_380); view_380 = None + mul_257 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_7); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_257); mul_257 = None + view_382 = torch.ops.aten.view.default(view_as_real_13, [2, 4096, 1, 64]); view_as_real_13 = None + convert_element_type_319 = torch.ops.prims.convert_element_type.default(view_382, torch.bfloat16); view_382 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16) + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 64, '0'); convert_element_type_320 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + convert_element_type_321 = torch.ops.prims.convert_element_type.default(getitem_81, torch.float32) + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_321, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_347 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_347); add_347 = None + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_321, rsqrt_19); convert_element_type_321 = None + mul_259 = torch.ops.aten.mul.Tensor(mul_258, wait_tensor_119); mul_258 = wait_tensor_119 = None + convert_element_type_322 = torch.ops.prims.convert_element_type.default(mul_259, torch.bfloat16); mul_259 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16) + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_323, 64, '0'); convert_element_type_323 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + view_385 = torch.ops.aten.view.default(convert_element_type_322, [8192, 512]); convert_element_type_322 = None + mm_49 = torch.ops.aten.mm.default(view_385, permute_88); permute_88 = None + view_386 = torch.ops.aten.view.default(mm_49, [2, 4096, 4096]); mm_49 = None + view_387 = torch.ops.aten.view.default(view_386, [2, 4096, -1, 256]); view_386 = None + split_with_sizes_20 = torch.ops.aten.split_with_sizes.default(view_387, [128, 128], -1); view_387 = None + getitem_83 = split_with_sizes_20[0] + getitem_84 = split_with_sizes_20[1]; split_with_sizes_20 = None + expand_6 = torch.ops.aten.expand.default(convert_element_type_319, [-1, -1, 16, -1]); convert_element_type_319 = None + cat_18 = torch.ops.aten.cat.default([getitem_83, expand_6], -1); getitem_83 = expand_6 = None + permute_89 = torch.ops.aten.permute.default(cat_17, [0, 2, 1, 3]); cat_17 = None + permute_90 = torch.ops.aten.permute.default(cat_18, [0, 2, 1, 3]); cat_18 = None + permute_91 = torch.ops.aten.permute.default(getitem_84, [0, 2, 1, 3]); getitem_84 = None + sdpa_score6 = self.sdpa_score6 + sdpa_mask6 = self.sdpa_mask6 + flex_attention_6 = torch.ops.higher_order.flex_attention(permute_89, permute_90, permute_91, sdpa_score6, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask6), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score6 = sdpa_mask6 = None + getitem_85 = flex_attention_6[0] + getitem_86 = flex_attention_6[1]; flex_attention_6 = None + permute_92 = torch.ops.aten.permute.default(getitem_85, [0, 2, 1, 3]) + view_388 = torch.ops.aten.view.default(permute_92, [2, 4096, -1]); permute_92 = None + convert_element_type_326 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16) + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_326, 64, '0'); convert_element_type_326 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_93 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + view_390 = torch.ops.aten.view.default(view_388, [8192, 2048]); view_388 = None + mm_50 = torch.ops.aten.mm.default(view_390, permute_93); view_390 = permute_93 = None + view_391 = torch.ops.aten.view.default(mm_50, [2, 4096, 2048]); mm_50 = None + add_348 = torch.ops.aten.add.Tensor(add_345, view_391); view_391 = None + convert_element_type_329 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16) + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_329, 64, '0'); convert_element_type_329 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + convert_element_type_330 = torch.ops.prims.convert_element_type.default(add_348, torch.float32) + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_330, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_349 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_349); add_349 = None + mul_260 = torch.ops.aten.mul.Tensor(convert_element_type_330, rsqrt_20); convert_element_type_330 = None + mul_261 = torch.ops.aten.mul.Tensor(mul_260, wait_tensor_122); mul_260 = wait_tensor_122 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(mul_261, torch.bfloat16); mul_261 = None + view_393 = torch.ops.aten.view.default(convert_element_type_331, [-1, 2048]); convert_element_type_331 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16) + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_332, 64, '0'); convert_element_type_332 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_94 = torch.ops.aten.permute.default(wait_tensor_123, [1, 0]); wait_tensor_123 = None + mm_51 = torch.ops.aten.mm.default(view_393, permute_94); permute_94 = None + convert_element_type_335 = torch.ops.prims.convert_element_type.default(mm_51, torch.float32) + amax_5 = torch.ops.aten.amax.default(convert_element_type_335, [1], True) + sub_120 = torch.ops.aten.sub.Tensor(convert_element_type_335, amax_5); convert_element_type_335 = None + exp_16 = torch.ops.aten.exp.default(sub_120); sub_120 = None + sum_21 = torch.ops.aten.sum.dim_IntList(exp_16, [1], True) + div_26 = torch.ops.aten.div.Tensor(exp_16, sum_21); exp_16 = None + add_350 = torch.ops.aten.add.Tensor(div_26, primals_110); primals_110 = None + topk_5 = torch.ops.aten.topk.default(add_350, 6, -1, True, False); add_350 = None + getitem_89 = topk_5[1]; topk_5 = None + gather_5 = torch.ops.aten.gather.default(div_26, 1, getitem_89); div_26 = None + mul_262 = torch.ops.aten.mul.Tensor(gather_5, 1.0); gather_5 = None + view_395 = torch.ops.aten.view.default(getitem_89, [-1]) + histc_10 = torch.ops.aten.histc.default(view_395, 64, 0, 64) + add_351 = torch.ops.aten.add.Tensor(primals_112, histc_10) + sort_5 = torch.ops.aten.sort.stable(view_395, stable = True); view_395 = None + getitem_91 = sort_5[1]; sort_5 = None + div_27 = torch.ops.aten.div.Tensor_mode(getitem_91, 6, rounding_mode = 'floor') + index_10 = torch.ops.aten.index.Tensor(view_393, [div_27]) + all_to_all_single_15 = torch.ops._c10d_functional.all_to_all_single.default(histc_10, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_15); all_to_all_single_15 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_124); wait_tensor_124 = None + view_399 = torch.ops.aten.view.default(histc_10, [8, -1]); histc_10 = None + sum_22 = torch.ops.aten.sum.dim_IntList(view_399, [1]); view_399 = None + device_put_10 = torch.ops.prims.device_put.default(sum_22, device(type='cpu'), True); sum_22 = None + view_400 = torch.ops.aten.view.default(wait_tensor_125, [8, -1]) + sum_23 = torch.ops.aten.sum.dim_IntList(view_400, [1]) + device_put_11 = torch.ops.prims.device_put.default(sum_23, device(type='cpu')); sum_23 = None + select_80 = torch.ops.aten.select.int(device_put_10, 0, 0) + _local_scalar_dense_80 = torch.ops.aten._local_scalar_dense.default(select_80); select_80 = None + ge_100 = _local_scalar_dense_80 >= 0 + _assert_scalar_80 = torch.ops.aten._assert_scalar.default(ge_100, "Runtime assertion failed for expression u80 >= 0 on node 'ge_80'"); ge_100 = _assert_scalar_80 = None + select_81 = torch.ops.aten.select.int(device_put_10, 0, 1) + _local_scalar_dense_81 = torch.ops.aten._local_scalar_dense.default(select_81); select_81 = None + ge_101 = _local_scalar_dense_81 >= 0 + _assert_scalar_81 = torch.ops.aten._assert_scalar.default(ge_101, "Runtime assertion failed for expression u81 >= 0 on node 'ge_81'"); ge_101 = _assert_scalar_81 = None + select_82 = torch.ops.aten.select.int(device_put_10, 0, 2) + _local_scalar_dense_82 = torch.ops.aten._local_scalar_dense.default(select_82); select_82 = None + ge_102 = _local_scalar_dense_82 >= 0 + _assert_scalar_82 = torch.ops.aten._assert_scalar.default(ge_102, "Runtime assertion failed for expression u82 >= 0 on node 'ge_82'"); ge_102 = _assert_scalar_82 = None + select_83 = torch.ops.aten.select.int(device_put_10, 0, 3) + _local_scalar_dense_83 = torch.ops.aten._local_scalar_dense.default(select_83); select_83 = None + ge_103 = _local_scalar_dense_83 >= 0 + _assert_scalar_83 = torch.ops.aten._assert_scalar.default(ge_103, "Runtime assertion failed for expression u83 >= 0 on node 'ge_83'"); ge_103 = _assert_scalar_83 = None + select_84 = torch.ops.aten.select.int(device_put_10, 0, 4) + _local_scalar_dense_84 = torch.ops.aten._local_scalar_dense.default(select_84); select_84 = None + ge_104 = _local_scalar_dense_84 >= 0 + _assert_scalar_84 = torch.ops.aten._assert_scalar.default(ge_104, "Runtime assertion failed for expression u84 >= 0 on node 'ge_84'"); ge_104 = _assert_scalar_84 = None + select_85 = torch.ops.aten.select.int(device_put_10, 0, 5) + _local_scalar_dense_85 = torch.ops.aten._local_scalar_dense.default(select_85); select_85 = None + ge_105 = _local_scalar_dense_85 >= 0 + _assert_scalar_85 = torch.ops.aten._assert_scalar.default(ge_105, "Runtime assertion failed for expression u85 >= 0 on node 'ge_85'"); ge_105 = _assert_scalar_85 = None + select_86 = torch.ops.aten.select.int(device_put_10, 0, 6) + _local_scalar_dense_86 = torch.ops.aten._local_scalar_dense.default(select_86); select_86 = None + ge_106 = _local_scalar_dense_86 >= 0 + _assert_scalar_86 = torch.ops.aten._assert_scalar.default(ge_106, "Runtime assertion failed for expression u86 >= 0 on node 'ge_86'"); ge_106 = _assert_scalar_86 = None + select_87 = torch.ops.aten.select.int(device_put_10, 0, 7); device_put_10 = None + _local_scalar_dense_87 = torch.ops.aten._local_scalar_dense.default(select_87); select_87 = None + ge_107 = _local_scalar_dense_87 >= 0 + _assert_scalar_87 = torch.ops.aten._assert_scalar.default(ge_107, "Runtime assertion failed for expression u87 >= 0 on node 'ge_87'"); ge_107 = _assert_scalar_87 = None + select_88 = torch.ops.aten.select.int(device_put_11, 0, 0) + _local_scalar_dense_88 = torch.ops.aten._local_scalar_dense.default(select_88); select_88 = None + ge_108 = _local_scalar_dense_88 >= 0 + _assert_scalar_88 = torch.ops.aten._assert_scalar.default(ge_108, "Runtime assertion failed for expression u88 >= 0 on node 'ge_88'"); ge_108 = _assert_scalar_88 = None + select_89 = torch.ops.aten.select.int(device_put_11, 0, 1) + _local_scalar_dense_89 = torch.ops.aten._local_scalar_dense.default(select_89); select_89 = None + ge_109 = _local_scalar_dense_89 >= 0 + _assert_scalar_89 = torch.ops.aten._assert_scalar.default(ge_109, "Runtime assertion failed for expression u89 >= 0 on node 'ge_89'"); ge_109 = _assert_scalar_89 = None + select_90 = torch.ops.aten.select.int(device_put_11, 0, 2) + _local_scalar_dense_90 = torch.ops.aten._local_scalar_dense.default(select_90); select_90 = None + ge_110 = _local_scalar_dense_90 >= 0 + _assert_scalar_90 = torch.ops.aten._assert_scalar.default(ge_110, "Runtime assertion failed for expression u90 >= 0 on node 'ge_90'"); ge_110 = _assert_scalar_90 = None + select_91 = torch.ops.aten.select.int(device_put_11, 0, 3) + _local_scalar_dense_91 = torch.ops.aten._local_scalar_dense.default(select_91); select_91 = None + ge_111 = _local_scalar_dense_91 >= 0 + _assert_scalar_91 = torch.ops.aten._assert_scalar.default(ge_111, "Runtime assertion failed for expression u91 >= 0 on node 'ge_91'"); ge_111 = _assert_scalar_91 = None + select_92 = torch.ops.aten.select.int(device_put_11, 0, 4) + _local_scalar_dense_92 = torch.ops.aten._local_scalar_dense.default(select_92); select_92 = None + ge_112 = _local_scalar_dense_92 >= 0 + _assert_scalar_92 = torch.ops.aten._assert_scalar.default(ge_112, "Runtime assertion failed for expression u92 >= 0 on node 'ge_92'"); ge_112 = _assert_scalar_92 = None + select_93 = torch.ops.aten.select.int(device_put_11, 0, 5) + _local_scalar_dense_93 = torch.ops.aten._local_scalar_dense.default(select_93); select_93 = None + ge_113 = _local_scalar_dense_93 >= 0 + _assert_scalar_93 = torch.ops.aten._assert_scalar.default(ge_113, "Runtime assertion failed for expression u93 >= 0 on node 'ge_93'"); ge_113 = _assert_scalar_93 = None + select_94 = torch.ops.aten.select.int(device_put_11, 0, 6) + _local_scalar_dense_94 = torch.ops.aten._local_scalar_dense.default(select_94); select_94 = None + ge_114 = _local_scalar_dense_94 >= 0 + _assert_scalar_94 = torch.ops.aten._assert_scalar.default(ge_114, "Runtime assertion failed for expression u94 >= 0 on node 'ge_94'"); ge_114 = _assert_scalar_94 = None + select_95 = torch.ops.aten.select.int(device_put_11, 0, 7); device_put_11 = None + _local_scalar_dense_95 = torch.ops.aten._local_scalar_dense.default(select_95); select_95 = None + ge_115 = _local_scalar_dense_95 >= 0 + _assert_scalar_95 = torch.ops.aten._assert_scalar.default(ge_115, "Runtime assertion failed for expression u95 >= 0 on node 'ge_95'"); ge_115 = _assert_scalar_95 = None + all_to_all_single_16 = torch.ops._c10d_functional.all_to_all_single.default(index_10, [_local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95], [_local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87], '521'); index_10 = None + sym_size_int_20 = torch.ops.aten.sym_size.int(all_to_all_single_16, 0) + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_16); all_to_all_single_16 = None + sym_sum_10 = torch.sym_sum((_local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95)) + add_358 = sym_sum_10 + 64; sym_sum_10 = None + add_359 = add_358 + 8; add_358 = None + sub_123 = add_359 - 1; add_359 = None + floordiv_5 = sub_123 // 8; sub_123 = None + mul_267 = floordiv_5 * 8; floordiv_5 = None + cumsum_15 = torch.ops.aten.cumsum.default(wait_tensor_125, 0) + sub_124 = torch.ops.aten.sub.Tensor(cumsum_15, wait_tensor_125); cumsum_15 = None + sum_24 = torch.ops.aten.sum.dim_IntList(view_400, [0]); view_400 = None + clamp_min_5 = torch.ops.aten.clamp_min.default(sum_24, 8); sum_24 = None + add_360 = torch.ops.aten.add.Tensor(clamp_min_5, 8); clamp_min_5 = None + sub_125 = torch.ops.aten.sub.Tensor(add_360, 1); add_360 = None + div_28 = torch.ops.aten.div.Tensor_mode(sub_125, 8, rounding_mode = 'floor'); sub_125 = None + mul_268 = torch.ops.aten.mul.Tensor(div_28, 8); div_28 = None + convert_element_type_338 = torch.ops.prims.convert_element_type.default(mul_268, torch.int32); mul_268 = None + cumsum_16 = torch.ops.aten.cumsum.default(convert_element_type_338, 0) + sub_126 = torch.ops.aten.sub.Tensor(cumsum_16, convert_element_type_338); cumsum_16 = None + full_85 = torch.ops.aten.full.default([mul_267], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_267 = None + triton_kernel_wrapper_functional_proxy_5 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 5, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_125, 'start_index_values_ptr': sub_124, 'write_offsets_ptr': sub_126, 'output_ptr': full_85}, tensors_to_clone = ['output_ptr']); wait_tensor_125 = sub_124 = sub_126 = full_85 = None + getitem_92 = triton_kernel_wrapper_functional_proxy_5['output_ptr']; triton_kernel_wrapper_functional_proxy_5 = None + cat_19 = torch.ops.aten.cat.default([wait_tensor_126, full_default]); wait_tensor_126 = None + sym_size_int_21 = torch.ops.aten.sym_size.int(cat_19, 0) + sym_sum_11 = torch.sym_sum((1, _local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95)) + index_11 = torch.ops.aten.index.Tensor(cat_19, [getitem_92]); cat_19 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16) + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 8, '513'); convert_element_type_340 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + convert_element_type_342 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16) + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_342, 8, '513'); convert_element_type_342 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16) + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_343, 8, '513'); convert_element_type_343 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + cumsum_17 = torch.ops.aten.cumsum.default(convert_element_type_338, 0, dtype = torch.int32); convert_element_type_338 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_127, [0, 2, 1]); wait_tensor_127 = None + _grouped_mm_15 = torch.ops.aten._grouped_mm.default(index_11, permute_95, cumsum_17); permute_95 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(_grouped_mm_15, torch.float32) + neg_11 = torch.ops.aten.neg.default(convert_element_type_346) + exp_17 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_372 = torch.ops.aten.add.Tensor(exp_17, 1); exp_17 = None + div_29 = torch.ops.aten.div.Tensor(convert_element_type_346, add_372); convert_element_type_346 = add_372 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(div_29, torch.bfloat16); div_29 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_130, [0, 2, 1]); wait_tensor_130 = None + _grouped_mm_16 = torch.ops.aten._grouped_mm.default(index_11, permute_96, cumsum_17); permute_96 = None + mul_280 = torch.ops.aten.mul.Tensor(convert_element_type_347, _grouped_mm_16); convert_element_type_347 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_129, [0, 2, 1]); wait_tensor_129 = None + _grouped_mm_17 = torch.ops.aten._grouped_mm.default(mul_280, permute_97, cumsum_17); permute_97 = None + empty_5 = torch.ops.aten.empty.memory_format([sym_size_int_21, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_10 = torch.ops.aten.index_put.default(empty_5, [getitem_92], _grouped_mm_17); empty_5 = _grouped_mm_17 = None + slice_26 = torch.ops.aten.slice.Tensor(index_put_10, 0, 0, -1); index_put_10 = None + all_to_all_single_17 = torch.ops._c10d_functional.all_to_all_single.default(slice_26, [_local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87], [_local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95], '521'); slice_26 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_17); all_to_all_single_17 = None + convert_element_type_348 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16) + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_348, 64, '0'); convert_element_type_348 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + mm_52 = torch.ops.aten.mm.default(view_393, permute_98); permute_98 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(mm_52, torch.float32) + neg_12 = torch.ops.aten.neg.default(convert_element_type_351) + exp_18 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_408 = torch.ops.aten.add.Tensor(exp_18, 1); exp_18 = None + div_30 = torch.ops.aten.div.Tensor(convert_element_type_351, add_408); convert_element_type_351 = add_408 = None + convert_element_type_352 = torch.ops.prims.convert_element_type.default(div_30, torch.bfloat16); div_30 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16) + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 64, '0'); convert_element_type_353 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + mm_53 = torch.ops.aten.mm.default(view_393, permute_99); permute_99 = None + mul_300 = torch.ops.aten.mul.Tensor(convert_element_type_352, mm_53); convert_element_type_352 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16) + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_356, 64, '0'); convert_element_type_356 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + mm_54 = torch.ops.aten.mm.default(mul_300, permute_100); permute_100 = None + index_put_11 = torch.ops.aten.index_put.default(full_default_1, [getitem_91], wait_tensor_133); wait_tensor_133 = None + view_433 = torch.ops.aten.view.default(mul_262, [-1, 1, 6]); mul_262 = None + view_434 = torch.ops.aten.view.default(index_put_11, [-1, 6, 2048]); index_put_11 = None + convert_element_type_359 = torch.ops.prims.convert_element_type.default(view_434, torch.float32); view_434 = None + bmm_5 = torch.ops.aten.bmm.default(view_433, convert_element_type_359) + convert_element_type_360 = torch.ops.prims.convert_element_type.default(bmm_5, torch.bfloat16); bmm_5 = None + squeeze_5 = torch.ops.aten.squeeze.dim(convert_element_type_360, 1); convert_element_type_360 = None + add_412 = torch.ops.aten.add.Tensor(mm_54, squeeze_5); mm_54 = squeeze_5 = None + view_435 = torch.ops.aten.view.default(add_412, [2, 4096, 2048]); add_412 = None + add_413 = torch.ops.aten.add.Tensor(add_348, view_435); view_435 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16) + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 64, '0'); convert_element_type_361 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + convert_element_type_362 = torch.ops.prims.convert_element_type.default(add_413, torch.float32) + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_362, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_414 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_414); add_414 = None + mul_303 = torch.ops.aten.mul.Tensor(convert_element_type_362, rsqrt_21); convert_element_type_362 = None + mul_304 = torch.ops.aten.mul.Tensor(mul_303, wait_tensor_137); mul_303 = wait_tensor_137 = None + convert_element_type_363 = torch.ops.prims.convert_element_type.default(mul_304, torch.bfloat16); mul_304 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16) + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 64, '0'); convert_element_type_364 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + view_438 = torch.ops.aten.view.default(convert_element_type_363, [8192, 2048]); convert_element_type_363 = None + mm_55 = torch.ops.aten.mm.default(view_438, permute_101); permute_101 = None + view_439 = torch.ops.aten.view.default(mm_55, [2, 4096, 3072]); mm_55 = None + view_440 = torch.ops.aten.view.default(view_439, [2, 4096, -1, 192]); view_439 = None + split_with_sizes_21 = torch.ops.aten.split_with_sizes.default(view_440, [128, 64], -1); view_440 = None + getitem_93 = split_with_sizes_21[0] + getitem_94 = split_with_sizes_21[1]; split_with_sizes_21 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(getitem_94, torch.float32); getitem_94 = None + view_441 = torch.ops.aten.view.default(convert_element_type_367, [2, 4096, 16, -1, 2]); convert_element_type_367 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_441); view_441 = None + mul_305 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_7); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_305); mul_305 = None + view_443 = torch.ops.aten.view.default(view_as_real_14, [2, 4096, 16, 64]); view_as_real_14 = None + convert_element_type_368 = torch.ops.prims.convert_element_type.default(view_443, torch.bfloat16); view_443 = None + cat_20 = torch.ops.aten.cat.default([getitem_93, convert_element_type_368], -1); getitem_93 = convert_element_type_368 = None + convert_element_type_369 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16) + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_369, 64, '0'); convert_element_type_369 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_102 = torch.ops.aten.permute.default(wait_tensor_139, [1, 0]); wait_tensor_139 = None + mm_56 = torch.ops.aten.mm.default(view_438, permute_102); permute_102 = None + view_446 = torch.ops.aten.view.default(mm_56, [2, 4096, 576]); mm_56 = None + split_with_sizes_22 = torch.ops.aten.split_with_sizes.default(view_446, [512, 64], -1); view_446 = None + getitem_95 = split_with_sizes_22[0] + getitem_96 = split_with_sizes_22[1]; split_with_sizes_22 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(getitem_96, 2); getitem_96 = None + convert_element_type_372 = torch.ops.prims.convert_element_type.default(unsqueeze_13, torch.float32); unsqueeze_13 = None + view_447 = torch.ops.aten.view.default(convert_element_type_372, [2, 4096, 1, -1, 2]); convert_element_type_372 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_447); view_447 = None + mul_306 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_7); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_306); mul_306 = None + view_449 = torch.ops.aten.view.default(view_as_real_15, [2, 4096, 1, 64]); view_as_real_15 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(view_449, torch.bfloat16); view_449 = None + convert_element_type_374 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16) + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_374, 64, '0'); convert_element_type_374 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + convert_element_type_375 = torch.ops.prims.convert_element_type.default(getitem_95, torch.float32) + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_375, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_415 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_415); add_415 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_375, rsqrt_22); convert_element_type_375 = None + mul_308 = torch.ops.aten.mul.Tensor(mul_307, wait_tensor_140); mul_307 = wait_tensor_140 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(mul_308, torch.bfloat16); mul_308 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16) + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_377, 64, '0'); convert_element_type_377 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_103 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + view_452 = torch.ops.aten.view.default(convert_element_type_376, [8192, 512]); convert_element_type_376 = None + mm_57 = torch.ops.aten.mm.default(view_452, permute_103); permute_103 = None + view_453 = torch.ops.aten.view.default(mm_57, [2, 4096, 4096]); mm_57 = None + view_454 = torch.ops.aten.view.default(view_453, [2, 4096, -1, 256]); view_453 = None + split_with_sizes_23 = torch.ops.aten.split_with_sizes.default(view_454, [128, 128], -1); view_454 = None + getitem_97 = split_with_sizes_23[0] + getitem_98 = split_with_sizes_23[1]; split_with_sizes_23 = None + expand_7 = torch.ops.aten.expand.default(convert_element_type_373, [-1, -1, 16, -1]); convert_element_type_373 = None + cat_21 = torch.ops.aten.cat.default([getitem_97, expand_7], -1); getitem_97 = expand_7 = None + permute_104 = torch.ops.aten.permute.default(cat_20, [0, 2, 1, 3]); cat_20 = None + permute_105 = torch.ops.aten.permute.default(cat_21, [0, 2, 1, 3]); cat_21 = None + permute_106 = torch.ops.aten.permute.default(getitem_98, [0, 2, 1, 3]); getitem_98 = None + sdpa_score7 = self.sdpa_score7 + sdpa_mask7 = self.sdpa_mask7 + flex_attention_7 = torch.ops.higher_order.flex_attention(permute_104, permute_105, permute_106, sdpa_score7, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask7), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score7 = sdpa_mask7 = None + getitem_99 = flex_attention_7[0] + getitem_100 = flex_attention_7[1]; flex_attention_7 = None + permute_107 = torch.ops.aten.permute.default(getitem_99, [0, 2, 1, 3]) + view_455 = torch.ops.aten.view.default(permute_107, [2, 4096, -1]); permute_107 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16) + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 64, '0'); convert_element_type_380 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + view_457 = torch.ops.aten.view.default(view_455, [8192, 2048]); view_455 = None + mm_58 = torch.ops.aten.mm.default(view_457, permute_108); view_457 = permute_108 = None + view_458 = torch.ops.aten.view.default(mm_58, [2, 4096, 2048]); mm_58 = None + add_416 = torch.ops.aten.add.Tensor(add_413, view_458); view_458 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16) + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 64, '0'); convert_element_type_383 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_416, torch.float32) + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_417 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_417); add_417 = None + mul_309 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_310 = torch.ops.aten.mul.Tensor(mul_309, wait_tensor_143); mul_309 = wait_tensor_143 = None + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_310, torch.bfloat16); mul_310 = None + view_460 = torch.ops.aten.view.default(convert_element_type_385, [-1, 2048]); convert_element_type_385 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16) + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 64, '0'); convert_element_type_386 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_144, [1, 0]); wait_tensor_144 = None + mm_59 = torch.ops.aten.mm.default(view_460, permute_109); permute_109 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(mm_59, torch.float32) + amax_6 = torch.ops.aten.amax.default(convert_element_type_389, [1], True) + sub_144 = torch.ops.aten.sub.Tensor(convert_element_type_389, amax_6); convert_element_type_389 = None + exp_19 = torch.ops.aten.exp.default(sub_144); sub_144 = None + sum_25 = torch.ops.aten.sum.dim_IntList(exp_19, [1], True) + div_31 = torch.ops.aten.div.Tensor(exp_19, sum_25); exp_19 = None + add_418 = torch.ops.aten.add.Tensor(div_31, primals_126); primals_126 = None + topk_6 = torch.ops.aten.topk.default(add_418, 6, -1, True, False); add_418 = None + getitem_103 = topk_6[1]; topk_6 = None + gather_6 = torch.ops.aten.gather.default(div_31, 1, getitem_103); div_31 = None + mul_311 = torch.ops.aten.mul.Tensor(gather_6, 1.0); gather_6 = None + view_462 = torch.ops.aten.view.default(getitem_103, [-1]) + histc_12 = torch.ops.aten.histc.default(view_462, 64, 0, 64) + add_419 = torch.ops.aten.add.Tensor(primals_128, histc_12) + sort_6 = torch.ops.aten.sort.stable(view_462, stable = True); view_462 = None + getitem_105 = sort_6[1]; sort_6 = None + div_32 = torch.ops.aten.div.Tensor_mode(getitem_105, 6, rounding_mode = 'floor') + index_12 = torch.ops.aten.index.Tensor(view_460, [div_32]) + all_to_all_single_18 = torch.ops._c10d_functional.all_to_all_single.default(histc_12, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_18); all_to_all_single_18 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_145); wait_tensor_145 = None + view_466 = torch.ops.aten.view.default(histc_12, [8, -1]); histc_12 = None + sum_26 = torch.ops.aten.sum.dim_IntList(view_466, [1]); view_466 = None + device_put_12 = torch.ops.prims.device_put.default(sum_26, device(type='cpu'), True); sum_26 = None + view_467 = torch.ops.aten.view.default(wait_tensor_146, [8, -1]) + sum_27 = torch.ops.aten.sum.dim_IntList(view_467, [1]) + device_put_13 = torch.ops.prims.device_put.default(sum_27, device(type='cpu')); sum_27 = None + select_96 = torch.ops.aten.select.int(device_put_12, 0, 0) + _local_scalar_dense_96 = torch.ops.aten._local_scalar_dense.default(select_96); select_96 = None + ge_120 = _local_scalar_dense_96 >= 0 + _assert_scalar_96 = torch.ops.aten._assert_scalar.default(ge_120, "Runtime assertion failed for expression u96 >= 0 on node 'ge_96'"); ge_120 = _assert_scalar_96 = None + select_97 = torch.ops.aten.select.int(device_put_12, 0, 1) + _local_scalar_dense_97 = torch.ops.aten._local_scalar_dense.default(select_97); select_97 = None + ge_121 = _local_scalar_dense_97 >= 0 + _assert_scalar_97 = torch.ops.aten._assert_scalar.default(ge_121, "Runtime assertion failed for expression u97 >= 0 on node 'ge_97'"); ge_121 = _assert_scalar_97 = None + select_98 = torch.ops.aten.select.int(device_put_12, 0, 2) + _local_scalar_dense_98 = torch.ops.aten._local_scalar_dense.default(select_98); select_98 = None + ge_122 = _local_scalar_dense_98 >= 0 + _assert_scalar_98 = torch.ops.aten._assert_scalar.default(ge_122, "Runtime assertion failed for expression u98 >= 0 on node 'ge_98'"); ge_122 = _assert_scalar_98 = None + select_99 = torch.ops.aten.select.int(device_put_12, 0, 3) + _local_scalar_dense_99 = torch.ops.aten._local_scalar_dense.default(select_99); select_99 = None + ge_123 = _local_scalar_dense_99 >= 0 + _assert_scalar_99 = torch.ops.aten._assert_scalar.default(ge_123, "Runtime assertion failed for expression u99 >= 0 on node 'ge_99'"); ge_123 = _assert_scalar_99 = None + select_100 = torch.ops.aten.select.int(device_put_12, 0, 4) + _local_scalar_dense_100 = torch.ops.aten._local_scalar_dense.default(select_100); select_100 = None + ge_124 = _local_scalar_dense_100 >= 0 + _assert_scalar_100 = torch.ops.aten._assert_scalar.default(ge_124, "Runtime assertion failed for expression u100 >= 0 on node 'ge_100'"); ge_124 = _assert_scalar_100 = None + select_101 = torch.ops.aten.select.int(device_put_12, 0, 5) + _local_scalar_dense_101 = torch.ops.aten._local_scalar_dense.default(select_101); select_101 = None + ge_125 = _local_scalar_dense_101 >= 0 + _assert_scalar_101 = torch.ops.aten._assert_scalar.default(ge_125, "Runtime assertion failed for expression u101 >= 0 on node 'ge_101'"); ge_125 = _assert_scalar_101 = None + select_102 = torch.ops.aten.select.int(device_put_12, 0, 6) + _local_scalar_dense_102 = torch.ops.aten._local_scalar_dense.default(select_102); select_102 = None + ge_126 = _local_scalar_dense_102 >= 0 + _assert_scalar_102 = torch.ops.aten._assert_scalar.default(ge_126, "Runtime assertion failed for expression u102 >= 0 on node 'ge_102'"); ge_126 = _assert_scalar_102 = None + select_103 = torch.ops.aten.select.int(device_put_12, 0, 7); device_put_12 = None + _local_scalar_dense_103 = torch.ops.aten._local_scalar_dense.default(select_103); select_103 = None + ge_127 = _local_scalar_dense_103 >= 0 + _assert_scalar_103 = torch.ops.aten._assert_scalar.default(ge_127, "Runtime assertion failed for expression u103 >= 0 on node 'ge_103'"); ge_127 = _assert_scalar_103 = None + select_104 = torch.ops.aten.select.int(device_put_13, 0, 0) + _local_scalar_dense_104 = torch.ops.aten._local_scalar_dense.default(select_104); select_104 = None + ge_128 = _local_scalar_dense_104 >= 0 + _assert_scalar_104 = torch.ops.aten._assert_scalar.default(ge_128, "Runtime assertion failed for expression u104 >= 0 on node 'ge_104'"); ge_128 = _assert_scalar_104 = None + select_105 = torch.ops.aten.select.int(device_put_13, 0, 1) + _local_scalar_dense_105 = torch.ops.aten._local_scalar_dense.default(select_105); select_105 = None + ge_129 = _local_scalar_dense_105 >= 0 + _assert_scalar_105 = torch.ops.aten._assert_scalar.default(ge_129, "Runtime assertion failed for expression u105 >= 0 on node 'ge_105'"); ge_129 = _assert_scalar_105 = None + select_106 = torch.ops.aten.select.int(device_put_13, 0, 2) + _local_scalar_dense_106 = torch.ops.aten._local_scalar_dense.default(select_106); select_106 = None + ge_130 = _local_scalar_dense_106 >= 0 + _assert_scalar_106 = torch.ops.aten._assert_scalar.default(ge_130, "Runtime assertion failed for expression u106 >= 0 on node 'ge_106'"); ge_130 = _assert_scalar_106 = None + select_107 = torch.ops.aten.select.int(device_put_13, 0, 3) + _local_scalar_dense_107 = torch.ops.aten._local_scalar_dense.default(select_107); select_107 = None + ge_131 = _local_scalar_dense_107 >= 0 + _assert_scalar_107 = torch.ops.aten._assert_scalar.default(ge_131, "Runtime assertion failed for expression u107 >= 0 on node 'ge_107'"); ge_131 = _assert_scalar_107 = None + select_108 = torch.ops.aten.select.int(device_put_13, 0, 4) + _local_scalar_dense_108 = torch.ops.aten._local_scalar_dense.default(select_108); select_108 = None + ge_132 = _local_scalar_dense_108 >= 0 + _assert_scalar_108 = torch.ops.aten._assert_scalar.default(ge_132, "Runtime assertion failed for expression u108 >= 0 on node 'ge_108'"); ge_132 = _assert_scalar_108 = None + select_109 = torch.ops.aten.select.int(device_put_13, 0, 5) + _local_scalar_dense_109 = torch.ops.aten._local_scalar_dense.default(select_109); select_109 = None + ge_133 = _local_scalar_dense_109 >= 0 + _assert_scalar_109 = torch.ops.aten._assert_scalar.default(ge_133, "Runtime assertion failed for expression u109 >= 0 on node 'ge_109'"); ge_133 = _assert_scalar_109 = None + select_110 = torch.ops.aten.select.int(device_put_13, 0, 6) + _local_scalar_dense_110 = torch.ops.aten._local_scalar_dense.default(select_110); select_110 = None + ge_134 = _local_scalar_dense_110 >= 0 + _assert_scalar_110 = torch.ops.aten._assert_scalar.default(ge_134, "Runtime assertion failed for expression u110 >= 0 on node 'ge_110'"); ge_134 = _assert_scalar_110 = None + select_111 = torch.ops.aten.select.int(device_put_13, 0, 7); device_put_13 = None + _local_scalar_dense_111 = torch.ops.aten._local_scalar_dense.default(select_111); select_111 = None + ge_135 = _local_scalar_dense_111 >= 0 + _assert_scalar_111 = torch.ops.aten._assert_scalar.default(ge_135, "Runtime assertion failed for expression u111 >= 0 on node 'ge_111'"); ge_135 = _assert_scalar_111 = None + all_to_all_single_19 = torch.ops._c10d_functional.all_to_all_single.default(index_12, [_local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111], [_local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103], '521'); index_12 = None + sym_size_int_24 = torch.ops.aten.sym_size.int(all_to_all_single_19, 0) + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_19); all_to_all_single_19 = None + sym_sum_12 = torch.sym_sum((_local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111)) + add_426 = sym_sum_12 + 64; sym_sum_12 = None + add_427 = add_426 + 8; add_426 = None + sub_147 = add_427 - 1; add_427 = None + floordiv_6 = sub_147 // 8; sub_147 = None + mul_316 = floordiv_6 * 8; floordiv_6 = None + cumsum_18 = torch.ops.aten.cumsum.default(wait_tensor_146, 0) + sub_148 = torch.ops.aten.sub.Tensor(cumsum_18, wait_tensor_146); cumsum_18 = None + sum_28 = torch.ops.aten.sum.dim_IntList(view_467, [0]); view_467 = None + clamp_min_6 = torch.ops.aten.clamp_min.default(sum_28, 8); sum_28 = None + add_428 = torch.ops.aten.add.Tensor(clamp_min_6, 8); clamp_min_6 = None + sub_149 = torch.ops.aten.sub.Tensor(add_428, 1); add_428 = None + div_33 = torch.ops.aten.div.Tensor_mode(sub_149, 8, rounding_mode = 'floor'); sub_149 = None + mul_317 = torch.ops.aten.mul.Tensor(div_33, 8); div_33 = None + convert_element_type_392 = torch.ops.prims.convert_element_type.default(mul_317, torch.int32); mul_317 = None + cumsum_19 = torch.ops.aten.cumsum.default(convert_element_type_392, 0) + sub_150 = torch.ops.aten.sub.Tensor(cumsum_19, convert_element_type_392); cumsum_19 = None + full_98 = torch.ops.aten.full.default([mul_316], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_316 = None + triton_kernel_wrapper_functional_proxy_6 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 6, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_146, 'start_index_values_ptr': sub_148, 'write_offsets_ptr': sub_150, 'output_ptr': full_98}, tensors_to_clone = ['output_ptr']); wait_tensor_146 = sub_148 = sub_150 = full_98 = None + getitem_106 = triton_kernel_wrapper_functional_proxy_6['output_ptr']; triton_kernel_wrapper_functional_proxy_6 = None + cat_22 = torch.ops.aten.cat.default([wait_tensor_147, full_default]); wait_tensor_147 = None + sym_size_int_25 = torch.ops.aten.sym_size.int(cat_22, 0) + sym_sum_13 = torch.sym_sum((1, _local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111)) + index_13 = torch.ops.aten.index.Tensor(cat_22, [getitem_106]); cat_22 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16) + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 8, '513'); convert_element_type_394 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + convert_element_type_396 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16) + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_396, 8, '513'); convert_element_type_396 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16) + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 8, '513'); convert_element_type_397 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + cumsum_20 = torch.ops.aten.cumsum.default(convert_element_type_392, 0, dtype = torch.int32); convert_element_type_392 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_148, [0, 2, 1]); wait_tensor_148 = None + _grouped_mm_18 = torch.ops.aten._grouped_mm.default(index_13, permute_110, cumsum_20); permute_110 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(_grouped_mm_18, torch.float32) + neg_13 = torch.ops.aten.neg.default(convert_element_type_400) + exp_20 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_440 = torch.ops.aten.add.Tensor(exp_20, 1); exp_20 = None + div_34 = torch.ops.aten.div.Tensor(convert_element_type_400, add_440); convert_element_type_400 = add_440 = None + convert_element_type_401 = torch.ops.prims.convert_element_type.default(div_34, torch.bfloat16); div_34 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_151, [0, 2, 1]); wait_tensor_151 = None + _grouped_mm_19 = torch.ops.aten._grouped_mm.default(index_13, permute_111, cumsum_20); permute_111 = None + mul_329 = torch.ops.aten.mul.Tensor(convert_element_type_401, _grouped_mm_19); convert_element_type_401 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_150, [0, 2, 1]); wait_tensor_150 = None + _grouped_mm_20 = torch.ops.aten._grouped_mm.default(mul_329, permute_112, cumsum_20); permute_112 = None + empty_6 = torch.ops.aten.empty.memory_format([sym_size_int_25, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_12 = torch.ops.aten.index_put.default(empty_6, [getitem_106], _grouped_mm_20); empty_6 = _grouped_mm_20 = None + slice_30 = torch.ops.aten.slice.Tensor(index_put_12, 0, 0, -1); index_put_12 = None + all_to_all_single_20 = torch.ops._c10d_functional.all_to_all_single.default(slice_30, [_local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103], [_local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111], '521'); slice_30 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_20); all_to_all_single_20 = None + convert_element_type_402 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16) + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_402, 64, '0'); convert_element_type_402 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_113 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + mm_60 = torch.ops.aten.mm.default(view_460, permute_113); permute_113 = None + convert_element_type_405 = torch.ops.prims.convert_element_type.default(mm_60, torch.float32) + neg_14 = torch.ops.aten.neg.default(convert_element_type_405) + exp_21 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_476 = torch.ops.aten.add.Tensor(exp_21, 1); exp_21 = None + div_35 = torch.ops.aten.div.Tensor(convert_element_type_405, add_476); convert_element_type_405 = add_476 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(div_35, torch.bfloat16); div_35 = None + convert_element_type_407 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16) + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_407, 64, '0'); convert_element_type_407 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_114 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + mm_61 = torch.ops.aten.mm.default(view_460, permute_114); permute_114 = None + mul_349 = torch.ops.aten.mul.Tensor(convert_element_type_406, mm_61); convert_element_type_406 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16) + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_410, 64, '0'); convert_element_type_410 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_115 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + mm_62 = torch.ops.aten.mm.default(mul_349, permute_115); permute_115 = None + index_put_13 = torch.ops.aten.index_put.default(full_default_1, [getitem_105], wait_tensor_154); wait_tensor_154 = None + view_500 = torch.ops.aten.view.default(mul_311, [-1, 1, 6]); mul_311 = None + view_501 = torch.ops.aten.view.default(index_put_13, [-1, 6, 2048]); index_put_13 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(view_501, torch.float32); view_501 = None + bmm_6 = torch.ops.aten.bmm.default(view_500, convert_element_type_413) + convert_element_type_414 = torch.ops.prims.convert_element_type.default(bmm_6, torch.bfloat16); bmm_6 = None + squeeze_6 = torch.ops.aten.squeeze.dim(convert_element_type_414, 1); convert_element_type_414 = None + add_480 = torch.ops.aten.add.Tensor(mm_62, squeeze_6); mm_62 = squeeze_6 = None + view_502 = torch.ops.aten.view.default(add_480, [2, 4096, 2048]); add_480 = None + add_481 = torch.ops.aten.add.Tensor(add_416, view_502); view_502 = None + convert_element_type_415 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16) + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_415, 64, '0'); convert_element_type_415 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(add_481, torch.float32) + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_416, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_482 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_482); add_482 = None + mul_352 = torch.ops.aten.mul.Tensor(convert_element_type_416, rsqrt_24); convert_element_type_416 = None + mul_353 = torch.ops.aten.mul.Tensor(mul_352, wait_tensor_158); mul_352 = wait_tensor_158 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(mul_353, torch.bfloat16); mul_353 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16) + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 64, '0'); convert_element_type_418 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_116 = torch.ops.aten.permute.default(wait_tensor_159, [1, 0]); wait_tensor_159 = None + view_505 = torch.ops.aten.view.default(convert_element_type_417, [8192, 2048]); convert_element_type_417 = None + mm_63 = torch.ops.aten.mm.default(view_505, permute_116); permute_116 = None + view_506 = torch.ops.aten.view.default(mm_63, [2, 4096, 3072]); mm_63 = None + view_507 = torch.ops.aten.view.default(view_506, [2, 4096, -1, 192]); view_506 = None + split_with_sizes_24 = torch.ops.aten.split_with_sizes.default(view_507, [128, 64], -1); view_507 = None + getitem_107 = split_with_sizes_24[0] + getitem_108 = split_with_sizes_24[1]; split_with_sizes_24 = None + convert_element_type_421 = torch.ops.prims.convert_element_type.default(getitem_108, torch.float32); getitem_108 = None + view_508 = torch.ops.aten.view.default(convert_element_type_421, [2, 4096, 16, -1, 2]); convert_element_type_421 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_508); view_508 = None + mul_354 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_7); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_354); mul_354 = None + view_510 = torch.ops.aten.view.default(view_as_real_16, [2, 4096, 16, 64]); view_as_real_16 = None + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_510, torch.bfloat16); view_510 = None + cat_23 = torch.ops.aten.cat.default([getitem_107, convert_element_type_422], -1); getitem_107 = convert_element_type_422 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16) + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_423, 64, '0'); convert_element_type_423 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + mm_64 = torch.ops.aten.mm.default(view_505, permute_117); permute_117 = None + view_513 = torch.ops.aten.view.default(mm_64, [2, 4096, 576]); mm_64 = None + split_with_sizes_25 = torch.ops.aten.split_with_sizes.default(view_513, [512, 64], -1); view_513 = None + getitem_109 = split_with_sizes_25[0] + getitem_110 = split_with_sizes_25[1]; split_with_sizes_25 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(getitem_110, 2); getitem_110 = None + convert_element_type_426 = torch.ops.prims.convert_element_type.default(unsqueeze_15, torch.float32); unsqueeze_15 = None + view_514 = torch.ops.aten.view.default(convert_element_type_426, [2, 4096, 1, -1, 2]); convert_element_type_426 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_514); view_514 = None + mul_355 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_7); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_355); mul_355 = None + view_516 = torch.ops.aten.view.default(view_as_real_17, [2, 4096, 1, 64]); view_as_real_17 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(view_516, torch.bfloat16); view_516 = None + convert_element_type_428 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16) + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_428, 64, '0'); convert_element_type_428 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_429 = torch.ops.prims.convert_element_type.default(getitem_109, torch.float32) + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_429, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_483 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_483); add_483 = None + mul_356 = torch.ops.aten.mul.Tensor(convert_element_type_429, rsqrt_25); convert_element_type_429 = None + mul_357 = torch.ops.aten.mul.Tensor(mul_356, wait_tensor_161); mul_356 = wait_tensor_161 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(mul_357, torch.bfloat16); mul_357 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16) + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_431, 64, '0'); convert_element_type_431 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + view_519 = torch.ops.aten.view.default(convert_element_type_430, [8192, 512]); convert_element_type_430 = None + mm_65 = torch.ops.aten.mm.default(view_519, permute_118); permute_118 = None + view_520 = torch.ops.aten.view.default(mm_65, [2, 4096, 4096]); mm_65 = None + view_521 = torch.ops.aten.view.default(view_520, [2, 4096, -1, 256]); view_520 = None + split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_521, [128, 128], -1); view_521 = None + getitem_111 = split_with_sizes_26[0] + getitem_112 = split_with_sizes_26[1]; split_with_sizes_26 = None + expand_8 = torch.ops.aten.expand.default(convert_element_type_427, [-1, -1, 16, -1]); convert_element_type_427 = None + cat_24 = torch.ops.aten.cat.default([getitem_111, expand_8], -1); getitem_111 = expand_8 = None + permute_119 = torch.ops.aten.permute.default(cat_23, [0, 2, 1, 3]); cat_23 = None + permute_120 = torch.ops.aten.permute.default(cat_24, [0, 2, 1, 3]); cat_24 = None + permute_121 = torch.ops.aten.permute.default(getitem_112, [0, 2, 1, 3]); getitem_112 = None + sdpa_score8 = self.sdpa_score8 + sdpa_mask8 = self.sdpa_mask8 + flex_attention_8 = torch.ops.higher_order.flex_attention(permute_119, permute_120, permute_121, sdpa_score8, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask8), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score8 = sdpa_mask8 = None + getitem_113 = flex_attention_8[0] + getitem_114 = flex_attention_8[1]; flex_attention_8 = None + permute_122 = torch.ops.aten.permute.default(getitem_113, [0, 2, 1, 3]) + view_522 = torch.ops.aten.view.default(permute_122, [2, 4096, -1]); permute_122 = None + convert_element_type_434 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16) + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_434, 64, '0'); convert_element_type_434 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + view_524 = torch.ops.aten.view.default(view_522, [8192, 2048]); view_522 = None + mm_66 = torch.ops.aten.mm.default(view_524, permute_123); view_524 = permute_123 = None + view_525 = torch.ops.aten.view.default(mm_66, [2, 4096, 2048]); mm_66 = None + add_484 = torch.ops.aten.add.Tensor(add_481, view_525); view_525 = None + convert_element_type_437 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16) + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_437, 64, '0'); convert_element_type_437 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_438 = torch.ops.prims.convert_element_type.default(add_484, torch.float32) + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_438, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_485 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_485); add_485 = None + mul_358 = torch.ops.aten.mul.Tensor(convert_element_type_438, rsqrt_26); convert_element_type_438 = None + mul_359 = torch.ops.aten.mul.Tensor(mul_358, wait_tensor_164); mul_358 = wait_tensor_164 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(mul_359, torch.bfloat16); mul_359 = None + view_527 = torch.ops.aten.view.default(convert_element_type_439, [-1, 2048]); convert_element_type_439 = None + convert_element_type_440 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16) + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_440, 64, '0'); convert_element_type_440 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_124 = torch.ops.aten.permute.default(wait_tensor_165, [1, 0]); wait_tensor_165 = None + mm_67 = torch.ops.aten.mm.default(view_527, permute_124); permute_124 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(mm_67, torch.float32) + amax_7 = torch.ops.aten.amax.default(convert_element_type_443, [1], True) + sub_168 = torch.ops.aten.sub.Tensor(convert_element_type_443, amax_7); convert_element_type_443 = None + exp_22 = torch.ops.aten.exp.default(sub_168); sub_168 = None + sum_29 = torch.ops.aten.sum.dim_IntList(exp_22, [1], True) + div_36 = torch.ops.aten.div.Tensor(exp_22, sum_29); exp_22 = None + add_486 = torch.ops.aten.add.Tensor(div_36, primals_142); primals_142 = None + topk_7 = torch.ops.aten.topk.default(add_486, 6, -1, True, False); add_486 = None + getitem_117 = topk_7[1]; topk_7 = None + gather_7 = torch.ops.aten.gather.default(div_36, 1, getitem_117); div_36 = None + mul_360 = torch.ops.aten.mul.Tensor(gather_7, 1.0); gather_7 = None + view_529 = torch.ops.aten.view.default(getitem_117, [-1]) + histc_14 = torch.ops.aten.histc.default(view_529, 64, 0, 64) + add_487 = torch.ops.aten.add.Tensor(primals_144, histc_14) + sort_7 = torch.ops.aten.sort.stable(view_529, stable = True); view_529 = None + getitem_119 = sort_7[1]; sort_7 = None + div_37 = torch.ops.aten.div.Tensor_mode(getitem_119, 6, rounding_mode = 'floor') + index_14 = torch.ops.aten.index.Tensor(view_527, [div_37]) + all_to_all_single_21 = torch.ops._c10d_functional.all_to_all_single.default(histc_14, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_21); all_to_all_single_21 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_166); wait_tensor_166 = None + view_533 = torch.ops.aten.view.default(histc_14, [8, -1]); histc_14 = None + sum_30 = torch.ops.aten.sum.dim_IntList(view_533, [1]); view_533 = None + device_put_14 = torch.ops.prims.device_put.default(sum_30, device(type='cpu'), True); sum_30 = None + view_534 = torch.ops.aten.view.default(wait_tensor_167, [8, -1]) + sum_31 = torch.ops.aten.sum.dim_IntList(view_534, [1]) + device_put_15 = torch.ops.prims.device_put.default(sum_31, device(type='cpu')); sum_31 = None + select_112 = torch.ops.aten.select.int(device_put_14, 0, 0) + _local_scalar_dense_112 = torch.ops.aten._local_scalar_dense.default(select_112); select_112 = None + ge_140 = _local_scalar_dense_112 >= 0 + _assert_scalar_112 = torch.ops.aten._assert_scalar.default(ge_140, "Runtime assertion failed for expression u112 >= 0 on node 'ge_112'"); ge_140 = _assert_scalar_112 = None + select_113 = torch.ops.aten.select.int(device_put_14, 0, 1) + _local_scalar_dense_113 = torch.ops.aten._local_scalar_dense.default(select_113); select_113 = None + ge_141 = _local_scalar_dense_113 >= 0 + _assert_scalar_113 = torch.ops.aten._assert_scalar.default(ge_141, "Runtime assertion failed for expression u113 >= 0 on node 'ge_113'"); ge_141 = _assert_scalar_113 = None + select_114 = torch.ops.aten.select.int(device_put_14, 0, 2) + _local_scalar_dense_114 = torch.ops.aten._local_scalar_dense.default(select_114); select_114 = None + ge_142 = _local_scalar_dense_114 >= 0 + _assert_scalar_114 = torch.ops.aten._assert_scalar.default(ge_142, "Runtime assertion failed for expression u114 >= 0 on node 'ge_114'"); ge_142 = _assert_scalar_114 = None + select_115 = torch.ops.aten.select.int(device_put_14, 0, 3) + _local_scalar_dense_115 = torch.ops.aten._local_scalar_dense.default(select_115); select_115 = None + ge_143 = _local_scalar_dense_115 >= 0 + _assert_scalar_115 = torch.ops.aten._assert_scalar.default(ge_143, "Runtime assertion failed for expression u115 >= 0 on node 'ge_115'"); ge_143 = _assert_scalar_115 = None + select_116 = torch.ops.aten.select.int(device_put_14, 0, 4) + _local_scalar_dense_116 = torch.ops.aten._local_scalar_dense.default(select_116); select_116 = None + ge_144 = _local_scalar_dense_116 >= 0 + _assert_scalar_116 = torch.ops.aten._assert_scalar.default(ge_144, "Runtime assertion failed for expression u116 >= 0 on node 'ge_116'"); ge_144 = _assert_scalar_116 = None + select_117 = torch.ops.aten.select.int(device_put_14, 0, 5) + _local_scalar_dense_117 = torch.ops.aten._local_scalar_dense.default(select_117); select_117 = None + ge_145 = _local_scalar_dense_117 >= 0 + _assert_scalar_117 = torch.ops.aten._assert_scalar.default(ge_145, "Runtime assertion failed for expression u117 >= 0 on node 'ge_117'"); ge_145 = _assert_scalar_117 = None + select_118 = torch.ops.aten.select.int(device_put_14, 0, 6) + _local_scalar_dense_118 = torch.ops.aten._local_scalar_dense.default(select_118); select_118 = None + ge_146 = _local_scalar_dense_118 >= 0 + _assert_scalar_118 = torch.ops.aten._assert_scalar.default(ge_146, "Runtime assertion failed for expression u118 >= 0 on node 'ge_118'"); ge_146 = _assert_scalar_118 = None + select_119 = torch.ops.aten.select.int(device_put_14, 0, 7); device_put_14 = None + _local_scalar_dense_119 = torch.ops.aten._local_scalar_dense.default(select_119); select_119 = None + ge_147 = _local_scalar_dense_119 >= 0 + _assert_scalar_119 = torch.ops.aten._assert_scalar.default(ge_147, "Runtime assertion failed for expression u119 >= 0 on node 'ge_119'"); ge_147 = _assert_scalar_119 = None + select_120 = torch.ops.aten.select.int(device_put_15, 0, 0) + _local_scalar_dense_120 = torch.ops.aten._local_scalar_dense.default(select_120); select_120 = None + ge_148 = _local_scalar_dense_120 >= 0 + _assert_scalar_120 = torch.ops.aten._assert_scalar.default(ge_148, "Runtime assertion failed for expression u120 >= 0 on node 'ge_120'"); ge_148 = _assert_scalar_120 = None + select_121 = torch.ops.aten.select.int(device_put_15, 0, 1) + _local_scalar_dense_121 = torch.ops.aten._local_scalar_dense.default(select_121); select_121 = None + ge_149 = _local_scalar_dense_121 >= 0 + _assert_scalar_121 = torch.ops.aten._assert_scalar.default(ge_149, "Runtime assertion failed for expression u121 >= 0 on node 'ge_121'"); ge_149 = _assert_scalar_121 = None + select_122 = torch.ops.aten.select.int(device_put_15, 0, 2) + _local_scalar_dense_122 = torch.ops.aten._local_scalar_dense.default(select_122); select_122 = None + ge_150 = _local_scalar_dense_122 >= 0 + _assert_scalar_122 = torch.ops.aten._assert_scalar.default(ge_150, "Runtime assertion failed for expression u122 >= 0 on node 'ge_122'"); ge_150 = _assert_scalar_122 = None + select_123 = torch.ops.aten.select.int(device_put_15, 0, 3) + _local_scalar_dense_123 = torch.ops.aten._local_scalar_dense.default(select_123); select_123 = None + ge_151 = _local_scalar_dense_123 >= 0 + _assert_scalar_123 = torch.ops.aten._assert_scalar.default(ge_151, "Runtime assertion failed for expression u123 >= 0 on node 'ge_123'"); ge_151 = _assert_scalar_123 = None + select_124 = torch.ops.aten.select.int(device_put_15, 0, 4) + _local_scalar_dense_124 = torch.ops.aten._local_scalar_dense.default(select_124); select_124 = None + ge_152 = _local_scalar_dense_124 >= 0 + _assert_scalar_124 = torch.ops.aten._assert_scalar.default(ge_152, "Runtime assertion failed for expression u124 >= 0 on node 'ge_124'"); ge_152 = _assert_scalar_124 = None + select_125 = torch.ops.aten.select.int(device_put_15, 0, 5) + _local_scalar_dense_125 = torch.ops.aten._local_scalar_dense.default(select_125); select_125 = None + ge_153 = _local_scalar_dense_125 >= 0 + _assert_scalar_125 = torch.ops.aten._assert_scalar.default(ge_153, "Runtime assertion failed for expression u125 >= 0 on node 'ge_125'"); ge_153 = _assert_scalar_125 = None + select_126 = torch.ops.aten.select.int(device_put_15, 0, 6) + _local_scalar_dense_126 = torch.ops.aten._local_scalar_dense.default(select_126); select_126 = None + ge_154 = _local_scalar_dense_126 >= 0 + _assert_scalar_126 = torch.ops.aten._assert_scalar.default(ge_154, "Runtime assertion failed for expression u126 >= 0 on node 'ge_126'"); ge_154 = _assert_scalar_126 = None + select_127 = torch.ops.aten.select.int(device_put_15, 0, 7); device_put_15 = None + _local_scalar_dense_127 = torch.ops.aten._local_scalar_dense.default(select_127); select_127 = None + ge_155 = _local_scalar_dense_127 >= 0 + _assert_scalar_127 = torch.ops.aten._assert_scalar.default(ge_155, "Runtime assertion failed for expression u127 >= 0 on node 'ge_127'"); ge_155 = _assert_scalar_127 = None + all_to_all_single_22 = torch.ops._c10d_functional.all_to_all_single.default(index_14, [_local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127], [_local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119], '521'); index_14 = None + sym_size_int_28 = torch.ops.aten.sym_size.int(all_to_all_single_22, 0) + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_22); all_to_all_single_22 = None + sym_sum_14 = torch.sym_sum((_local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127)) + add_494 = sym_sum_14 + 64; sym_sum_14 = None + add_495 = add_494 + 8; add_494 = None + sub_171 = add_495 - 1; add_495 = None + floordiv_7 = sub_171 // 8; sub_171 = None + mul_365 = floordiv_7 * 8; floordiv_7 = None + cumsum_21 = torch.ops.aten.cumsum.default(wait_tensor_167, 0) + sub_172 = torch.ops.aten.sub.Tensor(cumsum_21, wait_tensor_167); cumsum_21 = None + sum_32 = torch.ops.aten.sum.dim_IntList(view_534, [0]); view_534 = None + clamp_min_7 = torch.ops.aten.clamp_min.default(sum_32, 8); sum_32 = None + add_496 = torch.ops.aten.add.Tensor(clamp_min_7, 8); clamp_min_7 = None + sub_173 = torch.ops.aten.sub.Tensor(add_496, 1); add_496 = None + div_38 = torch.ops.aten.div.Tensor_mode(sub_173, 8, rounding_mode = 'floor'); sub_173 = None + mul_366 = torch.ops.aten.mul.Tensor(div_38, 8); div_38 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(mul_366, torch.int32); mul_366 = None + cumsum_22 = torch.ops.aten.cumsum.default(convert_element_type_446, 0) + sub_174 = torch.ops.aten.sub.Tensor(cumsum_22, convert_element_type_446); cumsum_22 = None + full_111 = torch.ops.aten.full.default([mul_365], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_365 = None + triton_kernel_wrapper_functional_proxy_7 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 7, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_167, 'start_index_values_ptr': sub_172, 'write_offsets_ptr': sub_174, 'output_ptr': full_111}, tensors_to_clone = ['output_ptr']); wait_tensor_167 = sub_172 = sub_174 = full_111 = None + getitem_120 = triton_kernel_wrapper_functional_proxy_7['output_ptr']; triton_kernel_wrapper_functional_proxy_7 = None + cat_25 = torch.ops.aten.cat.default([wait_tensor_168, full_default]); wait_tensor_168 = None + sym_size_int_29 = torch.ops.aten.sym_size.int(cat_25, 0) + sym_sum_15 = torch.sym_sum((1, _local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127)) + index_15 = torch.ops.aten.index.Tensor(cat_25, [getitem_120]); cat_25 = None + convert_element_type_448 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16) + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_448, 8, '513'); convert_element_type_448 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16) + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_450, 8, '513'); convert_element_type_450 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16) + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 8, '513'); convert_element_type_451 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + cumsum_23 = torch.ops.aten.cumsum.default(convert_element_type_446, 0, dtype = torch.int32); convert_element_type_446 = None + permute_125 = torch.ops.aten.permute.default(wait_tensor_169, [0, 2, 1]); wait_tensor_169 = None + _grouped_mm_21 = torch.ops.aten._grouped_mm.default(index_15, permute_125, cumsum_23); permute_125 = None + convert_element_type_454 = torch.ops.prims.convert_element_type.default(_grouped_mm_21, torch.float32) + neg_15 = torch.ops.aten.neg.default(convert_element_type_454) + exp_23 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_508 = torch.ops.aten.add.Tensor(exp_23, 1); exp_23 = None + div_39 = torch.ops.aten.div.Tensor(convert_element_type_454, add_508); convert_element_type_454 = add_508 = None + convert_element_type_455 = torch.ops.prims.convert_element_type.default(div_39, torch.bfloat16); div_39 = None + permute_126 = torch.ops.aten.permute.default(wait_tensor_172, [0, 2, 1]); wait_tensor_172 = None + _grouped_mm_22 = torch.ops.aten._grouped_mm.default(index_15, permute_126, cumsum_23); permute_126 = None + mul_378 = torch.ops.aten.mul.Tensor(convert_element_type_455, _grouped_mm_22); convert_element_type_455 = None + permute_127 = torch.ops.aten.permute.default(wait_tensor_171, [0, 2, 1]); wait_tensor_171 = None + _grouped_mm_23 = torch.ops.aten._grouped_mm.default(mul_378, permute_127, cumsum_23); permute_127 = None + empty_7 = torch.ops.aten.empty.memory_format([sym_size_int_29, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_14 = torch.ops.aten.index_put.default(empty_7, [getitem_120], _grouped_mm_23); empty_7 = _grouped_mm_23 = None + slice_34 = torch.ops.aten.slice.Tensor(index_put_14, 0, 0, -1); index_put_14 = None + all_to_all_single_23 = torch.ops._c10d_functional.all_to_all_single.default(slice_34, [_local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119], [_local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127], '521'); slice_34 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_23); all_to_all_single_23 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16) + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_456, 64, '0'); convert_element_type_456 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + mm_68 = torch.ops.aten.mm.default(view_527, permute_128); permute_128 = None + convert_element_type_459 = torch.ops.prims.convert_element_type.default(mm_68, torch.float32) + neg_16 = torch.ops.aten.neg.default(convert_element_type_459) + exp_24 = torch.ops.aten.exp.default(neg_16); neg_16 = None + add_544 = torch.ops.aten.add.Tensor(exp_24, 1); exp_24 = None + div_40 = torch.ops.aten.div.Tensor(convert_element_type_459, add_544); convert_element_type_459 = add_544 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(div_40, torch.bfloat16); div_40 = None + convert_element_type_461 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16) + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_461, 64, '0'); convert_element_type_461 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_177, [1, 0]); wait_tensor_177 = None + mm_69 = torch.ops.aten.mm.default(view_527, permute_129); permute_129 = None + mul_398 = torch.ops.aten.mul.Tensor(convert_element_type_460, mm_69); convert_element_type_460 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16) + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_464, 64, '0'); convert_element_type_464 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + mm_70 = torch.ops.aten.mm.default(mul_398, permute_130); permute_130 = None + index_put_15 = torch.ops.aten.index_put.default(full_default_1, [getitem_119], wait_tensor_175); wait_tensor_175 = None + view_567 = torch.ops.aten.view.default(mul_360, [-1, 1, 6]); mul_360 = None + view_568 = torch.ops.aten.view.default(index_put_15, [-1, 6, 2048]); index_put_15 = None + convert_element_type_467 = torch.ops.prims.convert_element_type.default(view_568, torch.float32); view_568 = None + bmm_7 = torch.ops.aten.bmm.default(view_567, convert_element_type_467) + convert_element_type_468 = torch.ops.prims.convert_element_type.default(bmm_7, torch.bfloat16); bmm_7 = None + squeeze_7 = torch.ops.aten.squeeze.dim(convert_element_type_468, 1); convert_element_type_468 = None + add_548 = torch.ops.aten.add.Tensor(mm_70, squeeze_7); mm_70 = squeeze_7 = None + view_569 = torch.ops.aten.view.default(add_548, [2, 4096, 2048]); add_548 = None + add_549 = torch.ops.aten.add.Tensor(add_484, view_569); view_569 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16) + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 64, '0'); convert_element_type_469 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + convert_element_type_470 = torch.ops.prims.convert_element_type.default(add_549, torch.float32) + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_470, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_550 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_550); add_550 = None + mul_401 = torch.ops.aten.mul.Tensor(convert_element_type_470, rsqrt_27); convert_element_type_470 = None + mul_402 = torch.ops.aten.mul.Tensor(mul_401, wait_tensor_179); mul_401 = wait_tensor_179 = None + convert_element_type_471 = torch.ops.prims.convert_element_type.default(mul_402, torch.bfloat16); mul_402 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16) + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 64, '0'); convert_element_type_472 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + view_572 = torch.ops.aten.view.default(convert_element_type_471, [8192, 2048]); convert_element_type_471 = None + mm_71 = torch.ops.aten.mm.default(view_572, permute_131); permute_131 = None + view_573 = torch.ops.aten.view.default(mm_71, [2, 4096, 3072]); mm_71 = None + view_574 = torch.ops.aten.view.default(view_573, [2, 4096, -1, 192]); view_573 = None + split_with_sizes_27 = torch.ops.aten.split_with_sizes.default(view_574, [128, 64], -1); view_574 = None + getitem_121 = split_with_sizes_27[0] + getitem_122 = split_with_sizes_27[1]; split_with_sizes_27 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(getitem_122, torch.float32); getitem_122 = None + view_575 = torch.ops.aten.view.default(convert_element_type_475, [2, 4096, 16, -1, 2]); convert_element_type_475 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_575); view_575 = None + mul_403 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_7); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_403); mul_403 = None + view_577 = torch.ops.aten.view.default(view_as_real_18, [2, 4096, 16, 64]); view_as_real_18 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_577, torch.bfloat16); view_577 = None + cat_26 = torch.ops.aten.cat.default([getitem_121, convert_element_type_476], -1); getitem_121 = convert_element_type_476 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16) + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_477, 64, '0'); convert_element_type_477 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_181, [1, 0]); wait_tensor_181 = None + mm_72 = torch.ops.aten.mm.default(view_572, permute_132); permute_132 = None + view_580 = torch.ops.aten.view.default(mm_72, [2, 4096, 576]); mm_72 = None + split_with_sizes_28 = torch.ops.aten.split_with_sizes.default(view_580, [512, 64], -1); view_580 = None + getitem_123 = split_with_sizes_28[0] + getitem_124 = split_with_sizes_28[1]; split_with_sizes_28 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(getitem_124, 2); getitem_124 = None + convert_element_type_480 = torch.ops.prims.convert_element_type.default(unsqueeze_17, torch.float32); unsqueeze_17 = None + view_581 = torch.ops.aten.view.default(convert_element_type_480, [2, 4096, 1, -1, 2]); convert_element_type_480 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_581); view_581 = None + mul_404 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_7); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_404); mul_404 = None + view_583 = torch.ops.aten.view.default(view_as_real_19, [2, 4096, 1, 64]); view_as_real_19 = None + convert_element_type_481 = torch.ops.prims.convert_element_type.default(view_583, torch.bfloat16); view_583 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16) + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 64, '0'); convert_element_type_482 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(getitem_123, torch.float32) + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_551 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_551); add_551 = None + mul_405 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_28); convert_element_type_483 = None + mul_406 = torch.ops.aten.mul.Tensor(mul_405, wait_tensor_182); mul_405 = wait_tensor_182 = None + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_406, torch.bfloat16); mul_406 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16) + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 64, '0'); convert_element_type_485 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + view_586 = torch.ops.aten.view.default(convert_element_type_484, [8192, 512]); convert_element_type_484 = None + mm_73 = torch.ops.aten.mm.default(view_586, permute_133); permute_133 = None + view_587 = torch.ops.aten.view.default(mm_73, [2, 4096, 4096]); mm_73 = None + view_588 = torch.ops.aten.view.default(view_587, [2, 4096, -1, 256]); view_587 = None + split_with_sizes_29 = torch.ops.aten.split_with_sizes.default(view_588, [128, 128], -1); view_588 = None + getitem_125 = split_with_sizes_29[0] + getitem_126 = split_with_sizes_29[1]; split_with_sizes_29 = None + expand_9 = torch.ops.aten.expand.default(convert_element_type_481, [-1, -1, 16, -1]); convert_element_type_481 = None + cat_27 = torch.ops.aten.cat.default([getitem_125, expand_9], -1); getitem_125 = expand_9 = None + permute_134 = torch.ops.aten.permute.default(cat_26, [0, 2, 1, 3]); cat_26 = None + permute_135 = torch.ops.aten.permute.default(cat_27, [0, 2, 1, 3]); cat_27 = None + permute_136 = torch.ops.aten.permute.default(getitem_126, [0, 2, 1, 3]); getitem_126 = None + sdpa_score9 = self.sdpa_score9 + sdpa_mask9 = self.sdpa_mask9 + flex_attention_9 = torch.ops.higher_order.flex_attention(permute_134, permute_135, permute_136, sdpa_score9, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask9), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score9 = sdpa_mask9 = None + getitem_127 = flex_attention_9[0] + getitem_128 = flex_attention_9[1]; flex_attention_9 = None + permute_137 = torch.ops.aten.permute.default(getitem_127, [0, 2, 1, 3]) + view_589 = torch.ops.aten.view.default(permute_137, [2, 4096, -1]); permute_137 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16) + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_488, 64, '0'); convert_element_type_488 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_138 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + view_591 = torch.ops.aten.view.default(view_589, [8192, 2048]); view_589 = None + mm_74 = torch.ops.aten.mm.default(view_591, permute_138); view_591 = permute_138 = None + view_592 = torch.ops.aten.view.default(mm_74, [2, 4096, 2048]); mm_74 = None + add_552 = torch.ops.aten.add.Tensor(add_549, view_592); view_592 = None + convert_element_type_491 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16) + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_491, 64, '0'); convert_element_type_491 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + convert_element_type_492 = torch.ops.prims.convert_element_type.default(add_552, torch.float32) + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_492, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_553 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_553); add_553 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_492, rsqrt_29); convert_element_type_492 = None + mul_408 = torch.ops.aten.mul.Tensor(mul_407, wait_tensor_185); mul_407 = wait_tensor_185 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(mul_408, torch.bfloat16); mul_408 = None + view_594 = torch.ops.aten.view.default(convert_element_type_493, [-1, 2048]); convert_element_type_493 = None + convert_element_type_494 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16) + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_494, 64, '0'); convert_element_type_494 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_186, [1, 0]); wait_tensor_186 = None + mm_75 = torch.ops.aten.mm.default(view_594, permute_139); permute_139 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(mm_75, torch.float32) + amax_8 = torch.ops.aten.amax.default(convert_element_type_497, [1], True) + sub_192 = torch.ops.aten.sub.Tensor(convert_element_type_497, amax_8); convert_element_type_497 = None + exp_25 = torch.ops.aten.exp.default(sub_192); sub_192 = None + sum_33 = torch.ops.aten.sum.dim_IntList(exp_25, [1], True) + div_41 = torch.ops.aten.div.Tensor(exp_25, sum_33); exp_25 = None + add_554 = torch.ops.aten.add.Tensor(div_41, primals_158); primals_158 = None + topk_8 = torch.ops.aten.topk.default(add_554, 6, -1, True, False); add_554 = None + getitem_131 = topk_8[1]; topk_8 = None + gather_8 = torch.ops.aten.gather.default(div_41, 1, getitem_131); div_41 = None + mul_409 = torch.ops.aten.mul.Tensor(gather_8, 1.0); gather_8 = None + view_596 = torch.ops.aten.view.default(getitem_131, [-1]) + histc_16 = torch.ops.aten.histc.default(view_596, 64, 0, 64) + add_555 = torch.ops.aten.add.Tensor(primals_160, histc_16) + sort_8 = torch.ops.aten.sort.stable(view_596, stable = True); view_596 = None + getitem_133 = sort_8[1]; sort_8 = None + div_42 = torch.ops.aten.div.Tensor_mode(getitem_133, 6, rounding_mode = 'floor') + index_16 = torch.ops.aten.index.Tensor(view_594, [div_42]) + all_to_all_single_24 = torch.ops._c10d_functional.all_to_all_single.default(histc_16, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_24); all_to_all_single_24 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_187); wait_tensor_187 = None + view_600 = torch.ops.aten.view.default(histc_16, [8, -1]); histc_16 = None + sum_34 = torch.ops.aten.sum.dim_IntList(view_600, [1]); view_600 = None + device_put_16 = torch.ops.prims.device_put.default(sum_34, device(type='cpu'), True); sum_34 = None + view_601 = torch.ops.aten.view.default(wait_tensor_188, [8, -1]) + sum_35 = torch.ops.aten.sum.dim_IntList(view_601, [1]) + device_put_17 = torch.ops.prims.device_put.default(sum_35, device(type='cpu')); sum_35 = None + select_128 = torch.ops.aten.select.int(device_put_16, 0, 0) + _local_scalar_dense_128 = torch.ops.aten._local_scalar_dense.default(select_128); select_128 = None + ge_160 = _local_scalar_dense_128 >= 0 + _assert_scalar_128 = torch.ops.aten._assert_scalar.default(ge_160, "Runtime assertion failed for expression u128 >= 0 on node 'ge_128'"); ge_160 = _assert_scalar_128 = None + select_129 = torch.ops.aten.select.int(device_put_16, 0, 1) + _local_scalar_dense_129 = torch.ops.aten._local_scalar_dense.default(select_129); select_129 = None + ge_161 = _local_scalar_dense_129 >= 0 + _assert_scalar_129 = torch.ops.aten._assert_scalar.default(ge_161, "Runtime assertion failed for expression u129 >= 0 on node 'ge_129'"); ge_161 = _assert_scalar_129 = None + select_130 = torch.ops.aten.select.int(device_put_16, 0, 2) + _local_scalar_dense_130 = torch.ops.aten._local_scalar_dense.default(select_130); select_130 = None + ge_162 = _local_scalar_dense_130 >= 0 + _assert_scalar_130 = torch.ops.aten._assert_scalar.default(ge_162, "Runtime assertion failed for expression u130 >= 0 on node 'ge_130'"); ge_162 = _assert_scalar_130 = None + select_131 = torch.ops.aten.select.int(device_put_16, 0, 3) + _local_scalar_dense_131 = torch.ops.aten._local_scalar_dense.default(select_131); select_131 = None + ge_163 = _local_scalar_dense_131 >= 0 + _assert_scalar_131 = torch.ops.aten._assert_scalar.default(ge_163, "Runtime assertion failed for expression u131 >= 0 on node 'ge_131'"); ge_163 = _assert_scalar_131 = None + select_132 = torch.ops.aten.select.int(device_put_16, 0, 4) + _local_scalar_dense_132 = torch.ops.aten._local_scalar_dense.default(select_132); select_132 = None + ge_164 = _local_scalar_dense_132 >= 0 + _assert_scalar_132 = torch.ops.aten._assert_scalar.default(ge_164, "Runtime assertion failed for expression u132 >= 0 on node 'ge_132'"); ge_164 = _assert_scalar_132 = None + select_133 = torch.ops.aten.select.int(device_put_16, 0, 5) + _local_scalar_dense_133 = torch.ops.aten._local_scalar_dense.default(select_133); select_133 = None + ge_165 = _local_scalar_dense_133 >= 0 + _assert_scalar_133 = torch.ops.aten._assert_scalar.default(ge_165, "Runtime assertion failed for expression u133 >= 0 on node 'ge_133'"); ge_165 = _assert_scalar_133 = None + select_134 = torch.ops.aten.select.int(device_put_16, 0, 6) + _local_scalar_dense_134 = torch.ops.aten._local_scalar_dense.default(select_134); select_134 = None + ge_166 = _local_scalar_dense_134 >= 0 + _assert_scalar_134 = torch.ops.aten._assert_scalar.default(ge_166, "Runtime assertion failed for expression u134 >= 0 on node 'ge_134'"); ge_166 = _assert_scalar_134 = None + select_135 = torch.ops.aten.select.int(device_put_16, 0, 7); device_put_16 = None + _local_scalar_dense_135 = torch.ops.aten._local_scalar_dense.default(select_135); select_135 = None + ge_167 = _local_scalar_dense_135 >= 0 + _assert_scalar_135 = torch.ops.aten._assert_scalar.default(ge_167, "Runtime assertion failed for expression u135 >= 0 on node 'ge_135'"); ge_167 = _assert_scalar_135 = None + select_136 = torch.ops.aten.select.int(device_put_17, 0, 0) + _local_scalar_dense_136 = torch.ops.aten._local_scalar_dense.default(select_136); select_136 = None + ge_168 = _local_scalar_dense_136 >= 0 + _assert_scalar_136 = torch.ops.aten._assert_scalar.default(ge_168, "Runtime assertion failed for expression u136 >= 0 on node 'ge_136'"); ge_168 = _assert_scalar_136 = None + select_137 = torch.ops.aten.select.int(device_put_17, 0, 1) + _local_scalar_dense_137 = torch.ops.aten._local_scalar_dense.default(select_137); select_137 = None + ge_169 = _local_scalar_dense_137 >= 0 + _assert_scalar_137 = torch.ops.aten._assert_scalar.default(ge_169, "Runtime assertion failed for expression u137 >= 0 on node 'ge_137'"); ge_169 = _assert_scalar_137 = None + select_138 = torch.ops.aten.select.int(device_put_17, 0, 2) + _local_scalar_dense_138 = torch.ops.aten._local_scalar_dense.default(select_138); select_138 = None + ge_170 = _local_scalar_dense_138 >= 0 + _assert_scalar_138 = torch.ops.aten._assert_scalar.default(ge_170, "Runtime assertion failed for expression u138 >= 0 on node 'ge_138'"); ge_170 = _assert_scalar_138 = None + select_139 = torch.ops.aten.select.int(device_put_17, 0, 3) + _local_scalar_dense_139 = torch.ops.aten._local_scalar_dense.default(select_139); select_139 = None + ge_171 = _local_scalar_dense_139 >= 0 + _assert_scalar_139 = torch.ops.aten._assert_scalar.default(ge_171, "Runtime assertion failed for expression u139 >= 0 on node 'ge_139'"); ge_171 = _assert_scalar_139 = None + select_140 = torch.ops.aten.select.int(device_put_17, 0, 4) + _local_scalar_dense_140 = torch.ops.aten._local_scalar_dense.default(select_140); select_140 = None + ge_172 = _local_scalar_dense_140 >= 0 + _assert_scalar_140 = torch.ops.aten._assert_scalar.default(ge_172, "Runtime assertion failed for expression u140 >= 0 on node 'ge_140'"); ge_172 = _assert_scalar_140 = None + select_141 = torch.ops.aten.select.int(device_put_17, 0, 5) + _local_scalar_dense_141 = torch.ops.aten._local_scalar_dense.default(select_141); select_141 = None + ge_173 = _local_scalar_dense_141 >= 0 + _assert_scalar_141 = torch.ops.aten._assert_scalar.default(ge_173, "Runtime assertion failed for expression u141 >= 0 on node 'ge_141'"); ge_173 = _assert_scalar_141 = None + select_142 = torch.ops.aten.select.int(device_put_17, 0, 6) + _local_scalar_dense_142 = torch.ops.aten._local_scalar_dense.default(select_142); select_142 = None + ge_174 = _local_scalar_dense_142 >= 0 + _assert_scalar_142 = torch.ops.aten._assert_scalar.default(ge_174, "Runtime assertion failed for expression u142 >= 0 on node 'ge_142'"); ge_174 = _assert_scalar_142 = None + select_143 = torch.ops.aten.select.int(device_put_17, 0, 7); device_put_17 = None + _local_scalar_dense_143 = torch.ops.aten._local_scalar_dense.default(select_143); select_143 = None + ge_175 = _local_scalar_dense_143 >= 0 + _assert_scalar_143 = torch.ops.aten._assert_scalar.default(ge_175, "Runtime assertion failed for expression u143 >= 0 on node 'ge_143'"); ge_175 = _assert_scalar_143 = None + all_to_all_single_25 = torch.ops._c10d_functional.all_to_all_single.default(index_16, [_local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143], [_local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135], '521'); index_16 = None + sym_size_int_32 = torch.ops.aten.sym_size.int(all_to_all_single_25, 0) + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_25); all_to_all_single_25 = None + sym_sum_16 = torch.sym_sum((_local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143)) + add_562 = sym_sum_16 + 64; sym_sum_16 = None + add_563 = add_562 + 8; add_562 = None + sub_195 = add_563 - 1; add_563 = None + floordiv_8 = sub_195 // 8; sub_195 = None + mul_414 = floordiv_8 * 8; floordiv_8 = None + cumsum_24 = torch.ops.aten.cumsum.default(wait_tensor_188, 0) + sub_196 = torch.ops.aten.sub.Tensor(cumsum_24, wait_tensor_188); cumsum_24 = None + sum_36 = torch.ops.aten.sum.dim_IntList(view_601, [0]); view_601 = None + clamp_min_8 = torch.ops.aten.clamp_min.default(sum_36, 8); sum_36 = None + add_564 = torch.ops.aten.add.Tensor(clamp_min_8, 8); clamp_min_8 = None + sub_197 = torch.ops.aten.sub.Tensor(add_564, 1); add_564 = None + div_43 = torch.ops.aten.div.Tensor_mode(sub_197, 8, rounding_mode = 'floor'); sub_197 = None + mul_415 = torch.ops.aten.mul.Tensor(div_43, 8); div_43 = None + convert_element_type_500 = torch.ops.prims.convert_element_type.default(mul_415, torch.int32); mul_415 = None + cumsum_25 = torch.ops.aten.cumsum.default(convert_element_type_500, 0) + sub_198 = torch.ops.aten.sub.Tensor(cumsum_25, convert_element_type_500); cumsum_25 = None + full_124 = torch.ops.aten.full.default([mul_414], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_414 = None + triton_kernel_wrapper_functional_proxy_8 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 8, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_188, 'start_index_values_ptr': sub_196, 'write_offsets_ptr': sub_198, 'output_ptr': full_124}, tensors_to_clone = ['output_ptr']); wait_tensor_188 = sub_196 = sub_198 = full_124 = None + getitem_134 = triton_kernel_wrapper_functional_proxy_8['output_ptr']; triton_kernel_wrapper_functional_proxy_8 = None + cat_28 = torch.ops.aten.cat.default([wait_tensor_189, full_default]); wait_tensor_189 = None + sym_size_int_33 = torch.ops.aten.sym_size.int(cat_28, 0) + sym_sum_17 = torch.sym_sum((1, _local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143)) + index_17 = torch.ops.aten.index.Tensor(cat_28, [getitem_134]); cat_28 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16) + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 8, '513'); convert_element_type_502 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + convert_element_type_504 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16) + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_504, 8, '513'); convert_element_type_504 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16) + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 8, '513'); convert_element_type_505 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + cumsum_26 = torch.ops.aten.cumsum.default(convert_element_type_500, 0, dtype = torch.int32); convert_element_type_500 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_190, [0, 2, 1]); wait_tensor_190 = None + _grouped_mm_24 = torch.ops.aten._grouped_mm.default(index_17, permute_140, cumsum_26); permute_140 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(_grouped_mm_24, torch.float32) + neg_17 = torch.ops.aten.neg.default(convert_element_type_508) + exp_26 = torch.ops.aten.exp.default(neg_17); neg_17 = None + add_576 = torch.ops.aten.add.Tensor(exp_26, 1); exp_26 = None + div_44 = torch.ops.aten.div.Tensor(convert_element_type_508, add_576); convert_element_type_508 = add_576 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(div_44, torch.bfloat16); div_44 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_193, [0, 2, 1]); wait_tensor_193 = None + _grouped_mm_25 = torch.ops.aten._grouped_mm.default(index_17, permute_141, cumsum_26); permute_141 = None + mul_427 = torch.ops.aten.mul.Tensor(convert_element_type_509, _grouped_mm_25); convert_element_type_509 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_192, [0, 2, 1]); wait_tensor_192 = None + _grouped_mm_26 = torch.ops.aten._grouped_mm.default(mul_427, permute_142, cumsum_26); permute_142 = None + empty_8 = torch.ops.aten.empty.memory_format([sym_size_int_33, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_16 = torch.ops.aten.index_put.default(empty_8, [getitem_134], _grouped_mm_26); empty_8 = _grouped_mm_26 = None + slice_38 = torch.ops.aten.slice.Tensor(index_put_16, 0, 0, -1); index_put_16 = None + all_to_all_single_26 = torch.ops._c10d_functional.all_to_all_single.default(slice_38, [_local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135], [_local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143], '521'); slice_38 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_26); all_to_all_single_26 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16) + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_510, 64, '0'); convert_element_type_510 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + mm_76 = torch.ops.aten.mm.default(view_594, permute_143); permute_143 = None + convert_element_type_513 = torch.ops.prims.convert_element_type.default(mm_76, torch.float32) + neg_18 = torch.ops.aten.neg.default(convert_element_type_513) + exp_27 = torch.ops.aten.exp.default(neg_18); neg_18 = None + add_612 = torch.ops.aten.add.Tensor(exp_27, 1); exp_27 = None + div_45 = torch.ops.aten.div.Tensor(convert_element_type_513, add_612); convert_element_type_513 = add_612 = None + convert_element_type_514 = torch.ops.prims.convert_element_type.default(div_45, torch.bfloat16); div_45 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16) + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 64, '0'); convert_element_type_515 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + mm_77 = torch.ops.aten.mm.default(view_594, permute_144); permute_144 = None + mul_447 = torch.ops.aten.mul.Tensor(convert_element_type_514, mm_77); convert_element_type_514 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16) + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 64, '0'); convert_element_type_518 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + mm_78 = torch.ops.aten.mm.default(mul_447, permute_145); permute_145 = None + index_put_17 = torch.ops.aten.index_put.default(full_default_1, [getitem_133], wait_tensor_196); wait_tensor_196 = None + view_634 = torch.ops.aten.view.default(mul_409, [-1, 1, 6]); mul_409 = None + view_635 = torch.ops.aten.view.default(index_put_17, [-1, 6, 2048]); index_put_17 = None + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_635, torch.float32); view_635 = None + bmm_8 = torch.ops.aten.bmm.default(view_634, convert_element_type_521) + convert_element_type_522 = torch.ops.prims.convert_element_type.default(bmm_8, torch.bfloat16); bmm_8 = None + squeeze_8 = torch.ops.aten.squeeze.dim(convert_element_type_522, 1); convert_element_type_522 = None + add_616 = torch.ops.aten.add.Tensor(mm_78, squeeze_8); mm_78 = squeeze_8 = None + view_636 = torch.ops.aten.view.default(add_616, [2, 4096, 2048]); add_616 = None + add_617 = torch.ops.aten.add.Tensor(add_552, view_636); view_636 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16) + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 64, '0'); convert_element_type_523 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + convert_element_type_524 = torch.ops.prims.convert_element_type.default(add_617, torch.float32) + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_524, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_618 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_618); add_618 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_524, rsqrt_30); convert_element_type_524 = None + mul_451 = torch.ops.aten.mul.Tensor(mul_450, wait_tensor_200); mul_450 = wait_tensor_200 = None + convert_element_type_525 = torch.ops.prims.convert_element_type.default(mul_451, torch.bfloat16); mul_451 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16) + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 64, '0'); convert_element_type_526 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_146 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + view_639 = torch.ops.aten.view.default(convert_element_type_525, [8192, 2048]); convert_element_type_525 = None + mm_79 = torch.ops.aten.mm.default(view_639, permute_146); permute_146 = None + view_640 = torch.ops.aten.view.default(mm_79, [2, 4096, 3072]); mm_79 = None + view_641 = torch.ops.aten.view.default(view_640, [2, 4096, -1, 192]); view_640 = None + split_with_sizes_30 = torch.ops.aten.split_with_sizes.default(view_641, [128, 64], -1); view_641 = None + getitem_135 = split_with_sizes_30[0] + getitem_136 = split_with_sizes_30[1]; split_with_sizes_30 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(getitem_136, torch.float32); getitem_136 = None + view_642 = torch.ops.aten.view.default(convert_element_type_529, [2, 4096, 16, -1, 2]); convert_element_type_529 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_642); view_642 = None + mul_452 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_7); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_452); mul_452 = None + view_644 = torch.ops.aten.view.default(view_as_real_20, [2, 4096, 16, 64]); view_as_real_20 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(view_644, torch.bfloat16); view_644 = None + cat_29 = torch.ops.aten.cat.default([getitem_135, convert_element_type_530], -1); getitem_135 = convert_element_type_530 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16) + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_531, 64, '0'); convert_element_type_531 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + permute_147 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + mm_80 = torch.ops.aten.mm.default(view_639, permute_147); permute_147 = None + view_647 = torch.ops.aten.view.default(mm_80, [2, 4096, 576]); mm_80 = None + split_with_sizes_31 = torch.ops.aten.split_with_sizes.default(view_647, [512, 64], -1); view_647 = None + getitem_137 = split_with_sizes_31[0] + getitem_138 = split_with_sizes_31[1]; split_with_sizes_31 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(getitem_138, 2); getitem_138 = None + convert_element_type_534 = torch.ops.prims.convert_element_type.default(unsqueeze_19, torch.float32); unsqueeze_19 = None + view_648 = torch.ops.aten.view.default(convert_element_type_534, [2, 4096, 1, -1, 2]); convert_element_type_534 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_648); view_648 = None + mul_453 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_7); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_453); mul_453 = None + view_650 = torch.ops.aten.view.default(view_as_real_21, [2, 4096, 1, 64]); view_as_real_21 = None + convert_element_type_535 = torch.ops.prims.convert_element_type.default(view_650, torch.bfloat16); view_650 = None + convert_element_type_536 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16) + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_536, 64, '0'); convert_element_type_536 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + convert_element_type_537 = torch.ops.prims.convert_element_type.default(getitem_137, torch.float32) + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_537, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_619 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_619); add_619 = None + mul_454 = torch.ops.aten.mul.Tensor(convert_element_type_537, rsqrt_31); convert_element_type_537 = None + mul_455 = torch.ops.aten.mul.Tensor(mul_454, wait_tensor_203); mul_454 = wait_tensor_203 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(mul_455, torch.bfloat16); mul_455 = None + convert_element_type_539 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16) + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_539, 64, '0'); convert_element_type_539 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_148 = torch.ops.aten.permute.default(wait_tensor_204, [1, 0]); wait_tensor_204 = None + view_653 = torch.ops.aten.view.default(convert_element_type_538, [8192, 512]); convert_element_type_538 = None + mm_81 = torch.ops.aten.mm.default(view_653, permute_148); permute_148 = None + view_654 = torch.ops.aten.view.default(mm_81, [2, 4096, 4096]); mm_81 = None + view_655 = torch.ops.aten.view.default(view_654, [2, 4096, -1, 256]); view_654 = None + split_with_sizes_32 = torch.ops.aten.split_with_sizes.default(view_655, [128, 128], -1); view_655 = None + getitem_139 = split_with_sizes_32[0] + getitem_140 = split_with_sizes_32[1]; split_with_sizes_32 = None + expand_10 = torch.ops.aten.expand.default(convert_element_type_535, [-1, -1, 16, -1]); convert_element_type_535 = None + cat_30 = torch.ops.aten.cat.default([getitem_139, expand_10], -1); getitem_139 = expand_10 = None + permute_149 = torch.ops.aten.permute.default(cat_29, [0, 2, 1, 3]); cat_29 = None + permute_150 = torch.ops.aten.permute.default(cat_30, [0, 2, 1, 3]); cat_30 = None + permute_151 = torch.ops.aten.permute.default(getitem_140, [0, 2, 1, 3]); getitem_140 = None + sdpa_score10 = self.sdpa_score10 + sdpa_mask10 = self.sdpa_mask10 + flex_attention_10 = torch.ops.higher_order.flex_attention(permute_149, permute_150, permute_151, sdpa_score10, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask10), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score10 = sdpa_mask10 = None + getitem_141 = flex_attention_10[0] + getitem_142 = flex_attention_10[1]; flex_attention_10 = None + permute_152 = torch.ops.aten.permute.default(getitem_141, [0, 2, 1, 3]) + view_656 = torch.ops.aten.view.default(permute_152, [2, 4096, -1]); permute_152 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16) + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_542, 64, '0'); convert_element_type_542 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + view_658 = torch.ops.aten.view.default(view_656, [8192, 2048]); view_656 = None + mm_82 = torch.ops.aten.mm.default(view_658, permute_153); view_658 = permute_153 = None + view_659 = torch.ops.aten.view.default(mm_82, [2, 4096, 2048]); mm_82 = None + add_620 = torch.ops.aten.add.Tensor(add_617, view_659); view_659 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16) + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 64, '0'); convert_element_type_545 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + convert_element_type_546 = torch.ops.prims.convert_element_type.default(add_620, torch.float32) + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_546, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_621 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_621); add_621 = None + mul_456 = torch.ops.aten.mul.Tensor(convert_element_type_546, rsqrt_32); convert_element_type_546 = None + mul_457 = torch.ops.aten.mul.Tensor(mul_456, wait_tensor_206); mul_456 = wait_tensor_206 = None + convert_element_type_547 = torch.ops.prims.convert_element_type.default(mul_457, torch.bfloat16); mul_457 = None + view_661 = torch.ops.aten.view.default(convert_element_type_547, [-1, 2048]); convert_element_type_547 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16) + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 64, '0'); convert_element_type_548 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + mm_83 = torch.ops.aten.mm.default(view_661, permute_154); permute_154 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(mm_83, torch.float32) + amax_9 = torch.ops.aten.amax.default(convert_element_type_551, [1], True) + sub_216 = torch.ops.aten.sub.Tensor(convert_element_type_551, amax_9); convert_element_type_551 = None + exp_28 = torch.ops.aten.exp.default(sub_216); sub_216 = None + sum_37 = torch.ops.aten.sum.dim_IntList(exp_28, [1], True) + div_46 = torch.ops.aten.div.Tensor(exp_28, sum_37); exp_28 = None + add_622 = torch.ops.aten.add.Tensor(div_46, primals_174); primals_174 = None + topk_9 = torch.ops.aten.topk.default(add_622, 6, -1, True, False); add_622 = None + getitem_145 = topk_9[1]; topk_9 = None + gather_9 = torch.ops.aten.gather.default(div_46, 1, getitem_145); div_46 = None + mul_458 = torch.ops.aten.mul.Tensor(gather_9, 1.0); gather_9 = None + view_663 = torch.ops.aten.view.default(getitem_145, [-1]) + histc_18 = torch.ops.aten.histc.default(view_663, 64, 0, 64) + add_623 = torch.ops.aten.add.Tensor(primals_176, histc_18) + sort_9 = torch.ops.aten.sort.stable(view_663, stable = True); view_663 = None + getitem_147 = sort_9[1]; sort_9 = None + div_47 = torch.ops.aten.div.Tensor_mode(getitem_147, 6, rounding_mode = 'floor') + index_18 = torch.ops.aten.index.Tensor(view_661, [div_47]) + all_to_all_single_27 = torch.ops._c10d_functional.all_to_all_single.default(histc_18, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_27); all_to_all_single_27 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_208); wait_tensor_208 = None + view_667 = torch.ops.aten.view.default(histc_18, [8, -1]); histc_18 = None + sum_38 = torch.ops.aten.sum.dim_IntList(view_667, [1]); view_667 = None + device_put_18 = torch.ops.prims.device_put.default(sum_38, device(type='cpu'), True); sum_38 = None + view_668 = torch.ops.aten.view.default(wait_tensor_209, [8, -1]) + sum_39 = torch.ops.aten.sum.dim_IntList(view_668, [1]) + device_put_19 = torch.ops.prims.device_put.default(sum_39, device(type='cpu')); sum_39 = None + select_144 = torch.ops.aten.select.int(device_put_18, 0, 0) + _local_scalar_dense_144 = torch.ops.aten._local_scalar_dense.default(select_144); select_144 = None + ge_180 = _local_scalar_dense_144 >= 0 + _assert_scalar_144 = torch.ops.aten._assert_scalar.default(ge_180, "Runtime assertion failed for expression u144 >= 0 on node 'ge_144'"); ge_180 = _assert_scalar_144 = None + select_145 = torch.ops.aten.select.int(device_put_18, 0, 1) + _local_scalar_dense_145 = torch.ops.aten._local_scalar_dense.default(select_145); select_145 = None + ge_181 = _local_scalar_dense_145 >= 0 + _assert_scalar_145 = torch.ops.aten._assert_scalar.default(ge_181, "Runtime assertion failed for expression u145 >= 0 on node 'ge_145'"); ge_181 = _assert_scalar_145 = None + select_146 = torch.ops.aten.select.int(device_put_18, 0, 2) + _local_scalar_dense_146 = torch.ops.aten._local_scalar_dense.default(select_146); select_146 = None + ge_182 = _local_scalar_dense_146 >= 0 + _assert_scalar_146 = torch.ops.aten._assert_scalar.default(ge_182, "Runtime assertion failed for expression u146 >= 0 on node 'ge_146'"); ge_182 = _assert_scalar_146 = None + select_147 = torch.ops.aten.select.int(device_put_18, 0, 3) + _local_scalar_dense_147 = torch.ops.aten._local_scalar_dense.default(select_147); select_147 = None + ge_183 = _local_scalar_dense_147 >= 0 + _assert_scalar_147 = torch.ops.aten._assert_scalar.default(ge_183, "Runtime assertion failed for expression u147 >= 0 on node 'ge_147'"); ge_183 = _assert_scalar_147 = None + select_148 = torch.ops.aten.select.int(device_put_18, 0, 4) + _local_scalar_dense_148 = torch.ops.aten._local_scalar_dense.default(select_148); select_148 = None + ge_184 = _local_scalar_dense_148 >= 0 + _assert_scalar_148 = torch.ops.aten._assert_scalar.default(ge_184, "Runtime assertion failed for expression u148 >= 0 on node 'ge_148'"); ge_184 = _assert_scalar_148 = None + select_149 = torch.ops.aten.select.int(device_put_18, 0, 5) + _local_scalar_dense_149 = torch.ops.aten._local_scalar_dense.default(select_149); select_149 = None + ge_185 = _local_scalar_dense_149 >= 0 + _assert_scalar_149 = torch.ops.aten._assert_scalar.default(ge_185, "Runtime assertion failed for expression u149 >= 0 on node 'ge_149'"); ge_185 = _assert_scalar_149 = None + select_150 = torch.ops.aten.select.int(device_put_18, 0, 6) + _local_scalar_dense_150 = torch.ops.aten._local_scalar_dense.default(select_150); select_150 = None + ge_186 = _local_scalar_dense_150 >= 0 + _assert_scalar_150 = torch.ops.aten._assert_scalar.default(ge_186, "Runtime assertion failed for expression u150 >= 0 on node 'ge_150'"); ge_186 = _assert_scalar_150 = None + select_151 = torch.ops.aten.select.int(device_put_18, 0, 7); device_put_18 = None + _local_scalar_dense_151 = torch.ops.aten._local_scalar_dense.default(select_151); select_151 = None + ge_187 = _local_scalar_dense_151 >= 0 + _assert_scalar_151 = torch.ops.aten._assert_scalar.default(ge_187, "Runtime assertion failed for expression u151 >= 0 on node 'ge_151'"); ge_187 = _assert_scalar_151 = None + select_152 = torch.ops.aten.select.int(device_put_19, 0, 0) + _local_scalar_dense_152 = torch.ops.aten._local_scalar_dense.default(select_152); select_152 = None + ge_188 = _local_scalar_dense_152 >= 0 + _assert_scalar_152 = torch.ops.aten._assert_scalar.default(ge_188, "Runtime assertion failed for expression u152 >= 0 on node 'ge_152'"); ge_188 = _assert_scalar_152 = None + select_153 = torch.ops.aten.select.int(device_put_19, 0, 1) + _local_scalar_dense_153 = torch.ops.aten._local_scalar_dense.default(select_153); select_153 = None + ge_189 = _local_scalar_dense_153 >= 0 + _assert_scalar_153 = torch.ops.aten._assert_scalar.default(ge_189, "Runtime assertion failed for expression u153 >= 0 on node 'ge_153'"); ge_189 = _assert_scalar_153 = None + select_154 = torch.ops.aten.select.int(device_put_19, 0, 2) + _local_scalar_dense_154 = torch.ops.aten._local_scalar_dense.default(select_154); select_154 = None + ge_190 = _local_scalar_dense_154 >= 0 + _assert_scalar_154 = torch.ops.aten._assert_scalar.default(ge_190, "Runtime assertion failed for expression u154 >= 0 on node 'ge_154'"); ge_190 = _assert_scalar_154 = None + select_155 = torch.ops.aten.select.int(device_put_19, 0, 3) + _local_scalar_dense_155 = torch.ops.aten._local_scalar_dense.default(select_155); select_155 = None + ge_191 = _local_scalar_dense_155 >= 0 + _assert_scalar_155 = torch.ops.aten._assert_scalar.default(ge_191, "Runtime assertion failed for expression u155 >= 0 on node 'ge_155'"); ge_191 = _assert_scalar_155 = None + select_156 = torch.ops.aten.select.int(device_put_19, 0, 4) + _local_scalar_dense_156 = torch.ops.aten._local_scalar_dense.default(select_156); select_156 = None + ge_192 = _local_scalar_dense_156 >= 0 + _assert_scalar_156 = torch.ops.aten._assert_scalar.default(ge_192, "Runtime assertion failed for expression u156 >= 0 on node 'ge_156'"); ge_192 = _assert_scalar_156 = None + select_157 = torch.ops.aten.select.int(device_put_19, 0, 5) + _local_scalar_dense_157 = torch.ops.aten._local_scalar_dense.default(select_157); select_157 = None + ge_193 = _local_scalar_dense_157 >= 0 + _assert_scalar_157 = torch.ops.aten._assert_scalar.default(ge_193, "Runtime assertion failed for expression u157 >= 0 on node 'ge_157'"); ge_193 = _assert_scalar_157 = None + select_158 = torch.ops.aten.select.int(device_put_19, 0, 6) + _local_scalar_dense_158 = torch.ops.aten._local_scalar_dense.default(select_158); select_158 = None + ge_194 = _local_scalar_dense_158 >= 0 + _assert_scalar_158 = torch.ops.aten._assert_scalar.default(ge_194, "Runtime assertion failed for expression u158 >= 0 on node 'ge_158'"); ge_194 = _assert_scalar_158 = None + select_159 = torch.ops.aten.select.int(device_put_19, 0, 7); device_put_19 = None + _local_scalar_dense_159 = torch.ops.aten._local_scalar_dense.default(select_159); select_159 = None + ge_195 = _local_scalar_dense_159 >= 0 + _assert_scalar_159 = torch.ops.aten._assert_scalar.default(ge_195, "Runtime assertion failed for expression u159 >= 0 on node 'ge_159'"); ge_195 = _assert_scalar_159 = None + all_to_all_single_28 = torch.ops._c10d_functional.all_to_all_single.default(index_18, [_local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159], [_local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151], '521'); index_18 = None + sym_size_int_36 = torch.ops.aten.sym_size.int(all_to_all_single_28, 0) + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_28); all_to_all_single_28 = None + sym_sum_18 = torch.sym_sum((_local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159)) + add_630 = sym_sum_18 + 64; sym_sum_18 = None + add_631 = add_630 + 8; add_630 = None + sub_219 = add_631 - 1; add_631 = None + floordiv_9 = sub_219 // 8; sub_219 = None + mul_463 = floordiv_9 * 8; floordiv_9 = None + cumsum_27 = torch.ops.aten.cumsum.default(wait_tensor_209, 0) + sub_220 = torch.ops.aten.sub.Tensor(cumsum_27, wait_tensor_209); cumsum_27 = None + sum_40 = torch.ops.aten.sum.dim_IntList(view_668, [0]); view_668 = None + clamp_min_9 = torch.ops.aten.clamp_min.default(sum_40, 8); sum_40 = None + add_632 = torch.ops.aten.add.Tensor(clamp_min_9, 8); clamp_min_9 = None + sub_221 = torch.ops.aten.sub.Tensor(add_632, 1); add_632 = None + div_48 = torch.ops.aten.div.Tensor_mode(sub_221, 8, rounding_mode = 'floor'); sub_221 = None + mul_464 = torch.ops.aten.mul.Tensor(div_48, 8); div_48 = None + convert_element_type_554 = torch.ops.prims.convert_element_type.default(mul_464, torch.int32); mul_464 = None + cumsum_28 = torch.ops.aten.cumsum.default(convert_element_type_554, 0) + sub_222 = torch.ops.aten.sub.Tensor(cumsum_28, convert_element_type_554); cumsum_28 = None + full_137 = torch.ops.aten.full.default([mul_463], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_463 = None + triton_kernel_wrapper_functional_proxy_9 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 9, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_209, 'start_index_values_ptr': sub_220, 'write_offsets_ptr': sub_222, 'output_ptr': full_137}, tensors_to_clone = ['output_ptr']); wait_tensor_209 = sub_220 = sub_222 = full_137 = None + getitem_148 = triton_kernel_wrapper_functional_proxy_9['output_ptr']; triton_kernel_wrapper_functional_proxy_9 = None + cat_31 = torch.ops.aten.cat.default([wait_tensor_210, full_default]); wait_tensor_210 = None + sym_size_int_37 = torch.ops.aten.sym_size.int(cat_31, 0) + sym_sum_19 = torch.sym_sum((1, _local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159)) + index_19 = torch.ops.aten.index.Tensor(cat_31, [getitem_148]); cat_31 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16) + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 8, '513'); convert_element_type_556 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_558 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16) + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_558, 8, '513'); convert_element_type_558 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16) + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 8, '513'); convert_element_type_559 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + cumsum_29 = torch.ops.aten.cumsum.default(convert_element_type_554, 0, dtype = torch.int32); convert_element_type_554 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_211, [0, 2, 1]); wait_tensor_211 = None + _grouped_mm_27 = torch.ops.aten._grouped_mm.default(index_19, permute_155, cumsum_29); permute_155 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(_grouped_mm_27, torch.float32) + neg_19 = torch.ops.aten.neg.default(convert_element_type_562) + exp_29 = torch.ops.aten.exp.default(neg_19); neg_19 = None + add_644 = torch.ops.aten.add.Tensor(exp_29, 1); exp_29 = None + div_49 = torch.ops.aten.div.Tensor(convert_element_type_562, add_644); convert_element_type_562 = add_644 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(div_49, torch.bfloat16); div_49 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_214, [0, 2, 1]); wait_tensor_214 = None + _grouped_mm_28 = torch.ops.aten._grouped_mm.default(index_19, permute_156, cumsum_29); permute_156 = None + mul_476 = torch.ops.aten.mul.Tensor(convert_element_type_563, _grouped_mm_28); convert_element_type_563 = None + permute_157 = torch.ops.aten.permute.default(wait_tensor_213, [0, 2, 1]); wait_tensor_213 = None + _grouped_mm_29 = torch.ops.aten._grouped_mm.default(mul_476, permute_157, cumsum_29); permute_157 = None + empty_9 = torch.ops.aten.empty.memory_format([sym_size_int_37, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_18 = torch.ops.aten.index_put.default(empty_9, [getitem_148], _grouped_mm_29); empty_9 = _grouped_mm_29 = None + slice_42 = torch.ops.aten.slice.Tensor(index_put_18, 0, 0, -1); index_put_18 = None + all_to_all_single_29 = torch.ops._c10d_functional.all_to_all_single.default(slice_42, [_local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151], [_local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159], '521'); slice_42 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_29); all_to_all_single_29 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16) + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_564, 64, '0'); convert_element_type_564 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_158 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + mm_84 = torch.ops.aten.mm.default(view_661, permute_158); permute_158 = None + convert_element_type_567 = torch.ops.prims.convert_element_type.default(mm_84, torch.float32) + neg_20 = torch.ops.aten.neg.default(convert_element_type_567) + exp_30 = torch.ops.aten.exp.default(neg_20); neg_20 = None + add_680 = torch.ops.aten.add.Tensor(exp_30, 1); exp_30 = None + div_50 = torch.ops.aten.div.Tensor(convert_element_type_567, add_680); convert_element_type_567 = add_680 = None + convert_element_type_568 = torch.ops.prims.convert_element_type.default(div_50, torch.bfloat16); div_50 = None + convert_element_type_569 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16) + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_569, 64, '0'); convert_element_type_569 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_159 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + mm_85 = torch.ops.aten.mm.default(view_661, permute_159); permute_159 = None + mul_496 = torch.ops.aten.mul.Tensor(convert_element_type_568, mm_85); convert_element_type_568 = None + convert_element_type_572 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16) + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_572, 64, '0'); convert_element_type_572 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_160 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_86 = torch.ops.aten.mm.default(mul_496, permute_160); permute_160 = None + index_put_19 = torch.ops.aten.index_put.default(full_default_1, [getitem_147], wait_tensor_217); wait_tensor_217 = None + view_701 = torch.ops.aten.view.default(mul_458, [-1, 1, 6]); mul_458 = None + view_702 = torch.ops.aten.view.default(index_put_19, [-1, 6, 2048]); index_put_19 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_702, torch.float32); view_702 = None + bmm_9 = torch.ops.aten.bmm.default(view_701, convert_element_type_575) + convert_element_type_576 = torch.ops.prims.convert_element_type.default(bmm_9, torch.bfloat16); bmm_9 = None + squeeze_9 = torch.ops.aten.squeeze.dim(convert_element_type_576, 1); convert_element_type_576 = None + add_684 = torch.ops.aten.add.Tensor(mm_86, squeeze_9); mm_86 = squeeze_9 = None + view_703 = torch.ops.aten.view.default(add_684, [2, 4096, 2048]); add_684 = None + add_685 = torch.ops.aten.add.Tensor(add_620, view_703); view_703 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16) + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_577, 64, '0'); convert_element_type_577 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(add_685, torch.float32) + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_578, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_686 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_686); add_686 = None + mul_499 = torch.ops.aten.mul.Tensor(convert_element_type_578, rsqrt_33); convert_element_type_578 = None + mul_500 = torch.ops.aten.mul.Tensor(mul_499, wait_tensor_221); mul_499 = wait_tensor_221 = None + convert_element_type_579 = torch.ops.prims.convert_element_type.default(mul_500, torch.bfloat16); mul_500 = None + convert_element_type_580 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16) + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_580, 64, '0'); convert_element_type_580 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_222, [1, 0]); wait_tensor_222 = None + view_706 = torch.ops.aten.view.default(convert_element_type_579, [8192, 2048]); convert_element_type_579 = None + mm_87 = torch.ops.aten.mm.default(view_706, permute_161); permute_161 = None + view_707 = torch.ops.aten.view.default(mm_87, [2, 4096, 3072]); mm_87 = None + view_708 = torch.ops.aten.view.default(view_707, [2, 4096, -1, 192]); view_707 = None + split_with_sizes_33 = torch.ops.aten.split_with_sizes.default(view_708, [128, 64], -1); view_708 = None + getitem_149 = split_with_sizes_33[0] + getitem_150 = split_with_sizes_33[1]; split_with_sizes_33 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(getitem_150, torch.float32); getitem_150 = None + view_709 = torch.ops.aten.view.default(convert_element_type_583, [2, 4096, 16, -1, 2]); convert_element_type_583 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_709); view_709 = None + mul_501 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_7); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_501); mul_501 = None + view_711 = torch.ops.aten.view.default(view_as_real_22, [2, 4096, 16, 64]); view_as_real_22 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(view_711, torch.bfloat16); view_711 = None + cat_32 = torch.ops.aten.cat.default([getitem_149, convert_element_type_584], -1); getitem_149 = convert_element_type_584 = None + convert_element_type_585 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16) + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_585, 64, '0'); convert_element_type_585 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_223, [1, 0]); wait_tensor_223 = None + mm_88 = torch.ops.aten.mm.default(view_706, permute_162); permute_162 = None + view_714 = torch.ops.aten.view.default(mm_88, [2, 4096, 576]); mm_88 = None + split_with_sizes_34 = torch.ops.aten.split_with_sizes.default(view_714, [512, 64], -1); view_714 = None + getitem_151 = split_with_sizes_34[0] + getitem_152 = split_with_sizes_34[1]; split_with_sizes_34 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(getitem_152, 2); getitem_152 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(unsqueeze_21, torch.float32); unsqueeze_21 = None + view_715 = torch.ops.aten.view.default(convert_element_type_588, [2, 4096, 1, -1, 2]); convert_element_type_588 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_715); view_715 = None + mul_502 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_7); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_502); mul_502 = None + view_717 = torch.ops.aten.view.default(view_as_real_23, [2, 4096, 1, 64]); view_as_real_23 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(view_717, torch.bfloat16); view_717 = None + convert_element_type_590 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16) + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_590, 64, '0'); convert_element_type_590 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + convert_element_type_591 = torch.ops.prims.convert_element_type.default(getitem_151, torch.float32) + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_591, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_687 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_687); add_687 = None + mul_503 = torch.ops.aten.mul.Tensor(convert_element_type_591, rsqrt_34); convert_element_type_591 = None + mul_504 = torch.ops.aten.mul.Tensor(mul_503, wait_tensor_224); mul_503 = wait_tensor_224 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(mul_504, torch.bfloat16); mul_504 = None + convert_element_type_593 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16) + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_593, 64, '0'); convert_element_type_593 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + view_720 = torch.ops.aten.view.default(convert_element_type_592, [8192, 512]); convert_element_type_592 = None + mm_89 = torch.ops.aten.mm.default(view_720, permute_163); permute_163 = None + view_721 = torch.ops.aten.view.default(mm_89, [2, 4096, 4096]); mm_89 = None + view_722 = torch.ops.aten.view.default(view_721, [2, 4096, -1, 256]); view_721 = None + split_with_sizes_35 = torch.ops.aten.split_with_sizes.default(view_722, [128, 128], -1); view_722 = None + getitem_153 = split_with_sizes_35[0] + getitem_154 = split_with_sizes_35[1]; split_with_sizes_35 = None + expand_11 = torch.ops.aten.expand.default(convert_element_type_589, [-1, -1, 16, -1]); convert_element_type_589 = None + cat_33 = torch.ops.aten.cat.default([getitem_153, expand_11], -1); getitem_153 = expand_11 = None + permute_164 = torch.ops.aten.permute.default(cat_32, [0, 2, 1, 3]); cat_32 = None + permute_165 = torch.ops.aten.permute.default(cat_33, [0, 2, 1, 3]); cat_33 = None + permute_166 = torch.ops.aten.permute.default(getitem_154, [0, 2, 1, 3]); getitem_154 = None + sdpa_score11 = self.sdpa_score11 + sdpa_mask11 = self.sdpa_mask11 + flex_attention_11 = torch.ops.higher_order.flex_attention(permute_164, permute_165, permute_166, sdpa_score11, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask11), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score11 = sdpa_mask11 = None + getitem_155 = flex_attention_11[0] + getitem_156 = flex_attention_11[1]; flex_attention_11 = None + permute_167 = torch.ops.aten.permute.default(getitem_155, [0, 2, 1, 3]) + view_723 = torch.ops.aten.view.default(permute_167, [2, 4096, -1]); permute_167 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16) + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_596, 64, '0'); convert_element_type_596 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + permute_168 = torch.ops.aten.permute.default(wait_tensor_226, [1, 0]); wait_tensor_226 = None + view_725 = torch.ops.aten.view.default(view_723, [8192, 2048]); view_723 = None + mm_90 = torch.ops.aten.mm.default(view_725, permute_168); view_725 = permute_168 = None + view_726 = torch.ops.aten.view.default(mm_90, [2, 4096, 2048]); mm_90 = None + add_688 = torch.ops.aten.add.Tensor(add_685, view_726); view_726 = None + convert_element_type_599 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16) + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_599, 64, '0'); convert_element_type_599 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + convert_element_type_600 = torch.ops.prims.convert_element_type.default(add_688, torch.float32) + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_600, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_689 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_689); add_689 = None + mul_505 = torch.ops.aten.mul.Tensor(convert_element_type_600, rsqrt_35); convert_element_type_600 = None + mul_506 = torch.ops.aten.mul.Tensor(mul_505, wait_tensor_227); mul_505 = wait_tensor_227 = None + convert_element_type_601 = torch.ops.prims.convert_element_type.default(mul_506, torch.bfloat16); mul_506 = None + view_728 = torch.ops.aten.view.default(convert_element_type_601, [-1, 2048]); convert_element_type_601 = None + convert_element_type_602 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16) + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_602, 64, '0'); convert_element_type_602 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + permute_169 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + mm_91 = torch.ops.aten.mm.default(view_728, permute_169); permute_169 = None + convert_element_type_605 = torch.ops.prims.convert_element_type.default(mm_91, torch.float32) + amax_10 = torch.ops.aten.amax.default(convert_element_type_605, [1], True) + sub_240 = torch.ops.aten.sub.Tensor(convert_element_type_605, amax_10); convert_element_type_605 = None + exp_31 = torch.ops.aten.exp.default(sub_240); sub_240 = None + sum_41 = torch.ops.aten.sum.dim_IntList(exp_31, [1], True) + div_51 = torch.ops.aten.div.Tensor(exp_31, sum_41); exp_31 = None + add_690 = torch.ops.aten.add.Tensor(div_51, primals_190); primals_190 = None + topk_10 = torch.ops.aten.topk.default(add_690, 6, -1, True, False); add_690 = None + getitem_159 = topk_10[1]; topk_10 = None + gather_10 = torch.ops.aten.gather.default(div_51, 1, getitem_159); div_51 = None + mul_507 = torch.ops.aten.mul.Tensor(gather_10, 1.0); gather_10 = None + view_730 = torch.ops.aten.view.default(getitem_159, [-1]) + histc_20 = torch.ops.aten.histc.default(view_730, 64, 0, 64) + add_691 = torch.ops.aten.add.Tensor(primals_192, histc_20) + sort_10 = torch.ops.aten.sort.stable(view_730, stable = True); view_730 = None + getitem_161 = sort_10[1]; sort_10 = None + div_52 = torch.ops.aten.div.Tensor_mode(getitem_161, 6, rounding_mode = 'floor') + index_20 = torch.ops.aten.index.Tensor(view_728, [div_52]) + all_to_all_single_30 = torch.ops._c10d_functional.all_to_all_single.default(histc_20, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_30); all_to_all_single_30 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_229); wait_tensor_229 = None + view_734 = torch.ops.aten.view.default(histc_20, [8, -1]); histc_20 = None + sum_42 = torch.ops.aten.sum.dim_IntList(view_734, [1]); view_734 = None + device_put_20 = torch.ops.prims.device_put.default(sum_42, device(type='cpu'), True); sum_42 = None + view_735 = torch.ops.aten.view.default(wait_tensor_230, [8, -1]) + sum_43 = torch.ops.aten.sum.dim_IntList(view_735, [1]) + device_put_21 = torch.ops.prims.device_put.default(sum_43, device(type='cpu')); sum_43 = None + select_160 = torch.ops.aten.select.int(device_put_20, 0, 0) + _local_scalar_dense_160 = torch.ops.aten._local_scalar_dense.default(select_160); select_160 = None + ge_200 = _local_scalar_dense_160 >= 0 + _assert_scalar_160 = torch.ops.aten._assert_scalar.default(ge_200, "Runtime assertion failed for expression u160 >= 0 on node 'ge_160'"); ge_200 = _assert_scalar_160 = None + select_161 = torch.ops.aten.select.int(device_put_20, 0, 1) + _local_scalar_dense_161 = torch.ops.aten._local_scalar_dense.default(select_161); select_161 = None + ge_201 = _local_scalar_dense_161 >= 0 + _assert_scalar_161 = torch.ops.aten._assert_scalar.default(ge_201, "Runtime assertion failed for expression u161 >= 0 on node 'ge_161'"); ge_201 = _assert_scalar_161 = None + select_162 = torch.ops.aten.select.int(device_put_20, 0, 2) + _local_scalar_dense_162 = torch.ops.aten._local_scalar_dense.default(select_162); select_162 = None + ge_202 = _local_scalar_dense_162 >= 0 + _assert_scalar_162 = torch.ops.aten._assert_scalar.default(ge_202, "Runtime assertion failed for expression u162 >= 0 on node 'ge_162'"); ge_202 = _assert_scalar_162 = None + select_163 = torch.ops.aten.select.int(device_put_20, 0, 3) + _local_scalar_dense_163 = torch.ops.aten._local_scalar_dense.default(select_163); select_163 = None + ge_203 = _local_scalar_dense_163 >= 0 + _assert_scalar_163 = torch.ops.aten._assert_scalar.default(ge_203, "Runtime assertion failed for expression u163 >= 0 on node 'ge_163'"); ge_203 = _assert_scalar_163 = None + select_164 = torch.ops.aten.select.int(device_put_20, 0, 4) + _local_scalar_dense_164 = torch.ops.aten._local_scalar_dense.default(select_164); select_164 = None + ge_204 = _local_scalar_dense_164 >= 0 + _assert_scalar_164 = torch.ops.aten._assert_scalar.default(ge_204, "Runtime assertion failed for expression u164 >= 0 on node 'ge_164'"); ge_204 = _assert_scalar_164 = None + select_165 = torch.ops.aten.select.int(device_put_20, 0, 5) + _local_scalar_dense_165 = torch.ops.aten._local_scalar_dense.default(select_165); select_165 = None + ge_205 = _local_scalar_dense_165 >= 0 + _assert_scalar_165 = torch.ops.aten._assert_scalar.default(ge_205, "Runtime assertion failed for expression u165 >= 0 on node 'ge_165'"); ge_205 = _assert_scalar_165 = None + select_166 = torch.ops.aten.select.int(device_put_20, 0, 6) + _local_scalar_dense_166 = torch.ops.aten._local_scalar_dense.default(select_166); select_166 = None + ge_206 = _local_scalar_dense_166 >= 0 + _assert_scalar_166 = torch.ops.aten._assert_scalar.default(ge_206, "Runtime assertion failed for expression u166 >= 0 on node 'ge_166'"); ge_206 = _assert_scalar_166 = None + select_167 = torch.ops.aten.select.int(device_put_20, 0, 7); device_put_20 = None + _local_scalar_dense_167 = torch.ops.aten._local_scalar_dense.default(select_167); select_167 = None + ge_207 = _local_scalar_dense_167 >= 0 + _assert_scalar_167 = torch.ops.aten._assert_scalar.default(ge_207, "Runtime assertion failed for expression u167 >= 0 on node 'ge_167'"); ge_207 = _assert_scalar_167 = None + select_168 = torch.ops.aten.select.int(device_put_21, 0, 0) + _local_scalar_dense_168 = torch.ops.aten._local_scalar_dense.default(select_168); select_168 = None + ge_208 = _local_scalar_dense_168 >= 0 + _assert_scalar_168 = torch.ops.aten._assert_scalar.default(ge_208, "Runtime assertion failed for expression u168 >= 0 on node 'ge_168'"); ge_208 = _assert_scalar_168 = None + select_169 = torch.ops.aten.select.int(device_put_21, 0, 1) + _local_scalar_dense_169 = torch.ops.aten._local_scalar_dense.default(select_169); select_169 = None + ge_209 = _local_scalar_dense_169 >= 0 + _assert_scalar_169 = torch.ops.aten._assert_scalar.default(ge_209, "Runtime assertion failed for expression u169 >= 0 on node 'ge_169'"); ge_209 = _assert_scalar_169 = None + select_170 = torch.ops.aten.select.int(device_put_21, 0, 2) + _local_scalar_dense_170 = torch.ops.aten._local_scalar_dense.default(select_170); select_170 = None + ge_210 = _local_scalar_dense_170 >= 0 + _assert_scalar_170 = torch.ops.aten._assert_scalar.default(ge_210, "Runtime assertion failed for expression u170 >= 0 on node 'ge_170'"); ge_210 = _assert_scalar_170 = None + select_171 = torch.ops.aten.select.int(device_put_21, 0, 3) + _local_scalar_dense_171 = torch.ops.aten._local_scalar_dense.default(select_171); select_171 = None + ge_211 = _local_scalar_dense_171 >= 0 + _assert_scalar_171 = torch.ops.aten._assert_scalar.default(ge_211, "Runtime assertion failed for expression u171 >= 0 on node 'ge_171'"); ge_211 = _assert_scalar_171 = None + select_172 = torch.ops.aten.select.int(device_put_21, 0, 4) + _local_scalar_dense_172 = torch.ops.aten._local_scalar_dense.default(select_172); select_172 = None + ge_212 = _local_scalar_dense_172 >= 0 + _assert_scalar_172 = torch.ops.aten._assert_scalar.default(ge_212, "Runtime assertion failed for expression u172 >= 0 on node 'ge_172'"); ge_212 = _assert_scalar_172 = None + select_173 = torch.ops.aten.select.int(device_put_21, 0, 5) + _local_scalar_dense_173 = torch.ops.aten._local_scalar_dense.default(select_173); select_173 = None + ge_213 = _local_scalar_dense_173 >= 0 + _assert_scalar_173 = torch.ops.aten._assert_scalar.default(ge_213, "Runtime assertion failed for expression u173 >= 0 on node 'ge_173'"); ge_213 = _assert_scalar_173 = None + select_174 = torch.ops.aten.select.int(device_put_21, 0, 6) + _local_scalar_dense_174 = torch.ops.aten._local_scalar_dense.default(select_174); select_174 = None + ge_214 = _local_scalar_dense_174 >= 0 + _assert_scalar_174 = torch.ops.aten._assert_scalar.default(ge_214, "Runtime assertion failed for expression u174 >= 0 on node 'ge_174'"); ge_214 = _assert_scalar_174 = None + select_175 = torch.ops.aten.select.int(device_put_21, 0, 7); device_put_21 = None + _local_scalar_dense_175 = torch.ops.aten._local_scalar_dense.default(select_175); select_175 = None + ge_215 = _local_scalar_dense_175 >= 0 + _assert_scalar_175 = torch.ops.aten._assert_scalar.default(ge_215, "Runtime assertion failed for expression u175 >= 0 on node 'ge_175'"); ge_215 = _assert_scalar_175 = None + all_to_all_single_31 = torch.ops._c10d_functional.all_to_all_single.default(index_20, [_local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175], [_local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167], '521'); index_20 = None + sym_size_int_40 = torch.ops.aten.sym_size.int(all_to_all_single_31, 0) + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_31); all_to_all_single_31 = None + sym_sum_20 = torch.sym_sum((_local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175)) + add_698 = sym_sum_20 + 64; sym_sum_20 = None + add_699 = add_698 + 8; add_698 = None + sub_243 = add_699 - 1; add_699 = None + floordiv_10 = sub_243 // 8; sub_243 = None + mul_512 = floordiv_10 * 8; floordiv_10 = None + cumsum_30 = torch.ops.aten.cumsum.default(wait_tensor_230, 0) + sub_244 = torch.ops.aten.sub.Tensor(cumsum_30, wait_tensor_230); cumsum_30 = None + sum_44 = torch.ops.aten.sum.dim_IntList(view_735, [0]); view_735 = None + clamp_min_10 = torch.ops.aten.clamp_min.default(sum_44, 8); sum_44 = None + add_700 = torch.ops.aten.add.Tensor(clamp_min_10, 8); clamp_min_10 = None + sub_245 = torch.ops.aten.sub.Tensor(add_700, 1); add_700 = None + div_53 = torch.ops.aten.div.Tensor_mode(sub_245, 8, rounding_mode = 'floor'); sub_245 = None + mul_513 = torch.ops.aten.mul.Tensor(div_53, 8); div_53 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(mul_513, torch.int32); mul_513 = None + cumsum_31 = torch.ops.aten.cumsum.default(convert_element_type_608, 0) + sub_246 = torch.ops.aten.sub.Tensor(cumsum_31, convert_element_type_608); cumsum_31 = None + full_150 = torch.ops.aten.full.default([mul_512], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_512 = None + triton_kernel_wrapper_functional_proxy_10 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 10, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_230, 'start_index_values_ptr': sub_244, 'write_offsets_ptr': sub_246, 'output_ptr': full_150}, tensors_to_clone = ['output_ptr']); wait_tensor_230 = sub_244 = sub_246 = full_150 = None + getitem_162 = triton_kernel_wrapper_functional_proxy_10['output_ptr']; triton_kernel_wrapper_functional_proxy_10 = None + cat_34 = torch.ops.aten.cat.default([wait_tensor_231, full_default]); wait_tensor_231 = None + sym_size_int_41 = torch.ops.aten.sym_size.int(cat_34, 0) + sym_sum_21 = torch.sym_sum((1, _local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175)) + index_21 = torch.ops.aten.index.Tensor(cat_34, [getitem_162]); cat_34 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16) + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_610, 8, '513'); convert_element_type_610 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + convert_element_type_612 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16) + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_612, 8, '513'); convert_element_type_612 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + convert_element_type_613 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16) + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_613, 8, '513'); convert_element_type_613 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + cumsum_32 = torch.ops.aten.cumsum.default(convert_element_type_608, 0, dtype = torch.int32); convert_element_type_608 = None + permute_170 = torch.ops.aten.permute.default(wait_tensor_232, [0, 2, 1]); wait_tensor_232 = None + _grouped_mm_30 = torch.ops.aten._grouped_mm.default(index_21, permute_170, cumsum_32); permute_170 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(_grouped_mm_30, torch.float32) + neg_21 = torch.ops.aten.neg.default(convert_element_type_616) + exp_32 = torch.ops.aten.exp.default(neg_21); neg_21 = None + add_712 = torch.ops.aten.add.Tensor(exp_32, 1); exp_32 = None + div_54 = torch.ops.aten.div.Tensor(convert_element_type_616, add_712); convert_element_type_616 = add_712 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(div_54, torch.bfloat16); div_54 = None + permute_171 = torch.ops.aten.permute.default(wait_tensor_235, [0, 2, 1]); wait_tensor_235 = None + _grouped_mm_31 = torch.ops.aten._grouped_mm.default(index_21, permute_171, cumsum_32); permute_171 = None + mul_525 = torch.ops.aten.mul.Tensor(convert_element_type_617, _grouped_mm_31); convert_element_type_617 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_234, [0, 2, 1]); wait_tensor_234 = None + _grouped_mm_32 = torch.ops.aten._grouped_mm.default(mul_525, permute_172, cumsum_32); permute_172 = None + empty_10 = torch.ops.aten.empty.memory_format([sym_size_int_41, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_20 = torch.ops.aten.index_put.default(empty_10, [getitem_162], _grouped_mm_32); empty_10 = _grouped_mm_32 = None + slice_46 = torch.ops.aten.slice.Tensor(index_put_20, 0, 0, -1); index_put_20 = None + all_to_all_single_32 = torch.ops._c10d_functional.all_to_all_single.default(slice_46, [_local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167], [_local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175], '521'); slice_46 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_32); all_to_all_single_32 = None + convert_element_type_618 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16) + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_618, 64, '0'); convert_element_type_618 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + mm_92 = torch.ops.aten.mm.default(view_728, permute_173); permute_173 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mm_92, torch.float32) + neg_22 = torch.ops.aten.neg.default(convert_element_type_621) + exp_33 = torch.ops.aten.exp.default(neg_22); neg_22 = None + add_748 = torch.ops.aten.add.Tensor(exp_33, 1); exp_33 = None + div_55 = torch.ops.aten.div.Tensor(convert_element_type_621, add_748); convert_element_type_621 = add_748 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(div_55, torch.bfloat16); div_55 = None + convert_element_type_623 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16) + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_623, 64, '0'); convert_element_type_623 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_240, [1, 0]); wait_tensor_240 = None + mm_93 = torch.ops.aten.mm.default(view_728, permute_174); permute_174 = None + mul_545 = torch.ops.aten.mul.Tensor(convert_element_type_622, mm_93); convert_element_type_622 = None + convert_element_type_626 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16) + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_626, 64, '0'); convert_element_type_626 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + mm_94 = torch.ops.aten.mm.default(mul_545, permute_175); permute_175 = None + index_put_21 = torch.ops.aten.index_put.default(full_default_1, [getitem_161], wait_tensor_238); wait_tensor_238 = None + view_768 = torch.ops.aten.view.default(mul_507, [-1, 1, 6]); mul_507 = None + view_769 = torch.ops.aten.view.default(index_put_21, [-1, 6, 2048]); index_put_21 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(view_769, torch.float32); view_769 = None + bmm_10 = torch.ops.aten.bmm.default(view_768, convert_element_type_629) + convert_element_type_630 = torch.ops.prims.convert_element_type.default(bmm_10, torch.bfloat16); bmm_10 = None + squeeze_10 = torch.ops.aten.squeeze.dim(convert_element_type_630, 1); convert_element_type_630 = None + add_752 = torch.ops.aten.add.Tensor(mm_94, squeeze_10); mm_94 = squeeze_10 = None + view_770 = torch.ops.aten.view.default(add_752, [2, 4096, 2048]); add_752 = None + add_753 = torch.ops.aten.add.Tensor(add_688, view_770); view_770 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16) + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 64, '0'); convert_element_type_631 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + convert_element_type_632 = torch.ops.prims.convert_element_type.default(add_753, torch.float32) + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_632, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_754 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_754); add_754 = None + mul_548 = torch.ops.aten.mul.Tensor(convert_element_type_632, rsqrt_36); convert_element_type_632 = None + mul_549 = torch.ops.aten.mul.Tensor(mul_548, wait_tensor_242); mul_548 = wait_tensor_242 = None + convert_element_type_633 = torch.ops.prims.convert_element_type.default(mul_549, torch.bfloat16); mul_549 = None + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16) + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 64, '0'); convert_element_type_634 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + view_773 = torch.ops.aten.view.default(convert_element_type_633, [8192, 2048]); convert_element_type_633 = None + mm_95 = torch.ops.aten.mm.default(view_773, permute_176); permute_176 = None + view_774 = torch.ops.aten.view.default(mm_95, [2, 4096, 3072]); mm_95 = None + view_775 = torch.ops.aten.view.default(view_774, [2, 4096, -1, 192]); view_774 = None + split_with_sizes_36 = torch.ops.aten.split_with_sizes.default(view_775, [128, 64], -1); view_775 = None + getitem_163 = split_with_sizes_36[0] + getitem_164 = split_with_sizes_36[1]; split_with_sizes_36 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(getitem_164, torch.float32); getitem_164 = None + view_776 = torch.ops.aten.view.default(convert_element_type_637, [2, 4096, 16, -1, 2]); convert_element_type_637 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_776); view_776 = None + mul_550 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_7); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_550); mul_550 = None + view_778 = torch.ops.aten.view.default(view_as_real_24, [2, 4096, 16, 64]); view_as_real_24 = None + convert_element_type_638 = torch.ops.prims.convert_element_type.default(view_778, torch.bfloat16); view_778 = None + cat_35 = torch.ops.aten.cat.default([getitem_163, convert_element_type_638], -1); getitem_163 = convert_element_type_638 = None + convert_element_type_639 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16) + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_639, 64, '0'); convert_element_type_639 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_244, [1, 0]); wait_tensor_244 = None + mm_96 = torch.ops.aten.mm.default(view_773, permute_177); permute_177 = None + view_781 = torch.ops.aten.view.default(mm_96, [2, 4096, 576]); mm_96 = None + split_with_sizes_37 = torch.ops.aten.split_with_sizes.default(view_781, [512, 64], -1); view_781 = None + getitem_165 = split_with_sizes_37[0] + getitem_166 = split_with_sizes_37[1]; split_with_sizes_37 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(getitem_166, 2); getitem_166 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(unsqueeze_23, torch.float32); unsqueeze_23 = None + view_782 = torch.ops.aten.view.default(convert_element_type_642, [2, 4096, 1, -1, 2]); convert_element_type_642 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_782); view_782 = None + mul_551 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_7); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_551); mul_551 = None + view_784 = torch.ops.aten.view.default(view_as_real_25, [2, 4096, 1, 64]); view_as_real_25 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_784, torch.bfloat16); view_784 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16) + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 64, '0'); convert_element_type_644 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + convert_element_type_645 = torch.ops.prims.convert_element_type.default(getitem_165, torch.float32) + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_645, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_755 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_755); add_755 = None + mul_552 = torch.ops.aten.mul.Tensor(convert_element_type_645, rsqrt_37); convert_element_type_645 = None + mul_553 = torch.ops.aten.mul.Tensor(mul_552, wait_tensor_245); mul_552 = wait_tensor_245 = None + convert_element_type_646 = torch.ops.prims.convert_element_type.default(mul_553, torch.bfloat16); mul_553 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16) + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 64, '0'); convert_element_type_647 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + view_787 = torch.ops.aten.view.default(convert_element_type_646, [8192, 512]); convert_element_type_646 = None + mm_97 = torch.ops.aten.mm.default(view_787, permute_178); permute_178 = None + view_788 = torch.ops.aten.view.default(mm_97, [2, 4096, 4096]); mm_97 = None + view_789 = torch.ops.aten.view.default(view_788, [2, 4096, -1, 256]); view_788 = None + split_with_sizes_38 = torch.ops.aten.split_with_sizes.default(view_789, [128, 128], -1); view_789 = None + getitem_167 = split_with_sizes_38[0] + getitem_168 = split_with_sizes_38[1]; split_with_sizes_38 = None + expand_12 = torch.ops.aten.expand.default(convert_element_type_643, [-1, -1, 16, -1]); convert_element_type_643 = None + cat_36 = torch.ops.aten.cat.default([getitem_167, expand_12], -1); getitem_167 = expand_12 = None + permute_179 = torch.ops.aten.permute.default(cat_35, [0, 2, 1, 3]); cat_35 = None + permute_180 = torch.ops.aten.permute.default(cat_36, [0, 2, 1, 3]); cat_36 = None + permute_181 = torch.ops.aten.permute.default(getitem_168, [0, 2, 1, 3]); getitem_168 = None + sdpa_score12 = self.sdpa_score12 + sdpa_mask12 = self.sdpa_mask12 + flex_attention_12 = torch.ops.higher_order.flex_attention(permute_179, permute_180, permute_181, sdpa_score12, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask12), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score12 = sdpa_mask12 = None + getitem_169 = flex_attention_12[0] + getitem_170 = flex_attention_12[1]; flex_attention_12 = None + permute_182 = torch.ops.aten.permute.default(getitem_169, [0, 2, 1, 3]) + view_790 = torch.ops.aten.view.default(permute_182, [2, 4096, -1]); permute_182 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16) + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 64, '0'); convert_element_type_650 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + view_792 = torch.ops.aten.view.default(view_790, [8192, 2048]); view_790 = None + mm_98 = torch.ops.aten.mm.default(view_792, permute_183); view_792 = permute_183 = None + view_793 = torch.ops.aten.view.default(mm_98, [2, 4096, 2048]); mm_98 = None + add_756 = torch.ops.aten.add.Tensor(add_753, view_793); view_793 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16) + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_653, 64, '0'); convert_element_type_653 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(add_756, torch.float32) + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_654, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_757 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_757); add_757 = None + mul_554 = torch.ops.aten.mul.Tensor(convert_element_type_654, rsqrt_38); convert_element_type_654 = None + mul_555 = torch.ops.aten.mul.Tensor(mul_554, wait_tensor_248); mul_554 = wait_tensor_248 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(mul_555, torch.bfloat16); mul_555 = None + view_795 = torch.ops.aten.view.default(convert_element_type_655, [-1, 2048]); convert_element_type_655 = None + convert_element_type_656 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16) + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_656, 64, '0'); convert_element_type_656 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_249, [1, 0]); wait_tensor_249 = None + mm_99 = torch.ops.aten.mm.default(view_795, permute_184); permute_184 = None + convert_element_type_659 = torch.ops.prims.convert_element_type.default(mm_99, torch.float32) + amax_11 = torch.ops.aten.amax.default(convert_element_type_659, [1], True) + sub_264 = torch.ops.aten.sub.Tensor(convert_element_type_659, amax_11); convert_element_type_659 = None + exp_34 = torch.ops.aten.exp.default(sub_264); sub_264 = None + sum_45 = torch.ops.aten.sum.dim_IntList(exp_34, [1], True) + div_56 = torch.ops.aten.div.Tensor(exp_34, sum_45); exp_34 = None + add_758 = torch.ops.aten.add.Tensor(div_56, primals_206); primals_206 = None + topk_11 = torch.ops.aten.topk.default(add_758, 6, -1, True, False); add_758 = None + getitem_173 = topk_11[1]; topk_11 = None + gather_11 = torch.ops.aten.gather.default(div_56, 1, getitem_173); div_56 = None + mul_556 = torch.ops.aten.mul.Tensor(gather_11, 1.0); gather_11 = None + view_797 = torch.ops.aten.view.default(getitem_173, [-1]) + histc_22 = torch.ops.aten.histc.default(view_797, 64, 0, 64) + add_759 = torch.ops.aten.add.Tensor(primals_208, histc_22) + sort_11 = torch.ops.aten.sort.stable(view_797, stable = True); view_797 = None + getitem_175 = sort_11[1]; sort_11 = None + div_57 = torch.ops.aten.div.Tensor_mode(getitem_175, 6, rounding_mode = 'floor') + index_22 = torch.ops.aten.index.Tensor(view_795, [div_57]) + all_to_all_single_33 = torch.ops._c10d_functional.all_to_all_single.default(histc_22, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_33); all_to_all_single_33 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_250); wait_tensor_250 = None + view_801 = torch.ops.aten.view.default(histc_22, [8, -1]); histc_22 = None + sum_46 = torch.ops.aten.sum.dim_IntList(view_801, [1]); view_801 = None + device_put_22 = torch.ops.prims.device_put.default(sum_46, device(type='cpu'), True); sum_46 = None + view_802 = torch.ops.aten.view.default(wait_tensor_251, [8, -1]) + sum_47 = torch.ops.aten.sum.dim_IntList(view_802, [1]) + device_put_23 = torch.ops.prims.device_put.default(sum_47, device(type='cpu')); sum_47 = None + select_176 = torch.ops.aten.select.int(device_put_22, 0, 0) + _local_scalar_dense_176 = torch.ops.aten._local_scalar_dense.default(select_176); select_176 = None + ge_220 = _local_scalar_dense_176 >= 0 + _assert_scalar_176 = torch.ops.aten._assert_scalar.default(ge_220, "Runtime assertion failed for expression u176 >= 0 on node 'ge_176'"); ge_220 = _assert_scalar_176 = None + select_177 = torch.ops.aten.select.int(device_put_22, 0, 1) + _local_scalar_dense_177 = torch.ops.aten._local_scalar_dense.default(select_177); select_177 = None + ge_221 = _local_scalar_dense_177 >= 0 + _assert_scalar_177 = torch.ops.aten._assert_scalar.default(ge_221, "Runtime assertion failed for expression u177 >= 0 on node 'ge_177'"); ge_221 = _assert_scalar_177 = None + select_178 = torch.ops.aten.select.int(device_put_22, 0, 2) + _local_scalar_dense_178 = torch.ops.aten._local_scalar_dense.default(select_178); select_178 = None + ge_222 = _local_scalar_dense_178 >= 0 + _assert_scalar_178 = torch.ops.aten._assert_scalar.default(ge_222, "Runtime assertion failed for expression u178 >= 0 on node 'ge_178'"); ge_222 = _assert_scalar_178 = None + select_179 = torch.ops.aten.select.int(device_put_22, 0, 3) + _local_scalar_dense_179 = torch.ops.aten._local_scalar_dense.default(select_179); select_179 = None + ge_223 = _local_scalar_dense_179 >= 0 + _assert_scalar_179 = torch.ops.aten._assert_scalar.default(ge_223, "Runtime assertion failed for expression u179 >= 0 on node 'ge_179'"); ge_223 = _assert_scalar_179 = None + select_180 = torch.ops.aten.select.int(device_put_22, 0, 4) + _local_scalar_dense_180 = torch.ops.aten._local_scalar_dense.default(select_180); select_180 = None + ge_224 = _local_scalar_dense_180 >= 0 + _assert_scalar_180 = torch.ops.aten._assert_scalar.default(ge_224, "Runtime assertion failed for expression u180 >= 0 on node 'ge_180'"); ge_224 = _assert_scalar_180 = None + select_181 = torch.ops.aten.select.int(device_put_22, 0, 5) + _local_scalar_dense_181 = torch.ops.aten._local_scalar_dense.default(select_181); select_181 = None + ge_225 = _local_scalar_dense_181 >= 0 + _assert_scalar_181 = torch.ops.aten._assert_scalar.default(ge_225, "Runtime assertion failed for expression u181 >= 0 on node 'ge_181'"); ge_225 = _assert_scalar_181 = None + select_182 = torch.ops.aten.select.int(device_put_22, 0, 6) + _local_scalar_dense_182 = torch.ops.aten._local_scalar_dense.default(select_182); select_182 = None + ge_226 = _local_scalar_dense_182 >= 0 + _assert_scalar_182 = torch.ops.aten._assert_scalar.default(ge_226, "Runtime assertion failed for expression u182 >= 0 on node 'ge_182'"); ge_226 = _assert_scalar_182 = None + select_183 = torch.ops.aten.select.int(device_put_22, 0, 7); device_put_22 = None + _local_scalar_dense_183 = torch.ops.aten._local_scalar_dense.default(select_183); select_183 = None + ge_227 = _local_scalar_dense_183 >= 0 + _assert_scalar_183 = torch.ops.aten._assert_scalar.default(ge_227, "Runtime assertion failed for expression u183 >= 0 on node 'ge_183'"); ge_227 = _assert_scalar_183 = None + select_184 = torch.ops.aten.select.int(device_put_23, 0, 0) + _local_scalar_dense_184 = torch.ops.aten._local_scalar_dense.default(select_184); select_184 = None + ge_228 = _local_scalar_dense_184 >= 0 + _assert_scalar_184 = torch.ops.aten._assert_scalar.default(ge_228, "Runtime assertion failed for expression u184 >= 0 on node 'ge_184'"); ge_228 = _assert_scalar_184 = None + select_185 = torch.ops.aten.select.int(device_put_23, 0, 1) + _local_scalar_dense_185 = torch.ops.aten._local_scalar_dense.default(select_185); select_185 = None + ge_229 = _local_scalar_dense_185 >= 0 + _assert_scalar_185 = torch.ops.aten._assert_scalar.default(ge_229, "Runtime assertion failed for expression u185 >= 0 on node 'ge_185'"); ge_229 = _assert_scalar_185 = None + select_186 = torch.ops.aten.select.int(device_put_23, 0, 2) + _local_scalar_dense_186 = torch.ops.aten._local_scalar_dense.default(select_186); select_186 = None + ge_230 = _local_scalar_dense_186 >= 0 + _assert_scalar_186 = torch.ops.aten._assert_scalar.default(ge_230, "Runtime assertion failed for expression u186 >= 0 on node 'ge_186'"); ge_230 = _assert_scalar_186 = None + select_187 = torch.ops.aten.select.int(device_put_23, 0, 3) + _local_scalar_dense_187 = torch.ops.aten._local_scalar_dense.default(select_187); select_187 = None + ge_231 = _local_scalar_dense_187 >= 0 + _assert_scalar_187 = torch.ops.aten._assert_scalar.default(ge_231, "Runtime assertion failed for expression u187 >= 0 on node 'ge_187'"); ge_231 = _assert_scalar_187 = None + select_188 = torch.ops.aten.select.int(device_put_23, 0, 4) + _local_scalar_dense_188 = torch.ops.aten._local_scalar_dense.default(select_188); select_188 = None + ge_232 = _local_scalar_dense_188 >= 0 + _assert_scalar_188 = torch.ops.aten._assert_scalar.default(ge_232, "Runtime assertion failed for expression u188 >= 0 on node 'ge_188'"); ge_232 = _assert_scalar_188 = None + select_189 = torch.ops.aten.select.int(device_put_23, 0, 5) + _local_scalar_dense_189 = torch.ops.aten._local_scalar_dense.default(select_189); select_189 = None + ge_233 = _local_scalar_dense_189 >= 0 + _assert_scalar_189 = torch.ops.aten._assert_scalar.default(ge_233, "Runtime assertion failed for expression u189 >= 0 on node 'ge_189'"); ge_233 = _assert_scalar_189 = None + select_190 = torch.ops.aten.select.int(device_put_23, 0, 6) + _local_scalar_dense_190 = torch.ops.aten._local_scalar_dense.default(select_190); select_190 = None + ge_234 = _local_scalar_dense_190 >= 0 + _assert_scalar_190 = torch.ops.aten._assert_scalar.default(ge_234, "Runtime assertion failed for expression u190 >= 0 on node 'ge_190'"); ge_234 = _assert_scalar_190 = None + select_191 = torch.ops.aten.select.int(device_put_23, 0, 7); device_put_23 = None + _local_scalar_dense_191 = torch.ops.aten._local_scalar_dense.default(select_191); select_191 = None + ge_235 = _local_scalar_dense_191 >= 0 + _assert_scalar_191 = torch.ops.aten._assert_scalar.default(ge_235, "Runtime assertion failed for expression u191 >= 0 on node 'ge_191'"); ge_235 = _assert_scalar_191 = None + all_to_all_single_34 = torch.ops._c10d_functional.all_to_all_single.default(index_22, [_local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191], [_local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183], '521'); index_22 = None + sym_size_int_44 = torch.ops.aten.sym_size.int(all_to_all_single_34, 0) + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_34); all_to_all_single_34 = None + sym_sum_22 = torch.sym_sum((_local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191)) + add_766 = sym_sum_22 + 64; sym_sum_22 = None + add_767 = add_766 + 8; add_766 = None + sub_267 = add_767 - 1; add_767 = None + floordiv_11 = sub_267 // 8; sub_267 = None + mul_561 = floordiv_11 * 8; floordiv_11 = None + cumsum_33 = torch.ops.aten.cumsum.default(wait_tensor_251, 0) + sub_268 = torch.ops.aten.sub.Tensor(cumsum_33, wait_tensor_251); cumsum_33 = None + sum_48 = torch.ops.aten.sum.dim_IntList(view_802, [0]); view_802 = None + clamp_min_11 = torch.ops.aten.clamp_min.default(sum_48, 8); sum_48 = None + add_768 = torch.ops.aten.add.Tensor(clamp_min_11, 8); clamp_min_11 = None + sub_269 = torch.ops.aten.sub.Tensor(add_768, 1); add_768 = None + div_58 = torch.ops.aten.div.Tensor_mode(sub_269, 8, rounding_mode = 'floor'); sub_269 = None + mul_562 = torch.ops.aten.mul.Tensor(div_58, 8); div_58 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(mul_562, torch.int32); mul_562 = None + cumsum_34 = torch.ops.aten.cumsum.default(convert_element_type_662, 0) + sub_270 = torch.ops.aten.sub.Tensor(cumsum_34, convert_element_type_662); cumsum_34 = None + full_163 = torch.ops.aten.full.default([mul_561], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_561 = None + triton_kernel_wrapper_functional_proxy_11 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 11, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_251, 'start_index_values_ptr': sub_268, 'write_offsets_ptr': sub_270, 'output_ptr': full_163}, tensors_to_clone = ['output_ptr']); wait_tensor_251 = sub_268 = sub_270 = full_163 = None + getitem_176 = triton_kernel_wrapper_functional_proxy_11['output_ptr']; triton_kernel_wrapper_functional_proxy_11 = None + cat_37 = torch.ops.aten.cat.default([wait_tensor_252, full_default]); wait_tensor_252 = None + sym_size_int_45 = torch.ops.aten.sym_size.int(cat_37, 0) + sym_sum_23 = torch.sym_sum((1, _local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191)) + index_23 = torch.ops.aten.index.Tensor(cat_37, [getitem_176]); cat_37 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16) + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 8, '513'); convert_element_type_664 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + convert_element_type_666 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16) + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_666, 8, '513'); convert_element_type_666 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16) + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 8, '513'); convert_element_type_667 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + cumsum_35 = torch.ops.aten.cumsum.default(convert_element_type_662, 0, dtype = torch.int32); convert_element_type_662 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_253, [0, 2, 1]); wait_tensor_253 = None + _grouped_mm_33 = torch.ops.aten._grouped_mm.default(index_23, permute_185, cumsum_35); permute_185 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(_grouped_mm_33, torch.float32) + neg_23 = torch.ops.aten.neg.default(convert_element_type_670) + exp_35 = torch.ops.aten.exp.default(neg_23); neg_23 = None + add_780 = torch.ops.aten.add.Tensor(exp_35, 1); exp_35 = None + div_59 = torch.ops.aten.div.Tensor(convert_element_type_670, add_780); convert_element_type_670 = add_780 = None + convert_element_type_671 = torch.ops.prims.convert_element_type.default(div_59, torch.bfloat16); div_59 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_256, [0, 2, 1]); wait_tensor_256 = None + _grouped_mm_34 = torch.ops.aten._grouped_mm.default(index_23, permute_186, cumsum_35); permute_186 = None + mul_574 = torch.ops.aten.mul.Tensor(convert_element_type_671, _grouped_mm_34); convert_element_type_671 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_255, [0, 2, 1]); wait_tensor_255 = None + _grouped_mm_35 = torch.ops.aten._grouped_mm.default(mul_574, permute_187, cumsum_35); permute_187 = None + empty_11 = torch.ops.aten.empty.memory_format([sym_size_int_45, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_22 = torch.ops.aten.index_put.default(empty_11, [getitem_176], _grouped_mm_35); empty_11 = _grouped_mm_35 = None + slice_50 = torch.ops.aten.slice.Tensor(index_put_22, 0, 0, -1); index_put_22 = None + all_to_all_single_35 = torch.ops._c10d_functional.all_to_all_single.default(slice_50, [_local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183], [_local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191], '521'); slice_50 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_35); all_to_all_single_35 = None + convert_element_type_672 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16) + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_672, 64, '0'); convert_element_type_672 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + mm_100 = torch.ops.aten.mm.default(view_795, permute_188); permute_188 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(mm_100, torch.float32) + neg_24 = torch.ops.aten.neg.default(convert_element_type_675) + exp_36 = torch.ops.aten.exp.default(neg_24); neg_24 = None + add_816 = torch.ops.aten.add.Tensor(exp_36, 1); exp_36 = None + div_60 = torch.ops.aten.div.Tensor(convert_element_type_675, add_816); convert_element_type_675 = add_816 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(div_60, torch.bfloat16); div_60 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16) + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 64, '0'); convert_element_type_677 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + mm_101 = torch.ops.aten.mm.default(view_795, permute_189); permute_189 = None + mul_594 = torch.ops.aten.mul.Tensor(convert_element_type_676, mm_101); convert_element_type_676 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16) + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 64, '0'); convert_element_type_680 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_190 = torch.ops.aten.permute.default(wait_tensor_262, [1, 0]); wait_tensor_262 = None + mm_102 = torch.ops.aten.mm.default(mul_594, permute_190); permute_190 = None + index_put_23 = torch.ops.aten.index_put.default(full_default_1, [getitem_175], wait_tensor_259); wait_tensor_259 = None + view_835 = torch.ops.aten.view.default(mul_556, [-1, 1, 6]); mul_556 = None + view_836 = torch.ops.aten.view.default(index_put_23, [-1, 6, 2048]); index_put_23 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(view_836, torch.float32); view_836 = None + bmm_11 = torch.ops.aten.bmm.default(view_835, convert_element_type_683) + convert_element_type_684 = torch.ops.prims.convert_element_type.default(bmm_11, torch.bfloat16); bmm_11 = None + squeeze_11 = torch.ops.aten.squeeze.dim(convert_element_type_684, 1); convert_element_type_684 = None + add_820 = torch.ops.aten.add.Tensor(mm_102, squeeze_11); mm_102 = squeeze_11 = None + view_837 = torch.ops.aten.view.default(add_820, [2, 4096, 2048]); add_820 = None + add_821 = torch.ops.aten.add.Tensor(add_756, view_837); view_837 = None + convert_element_type_685 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16) + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_685, 64, '0'); convert_element_type_685 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(add_821, torch.float32) + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_686, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_822 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_822); add_822 = None + mul_597 = torch.ops.aten.mul.Tensor(convert_element_type_686, rsqrt_39); convert_element_type_686 = None + mul_598 = torch.ops.aten.mul.Tensor(mul_597, wait_tensor_263); mul_597 = wait_tensor_263 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_598, torch.bfloat16); mul_598 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16) + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 64, '0'); convert_element_type_688 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_191 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + view_840 = torch.ops.aten.view.default(convert_element_type_687, [8192, 2048]); convert_element_type_687 = None + mm_103 = torch.ops.aten.mm.default(view_840, permute_191); permute_191 = None + view_841 = torch.ops.aten.view.default(mm_103, [2, 4096, 3072]); mm_103 = None + view_842 = torch.ops.aten.view.default(view_841, [2, 4096, -1, 192]); view_841 = None + split_with_sizes_39 = torch.ops.aten.split_with_sizes.default(view_842, [128, 64], -1); view_842 = None + getitem_177 = split_with_sizes_39[0] + getitem_178 = split_with_sizes_39[1]; split_with_sizes_39 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(getitem_178, torch.float32); getitem_178 = None + view_843 = torch.ops.aten.view.default(convert_element_type_691, [2, 4096, 16, -1, 2]); convert_element_type_691 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_843); view_843 = None + mul_599 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_7); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_599); mul_599 = None + view_845 = torch.ops.aten.view.default(view_as_real_26, [2, 4096, 16, 64]); view_as_real_26 = None + convert_element_type_692 = torch.ops.prims.convert_element_type.default(view_845, torch.bfloat16); view_845 = None + cat_38 = torch.ops.aten.cat.default([getitem_177, convert_element_type_692], -1); getitem_177 = convert_element_type_692 = None + convert_element_type_693 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16) + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_693, 64, '0'); convert_element_type_693 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + permute_192 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_104 = torch.ops.aten.mm.default(view_840, permute_192); permute_192 = None + view_848 = torch.ops.aten.view.default(mm_104, [2, 4096, 576]); mm_104 = None + split_with_sizes_40 = torch.ops.aten.split_with_sizes.default(view_848, [512, 64], -1); view_848 = None + getitem_179 = split_with_sizes_40[0] + getitem_180 = split_with_sizes_40[1]; split_with_sizes_40 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(getitem_180, 2); getitem_180 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(unsqueeze_25, torch.float32); unsqueeze_25 = None + view_849 = torch.ops.aten.view.default(convert_element_type_696, [2, 4096, 1, -1, 2]); convert_element_type_696 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_849); view_849 = None + mul_600 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_7); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_600); mul_600 = None + view_851 = torch.ops.aten.view.default(view_as_real_27, [2, 4096, 1, 64]); view_as_real_27 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(view_851, torch.bfloat16); view_851 = None + convert_element_type_698 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16) + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_698, 64, '0'); convert_element_type_698 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + convert_element_type_699 = torch.ops.prims.convert_element_type.default(getitem_179, torch.float32) + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_699, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_823 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_823); add_823 = None + mul_601 = torch.ops.aten.mul.Tensor(convert_element_type_699, rsqrt_40); convert_element_type_699 = None + mul_602 = torch.ops.aten.mul.Tensor(mul_601, wait_tensor_266); mul_601 = wait_tensor_266 = None + convert_element_type_700 = torch.ops.prims.convert_element_type.default(mul_602, torch.bfloat16); mul_602 = None + convert_element_type_701 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16) + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_701, 64, '0'); convert_element_type_701 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_193 = torch.ops.aten.permute.default(wait_tensor_267, [1, 0]); wait_tensor_267 = None + view_854 = torch.ops.aten.view.default(convert_element_type_700, [8192, 512]); convert_element_type_700 = None + mm_105 = torch.ops.aten.mm.default(view_854, permute_193); permute_193 = None + view_855 = torch.ops.aten.view.default(mm_105, [2, 4096, 4096]); mm_105 = None + view_856 = torch.ops.aten.view.default(view_855, [2, 4096, -1, 256]); view_855 = None + split_with_sizes_41 = torch.ops.aten.split_with_sizes.default(view_856, [128, 128], -1); view_856 = None + getitem_181 = split_with_sizes_41[0] + getitem_182 = split_with_sizes_41[1]; split_with_sizes_41 = None + expand_13 = torch.ops.aten.expand.default(convert_element_type_697, [-1, -1, 16, -1]); convert_element_type_697 = None + cat_39 = torch.ops.aten.cat.default([getitem_181, expand_13], -1); getitem_181 = expand_13 = None + permute_194 = torch.ops.aten.permute.default(cat_38, [0, 2, 1, 3]); cat_38 = None + permute_195 = torch.ops.aten.permute.default(cat_39, [0, 2, 1, 3]); cat_39 = None + permute_196 = torch.ops.aten.permute.default(getitem_182, [0, 2, 1, 3]); getitem_182 = None + sdpa_score13 = self.sdpa_score13 + sdpa_mask13 = self.sdpa_mask13 + flex_attention_13 = torch.ops.higher_order.flex_attention(permute_194, permute_195, permute_196, sdpa_score13, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask13), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score13 = sdpa_mask13 = None + getitem_183 = flex_attention_13[0] + getitem_184 = flex_attention_13[1]; flex_attention_13 = None + permute_197 = torch.ops.aten.permute.default(getitem_183, [0, 2, 1, 3]) + view_857 = torch.ops.aten.view.default(permute_197, [2, 4096, -1]); permute_197 = None + convert_element_type_704 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16) + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_704, 64, '0'); convert_element_type_704 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + view_859 = torch.ops.aten.view.default(view_857, [8192, 2048]); view_857 = None + mm_106 = torch.ops.aten.mm.default(view_859, permute_198); view_859 = permute_198 = None + view_860 = torch.ops.aten.view.default(mm_106, [2, 4096, 2048]); mm_106 = None + add_824 = torch.ops.aten.add.Tensor(add_821, view_860); view_860 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16) + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_707, 64, '0'); convert_element_type_707 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(add_824, torch.float32) + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_708, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_825 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_825); add_825 = None + mul_603 = torch.ops.aten.mul.Tensor(convert_element_type_708, rsqrt_41); convert_element_type_708 = None + mul_604 = torch.ops.aten.mul.Tensor(mul_603, wait_tensor_269); mul_603 = wait_tensor_269 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(mul_604, torch.bfloat16); mul_604 = None + view_862 = torch.ops.aten.view.default(convert_element_type_709, [-1, 2048]); convert_element_type_709 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16) + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 64, '0'); convert_element_type_710 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_270, [1, 0]); wait_tensor_270 = None + mm_107 = torch.ops.aten.mm.default(view_862, permute_199); permute_199 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(mm_107, torch.float32) + amax_12 = torch.ops.aten.amax.default(convert_element_type_713, [1], True) + sub_288 = torch.ops.aten.sub.Tensor(convert_element_type_713, amax_12); convert_element_type_713 = None + exp_37 = torch.ops.aten.exp.default(sub_288); sub_288 = None + sum_49 = torch.ops.aten.sum.dim_IntList(exp_37, [1], True) + div_61 = torch.ops.aten.div.Tensor(exp_37, sum_49); exp_37 = None + add_826 = torch.ops.aten.add.Tensor(div_61, primals_222); primals_222 = None + topk_12 = torch.ops.aten.topk.default(add_826, 6, -1, True, False); add_826 = None + getitem_187 = topk_12[1]; topk_12 = None + gather_12 = torch.ops.aten.gather.default(div_61, 1, getitem_187); div_61 = None + mul_605 = torch.ops.aten.mul.Tensor(gather_12, 1.0); gather_12 = None + view_864 = torch.ops.aten.view.default(getitem_187, [-1]) + histc_24 = torch.ops.aten.histc.default(view_864, 64, 0, 64) + add_827 = torch.ops.aten.add.Tensor(primals_224, histc_24) + sort_12 = torch.ops.aten.sort.stable(view_864, stable = True); view_864 = None + getitem_189 = sort_12[1]; sort_12 = None + div_62 = torch.ops.aten.div.Tensor_mode(getitem_189, 6, rounding_mode = 'floor') + index_24 = torch.ops.aten.index.Tensor(view_862, [div_62]) + all_to_all_single_36 = torch.ops._c10d_functional.all_to_all_single.default(histc_24, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_36); all_to_all_single_36 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_271); wait_tensor_271 = None + view_868 = torch.ops.aten.view.default(histc_24, [8, -1]); histc_24 = None + sum_50 = torch.ops.aten.sum.dim_IntList(view_868, [1]); view_868 = None + device_put_24 = torch.ops.prims.device_put.default(sum_50, device(type='cpu'), True); sum_50 = None + view_869 = torch.ops.aten.view.default(wait_tensor_272, [8, -1]) + sum_51 = torch.ops.aten.sum.dim_IntList(view_869, [1]) + device_put_25 = torch.ops.prims.device_put.default(sum_51, device(type='cpu')); sum_51 = None + select_192 = torch.ops.aten.select.int(device_put_24, 0, 0) + _local_scalar_dense_192 = torch.ops.aten._local_scalar_dense.default(select_192); select_192 = None + ge_240 = _local_scalar_dense_192 >= 0 + _assert_scalar_192 = torch.ops.aten._assert_scalar.default(ge_240, "Runtime assertion failed for expression u192 >= 0 on node 'ge_192'"); ge_240 = _assert_scalar_192 = None + select_193 = torch.ops.aten.select.int(device_put_24, 0, 1) + _local_scalar_dense_193 = torch.ops.aten._local_scalar_dense.default(select_193); select_193 = None + ge_241 = _local_scalar_dense_193 >= 0 + _assert_scalar_193 = torch.ops.aten._assert_scalar.default(ge_241, "Runtime assertion failed for expression u193 >= 0 on node 'ge_193'"); ge_241 = _assert_scalar_193 = None + select_194 = torch.ops.aten.select.int(device_put_24, 0, 2) + _local_scalar_dense_194 = torch.ops.aten._local_scalar_dense.default(select_194); select_194 = None + ge_242 = _local_scalar_dense_194 >= 0 + _assert_scalar_194 = torch.ops.aten._assert_scalar.default(ge_242, "Runtime assertion failed for expression u194 >= 0 on node 'ge_194'"); ge_242 = _assert_scalar_194 = None + select_195 = torch.ops.aten.select.int(device_put_24, 0, 3) + _local_scalar_dense_195 = torch.ops.aten._local_scalar_dense.default(select_195); select_195 = None + ge_243 = _local_scalar_dense_195 >= 0 + _assert_scalar_195 = torch.ops.aten._assert_scalar.default(ge_243, "Runtime assertion failed for expression u195 >= 0 on node 'ge_195'"); ge_243 = _assert_scalar_195 = None + select_196 = torch.ops.aten.select.int(device_put_24, 0, 4) + _local_scalar_dense_196 = torch.ops.aten._local_scalar_dense.default(select_196); select_196 = None + ge_244 = _local_scalar_dense_196 >= 0 + _assert_scalar_196 = torch.ops.aten._assert_scalar.default(ge_244, "Runtime assertion failed for expression u196 >= 0 on node 'ge_196'"); ge_244 = _assert_scalar_196 = None + select_197 = torch.ops.aten.select.int(device_put_24, 0, 5) + _local_scalar_dense_197 = torch.ops.aten._local_scalar_dense.default(select_197); select_197 = None + ge_245 = _local_scalar_dense_197 >= 0 + _assert_scalar_197 = torch.ops.aten._assert_scalar.default(ge_245, "Runtime assertion failed for expression u197 >= 0 on node 'ge_197'"); ge_245 = _assert_scalar_197 = None + select_198 = torch.ops.aten.select.int(device_put_24, 0, 6) + _local_scalar_dense_198 = torch.ops.aten._local_scalar_dense.default(select_198); select_198 = None + ge_246 = _local_scalar_dense_198 >= 0 + _assert_scalar_198 = torch.ops.aten._assert_scalar.default(ge_246, "Runtime assertion failed for expression u198 >= 0 on node 'ge_198'"); ge_246 = _assert_scalar_198 = None + select_199 = torch.ops.aten.select.int(device_put_24, 0, 7); device_put_24 = None + _local_scalar_dense_199 = torch.ops.aten._local_scalar_dense.default(select_199); select_199 = None + ge_247 = _local_scalar_dense_199 >= 0 + _assert_scalar_199 = torch.ops.aten._assert_scalar.default(ge_247, "Runtime assertion failed for expression u199 >= 0 on node 'ge_199'"); ge_247 = _assert_scalar_199 = None + select_200 = torch.ops.aten.select.int(device_put_25, 0, 0) + _local_scalar_dense_200 = torch.ops.aten._local_scalar_dense.default(select_200); select_200 = None + ge_248 = _local_scalar_dense_200 >= 0 + _assert_scalar_200 = torch.ops.aten._assert_scalar.default(ge_248, "Runtime assertion failed for expression u200 >= 0 on node 'ge_200'"); ge_248 = _assert_scalar_200 = None + select_201 = torch.ops.aten.select.int(device_put_25, 0, 1) + _local_scalar_dense_201 = torch.ops.aten._local_scalar_dense.default(select_201); select_201 = None + ge_249 = _local_scalar_dense_201 >= 0 + _assert_scalar_201 = torch.ops.aten._assert_scalar.default(ge_249, "Runtime assertion failed for expression u201 >= 0 on node 'ge_201'"); ge_249 = _assert_scalar_201 = None + select_202 = torch.ops.aten.select.int(device_put_25, 0, 2) + _local_scalar_dense_202 = torch.ops.aten._local_scalar_dense.default(select_202); select_202 = None + ge_250 = _local_scalar_dense_202 >= 0 + _assert_scalar_202 = torch.ops.aten._assert_scalar.default(ge_250, "Runtime assertion failed for expression u202 >= 0 on node 'ge_202'"); ge_250 = _assert_scalar_202 = None + select_203 = torch.ops.aten.select.int(device_put_25, 0, 3) + _local_scalar_dense_203 = torch.ops.aten._local_scalar_dense.default(select_203); select_203 = None + ge_251 = _local_scalar_dense_203 >= 0 + _assert_scalar_203 = torch.ops.aten._assert_scalar.default(ge_251, "Runtime assertion failed for expression u203 >= 0 on node 'ge_203'"); ge_251 = _assert_scalar_203 = None + select_204 = torch.ops.aten.select.int(device_put_25, 0, 4) + _local_scalar_dense_204 = torch.ops.aten._local_scalar_dense.default(select_204); select_204 = None + ge_252 = _local_scalar_dense_204 >= 0 + _assert_scalar_204 = torch.ops.aten._assert_scalar.default(ge_252, "Runtime assertion failed for expression u204 >= 0 on node 'ge_204'"); ge_252 = _assert_scalar_204 = None + select_205 = torch.ops.aten.select.int(device_put_25, 0, 5) + _local_scalar_dense_205 = torch.ops.aten._local_scalar_dense.default(select_205); select_205 = None + ge_253 = _local_scalar_dense_205 >= 0 + _assert_scalar_205 = torch.ops.aten._assert_scalar.default(ge_253, "Runtime assertion failed for expression u205 >= 0 on node 'ge_205'"); ge_253 = _assert_scalar_205 = None + select_206 = torch.ops.aten.select.int(device_put_25, 0, 6) + _local_scalar_dense_206 = torch.ops.aten._local_scalar_dense.default(select_206); select_206 = None + ge_254 = _local_scalar_dense_206 >= 0 + _assert_scalar_206 = torch.ops.aten._assert_scalar.default(ge_254, "Runtime assertion failed for expression u206 >= 0 on node 'ge_206'"); ge_254 = _assert_scalar_206 = None + select_207 = torch.ops.aten.select.int(device_put_25, 0, 7); device_put_25 = None + _local_scalar_dense_207 = torch.ops.aten._local_scalar_dense.default(select_207); select_207 = None + ge_255 = _local_scalar_dense_207 >= 0 + _assert_scalar_207 = torch.ops.aten._assert_scalar.default(ge_255, "Runtime assertion failed for expression u207 >= 0 on node 'ge_207'"); ge_255 = _assert_scalar_207 = None + all_to_all_single_37 = torch.ops._c10d_functional.all_to_all_single.default(index_24, [_local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207], [_local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199], '521'); index_24 = None + sym_size_int_48 = torch.ops.aten.sym_size.int(all_to_all_single_37, 0) + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_37); all_to_all_single_37 = None + sym_sum_24 = torch.sym_sum((_local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207)) + add_834 = sym_sum_24 + 64; sym_sum_24 = None + add_835 = add_834 + 8; add_834 = None + sub_291 = add_835 - 1; add_835 = None + floordiv_12 = sub_291 // 8; sub_291 = None + mul_610 = floordiv_12 * 8; floordiv_12 = None + cumsum_36 = torch.ops.aten.cumsum.default(wait_tensor_272, 0) + sub_292 = torch.ops.aten.sub.Tensor(cumsum_36, wait_tensor_272); cumsum_36 = None + sum_52 = torch.ops.aten.sum.dim_IntList(view_869, [0]); view_869 = None + clamp_min_12 = torch.ops.aten.clamp_min.default(sum_52, 8); sum_52 = None + add_836 = torch.ops.aten.add.Tensor(clamp_min_12, 8); clamp_min_12 = None + sub_293 = torch.ops.aten.sub.Tensor(add_836, 1); add_836 = None + div_63 = torch.ops.aten.div.Tensor_mode(sub_293, 8, rounding_mode = 'floor'); sub_293 = None + mul_611 = torch.ops.aten.mul.Tensor(div_63, 8); div_63 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(mul_611, torch.int32); mul_611 = None + cumsum_37 = torch.ops.aten.cumsum.default(convert_element_type_716, 0) + sub_294 = torch.ops.aten.sub.Tensor(cumsum_37, convert_element_type_716); cumsum_37 = None + full_176 = torch.ops.aten.full.default([mul_610], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_610 = None + triton_kernel_wrapper_functional_proxy_12 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 12, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_272, 'start_index_values_ptr': sub_292, 'write_offsets_ptr': sub_294, 'output_ptr': full_176}, tensors_to_clone = ['output_ptr']); wait_tensor_272 = sub_292 = sub_294 = full_176 = None + getitem_190 = triton_kernel_wrapper_functional_proxy_12['output_ptr']; triton_kernel_wrapper_functional_proxy_12 = None + cat_40 = torch.ops.aten.cat.default([wait_tensor_273, full_default]); wait_tensor_273 = None + sym_size_int_49 = torch.ops.aten.sym_size.int(cat_40, 0) + sym_sum_25 = torch.sym_sum((1, _local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207)) + index_25 = torch.ops.aten.index.Tensor(cat_40, [getitem_190]); cat_40 = None + convert_element_type_718 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16) + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_718, 8, '513'); convert_element_type_718 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16) + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_720, 8, '513'); convert_element_type_720 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16) + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 8, '513'); convert_element_type_721 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + cumsum_38 = torch.ops.aten.cumsum.default(convert_element_type_716, 0, dtype = torch.int32); convert_element_type_716 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_274, [0, 2, 1]); wait_tensor_274 = None + _grouped_mm_36 = torch.ops.aten._grouped_mm.default(index_25, permute_200, cumsum_38); permute_200 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(_grouped_mm_36, torch.float32) + neg_25 = torch.ops.aten.neg.default(convert_element_type_724) + exp_38 = torch.ops.aten.exp.default(neg_25); neg_25 = None + add_848 = torch.ops.aten.add.Tensor(exp_38, 1); exp_38 = None + div_64 = torch.ops.aten.div.Tensor(convert_element_type_724, add_848); convert_element_type_724 = add_848 = None + convert_element_type_725 = torch.ops.prims.convert_element_type.default(div_64, torch.bfloat16); div_64 = None + permute_201 = torch.ops.aten.permute.default(wait_tensor_277, [0, 2, 1]); wait_tensor_277 = None + _grouped_mm_37 = torch.ops.aten._grouped_mm.default(index_25, permute_201, cumsum_38); permute_201 = None + mul_623 = torch.ops.aten.mul.Tensor(convert_element_type_725, _grouped_mm_37); convert_element_type_725 = None + permute_202 = torch.ops.aten.permute.default(wait_tensor_276, [0, 2, 1]); wait_tensor_276 = None + _grouped_mm_38 = torch.ops.aten._grouped_mm.default(mul_623, permute_202, cumsum_38); permute_202 = None + empty_12 = torch.ops.aten.empty.memory_format([sym_size_int_49, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_24 = torch.ops.aten.index_put.default(empty_12, [getitem_190], _grouped_mm_38); empty_12 = _grouped_mm_38 = None + slice_54 = torch.ops.aten.slice.Tensor(index_put_24, 0, 0, -1); index_put_24 = None + all_to_all_single_38 = torch.ops._c10d_functional.all_to_all_single.default(slice_54, [_local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199], [_local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207], '521'); slice_54 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_38); all_to_all_single_38 = None + convert_element_type_726 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16) + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_726, 64, '0'); convert_element_type_726 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_203 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + mm_108 = torch.ops.aten.mm.default(view_862, permute_203); permute_203 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mm_108, torch.float32) + neg_26 = torch.ops.aten.neg.default(convert_element_type_729) + exp_39 = torch.ops.aten.exp.default(neg_26); neg_26 = None + add_884 = torch.ops.aten.add.Tensor(exp_39, 1); exp_39 = None + div_65 = torch.ops.aten.div.Tensor(convert_element_type_729, add_884); convert_element_type_729 = add_884 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(div_65, torch.bfloat16); div_65 = None + convert_element_type_731 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16) + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_731, 64, '0'); convert_element_type_731 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_204 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + mm_109 = torch.ops.aten.mm.default(view_862, permute_204); permute_204 = None + mul_643 = torch.ops.aten.mul.Tensor(convert_element_type_730, mm_109); convert_element_type_730 = None + convert_element_type_734 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16) + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_734, 64, '0'); convert_element_type_734 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + mm_110 = torch.ops.aten.mm.default(mul_643, permute_205); permute_205 = None + index_put_25 = torch.ops.aten.index_put.default(full_default_1, [getitem_189], wait_tensor_280); wait_tensor_280 = None + view_902 = torch.ops.aten.view.default(mul_605, [-1, 1, 6]); mul_605 = None + view_903 = torch.ops.aten.view.default(index_put_25, [-1, 6, 2048]); index_put_25 = None + convert_element_type_737 = torch.ops.prims.convert_element_type.default(view_903, torch.float32); view_903 = None + bmm_12 = torch.ops.aten.bmm.default(view_902, convert_element_type_737) + convert_element_type_738 = torch.ops.prims.convert_element_type.default(bmm_12, torch.bfloat16); bmm_12 = None + squeeze_12 = torch.ops.aten.squeeze.dim(convert_element_type_738, 1); convert_element_type_738 = None + add_888 = torch.ops.aten.add.Tensor(mm_110, squeeze_12); mm_110 = squeeze_12 = None + view_904 = torch.ops.aten.view.default(add_888, [2, 4096, 2048]); add_888 = None + add_889 = torch.ops.aten.add.Tensor(add_824, view_904); view_904 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16) + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_739, 64, '0'); convert_element_type_739 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(add_889, torch.float32) + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_740, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_890 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_890); add_890 = None + mul_646 = torch.ops.aten.mul.Tensor(convert_element_type_740, rsqrt_42); convert_element_type_740 = None + mul_647 = torch.ops.aten.mul.Tensor(mul_646, wait_tensor_284); mul_646 = wait_tensor_284 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(mul_647, torch.bfloat16); mul_647 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16) + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_742, 64, '0'); convert_element_type_742 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_285, [1, 0]); wait_tensor_285 = None + view_907 = torch.ops.aten.view.default(convert_element_type_741, [8192, 2048]); convert_element_type_741 = None + mm_111 = torch.ops.aten.mm.default(view_907, permute_206); permute_206 = None + view_908 = torch.ops.aten.view.default(mm_111, [2, 4096, 3072]); mm_111 = None + view_909 = torch.ops.aten.view.default(view_908, [2, 4096, -1, 192]); view_908 = None + split_with_sizes_42 = torch.ops.aten.split_with_sizes.default(view_909, [128, 64], -1); view_909 = None + getitem_191 = split_with_sizes_42[0] + getitem_192 = split_with_sizes_42[1]; split_with_sizes_42 = None + convert_element_type_745 = torch.ops.prims.convert_element_type.default(getitem_192, torch.float32); getitem_192 = None + view_910 = torch.ops.aten.view.default(convert_element_type_745, [2, 4096, 16, -1, 2]); convert_element_type_745 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_910); view_910 = None + mul_648 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_7); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_648); mul_648 = None + view_912 = torch.ops.aten.view.default(view_as_real_28, [2, 4096, 16, 64]); view_as_real_28 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(view_912, torch.bfloat16); view_912 = None + cat_41 = torch.ops.aten.cat.default([getitem_191, convert_element_type_746], -1); getitem_191 = convert_element_type_746 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16) + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_747, 64, '0'); convert_element_type_747 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + mm_112 = torch.ops.aten.mm.default(view_907, permute_207); permute_207 = None + view_915 = torch.ops.aten.view.default(mm_112, [2, 4096, 576]); mm_112 = None + split_with_sizes_43 = torch.ops.aten.split_with_sizes.default(view_915, [512, 64], -1); view_915 = None + getitem_193 = split_with_sizes_43[0] + getitem_194 = split_with_sizes_43[1]; split_with_sizes_43 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(getitem_194, 2); getitem_194 = None + convert_element_type_750 = torch.ops.prims.convert_element_type.default(unsqueeze_27, torch.float32); unsqueeze_27 = None + view_916 = torch.ops.aten.view.default(convert_element_type_750, [2, 4096, 1, -1, 2]); convert_element_type_750 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_916); view_916 = None + mul_649 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_7); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_649); mul_649 = None + view_918 = torch.ops.aten.view.default(view_as_real_29, [2, 4096, 1, 64]); view_as_real_29 = None + convert_element_type_751 = torch.ops.prims.convert_element_type.default(view_918, torch.bfloat16); view_918 = None + convert_element_type_752 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16) + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_752, 64, '0'); convert_element_type_752 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(getitem_193, torch.float32) + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_753, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_891 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_891); add_891 = None + mul_650 = torch.ops.aten.mul.Tensor(convert_element_type_753, rsqrt_43); convert_element_type_753 = None + mul_651 = torch.ops.aten.mul.Tensor(mul_650, wait_tensor_287); mul_650 = wait_tensor_287 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(mul_651, torch.bfloat16); mul_651 = None + convert_element_type_755 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16) + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_755, 64, '0'); convert_element_type_755 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + view_921 = torch.ops.aten.view.default(convert_element_type_754, [8192, 512]); convert_element_type_754 = None + mm_113 = torch.ops.aten.mm.default(view_921, permute_208); permute_208 = None + view_922 = torch.ops.aten.view.default(mm_113, [2, 4096, 4096]); mm_113 = None + view_923 = torch.ops.aten.view.default(view_922, [2, 4096, -1, 256]); view_922 = None + split_with_sizes_44 = torch.ops.aten.split_with_sizes.default(view_923, [128, 128], -1); view_923 = None + getitem_195 = split_with_sizes_44[0] + getitem_196 = split_with_sizes_44[1]; split_with_sizes_44 = None + expand_14 = torch.ops.aten.expand.default(convert_element_type_751, [-1, -1, 16, -1]); convert_element_type_751 = None + cat_42 = torch.ops.aten.cat.default([getitem_195, expand_14], -1); getitem_195 = expand_14 = None + permute_209 = torch.ops.aten.permute.default(cat_41, [0, 2, 1, 3]); cat_41 = None + permute_210 = torch.ops.aten.permute.default(cat_42, [0, 2, 1, 3]); cat_42 = None + permute_211 = torch.ops.aten.permute.default(getitem_196, [0, 2, 1, 3]); getitem_196 = None + sdpa_score14 = self.sdpa_score14 + sdpa_mask14 = self.sdpa_mask14 + flex_attention_14 = torch.ops.higher_order.flex_attention(permute_209, permute_210, permute_211, sdpa_score14, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask14), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score14 = sdpa_mask14 = None + getitem_197 = flex_attention_14[0] + getitem_198 = flex_attention_14[1]; flex_attention_14 = None + permute_212 = torch.ops.aten.permute.default(getitem_197, [0, 2, 1, 3]) + view_924 = torch.ops.aten.view.default(permute_212, [2, 4096, -1]); permute_212 = None + convert_element_type_758 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16) + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_758, 64, '0'); convert_element_type_758 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_213 = torch.ops.aten.permute.default(wait_tensor_289, [1, 0]); wait_tensor_289 = None + view_926 = torch.ops.aten.view.default(view_924, [8192, 2048]); view_924 = None + mm_114 = torch.ops.aten.mm.default(view_926, permute_213); view_926 = permute_213 = None + view_927 = torch.ops.aten.view.default(mm_114, [2, 4096, 2048]); mm_114 = None + add_892 = torch.ops.aten.add.Tensor(add_889, view_927); view_927 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16) + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_761, 64, '0'); convert_element_type_761 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(add_892, torch.float32) + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_762, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_893 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_893); add_893 = None + mul_652 = torch.ops.aten.mul.Tensor(convert_element_type_762, rsqrt_44); convert_element_type_762 = None + mul_653 = torch.ops.aten.mul.Tensor(mul_652, wait_tensor_290); mul_652 = wait_tensor_290 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(mul_653, torch.bfloat16); mul_653 = None + view_929 = torch.ops.aten.view.default(convert_element_type_763, [-1, 2048]); convert_element_type_763 = None + convert_element_type_764 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16) + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_764, 64, '0'); convert_element_type_764 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + permute_214 = torch.ops.aten.permute.default(wait_tensor_291, [1, 0]); wait_tensor_291 = None + mm_115 = torch.ops.aten.mm.default(view_929, permute_214); permute_214 = None + convert_element_type_767 = torch.ops.prims.convert_element_type.default(mm_115, torch.float32) + amax_13 = torch.ops.aten.amax.default(convert_element_type_767, [1], True) + sub_312 = torch.ops.aten.sub.Tensor(convert_element_type_767, amax_13); convert_element_type_767 = None + exp_40 = torch.ops.aten.exp.default(sub_312); sub_312 = None + sum_53 = torch.ops.aten.sum.dim_IntList(exp_40, [1], True) + div_66 = torch.ops.aten.div.Tensor(exp_40, sum_53); exp_40 = None + add_894 = torch.ops.aten.add.Tensor(div_66, primals_238); primals_238 = None + topk_13 = torch.ops.aten.topk.default(add_894, 6, -1, True, False); add_894 = None + getitem_201 = topk_13[1]; topk_13 = None + gather_13 = torch.ops.aten.gather.default(div_66, 1, getitem_201); div_66 = None + mul_654 = torch.ops.aten.mul.Tensor(gather_13, 1.0); gather_13 = None + view_931 = torch.ops.aten.view.default(getitem_201, [-1]) + histc_26 = torch.ops.aten.histc.default(view_931, 64, 0, 64) + add_895 = torch.ops.aten.add.Tensor(primals_240, histc_26) + sort_13 = torch.ops.aten.sort.stable(view_931, stable = True); view_931 = None + getitem_203 = sort_13[1]; sort_13 = None + div_67 = torch.ops.aten.div.Tensor_mode(getitem_203, 6, rounding_mode = 'floor') + index_26 = torch.ops.aten.index.Tensor(view_929, [div_67]) + all_to_all_single_39 = torch.ops._c10d_functional.all_to_all_single.default(histc_26, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_39); all_to_all_single_39 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_292); wait_tensor_292 = None + view_935 = torch.ops.aten.view.default(histc_26, [8, -1]); histc_26 = None + sum_54 = torch.ops.aten.sum.dim_IntList(view_935, [1]); view_935 = None + device_put_26 = torch.ops.prims.device_put.default(sum_54, device(type='cpu'), True); sum_54 = None + view_936 = torch.ops.aten.view.default(wait_tensor_293, [8, -1]) + sum_55 = torch.ops.aten.sum.dim_IntList(view_936, [1]) + device_put_27 = torch.ops.prims.device_put.default(sum_55, device(type='cpu')); sum_55 = None + select_208 = torch.ops.aten.select.int(device_put_26, 0, 0) + _local_scalar_dense_208 = torch.ops.aten._local_scalar_dense.default(select_208); select_208 = None + ge_260 = _local_scalar_dense_208 >= 0 + _assert_scalar_208 = torch.ops.aten._assert_scalar.default(ge_260, "Runtime assertion failed for expression u208 >= 0 on node 'ge_208'"); ge_260 = _assert_scalar_208 = None + select_209 = torch.ops.aten.select.int(device_put_26, 0, 1) + _local_scalar_dense_209 = torch.ops.aten._local_scalar_dense.default(select_209); select_209 = None + ge_261 = _local_scalar_dense_209 >= 0 + _assert_scalar_209 = torch.ops.aten._assert_scalar.default(ge_261, "Runtime assertion failed for expression u209 >= 0 on node 'ge_209'"); ge_261 = _assert_scalar_209 = None + select_210 = torch.ops.aten.select.int(device_put_26, 0, 2) + _local_scalar_dense_210 = torch.ops.aten._local_scalar_dense.default(select_210); select_210 = None + ge_262 = _local_scalar_dense_210 >= 0 + _assert_scalar_210 = torch.ops.aten._assert_scalar.default(ge_262, "Runtime assertion failed for expression u210 >= 0 on node 'ge_210'"); ge_262 = _assert_scalar_210 = None + select_211 = torch.ops.aten.select.int(device_put_26, 0, 3) + _local_scalar_dense_211 = torch.ops.aten._local_scalar_dense.default(select_211); select_211 = None + ge_263 = _local_scalar_dense_211 >= 0 + _assert_scalar_211 = torch.ops.aten._assert_scalar.default(ge_263, "Runtime assertion failed for expression u211 >= 0 on node 'ge_211'"); ge_263 = _assert_scalar_211 = None + select_212 = torch.ops.aten.select.int(device_put_26, 0, 4) + _local_scalar_dense_212 = torch.ops.aten._local_scalar_dense.default(select_212); select_212 = None + ge_264 = _local_scalar_dense_212 >= 0 + _assert_scalar_212 = torch.ops.aten._assert_scalar.default(ge_264, "Runtime assertion failed for expression u212 >= 0 on node 'ge_212'"); ge_264 = _assert_scalar_212 = None + select_213 = torch.ops.aten.select.int(device_put_26, 0, 5) + _local_scalar_dense_213 = torch.ops.aten._local_scalar_dense.default(select_213); select_213 = None + ge_265 = _local_scalar_dense_213 >= 0 + _assert_scalar_213 = torch.ops.aten._assert_scalar.default(ge_265, "Runtime assertion failed for expression u213 >= 0 on node 'ge_213'"); ge_265 = _assert_scalar_213 = None + select_214 = torch.ops.aten.select.int(device_put_26, 0, 6) + _local_scalar_dense_214 = torch.ops.aten._local_scalar_dense.default(select_214); select_214 = None + ge_266 = _local_scalar_dense_214 >= 0 + _assert_scalar_214 = torch.ops.aten._assert_scalar.default(ge_266, "Runtime assertion failed for expression u214 >= 0 on node 'ge_214'"); ge_266 = _assert_scalar_214 = None + select_215 = torch.ops.aten.select.int(device_put_26, 0, 7); device_put_26 = None + _local_scalar_dense_215 = torch.ops.aten._local_scalar_dense.default(select_215); select_215 = None + ge_267 = _local_scalar_dense_215 >= 0 + _assert_scalar_215 = torch.ops.aten._assert_scalar.default(ge_267, "Runtime assertion failed for expression u215 >= 0 on node 'ge_215'"); ge_267 = _assert_scalar_215 = None + select_216 = torch.ops.aten.select.int(device_put_27, 0, 0) + _local_scalar_dense_216 = torch.ops.aten._local_scalar_dense.default(select_216); select_216 = None + ge_268 = _local_scalar_dense_216 >= 0 + _assert_scalar_216 = torch.ops.aten._assert_scalar.default(ge_268, "Runtime assertion failed for expression u216 >= 0 on node 'ge_216'"); ge_268 = _assert_scalar_216 = None + select_217 = torch.ops.aten.select.int(device_put_27, 0, 1) + _local_scalar_dense_217 = torch.ops.aten._local_scalar_dense.default(select_217); select_217 = None + ge_269 = _local_scalar_dense_217 >= 0 + _assert_scalar_217 = torch.ops.aten._assert_scalar.default(ge_269, "Runtime assertion failed for expression u217 >= 0 on node 'ge_217'"); ge_269 = _assert_scalar_217 = None + select_218 = torch.ops.aten.select.int(device_put_27, 0, 2) + _local_scalar_dense_218 = torch.ops.aten._local_scalar_dense.default(select_218); select_218 = None + ge_270 = _local_scalar_dense_218 >= 0 + _assert_scalar_218 = torch.ops.aten._assert_scalar.default(ge_270, "Runtime assertion failed for expression u218 >= 0 on node 'ge_218'"); ge_270 = _assert_scalar_218 = None + select_219 = torch.ops.aten.select.int(device_put_27, 0, 3) + _local_scalar_dense_219 = torch.ops.aten._local_scalar_dense.default(select_219); select_219 = None + ge_271 = _local_scalar_dense_219 >= 0 + _assert_scalar_219 = torch.ops.aten._assert_scalar.default(ge_271, "Runtime assertion failed for expression u219 >= 0 on node 'ge_219'"); ge_271 = _assert_scalar_219 = None + select_220 = torch.ops.aten.select.int(device_put_27, 0, 4) + _local_scalar_dense_220 = torch.ops.aten._local_scalar_dense.default(select_220); select_220 = None + ge_272 = _local_scalar_dense_220 >= 0 + _assert_scalar_220 = torch.ops.aten._assert_scalar.default(ge_272, "Runtime assertion failed for expression u220 >= 0 on node 'ge_220'"); ge_272 = _assert_scalar_220 = None + select_221 = torch.ops.aten.select.int(device_put_27, 0, 5) + _local_scalar_dense_221 = torch.ops.aten._local_scalar_dense.default(select_221); select_221 = None + ge_273 = _local_scalar_dense_221 >= 0 + _assert_scalar_221 = torch.ops.aten._assert_scalar.default(ge_273, "Runtime assertion failed for expression u221 >= 0 on node 'ge_221'"); ge_273 = _assert_scalar_221 = None + select_222 = torch.ops.aten.select.int(device_put_27, 0, 6) + _local_scalar_dense_222 = torch.ops.aten._local_scalar_dense.default(select_222); select_222 = None + ge_274 = _local_scalar_dense_222 >= 0 + _assert_scalar_222 = torch.ops.aten._assert_scalar.default(ge_274, "Runtime assertion failed for expression u222 >= 0 on node 'ge_222'"); ge_274 = _assert_scalar_222 = None + select_223 = torch.ops.aten.select.int(device_put_27, 0, 7); device_put_27 = None + _local_scalar_dense_223 = torch.ops.aten._local_scalar_dense.default(select_223); select_223 = None + ge_275 = _local_scalar_dense_223 >= 0 + _assert_scalar_223 = torch.ops.aten._assert_scalar.default(ge_275, "Runtime assertion failed for expression u223 >= 0 on node 'ge_223'"); ge_275 = _assert_scalar_223 = None + all_to_all_single_40 = torch.ops._c10d_functional.all_to_all_single.default(index_26, [_local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223], [_local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215], '521'); index_26 = None + sym_size_int_52 = torch.ops.aten.sym_size.int(all_to_all_single_40, 0) + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_40); all_to_all_single_40 = None + sym_sum_26 = torch.sym_sum((_local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223)) + add_902 = sym_sum_26 + 64; sym_sum_26 = None + add_903 = add_902 + 8; add_902 = None + sub_315 = add_903 - 1; add_903 = None + floordiv_13 = sub_315 // 8; sub_315 = None + mul_659 = floordiv_13 * 8; floordiv_13 = None + cumsum_39 = torch.ops.aten.cumsum.default(wait_tensor_293, 0) + sub_316 = torch.ops.aten.sub.Tensor(cumsum_39, wait_tensor_293); cumsum_39 = None + sum_56 = torch.ops.aten.sum.dim_IntList(view_936, [0]); view_936 = None + clamp_min_13 = torch.ops.aten.clamp_min.default(sum_56, 8); sum_56 = None + add_904 = torch.ops.aten.add.Tensor(clamp_min_13, 8); clamp_min_13 = None + sub_317 = torch.ops.aten.sub.Tensor(add_904, 1); add_904 = None + div_68 = torch.ops.aten.div.Tensor_mode(sub_317, 8, rounding_mode = 'floor'); sub_317 = None + mul_660 = torch.ops.aten.mul.Tensor(div_68, 8); div_68 = None + convert_element_type_770 = torch.ops.prims.convert_element_type.default(mul_660, torch.int32); mul_660 = None + cumsum_40 = torch.ops.aten.cumsum.default(convert_element_type_770, 0) + sub_318 = torch.ops.aten.sub.Tensor(cumsum_40, convert_element_type_770); cumsum_40 = None + full_189 = torch.ops.aten.full.default([mul_659], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_659 = None + triton_kernel_wrapper_functional_proxy_13 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 13, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_293, 'start_index_values_ptr': sub_316, 'write_offsets_ptr': sub_318, 'output_ptr': full_189}, tensors_to_clone = ['output_ptr']); wait_tensor_293 = sub_316 = sub_318 = full_189 = None + getitem_204 = triton_kernel_wrapper_functional_proxy_13['output_ptr']; triton_kernel_wrapper_functional_proxy_13 = None + cat_43 = torch.ops.aten.cat.default([wait_tensor_294, full_default]); wait_tensor_294 = None + sym_size_int_53 = torch.ops.aten.sym_size.int(cat_43, 0) + sym_sum_27 = torch.sym_sum((1, _local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223)) + index_27 = torch.ops.aten.index.Tensor(cat_43, [getitem_204]); cat_43 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16) + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_772, 8, '513'); convert_element_type_772 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16) + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_774, 8, '513'); convert_element_type_774 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16) + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_775, 8, '513'); convert_element_type_775 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + cumsum_41 = torch.ops.aten.cumsum.default(convert_element_type_770, 0, dtype = torch.int32); convert_element_type_770 = None + permute_215 = torch.ops.aten.permute.default(wait_tensor_295, [0, 2, 1]); wait_tensor_295 = None + _grouped_mm_39 = torch.ops.aten._grouped_mm.default(index_27, permute_215, cumsum_41); permute_215 = None + convert_element_type_778 = torch.ops.prims.convert_element_type.default(_grouped_mm_39, torch.float32) + neg_27 = torch.ops.aten.neg.default(convert_element_type_778) + exp_41 = torch.ops.aten.exp.default(neg_27); neg_27 = None + add_916 = torch.ops.aten.add.Tensor(exp_41, 1); exp_41 = None + div_69 = torch.ops.aten.div.Tensor(convert_element_type_778, add_916); convert_element_type_778 = add_916 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(div_69, torch.bfloat16); div_69 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_298, [0, 2, 1]); wait_tensor_298 = None + _grouped_mm_40 = torch.ops.aten._grouped_mm.default(index_27, permute_216, cumsum_41); permute_216 = None + mul_672 = torch.ops.aten.mul.Tensor(convert_element_type_779, _grouped_mm_40); convert_element_type_779 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_297, [0, 2, 1]); wait_tensor_297 = None + _grouped_mm_41 = torch.ops.aten._grouped_mm.default(mul_672, permute_217, cumsum_41); permute_217 = None + empty_13 = torch.ops.aten.empty.memory_format([sym_size_int_53, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_26 = torch.ops.aten.index_put.default(empty_13, [getitem_204], _grouped_mm_41); empty_13 = _grouped_mm_41 = None + slice_58 = torch.ops.aten.slice.Tensor(index_put_26, 0, 0, -1); index_put_26 = None + all_to_all_single_41 = torch.ops._c10d_functional.all_to_all_single.default(slice_58, [_local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215], [_local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223], '521'); slice_58 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_41); all_to_all_single_41 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16) + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_780, 64, '0'); convert_element_type_780 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_302, [1, 0]); wait_tensor_302 = None + mm_116 = torch.ops.aten.mm.default(view_929, permute_218); permute_218 = None + convert_element_type_783 = torch.ops.prims.convert_element_type.default(mm_116, torch.float32) + neg_28 = torch.ops.aten.neg.default(convert_element_type_783) + exp_42 = torch.ops.aten.exp.default(neg_28); neg_28 = None + add_952 = torch.ops.aten.add.Tensor(exp_42, 1); exp_42 = None + div_70 = torch.ops.aten.div.Tensor(convert_element_type_783, add_952); convert_element_type_783 = add_952 = None + convert_element_type_784 = torch.ops.prims.convert_element_type.default(div_70, torch.bfloat16); div_70 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16) + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_785, 64, '0'); convert_element_type_785 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_303, [1, 0]); wait_tensor_303 = None + mm_117 = torch.ops.aten.mm.default(view_929, permute_219); permute_219 = None + mul_692 = torch.ops.aten.mul.Tensor(convert_element_type_784, mm_117); convert_element_type_784 = None + convert_element_type_788 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16) + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_788, 64, '0'); convert_element_type_788 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_304, [1, 0]); wait_tensor_304 = None + mm_118 = torch.ops.aten.mm.default(mul_692, permute_220); permute_220 = None + index_put_27 = torch.ops.aten.index_put.default(full_default_1, [getitem_203], wait_tensor_301); wait_tensor_301 = None + view_969 = torch.ops.aten.view.default(mul_654, [-1, 1, 6]); mul_654 = None + view_970 = torch.ops.aten.view.default(index_put_27, [-1, 6, 2048]); index_put_27 = None + convert_element_type_791 = torch.ops.prims.convert_element_type.default(view_970, torch.float32); view_970 = None + bmm_13 = torch.ops.aten.bmm.default(view_969, convert_element_type_791) + convert_element_type_792 = torch.ops.prims.convert_element_type.default(bmm_13, torch.bfloat16); bmm_13 = None + squeeze_13 = torch.ops.aten.squeeze.dim(convert_element_type_792, 1); convert_element_type_792 = None + add_956 = torch.ops.aten.add.Tensor(mm_118, squeeze_13); mm_118 = squeeze_13 = None + view_971 = torch.ops.aten.view.default(add_956, [2, 4096, 2048]); add_956 = None + add_957 = torch.ops.aten.add.Tensor(add_892, view_971); view_971 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16) + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 64, '0'); convert_element_type_793 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_957, torch.float32) + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_958 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_958); add_958 = None + mul_695 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_45); convert_element_type_794 = None + mul_696 = torch.ops.aten.mul.Tensor(mul_695, wait_tensor_305); mul_695 = wait_tensor_305 = None + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_696, torch.bfloat16); mul_696 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16) + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 64, '0'); convert_element_type_796 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_306, [1, 0]); wait_tensor_306 = None + view_974 = torch.ops.aten.view.default(convert_element_type_795, [8192, 2048]); convert_element_type_795 = None + mm_119 = torch.ops.aten.mm.default(view_974, permute_221); permute_221 = None + view_975 = torch.ops.aten.view.default(mm_119, [2, 4096, 3072]); mm_119 = None + view_976 = torch.ops.aten.view.default(view_975, [2, 4096, -1, 192]); view_975 = None + split_with_sizes_45 = torch.ops.aten.split_with_sizes.default(view_976, [128, 64], -1); view_976 = None + getitem_205 = split_with_sizes_45[0] + getitem_206 = split_with_sizes_45[1]; split_with_sizes_45 = None + convert_element_type_799 = torch.ops.prims.convert_element_type.default(getitem_206, torch.float32); getitem_206 = None + view_977 = torch.ops.aten.view.default(convert_element_type_799, [2, 4096, 16, -1, 2]); convert_element_type_799 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_977); view_977 = None + mul_697 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_7); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_697); mul_697 = None + view_979 = torch.ops.aten.view.default(view_as_real_30, [2, 4096, 16, 64]); view_as_real_30 = None + convert_element_type_800 = torch.ops.prims.convert_element_type.default(view_979, torch.bfloat16); view_979 = None + cat_44 = torch.ops.aten.cat.default([getitem_205, convert_element_type_800], -1); getitem_205 = convert_element_type_800 = None + convert_element_type_801 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16) + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_801, 64, '0'); convert_element_type_801 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_307, [1, 0]); wait_tensor_307 = None + mm_120 = torch.ops.aten.mm.default(view_974, permute_222); permute_222 = None + view_982 = torch.ops.aten.view.default(mm_120, [2, 4096, 576]); mm_120 = None + split_with_sizes_46 = torch.ops.aten.split_with_sizes.default(view_982, [512, 64], -1); view_982 = None + getitem_207 = split_with_sizes_46[0] + getitem_208 = split_with_sizes_46[1]; split_with_sizes_46 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(getitem_208, 2); getitem_208 = None + convert_element_type_804 = torch.ops.prims.convert_element_type.default(unsqueeze_29, torch.float32); unsqueeze_29 = None + view_983 = torch.ops.aten.view.default(convert_element_type_804, [2, 4096, 1, -1, 2]); convert_element_type_804 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_983); view_983 = None + mul_698 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_7); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_698); mul_698 = None + view_985 = torch.ops.aten.view.default(view_as_real_31, [2, 4096, 1, 64]); view_as_real_31 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_985, torch.bfloat16); view_985 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16) + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_806, 64, '0'); convert_element_type_806 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(getitem_207, torch.float32) + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_807, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_959 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_959); add_959 = None + mul_699 = torch.ops.aten.mul.Tensor(convert_element_type_807, rsqrt_46); convert_element_type_807 = None + mul_700 = torch.ops.aten.mul.Tensor(mul_699, wait_tensor_308); mul_699 = wait_tensor_308 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(mul_700, torch.bfloat16); mul_700 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16) + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 64, '0'); convert_element_type_809 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + permute_223 = torch.ops.aten.permute.default(wait_tensor_309, [1, 0]); wait_tensor_309 = None + view_988 = torch.ops.aten.view.default(convert_element_type_808, [8192, 512]); convert_element_type_808 = None + mm_121 = torch.ops.aten.mm.default(view_988, permute_223); permute_223 = None + view_989 = torch.ops.aten.view.default(mm_121, [2, 4096, 4096]); mm_121 = None + view_990 = torch.ops.aten.view.default(view_989, [2, 4096, -1, 256]); view_989 = None + split_with_sizes_47 = torch.ops.aten.split_with_sizes.default(view_990, [128, 128], -1); view_990 = None + getitem_209 = split_with_sizes_47[0] + getitem_210 = split_with_sizes_47[1]; split_with_sizes_47 = None + expand_15 = torch.ops.aten.expand.default(convert_element_type_805, [-1, -1, 16, -1]); convert_element_type_805 = None + cat_45 = torch.ops.aten.cat.default([getitem_209, expand_15], -1); getitem_209 = expand_15 = None + permute_224 = torch.ops.aten.permute.default(cat_44, [0, 2, 1, 3]); cat_44 = None + permute_225 = torch.ops.aten.permute.default(cat_45, [0, 2, 1, 3]); cat_45 = None + permute_226 = torch.ops.aten.permute.default(getitem_210, [0, 2, 1, 3]); getitem_210 = None + sdpa_score15 = self.sdpa_score15 + sdpa_mask15 = self.sdpa_mask15 + flex_attention_15 = torch.ops.higher_order.flex_attention(permute_224, permute_225, permute_226, sdpa_score15, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask15), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score15 = sdpa_mask15 = None + getitem_211 = flex_attention_15[0] + getitem_212 = flex_attention_15[1]; flex_attention_15 = None + permute_227 = torch.ops.aten.permute.default(getitem_211, [0, 2, 1, 3]) + view_991 = torch.ops.aten.view.default(permute_227, [2, 4096, -1]); permute_227 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16) + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 64, '0'); convert_element_type_812 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_310, [1, 0]); wait_tensor_310 = None + view_993 = torch.ops.aten.view.default(view_991, [8192, 2048]); view_991 = None + mm_122 = torch.ops.aten.mm.default(view_993, permute_228); view_993 = permute_228 = None + view_994 = torch.ops.aten.view.default(mm_122, [2, 4096, 2048]); mm_122 = None + add_960 = torch.ops.aten.add.Tensor(add_957, view_994); view_994 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16) + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 64, '0'); convert_element_type_815 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + convert_element_type_816 = torch.ops.prims.convert_element_type.default(add_960, torch.float32) + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_816, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_961 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_961); add_961 = None + mul_701 = torch.ops.aten.mul.Tensor(convert_element_type_816, rsqrt_47); convert_element_type_816 = None + mul_702 = torch.ops.aten.mul.Tensor(mul_701, wait_tensor_311); mul_701 = wait_tensor_311 = None + convert_element_type_817 = torch.ops.prims.convert_element_type.default(mul_702, torch.bfloat16); mul_702 = None + view_996 = torch.ops.aten.view.default(convert_element_type_817, [-1, 2048]); convert_element_type_817 = None + convert_element_type_818 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16) + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_818, 64, '0'); convert_element_type_818 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_312, [1, 0]); wait_tensor_312 = None + mm_123 = torch.ops.aten.mm.default(view_996, permute_229); permute_229 = None + convert_element_type_821 = torch.ops.prims.convert_element_type.default(mm_123, torch.float32) + amax_14 = torch.ops.aten.amax.default(convert_element_type_821, [1], True) + sub_336 = torch.ops.aten.sub.Tensor(convert_element_type_821, amax_14); convert_element_type_821 = None + exp_43 = torch.ops.aten.exp.default(sub_336); sub_336 = None + sum_57 = torch.ops.aten.sum.dim_IntList(exp_43, [1], True) + div_71 = torch.ops.aten.div.Tensor(exp_43, sum_57); exp_43 = None + add_962 = torch.ops.aten.add.Tensor(div_71, primals_254); primals_254 = None + topk_14 = torch.ops.aten.topk.default(add_962, 6, -1, True, False); add_962 = None + getitem_215 = topk_14[1]; topk_14 = None + gather_14 = torch.ops.aten.gather.default(div_71, 1, getitem_215); div_71 = None + mul_703 = torch.ops.aten.mul.Tensor(gather_14, 1.0); gather_14 = None + view_998 = torch.ops.aten.view.default(getitem_215, [-1]) + histc_28 = torch.ops.aten.histc.default(view_998, 64, 0, 64) + add_963 = torch.ops.aten.add.Tensor(primals_256, histc_28) + sort_14 = torch.ops.aten.sort.stable(view_998, stable = True); view_998 = None + getitem_217 = sort_14[1]; sort_14 = None + div_72 = torch.ops.aten.div.Tensor_mode(getitem_217, 6, rounding_mode = 'floor') + index_28 = torch.ops.aten.index.Tensor(view_996, [div_72]) + all_to_all_single_42 = torch.ops._c10d_functional.all_to_all_single.default(histc_28, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_42); all_to_all_single_42 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_313); wait_tensor_313 = None + view_1002 = torch.ops.aten.view.default(histc_28, [8, -1]); histc_28 = None + sum_58 = torch.ops.aten.sum.dim_IntList(view_1002, [1]); view_1002 = None + device_put_28 = torch.ops.prims.device_put.default(sum_58, device(type='cpu'), True); sum_58 = None + view_1003 = torch.ops.aten.view.default(wait_tensor_314, [8, -1]) + sum_59 = torch.ops.aten.sum.dim_IntList(view_1003, [1]) + device_put_29 = torch.ops.prims.device_put.default(sum_59, device(type='cpu')); sum_59 = None + select_224 = torch.ops.aten.select.int(device_put_28, 0, 0) + _local_scalar_dense_224 = torch.ops.aten._local_scalar_dense.default(select_224); select_224 = None + ge_280 = _local_scalar_dense_224 >= 0 + _assert_scalar_224 = torch.ops.aten._assert_scalar.default(ge_280, "Runtime assertion failed for expression u224 >= 0 on node 'ge_224'"); ge_280 = _assert_scalar_224 = None + select_225 = torch.ops.aten.select.int(device_put_28, 0, 1) + _local_scalar_dense_225 = torch.ops.aten._local_scalar_dense.default(select_225); select_225 = None + ge_281 = _local_scalar_dense_225 >= 0 + _assert_scalar_225 = torch.ops.aten._assert_scalar.default(ge_281, "Runtime assertion failed for expression u225 >= 0 on node 'ge_225'"); ge_281 = _assert_scalar_225 = None + select_226 = torch.ops.aten.select.int(device_put_28, 0, 2) + _local_scalar_dense_226 = torch.ops.aten._local_scalar_dense.default(select_226); select_226 = None + ge_282 = _local_scalar_dense_226 >= 0 + _assert_scalar_226 = torch.ops.aten._assert_scalar.default(ge_282, "Runtime assertion failed for expression u226 >= 0 on node 'ge_226'"); ge_282 = _assert_scalar_226 = None + select_227 = torch.ops.aten.select.int(device_put_28, 0, 3) + _local_scalar_dense_227 = torch.ops.aten._local_scalar_dense.default(select_227); select_227 = None + ge_283 = _local_scalar_dense_227 >= 0 + _assert_scalar_227 = torch.ops.aten._assert_scalar.default(ge_283, "Runtime assertion failed for expression u227 >= 0 on node 'ge_227'"); ge_283 = _assert_scalar_227 = None + select_228 = torch.ops.aten.select.int(device_put_28, 0, 4) + _local_scalar_dense_228 = torch.ops.aten._local_scalar_dense.default(select_228); select_228 = None + ge_284 = _local_scalar_dense_228 >= 0 + _assert_scalar_228 = torch.ops.aten._assert_scalar.default(ge_284, "Runtime assertion failed for expression u228 >= 0 on node 'ge_228'"); ge_284 = _assert_scalar_228 = None + select_229 = torch.ops.aten.select.int(device_put_28, 0, 5) + _local_scalar_dense_229 = torch.ops.aten._local_scalar_dense.default(select_229); select_229 = None + ge_285 = _local_scalar_dense_229 >= 0 + _assert_scalar_229 = torch.ops.aten._assert_scalar.default(ge_285, "Runtime assertion failed for expression u229 >= 0 on node 'ge_229'"); ge_285 = _assert_scalar_229 = None + select_230 = torch.ops.aten.select.int(device_put_28, 0, 6) + _local_scalar_dense_230 = torch.ops.aten._local_scalar_dense.default(select_230); select_230 = None + ge_286 = _local_scalar_dense_230 >= 0 + _assert_scalar_230 = torch.ops.aten._assert_scalar.default(ge_286, "Runtime assertion failed for expression u230 >= 0 on node 'ge_230'"); ge_286 = _assert_scalar_230 = None + select_231 = torch.ops.aten.select.int(device_put_28, 0, 7); device_put_28 = None + _local_scalar_dense_231 = torch.ops.aten._local_scalar_dense.default(select_231); select_231 = None + ge_287 = _local_scalar_dense_231 >= 0 + _assert_scalar_231 = torch.ops.aten._assert_scalar.default(ge_287, "Runtime assertion failed for expression u231 >= 0 on node 'ge_231'"); ge_287 = _assert_scalar_231 = None + select_232 = torch.ops.aten.select.int(device_put_29, 0, 0) + _local_scalar_dense_232 = torch.ops.aten._local_scalar_dense.default(select_232); select_232 = None + ge_288 = _local_scalar_dense_232 >= 0 + _assert_scalar_232 = torch.ops.aten._assert_scalar.default(ge_288, "Runtime assertion failed for expression u232 >= 0 on node 'ge_232'"); ge_288 = _assert_scalar_232 = None + select_233 = torch.ops.aten.select.int(device_put_29, 0, 1) + _local_scalar_dense_233 = torch.ops.aten._local_scalar_dense.default(select_233); select_233 = None + ge_289 = _local_scalar_dense_233 >= 0 + _assert_scalar_233 = torch.ops.aten._assert_scalar.default(ge_289, "Runtime assertion failed for expression u233 >= 0 on node 'ge_233'"); ge_289 = _assert_scalar_233 = None + select_234 = torch.ops.aten.select.int(device_put_29, 0, 2) + _local_scalar_dense_234 = torch.ops.aten._local_scalar_dense.default(select_234); select_234 = None + ge_290 = _local_scalar_dense_234 >= 0 + _assert_scalar_234 = torch.ops.aten._assert_scalar.default(ge_290, "Runtime assertion failed for expression u234 >= 0 on node 'ge_234'"); ge_290 = _assert_scalar_234 = None + select_235 = torch.ops.aten.select.int(device_put_29, 0, 3) + _local_scalar_dense_235 = torch.ops.aten._local_scalar_dense.default(select_235); select_235 = None + ge_291 = _local_scalar_dense_235 >= 0 + _assert_scalar_235 = torch.ops.aten._assert_scalar.default(ge_291, "Runtime assertion failed for expression u235 >= 0 on node 'ge_235'"); ge_291 = _assert_scalar_235 = None + select_236 = torch.ops.aten.select.int(device_put_29, 0, 4) + _local_scalar_dense_236 = torch.ops.aten._local_scalar_dense.default(select_236); select_236 = None + ge_292 = _local_scalar_dense_236 >= 0 + _assert_scalar_236 = torch.ops.aten._assert_scalar.default(ge_292, "Runtime assertion failed for expression u236 >= 0 on node 'ge_236'"); ge_292 = _assert_scalar_236 = None + select_237 = torch.ops.aten.select.int(device_put_29, 0, 5) + _local_scalar_dense_237 = torch.ops.aten._local_scalar_dense.default(select_237); select_237 = None + ge_293 = _local_scalar_dense_237 >= 0 + _assert_scalar_237 = torch.ops.aten._assert_scalar.default(ge_293, "Runtime assertion failed for expression u237 >= 0 on node 'ge_237'"); ge_293 = _assert_scalar_237 = None + select_238 = torch.ops.aten.select.int(device_put_29, 0, 6) + _local_scalar_dense_238 = torch.ops.aten._local_scalar_dense.default(select_238); select_238 = None + ge_294 = _local_scalar_dense_238 >= 0 + _assert_scalar_238 = torch.ops.aten._assert_scalar.default(ge_294, "Runtime assertion failed for expression u238 >= 0 on node 'ge_238'"); ge_294 = _assert_scalar_238 = None + select_239 = torch.ops.aten.select.int(device_put_29, 0, 7); device_put_29 = None + _local_scalar_dense_239 = torch.ops.aten._local_scalar_dense.default(select_239); select_239 = None + ge_295 = _local_scalar_dense_239 >= 0 + _assert_scalar_239 = torch.ops.aten._assert_scalar.default(ge_295, "Runtime assertion failed for expression u239 >= 0 on node 'ge_239'"); ge_295 = _assert_scalar_239 = None + all_to_all_single_43 = torch.ops._c10d_functional.all_to_all_single.default(index_28, [_local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239], [_local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231], '521'); index_28 = None + sym_size_int_56 = torch.ops.aten.sym_size.int(all_to_all_single_43, 0) + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_43); all_to_all_single_43 = None + sym_sum_28 = torch.sym_sum((_local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239)) + add_970 = sym_sum_28 + 64; sym_sum_28 = None + add_971 = add_970 + 8; add_970 = None + sub_339 = add_971 - 1; add_971 = None + floordiv_14 = sub_339 // 8; sub_339 = None + mul_708 = floordiv_14 * 8; floordiv_14 = None + cumsum_42 = torch.ops.aten.cumsum.default(wait_tensor_314, 0) + sub_340 = torch.ops.aten.sub.Tensor(cumsum_42, wait_tensor_314); cumsum_42 = None + sum_60 = torch.ops.aten.sum.dim_IntList(view_1003, [0]); view_1003 = None + clamp_min_14 = torch.ops.aten.clamp_min.default(sum_60, 8); sum_60 = None + add_972 = torch.ops.aten.add.Tensor(clamp_min_14, 8); clamp_min_14 = None + sub_341 = torch.ops.aten.sub.Tensor(add_972, 1); add_972 = None + div_73 = torch.ops.aten.div.Tensor_mode(sub_341, 8, rounding_mode = 'floor'); sub_341 = None + mul_709 = torch.ops.aten.mul.Tensor(div_73, 8); div_73 = None + convert_element_type_824 = torch.ops.prims.convert_element_type.default(mul_709, torch.int32); mul_709 = None + cumsum_43 = torch.ops.aten.cumsum.default(convert_element_type_824, 0) + sub_342 = torch.ops.aten.sub.Tensor(cumsum_43, convert_element_type_824); cumsum_43 = None + full_202 = torch.ops.aten.full.default([mul_708], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_708 = None + triton_kernel_wrapper_functional_proxy_14 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 14, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_314, 'start_index_values_ptr': sub_340, 'write_offsets_ptr': sub_342, 'output_ptr': full_202}, tensors_to_clone = ['output_ptr']); wait_tensor_314 = sub_340 = sub_342 = full_202 = None + getitem_218 = triton_kernel_wrapper_functional_proxy_14['output_ptr']; triton_kernel_wrapper_functional_proxy_14 = None + cat_46 = torch.ops.aten.cat.default([wait_tensor_315, full_default]); wait_tensor_315 = None + sym_size_int_57 = torch.ops.aten.sym_size.int(cat_46, 0) + sym_sum_29 = torch.sym_sum((1, _local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239)) + index_29 = torch.ops.aten.index.Tensor(cat_46, [getitem_218]); cat_46 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16) + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 8, '513'); convert_element_type_826 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16) + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_828, 8, '513'); convert_element_type_828 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16) + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 8, '513'); convert_element_type_829 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + cumsum_44 = torch.ops.aten.cumsum.default(convert_element_type_824, 0, dtype = torch.int32); convert_element_type_824 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_316, [0, 2, 1]); wait_tensor_316 = None + _grouped_mm_42 = torch.ops.aten._grouped_mm.default(index_29, permute_230, cumsum_44); permute_230 = None + convert_element_type_832 = torch.ops.prims.convert_element_type.default(_grouped_mm_42, torch.float32) + neg_29 = torch.ops.aten.neg.default(convert_element_type_832) + exp_44 = torch.ops.aten.exp.default(neg_29); neg_29 = None + add_984 = torch.ops.aten.add.Tensor(exp_44, 1); exp_44 = None + div_74 = torch.ops.aten.div.Tensor(convert_element_type_832, add_984); convert_element_type_832 = add_984 = None + convert_element_type_833 = torch.ops.prims.convert_element_type.default(div_74, torch.bfloat16); div_74 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_319, [0, 2, 1]); wait_tensor_319 = None + _grouped_mm_43 = torch.ops.aten._grouped_mm.default(index_29, permute_231, cumsum_44); permute_231 = None + mul_721 = torch.ops.aten.mul.Tensor(convert_element_type_833, _grouped_mm_43); convert_element_type_833 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_318, [0, 2, 1]); wait_tensor_318 = None + _grouped_mm_44 = torch.ops.aten._grouped_mm.default(mul_721, permute_232, cumsum_44); permute_232 = None + empty_14 = torch.ops.aten.empty.memory_format([sym_size_int_57, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_28 = torch.ops.aten.index_put.default(empty_14, [getitem_218], _grouped_mm_44); empty_14 = _grouped_mm_44 = None + slice_62 = torch.ops.aten.slice.Tensor(index_put_28, 0, 0, -1); index_put_28 = None + all_to_all_single_44 = torch.ops._c10d_functional.all_to_all_single.default(slice_62, [_local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231], [_local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239], '521'); slice_62 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_44); all_to_all_single_44 = None + convert_element_type_834 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16) + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_834, 64, '0'); convert_element_type_834 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_323, [1, 0]); wait_tensor_323 = None + mm_124 = torch.ops.aten.mm.default(view_996, permute_233); permute_233 = None + convert_element_type_837 = torch.ops.prims.convert_element_type.default(mm_124, torch.float32) + neg_30 = torch.ops.aten.neg.default(convert_element_type_837) + exp_45 = torch.ops.aten.exp.default(neg_30); neg_30 = None + add_1020 = torch.ops.aten.add.Tensor(exp_45, 1); exp_45 = None + div_75 = torch.ops.aten.div.Tensor(convert_element_type_837, add_1020); convert_element_type_837 = add_1020 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(div_75, torch.bfloat16); div_75 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16) + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_839, 64, '0'); convert_element_type_839 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_234 = torch.ops.aten.permute.default(wait_tensor_324, [1, 0]); wait_tensor_324 = None + mm_125 = torch.ops.aten.mm.default(view_996, permute_234); permute_234 = None + mul_741 = torch.ops.aten.mul.Tensor(convert_element_type_838, mm_125); convert_element_type_838 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16) + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 64, '0'); convert_element_type_842 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_235 = torch.ops.aten.permute.default(wait_tensor_325, [1, 0]); wait_tensor_325 = None + mm_126 = torch.ops.aten.mm.default(mul_741, permute_235); permute_235 = None + index_put_29 = torch.ops.aten.index_put.default(full_default_1, [getitem_217], wait_tensor_322); wait_tensor_322 = None + view_1036 = torch.ops.aten.view.default(mul_703, [-1, 1, 6]); mul_703 = None + view_1037 = torch.ops.aten.view.default(index_put_29, [-1, 6, 2048]); index_put_29 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(view_1037, torch.float32); view_1037 = None + bmm_14 = torch.ops.aten.bmm.default(view_1036, convert_element_type_845) + convert_element_type_846 = torch.ops.prims.convert_element_type.default(bmm_14, torch.bfloat16); bmm_14 = None + squeeze_14 = torch.ops.aten.squeeze.dim(convert_element_type_846, 1); convert_element_type_846 = None + add_1024 = torch.ops.aten.add.Tensor(mm_126, squeeze_14); mm_126 = squeeze_14 = None + view_1038 = torch.ops.aten.view.default(add_1024, [2, 4096, 2048]); add_1024 = None + add_1025 = torch.ops.aten.add.Tensor(add_960, view_1038); view_1038 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16) + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_847, 64, '0'); convert_element_type_847 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(add_1025, torch.float32) + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_848, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_1026 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_1026); add_1026 = None + mul_744 = torch.ops.aten.mul.Tensor(convert_element_type_848, rsqrt_48); convert_element_type_848 = None + mul_745 = torch.ops.aten.mul.Tensor(mul_744, wait_tensor_326); mul_744 = wait_tensor_326 = None + convert_element_type_849 = torch.ops.prims.convert_element_type.default(mul_745, torch.bfloat16); mul_745 = None + convert_element_type_850 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16) + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_850, 64, '0'); convert_element_type_850 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + permute_236 = torch.ops.aten.permute.default(wait_tensor_327, [1, 0]); wait_tensor_327 = None + view_1041 = torch.ops.aten.view.default(convert_element_type_849, [8192, 2048]); convert_element_type_849 = None + mm_127 = torch.ops.aten.mm.default(view_1041, permute_236); permute_236 = None + view_1042 = torch.ops.aten.view.default(mm_127, [2, 4096, 3072]); mm_127 = None + view_1043 = torch.ops.aten.view.default(view_1042, [2, 4096, -1, 192]); view_1042 = None + split_with_sizes_48 = torch.ops.aten.split_with_sizes.default(view_1043, [128, 64], -1); view_1043 = None + getitem_219 = split_with_sizes_48[0] + getitem_220 = split_with_sizes_48[1]; split_with_sizes_48 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(getitem_220, torch.float32); getitem_220 = None + view_1044 = torch.ops.aten.view.default(convert_element_type_853, [2, 4096, 16, -1, 2]); convert_element_type_853 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_1044); view_1044 = None + mul_746 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_7); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_746); mul_746 = None + view_1046 = torch.ops.aten.view.default(view_as_real_32, [2, 4096, 16, 64]); view_as_real_32 = None + convert_element_type_854 = torch.ops.prims.convert_element_type.default(view_1046, torch.bfloat16); view_1046 = None + cat_47 = torch.ops.aten.cat.default([getitem_219, convert_element_type_854], -1); getitem_219 = convert_element_type_854 = None + convert_element_type_855 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16) + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_855, 64, '0'); convert_element_type_855 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_237 = torch.ops.aten.permute.default(wait_tensor_328, [1, 0]); wait_tensor_328 = None + mm_128 = torch.ops.aten.mm.default(view_1041, permute_237); permute_237 = None + view_1049 = torch.ops.aten.view.default(mm_128, [2, 4096, 576]); mm_128 = None + split_with_sizes_49 = torch.ops.aten.split_with_sizes.default(view_1049, [512, 64], -1); view_1049 = None + getitem_221 = split_with_sizes_49[0] + getitem_222 = split_with_sizes_49[1]; split_with_sizes_49 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(getitem_222, 2); getitem_222 = None + convert_element_type_858 = torch.ops.prims.convert_element_type.default(unsqueeze_31, torch.float32); unsqueeze_31 = None + view_1050 = torch.ops.aten.view.default(convert_element_type_858, [2, 4096, 1, -1, 2]); convert_element_type_858 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_1050); view_1050 = None + mul_747 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_7); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_747); mul_747 = None + view_1052 = torch.ops.aten.view.default(view_as_real_33, [2, 4096, 1, 64]); view_as_real_33 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(view_1052, torch.bfloat16); view_1052 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16) + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_860, 64, '0'); convert_element_type_860 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(getitem_221, torch.float32) + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_861, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_1027 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_1027); add_1027 = None + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_861, rsqrt_49); convert_element_type_861 = None + mul_749 = torch.ops.aten.mul.Tensor(mul_748, wait_tensor_329); mul_748 = wait_tensor_329 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(mul_749, torch.bfloat16); mul_749 = None + convert_element_type_863 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16) + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_863, 64, '0'); convert_element_type_863 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_330, [1, 0]); wait_tensor_330 = None + view_1055 = torch.ops.aten.view.default(convert_element_type_862, [8192, 512]); convert_element_type_862 = None + mm_129 = torch.ops.aten.mm.default(view_1055, permute_238); permute_238 = None + view_1056 = torch.ops.aten.view.default(mm_129, [2, 4096, 4096]); mm_129 = None + view_1057 = torch.ops.aten.view.default(view_1056, [2, 4096, -1, 256]); view_1056 = None + split_with_sizes_50 = torch.ops.aten.split_with_sizes.default(view_1057, [128, 128], -1); view_1057 = None + getitem_223 = split_with_sizes_50[0] + getitem_224 = split_with_sizes_50[1]; split_with_sizes_50 = None + expand_16 = torch.ops.aten.expand.default(convert_element_type_859, [-1, -1, 16, -1]); convert_element_type_859 = None + cat_48 = torch.ops.aten.cat.default([getitem_223, expand_16], -1); getitem_223 = expand_16 = None + permute_239 = torch.ops.aten.permute.default(cat_47, [0, 2, 1, 3]); cat_47 = None + permute_240 = torch.ops.aten.permute.default(cat_48, [0, 2, 1, 3]); cat_48 = None + permute_241 = torch.ops.aten.permute.default(getitem_224, [0, 2, 1, 3]); getitem_224 = None + sdpa_score16 = self.sdpa_score16 + sdpa_mask16 = self.sdpa_mask16 + flex_attention_16 = torch.ops.higher_order.flex_attention(permute_239, permute_240, permute_241, sdpa_score16, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask16), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score16 = sdpa_mask16 = None + getitem_225 = flex_attention_16[0] + getitem_226 = flex_attention_16[1]; flex_attention_16 = None + permute_242 = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]) + view_1058 = torch.ops.aten.view.default(permute_242, [2, 4096, -1]); permute_242 = None + convert_element_type_866 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16) + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_866, 64, '0'); convert_element_type_866 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_331, [1, 0]); wait_tensor_331 = None + view_1060 = torch.ops.aten.view.default(view_1058, [8192, 2048]); view_1058 = None + mm_130 = torch.ops.aten.mm.default(view_1060, permute_243); view_1060 = permute_243 = None + view_1061 = torch.ops.aten.view.default(mm_130, [2, 4096, 2048]); mm_130 = None + add_1028 = torch.ops.aten.add.Tensor(add_1025, view_1061); view_1061 = None + convert_element_type_869 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16) + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_869, 64, '0'); convert_element_type_869 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + convert_element_type_870 = torch.ops.prims.convert_element_type.default(add_1028, torch.float32) + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_870, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_1029 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_1029); add_1029 = None + mul_750 = torch.ops.aten.mul.Tensor(convert_element_type_870, rsqrt_50); convert_element_type_870 = None + mul_751 = torch.ops.aten.mul.Tensor(mul_750, wait_tensor_332); mul_750 = wait_tensor_332 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(mul_751, torch.bfloat16); mul_751 = None + view_1063 = torch.ops.aten.view.default(convert_element_type_871, [-1, 2048]); convert_element_type_871 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16) + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_872, 64, '0'); convert_element_type_872 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_333, [1, 0]); wait_tensor_333 = None + mm_131 = torch.ops.aten.mm.default(view_1063, permute_244); permute_244 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(mm_131, torch.float32) + amax_15 = torch.ops.aten.amax.default(convert_element_type_875, [1], True) + sub_360 = torch.ops.aten.sub.Tensor(convert_element_type_875, amax_15); convert_element_type_875 = None + exp_46 = torch.ops.aten.exp.default(sub_360); sub_360 = None + sum_61 = torch.ops.aten.sum.dim_IntList(exp_46, [1], True) + div_76 = torch.ops.aten.div.Tensor(exp_46, sum_61); exp_46 = None + add_1030 = torch.ops.aten.add.Tensor(div_76, primals_270); primals_270 = None + topk_15 = torch.ops.aten.topk.default(add_1030, 6, -1, True, False); add_1030 = None + getitem_229 = topk_15[1]; topk_15 = None + gather_15 = torch.ops.aten.gather.default(div_76, 1, getitem_229); div_76 = None + mul_752 = torch.ops.aten.mul.Tensor(gather_15, 1.0); gather_15 = None + view_1065 = torch.ops.aten.view.default(getitem_229, [-1]) + histc_30 = torch.ops.aten.histc.default(view_1065, 64, 0, 64) + add_1031 = torch.ops.aten.add.Tensor(primals_272, histc_30) + sort_15 = torch.ops.aten.sort.stable(view_1065, stable = True); view_1065 = None + getitem_231 = sort_15[1]; sort_15 = None + div_77 = torch.ops.aten.div.Tensor_mode(getitem_231, 6, rounding_mode = 'floor') + index_30 = torch.ops.aten.index.Tensor(view_1063, [div_77]) + all_to_all_single_45 = torch.ops._c10d_functional.all_to_all_single.default(histc_30, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_45); all_to_all_single_45 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_334); wait_tensor_334 = None + view_1069 = torch.ops.aten.view.default(histc_30, [8, -1]); histc_30 = None + sum_62 = torch.ops.aten.sum.dim_IntList(view_1069, [1]); view_1069 = None + device_put_30 = torch.ops.prims.device_put.default(sum_62, device(type='cpu'), True); sum_62 = None + view_1070 = torch.ops.aten.view.default(wait_tensor_335, [8, -1]) + sum_63 = torch.ops.aten.sum.dim_IntList(view_1070, [1]) + device_put_31 = torch.ops.prims.device_put.default(sum_63, device(type='cpu')); sum_63 = None + select_240 = torch.ops.aten.select.int(device_put_30, 0, 0) + _local_scalar_dense_240 = torch.ops.aten._local_scalar_dense.default(select_240); select_240 = None + ge_300 = _local_scalar_dense_240 >= 0 + _assert_scalar_240 = torch.ops.aten._assert_scalar.default(ge_300, "Runtime assertion failed for expression u240 >= 0 on node 'ge_240'"); ge_300 = _assert_scalar_240 = None + select_241 = torch.ops.aten.select.int(device_put_30, 0, 1) + _local_scalar_dense_241 = torch.ops.aten._local_scalar_dense.default(select_241); select_241 = None + ge_301 = _local_scalar_dense_241 >= 0 + _assert_scalar_241 = torch.ops.aten._assert_scalar.default(ge_301, "Runtime assertion failed for expression u241 >= 0 on node 'ge_241'"); ge_301 = _assert_scalar_241 = None + select_242 = torch.ops.aten.select.int(device_put_30, 0, 2) + _local_scalar_dense_242 = torch.ops.aten._local_scalar_dense.default(select_242); select_242 = None + ge_302 = _local_scalar_dense_242 >= 0 + _assert_scalar_242 = torch.ops.aten._assert_scalar.default(ge_302, "Runtime assertion failed for expression u242 >= 0 on node 'ge_242'"); ge_302 = _assert_scalar_242 = None + select_243 = torch.ops.aten.select.int(device_put_30, 0, 3) + _local_scalar_dense_243 = torch.ops.aten._local_scalar_dense.default(select_243); select_243 = None + ge_303 = _local_scalar_dense_243 >= 0 + _assert_scalar_243 = torch.ops.aten._assert_scalar.default(ge_303, "Runtime assertion failed for expression u243 >= 0 on node 'ge_243'"); ge_303 = _assert_scalar_243 = None + select_244 = torch.ops.aten.select.int(device_put_30, 0, 4) + _local_scalar_dense_244 = torch.ops.aten._local_scalar_dense.default(select_244); select_244 = None + ge_304 = _local_scalar_dense_244 >= 0 + _assert_scalar_244 = torch.ops.aten._assert_scalar.default(ge_304, "Runtime assertion failed for expression u244 >= 0 on node 'ge_244'"); ge_304 = _assert_scalar_244 = None + select_245 = torch.ops.aten.select.int(device_put_30, 0, 5) + _local_scalar_dense_245 = torch.ops.aten._local_scalar_dense.default(select_245); select_245 = None + ge_305 = _local_scalar_dense_245 >= 0 + _assert_scalar_245 = torch.ops.aten._assert_scalar.default(ge_305, "Runtime assertion failed for expression u245 >= 0 on node 'ge_245'"); ge_305 = _assert_scalar_245 = None + select_246 = torch.ops.aten.select.int(device_put_30, 0, 6) + _local_scalar_dense_246 = torch.ops.aten._local_scalar_dense.default(select_246); select_246 = None + ge_306 = _local_scalar_dense_246 >= 0 + _assert_scalar_246 = torch.ops.aten._assert_scalar.default(ge_306, "Runtime assertion failed for expression u246 >= 0 on node 'ge_246'"); ge_306 = _assert_scalar_246 = None + select_247 = torch.ops.aten.select.int(device_put_30, 0, 7); device_put_30 = None + _local_scalar_dense_247 = torch.ops.aten._local_scalar_dense.default(select_247); select_247 = None + ge_307 = _local_scalar_dense_247 >= 0 + _assert_scalar_247 = torch.ops.aten._assert_scalar.default(ge_307, "Runtime assertion failed for expression u247 >= 0 on node 'ge_247'"); ge_307 = _assert_scalar_247 = None + select_248 = torch.ops.aten.select.int(device_put_31, 0, 0) + _local_scalar_dense_248 = torch.ops.aten._local_scalar_dense.default(select_248); select_248 = None + ge_308 = _local_scalar_dense_248 >= 0 + _assert_scalar_248 = torch.ops.aten._assert_scalar.default(ge_308, "Runtime assertion failed for expression u248 >= 0 on node 'ge_248'"); ge_308 = _assert_scalar_248 = None + select_249 = torch.ops.aten.select.int(device_put_31, 0, 1) + _local_scalar_dense_249 = torch.ops.aten._local_scalar_dense.default(select_249); select_249 = None + ge_309 = _local_scalar_dense_249 >= 0 + _assert_scalar_249 = torch.ops.aten._assert_scalar.default(ge_309, "Runtime assertion failed for expression u249 >= 0 on node 'ge_249'"); ge_309 = _assert_scalar_249 = None + select_250 = torch.ops.aten.select.int(device_put_31, 0, 2) + _local_scalar_dense_250 = torch.ops.aten._local_scalar_dense.default(select_250); select_250 = None + ge_310 = _local_scalar_dense_250 >= 0 + _assert_scalar_250 = torch.ops.aten._assert_scalar.default(ge_310, "Runtime assertion failed for expression u250 >= 0 on node 'ge_250'"); ge_310 = _assert_scalar_250 = None + select_251 = torch.ops.aten.select.int(device_put_31, 0, 3) + _local_scalar_dense_251 = torch.ops.aten._local_scalar_dense.default(select_251); select_251 = None + ge_311 = _local_scalar_dense_251 >= 0 + _assert_scalar_251 = torch.ops.aten._assert_scalar.default(ge_311, "Runtime assertion failed for expression u251 >= 0 on node 'ge_251'"); ge_311 = _assert_scalar_251 = None + select_252 = torch.ops.aten.select.int(device_put_31, 0, 4) + _local_scalar_dense_252 = torch.ops.aten._local_scalar_dense.default(select_252); select_252 = None + ge_312 = _local_scalar_dense_252 >= 0 + _assert_scalar_252 = torch.ops.aten._assert_scalar.default(ge_312, "Runtime assertion failed for expression u252 >= 0 on node 'ge_252'"); ge_312 = _assert_scalar_252 = None + select_253 = torch.ops.aten.select.int(device_put_31, 0, 5) + _local_scalar_dense_253 = torch.ops.aten._local_scalar_dense.default(select_253); select_253 = None + ge_313 = _local_scalar_dense_253 >= 0 + _assert_scalar_253 = torch.ops.aten._assert_scalar.default(ge_313, "Runtime assertion failed for expression u253 >= 0 on node 'ge_253'"); ge_313 = _assert_scalar_253 = None + select_254 = torch.ops.aten.select.int(device_put_31, 0, 6) + _local_scalar_dense_254 = torch.ops.aten._local_scalar_dense.default(select_254); select_254 = None + ge_314 = _local_scalar_dense_254 >= 0 + _assert_scalar_254 = torch.ops.aten._assert_scalar.default(ge_314, "Runtime assertion failed for expression u254 >= 0 on node 'ge_254'"); ge_314 = _assert_scalar_254 = None + select_255 = torch.ops.aten.select.int(device_put_31, 0, 7); device_put_31 = None + _local_scalar_dense_255 = torch.ops.aten._local_scalar_dense.default(select_255); select_255 = None + ge_315 = _local_scalar_dense_255 >= 0 + _assert_scalar_255 = torch.ops.aten._assert_scalar.default(ge_315, "Runtime assertion failed for expression u255 >= 0 on node 'ge_255'"); ge_315 = _assert_scalar_255 = None + all_to_all_single_46 = torch.ops._c10d_functional.all_to_all_single.default(index_30, [_local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255], [_local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247], '521'); index_30 = None + sym_size_int_60 = torch.ops.aten.sym_size.int(all_to_all_single_46, 0) + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_46); all_to_all_single_46 = None + sym_sum_30 = torch.sym_sum((_local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255)) + add_1038 = sym_sum_30 + 64; sym_sum_30 = None + add_1039 = add_1038 + 8; add_1038 = None + sub_363 = add_1039 - 1; add_1039 = None + floordiv_15 = sub_363 // 8; sub_363 = None + mul_757 = floordiv_15 * 8; floordiv_15 = None + cumsum_45 = torch.ops.aten.cumsum.default(wait_tensor_335, 0) + sub_364 = torch.ops.aten.sub.Tensor(cumsum_45, wait_tensor_335); cumsum_45 = None + sum_64 = torch.ops.aten.sum.dim_IntList(view_1070, [0]); view_1070 = None + clamp_min_15 = torch.ops.aten.clamp_min.default(sum_64, 8); sum_64 = None + add_1040 = torch.ops.aten.add.Tensor(clamp_min_15, 8); clamp_min_15 = None + sub_365 = torch.ops.aten.sub.Tensor(add_1040, 1); add_1040 = None + div_78 = torch.ops.aten.div.Tensor_mode(sub_365, 8, rounding_mode = 'floor'); sub_365 = None + mul_758 = torch.ops.aten.mul.Tensor(div_78, 8); div_78 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(mul_758, torch.int32); mul_758 = None + cumsum_46 = torch.ops.aten.cumsum.default(convert_element_type_878, 0) + sub_366 = torch.ops.aten.sub.Tensor(cumsum_46, convert_element_type_878); cumsum_46 = None + full_215 = torch.ops.aten.full.default([mul_757], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_757 = None + triton_kernel_wrapper_functional_proxy_15 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 15, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_335, 'start_index_values_ptr': sub_364, 'write_offsets_ptr': sub_366, 'output_ptr': full_215}, tensors_to_clone = ['output_ptr']); wait_tensor_335 = sub_364 = sub_366 = full_215 = None + getitem_232 = triton_kernel_wrapper_functional_proxy_15['output_ptr']; triton_kernel_wrapper_functional_proxy_15 = None + cat_49 = torch.ops.aten.cat.default([wait_tensor_336, full_default]); wait_tensor_336 = None + sym_size_int_61 = torch.ops.aten.sym_size.int(cat_49, 0) + sym_sum_31 = torch.sym_sum((1, _local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255)) + index_31 = torch.ops.aten.index.Tensor(cat_49, [getitem_232]); cat_49 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16) + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_880, 8, '513'); convert_element_type_880 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + convert_element_type_882 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16) + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_882, 8, '513'); convert_element_type_882 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_883 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16) + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_883, 8, '513'); convert_element_type_883 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + cumsum_47 = torch.ops.aten.cumsum.default(convert_element_type_878, 0, dtype = torch.int32); convert_element_type_878 = None + permute_245 = torch.ops.aten.permute.default(wait_tensor_337, [0, 2, 1]); wait_tensor_337 = None + _grouped_mm_45 = torch.ops.aten._grouped_mm.default(index_31, permute_245, cumsum_47); permute_245 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(_grouped_mm_45, torch.float32) + neg_31 = torch.ops.aten.neg.default(convert_element_type_886) + exp_47 = torch.ops.aten.exp.default(neg_31); neg_31 = None + add_1052 = torch.ops.aten.add.Tensor(exp_47, 1); exp_47 = None + div_79 = torch.ops.aten.div.Tensor(convert_element_type_886, add_1052); convert_element_type_886 = add_1052 = None + convert_element_type_887 = torch.ops.prims.convert_element_type.default(div_79, torch.bfloat16); div_79 = None + permute_246 = torch.ops.aten.permute.default(wait_tensor_340, [0, 2, 1]); wait_tensor_340 = None + _grouped_mm_46 = torch.ops.aten._grouped_mm.default(index_31, permute_246, cumsum_47); permute_246 = None + mul_770 = torch.ops.aten.mul.Tensor(convert_element_type_887, _grouped_mm_46); convert_element_type_887 = None + permute_247 = torch.ops.aten.permute.default(wait_tensor_339, [0, 2, 1]); wait_tensor_339 = None + _grouped_mm_47 = torch.ops.aten._grouped_mm.default(mul_770, permute_247, cumsum_47); permute_247 = None + empty_15 = torch.ops.aten.empty.memory_format([sym_size_int_61, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_30 = torch.ops.aten.index_put.default(empty_15, [getitem_232], _grouped_mm_47); empty_15 = _grouped_mm_47 = None + slice_66 = torch.ops.aten.slice.Tensor(index_put_30, 0, 0, -1); index_put_30 = None + all_to_all_single_47 = torch.ops._c10d_functional.all_to_all_single.default(slice_66, [_local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247], [_local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255], '521'); slice_66 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_47); all_to_all_single_47 = None + convert_element_type_888 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16) + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_888, 64, '0'); convert_element_type_888 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + permute_248 = torch.ops.aten.permute.default(wait_tensor_344, [1, 0]); wait_tensor_344 = None + mm_132 = torch.ops.aten.mm.default(view_1063, permute_248); permute_248 = None + convert_element_type_891 = torch.ops.prims.convert_element_type.default(mm_132, torch.float32) + neg_32 = torch.ops.aten.neg.default(convert_element_type_891) + exp_48 = torch.ops.aten.exp.default(neg_32); neg_32 = None + add_1088 = torch.ops.aten.add.Tensor(exp_48, 1); exp_48 = None + div_80 = torch.ops.aten.div.Tensor(convert_element_type_891, add_1088); convert_element_type_891 = add_1088 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(div_80, torch.bfloat16); div_80 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16) + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_893, 64, '0'); convert_element_type_893 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_345, [1, 0]); wait_tensor_345 = None + mm_133 = torch.ops.aten.mm.default(view_1063, permute_249); permute_249 = None + mul_790 = torch.ops.aten.mul.Tensor(convert_element_type_892, mm_133); convert_element_type_892 = None + convert_element_type_896 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16) + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_896, 64, '0'); convert_element_type_896 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_346, [1, 0]); wait_tensor_346 = None + mm_134 = torch.ops.aten.mm.default(mul_790, permute_250); permute_250 = None + index_put_31 = torch.ops.aten.index_put.default(full_default_1, [getitem_231], wait_tensor_343); wait_tensor_343 = None + view_1103 = torch.ops.aten.view.default(mul_752, [-1, 1, 6]); mul_752 = None + view_1104 = torch.ops.aten.view.default(index_put_31, [-1, 6, 2048]); index_put_31 = None + convert_element_type_899 = torch.ops.prims.convert_element_type.default(view_1104, torch.float32); view_1104 = None + bmm_15 = torch.ops.aten.bmm.default(view_1103, convert_element_type_899) + convert_element_type_900 = torch.ops.prims.convert_element_type.default(bmm_15, torch.bfloat16); bmm_15 = None + squeeze_15 = torch.ops.aten.squeeze.dim(convert_element_type_900, 1); convert_element_type_900 = None + add_1092 = torch.ops.aten.add.Tensor(mm_134, squeeze_15); mm_134 = squeeze_15 = None + view_1105 = torch.ops.aten.view.default(add_1092, [2, 4096, 2048]); add_1092 = None + add_1093 = torch.ops.aten.add.Tensor(add_1028, view_1105); view_1105 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16) + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 64, '0'); convert_element_type_901 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + convert_element_type_902 = torch.ops.prims.convert_element_type.default(add_1093, torch.float32) + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_902, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_1094 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_1094); add_1094 = None + mul_793 = torch.ops.aten.mul.Tensor(convert_element_type_902, rsqrt_51); convert_element_type_902 = None + mul_794 = torch.ops.aten.mul.Tensor(mul_793, wait_tensor_347); mul_793 = wait_tensor_347 = None + convert_element_type_903 = torch.ops.prims.convert_element_type.default(mul_794, torch.bfloat16); mul_794 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16) + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_904, 64, '0'); convert_element_type_904 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_348, [1, 0]); wait_tensor_348 = None + view_1108 = torch.ops.aten.view.default(convert_element_type_903, [8192, 2048]); convert_element_type_903 = None + mm_135 = torch.ops.aten.mm.default(view_1108, permute_251); permute_251 = None + view_1109 = torch.ops.aten.view.default(mm_135, [2, 4096, 3072]); mm_135 = None + view_1110 = torch.ops.aten.view.default(view_1109, [2, 4096, -1, 192]); view_1109 = None + split_with_sizes_51 = torch.ops.aten.split_with_sizes.default(view_1110, [128, 64], -1); view_1110 = None + getitem_233 = split_with_sizes_51[0] + getitem_234 = split_with_sizes_51[1]; split_with_sizes_51 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(getitem_234, torch.float32); getitem_234 = None + view_1111 = torch.ops.aten.view.default(convert_element_type_907, [2, 4096, 16, -1, 2]); convert_element_type_907 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_1111); view_1111 = None + mul_795 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_7); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_795); mul_795 = None + view_1113 = torch.ops.aten.view.default(view_as_real_34, [2, 4096, 16, 64]); view_as_real_34 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(view_1113, torch.bfloat16); view_1113 = None + cat_50 = torch.ops.aten.cat.default([getitem_233, convert_element_type_908], -1); getitem_233 = convert_element_type_908 = None + convert_element_type_909 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16) + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_909, 64, '0'); convert_element_type_909 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_349, [1, 0]); wait_tensor_349 = None + mm_136 = torch.ops.aten.mm.default(view_1108, permute_252); permute_252 = None + view_1116 = torch.ops.aten.view.default(mm_136, [2, 4096, 576]); mm_136 = None + split_with_sizes_52 = torch.ops.aten.split_with_sizes.default(view_1116, [512, 64], -1); view_1116 = None + getitem_235 = split_with_sizes_52[0] + getitem_236 = split_with_sizes_52[1]; split_with_sizes_52 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(getitem_236, 2); getitem_236 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(unsqueeze_33, torch.float32); unsqueeze_33 = None + view_1117 = torch.ops.aten.view.default(convert_element_type_912, [2, 4096, 1, -1, 2]); convert_element_type_912 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_1117); view_1117 = None + mul_796 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_7); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_796); mul_796 = None + view_1119 = torch.ops.aten.view.default(view_as_real_35, [2, 4096, 1, 64]); view_as_real_35 = None + convert_element_type_913 = torch.ops.prims.convert_element_type.default(view_1119, torch.bfloat16); view_1119 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16) + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 64, '0'); convert_element_type_914 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + convert_element_type_915 = torch.ops.prims.convert_element_type.default(getitem_235, torch.float32) + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_915, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_1095 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_1095); add_1095 = None + mul_797 = torch.ops.aten.mul.Tensor(convert_element_type_915, rsqrt_52); convert_element_type_915 = None + mul_798 = torch.ops.aten.mul.Tensor(mul_797, wait_tensor_350); mul_797 = wait_tensor_350 = None + convert_element_type_916 = torch.ops.prims.convert_element_type.default(mul_798, torch.bfloat16); mul_798 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16) + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_917, 64, '0'); convert_element_type_917 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_351, [1, 0]); wait_tensor_351 = None + view_1122 = torch.ops.aten.view.default(convert_element_type_916, [8192, 512]); convert_element_type_916 = None + mm_137 = torch.ops.aten.mm.default(view_1122, permute_253); permute_253 = None + view_1123 = torch.ops.aten.view.default(mm_137, [2, 4096, 4096]); mm_137 = None + view_1124 = torch.ops.aten.view.default(view_1123, [2, 4096, -1, 256]); view_1123 = None + split_with_sizes_53 = torch.ops.aten.split_with_sizes.default(view_1124, [128, 128], -1); view_1124 = None + getitem_237 = split_with_sizes_53[0] + getitem_238 = split_with_sizes_53[1]; split_with_sizes_53 = None + expand_17 = torch.ops.aten.expand.default(convert_element_type_913, [-1, -1, 16, -1]); convert_element_type_913 = None + cat_51 = torch.ops.aten.cat.default([getitem_237, expand_17], -1); getitem_237 = expand_17 = None + permute_254 = torch.ops.aten.permute.default(cat_50, [0, 2, 1, 3]); cat_50 = None + permute_255 = torch.ops.aten.permute.default(cat_51, [0, 2, 1, 3]); cat_51 = None + permute_256 = torch.ops.aten.permute.default(getitem_238, [0, 2, 1, 3]); getitem_238 = None + sdpa_score17 = self.sdpa_score17 + sdpa_mask17 = self.sdpa_mask17 + flex_attention_17 = torch.ops.higher_order.flex_attention(permute_254, permute_255, permute_256, sdpa_score17, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask17), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score17 = sdpa_mask17 = None + getitem_239 = flex_attention_17[0] + getitem_240 = flex_attention_17[1]; flex_attention_17 = None + permute_257 = torch.ops.aten.permute.default(getitem_239, [0, 2, 1, 3]) + view_1125 = torch.ops.aten.view.default(permute_257, [2, 4096, -1]); permute_257 = None + convert_element_type_920 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16) + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_920, 64, '0'); convert_element_type_920 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_258 = torch.ops.aten.permute.default(wait_tensor_352, [1, 0]); wait_tensor_352 = None + view_1127 = torch.ops.aten.view.default(view_1125, [8192, 2048]); view_1125 = None + mm_138 = torch.ops.aten.mm.default(view_1127, permute_258); view_1127 = permute_258 = None + view_1128 = torch.ops.aten.view.default(mm_138, [2, 4096, 2048]); mm_138 = None + add_1096 = torch.ops.aten.add.Tensor(add_1093, view_1128); view_1128 = None + convert_element_type_923 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16) + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_923, 64, '0'); convert_element_type_923 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_924 = torch.ops.prims.convert_element_type.default(add_1096, torch.float32) + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_924, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_1097 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_1097); add_1097 = None + mul_799 = torch.ops.aten.mul.Tensor(convert_element_type_924, rsqrt_53); convert_element_type_924 = None + mul_800 = torch.ops.aten.mul.Tensor(mul_799, wait_tensor_353); mul_799 = wait_tensor_353 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(mul_800, torch.bfloat16); mul_800 = None + view_1130 = torch.ops.aten.view.default(convert_element_type_925, [-1, 2048]); convert_element_type_925 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16) + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_926, 64, '0'); convert_element_type_926 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_259 = torch.ops.aten.permute.default(wait_tensor_354, [1, 0]); wait_tensor_354 = None + mm_139 = torch.ops.aten.mm.default(view_1130, permute_259); permute_259 = None + convert_element_type_929 = torch.ops.prims.convert_element_type.default(mm_139, torch.float32) + amax_16 = torch.ops.aten.amax.default(convert_element_type_929, [1], True) + sub_384 = torch.ops.aten.sub.Tensor(convert_element_type_929, amax_16); convert_element_type_929 = None + exp_49 = torch.ops.aten.exp.default(sub_384); sub_384 = None + sum_65 = torch.ops.aten.sum.dim_IntList(exp_49, [1], True) + div_81 = torch.ops.aten.div.Tensor(exp_49, sum_65); exp_49 = None + add_1098 = torch.ops.aten.add.Tensor(div_81, primals_286); primals_286 = None + topk_16 = torch.ops.aten.topk.default(add_1098, 6, -1, True, False); add_1098 = None + getitem_243 = topk_16[1]; topk_16 = None + gather_16 = torch.ops.aten.gather.default(div_81, 1, getitem_243); div_81 = None + mul_801 = torch.ops.aten.mul.Tensor(gather_16, 1.0); gather_16 = None + view_1132 = torch.ops.aten.view.default(getitem_243, [-1]) + histc_32 = torch.ops.aten.histc.default(view_1132, 64, 0, 64) + add_1099 = torch.ops.aten.add.Tensor(primals_288, histc_32) + sort_16 = torch.ops.aten.sort.stable(view_1132, stable = True); view_1132 = None + getitem_245 = sort_16[1]; sort_16 = None + div_82 = torch.ops.aten.div.Tensor_mode(getitem_245, 6, rounding_mode = 'floor') + index_32 = torch.ops.aten.index.Tensor(view_1130, [div_82]) + all_to_all_single_48 = torch.ops._c10d_functional.all_to_all_single.default(histc_32, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_48); all_to_all_single_48 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_355); wait_tensor_355 = None + view_1136 = torch.ops.aten.view.default(histc_32, [8, -1]); histc_32 = None + sum_66 = torch.ops.aten.sum.dim_IntList(view_1136, [1]); view_1136 = None + device_put_32 = torch.ops.prims.device_put.default(sum_66, device(type='cpu'), True); sum_66 = None + view_1137 = torch.ops.aten.view.default(wait_tensor_356, [8, -1]) + sum_67 = torch.ops.aten.sum.dim_IntList(view_1137, [1]) + device_put_33 = torch.ops.prims.device_put.default(sum_67, device(type='cpu')); sum_67 = None + select_256 = torch.ops.aten.select.int(device_put_32, 0, 0) + _local_scalar_dense_256 = torch.ops.aten._local_scalar_dense.default(select_256); select_256 = None + ge_320 = _local_scalar_dense_256 >= 0 + _assert_scalar_256 = torch.ops.aten._assert_scalar.default(ge_320, "Runtime assertion failed for expression u256 >= 0 on node 'ge_256'"); ge_320 = _assert_scalar_256 = None + select_257 = torch.ops.aten.select.int(device_put_32, 0, 1) + _local_scalar_dense_257 = torch.ops.aten._local_scalar_dense.default(select_257); select_257 = None + ge_321 = _local_scalar_dense_257 >= 0 + _assert_scalar_257 = torch.ops.aten._assert_scalar.default(ge_321, "Runtime assertion failed for expression u257 >= 0 on node 'ge_257'"); ge_321 = _assert_scalar_257 = None + select_258 = torch.ops.aten.select.int(device_put_32, 0, 2) + _local_scalar_dense_258 = torch.ops.aten._local_scalar_dense.default(select_258); select_258 = None + ge_322 = _local_scalar_dense_258 >= 0 + _assert_scalar_258 = torch.ops.aten._assert_scalar.default(ge_322, "Runtime assertion failed for expression u258 >= 0 on node 'ge_258'"); ge_322 = _assert_scalar_258 = None + select_259 = torch.ops.aten.select.int(device_put_32, 0, 3) + _local_scalar_dense_259 = torch.ops.aten._local_scalar_dense.default(select_259); select_259 = None + ge_323 = _local_scalar_dense_259 >= 0 + _assert_scalar_259 = torch.ops.aten._assert_scalar.default(ge_323, "Runtime assertion failed for expression u259 >= 0 on node 'ge_259'"); ge_323 = _assert_scalar_259 = None + select_260 = torch.ops.aten.select.int(device_put_32, 0, 4) + _local_scalar_dense_260 = torch.ops.aten._local_scalar_dense.default(select_260); select_260 = None + ge_324 = _local_scalar_dense_260 >= 0 + _assert_scalar_260 = torch.ops.aten._assert_scalar.default(ge_324, "Runtime assertion failed for expression u260 >= 0 on node 'ge_260'"); ge_324 = _assert_scalar_260 = None + select_261 = torch.ops.aten.select.int(device_put_32, 0, 5) + _local_scalar_dense_261 = torch.ops.aten._local_scalar_dense.default(select_261); select_261 = None + ge_325 = _local_scalar_dense_261 >= 0 + _assert_scalar_261 = torch.ops.aten._assert_scalar.default(ge_325, "Runtime assertion failed for expression u261 >= 0 on node 'ge_261'"); ge_325 = _assert_scalar_261 = None + select_262 = torch.ops.aten.select.int(device_put_32, 0, 6) + _local_scalar_dense_262 = torch.ops.aten._local_scalar_dense.default(select_262); select_262 = None + ge_326 = _local_scalar_dense_262 >= 0 + _assert_scalar_262 = torch.ops.aten._assert_scalar.default(ge_326, "Runtime assertion failed for expression u262 >= 0 on node 'ge_262'"); ge_326 = _assert_scalar_262 = None + select_263 = torch.ops.aten.select.int(device_put_32, 0, 7); device_put_32 = None + _local_scalar_dense_263 = torch.ops.aten._local_scalar_dense.default(select_263); select_263 = None + ge_327 = _local_scalar_dense_263 >= 0 + _assert_scalar_263 = torch.ops.aten._assert_scalar.default(ge_327, "Runtime assertion failed for expression u263 >= 0 on node 'ge_263'"); ge_327 = _assert_scalar_263 = None + select_264 = torch.ops.aten.select.int(device_put_33, 0, 0) + _local_scalar_dense_264 = torch.ops.aten._local_scalar_dense.default(select_264); select_264 = None + ge_328 = _local_scalar_dense_264 >= 0 + _assert_scalar_264 = torch.ops.aten._assert_scalar.default(ge_328, "Runtime assertion failed for expression u264 >= 0 on node 'ge_264'"); ge_328 = _assert_scalar_264 = None + select_265 = torch.ops.aten.select.int(device_put_33, 0, 1) + _local_scalar_dense_265 = torch.ops.aten._local_scalar_dense.default(select_265); select_265 = None + ge_329 = _local_scalar_dense_265 >= 0 + _assert_scalar_265 = torch.ops.aten._assert_scalar.default(ge_329, "Runtime assertion failed for expression u265 >= 0 on node 'ge_265'"); ge_329 = _assert_scalar_265 = None + select_266 = torch.ops.aten.select.int(device_put_33, 0, 2) + _local_scalar_dense_266 = torch.ops.aten._local_scalar_dense.default(select_266); select_266 = None + ge_330 = _local_scalar_dense_266 >= 0 + _assert_scalar_266 = torch.ops.aten._assert_scalar.default(ge_330, "Runtime assertion failed for expression u266 >= 0 on node 'ge_266'"); ge_330 = _assert_scalar_266 = None + select_267 = torch.ops.aten.select.int(device_put_33, 0, 3) + _local_scalar_dense_267 = torch.ops.aten._local_scalar_dense.default(select_267); select_267 = None + ge_331 = _local_scalar_dense_267 >= 0 + _assert_scalar_267 = torch.ops.aten._assert_scalar.default(ge_331, "Runtime assertion failed for expression u267 >= 0 on node 'ge_267'"); ge_331 = _assert_scalar_267 = None + select_268 = torch.ops.aten.select.int(device_put_33, 0, 4) + _local_scalar_dense_268 = torch.ops.aten._local_scalar_dense.default(select_268); select_268 = None + ge_332 = _local_scalar_dense_268 >= 0 + _assert_scalar_268 = torch.ops.aten._assert_scalar.default(ge_332, "Runtime assertion failed for expression u268 >= 0 on node 'ge_268'"); ge_332 = _assert_scalar_268 = None + select_269 = torch.ops.aten.select.int(device_put_33, 0, 5) + _local_scalar_dense_269 = torch.ops.aten._local_scalar_dense.default(select_269); select_269 = None + ge_333 = _local_scalar_dense_269 >= 0 + _assert_scalar_269 = torch.ops.aten._assert_scalar.default(ge_333, "Runtime assertion failed for expression u269 >= 0 on node 'ge_269'"); ge_333 = _assert_scalar_269 = None + select_270 = torch.ops.aten.select.int(device_put_33, 0, 6) + _local_scalar_dense_270 = torch.ops.aten._local_scalar_dense.default(select_270); select_270 = None + ge_334 = _local_scalar_dense_270 >= 0 + _assert_scalar_270 = torch.ops.aten._assert_scalar.default(ge_334, "Runtime assertion failed for expression u270 >= 0 on node 'ge_270'"); ge_334 = _assert_scalar_270 = None + select_271 = torch.ops.aten.select.int(device_put_33, 0, 7); device_put_33 = None + _local_scalar_dense_271 = torch.ops.aten._local_scalar_dense.default(select_271); select_271 = None + ge_335 = _local_scalar_dense_271 >= 0 + _assert_scalar_271 = torch.ops.aten._assert_scalar.default(ge_335, "Runtime assertion failed for expression u271 >= 0 on node 'ge_271'"); ge_335 = _assert_scalar_271 = None + all_to_all_single_49 = torch.ops._c10d_functional.all_to_all_single.default(index_32, [_local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271], [_local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263], '521'); index_32 = None + sym_size_int_64 = torch.ops.aten.sym_size.int(all_to_all_single_49, 0) + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_49); all_to_all_single_49 = None + sym_sum_32 = torch.sym_sum((_local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271)) + add_1106 = sym_sum_32 + 64; sym_sum_32 = None + add_1107 = add_1106 + 8; add_1106 = None + sub_387 = add_1107 - 1; add_1107 = None + floordiv_16 = sub_387 // 8; sub_387 = None + mul_806 = floordiv_16 * 8; floordiv_16 = None + cumsum_48 = torch.ops.aten.cumsum.default(wait_tensor_356, 0) + sub_388 = torch.ops.aten.sub.Tensor(cumsum_48, wait_tensor_356); cumsum_48 = None + sum_68 = torch.ops.aten.sum.dim_IntList(view_1137, [0]); view_1137 = None + clamp_min_16 = torch.ops.aten.clamp_min.default(sum_68, 8); sum_68 = None + add_1108 = torch.ops.aten.add.Tensor(clamp_min_16, 8); clamp_min_16 = None + sub_389 = torch.ops.aten.sub.Tensor(add_1108, 1); add_1108 = None + div_83 = torch.ops.aten.div.Tensor_mode(sub_389, 8, rounding_mode = 'floor'); sub_389 = None + mul_807 = torch.ops.aten.mul.Tensor(div_83, 8); div_83 = None + convert_element_type_932 = torch.ops.prims.convert_element_type.default(mul_807, torch.int32); mul_807 = None + cumsum_49 = torch.ops.aten.cumsum.default(convert_element_type_932, 0) + sub_390 = torch.ops.aten.sub.Tensor(cumsum_49, convert_element_type_932); cumsum_49 = None + full_228 = torch.ops.aten.full.default([mul_806], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_806 = None + triton_kernel_wrapper_functional_proxy_16 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 16, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_356, 'start_index_values_ptr': sub_388, 'write_offsets_ptr': sub_390, 'output_ptr': full_228}, tensors_to_clone = ['output_ptr']); wait_tensor_356 = sub_388 = sub_390 = full_228 = None + getitem_246 = triton_kernel_wrapper_functional_proxy_16['output_ptr']; triton_kernel_wrapper_functional_proxy_16 = None + cat_52 = torch.ops.aten.cat.default([wait_tensor_357, full_default]); wait_tensor_357 = None + sym_size_int_65 = torch.ops.aten.sym_size.int(cat_52, 0) + sym_sum_33 = torch.sym_sum((1, _local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271)) + index_33 = torch.ops.aten.index.Tensor(cat_52, [getitem_246]); cat_52 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16) + all_gather_into_tensor_291 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 8, '513'); convert_element_type_934 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_291); all_gather_into_tensor_291 = None + convert_element_type_936 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16) + all_gather_into_tensor_293 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_936, 8, '513'); convert_element_type_936 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_293); all_gather_into_tensor_293 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16) + all_gather_into_tensor_294 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_937, 8, '513'); convert_element_type_937 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_294); all_gather_into_tensor_294 = None + cumsum_50 = torch.ops.aten.cumsum.default(convert_element_type_932, 0, dtype = torch.int32); convert_element_type_932 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_358, [0, 2, 1]); wait_tensor_358 = None + _grouped_mm_48 = torch.ops.aten._grouped_mm.default(index_33, permute_260, cumsum_50); permute_260 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(_grouped_mm_48, torch.float32) + neg_33 = torch.ops.aten.neg.default(convert_element_type_940) + exp_50 = torch.ops.aten.exp.default(neg_33); neg_33 = None + add_1120 = torch.ops.aten.add.Tensor(exp_50, 1); exp_50 = None + div_84 = torch.ops.aten.div.Tensor(convert_element_type_940, add_1120); convert_element_type_940 = add_1120 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(div_84, torch.bfloat16); div_84 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_361, [0, 2, 1]); wait_tensor_361 = None + _grouped_mm_49 = torch.ops.aten._grouped_mm.default(index_33, permute_261, cumsum_50); permute_261 = None + mul_819 = torch.ops.aten.mul.Tensor(convert_element_type_941, _grouped_mm_49); convert_element_type_941 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_360, [0, 2, 1]); wait_tensor_360 = None + _grouped_mm_50 = torch.ops.aten._grouped_mm.default(mul_819, permute_262, cumsum_50); permute_262 = None + empty_16 = torch.ops.aten.empty.memory_format([sym_size_int_65, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_32 = torch.ops.aten.index_put.default(empty_16, [getitem_246], _grouped_mm_50); empty_16 = _grouped_mm_50 = None + slice_70 = torch.ops.aten.slice.Tensor(index_put_32, 0, 0, -1); index_put_32 = None + all_to_all_single_50 = torch.ops._c10d_functional.all_to_all_single.default(slice_70, [_local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263], [_local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271], '521'); slice_70 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_50); all_to_all_single_50 = None + convert_element_type_942 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16) + all_gather_into_tensor_297 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_942, 64, '0'); convert_element_type_942 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_297); all_gather_into_tensor_297 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_365, [1, 0]); wait_tensor_365 = None + mm_140 = torch.ops.aten.mm.default(view_1130, permute_263); permute_263 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(mm_140, torch.float32) + neg_34 = torch.ops.aten.neg.default(convert_element_type_945) + exp_51 = torch.ops.aten.exp.default(neg_34); neg_34 = None + add_1156 = torch.ops.aten.add.Tensor(exp_51, 1); exp_51 = None + div_85 = torch.ops.aten.div.Tensor(convert_element_type_945, add_1156); convert_element_type_945 = add_1156 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(div_85, torch.bfloat16); div_85 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16) + all_gather_into_tensor_298 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 64, '0'); convert_element_type_947 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_298); all_gather_into_tensor_298 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_366, [1, 0]); wait_tensor_366 = None + mm_141 = torch.ops.aten.mm.default(view_1130, permute_264); permute_264 = None + mul_839 = torch.ops.aten.mul.Tensor(convert_element_type_946, mm_141); convert_element_type_946 = None + convert_element_type_950 = torch.ops.prims.convert_element_type.default(primals_294, torch.bfloat16) + all_gather_into_tensor_299 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_950, 64, '0'); convert_element_type_950 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_299); all_gather_into_tensor_299 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_367, [1, 0]); wait_tensor_367 = None + mm_142 = torch.ops.aten.mm.default(mul_839, permute_265); permute_265 = None + index_put_33 = torch.ops.aten.index_put.default(full_default_1, [getitem_245], wait_tensor_364); wait_tensor_364 = None + view_1170 = torch.ops.aten.view.default(mul_801, [-1, 1, 6]); mul_801 = None + view_1171 = torch.ops.aten.view.default(index_put_33, [-1, 6, 2048]); index_put_33 = None + convert_element_type_953 = torch.ops.prims.convert_element_type.default(view_1171, torch.float32); view_1171 = None + bmm_16 = torch.ops.aten.bmm.default(view_1170, convert_element_type_953) + convert_element_type_954 = torch.ops.prims.convert_element_type.default(bmm_16, torch.bfloat16); bmm_16 = None + squeeze_16 = torch.ops.aten.squeeze.dim(convert_element_type_954, 1); convert_element_type_954 = None + add_1160 = torch.ops.aten.add.Tensor(mm_142, squeeze_16); mm_142 = squeeze_16 = None + view_1172 = torch.ops.aten.view.default(add_1160, [2, 4096, 2048]); add_1160 = None + add_1161 = torch.ops.aten.add.Tensor(add_1096, view_1172); view_1172 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_295, torch.bfloat16) + all_gather_into_tensor_300 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 64, '0'); convert_element_type_955 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_300); all_gather_into_tensor_300 = None + convert_element_type_956 = torch.ops.prims.convert_element_type.default(add_1161, torch.float32) + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_956, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_1162 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_1162); add_1162 = None + mul_842 = torch.ops.aten.mul.Tensor(convert_element_type_956, rsqrt_54); convert_element_type_956 = None + mul_843 = torch.ops.aten.mul.Tensor(mul_842, wait_tensor_368); mul_842 = wait_tensor_368 = None + convert_element_type_957 = torch.ops.prims.convert_element_type.default(mul_843, torch.bfloat16); mul_843 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_296, torch.bfloat16) + all_gather_into_tensor_301 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 64, '0'); convert_element_type_958 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_301); all_gather_into_tensor_301 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_369, [1, 0]); wait_tensor_369 = None + view_1175 = torch.ops.aten.view.default(convert_element_type_957, [8192, 2048]); convert_element_type_957 = None + mm_143 = torch.ops.aten.mm.default(view_1175, permute_266); permute_266 = None + view_1176 = torch.ops.aten.view.default(mm_143, [2, 4096, 3072]); mm_143 = None + view_1177 = torch.ops.aten.view.default(view_1176, [2, 4096, -1, 192]); view_1176 = None + split_with_sizes_54 = torch.ops.aten.split_with_sizes.default(view_1177, [128, 64], -1); view_1177 = None + getitem_247 = split_with_sizes_54[0] + getitem_248 = split_with_sizes_54[1]; split_with_sizes_54 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(getitem_248, torch.float32); getitem_248 = None + view_1178 = torch.ops.aten.view.default(convert_element_type_961, [2, 4096, 16, -1, 2]); convert_element_type_961 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_1178); view_1178 = None + mul_844 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_7); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_844); mul_844 = None + view_1180 = torch.ops.aten.view.default(view_as_real_36, [2, 4096, 16, 64]); view_as_real_36 = None + convert_element_type_962 = torch.ops.prims.convert_element_type.default(view_1180, torch.bfloat16); view_1180 = None + cat_53 = torch.ops.aten.cat.default([getitem_247, convert_element_type_962], -1); getitem_247 = convert_element_type_962 = None + convert_element_type_963 = torch.ops.prims.convert_element_type.default(primals_297, torch.bfloat16) + all_gather_into_tensor_302 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_963, 64, '0'); convert_element_type_963 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_302); all_gather_into_tensor_302 = None + permute_267 = torch.ops.aten.permute.default(wait_tensor_370, [1, 0]); wait_tensor_370 = None + mm_144 = torch.ops.aten.mm.default(view_1175, permute_267); permute_267 = None + view_1183 = torch.ops.aten.view.default(mm_144, [2, 4096, 576]); mm_144 = None + split_with_sizes_55 = torch.ops.aten.split_with_sizes.default(view_1183, [512, 64], -1); view_1183 = None + getitem_249 = split_with_sizes_55[0] + getitem_250 = split_with_sizes_55[1]; split_with_sizes_55 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(getitem_250, 2); getitem_250 = None + convert_element_type_966 = torch.ops.prims.convert_element_type.default(unsqueeze_35, torch.float32); unsqueeze_35 = None + view_1184 = torch.ops.aten.view.default(convert_element_type_966, [2, 4096, 1, -1, 2]); convert_element_type_966 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_1184); view_1184 = None + mul_845 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_7); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_845); mul_845 = None + view_1186 = torch.ops.aten.view.default(view_as_real_37, [2, 4096, 1, 64]); view_as_real_37 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(view_1186, torch.bfloat16); view_1186 = None + convert_element_type_968 = torch.ops.prims.convert_element_type.default(primals_298, torch.bfloat16) + all_gather_into_tensor_303 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_968, 64, '0'); convert_element_type_968 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_303); all_gather_into_tensor_303 = None + convert_element_type_969 = torch.ops.prims.convert_element_type.default(getitem_249, torch.float32) + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_969, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_1163 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_1163); add_1163 = None + mul_846 = torch.ops.aten.mul.Tensor(convert_element_type_969, rsqrt_55); convert_element_type_969 = None + mul_847 = torch.ops.aten.mul.Tensor(mul_846, wait_tensor_371); mul_846 = wait_tensor_371 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(mul_847, torch.bfloat16); mul_847 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(primals_299, torch.bfloat16) + all_gather_into_tensor_304 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_971, 64, '0'); convert_element_type_971 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_304); all_gather_into_tensor_304 = None + permute_268 = torch.ops.aten.permute.default(wait_tensor_372, [1, 0]); wait_tensor_372 = None + view_1189 = torch.ops.aten.view.default(convert_element_type_970, [8192, 512]); convert_element_type_970 = None + mm_145 = torch.ops.aten.mm.default(view_1189, permute_268); permute_268 = None + view_1190 = torch.ops.aten.view.default(mm_145, [2, 4096, 4096]); mm_145 = None + view_1191 = torch.ops.aten.view.default(view_1190, [2, 4096, -1, 256]); view_1190 = None + split_with_sizes_56 = torch.ops.aten.split_with_sizes.default(view_1191, [128, 128], -1); view_1191 = None + getitem_251 = split_with_sizes_56[0] + getitem_252 = split_with_sizes_56[1]; split_with_sizes_56 = None + expand_18 = torch.ops.aten.expand.default(convert_element_type_967, [-1, -1, 16, -1]); convert_element_type_967 = None + cat_54 = torch.ops.aten.cat.default([getitem_251, expand_18], -1); getitem_251 = expand_18 = None + permute_269 = torch.ops.aten.permute.default(cat_53, [0, 2, 1, 3]); cat_53 = None + permute_270 = torch.ops.aten.permute.default(cat_54, [0, 2, 1, 3]); cat_54 = None + permute_271 = torch.ops.aten.permute.default(getitem_252, [0, 2, 1, 3]); getitem_252 = None + sdpa_score18 = self.sdpa_score18 + sdpa_mask18 = self.sdpa_mask18 + flex_attention_18 = torch.ops.higher_order.flex_attention(permute_269, permute_270, permute_271, sdpa_score18, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask18), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score18 = sdpa_mask18 = None + getitem_253 = flex_attention_18[0] + getitem_254 = flex_attention_18[1]; flex_attention_18 = None + permute_272 = torch.ops.aten.permute.default(getitem_253, [0, 2, 1, 3]) + view_1192 = torch.ops.aten.view.default(permute_272, [2, 4096, -1]); permute_272 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_300, torch.bfloat16) + all_gather_into_tensor_305 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 64, '0'); convert_element_type_974 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_305); all_gather_into_tensor_305 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_373, [1, 0]); wait_tensor_373 = None + view_1194 = torch.ops.aten.view.default(view_1192, [8192, 2048]); view_1192 = None + mm_146 = torch.ops.aten.mm.default(view_1194, permute_273); view_1194 = permute_273 = None + view_1195 = torch.ops.aten.view.default(mm_146, [2, 4096, 2048]); mm_146 = None + add_1164 = torch.ops.aten.add.Tensor(add_1161, view_1195); view_1195 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_301, torch.bfloat16) + all_gather_into_tensor_306 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 64, '0'); convert_element_type_977 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_306); all_gather_into_tensor_306 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_1164, torch.float32) + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_1165 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_1165); add_1165 = None + mul_848 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_56); convert_element_type_978 = None + mul_849 = torch.ops.aten.mul.Tensor(mul_848, wait_tensor_374); mul_848 = wait_tensor_374 = None + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_849, torch.bfloat16); mul_849 = None + view_1197 = torch.ops.aten.view.default(convert_element_type_979, [-1, 2048]); convert_element_type_979 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_303, torch.bfloat16) + all_gather_into_tensor_307 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 64, '0'); convert_element_type_980 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_307); all_gather_into_tensor_307 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_375, [1, 0]); wait_tensor_375 = None + mm_147 = torch.ops.aten.mm.default(view_1197, permute_274); permute_274 = None + convert_element_type_983 = torch.ops.prims.convert_element_type.default(mm_147, torch.float32) + amax_17 = torch.ops.aten.amax.default(convert_element_type_983, [1], True) + sub_408 = torch.ops.aten.sub.Tensor(convert_element_type_983, amax_17); convert_element_type_983 = None + exp_52 = torch.ops.aten.exp.default(sub_408); sub_408 = None + sum_69 = torch.ops.aten.sum.dim_IntList(exp_52, [1], True) + div_86 = torch.ops.aten.div.Tensor(exp_52, sum_69); exp_52 = None + add_1166 = torch.ops.aten.add.Tensor(div_86, primals_302); primals_302 = None + topk_17 = torch.ops.aten.topk.default(add_1166, 6, -1, True, False); add_1166 = None + getitem_257 = topk_17[1]; topk_17 = None + gather_17 = torch.ops.aten.gather.default(div_86, 1, getitem_257); div_86 = None + mul_850 = torch.ops.aten.mul.Tensor(gather_17, 1.0); gather_17 = None + view_1199 = torch.ops.aten.view.default(getitem_257, [-1]) + histc_34 = torch.ops.aten.histc.default(view_1199, 64, 0, 64) + add_1167 = torch.ops.aten.add.Tensor(primals_304, histc_34) + sort_17 = torch.ops.aten.sort.stable(view_1199, stable = True); view_1199 = None + getitem_259 = sort_17[1]; sort_17 = None + div_87 = torch.ops.aten.div.Tensor_mode(getitem_259, 6, rounding_mode = 'floor') + index_34 = torch.ops.aten.index.Tensor(view_1197, [div_87]) + all_to_all_single_51 = torch.ops._c10d_functional.all_to_all_single.default(histc_34, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_51); all_to_all_single_51 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_376); wait_tensor_376 = None + view_1203 = torch.ops.aten.view.default(histc_34, [8, -1]); histc_34 = None + sum_70 = torch.ops.aten.sum.dim_IntList(view_1203, [1]); view_1203 = None + device_put_34 = torch.ops.prims.device_put.default(sum_70, device(type='cpu'), True); sum_70 = None + view_1204 = torch.ops.aten.view.default(wait_tensor_377, [8, -1]) + sum_71 = torch.ops.aten.sum.dim_IntList(view_1204, [1]) + device_put_35 = torch.ops.prims.device_put.default(sum_71, device(type='cpu')); sum_71 = None + select_272 = torch.ops.aten.select.int(device_put_34, 0, 0) + _local_scalar_dense_272 = torch.ops.aten._local_scalar_dense.default(select_272); select_272 = None + ge_340 = _local_scalar_dense_272 >= 0 + _assert_scalar_272 = torch.ops.aten._assert_scalar.default(ge_340, "Runtime assertion failed for expression u272 >= 0 on node 'ge_272'"); ge_340 = _assert_scalar_272 = None + select_273 = torch.ops.aten.select.int(device_put_34, 0, 1) + _local_scalar_dense_273 = torch.ops.aten._local_scalar_dense.default(select_273); select_273 = None + ge_341 = _local_scalar_dense_273 >= 0 + _assert_scalar_273 = torch.ops.aten._assert_scalar.default(ge_341, "Runtime assertion failed for expression u273 >= 0 on node 'ge_273'"); ge_341 = _assert_scalar_273 = None + select_274 = torch.ops.aten.select.int(device_put_34, 0, 2) + _local_scalar_dense_274 = torch.ops.aten._local_scalar_dense.default(select_274); select_274 = None + ge_342 = _local_scalar_dense_274 >= 0 + _assert_scalar_274 = torch.ops.aten._assert_scalar.default(ge_342, "Runtime assertion failed for expression u274 >= 0 on node 'ge_274'"); ge_342 = _assert_scalar_274 = None + select_275 = torch.ops.aten.select.int(device_put_34, 0, 3) + _local_scalar_dense_275 = torch.ops.aten._local_scalar_dense.default(select_275); select_275 = None + ge_343 = _local_scalar_dense_275 >= 0 + _assert_scalar_275 = torch.ops.aten._assert_scalar.default(ge_343, "Runtime assertion failed for expression u275 >= 0 on node 'ge_275'"); ge_343 = _assert_scalar_275 = None + select_276 = torch.ops.aten.select.int(device_put_34, 0, 4) + _local_scalar_dense_276 = torch.ops.aten._local_scalar_dense.default(select_276); select_276 = None + ge_344 = _local_scalar_dense_276 >= 0 + _assert_scalar_276 = torch.ops.aten._assert_scalar.default(ge_344, "Runtime assertion failed for expression u276 >= 0 on node 'ge_276'"); ge_344 = _assert_scalar_276 = None + select_277 = torch.ops.aten.select.int(device_put_34, 0, 5) + _local_scalar_dense_277 = torch.ops.aten._local_scalar_dense.default(select_277); select_277 = None + ge_345 = _local_scalar_dense_277 >= 0 + _assert_scalar_277 = torch.ops.aten._assert_scalar.default(ge_345, "Runtime assertion failed for expression u277 >= 0 on node 'ge_277'"); ge_345 = _assert_scalar_277 = None + select_278 = torch.ops.aten.select.int(device_put_34, 0, 6) + _local_scalar_dense_278 = torch.ops.aten._local_scalar_dense.default(select_278); select_278 = None + ge_346 = _local_scalar_dense_278 >= 0 + _assert_scalar_278 = torch.ops.aten._assert_scalar.default(ge_346, "Runtime assertion failed for expression u278 >= 0 on node 'ge_278'"); ge_346 = _assert_scalar_278 = None + select_279 = torch.ops.aten.select.int(device_put_34, 0, 7); device_put_34 = None + _local_scalar_dense_279 = torch.ops.aten._local_scalar_dense.default(select_279); select_279 = None + ge_347 = _local_scalar_dense_279 >= 0 + _assert_scalar_279 = torch.ops.aten._assert_scalar.default(ge_347, "Runtime assertion failed for expression u279 >= 0 on node 'ge_279'"); ge_347 = _assert_scalar_279 = None + select_280 = torch.ops.aten.select.int(device_put_35, 0, 0) + _local_scalar_dense_280 = torch.ops.aten._local_scalar_dense.default(select_280); select_280 = None + ge_348 = _local_scalar_dense_280 >= 0 + _assert_scalar_280 = torch.ops.aten._assert_scalar.default(ge_348, "Runtime assertion failed for expression u280 >= 0 on node 'ge_280'"); ge_348 = _assert_scalar_280 = None + select_281 = torch.ops.aten.select.int(device_put_35, 0, 1) + _local_scalar_dense_281 = torch.ops.aten._local_scalar_dense.default(select_281); select_281 = None + ge_349 = _local_scalar_dense_281 >= 0 + _assert_scalar_281 = torch.ops.aten._assert_scalar.default(ge_349, "Runtime assertion failed for expression u281 >= 0 on node 'ge_281'"); ge_349 = _assert_scalar_281 = None + select_282 = torch.ops.aten.select.int(device_put_35, 0, 2) + _local_scalar_dense_282 = torch.ops.aten._local_scalar_dense.default(select_282); select_282 = None + ge_350 = _local_scalar_dense_282 >= 0 + _assert_scalar_282 = torch.ops.aten._assert_scalar.default(ge_350, "Runtime assertion failed for expression u282 >= 0 on node 'ge_282'"); ge_350 = _assert_scalar_282 = None + select_283 = torch.ops.aten.select.int(device_put_35, 0, 3) + _local_scalar_dense_283 = torch.ops.aten._local_scalar_dense.default(select_283); select_283 = None + ge_351 = _local_scalar_dense_283 >= 0 + _assert_scalar_283 = torch.ops.aten._assert_scalar.default(ge_351, "Runtime assertion failed for expression u283 >= 0 on node 'ge_283'"); ge_351 = _assert_scalar_283 = None + select_284 = torch.ops.aten.select.int(device_put_35, 0, 4) + _local_scalar_dense_284 = torch.ops.aten._local_scalar_dense.default(select_284); select_284 = None + ge_352 = _local_scalar_dense_284 >= 0 + _assert_scalar_284 = torch.ops.aten._assert_scalar.default(ge_352, "Runtime assertion failed for expression u284 >= 0 on node 'ge_284'"); ge_352 = _assert_scalar_284 = None + select_285 = torch.ops.aten.select.int(device_put_35, 0, 5) + _local_scalar_dense_285 = torch.ops.aten._local_scalar_dense.default(select_285); select_285 = None + ge_353 = _local_scalar_dense_285 >= 0 + _assert_scalar_285 = torch.ops.aten._assert_scalar.default(ge_353, "Runtime assertion failed for expression u285 >= 0 on node 'ge_285'"); ge_353 = _assert_scalar_285 = None + select_286 = torch.ops.aten.select.int(device_put_35, 0, 6) + _local_scalar_dense_286 = torch.ops.aten._local_scalar_dense.default(select_286); select_286 = None + ge_354 = _local_scalar_dense_286 >= 0 + _assert_scalar_286 = torch.ops.aten._assert_scalar.default(ge_354, "Runtime assertion failed for expression u286 >= 0 on node 'ge_286'"); ge_354 = _assert_scalar_286 = None + select_287 = torch.ops.aten.select.int(device_put_35, 0, 7); device_put_35 = None + _local_scalar_dense_287 = torch.ops.aten._local_scalar_dense.default(select_287); select_287 = None + ge_355 = _local_scalar_dense_287 >= 0 + _assert_scalar_287 = torch.ops.aten._assert_scalar.default(ge_355, "Runtime assertion failed for expression u287 >= 0 on node 'ge_287'"); ge_355 = _assert_scalar_287 = None + all_to_all_single_52 = torch.ops._c10d_functional.all_to_all_single.default(index_34, [_local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287], [_local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279], '521'); index_34 = None + sym_size_int_68 = torch.ops.aten.sym_size.int(all_to_all_single_52, 0) + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_52); all_to_all_single_52 = None + sym_sum_34 = torch.sym_sum((_local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287)) + add_1174 = sym_sum_34 + 64; sym_sum_34 = None + add_1175 = add_1174 + 8; add_1174 = None + sub_411 = add_1175 - 1; add_1175 = None + floordiv_17 = sub_411 // 8; sub_411 = None + mul_855 = floordiv_17 * 8; floordiv_17 = None + cumsum_51 = torch.ops.aten.cumsum.default(wait_tensor_377, 0) + sub_412 = torch.ops.aten.sub.Tensor(cumsum_51, wait_tensor_377); cumsum_51 = None + sum_72 = torch.ops.aten.sum.dim_IntList(view_1204, [0]); view_1204 = None + clamp_min_17 = torch.ops.aten.clamp_min.default(sum_72, 8); sum_72 = None + add_1176 = torch.ops.aten.add.Tensor(clamp_min_17, 8); clamp_min_17 = None + sub_413 = torch.ops.aten.sub.Tensor(add_1176, 1); add_1176 = None + div_88 = torch.ops.aten.div.Tensor_mode(sub_413, 8, rounding_mode = 'floor'); sub_413 = None + mul_856 = torch.ops.aten.mul.Tensor(div_88, 8); div_88 = None + convert_element_type_986 = torch.ops.prims.convert_element_type.default(mul_856, torch.int32); mul_856 = None + cumsum_52 = torch.ops.aten.cumsum.default(convert_element_type_986, 0) + sub_414 = torch.ops.aten.sub.Tensor(cumsum_52, convert_element_type_986); cumsum_52 = None + full_241 = torch.ops.aten.full.default([mul_855], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_855 = None + triton_kernel_wrapper_functional_proxy_17 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 17, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_377, 'start_index_values_ptr': sub_412, 'write_offsets_ptr': sub_414, 'output_ptr': full_241}, tensors_to_clone = ['output_ptr']); wait_tensor_377 = sub_412 = sub_414 = full_241 = None + getitem_260 = triton_kernel_wrapper_functional_proxy_17['output_ptr']; triton_kernel_wrapper_functional_proxy_17 = None + cat_55 = torch.ops.aten.cat.default([wait_tensor_378, full_default]); wait_tensor_378 = None + sym_size_int_69 = torch.ops.aten.sym_size.int(cat_55, 0) + sym_sum_35 = torch.sym_sum((1, _local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287)) + index_35 = torch.ops.aten.index.Tensor(cat_55, [getitem_260]); cat_55 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_305, torch.bfloat16) + all_gather_into_tensor_308 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 8, '513'); convert_element_type_988 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_308); all_gather_into_tensor_308 = None + convert_element_type_990 = torch.ops.prims.convert_element_type.default(primals_306, torch.bfloat16) + all_gather_into_tensor_310 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_990, 8, '513'); convert_element_type_990 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_310); all_gather_into_tensor_310 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_307, torch.bfloat16) + all_gather_into_tensor_311 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 8, '513'); convert_element_type_991 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_311); all_gather_into_tensor_311 = None + cumsum_53 = torch.ops.aten.cumsum.default(convert_element_type_986, 0, dtype = torch.int32); convert_element_type_986 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_379, [0, 2, 1]); wait_tensor_379 = None + _grouped_mm_51 = torch.ops.aten._grouped_mm.default(index_35, permute_275, cumsum_53); permute_275 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(_grouped_mm_51, torch.float32) + neg_35 = torch.ops.aten.neg.default(convert_element_type_994) + exp_53 = torch.ops.aten.exp.default(neg_35); neg_35 = None + add_1188 = torch.ops.aten.add.Tensor(exp_53, 1); exp_53 = None + div_89 = torch.ops.aten.div.Tensor(convert_element_type_994, add_1188); convert_element_type_994 = add_1188 = None + convert_element_type_995 = torch.ops.prims.convert_element_type.default(div_89, torch.bfloat16); div_89 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_382, [0, 2, 1]); wait_tensor_382 = None + _grouped_mm_52 = torch.ops.aten._grouped_mm.default(index_35, permute_276, cumsum_53); permute_276 = None + mul_868 = torch.ops.aten.mul.Tensor(convert_element_type_995, _grouped_mm_52); convert_element_type_995 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_381, [0, 2, 1]); wait_tensor_381 = None + _grouped_mm_53 = torch.ops.aten._grouped_mm.default(mul_868, permute_277, cumsum_53); permute_277 = None + empty_17 = torch.ops.aten.empty.memory_format([sym_size_int_69, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_34 = torch.ops.aten.index_put.default(empty_17, [getitem_260], _grouped_mm_53); empty_17 = _grouped_mm_53 = None + slice_74 = torch.ops.aten.slice.Tensor(index_put_34, 0, 0, -1); index_put_34 = None + all_to_all_single_53 = torch.ops._c10d_functional.all_to_all_single.default(slice_74, [_local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279], [_local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287], '521'); slice_74 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_53); all_to_all_single_53 = None + convert_element_type_996 = torch.ops.prims.convert_element_type.default(primals_308, torch.bfloat16) + all_gather_into_tensor_314 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_996, 64, '0'); convert_element_type_996 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_314); all_gather_into_tensor_314 = None + permute_278 = torch.ops.aten.permute.default(wait_tensor_386, [1, 0]); wait_tensor_386 = None + mm_148 = torch.ops.aten.mm.default(view_1197, permute_278); permute_278 = None + convert_element_type_999 = torch.ops.prims.convert_element_type.default(mm_148, torch.float32) + neg_36 = torch.ops.aten.neg.default(convert_element_type_999) + exp_54 = torch.ops.aten.exp.default(neg_36); neg_36 = None + add_1224 = torch.ops.aten.add.Tensor(exp_54, 1); exp_54 = None + div_90 = torch.ops.aten.div.Tensor(convert_element_type_999, add_1224); convert_element_type_999 = add_1224 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(div_90, torch.bfloat16); div_90 = None + convert_element_type_1001 = torch.ops.prims.convert_element_type.default(primals_309, torch.bfloat16) + all_gather_into_tensor_315 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1001, 64, '0'); convert_element_type_1001 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_315); all_gather_into_tensor_315 = None + permute_279 = torch.ops.aten.permute.default(wait_tensor_387, [1, 0]); wait_tensor_387 = None + mm_149 = torch.ops.aten.mm.default(view_1197, permute_279); permute_279 = None + mul_888 = torch.ops.aten.mul.Tensor(convert_element_type_1000, mm_149); convert_element_type_1000 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(primals_310, torch.bfloat16) + all_gather_into_tensor_316 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1004, 64, '0'); convert_element_type_1004 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_316); all_gather_into_tensor_316 = None + permute_280 = torch.ops.aten.permute.default(wait_tensor_388, [1, 0]); wait_tensor_388 = None + mm_150 = torch.ops.aten.mm.default(mul_888, permute_280); permute_280 = None + index_put_35 = torch.ops.aten.index_put.default(full_default_1, [getitem_259], wait_tensor_385); wait_tensor_385 = None + view_1237 = torch.ops.aten.view.default(mul_850, [-1, 1, 6]); mul_850 = None + view_1238 = torch.ops.aten.view.default(index_put_35, [-1, 6, 2048]); index_put_35 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(view_1238, torch.float32); view_1238 = None + bmm_17 = torch.ops.aten.bmm.default(view_1237, convert_element_type_1007) + convert_element_type_1008 = torch.ops.prims.convert_element_type.default(bmm_17, torch.bfloat16); bmm_17 = None + squeeze_17 = torch.ops.aten.squeeze.dim(convert_element_type_1008, 1); convert_element_type_1008 = None + add_1228 = torch.ops.aten.add.Tensor(mm_150, squeeze_17); mm_150 = squeeze_17 = None + view_1239 = torch.ops.aten.view.default(add_1228, [2, 4096, 2048]); add_1228 = None + add_1229 = torch.ops.aten.add.Tensor(add_1164, view_1239); view_1239 = None + convert_element_type_1009 = torch.ops.prims.convert_element_type.default(primals_311, torch.bfloat16) + all_gather_into_tensor_317 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1009, 64, '0'); convert_element_type_1009 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_317); all_gather_into_tensor_317 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(add_1229, torch.float32) + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1010, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_1230 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_1230); add_1230 = None + mul_891 = torch.ops.aten.mul.Tensor(convert_element_type_1010, rsqrt_57); convert_element_type_1010 = None + mul_892 = torch.ops.aten.mul.Tensor(mul_891, wait_tensor_389); mul_891 = wait_tensor_389 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(mul_892, torch.bfloat16); mul_892 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(primals_312, torch.bfloat16) + all_gather_into_tensor_318 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1012, 64, '0'); convert_element_type_1012 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_318); all_gather_into_tensor_318 = None + permute_281 = torch.ops.aten.permute.default(wait_tensor_390, [1, 0]); wait_tensor_390 = None + view_1242 = torch.ops.aten.view.default(convert_element_type_1011, [8192, 2048]); convert_element_type_1011 = None + mm_151 = torch.ops.aten.mm.default(view_1242, permute_281); permute_281 = None + view_1243 = torch.ops.aten.view.default(mm_151, [2, 4096, 3072]); mm_151 = None + view_1244 = torch.ops.aten.view.default(view_1243, [2, 4096, -1, 192]); view_1243 = None + split_with_sizes_57 = torch.ops.aten.split_with_sizes.default(view_1244, [128, 64], -1); view_1244 = None + getitem_261 = split_with_sizes_57[0] + getitem_262 = split_with_sizes_57[1]; split_with_sizes_57 = None + convert_element_type_1015 = torch.ops.prims.convert_element_type.default(getitem_262, torch.float32); getitem_262 = None + view_1245 = torch.ops.aten.view.default(convert_element_type_1015, [2, 4096, 16, -1, 2]); convert_element_type_1015 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_1245); view_1245 = None + mul_893 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_7); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_893); mul_893 = None + view_1247 = torch.ops.aten.view.default(view_as_real_38, [2, 4096, 16, 64]); view_as_real_38 = None + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_1247, torch.bfloat16); view_1247 = None + cat_56 = torch.ops.aten.cat.default([getitem_261, convert_element_type_1016], -1); getitem_261 = convert_element_type_1016 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(primals_313, torch.bfloat16) + all_gather_into_tensor_319 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1017, 64, '0'); convert_element_type_1017 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_319); all_gather_into_tensor_319 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_391, [1, 0]); wait_tensor_391 = None + mm_152 = torch.ops.aten.mm.default(view_1242, permute_282); permute_282 = None + view_1250 = torch.ops.aten.view.default(mm_152, [2, 4096, 576]); mm_152 = None + split_with_sizes_58 = torch.ops.aten.split_with_sizes.default(view_1250, [512, 64], -1); view_1250 = None + getitem_263 = split_with_sizes_58[0] + getitem_264 = split_with_sizes_58[1]; split_with_sizes_58 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(getitem_264, 2); getitem_264 = None + convert_element_type_1020 = torch.ops.prims.convert_element_type.default(unsqueeze_37, torch.float32); unsqueeze_37 = None + view_1251 = torch.ops.aten.view.default(convert_element_type_1020, [2, 4096, 1, -1, 2]); convert_element_type_1020 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_1251); view_1251 = None + mul_894 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_7); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_894); mul_894 = None + view_1253 = torch.ops.aten.view.default(view_as_real_39, [2, 4096, 1, 64]); view_as_real_39 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(view_1253, torch.bfloat16); view_1253 = None + convert_element_type_1022 = torch.ops.prims.convert_element_type.default(primals_314, torch.bfloat16) + all_gather_into_tensor_320 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1022, 64, '0'); convert_element_type_1022 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_320); all_gather_into_tensor_320 = None + convert_element_type_1023 = torch.ops.prims.convert_element_type.default(getitem_263, torch.float32) + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1023, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_1231 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_1231); add_1231 = None + mul_895 = torch.ops.aten.mul.Tensor(convert_element_type_1023, rsqrt_58); convert_element_type_1023 = None + mul_896 = torch.ops.aten.mul.Tensor(mul_895, wait_tensor_392); mul_895 = wait_tensor_392 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(mul_896, torch.bfloat16); mul_896 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(primals_315, torch.bfloat16) + all_gather_into_tensor_321 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1025, 64, '0'); convert_element_type_1025 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_321); all_gather_into_tensor_321 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_393, [1, 0]); wait_tensor_393 = None + view_1256 = torch.ops.aten.view.default(convert_element_type_1024, [8192, 512]); convert_element_type_1024 = None + mm_153 = torch.ops.aten.mm.default(view_1256, permute_283); permute_283 = None + view_1257 = torch.ops.aten.view.default(mm_153, [2, 4096, 4096]); mm_153 = None + view_1258 = torch.ops.aten.view.default(view_1257, [2, 4096, -1, 256]); view_1257 = None + split_with_sizes_59 = torch.ops.aten.split_with_sizes.default(view_1258, [128, 128], -1); view_1258 = None + getitem_265 = split_with_sizes_59[0] + getitem_266 = split_with_sizes_59[1]; split_with_sizes_59 = None + expand_19 = torch.ops.aten.expand.default(convert_element_type_1021, [-1, -1, 16, -1]); convert_element_type_1021 = None + cat_57 = torch.ops.aten.cat.default([getitem_265, expand_19], -1); getitem_265 = expand_19 = None + permute_284 = torch.ops.aten.permute.default(cat_56, [0, 2, 1, 3]); cat_56 = None + permute_285 = torch.ops.aten.permute.default(cat_57, [0, 2, 1, 3]); cat_57 = None + permute_286 = torch.ops.aten.permute.default(getitem_266, [0, 2, 1, 3]); getitem_266 = None + sdpa_score19 = self.sdpa_score19 + sdpa_mask19 = self.sdpa_mask19 + flex_attention_19 = torch.ops.higher_order.flex_attention(permute_284, permute_285, permute_286, sdpa_score19, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask19), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score19 = sdpa_mask19 = None + getitem_267 = flex_attention_19[0] + getitem_268 = flex_attention_19[1]; flex_attention_19 = None + permute_287 = torch.ops.aten.permute.default(getitem_267, [0, 2, 1, 3]) + view_1259 = torch.ops.aten.view.default(permute_287, [2, 4096, -1]); permute_287 = None + convert_element_type_1028 = torch.ops.prims.convert_element_type.default(primals_316, torch.bfloat16) + all_gather_into_tensor_322 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1028, 64, '0'); convert_element_type_1028 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_322); all_gather_into_tensor_322 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_394, [1, 0]); wait_tensor_394 = None + view_1261 = torch.ops.aten.view.default(view_1259, [8192, 2048]); view_1259 = None + mm_154 = torch.ops.aten.mm.default(view_1261, permute_288); view_1261 = permute_288 = None + view_1262 = torch.ops.aten.view.default(mm_154, [2, 4096, 2048]); mm_154 = None + add_1232 = torch.ops.aten.add.Tensor(add_1229, view_1262); view_1262 = None + convert_element_type_1031 = torch.ops.prims.convert_element_type.default(primals_317, torch.bfloat16) + all_gather_into_tensor_323 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1031, 64, '0'); convert_element_type_1031 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_323); all_gather_into_tensor_323 = None + convert_element_type_1032 = torch.ops.prims.convert_element_type.default(add_1232, torch.float32) + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1032, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_1233 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_1233); add_1233 = None + mul_897 = torch.ops.aten.mul.Tensor(convert_element_type_1032, rsqrt_59); convert_element_type_1032 = None + mul_898 = torch.ops.aten.mul.Tensor(mul_897, wait_tensor_395); mul_897 = wait_tensor_395 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(mul_898, torch.bfloat16); mul_898 = None + view_1264 = torch.ops.aten.view.default(convert_element_type_1033, [-1, 2048]); convert_element_type_1033 = None + convert_element_type_1034 = torch.ops.prims.convert_element_type.default(primals_319, torch.bfloat16) + all_gather_into_tensor_324 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1034, 64, '0'); convert_element_type_1034 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_324); all_gather_into_tensor_324 = None + permute_289 = torch.ops.aten.permute.default(wait_tensor_396, [1, 0]); wait_tensor_396 = None + mm_155 = torch.ops.aten.mm.default(view_1264, permute_289); permute_289 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(mm_155, torch.float32) + amax_18 = torch.ops.aten.amax.default(convert_element_type_1037, [1], True) + sub_432 = torch.ops.aten.sub.Tensor(convert_element_type_1037, amax_18); convert_element_type_1037 = None + exp_55 = torch.ops.aten.exp.default(sub_432); sub_432 = None + sum_73 = torch.ops.aten.sum.dim_IntList(exp_55, [1], True) + div_91 = torch.ops.aten.div.Tensor(exp_55, sum_73); exp_55 = None + add_1234 = torch.ops.aten.add.Tensor(div_91, primals_318); primals_318 = None + topk_18 = torch.ops.aten.topk.default(add_1234, 6, -1, True, False); add_1234 = None + getitem_271 = topk_18[1]; topk_18 = None + gather_18 = torch.ops.aten.gather.default(div_91, 1, getitem_271); div_91 = None + mul_899 = torch.ops.aten.mul.Tensor(gather_18, 1.0); gather_18 = None + view_1266 = torch.ops.aten.view.default(getitem_271, [-1]) + histc_36 = torch.ops.aten.histc.default(view_1266, 64, 0, 64) + add_1235 = torch.ops.aten.add.Tensor(primals_320, histc_36) + sort_18 = torch.ops.aten.sort.stable(view_1266, stable = True); view_1266 = None + getitem_273 = sort_18[1]; sort_18 = None + div_92 = torch.ops.aten.div.Tensor_mode(getitem_273, 6, rounding_mode = 'floor') + index_36 = torch.ops.aten.index.Tensor(view_1264, [div_92]) + all_to_all_single_54 = torch.ops._c10d_functional.all_to_all_single.default(histc_36, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_54); all_to_all_single_54 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_397); wait_tensor_397 = None + view_1270 = torch.ops.aten.view.default(histc_36, [8, -1]); histc_36 = None + sum_74 = torch.ops.aten.sum.dim_IntList(view_1270, [1]); view_1270 = None + device_put_36 = torch.ops.prims.device_put.default(sum_74, device(type='cpu'), True); sum_74 = None + view_1271 = torch.ops.aten.view.default(wait_tensor_398, [8, -1]) + sum_75 = torch.ops.aten.sum.dim_IntList(view_1271, [1]) + device_put_37 = torch.ops.prims.device_put.default(sum_75, device(type='cpu')); sum_75 = None + select_288 = torch.ops.aten.select.int(device_put_36, 0, 0) + _local_scalar_dense_288 = torch.ops.aten._local_scalar_dense.default(select_288); select_288 = None + ge_360 = _local_scalar_dense_288 >= 0 + _assert_scalar_288 = torch.ops.aten._assert_scalar.default(ge_360, "Runtime assertion failed for expression u288 >= 0 on node 'ge_288'"); ge_360 = _assert_scalar_288 = None + select_289 = torch.ops.aten.select.int(device_put_36, 0, 1) + _local_scalar_dense_289 = torch.ops.aten._local_scalar_dense.default(select_289); select_289 = None + ge_361 = _local_scalar_dense_289 >= 0 + _assert_scalar_289 = torch.ops.aten._assert_scalar.default(ge_361, "Runtime assertion failed for expression u289 >= 0 on node 'ge_289'"); ge_361 = _assert_scalar_289 = None + select_290 = torch.ops.aten.select.int(device_put_36, 0, 2) + _local_scalar_dense_290 = torch.ops.aten._local_scalar_dense.default(select_290); select_290 = None + ge_362 = _local_scalar_dense_290 >= 0 + _assert_scalar_290 = torch.ops.aten._assert_scalar.default(ge_362, "Runtime assertion failed for expression u290 >= 0 on node 'ge_290'"); ge_362 = _assert_scalar_290 = None + select_291 = torch.ops.aten.select.int(device_put_36, 0, 3) + _local_scalar_dense_291 = torch.ops.aten._local_scalar_dense.default(select_291); select_291 = None + ge_363 = _local_scalar_dense_291 >= 0 + _assert_scalar_291 = torch.ops.aten._assert_scalar.default(ge_363, "Runtime assertion failed for expression u291 >= 0 on node 'ge_291'"); ge_363 = _assert_scalar_291 = None + select_292 = torch.ops.aten.select.int(device_put_36, 0, 4) + _local_scalar_dense_292 = torch.ops.aten._local_scalar_dense.default(select_292); select_292 = None + ge_364 = _local_scalar_dense_292 >= 0 + _assert_scalar_292 = torch.ops.aten._assert_scalar.default(ge_364, "Runtime assertion failed for expression u292 >= 0 on node 'ge_292'"); ge_364 = _assert_scalar_292 = None + select_293 = torch.ops.aten.select.int(device_put_36, 0, 5) + _local_scalar_dense_293 = torch.ops.aten._local_scalar_dense.default(select_293); select_293 = None + ge_365 = _local_scalar_dense_293 >= 0 + _assert_scalar_293 = torch.ops.aten._assert_scalar.default(ge_365, "Runtime assertion failed for expression u293 >= 0 on node 'ge_293'"); ge_365 = _assert_scalar_293 = None + select_294 = torch.ops.aten.select.int(device_put_36, 0, 6) + _local_scalar_dense_294 = torch.ops.aten._local_scalar_dense.default(select_294); select_294 = None + ge_366 = _local_scalar_dense_294 >= 0 + _assert_scalar_294 = torch.ops.aten._assert_scalar.default(ge_366, "Runtime assertion failed for expression u294 >= 0 on node 'ge_294'"); ge_366 = _assert_scalar_294 = None + select_295 = torch.ops.aten.select.int(device_put_36, 0, 7); device_put_36 = None + _local_scalar_dense_295 = torch.ops.aten._local_scalar_dense.default(select_295); select_295 = None + ge_367 = _local_scalar_dense_295 >= 0 + _assert_scalar_295 = torch.ops.aten._assert_scalar.default(ge_367, "Runtime assertion failed for expression u295 >= 0 on node 'ge_295'"); ge_367 = _assert_scalar_295 = None + select_296 = torch.ops.aten.select.int(device_put_37, 0, 0) + _local_scalar_dense_296 = torch.ops.aten._local_scalar_dense.default(select_296); select_296 = None + ge_368 = _local_scalar_dense_296 >= 0 + _assert_scalar_296 = torch.ops.aten._assert_scalar.default(ge_368, "Runtime assertion failed for expression u296 >= 0 on node 'ge_296'"); ge_368 = _assert_scalar_296 = None + select_297 = torch.ops.aten.select.int(device_put_37, 0, 1) + _local_scalar_dense_297 = torch.ops.aten._local_scalar_dense.default(select_297); select_297 = None + ge_369 = _local_scalar_dense_297 >= 0 + _assert_scalar_297 = torch.ops.aten._assert_scalar.default(ge_369, "Runtime assertion failed for expression u297 >= 0 on node 'ge_297'"); ge_369 = _assert_scalar_297 = None + select_298 = torch.ops.aten.select.int(device_put_37, 0, 2) + _local_scalar_dense_298 = torch.ops.aten._local_scalar_dense.default(select_298); select_298 = None + ge_370 = _local_scalar_dense_298 >= 0 + _assert_scalar_298 = torch.ops.aten._assert_scalar.default(ge_370, "Runtime assertion failed for expression u298 >= 0 on node 'ge_298'"); ge_370 = _assert_scalar_298 = None + select_299 = torch.ops.aten.select.int(device_put_37, 0, 3) + _local_scalar_dense_299 = torch.ops.aten._local_scalar_dense.default(select_299); select_299 = None + ge_371 = _local_scalar_dense_299 >= 0 + _assert_scalar_299 = torch.ops.aten._assert_scalar.default(ge_371, "Runtime assertion failed for expression u299 >= 0 on node 'ge_299'"); ge_371 = _assert_scalar_299 = None + select_300 = torch.ops.aten.select.int(device_put_37, 0, 4) + _local_scalar_dense_300 = torch.ops.aten._local_scalar_dense.default(select_300); select_300 = None + ge_372 = _local_scalar_dense_300 >= 0 + _assert_scalar_300 = torch.ops.aten._assert_scalar.default(ge_372, "Runtime assertion failed for expression u300 >= 0 on node 'ge_300'"); ge_372 = _assert_scalar_300 = None + select_301 = torch.ops.aten.select.int(device_put_37, 0, 5) + _local_scalar_dense_301 = torch.ops.aten._local_scalar_dense.default(select_301); select_301 = None + ge_373 = _local_scalar_dense_301 >= 0 + _assert_scalar_301 = torch.ops.aten._assert_scalar.default(ge_373, "Runtime assertion failed for expression u301 >= 0 on node 'ge_301'"); ge_373 = _assert_scalar_301 = None + select_302 = torch.ops.aten.select.int(device_put_37, 0, 6) + _local_scalar_dense_302 = torch.ops.aten._local_scalar_dense.default(select_302); select_302 = None + ge_374 = _local_scalar_dense_302 >= 0 + _assert_scalar_302 = torch.ops.aten._assert_scalar.default(ge_374, "Runtime assertion failed for expression u302 >= 0 on node 'ge_302'"); ge_374 = _assert_scalar_302 = None + select_303 = torch.ops.aten.select.int(device_put_37, 0, 7); device_put_37 = None + _local_scalar_dense_303 = torch.ops.aten._local_scalar_dense.default(select_303); select_303 = None + ge_375 = _local_scalar_dense_303 >= 0 + _assert_scalar_303 = torch.ops.aten._assert_scalar.default(ge_375, "Runtime assertion failed for expression u303 >= 0 on node 'ge_303'"); ge_375 = _assert_scalar_303 = None + all_to_all_single_55 = torch.ops._c10d_functional.all_to_all_single.default(index_36, [_local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303], [_local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295], '521'); index_36 = None + sym_size_int_72 = torch.ops.aten.sym_size.int(all_to_all_single_55, 0) + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_55); all_to_all_single_55 = None + sym_sum_36 = torch.sym_sum((_local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303)) + add_1242 = sym_sum_36 + 64; sym_sum_36 = None + add_1243 = add_1242 + 8; add_1242 = None + sub_435 = add_1243 - 1; add_1243 = None + floordiv_18 = sub_435 // 8; sub_435 = None + mul_904 = floordiv_18 * 8; floordiv_18 = None + cumsum_54 = torch.ops.aten.cumsum.default(wait_tensor_398, 0) + sub_436 = torch.ops.aten.sub.Tensor(cumsum_54, wait_tensor_398); cumsum_54 = None + sum_76 = torch.ops.aten.sum.dim_IntList(view_1271, [0]); view_1271 = None + clamp_min_18 = torch.ops.aten.clamp_min.default(sum_76, 8); sum_76 = None + add_1244 = torch.ops.aten.add.Tensor(clamp_min_18, 8); clamp_min_18 = None + sub_437 = torch.ops.aten.sub.Tensor(add_1244, 1); add_1244 = None + div_93 = torch.ops.aten.div.Tensor_mode(sub_437, 8, rounding_mode = 'floor'); sub_437 = None + mul_905 = torch.ops.aten.mul.Tensor(div_93, 8); div_93 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(mul_905, torch.int32); mul_905 = None + cumsum_55 = torch.ops.aten.cumsum.default(convert_element_type_1040, 0) + sub_438 = torch.ops.aten.sub.Tensor(cumsum_55, convert_element_type_1040); cumsum_55 = None + full_254 = torch.ops.aten.full.default([mul_904], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_904 = None + triton_kernel_wrapper_functional_proxy_18 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 18, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_398, 'start_index_values_ptr': sub_436, 'write_offsets_ptr': sub_438, 'output_ptr': full_254}, tensors_to_clone = ['output_ptr']); wait_tensor_398 = sub_436 = sub_438 = full_254 = None + getitem_274 = triton_kernel_wrapper_functional_proxy_18['output_ptr']; triton_kernel_wrapper_functional_proxy_18 = None + cat_58 = torch.ops.aten.cat.default([wait_tensor_399, full_default]); wait_tensor_399 = None + sym_size_int_73 = torch.ops.aten.sym_size.int(cat_58, 0) + sym_sum_37 = torch.sym_sum((1, _local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303)) + index_37 = torch.ops.aten.index.Tensor(cat_58, [getitem_274]); cat_58 = None + convert_element_type_1042 = torch.ops.prims.convert_element_type.default(primals_321, torch.bfloat16) + all_gather_into_tensor_325 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1042, 8, '513'); convert_element_type_1042 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_325); all_gather_into_tensor_325 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(primals_322, torch.bfloat16) + all_gather_into_tensor_327 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1044, 8, '513'); convert_element_type_1044 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_327); all_gather_into_tensor_327 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(primals_323, torch.bfloat16) + all_gather_into_tensor_328 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1045, 8, '513'); convert_element_type_1045 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_328); all_gather_into_tensor_328 = None + cumsum_56 = torch.ops.aten.cumsum.default(convert_element_type_1040, 0, dtype = torch.int32); convert_element_type_1040 = None + permute_290 = torch.ops.aten.permute.default(wait_tensor_400, [0, 2, 1]); wait_tensor_400 = None + _grouped_mm_54 = torch.ops.aten._grouped_mm.default(index_37, permute_290, cumsum_56); permute_290 = None + convert_element_type_1048 = torch.ops.prims.convert_element_type.default(_grouped_mm_54, torch.float32) + neg_37 = torch.ops.aten.neg.default(convert_element_type_1048) + exp_56 = torch.ops.aten.exp.default(neg_37); neg_37 = None + add_1256 = torch.ops.aten.add.Tensor(exp_56, 1); exp_56 = None + div_94 = torch.ops.aten.div.Tensor(convert_element_type_1048, add_1256); convert_element_type_1048 = add_1256 = None + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(div_94, torch.bfloat16); div_94 = None + permute_291 = torch.ops.aten.permute.default(wait_tensor_403, [0, 2, 1]); wait_tensor_403 = None + _grouped_mm_55 = torch.ops.aten._grouped_mm.default(index_37, permute_291, cumsum_56); permute_291 = None + mul_917 = torch.ops.aten.mul.Tensor(convert_element_type_1049, _grouped_mm_55); convert_element_type_1049 = None + permute_292 = torch.ops.aten.permute.default(wait_tensor_402, [0, 2, 1]); wait_tensor_402 = None + _grouped_mm_56 = torch.ops.aten._grouped_mm.default(mul_917, permute_292, cumsum_56); permute_292 = None + empty_18 = torch.ops.aten.empty.memory_format([sym_size_int_73, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_36 = torch.ops.aten.index_put.default(empty_18, [getitem_274], _grouped_mm_56); empty_18 = _grouped_mm_56 = None + slice_78 = torch.ops.aten.slice.Tensor(index_put_36, 0, 0, -1); index_put_36 = None + all_to_all_single_56 = torch.ops._c10d_functional.all_to_all_single.default(slice_78, [_local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295], [_local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303], '521'); slice_78 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_56); all_to_all_single_56 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(primals_324, torch.bfloat16) + all_gather_into_tensor_331 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1050, 64, '0'); convert_element_type_1050 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_331); all_gather_into_tensor_331 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_407, [1, 0]); wait_tensor_407 = None + mm_156 = torch.ops.aten.mm.default(view_1264, permute_293); permute_293 = None + convert_element_type_1053 = torch.ops.prims.convert_element_type.default(mm_156, torch.float32) + neg_38 = torch.ops.aten.neg.default(convert_element_type_1053) + exp_57 = torch.ops.aten.exp.default(neg_38); neg_38 = None + add_1292 = torch.ops.aten.add.Tensor(exp_57, 1); exp_57 = None + div_95 = torch.ops.aten.div.Tensor(convert_element_type_1053, add_1292); convert_element_type_1053 = add_1292 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(div_95, torch.bfloat16); div_95 = None + convert_element_type_1055 = torch.ops.prims.convert_element_type.default(primals_325, torch.bfloat16) + all_gather_into_tensor_332 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1055, 64, '0'); convert_element_type_1055 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_332); all_gather_into_tensor_332 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_408, [1, 0]); wait_tensor_408 = None + mm_157 = torch.ops.aten.mm.default(view_1264, permute_294); permute_294 = None + mul_937 = torch.ops.aten.mul.Tensor(convert_element_type_1054, mm_157); convert_element_type_1054 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(primals_326, torch.bfloat16) + all_gather_into_tensor_333 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1058, 64, '0'); convert_element_type_1058 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_333); all_gather_into_tensor_333 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_409, [1, 0]); wait_tensor_409 = None + mm_158 = torch.ops.aten.mm.default(mul_937, permute_295); permute_295 = None + index_put_37 = torch.ops.aten.index_put.default(full_default_1, [getitem_273], wait_tensor_406); wait_tensor_406 = None + view_1304 = torch.ops.aten.view.default(mul_899, [-1, 1, 6]); mul_899 = None + view_1305 = torch.ops.aten.view.default(index_put_37, [-1, 6, 2048]); index_put_37 = None + convert_element_type_1061 = torch.ops.prims.convert_element_type.default(view_1305, torch.float32); view_1305 = None + bmm_18 = torch.ops.aten.bmm.default(view_1304, convert_element_type_1061) + convert_element_type_1062 = torch.ops.prims.convert_element_type.default(bmm_18, torch.bfloat16); bmm_18 = None + squeeze_18 = torch.ops.aten.squeeze.dim(convert_element_type_1062, 1); convert_element_type_1062 = None + add_1296 = torch.ops.aten.add.Tensor(mm_158, squeeze_18); mm_158 = squeeze_18 = None + view_1306 = torch.ops.aten.view.default(add_1296, [2, 4096, 2048]); add_1296 = None + add_1297 = torch.ops.aten.add.Tensor(add_1232, view_1306); view_1306 = None + convert_element_type_1063 = torch.ops.prims.convert_element_type.default(primals_327, torch.bfloat16) + all_gather_into_tensor_334 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1063, 64, '0'); convert_element_type_1063 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_334); all_gather_into_tensor_334 = None + convert_element_type_1064 = torch.ops.prims.convert_element_type.default(add_1297, torch.float32) + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1064, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_1298 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_1298); add_1298 = None + mul_940 = torch.ops.aten.mul.Tensor(convert_element_type_1064, rsqrt_60); convert_element_type_1064 = None + mul_941 = torch.ops.aten.mul.Tensor(mul_940, wait_tensor_410); mul_940 = wait_tensor_410 = None + convert_element_type_1065 = torch.ops.prims.convert_element_type.default(mul_941, torch.bfloat16); mul_941 = None + convert_element_type_1066 = torch.ops.prims.convert_element_type.default(primals_328, torch.bfloat16) + all_gather_into_tensor_335 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1066, 64, '0'); convert_element_type_1066 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_335); all_gather_into_tensor_335 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_411, [1, 0]); wait_tensor_411 = None + view_1309 = torch.ops.aten.view.default(convert_element_type_1065, [8192, 2048]); convert_element_type_1065 = None + mm_159 = torch.ops.aten.mm.default(view_1309, permute_296); permute_296 = None + view_1310 = torch.ops.aten.view.default(mm_159, [2, 4096, 3072]); mm_159 = None + view_1311 = torch.ops.aten.view.default(view_1310, [2, 4096, -1, 192]); view_1310 = None + split_with_sizes_60 = torch.ops.aten.split_with_sizes.default(view_1311, [128, 64], -1); view_1311 = None + getitem_275 = split_with_sizes_60[0] + getitem_276 = split_with_sizes_60[1]; split_with_sizes_60 = None + convert_element_type_1069 = torch.ops.prims.convert_element_type.default(getitem_276, torch.float32); getitem_276 = None + view_1312 = torch.ops.aten.view.default(convert_element_type_1069, [2, 4096, 16, -1, 2]); convert_element_type_1069 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_1312); view_1312 = None + mul_942 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_7); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_942); mul_942 = None + view_1314 = torch.ops.aten.view.default(view_as_real_40, [2, 4096, 16, 64]); view_as_real_40 = None + convert_element_type_1070 = torch.ops.prims.convert_element_type.default(view_1314, torch.bfloat16); view_1314 = None + cat_59 = torch.ops.aten.cat.default([getitem_275, convert_element_type_1070], -1); getitem_275 = convert_element_type_1070 = None + convert_element_type_1071 = torch.ops.prims.convert_element_type.default(primals_329, torch.bfloat16) + all_gather_into_tensor_336 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1071, 64, '0'); convert_element_type_1071 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_336); all_gather_into_tensor_336 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_412, [1, 0]); wait_tensor_412 = None + mm_160 = torch.ops.aten.mm.default(view_1309, permute_297); permute_297 = None + view_1317 = torch.ops.aten.view.default(mm_160, [2, 4096, 576]); mm_160 = None + split_with_sizes_61 = torch.ops.aten.split_with_sizes.default(view_1317, [512, 64], -1); view_1317 = None + getitem_277 = split_with_sizes_61[0] + getitem_278 = split_with_sizes_61[1]; split_with_sizes_61 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(getitem_278, 2); getitem_278 = None + convert_element_type_1074 = torch.ops.prims.convert_element_type.default(unsqueeze_39, torch.float32); unsqueeze_39 = None + view_1318 = torch.ops.aten.view.default(convert_element_type_1074, [2, 4096, 1, -1, 2]); convert_element_type_1074 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_1318); view_1318 = None + mul_943 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_7); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_943); mul_943 = None + view_1320 = torch.ops.aten.view.default(view_as_real_41, [2, 4096, 1, 64]); view_as_real_41 = None + convert_element_type_1075 = torch.ops.prims.convert_element_type.default(view_1320, torch.bfloat16); view_1320 = None + convert_element_type_1076 = torch.ops.prims.convert_element_type.default(primals_330, torch.bfloat16) + all_gather_into_tensor_337 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1076, 64, '0'); convert_element_type_1076 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_337); all_gather_into_tensor_337 = None + convert_element_type_1077 = torch.ops.prims.convert_element_type.default(getitem_277, torch.float32) + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1077, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_1299 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_1299); add_1299 = None + mul_944 = torch.ops.aten.mul.Tensor(convert_element_type_1077, rsqrt_61); convert_element_type_1077 = None + mul_945 = torch.ops.aten.mul.Tensor(mul_944, wait_tensor_413); mul_944 = wait_tensor_413 = None + convert_element_type_1078 = torch.ops.prims.convert_element_type.default(mul_945, torch.bfloat16); mul_945 = None + convert_element_type_1079 = torch.ops.prims.convert_element_type.default(primals_331, torch.bfloat16) + all_gather_into_tensor_338 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1079, 64, '0'); convert_element_type_1079 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_338); all_gather_into_tensor_338 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_414, [1, 0]); wait_tensor_414 = None + view_1323 = torch.ops.aten.view.default(convert_element_type_1078, [8192, 512]); convert_element_type_1078 = None + mm_161 = torch.ops.aten.mm.default(view_1323, permute_298); permute_298 = None + view_1324 = torch.ops.aten.view.default(mm_161, [2, 4096, 4096]); mm_161 = None + view_1325 = torch.ops.aten.view.default(view_1324, [2, 4096, -1, 256]); view_1324 = None + split_with_sizes_62 = torch.ops.aten.split_with_sizes.default(view_1325, [128, 128], -1); view_1325 = None + getitem_279 = split_with_sizes_62[0] + getitem_280 = split_with_sizes_62[1]; split_with_sizes_62 = None + expand_20 = torch.ops.aten.expand.default(convert_element_type_1075, [-1, -1, 16, -1]); convert_element_type_1075 = None + cat_60 = torch.ops.aten.cat.default([getitem_279, expand_20], -1); getitem_279 = expand_20 = None + permute_299 = torch.ops.aten.permute.default(cat_59, [0, 2, 1, 3]); cat_59 = None + permute_300 = torch.ops.aten.permute.default(cat_60, [0, 2, 1, 3]); cat_60 = None + permute_301 = torch.ops.aten.permute.default(getitem_280, [0, 2, 1, 3]); getitem_280 = None + sdpa_score20 = self.sdpa_score20 + sdpa_mask20 = self.sdpa_mask20 + flex_attention_20 = torch.ops.higher_order.flex_attention(permute_299, permute_300, permute_301, sdpa_score20, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask20), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score20 = sdpa_mask20 = None + getitem_281 = flex_attention_20[0] + getitem_282 = flex_attention_20[1]; flex_attention_20 = None + permute_302 = torch.ops.aten.permute.default(getitem_281, [0, 2, 1, 3]) + view_1326 = torch.ops.aten.view.default(permute_302, [2, 4096, -1]); permute_302 = None + convert_element_type_1082 = torch.ops.prims.convert_element_type.default(primals_332, torch.bfloat16) + all_gather_into_tensor_339 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1082, 64, '0'); convert_element_type_1082 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_339); all_gather_into_tensor_339 = None + permute_303 = torch.ops.aten.permute.default(wait_tensor_415, [1, 0]); wait_tensor_415 = None + view_1328 = torch.ops.aten.view.default(view_1326, [8192, 2048]); view_1326 = None + mm_162 = torch.ops.aten.mm.default(view_1328, permute_303); view_1328 = permute_303 = None + view_1329 = torch.ops.aten.view.default(mm_162, [2, 4096, 2048]); mm_162 = None + add_1300 = torch.ops.aten.add.Tensor(add_1297, view_1329); view_1329 = None + convert_element_type_1085 = torch.ops.prims.convert_element_type.default(primals_333, torch.bfloat16) + all_gather_into_tensor_340 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1085, 64, '0'); convert_element_type_1085 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_340); all_gather_into_tensor_340 = None + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(add_1300, torch.float32) + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1086, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_1301 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_1301); add_1301 = None + mul_946 = torch.ops.aten.mul.Tensor(convert_element_type_1086, rsqrt_62); convert_element_type_1086 = None + mul_947 = torch.ops.aten.mul.Tensor(mul_946, wait_tensor_416); mul_946 = wait_tensor_416 = None + convert_element_type_1087 = torch.ops.prims.convert_element_type.default(mul_947, torch.bfloat16); mul_947 = None + view_1331 = torch.ops.aten.view.default(convert_element_type_1087, [-1, 2048]); convert_element_type_1087 = None + convert_element_type_1088 = torch.ops.prims.convert_element_type.default(primals_335, torch.bfloat16) + all_gather_into_tensor_341 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1088, 64, '0'); convert_element_type_1088 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_341); all_gather_into_tensor_341 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_417, [1, 0]); wait_tensor_417 = None + mm_163 = torch.ops.aten.mm.default(view_1331, permute_304); permute_304 = None + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(mm_163, torch.float32) + amax_19 = torch.ops.aten.amax.default(convert_element_type_1091, [1], True) + sub_456 = torch.ops.aten.sub.Tensor(convert_element_type_1091, amax_19); convert_element_type_1091 = None + exp_58 = torch.ops.aten.exp.default(sub_456); sub_456 = None + sum_77 = torch.ops.aten.sum.dim_IntList(exp_58, [1], True) + div_96 = torch.ops.aten.div.Tensor(exp_58, sum_77); exp_58 = None + add_1302 = torch.ops.aten.add.Tensor(div_96, primals_334); primals_334 = None + topk_19 = torch.ops.aten.topk.default(add_1302, 6, -1, True, False); add_1302 = None + getitem_285 = topk_19[1]; topk_19 = None + gather_19 = torch.ops.aten.gather.default(div_96, 1, getitem_285); div_96 = None + mul_948 = torch.ops.aten.mul.Tensor(gather_19, 1.0); gather_19 = None + view_1333 = torch.ops.aten.view.default(getitem_285, [-1]) + histc_38 = torch.ops.aten.histc.default(view_1333, 64, 0, 64) + add_1303 = torch.ops.aten.add.Tensor(primals_336, histc_38) + sort_19 = torch.ops.aten.sort.stable(view_1333, stable = True); view_1333 = None + getitem_287 = sort_19[1]; sort_19 = None + div_97 = torch.ops.aten.div.Tensor_mode(getitem_287, 6, rounding_mode = 'floor') + index_38 = torch.ops.aten.index.Tensor(view_1331, [div_97]) + all_to_all_single_57 = torch.ops._c10d_functional.all_to_all_single.default(histc_38, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_57); all_to_all_single_57 = None + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_418); wait_tensor_418 = None + view_1337 = torch.ops.aten.view.default(histc_38, [8, -1]); histc_38 = None + sum_78 = torch.ops.aten.sum.dim_IntList(view_1337, [1]); view_1337 = None + device_put_38 = torch.ops.prims.device_put.default(sum_78, device(type='cpu'), True); sum_78 = None + view_1338 = torch.ops.aten.view.default(wait_tensor_419, [8, -1]) + sum_79 = torch.ops.aten.sum.dim_IntList(view_1338, [1]) + device_put_39 = torch.ops.prims.device_put.default(sum_79, device(type='cpu')); sum_79 = None + select_304 = torch.ops.aten.select.int(device_put_38, 0, 0) + _local_scalar_dense_304 = torch.ops.aten._local_scalar_dense.default(select_304); select_304 = None + ge_380 = _local_scalar_dense_304 >= 0 + _assert_scalar_304 = torch.ops.aten._assert_scalar.default(ge_380, "Runtime assertion failed for expression u304 >= 0 on node 'ge_304'"); ge_380 = _assert_scalar_304 = None + select_305 = torch.ops.aten.select.int(device_put_38, 0, 1) + _local_scalar_dense_305 = torch.ops.aten._local_scalar_dense.default(select_305); select_305 = None + ge_381 = _local_scalar_dense_305 >= 0 + _assert_scalar_305 = torch.ops.aten._assert_scalar.default(ge_381, "Runtime assertion failed for expression u305 >= 0 on node 'ge_305'"); ge_381 = _assert_scalar_305 = None + select_306 = torch.ops.aten.select.int(device_put_38, 0, 2) + _local_scalar_dense_306 = torch.ops.aten._local_scalar_dense.default(select_306); select_306 = None + ge_382 = _local_scalar_dense_306 >= 0 + _assert_scalar_306 = torch.ops.aten._assert_scalar.default(ge_382, "Runtime assertion failed for expression u306 >= 0 on node 'ge_306'"); ge_382 = _assert_scalar_306 = None + select_307 = torch.ops.aten.select.int(device_put_38, 0, 3) + _local_scalar_dense_307 = torch.ops.aten._local_scalar_dense.default(select_307); select_307 = None + ge_383 = _local_scalar_dense_307 >= 0 + _assert_scalar_307 = torch.ops.aten._assert_scalar.default(ge_383, "Runtime assertion failed for expression u307 >= 0 on node 'ge_307'"); ge_383 = _assert_scalar_307 = None + select_308 = torch.ops.aten.select.int(device_put_38, 0, 4) + _local_scalar_dense_308 = torch.ops.aten._local_scalar_dense.default(select_308); select_308 = None + ge_384 = _local_scalar_dense_308 >= 0 + _assert_scalar_308 = torch.ops.aten._assert_scalar.default(ge_384, "Runtime assertion failed for expression u308 >= 0 on node 'ge_308'"); ge_384 = _assert_scalar_308 = None + select_309 = torch.ops.aten.select.int(device_put_38, 0, 5) + _local_scalar_dense_309 = torch.ops.aten._local_scalar_dense.default(select_309); select_309 = None + ge_385 = _local_scalar_dense_309 >= 0 + _assert_scalar_309 = torch.ops.aten._assert_scalar.default(ge_385, "Runtime assertion failed for expression u309 >= 0 on node 'ge_309'"); ge_385 = _assert_scalar_309 = None + select_310 = torch.ops.aten.select.int(device_put_38, 0, 6) + _local_scalar_dense_310 = torch.ops.aten._local_scalar_dense.default(select_310); select_310 = None + ge_386 = _local_scalar_dense_310 >= 0 + _assert_scalar_310 = torch.ops.aten._assert_scalar.default(ge_386, "Runtime assertion failed for expression u310 >= 0 on node 'ge_310'"); ge_386 = _assert_scalar_310 = None + select_311 = torch.ops.aten.select.int(device_put_38, 0, 7); device_put_38 = None + _local_scalar_dense_311 = torch.ops.aten._local_scalar_dense.default(select_311); select_311 = None + ge_387 = _local_scalar_dense_311 >= 0 + _assert_scalar_311 = torch.ops.aten._assert_scalar.default(ge_387, "Runtime assertion failed for expression u311 >= 0 on node 'ge_311'"); ge_387 = _assert_scalar_311 = None + select_312 = torch.ops.aten.select.int(device_put_39, 0, 0) + _local_scalar_dense_312 = torch.ops.aten._local_scalar_dense.default(select_312); select_312 = None + ge_388 = _local_scalar_dense_312 >= 0 + _assert_scalar_312 = torch.ops.aten._assert_scalar.default(ge_388, "Runtime assertion failed for expression u312 >= 0 on node 'ge_312'"); ge_388 = _assert_scalar_312 = None + select_313 = torch.ops.aten.select.int(device_put_39, 0, 1) + _local_scalar_dense_313 = torch.ops.aten._local_scalar_dense.default(select_313); select_313 = None + ge_389 = _local_scalar_dense_313 >= 0 + _assert_scalar_313 = torch.ops.aten._assert_scalar.default(ge_389, "Runtime assertion failed for expression u313 >= 0 on node 'ge_313'"); ge_389 = _assert_scalar_313 = None + select_314 = torch.ops.aten.select.int(device_put_39, 0, 2) + _local_scalar_dense_314 = torch.ops.aten._local_scalar_dense.default(select_314); select_314 = None + ge_390 = _local_scalar_dense_314 >= 0 + _assert_scalar_314 = torch.ops.aten._assert_scalar.default(ge_390, "Runtime assertion failed for expression u314 >= 0 on node 'ge_314'"); ge_390 = _assert_scalar_314 = None + select_315 = torch.ops.aten.select.int(device_put_39, 0, 3) + _local_scalar_dense_315 = torch.ops.aten._local_scalar_dense.default(select_315); select_315 = None + ge_391 = _local_scalar_dense_315 >= 0 + _assert_scalar_315 = torch.ops.aten._assert_scalar.default(ge_391, "Runtime assertion failed for expression u315 >= 0 on node 'ge_315'"); ge_391 = _assert_scalar_315 = None + select_316 = torch.ops.aten.select.int(device_put_39, 0, 4) + _local_scalar_dense_316 = torch.ops.aten._local_scalar_dense.default(select_316); select_316 = None + ge_392 = _local_scalar_dense_316 >= 0 + _assert_scalar_316 = torch.ops.aten._assert_scalar.default(ge_392, "Runtime assertion failed for expression u316 >= 0 on node 'ge_316'"); ge_392 = _assert_scalar_316 = None + select_317 = torch.ops.aten.select.int(device_put_39, 0, 5) + _local_scalar_dense_317 = torch.ops.aten._local_scalar_dense.default(select_317); select_317 = None + ge_393 = _local_scalar_dense_317 >= 0 + _assert_scalar_317 = torch.ops.aten._assert_scalar.default(ge_393, "Runtime assertion failed for expression u317 >= 0 on node 'ge_317'"); ge_393 = _assert_scalar_317 = None + select_318 = torch.ops.aten.select.int(device_put_39, 0, 6) + _local_scalar_dense_318 = torch.ops.aten._local_scalar_dense.default(select_318); select_318 = None + ge_394 = _local_scalar_dense_318 >= 0 + _assert_scalar_318 = torch.ops.aten._assert_scalar.default(ge_394, "Runtime assertion failed for expression u318 >= 0 on node 'ge_318'"); ge_394 = _assert_scalar_318 = None + select_319 = torch.ops.aten.select.int(device_put_39, 0, 7); device_put_39 = None + _local_scalar_dense_319 = torch.ops.aten._local_scalar_dense.default(select_319); select_319 = None + ge_395 = _local_scalar_dense_319 >= 0 + _assert_scalar_319 = torch.ops.aten._assert_scalar.default(ge_395, "Runtime assertion failed for expression u319 >= 0 on node 'ge_319'"); ge_395 = _assert_scalar_319 = None + all_to_all_single_58 = torch.ops._c10d_functional.all_to_all_single.default(index_38, [_local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319], [_local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311], '521'); index_38 = None + sym_size_int_76 = torch.ops.aten.sym_size.int(all_to_all_single_58, 0) + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_58); all_to_all_single_58 = None + sym_sum_38 = torch.sym_sum((_local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319)) + add_1310 = sym_sum_38 + 64; sym_sum_38 = None + add_1311 = add_1310 + 8; add_1310 = None + sub_459 = add_1311 - 1; add_1311 = None + floordiv_19 = sub_459 // 8; sub_459 = None + mul_953 = floordiv_19 * 8; floordiv_19 = None + cumsum_57 = torch.ops.aten.cumsum.default(wait_tensor_419, 0) + sub_460 = torch.ops.aten.sub.Tensor(cumsum_57, wait_tensor_419); cumsum_57 = None + sum_80 = torch.ops.aten.sum.dim_IntList(view_1338, [0]); view_1338 = None + clamp_min_19 = torch.ops.aten.clamp_min.default(sum_80, 8); sum_80 = None + add_1312 = torch.ops.aten.add.Tensor(clamp_min_19, 8); clamp_min_19 = None + sub_461 = torch.ops.aten.sub.Tensor(add_1312, 1); add_1312 = None + div_98 = torch.ops.aten.div.Tensor_mode(sub_461, 8, rounding_mode = 'floor'); sub_461 = None + mul_954 = torch.ops.aten.mul.Tensor(div_98, 8); div_98 = None + convert_element_type_1094 = torch.ops.prims.convert_element_type.default(mul_954, torch.int32); mul_954 = None + cumsum_58 = torch.ops.aten.cumsum.default(convert_element_type_1094, 0) + sub_462 = torch.ops.aten.sub.Tensor(cumsum_58, convert_element_type_1094); cumsum_58 = None + full_267 = torch.ops.aten.full.default([mul_953], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_953 = None + triton_kernel_wrapper_functional_proxy_19 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 19, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_419, 'start_index_values_ptr': sub_460, 'write_offsets_ptr': sub_462, 'output_ptr': full_267}, tensors_to_clone = ['output_ptr']); wait_tensor_419 = sub_460 = sub_462 = full_267 = None + getitem_288 = triton_kernel_wrapper_functional_proxy_19['output_ptr']; triton_kernel_wrapper_functional_proxy_19 = None + cat_61 = torch.ops.aten.cat.default([wait_tensor_420, full_default]); wait_tensor_420 = None + sym_size_int_77 = torch.ops.aten.sym_size.int(cat_61, 0) + sym_sum_39 = torch.sym_sum((1, _local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319)) + index_39 = torch.ops.aten.index.Tensor(cat_61, [getitem_288]); cat_61 = None + convert_element_type_1096 = torch.ops.prims.convert_element_type.default(primals_337, torch.bfloat16) + all_gather_into_tensor_342 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1096, 8, '513'); convert_element_type_1096 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_342); all_gather_into_tensor_342 = None + convert_element_type_1098 = torch.ops.prims.convert_element_type.default(primals_338, torch.bfloat16) + all_gather_into_tensor_344 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1098, 8, '513'); convert_element_type_1098 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_344); all_gather_into_tensor_344 = None + convert_element_type_1099 = torch.ops.prims.convert_element_type.default(primals_339, torch.bfloat16) + all_gather_into_tensor_345 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1099, 8, '513'); convert_element_type_1099 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_345); all_gather_into_tensor_345 = None + cumsum_59 = torch.ops.aten.cumsum.default(convert_element_type_1094, 0, dtype = torch.int32); convert_element_type_1094 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_421, [0, 2, 1]); wait_tensor_421 = None + _grouped_mm_57 = torch.ops.aten._grouped_mm.default(index_39, permute_305, cumsum_59); permute_305 = None + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(_grouped_mm_57, torch.float32) + neg_39 = torch.ops.aten.neg.default(convert_element_type_1102) + exp_59 = torch.ops.aten.exp.default(neg_39); neg_39 = None + add_1324 = torch.ops.aten.add.Tensor(exp_59, 1); exp_59 = None + div_99 = torch.ops.aten.div.Tensor(convert_element_type_1102, add_1324); convert_element_type_1102 = add_1324 = None + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(div_99, torch.bfloat16); div_99 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_424, [0, 2, 1]); wait_tensor_424 = None + _grouped_mm_58 = torch.ops.aten._grouped_mm.default(index_39, permute_306, cumsum_59); permute_306 = None + mul_966 = torch.ops.aten.mul.Tensor(convert_element_type_1103, _grouped_mm_58); convert_element_type_1103 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_423, [0, 2, 1]); wait_tensor_423 = None + _grouped_mm_59 = torch.ops.aten._grouped_mm.default(mul_966, permute_307, cumsum_59); permute_307 = None + empty_19 = torch.ops.aten.empty.memory_format([sym_size_int_77, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_38 = torch.ops.aten.index_put.default(empty_19, [getitem_288], _grouped_mm_59); empty_19 = _grouped_mm_59 = None + slice_82 = torch.ops.aten.slice.Tensor(index_put_38, 0, 0, -1); index_put_38 = None + all_to_all_single_59 = torch.ops._c10d_functional.all_to_all_single.default(slice_82, [_local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311], [_local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319], '521'); slice_82 = None + wait_tensor_427 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_59); all_to_all_single_59 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(primals_340, torch.bfloat16) + all_gather_into_tensor_348 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1104, 64, '0'); convert_element_type_1104 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_348); all_gather_into_tensor_348 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_428, [1, 0]); wait_tensor_428 = None + mm_164 = torch.ops.aten.mm.default(view_1331, permute_308); permute_308 = None + convert_element_type_1107 = torch.ops.prims.convert_element_type.default(mm_164, torch.float32) + neg_40 = torch.ops.aten.neg.default(convert_element_type_1107) + exp_60 = torch.ops.aten.exp.default(neg_40); neg_40 = None + add_1360 = torch.ops.aten.add.Tensor(exp_60, 1); exp_60 = None + div_100 = torch.ops.aten.div.Tensor(convert_element_type_1107, add_1360); convert_element_type_1107 = add_1360 = None + convert_element_type_1108 = torch.ops.prims.convert_element_type.default(div_100, torch.bfloat16); div_100 = None + convert_element_type_1109 = torch.ops.prims.convert_element_type.default(primals_341, torch.bfloat16) + all_gather_into_tensor_349 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1109, 64, '0'); convert_element_type_1109 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_349); all_gather_into_tensor_349 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_429, [1, 0]); wait_tensor_429 = None + mm_165 = torch.ops.aten.mm.default(view_1331, permute_309); permute_309 = None + mul_986 = torch.ops.aten.mul.Tensor(convert_element_type_1108, mm_165); convert_element_type_1108 = None + convert_element_type_1112 = torch.ops.prims.convert_element_type.default(primals_342, torch.bfloat16) + all_gather_into_tensor_350 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1112, 64, '0'); convert_element_type_1112 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_350); all_gather_into_tensor_350 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_430, [1, 0]); wait_tensor_430 = None + mm_166 = torch.ops.aten.mm.default(mul_986, permute_310); permute_310 = None + index_put_39 = torch.ops.aten.index_put.default(full_default_1, [getitem_287], wait_tensor_427); wait_tensor_427 = None + view_1371 = torch.ops.aten.view.default(mul_948, [-1, 1, 6]); mul_948 = None + view_1372 = torch.ops.aten.view.default(index_put_39, [-1, 6, 2048]); index_put_39 = None + convert_element_type_1115 = torch.ops.prims.convert_element_type.default(view_1372, torch.float32); view_1372 = None + bmm_19 = torch.ops.aten.bmm.default(view_1371, convert_element_type_1115) + convert_element_type_1116 = torch.ops.prims.convert_element_type.default(bmm_19, torch.bfloat16); bmm_19 = None + squeeze_19 = torch.ops.aten.squeeze.dim(convert_element_type_1116, 1); convert_element_type_1116 = None + add_1364 = torch.ops.aten.add.Tensor(mm_166, squeeze_19); mm_166 = squeeze_19 = None + view_1373 = torch.ops.aten.view.default(add_1364, [2, 4096, 2048]); add_1364 = None + add_1365 = torch.ops.aten.add.Tensor(add_1300, view_1373); view_1373 = None + convert_element_type_1117 = torch.ops.prims.convert_element_type.default(primals_343, torch.bfloat16) + all_gather_into_tensor_351 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1117, 64, '0'); convert_element_type_1117 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_351); all_gather_into_tensor_351 = None + convert_element_type_1118 = torch.ops.prims.convert_element_type.default(add_1365, torch.float32) + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1118, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_1366 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_1366); add_1366 = None + mul_989 = torch.ops.aten.mul.Tensor(convert_element_type_1118, rsqrt_63); convert_element_type_1118 = None + mul_990 = torch.ops.aten.mul.Tensor(mul_989, wait_tensor_431); mul_989 = wait_tensor_431 = None + convert_element_type_1119 = torch.ops.prims.convert_element_type.default(mul_990, torch.bfloat16); mul_990 = None + convert_element_type_1120 = torch.ops.prims.convert_element_type.default(primals_344, torch.bfloat16) + all_gather_into_tensor_352 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1120, 64, '0'); convert_element_type_1120 = None + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_352); all_gather_into_tensor_352 = None + permute_311 = torch.ops.aten.permute.default(wait_tensor_432, [1, 0]); wait_tensor_432 = None + view_1376 = torch.ops.aten.view.default(convert_element_type_1119, [8192, 2048]); convert_element_type_1119 = None + mm_167 = torch.ops.aten.mm.default(view_1376, permute_311); permute_311 = None + view_1377 = torch.ops.aten.view.default(mm_167, [2, 4096, 3072]); mm_167 = None + view_1378 = torch.ops.aten.view.default(view_1377, [2, 4096, -1, 192]); view_1377 = None + split_with_sizes_63 = torch.ops.aten.split_with_sizes.default(view_1378, [128, 64], -1); view_1378 = None + getitem_289 = split_with_sizes_63[0] + getitem_290 = split_with_sizes_63[1]; split_with_sizes_63 = None + convert_element_type_1123 = torch.ops.prims.convert_element_type.default(getitem_290, torch.float32); getitem_290 = None + view_1379 = torch.ops.aten.view.default(convert_element_type_1123, [2, 4096, 16, -1, 2]); convert_element_type_1123 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_1379); view_1379 = None + mul_991 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_7); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_991); mul_991 = None + view_1381 = torch.ops.aten.view.default(view_as_real_42, [2, 4096, 16, 64]); view_as_real_42 = None + convert_element_type_1124 = torch.ops.prims.convert_element_type.default(view_1381, torch.bfloat16); view_1381 = None + cat_62 = torch.ops.aten.cat.default([getitem_289, convert_element_type_1124], -1); getitem_289 = convert_element_type_1124 = None + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(primals_345, torch.bfloat16) + all_gather_into_tensor_353 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1125, 64, '0'); convert_element_type_1125 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_353); all_gather_into_tensor_353 = None + permute_312 = torch.ops.aten.permute.default(wait_tensor_433, [1, 0]); wait_tensor_433 = None + mm_168 = torch.ops.aten.mm.default(view_1376, permute_312); permute_312 = None + view_1384 = torch.ops.aten.view.default(mm_168, [2, 4096, 576]); mm_168 = None + split_with_sizes_64 = torch.ops.aten.split_with_sizes.default(view_1384, [512, 64], -1); view_1384 = None + getitem_291 = split_with_sizes_64[0] + getitem_292 = split_with_sizes_64[1]; split_with_sizes_64 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(getitem_292, 2); getitem_292 = None + convert_element_type_1128 = torch.ops.prims.convert_element_type.default(unsqueeze_41, torch.float32); unsqueeze_41 = None + view_1385 = torch.ops.aten.view.default(convert_element_type_1128, [2, 4096, 1, -1, 2]); convert_element_type_1128 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_1385); view_1385 = None + mul_992 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_7); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_992); mul_992 = None + view_1387 = torch.ops.aten.view.default(view_as_real_43, [2, 4096, 1, 64]); view_as_real_43 = None + convert_element_type_1129 = torch.ops.prims.convert_element_type.default(view_1387, torch.bfloat16); view_1387 = None + convert_element_type_1130 = torch.ops.prims.convert_element_type.default(primals_346, torch.bfloat16) + all_gather_into_tensor_354 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1130, 64, '0'); convert_element_type_1130 = None + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_354); all_gather_into_tensor_354 = None + convert_element_type_1131 = torch.ops.prims.convert_element_type.default(getitem_291, torch.float32) + pow_65 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1131, 2) + mean_64 = torch.ops.aten.mean.dim(pow_65, [2], True); pow_65 = None + add_1367 = torch.ops.aten.add.Scalar(mean_64, 1e-05); mean_64 = None + rsqrt_64 = torch.ops.aten.rsqrt.default(add_1367); add_1367 = None + mul_993 = torch.ops.aten.mul.Tensor(convert_element_type_1131, rsqrt_64); convert_element_type_1131 = None + mul_994 = torch.ops.aten.mul.Tensor(mul_993, wait_tensor_434); mul_993 = wait_tensor_434 = None + convert_element_type_1132 = torch.ops.prims.convert_element_type.default(mul_994, torch.bfloat16); mul_994 = None + convert_element_type_1133 = torch.ops.prims.convert_element_type.default(primals_347, torch.bfloat16) + all_gather_into_tensor_355 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1133, 64, '0'); convert_element_type_1133 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_355); all_gather_into_tensor_355 = None + permute_313 = torch.ops.aten.permute.default(wait_tensor_435, [1, 0]); wait_tensor_435 = None + view_1390 = torch.ops.aten.view.default(convert_element_type_1132, [8192, 512]); convert_element_type_1132 = None + mm_169 = torch.ops.aten.mm.default(view_1390, permute_313); permute_313 = None + view_1391 = torch.ops.aten.view.default(mm_169, [2, 4096, 4096]); mm_169 = None + view_1392 = torch.ops.aten.view.default(view_1391, [2, 4096, -1, 256]); view_1391 = None + split_with_sizes_65 = torch.ops.aten.split_with_sizes.default(view_1392, [128, 128], -1); view_1392 = None + getitem_293 = split_with_sizes_65[0] + getitem_294 = split_with_sizes_65[1]; split_with_sizes_65 = None + expand_21 = torch.ops.aten.expand.default(convert_element_type_1129, [-1, -1, 16, -1]); convert_element_type_1129 = None + cat_63 = torch.ops.aten.cat.default([getitem_293, expand_21], -1); getitem_293 = expand_21 = None + permute_314 = torch.ops.aten.permute.default(cat_62, [0, 2, 1, 3]); cat_62 = None + permute_315 = torch.ops.aten.permute.default(cat_63, [0, 2, 1, 3]); cat_63 = None + permute_316 = torch.ops.aten.permute.default(getitem_294, [0, 2, 1, 3]); getitem_294 = None + sdpa_score21 = self.sdpa_score21 + sdpa_mask21 = self.sdpa_mask21 + flex_attention_21 = torch.ops.higher_order.flex_attention(permute_314, permute_315, permute_316, sdpa_score21, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask21), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score21 = sdpa_mask21 = None + getitem_295 = flex_attention_21[0] + getitem_296 = flex_attention_21[1]; flex_attention_21 = None + permute_317 = torch.ops.aten.permute.default(getitem_295, [0, 2, 1, 3]) + view_1393 = torch.ops.aten.view.default(permute_317, [2, 4096, -1]); permute_317 = None + convert_element_type_1136 = torch.ops.prims.convert_element_type.default(primals_348, torch.bfloat16) + all_gather_into_tensor_356 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1136, 64, '0'); convert_element_type_1136 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_356); all_gather_into_tensor_356 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_436, [1, 0]); wait_tensor_436 = None + view_1395 = torch.ops.aten.view.default(view_1393, [8192, 2048]); view_1393 = None + mm_170 = torch.ops.aten.mm.default(view_1395, permute_318); view_1395 = permute_318 = None + view_1396 = torch.ops.aten.view.default(mm_170, [2, 4096, 2048]); mm_170 = None + add_1368 = torch.ops.aten.add.Tensor(add_1365, view_1396); view_1396 = None + convert_element_type_1139 = torch.ops.prims.convert_element_type.default(primals_349, torch.bfloat16) + all_gather_into_tensor_357 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1139, 64, '0'); convert_element_type_1139 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_357); all_gather_into_tensor_357 = None + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(add_1368, torch.float32) + pow_66 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1140, 2) + mean_65 = torch.ops.aten.mean.dim(pow_66, [2], True); pow_66 = None + add_1369 = torch.ops.aten.add.Scalar(mean_65, 1e-05); mean_65 = None + rsqrt_65 = torch.ops.aten.rsqrt.default(add_1369); add_1369 = None + mul_995 = torch.ops.aten.mul.Tensor(convert_element_type_1140, rsqrt_65); convert_element_type_1140 = None + mul_996 = torch.ops.aten.mul.Tensor(mul_995, wait_tensor_437); mul_995 = wait_tensor_437 = None + convert_element_type_1141 = torch.ops.prims.convert_element_type.default(mul_996, torch.bfloat16); mul_996 = None + view_1398 = torch.ops.aten.view.default(convert_element_type_1141, [-1, 2048]); convert_element_type_1141 = None + convert_element_type_1142 = torch.ops.prims.convert_element_type.default(primals_351, torch.bfloat16) + all_gather_into_tensor_358 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1142, 64, '0'); convert_element_type_1142 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_358); all_gather_into_tensor_358 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_438, [1, 0]); wait_tensor_438 = None + mm_171 = torch.ops.aten.mm.default(view_1398, permute_319); permute_319 = None + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(mm_171, torch.float32) + amax_20 = torch.ops.aten.amax.default(convert_element_type_1145, [1], True) + sub_480 = torch.ops.aten.sub.Tensor(convert_element_type_1145, amax_20); convert_element_type_1145 = None + exp_61 = torch.ops.aten.exp.default(sub_480); sub_480 = None + sum_81 = torch.ops.aten.sum.dim_IntList(exp_61, [1], True) + div_101 = torch.ops.aten.div.Tensor(exp_61, sum_81); exp_61 = None + add_1370 = torch.ops.aten.add.Tensor(div_101, primals_350); primals_350 = None + topk_20 = torch.ops.aten.topk.default(add_1370, 6, -1, True, False); add_1370 = None + getitem_299 = topk_20[1]; topk_20 = None + gather_20 = torch.ops.aten.gather.default(div_101, 1, getitem_299); div_101 = None + mul_997 = torch.ops.aten.mul.Tensor(gather_20, 1.0); gather_20 = None + view_1400 = torch.ops.aten.view.default(getitem_299, [-1]) + histc_40 = torch.ops.aten.histc.default(view_1400, 64, 0, 64) + add_1371 = torch.ops.aten.add.Tensor(primals_352, histc_40) + sort_20 = torch.ops.aten.sort.stable(view_1400, stable = True); view_1400 = None + getitem_301 = sort_20[1]; sort_20 = None + div_102 = torch.ops.aten.div.Tensor_mode(getitem_301, 6, rounding_mode = 'floor') + index_40 = torch.ops.aten.index.Tensor(view_1398, [div_102]) + all_to_all_single_60 = torch.ops._c10d_functional.all_to_all_single.default(histc_40, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_439 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_60); all_to_all_single_60 = None + wait_tensor_440 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_439); wait_tensor_439 = None + view_1404 = torch.ops.aten.view.default(histc_40, [8, -1]); histc_40 = None + sum_82 = torch.ops.aten.sum.dim_IntList(view_1404, [1]); view_1404 = None + device_put_40 = torch.ops.prims.device_put.default(sum_82, device(type='cpu'), True); sum_82 = None + view_1405 = torch.ops.aten.view.default(wait_tensor_440, [8, -1]) + sum_83 = torch.ops.aten.sum.dim_IntList(view_1405, [1]) + device_put_41 = torch.ops.prims.device_put.default(sum_83, device(type='cpu')); sum_83 = None + select_320 = torch.ops.aten.select.int(device_put_40, 0, 0) + _local_scalar_dense_320 = torch.ops.aten._local_scalar_dense.default(select_320); select_320 = None + ge_400 = _local_scalar_dense_320 >= 0 + _assert_scalar_320 = torch.ops.aten._assert_scalar.default(ge_400, "Runtime assertion failed for expression u320 >= 0 on node 'ge_320'"); ge_400 = _assert_scalar_320 = None + select_321 = torch.ops.aten.select.int(device_put_40, 0, 1) + _local_scalar_dense_321 = torch.ops.aten._local_scalar_dense.default(select_321); select_321 = None + ge_401 = _local_scalar_dense_321 >= 0 + _assert_scalar_321 = torch.ops.aten._assert_scalar.default(ge_401, "Runtime assertion failed for expression u321 >= 0 on node 'ge_321'"); ge_401 = _assert_scalar_321 = None + select_322 = torch.ops.aten.select.int(device_put_40, 0, 2) + _local_scalar_dense_322 = torch.ops.aten._local_scalar_dense.default(select_322); select_322 = None + ge_402 = _local_scalar_dense_322 >= 0 + _assert_scalar_322 = torch.ops.aten._assert_scalar.default(ge_402, "Runtime assertion failed for expression u322 >= 0 on node 'ge_322'"); ge_402 = _assert_scalar_322 = None + select_323 = torch.ops.aten.select.int(device_put_40, 0, 3) + _local_scalar_dense_323 = torch.ops.aten._local_scalar_dense.default(select_323); select_323 = None + ge_403 = _local_scalar_dense_323 >= 0 + _assert_scalar_323 = torch.ops.aten._assert_scalar.default(ge_403, "Runtime assertion failed for expression u323 >= 0 on node 'ge_323'"); ge_403 = _assert_scalar_323 = None + select_324 = torch.ops.aten.select.int(device_put_40, 0, 4) + _local_scalar_dense_324 = torch.ops.aten._local_scalar_dense.default(select_324); select_324 = None + ge_404 = _local_scalar_dense_324 >= 0 + _assert_scalar_324 = torch.ops.aten._assert_scalar.default(ge_404, "Runtime assertion failed for expression u324 >= 0 on node 'ge_324'"); ge_404 = _assert_scalar_324 = None + select_325 = torch.ops.aten.select.int(device_put_40, 0, 5) + _local_scalar_dense_325 = torch.ops.aten._local_scalar_dense.default(select_325); select_325 = None + ge_405 = _local_scalar_dense_325 >= 0 + _assert_scalar_325 = torch.ops.aten._assert_scalar.default(ge_405, "Runtime assertion failed for expression u325 >= 0 on node 'ge_325'"); ge_405 = _assert_scalar_325 = None + select_326 = torch.ops.aten.select.int(device_put_40, 0, 6) + _local_scalar_dense_326 = torch.ops.aten._local_scalar_dense.default(select_326); select_326 = None + ge_406 = _local_scalar_dense_326 >= 0 + _assert_scalar_326 = torch.ops.aten._assert_scalar.default(ge_406, "Runtime assertion failed for expression u326 >= 0 on node 'ge_326'"); ge_406 = _assert_scalar_326 = None + select_327 = torch.ops.aten.select.int(device_put_40, 0, 7); device_put_40 = None + _local_scalar_dense_327 = torch.ops.aten._local_scalar_dense.default(select_327); select_327 = None + ge_407 = _local_scalar_dense_327 >= 0 + _assert_scalar_327 = torch.ops.aten._assert_scalar.default(ge_407, "Runtime assertion failed for expression u327 >= 0 on node 'ge_327'"); ge_407 = _assert_scalar_327 = None + select_328 = torch.ops.aten.select.int(device_put_41, 0, 0) + _local_scalar_dense_328 = torch.ops.aten._local_scalar_dense.default(select_328); select_328 = None + ge_408 = _local_scalar_dense_328 >= 0 + _assert_scalar_328 = torch.ops.aten._assert_scalar.default(ge_408, "Runtime assertion failed for expression u328 >= 0 on node 'ge_328'"); ge_408 = _assert_scalar_328 = None + select_329 = torch.ops.aten.select.int(device_put_41, 0, 1) + _local_scalar_dense_329 = torch.ops.aten._local_scalar_dense.default(select_329); select_329 = None + ge_409 = _local_scalar_dense_329 >= 0 + _assert_scalar_329 = torch.ops.aten._assert_scalar.default(ge_409, "Runtime assertion failed for expression u329 >= 0 on node 'ge_329'"); ge_409 = _assert_scalar_329 = None + select_330 = torch.ops.aten.select.int(device_put_41, 0, 2) + _local_scalar_dense_330 = torch.ops.aten._local_scalar_dense.default(select_330); select_330 = None + ge_410 = _local_scalar_dense_330 >= 0 + _assert_scalar_330 = torch.ops.aten._assert_scalar.default(ge_410, "Runtime assertion failed for expression u330 >= 0 on node 'ge_330'"); ge_410 = _assert_scalar_330 = None + select_331 = torch.ops.aten.select.int(device_put_41, 0, 3) + _local_scalar_dense_331 = torch.ops.aten._local_scalar_dense.default(select_331); select_331 = None + ge_411 = _local_scalar_dense_331 >= 0 + _assert_scalar_331 = torch.ops.aten._assert_scalar.default(ge_411, "Runtime assertion failed for expression u331 >= 0 on node 'ge_331'"); ge_411 = _assert_scalar_331 = None + select_332 = torch.ops.aten.select.int(device_put_41, 0, 4) + _local_scalar_dense_332 = torch.ops.aten._local_scalar_dense.default(select_332); select_332 = None + ge_412 = _local_scalar_dense_332 >= 0 + _assert_scalar_332 = torch.ops.aten._assert_scalar.default(ge_412, "Runtime assertion failed for expression u332 >= 0 on node 'ge_332'"); ge_412 = _assert_scalar_332 = None + select_333 = torch.ops.aten.select.int(device_put_41, 0, 5) + _local_scalar_dense_333 = torch.ops.aten._local_scalar_dense.default(select_333); select_333 = None + ge_413 = _local_scalar_dense_333 >= 0 + _assert_scalar_333 = torch.ops.aten._assert_scalar.default(ge_413, "Runtime assertion failed for expression u333 >= 0 on node 'ge_333'"); ge_413 = _assert_scalar_333 = None + select_334 = torch.ops.aten.select.int(device_put_41, 0, 6) + _local_scalar_dense_334 = torch.ops.aten._local_scalar_dense.default(select_334); select_334 = None + ge_414 = _local_scalar_dense_334 >= 0 + _assert_scalar_334 = torch.ops.aten._assert_scalar.default(ge_414, "Runtime assertion failed for expression u334 >= 0 on node 'ge_334'"); ge_414 = _assert_scalar_334 = None + select_335 = torch.ops.aten.select.int(device_put_41, 0, 7); device_put_41 = None + _local_scalar_dense_335 = torch.ops.aten._local_scalar_dense.default(select_335); select_335 = None + ge_415 = _local_scalar_dense_335 >= 0 + _assert_scalar_335 = torch.ops.aten._assert_scalar.default(ge_415, "Runtime assertion failed for expression u335 >= 0 on node 'ge_335'"); ge_415 = _assert_scalar_335 = None + all_to_all_single_61 = torch.ops._c10d_functional.all_to_all_single.default(index_40, [_local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335], [_local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327], '521'); index_40 = None + sym_size_int_80 = torch.ops.aten.sym_size.int(all_to_all_single_61, 0) + wait_tensor_441 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_61); all_to_all_single_61 = None + sym_sum_40 = torch.sym_sum((_local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335)) + add_1378 = sym_sum_40 + 64; sym_sum_40 = None + add_1379 = add_1378 + 8; add_1378 = None + sub_483 = add_1379 - 1; add_1379 = None + floordiv_20 = sub_483 // 8; sub_483 = None + mul_1002 = floordiv_20 * 8; floordiv_20 = None + cumsum_60 = torch.ops.aten.cumsum.default(wait_tensor_440, 0) + sub_484 = torch.ops.aten.sub.Tensor(cumsum_60, wait_tensor_440); cumsum_60 = None + sum_84 = torch.ops.aten.sum.dim_IntList(view_1405, [0]); view_1405 = None + clamp_min_20 = torch.ops.aten.clamp_min.default(sum_84, 8); sum_84 = None + add_1380 = torch.ops.aten.add.Tensor(clamp_min_20, 8); clamp_min_20 = None + sub_485 = torch.ops.aten.sub.Tensor(add_1380, 1); add_1380 = None + div_103 = torch.ops.aten.div.Tensor_mode(sub_485, 8, rounding_mode = 'floor'); sub_485 = None + mul_1003 = torch.ops.aten.mul.Tensor(div_103, 8); div_103 = None + convert_element_type_1148 = torch.ops.prims.convert_element_type.default(mul_1003, torch.int32); mul_1003 = None + cumsum_61 = torch.ops.aten.cumsum.default(convert_element_type_1148, 0) + sub_486 = torch.ops.aten.sub.Tensor(cumsum_61, convert_element_type_1148); cumsum_61 = None + full_280 = torch.ops.aten.full.default([mul_1002], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1002 = None + triton_kernel_wrapper_functional_proxy_20 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 20, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_440, 'start_index_values_ptr': sub_484, 'write_offsets_ptr': sub_486, 'output_ptr': full_280}, tensors_to_clone = ['output_ptr']); wait_tensor_440 = sub_484 = sub_486 = full_280 = None + getitem_302 = triton_kernel_wrapper_functional_proxy_20['output_ptr']; triton_kernel_wrapper_functional_proxy_20 = None + cat_64 = torch.ops.aten.cat.default([wait_tensor_441, full_default]); wait_tensor_441 = None + sym_size_int_81 = torch.ops.aten.sym_size.int(cat_64, 0) + sym_sum_41 = torch.sym_sum((1, _local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335)) + index_41 = torch.ops.aten.index.Tensor(cat_64, [getitem_302]); cat_64 = None + convert_element_type_1150 = torch.ops.prims.convert_element_type.default(primals_353, torch.bfloat16) + all_gather_into_tensor_359 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1150, 8, '513'); convert_element_type_1150 = None + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_359); all_gather_into_tensor_359 = None + convert_element_type_1152 = torch.ops.prims.convert_element_type.default(primals_354, torch.bfloat16) + all_gather_into_tensor_361 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1152, 8, '513'); convert_element_type_1152 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_361); all_gather_into_tensor_361 = None + convert_element_type_1153 = torch.ops.prims.convert_element_type.default(primals_355, torch.bfloat16) + all_gather_into_tensor_362 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1153, 8, '513'); convert_element_type_1153 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_362); all_gather_into_tensor_362 = None + cumsum_62 = torch.ops.aten.cumsum.default(convert_element_type_1148, 0, dtype = torch.int32); convert_element_type_1148 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_442, [0, 2, 1]); wait_tensor_442 = None + _grouped_mm_60 = torch.ops.aten._grouped_mm.default(index_41, permute_320, cumsum_62); permute_320 = None + convert_element_type_1156 = torch.ops.prims.convert_element_type.default(_grouped_mm_60, torch.float32) + neg_41 = torch.ops.aten.neg.default(convert_element_type_1156) + exp_62 = torch.ops.aten.exp.default(neg_41); neg_41 = None + add_1392 = torch.ops.aten.add.Tensor(exp_62, 1); exp_62 = None + div_104 = torch.ops.aten.div.Tensor(convert_element_type_1156, add_1392); convert_element_type_1156 = add_1392 = None + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(div_104, torch.bfloat16); div_104 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_445, [0, 2, 1]); wait_tensor_445 = None + _grouped_mm_61 = torch.ops.aten._grouped_mm.default(index_41, permute_321, cumsum_62); permute_321 = None + mul_1015 = torch.ops.aten.mul.Tensor(convert_element_type_1157, _grouped_mm_61); convert_element_type_1157 = None + permute_322 = torch.ops.aten.permute.default(wait_tensor_444, [0, 2, 1]); wait_tensor_444 = None + _grouped_mm_62 = torch.ops.aten._grouped_mm.default(mul_1015, permute_322, cumsum_62); permute_322 = None + empty_20 = torch.ops.aten.empty.memory_format([sym_size_int_81, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_40 = torch.ops.aten.index_put.default(empty_20, [getitem_302], _grouped_mm_62); empty_20 = _grouped_mm_62 = None + slice_86 = torch.ops.aten.slice.Tensor(index_put_40, 0, 0, -1); index_put_40 = None + all_to_all_single_62 = torch.ops._c10d_functional.all_to_all_single.default(slice_86, [_local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327], [_local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335], '521'); slice_86 = None + wait_tensor_448 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_62); all_to_all_single_62 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(primals_356, torch.bfloat16) + all_gather_into_tensor_365 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1158, 64, '0'); convert_element_type_1158 = None + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_365); all_gather_into_tensor_365 = None + permute_323 = torch.ops.aten.permute.default(wait_tensor_449, [1, 0]); wait_tensor_449 = None + mm_172 = torch.ops.aten.mm.default(view_1398, permute_323); permute_323 = None + convert_element_type_1161 = torch.ops.prims.convert_element_type.default(mm_172, torch.float32) + neg_42 = torch.ops.aten.neg.default(convert_element_type_1161) + exp_63 = torch.ops.aten.exp.default(neg_42); neg_42 = None + add_1428 = torch.ops.aten.add.Tensor(exp_63, 1); exp_63 = None + div_105 = torch.ops.aten.div.Tensor(convert_element_type_1161, add_1428); convert_element_type_1161 = add_1428 = None + convert_element_type_1162 = torch.ops.prims.convert_element_type.default(div_105, torch.bfloat16); div_105 = None + convert_element_type_1163 = torch.ops.prims.convert_element_type.default(primals_357, torch.bfloat16) + all_gather_into_tensor_366 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1163, 64, '0'); convert_element_type_1163 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_366); all_gather_into_tensor_366 = None + permute_324 = torch.ops.aten.permute.default(wait_tensor_450, [1, 0]); wait_tensor_450 = None + mm_173 = torch.ops.aten.mm.default(view_1398, permute_324); permute_324 = None + mul_1035 = torch.ops.aten.mul.Tensor(convert_element_type_1162, mm_173); convert_element_type_1162 = None + convert_element_type_1166 = torch.ops.prims.convert_element_type.default(primals_358, torch.bfloat16) + all_gather_into_tensor_367 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1166, 64, '0'); convert_element_type_1166 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_367); all_gather_into_tensor_367 = None + permute_325 = torch.ops.aten.permute.default(wait_tensor_451, [1, 0]); wait_tensor_451 = None + mm_174 = torch.ops.aten.mm.default(mul_1035, permute_325); permute_325 = None + index_put_41 = torch.ops.aten.index_put.default(full_default_1, [getitem_301], wait_tensor_448); wait_tensor_448 = None + view_1438 = torch.ops.aten.view.default(mul_997, [-1, 1, 6]); mul_997 = None + view_1439 = torch.ops.aten.view.default(index_put_41, [-1, 6, 2048]); index_put_41 = None + convert_element_type_1169 = torch.ops.prims.convert_element_type.default(view_1439, torch.float32); view_1439 = None + bmm_20 = torch.ops.aten.bmm.default(view_1438, convert_element_type_1169) + convert_element_type_1170 = torch.ops.prims.convert_element_type.default(bmm_20, torch.bfloat16); bmm_20 = None + squeeze_20 = torch.ops.aten.squeeze.dim(convert_element_type_1170, 1); convert_element_type_1170 = None + add_1432 = torch.ops.aten.add.Tensor(mm_174, squeeze_20); mm_174 = squeeze_20 = None + view_1440 = torch.ops.aten.view.default(add_1432, [2, 4096, 2048]); add_1432 = None + add_1433 = torch.ops.aten.add.Tensor(add_1368, view_1440); view_1440 = None + convert_element_type_1171 = torch.ops.prims.convert_element_type.default(primals_359, torch.bfloat16) + all_gather_into_tensor_368 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1171, 64, '0'); convert_element_type_1171 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_368); all_gather_into_tensor_368 = None + convert_element_type_1172 = torch.ops.prims.convert_element_type.default(add_1433, torch.float32) + pow_67 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1172, 2) + mean_66 = torch.ops.aten.mean.dim(pow_67, [2], True); pow_67 = None + add_1434 = torch.ops.aten.add.Scalar(mean_66, 1e-05); mean_66 = None + rsqrt_66 = torch.ops.aten.rsqrt.default(add_1434); add_1434 = None + mul_1038 = torch.ops.aten.mul.Tensor(convert_element_type_1172, rsqrt_66); convert_element_type_1172 = None + mul_1039 = torch.ops.aten.mul.Tensor(mul_1038, wait_tensor_452); mul_1038 = wait_tensor_452 = None + convert_element_type_1173 = torch.ops.prims.convert_element_type.default(mul_1039, torch.bfloat16); mul_1039 = None + convert_element_type_1174 = torch.ops.prims.convert_element_type.default(primals_360, torch.bfloat16) + all_gather_into_tensor_369 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1174, 64, '0'); convert_element_type_1174 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_369); all_gather_into_tensor_369 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_453, [1, 0]); wait_tensor_453 = None + view_1443 = torch.ops.aten.view.default(convert_element_type_1173, [8192, 2048]); convert_element_type_1173 = None + mm_175 = torch.ops.aten.mm.default(view_1443, permute_326); permute_326 = None + view_1444 = torch.ops.aten.view.default(mm_175, [2, 4096, 3072]); mm_175 = None + view_1445 = torch.ops.aten.view.default(view_1444, [2, 4096, -1, 192]); view_1444 = None + split_with_sizes_66 = torch.ops.aten.split_with_sizes.default(view_1445, [128, 64], -1); view_1445 = None + getitem_303 = split_with_sizes_66[0] + getitem_304 = split_with_sizes_66[1]; split_with_sizes_66 = None + convert_element_type_1177 = torch.ops.prims.convert_element_type.default(getitem_304, torch.float32); getitem_304 = None + view_1446 = torch.ops.aten.view.default(convert_element_type_1177, [2, 4096, 16, -1, 2]); convert_element_type_1177 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_1446); view_1446 = None + mul_1040 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_7); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_1040); mul_1040 = None + view_1448 = torch.ops.aten.view.default(view_as_real_44, [2, 4096, 16, 64]); view_as_real_44 = None + convert_element_type_1178 = torch.ops.prims.convert_element_type.default(view_1448, torch.bfloat16); view_1448 = None + cat_65 = torch.ops.aten.cat.default([getitem_303, convert_element_type_1178], -1); getitem_303 = convert_element_type_1178 = None + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(primals_361, torch.bfloat16) + all_gather_into_tensor_370 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1179, 64, '0'); convert_element_type_1179 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_370); all_gather_into_tensor_370 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_454, [1, 0]); wait_tensor_454 = None + mm_176 = torch.ops.aten.mm.default(view_1443, permute_327); permute_327 = None + view_1451 = torch.ops.aten.view.default(mm_176, [2, 4096, 576]); mm_176 = None + split_with_sizes_67 = torch.ops.aten.split_with_sizes.default(view_1451, [512, 64], -1); view_1451 = None + getitem_305 = split_with_sizes_67[0] + getitem_306 = split_with_sizes_67[1]; split_with_sizes_67 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(getitem_306, 2); getitem_306 = None + convert_element_type_1182 = torch.ops.prims.convert_element_type.default(unsqueeze_43, torch.float32); unsqueeze_43 = None + view_1452 = torch.ops.aten.view.default(convert_element_type_1182, [2, 4096, 1, -1, 2]); convert_element_type_1182 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_1452); view_1452 = None + mul_1041 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_7); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_1041); mul_1041 = None + view_1454 = torch.ops.aten.view.default(view_as_real_45, [2, 4096, 1, 64]); view_as_real_45 = None + convert_element_type_1183 = torch.ops.prims.convert_element_type.default(view_1454, torch.bfloat16); view_1454 = None + convert_element_type_1184 = torch.ops.prims.convert_element_type.default(primals_362, torch.bfloat16) + all_gather_into_tensor_371 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1184, 64, '0'); convert_element_type_1184 = None + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_371); all_gather_into_tensor_371 = None + convert_element_type_1185 = torch.ops.prims.convert_element_type.default(getitem_305, torch.float32) + pow_68 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1185, 2) + mean_67 = torch.ops.aten.mean.dim(pow_68, [2], True); pow_68 = None + add_1435 = torch.ops.aten.add.Scalar(mean_67, 1e-05); mean_67 = None + rsqrt_67 = torch.ops.aten.rsqrt.default(add_1435); add_1435 = None + mul_1042 = torch.ops.aten.mul.Tensor(convert_element_type_1185, rsqrt_67); convert_element_type_1185 = None + mul_1043 = torch.ops.aten.mul.Tensor(mul_1042, wait_tensor_455); mul_1042 = wait_tensor_455 = None + convert_element_type_1186 = torch.ops.prims.convert_element_type.default(mul_1043, torch.bfloat16); mul_1043 = None + convert_element_type_1187 = torch.ops.prims.convert_element_type.default(primals_363, torch.bfloat16) + all_gather_into_tensor_372 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1187, 64, '0'); convert_element_type_1187 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_372); all_gather_into_tensor_372 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_456, [1, 0]); wait_tensor_456 = None + view_1457 = torch.ops.aten.view.default(convert_element_type_1186, [8192, 512]); convert_element_type_1186 = None + mm_177 = torch.ops.aten.mm.default(view_1457, permute_328); permute_328 = None + view_1458 = torch.ops.aten.view.default(mm_177, [2, 4096, 4096]); mm_177 = None + view_1459 = torch.ops.aten.view.default(view_1458, [2, 4096, -1, 256]); view_1458 = None + split_with_sizes_68 = torch.ops.aten.split_with_sizes.default(view_1459, [128, 128], -1); view_1459 = None + getitem_307 = split_with_sizes_68[0] + getitem_308 = split_with_sizes_68[1]; split_with_sizes_68 = None + expand_22 = torch.ops.aten.expand.default(convert_element_type_1183, [-1, -1, 16, -1]); convert_element_type_1183 = None + cat_66 = torch.ops.aten.cat.default([getitem_307, expand_22], -1); getitem_307 = expand_22 = None + permute_329 = torch.ops.aten.permute.default(cat_65, [0, 2, 1, 3]); cat_65 = None + permute_330 = torch.ops.aten.permute.default(cat_66, [0, 2, 1, 3]); cat_66 = None + permute_331 = torch.ops.aten.permute.default(getitem_308, [0, 2, 1, 3]); getitem_308 = None + sdpa_score22 = self.sdpa_score22 + sdpa_mask22 = self.sdpa_mask22 + flex_attention_22 = torch.ops.higher_order.flex_attention(permute_329, permute_330, permute_331, sdpa_score22, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask22), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score22 = sdpa_mask22 = None + getitem_309 = flex_attention_22[0] + getitem_310 = flex_attention_22[1]; flex_attention_22 = None + permute_332 = torch.ops.aten.permute.default(getitem_309, [0, 2, 1, 3]) + view_1460 = torch.ops.aten.view.default(permute_332, [2, 4096, -1]); permute_332 = None + convert_element_type_1190 = torch.ops.prims.convert_element_type.default(primals_364, torch.bfloat16) + all_gather_into_tensor_373 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1190, 64, '0'); convert_element_type_1190 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_373); all_gather_into_tensor_373 = None + permute_333 = torch.ops.aten.permute.default(wait_tensor_457, [1, 0]); wait_tensor_457 = None + view_1462 = torch.ops.aten.view.default(view_1460, [8192, 2048]); view_1460 = None + mm_178 = torch.ops.aten.mm.default(view_1462, permute_333); view_1462 = permute_333 = None + view_1463 = torch.ops.aten.view.default(mm_178, [2, 4096, 2048]); mm_178 = None + add_1436 = torch.ops.aten.add.Tensor(add_1433, view_1463); view_1463 = None + convert_element_type_1193 = torch.ops.prims.convert_element_type.default(primals_365, torch.bfloat16) + all_gather_into_tensor_374 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1193, 64, '0'); convert_element_type_1193 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_374); all_gather_into_tensor_374 = None + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(add_1436, torch.float32) + pow_69 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1194, 2) + mean_68 = torch.ops.aten.mean.dim(pow_69, [2], True); pow_69 = None + add_1437 = torch.ops.aten.add.Scalar(mean_68, 1e-05); mean_68 = None + rsqrt_68 = torch.ops.aten.rsqrt.default(add_1437); add_1437 = None + mul_1044 = torch.ops.aten.mul.Tensor(convert_element_type_1194, rsqrt_68); convert_element_type_1194 = None + mul_1045 = torch.ops.aten.mul.Tensor(mul_1044, wait_tensor_458); mul_1044 = wait_tensor_458 = None + convert_element_type_1195 = torch.ops.prims.convert_element_type.default(mul_1045, torch.bfloat16); mul_1045 = None + view_1465 = torch.ops.aten.view.default(convert_element_type_1195, [-1, 2048]); convert_element_type_1195 = None + convert_element_type_1196 = torch.ops.prims.convert_element_type.default(primals_367, torch.bfloat16) + all_gather_into_tensor_375 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1196, 64, '0'); convert_element_type_1196 = None + wait_tensor_459 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_375); all_gather_into_tensor_375 = None + permute_334 = torch.ops.aten.permute.default(wait_tensor_459, [1, 0]); wait_tensor_459 = None + mm_179 = torch.ops.aten.mm.default(view_1465, permute_334); permute_334 = None + convert_element_type_1199 = torch.ops.prims.convert_element_type.default(mm_179, torch.float32) + amax_21 = torch.ops.aten.amax.default(convert_element_type_1199, [1], True) + sub_504 = torch.ops.aten.sub.Tensor(convert_element_type_1199, amax_21); convert_element_type_1199 = None + exp_64 = torch.ops.aten.exp.default(sub_504); sub_504 = None + sum_85 = torch.ops.aten.sum.dim_IntList(exp_64, [1], True) + div_106 = torch.ops.aten.div.Tensor(exp_64, sum_85); exp_64 = None + add_1438 = torch.ops.aten.add.Tensor(div_106, primals_366); primals_366 = None + topk_21 = torch.ops.aten.topk.default(add_1438, 6, -1, True, False); add_1438 = None + getitem_313 = topk_21[1]; topk_21 = None + gather_21 = torch.ops.aten.gather.default(div_106, 1, getitem_313); div_106 = None + mul_1046 = torch.ops.aten.mul.Tensor(gather_21, 1.0); gather_21 = None + view_1467 = torch.ops.aten.view.default(getitem_313, [-1]) + histc_42 = torch.ops.aten.histc.default(view_1467, 64, 0, 64) + add_1439 = torch.ops.aten.add.Tensor(primals_368, histc_42) + sort_21 = torch.ops.aten.sort.stable(view_1467, stable = True); view_1467 = None + getitem_315 = sort_21[1]; sort_21 = None + div_107 = torch.ops.aten.div.Tensor_mode(getitem_315, 6, rounding_mode = 'floor') + index_42 = torch.ops.aten.index.Tensor(view_1465, [div_107]) + all_to_all_single_63 = torch.ops._c10d_functional.all_to_all_single.default(histc_42, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_460 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_63); all_to_all_single_63 = None + wait_tensor_461 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_460); wait_tensor_460 = None + view_1471 = torch.ops.aten.view.default(histc_42, [8, -1]); histc_42 = None + sum_86 = torch.ops.aten.sum.dim_IntList(view_1471, [1]); view_1471 = None + device_put_42 = torch.ops.prims.device_put.default(sum_86, device(type='cpu'), True); sum_86 = None + view_1472 = torch.ops.aten.view.default(wait_tensor_461, [8, -1]) + sum_87 = torch.ops.aten.sum.dim_IntList(view_1472, [1]) + device_put_43 = torch.ops.prims.device_put.default(sum_87, device(type='cpu')); sum_87 = None + select_336 = torch.ops.aten.select.int(device_put_42, 0, 0) + _local_scalar_dense_336 = torch.ops.aten._local_scalar_dense.default(select_336); select_336 = None + ge_420 = _local_scalar_dense_336 >= 0 + _assert_scalar_336 = torch.ops.aten._assert_scalar.default(ge_420, "Runtime assertion failed for expression u336 >= 0 on node 'ge_336'"); ge_420 = _assert_scalar_336 = None + select_337 = torch.ops.aten.select.int(device_put_42, 0, 1) + _local_scalar_dense_337 = torch.ops.aten._local_scalar_dense.default(select_337); select_337 = None + ge_421 = _local_scalar_dense_337 >= 0 + _assert_scalar_337 = torch.ops.aten._assert_scalar.default(ge_421, "Runtime assertion failed for expression u337 >= 0 on node 'ge_337'"); ge_421 = _assert_scalar_337 = None + select_338 = torch.ops.aten.select.int(device_put_42, 0, 2) + _local_scalar_dense_338 = torch.ops.aten._local_scalar_dense.default(select_338); select_338 = None + ge_422 = _local_scalar_dense_338 >= 0 + _assert_scalar_338 = torch.ops.aten._assert_scalar.default(ge_422, "Runtime assertion failed for expression u338 >= 0 on node 'ge_338'"); ge_422 = _assert_scalar_338 = None + select_339 = torch.ops.aten.select.int(device_put_42, 0, 3) + _local_scalar_dense_339 = torch.ops.aten._local_scalar_dense.default(select_339); select_339 = None + ge_423 = _local_scalar_dense_339 >= 0 + _assert_scalar_339 = torch.ops.aten._assert_scalar.default(ge_423, "Runtime assertion failed for expression u339 >= 0 on node 'ge_339'"); ge_423 = _assert_scalar_339 = None + select_340 = torch.ops.aten.select.int(device_put_42, 0, 4) + _local_scalar_dense_340 = torch.ops.aten._local_scalar_dense.default(select_340); select_340 = None + ge_424 = _local_scalar_dense_340 >= 0 + _assert_scalar_340 = torch.ops.aten._assert_scalar.default(ge_424, "Runtime assertion failed for expression u340 >= 0 on node 'ge_340'"); ge_424 = _assert_scalar_340 = None + select_341 = torch.ops.aten.select.int(device_put_42, 0, 5) + _local_scalar_dense_341 = torch.ops.aten._local_scalar_dense.default(select_341); select_341 = None + ge_425 = _local_scalar_dense_341 >= 0 + _assert_scalar_341 = torch.ops.aten._assert_scalar.default(ge_425, "Runtime assertion failed for expression u341 >= 0 on node 'ge_341'"); ge_425 = _assert_scalar_341 = None + select_342 = torch.ops.aten.select.int(device_put_42, 0, 6) + _local_scalar_dense_342 = torch.ops.aten._local_scalar_dense.default(select_342); select_342 = None + ge_426 = _local_scalar_dense_342 >= 0 + _assert_scalar_342 = torch.ops.aten._assert_scalar.default(ge_426, "Runtime assertion failed for expression u342 >= 0 on node 'ge_342'"); ge_426 = _assert_scalar_342 = None + select_343 = torch.ops.aten.select.int(device_put_42, 0, 7); device_put_42 = None + _local_scalar_dense_343 = torch.ops.aten._local_scalar_dense.default(select_343); select_343 = None + ge_427 = _local_scalar_dense_343 >= 0 + _assert_scalar_343 = torch.ops.aten._assert_scalar.default(ge_427, "Runtime assertion failed for expression u343 >= 0 on node 'ge_343'"); ge_427 = _assert_scalar_343 = None + select_344 = torch.ops.aten.select.int(device_put_43, 0, 0) + _local_scalar_dense_344 = torch.ops.aten._local_scalar_dense.default(select_344); select_344 = None + ge_428 = _local_scalar_dense_344 >= 0 + _assert_scalar_344 = torch.ops.aten._assert_scalar.default(ge_428, "Runtime assertion failed for expression u344 >= 0 on node 'ge_344'"); ge_428 = _assert_scalar_344 = None + select_345 = torch.ops.aten.select.int(device_put_43, 0, 1) + _local_scalar_dense_345 = torch.ops.aten._local_scalar_dense.default(select_345); select_345 = None + ge_429 = _local_scalar_dense_345 >= 0 + _assert_scalar_345 = torch.ops.aten._assert_scalar.default(ge_429, "Runtime assertion failed for expression u345 >= 0 on node 'ge_345'"); ge_429 = _assert_scalar_345 = None + select_346 = torch.ops.aten.select.int(device_put_43, 0, 2) + _local_scalar_dense_346 = torch.ops.aten._local_scalar_dense.default(select_346); select_346 = None + ge_430 = _local_scalar_dense_346 >= 0 + _assert_scalar_346 = torch.ops.aten._assert_scalar.default(ge_430, "Runtime assertion failed for expression u346 >= 0 on node 'ge_346'"); ge_430 = _assert_scalar_346 = None + select_347 = torch.ops.aten.select.int(device_put_43, 0, 3) + _local_scalar_dense_347 = torch.ops.aten._local_scalar_dense.default(select_347); select_347 = None + ge_431 = _local_scalar_dense_347 >= 0 + _assert_scalar_347 = torch.ops.aten._assert_scalar.default(ge_431, "Runtime assertion failed for expression u347 >= 0 on node 'ge_347'"); ge_431 = _assert_scalar_347 = None + select_348 = torch.ops.aten.select.int(device_put_43, 0, 4) + _local_scalar_dense_348 = torch.ops.aten._local_scalar_dense.default(select_348); select_348 = None + ge_432 = _local_scalar_dense_348 >= 0 + _assert_scalar_348 = torch.ops.aten._assert_scalar.default(ge_432, "Runtime assertion failed for expression u348 >= 0 on node 'ge_348'"); ge_432 = _assert_scalar_348 = None + select_349 = torch.ops.aten.select.int(device_put_43, 0, 5) + _local_scalar_dense_349 = torch.ops.aten._local_scalar_dense.default(select_349); select_349 = None + ge_433 = _local_scalar_dense_349 >= 0 + _assert_scalar_349 = torch.ops.aten._assert_scalar.default(ge_433, "Runtime assertion failed for expression u349 >= 0 on node 'ge_349'"); ge_433 = _assert_scalar_349 = None + select_350 = torch.ops.aten.select.int(device_put_43, 0, 6) + _local_scalar_dense_350 = torch.ops.aten._local_scalar_dense.default(select_350); select_350 = None + ge_434 = _local_scalar_dense_350 >= 0 + _assert_scalar_350 = torch.ops.aten._assert_scalar.default(ge_434, "Runtime assertion failed for expression u350 >= 0 on node 'ge_350'"); ge_434 = _assert_scalar_350 = None + select_351 = torch.ops.aten.select.int(device_put_43, 0, 7); device_put_43 = None + _local_scalar_dense_351 = torch.ops.aten._local_scalar_dense.default(select_351); select_351 = None + ge_435 = _local_scalar_dense_351 >= 0 + _assert_scalar_351 = torch.ops.aten._assert_scalar.default(ge_435, "Runtime assertion failed for expression u351 >= 0 on node 'ge_351'"); ge_435 = _assert_scalar_351 = None + all_to_all_single_64 = torch.ops._c10d_functional.all_to_all_single.default(index_42, [_local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351], [_local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343], '521'); index_42 = None + sym_size_int_84 = torch.ops.aten.sym_size.int(all_to_all_single_64, 0) + wait_tensor_462 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_64); all_to_all_single_64 = None + sym_sum_42 = torch.sym_sum((_local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351)) + add_1446 = sym_sum_42 + 64; sym_sum_42 = None + add_1447 = add_1446 + 8; add_1446 = None + sub_507 = add_1447 - 1; add_1447 = None + floordiv_21 = sub_507 // 8; sub_507 = None + mul_1051 = floordiv_21 * 8; floordiv_21 = None + cumsum_63 = torch.ops.aten.cumsum.default(wait_tensor_461, 0) + sub_508 = torch.ops.aten.sub.Tensor(cumsum_63, wait_tensor_461); cumsum_63 = None + sum_88 = torch.ops.aten.sum.dim_IntList(view_1472, [0]); view_1472 = None + clamp_min_21 = torch.ops.aten.clamp_min.default(sum_88, 8); sum_88 = None + add_1448 = torch.ops.aten.add.Tensor(clamp_min_21, 8); clamp_min_21 = None + sub_509 = torch.ops.aten.sub.Tensor(add_1448, 1); add_1448 = None + div_108 = torch.ops.aten.div.Tensor_mode(sub_509, 8, rounding_mode = 'floor'); sub_509 = None + mul_1052 = torch.ops.aten.mul.Tensor(div_108, 8); div_108 = None + convert_element_type_1202 = torch.ops.prims.convert_element_type.default(mul_1052, torch.int32); mul_1052 = None + cumsum_64 = torch.ops.aten.cumsum.default(convert_element_type_1202, 0) + sub_510 = torch.ops.aten.sub.Tensor(cumsum_64, convert_element_type_1202); cumsum_64 = None + full_293 = torch.ops.aten.full.default([mul_1051], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1051 = None + triton_kernel_wrapper_functional_proxy_21 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 21, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_461, 'start_index_values_ptr': sub_508, 'write_offsets_ptr': sub_510, 'output_ptr': full_293}, tensors_to_clone = ['output_ptr']); wait_tensor_461 = sub_508 = sub_510 = full_293 = None + getitem_316 = triton_kernel_wrapper_functional_proxy_21['output_ptr']; triton_kernel_wrapper_functional_proxy_21 = None + cat_67 = torch.ops.aten.cat.default([wait_tensor_462, full_default]); wait_tensor_462 = None + sym_size_int_85 = torch.ops.aten.sym_size.int(cat_67, 0) + sym_sum_43 = torch.sym_sum((1, _local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351)) + index_43 = torch.ops.aten.index.Tensor(cat_67, [getitem_316]); cat_67 = None + convert_element_type_1204 = torch.ops.prims.convert_element_type.default(primals_369, torch.bfloat16) + all_gather_into_tensor_376 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1204, 8, '513'); convert_element_type_1204 = None + wait_tensor_463 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_376); all_gather_into_tensor_376 = None + convert_element_type_1206 = torch.ops.prims.convert_element_type.default(primals_370, torch.bfloat16) + all_gather_into_tensor_378 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1206, 8, '513'); convert_element_type_1206 = None + wait_tensor_465 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_378); all_gather_into_tensor_378 = None + convert_element_type_1207 = torch.ops.prims.convert_element_type.default(primals_371, torch.bfloat16) + all_gather_into_tensor_379 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1207, 8, '513'); convert_element_type_1207 = None + wait_tensor_466 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_379); all_gather_into_tensor_379 = None + cumsum_65 = torch.ops.aten.cumsum.default(convert_element_type_1202, 0, dtype = torch.int32); convert_element_type_1202 = None + permute_335 = torch.ops.aten.permute.default(wait_tensor_463, [0, 2, 1]); wait_tensor_463 = None + _grouped_mm_63 = torch.ops.aten._grouped_mm.default(index_43, permute_335, cumsum_65); permute_335 = None + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(_grouped_mm_63, torch.float32) + neg_43 = torch.ops.aten.neg.default(convert_element_type_1210) + exp_65 = torch.ops.aten.exp.default(neg_43); neg_43 = None + add_1460 = torch.ops.aten.add.Tensor(exp_65, 1); exp_65 = None + div_109 = torch.ops.aten.div.Tensor(convert_element_type_1210, add_1460); convert_element_type_1210 = add_1460 = None + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(div_109, torch.bfloat16); div_109 = None + permute_336 = torch.ops.aten.permute.default(wait_tensor_466, [0, 2, 1]); wait_tensor_466 = None + _grouped_mm_64 = torch.ops.aten._grouped_mm.default(index_43, permute_336, cumsum_65); permute_336 = None + mul_1064 = torch.ops.aten.mul.Tensor(convert_element_type_1211, _grouped_mm_64); convert_element_type_1211 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_465, [0, 2, 1]); wait_tensor_465 = None + _grouped_mm_65 = torch.ops.aten._grouped_mm.default(mul_1064, permute_337, cumsum_65); permute_337 = None + empty_21 = torch.ops.aten.empty.memory_format([sym_size_int_85, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_42 = torch.ops.aten.index_put.default(empty_21, [getitem_316], _grouped_mm_65); empty_21 = _grouped_mm_65 = None + slice_90 = torch.ops.aten.slice.Tensor(index_put_42, 0, 0, -1); index_put_42 = None + all_to_all_single_65 = torch.ops._c10d_functional.all_to_all_single.default(slice_90, [_local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343], [_local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351], '521'); slice_90 = None + wait_tensor_469 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_65); all_to_all_single_65 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(primals_372, torch.bfloat16) + all_gather_into_tensor_382 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1212, 64, '0'); convert_element_type_1212 = None + wait_tensor_470 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_382); all_gather_into_tensor_382 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_470, [1, 0]); wait_tensor_470 = None + mm_180 = torch.ops.aten.mm.default(view_1465, permute_338); permute_338 = None + convert_element_type_1215 = torch.ops.prims.convert_element_type.default(mm_180, torch.float32) + neg_44 = torch.ops.aten.neg.default(convert_element_type_1215) + exp_66 = torch.ops.aten.exp.default(neg_44); neg_44 = None + add_1496 = torch.ops.aten.add.Tensor(exp_66, 1); exp_66 = None + div_110 = torch.ops.aten.div.Tensor(convert_element_type_1215, add_1496); convert_element_type_1215 = add_1496 = None + convert_element_type_1216 = torch.ops.prims.convert_element_type.default(div_110, torch.bfloat16); div_110 = None + convert_element_type_1217 = torch.ops.prims.convert_element_type.default(primals_373, torch.bfloat16) + all_gather_into_tensor_383 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1217, 64, '0'); convert_element_type_1217 = None + wait_tensor_471 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_383); all_gather_into_tensor_383 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_471, [1, 0]); wait_tensor_471 = None + mm_181 = torch.ops.aten.mm.default(view_1465, permute_339); permute_339 = None + mul_1084 = torch.ops.aten.mul.Tensor(convert_element_type_1216, mm_181); convert_element_type_1216 = None + convert_element_type_1220 = torch.ops.prims.convert_element_type.default(primals_374, torch.bfloat16) + all_gather_into_tensor_384 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1220, 64, '0'); convert_element_type_1220 = None + wait_tensor_472 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_384); all_gather_into_tensor_384 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_472, [1, 0]); wait_tensor_472 = None + mm_182 = torch.ops.aten.mm.default(mul_1084, permute_340); permute_340 = None + index_put_43 = torch.ops.aten.index_put.default(full_default_1, [getitem_315], wait_tensor_469); wait_tensor_469 = None + view_1505 = torch.ops.aten.view.default(mul_1046, [-1, 1, 6]); mul_1046 = None + view_1506 = torch.ops.aten.view.default(index_put_43, [-1, 6, 2048]); index_put_43 = None + convert_element_type_1223 = torch.ops.prims.convert_element_type.default(view_1506, torch.float32); view_1506 = None + bmm_21 = torch.ops.aten.bmm.default(view_1505, convert_element_type_1223) + convert_element_type_1224 = torch.ops.prims.convert_element_type.default(bmm_21, torch.bfloat16); bmm_21 = None + squeeze_21 = torch.ops.aten.squeeze.dim(convert_element_type_1224, 1); convert_element_type_1224 = None + add_1500 = torch.ops.aten.add.Tensor(mm_182, squeeze_21); mm_182 = squeeze_21 = None + view_1507 = torch.ops.aten.view.default(add_1500, [2, 4096, 2048]); add_1500 = None + add_1501 = torch.ops.aten.add.Tensor(add_1436, view_1507); view_1507 = None + convert_element_type_1225 = torch.ops.prims.convert_element_type.default(primals_375, torch.bfloat16) + all_gather_into_tensor_385 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1225, 64, '0'); convert_element_type_1225 = None + wait_tensor_473 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_385); all_gather_into_tensor_385 = None + convert_element_type_1226 = torch.ops.prims.convert_element_type.default(add_1501, torch.float32) + pow_70 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1226, 2) + mean_69 = torch.ops.aten.mean.dim(pow_70, [2], True); pow_70 = None + add_1502 = torch.ops.aten.add.Scalar(mean_69, 1e-05); mean_69 = None + rsqrt_69 = torch.ops.aten.rsqrt.default(add_1502); add_1502 = None + mul_1087 = torch.ops.aten.mul.Tensor(convert_element_type_1226, rsqrt_69); convert_element_type_1226 = None + mul_1088 = torch.ops.aten.mul.Tensor(mul_1087, wait_tensor_473); mul_1087 = wait_tensor_473 = None + convert_element_type_1227 = torch.ops.prims.convert_element_type.default(mul_1088, torch.bfloat16); mul_1088 = None + convert_element_type_1228 = torch.ops.prims.convert_element_type.default(primals_376, torch.bfloat16) + all_gather_into_tensor_386 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1228, 64, '0'); convert_element_type_1228 = None + wait_tensor_474 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_386); all_gather_into_tensor_386 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_474, [1, 0]); wait_tensor_474 = None + view_1510 = torch.ops.aten.view.default(convert_element_type_1227, [8192, 2048]); convert_element_type_1227 = None + mm_183 = torch.ops.aten.mm.default(view_1510, permute_341); permute_341 = None + view_1511 = torch.ops.aten.view.default(mm_183, [2, 4096, 3072]); mm_183 = None + view_1512 = torch.ops.aten.view.default(view_1511, [2, 4096, -1, 192]); view_1511 = None + split_with_sizes_69 = torch.ops.aten.split_with_sizes.default(view_1512, [128, 64], -1); view_1512 = None + getitem_317 = split_with_sizes_69[0] + getitem_318 = split_with_sizes_69[1]; split_with_sizes_69 = None + convert_element_type_1231 = torch.ops.prims.convert_element_type.default(getitem_318, torch.float32); getitem_318 = None + view_1513 = torch.ops.aten.view.default(convert_element_type_1231, [2, 4096, 16, -1, 2]); convert_element_type_1231 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_1513); view_1513 = None + mul_1089 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_7); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_1089); mul_1089 = None + view_1515 = torch.ops.aten.view.default(view_as_real_46, [2, 4096, 16, 64]); view_as_real_46 = None + convert_element_type_1232 = torch.ops.prims.convert_element_type.default(view_1515, torch.bfloat16); view_1515 = None + cat_68 = torch.ops.aten.cat.default([getitem_317, convert_element_type_1232], -1); getitem_317 = convert_element_type_1232 = None + convert_element_type_1233 = torch.ops.prims.convert_element_type.default(primals_377, torch.bfloat16) + all_gather_into_tensor_387 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1233, 64, '0'); convert_element_type_1233 = None + wait_tensor_475 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_387); all_gather_into_tensor_387 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_475, [1, 0]); wait_tensor_475 = None + mm_184 = torch.ops.aten.mm.default(view_1510, permute_342); permute_342 = None + view_1518 = torch.ops.aten.view.default(mm_184, [2, 4096, 576]); mm_184 = None + split_with_sizes_70 = torch.ops.aten.split_with_sizes.default(view_1518, [512, 64], -1); view_1518 = None + getitem_319 = split_with_sizes_70[0] + getitem_320 = split_with_sizes_70[1]; split_with_sizes_70 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(getitem_320, 2); getitem_320 = None + convert_element_type_1236 = torch.ops.prims.convert_element_type.default(unsqueeze_45, torch.float32); unsqueeze_45 = None + view_1519 = torch.ops.aten.view.default(convert_element_type_1236, [2, 4096, 1, -1, 2]); convert_element_type_1236 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_1519); view_1519 = None + mul_1090 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_7); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_1090); mul_1090 = None + view_1521 = torch.ops.aten.view.default(view_as_real_47, [2, 4096, 1, 64]); view_as_real_47 = None + convert_element_type_1237 = torch.ops.prims.convert_element_type.default(view_1521, torch.bfloat16); view_1521 = None + convert_element_type_1238 = torch.ops.prims.convert_element_type.default(primals_378, torch.bfloat16) + all_gather_into_tensor_388 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1238, 64, '0'); convert_element_type_1238 = None + wait_tensor_476 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_388); all_gather_into_tensor_388 = None + convert_element_type_1239 = torch.ops.prims.convert_element_type.default(getitem_319, torch.float32) + pow_71 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1239, 2) + mean_70 = torch.ops.aten.mean.dim(pow_71, [2], True); pow_71 = None + add_1503 = torch.ops.aten.add.Scalar(mean_70, 1e-05); mean_70 = None + rsqrt_70 = torch.ops.aten.rsqrt.default(add_1503); add_1503 = None + mul_1091 = torch.ops.aten.mul.Tensor(convert_element_type_1239, rsqrt_70); convert_element_type_1239 = None + mul_1092 = torch.ops.aten.mul.Tensor(mul_1091, wait_tensor_476); mul_1091 = wait_tensor_476 = None + convert_element_type_1240 = torch.ops.prims.convert_element_type.default(mul_1092, torch.bfloat16); mul_1092 = None + convert_element_type_1241 = torch.ops.prims.convert_element_type.default(primals_379, torch.bfloat16) + all_gather_into_tensor_389 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1241, 64, '0'); convert_element_type_1241 = None + wait_tensor_477 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_389); all_gather_into_tensor_389 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_477, [1, 0]); wait_tensor_477 = None + view_1524 = torch.ops.aten.view.default(convert_element_type_1240, [8192, 512]); convert_element_type_1240 = None + mm_185 = torch.ops.aten.mm.default(view_1524, permute_343); permute_343 = None + view_1525 = torch.ops.aten.view.default(mm_185, [2, 4096, 4096]); mm_185 = None + view_1526 = torch.ops.aten.view.default(view_1525, [2, 4096, -1, 256]); view_1525 = None + split_with_sizes_71 = torch.ops.aten.split_with_sizes.default(view_1526, [128, 128], -1); view_1526 = None + getitem_321 = split_with_sizes_71[0] + getitem_322 = split_with_sizes_71[1]; split_with_sizes_71 = None + expand_23 = torch.ops.aten.expand.default(convert_element_type_1237, [-1, -1, 16, -1]); convert_element_type_1237 = None + cat_69 = torch.ops.aten.cat.default([getitem_321, expand_23], -1); getitem_321 = expand_23 = None + permute_344 = torch.ops.aten.permute.default(cat_68, [0, 2, 1, 3]); cat_68 = None + permute_345 = torch.ops.aten.permute.default(cat_69, [0, 2, 1, 3]); cat_69 = None + permute_346 = torch.ops.aten.permute.default(getitem_322, [0, 2, 1, 3]); getitem_322 = None + sdpa_score23 = self.sdpa_score23 + sdpa_mask23 = self.sdpa_mask23 + flex_attention_23 = torch.ops.higher_order.flex_attention(permute_344, permute_345, permute_346, sdpa_score23, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask23), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score23 = sdpa_mask23 = None + getitem_323 = flex_attention_23[0] + getitem_324 = flex_attention_23[1]; flex_attention_23 = None + permute_347 = torch.ops.aten.permute.default(getitem_323, [0, 2, 1, 3]) + view_1527 = torch.ops.aten.view.default(permute_347, [2, 4096, -1]); permute_347 = None + convert_element_type_1244 = torch.ops.prims.convert_element_type.default(primals_380, torch.bfloat16) + all_gather_into_tensor_390 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1244, 64, '0'); convert_element_type_1244 = None + wait_tensor_478 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_390); all_gather_into_tensor_390 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_478, [1, 0]); wait_tensor_478 = None + view_1529 = torch.ops.aten.view.default(view_1527, [8192, 2048]); view_1527 = None + mm_186 = torch.ops.aten.mm.default(view_1529, permute_348); view_1529 = permute_348 = None + view_1530 = torch.ops.aten.view.default(mm_186, [2, 4096, 2048]); mm_186 = None + add_1504 = torch.ops.aten.add.Tensor(add_1501, view_1530); view_1530 = None + convert_element_type_1247 = torch.ops.prims.convert_element_type.default(primals_381, torch.bfloat16) + all_gather_into_tensor_391 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1247, 64, '0'); convert_element_type_1247 = None + wait_tensor_479 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_391); all_gather_into_tensor_391 = None + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(add_1504, torch.float32) + pow_72 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1248, 2) + mean_71 = torch.ops.aten.mean.dim(pow_72, [2], True); pow_72 = None + add_1505 = torch.ops.aten.add.Scalar(mean_71, 1e-05); mean_71 = None + rsqrt_71 = torch.ops.aten.rsqrt.default(add_1505); add_1505 = None + mul_1093 = torch.ops.aten.mul.Tensor(convert_element_type_1248, rsqrt_71); convert_element_type_1248 = None + mul_1094 = torch.ops.aten.mul.Tensor(mul_1093, wait_tensor_479); mul_1093 = wait_tensor_479 = None + convert_element_type_1249 = torch.ops.prims.convert_element_type.default(mul_1094, torch.bfloat16); mul_1094 = None + view_1532 = torch.ops.aten.view.default(convert_element_type_1249, [-1, 2048]); convert_element_type_1249 = None + convert_element_type_1250 = torch.ops.prims.convert_element_type.default(primals_383, torch.bfloat16) + all_gather_into_tensor_392 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1250, 64, '0'); convert_element_type_1250 = None + wait_tensor_480 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_392); all_gather_into_tensor_392 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_480, [1, 0]); wait_tensor_480 = None + mm_187 = torch.ops.aten.mm.default(view_1532, permute_349); permute_349 = None + convert_element_type_1253 = torch.ops.prims.convert_element_type.default(mm_187, torch.float32) + amax_22 = torch.ops.aten.amax.default(convert_element_type_1253, [1], True) + sub_528 = torch.ops.aten.sub.Tensor(convert_element_type_1253, amax_22); convert_element_type_1253 = None + exp_67 = torch.ops.aten.exp.default(sub_528); sub_528 = None + sum_89 = torch.ops.aten.sum.dim_IntList(exp_67, [1], True) + div_111 = torch.ops.aten.div.Tensor(exp_67, sum_89); exp_67 = None + add_1506 = torch.ops.aten.add.Tensor(div_111, primals_382); primals_382 = None + topk_22 = torch.ops.aten.topk.default(add_1506, 6, -1, True, False); add_1506 = None + getitem_327 = topk_22[1]; topk_22 = None + gather_22 = torch.ops.aten.gather.default(div_111, 1, getitem_327); div_111 = None + mul_1095 = torch.ops.aten.mul.Tensor(gather_22, 1.0); gather_22 = None + view_1534 = torch.ops.aten.view.default(getitem_327, [-1]) + histc_44 = torch.ops.aten.histc.default(view_1534, 64, 0, 64) + add_1507 = torch.ops.aten.add.Tensor(primals_384, histc_44) + sort_22 = torch.ops.aten.sort.stable(view_1534, stable = True); view_1534 = None + getitem_329 = sort_22[1]; sort_22 = None + div_112 = torch.ops.aten.div.Tensor_mode(getitem_329, 6, rounding_mode = 'floor') + index_44 = torch.ops.aten.index.Tensor(view_1532, [div_112]) + all_to_all_single_66 = torch.ops._c10d_functional.all_to_all_single.default(histc_44, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_481 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_66); all_to_all_single_66 = None + wait_tensor_482 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_481); wait_tensor_481 = None + view_1538 = torch.ops.aten.view.default(histc_44, [8, -1]); histc_44 = None + sum_90 = torch.ops.aten.sum.dim_IntList(view_1538, [1]); view_1538 = None + device_put_44 = torch.ops.prims.device_put.default(sum_90, device(type='cpu'), True); sum_90 = None + view_1539 = torch.ops.aten.view.default(wait_tensor_482, [8, -1]) + sum_91 = torch.ops.aten.sum.dim_IntList(view_1539, [1]) + device_put_45 = torch.ops.prims.device_put.default(sum_91, device(type='cpu')); sum_91 = None + select_352 = torch.ops.aten.select.int(device_put_44, 0, 0) + _local_scalar_dense_352 = torch.ops.aten._local_scalar_dense.default(select_352); select_352 = None + ge_440 = _local_scalar_dense_352 >= 0 + _assert_scalar_352 = torch.ops.aten._assert_scalar.default(ge_440, "Runtime assertion failed for expression u352 >= 0 on node 'ge_352'"); ge_440 = _assert_scalar_352 = None + select_353 = torch.ops.aten.select.int(device_put_44, 0, 1) + _local_scalar_dense_353 = torch.ops.aten._local_scalar_dense.default(select_353); select_353 = None + ge_441 = _local_scalar_dense_353 >= 0 + _assert_scalar_353 = torch.ops.aten._assert_scalar.default(ge_441, "Runtime assertion failed for expression u353 >= 0 on node 'ge_353'"); ge_441 = _assert_scalar_353 = None + select_354 = torch.ops.aten.select.int(device_put_44, 0, 2) + _local_scalar_dense_354 = torch.ops.aten._local_scalar_dense.default(select_354); select_354 = None + ge_442 = _local_scalar_dense_354 >= 0 + _assert_scalar_354 = torch.ops.aten._assert_scalar.default(ge_442, "Runtime assertion failed for expression u354 >= 0 on node 'ge_354'"); ge_442 = _assert_scalar_354 = None + select_355 = torch.ops.aten.select.int(device_put_44, 0, 3) + _local_scalar_dense_355 = torch.ops.aten._local_scalar_dense.default(select_355); select_355 = None + ge_443 = _local_scalar_dense_355 >= 0 + _assert_scalar_355 = torch.ops.aten._assert_scalar.default(ge_443, "Runtime assertion failed for expression u355 >= 0 on node 'ge_355'"); ge_443 = _assert_scalar_355 = None + select_356 = torch.ops.aten.select.int(device_put_44, 0, 4) + _local_scalar_dense_356 = torch.ops.aten._local_scalar_dense.default(select_356); select_356 = None + ge_444 = _local_scalar_dense_356 >= 0 + _assert_scalar_356 = torch.ops.aten._assert_scalar.default(ge_444, "Runtime assertion failed for expression u356 >= 0 on node 'ge_356'"); ge_444 = _assert_scalar_356 = None + select_357 = torch.ops.aten.select.int(device_put_44, 0, 5) + _local_scalar_dense_357 = torch.ops.aten._local_scalar_dense.default(select_357); select_357 = None + ge_445 = _local_scalar_dense_357 >= 0 + _assert_scalar_357 = torch.ops.aten._assert_scalar.default(ge_445, "Runtime assertion failed for expression u357 >= 0 on node 'ge_357'"); ge_445 = _assert_scalar_357 = None + select_358 = torch.ops.aten.select.int(device_put_44, 0, 6) + _local_scalar_dense_358 = torch.ops.aten._local_scalar_dense.default(select_358); select_358 = None + ge_446 = _local_scalar_dense_358 >= 0 + _assert_scalar_358 = torch.ops.aten._assert_scalar.default(ge_446, "Runtime assertion failed for expression u358 >= 0 on node 'ge_358'"); ge_446 = _assert_scalar_358 = None + select_359 = torch.ops.aten.select.int(device_put_44, 0, 7); device_put_44 = None + _local_scalar_dense_359 = torch.ops.aten._local_scalar_dense.default(select_359); select_359 = None + ge_447 = _local_scalar_dense_359 >= 0 + _assert_scalar_359 = torch.ops.aten._assert_scalar.default(ge_447, "Runtime assertion failed for expression u359 >= 0 on node 'ge_359'"); ge_447 = _assert_scalar_359 = None + select_360 = torch.ops.aten.select.int(device_put_45, 0, 0) + _local_scalar_dense_360 = torch.ops.aten._local_scalar_dense.default(select_360); select_360 = None + ge_448 = _local_scalar_dense_360 >= 0 + _assert_scalar_360 = torch.ops.aten._assert_scalar.default(ge_448, "Runtime assertion failed for expression u360 >= 0 on node 'ge_360'"); ge_448 = _assert_scalar_360 = None + select_361 = torch.ops.aten.select.int(device_put_45, 0, 1) + _local_scalar_dense_361 = torch.ops.aten._local_scalar_dense.default(select_361); select_361 = None + ge_449 = _local_scalar_dense_361 >= 0 + _assert_scalar_361 = torch.ops.aten._assert_scalar.default(ge_449, "Runtime assertion failed for expression u361 >= 0 on node 'ge_361'"); ge_449 = _assert_scalar_361 = None + select_362 = torch.ops.aten.select.int(device_put_45, 0, 2) + _local_scalar_dense_362 = torch.ops.aten._local_scalar_dense.default(select_362); select_362 = None + ge_450 = _local_scalar_dense_362 >= 0 + _assert_scalar_362 = torch.ops.aten._assert_scalar.default(ge_450, "Runtime assertion failed for expression u362 >= 0 on node 'ge_362'"); ge_450 = _assert_scalar_362 = None + select_363 = torch.ops.aten.select.int(device_put_45, 0, 3) + _local_scalar_dense_363 = torch.ops.aten._local_scalar_dense.default(select_363); select_363 = None + ge_451 = _local_scalar_dense_363 >= 0 + _assert_scalar_363 = torch.ops.aten._assert_scalar.default(ge_451, "Runtime assertion failed for expression u363 >= 0 on node 'ge_363'"); ge_451 = _assert_scalar_363 = None + select_364 = torch.ops.aten.select.int(device_put_45, 0, 4) + _local_scalar_dense_364 = torch.ops.aten._local_scalar_dense.default(select_364); select_364 = None + ge_452 = _local_scalar_dense_364 >= 0 + _assert_scalar_364 = torch.ops.aten._assert_scalar.default(ge_452, "Runtime assertion failed for expression u364 >= 0 on node 'ge_364'"); ge_452 = _assert_scalar_364 = None + select_365 = torch.ops.aten.select.int(device_put_45, 0, 5) + _local_scalar_dense_365 = torch.ops.aten._local_scalar_dense.default(select_365); select_365 = None + ge_453 = _local_scalar_dense_365 >= 0 + _assert_scalar_365 = torch.ops.aten._assert_scalar.default(ge_453, "Runtime assertion failed for expression u365 >= 0 on node 'ge_365'"); ge_453 = _assert_scalar_365 = None + select_366 = torch.ops.aten.select.int(device_put_45, 0, 6) + _local_scalar_dense_366 = torch.ops.aten._local_scalar_dense.default(select_366); select_366 = None + ge_454 = _local_scalar_dense_366 >= 0 + _assert_scalar_366 = torch.ops.aten._assert_scalar.default(ge_454, "Runtime assertion failed for expression u366 >= 0 on node 'ge_366'"); ge_454 = _assert_scalar_366 = None + select_367 = torch.ops.aten.select.int(device_put_45, 0, 7); device_put_45 = None + _local_scalar_dense_367 = torch.ops.aten._local_scalar_dense.default(select_367); select_367 = None + ge_455 = _local_scalar_dense_367 >= 0 + _assert_scalar_367 = torch.ops.aten._assert_scalar.default(ge_455, "Runtime assertion failed for expression u367 >= 0 on node 'ge_367'"); ge_455 = _assert_scalar_367 = None + all_to_all_single_67 = torch.ops._c10d_functional.all_to_all_single.default(index_44, [_local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367], [_local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359], '521'); index_44 = None + sym_size_int_88 = torch.ops.aten.sym_size.int(all_to_all_single_67, 0) + wait_tensor_483 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_67); all_to_all_single_67 = None + sym_sum_44 = torch.sym_sum((_local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367)) + add_1514 = sym_sum_44 + 64; sym_sum_44 = None + add_1515 = add_1514 + 8; add_1514 = None + sub_531 = add_1515 - 1; add_1515 = None + floordiv_22 = sub_531 // 8; sub_531 = None + mul_1100 = floordiv_22 * 8; floordiv_22 = None + cumsum_66 = torch.ops.aten.cumsum.default(wait_tensor_482, 0) + sub_532 = torch.ops.aten.sub.Tensor(cumsum_66, wait_tensor_482); cumsum_66 = None + sum_92 = torch.ops.aten.sum.dim_IntList(view_1539, [0]); view_1539 = None + clamp_min_22 = torch.ops.aten.clamp_min.default(sum_92, 8); sum_92 = None + add_1516 = torch.ops.aten.add.Tensor(clamp_min_22, 8); clamp_min_22 = None + sub_533 = torch.ops.aten.sub.Tensor(add_1516, 1); add_1516 = None + div_113 = torch.ops.aten.div.Tensor_mode(sub_533, 8, rounding_mode = 'floor'); sub_533 = None + mul_1101 = torch.ops.aten.mul.Tensor(div_113, 8); div_113 = None + convert_element_type_1256 = torch.ops.prims.convert_element_type.default(mul_1101, torch.int32); mul_1101 = None + cumsum_67 = torch.ops.aten.cumsum.default(convert_element_type_1256, 0) + sub_534 = torch.ops.aten.sub.Tensor(cumsum_67, convert_element_type_1256); cumsum_67 = None + full_306 = torch.ops.aten.full.default([mul_1100], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1100 = None + triton_kernel_wrapper_functional_proxy_22 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 22, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_482, 'start_index_values_ptr': sub_532, 'write_offsets_ptr': sub_534, 'output_ptr': full_306}, tensors_to_clone = ['output_ptr']); wait_tensor_482 = sub_532 = sub_534 = full_306 = None + getitem_330 = triton_kernel_wrapper_functional_proxy_22['output_ptr']; triton_kernel_wrapper_functional_proxy_22 = None + cat_70 = torch.ops.aten.cat.default([wait_tensor_483, full_default]); wait_tensor_483 = None + sym_size_int_89 = torch.ops.aten.sym_size.int(cat_70, 0) + sym_sum_45 = torch.sym_sum((1, _local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367)) + index_45 = torch.ops.aten.index.Tensor(cat_70, [getitem_330]); cat_70 = None + convert_element_type_1258 = torch.ops.prims.convert_element_type.default(primals_385, torch.bfloat16) + all_gather_into_tensor_393 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1258, 8, '513'); convert_element_type_1258 = None + wait_tensor_484 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_393); all_gather_into_tensor_393 = None + convert_element_type_1260 = torch.ops.prims.convert_element_type.default(primals_386, torch.bfloat16) + all_gather_into_tensor_395 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1260, 8, '513'); convert_element_type_1260 = None + wait_tensor_486 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_395); all_gather_into_tensor_395 = None + convert_element_type_1261 = torch.ops.prims.convert_element_type.default(primals_387, torch.bfloat16) + all_gather_into_tensor_396 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1261, 8, '513'); convert_element_type_1261 = None + wait_tensor_487 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_396); all_gather_into_tensor_396 = None + cumsum_68 = torch.ops.aten.cumsum.default(convert_element_type_1256, 0, dtype = torch.int32); convert_element_type_1256 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_484, [0, 2, 1]); wait_tensor_484 = None + _grouped_mm_66 = torch.ops.aten._grouped_mm.default(index_45, permute_350, cumsum_68); permute_350 = None + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(_grouped_mm_66, torch.float32) + neg_45 = torch.ops.aten.neg.default(convert_element_type_1264) + exp_68 = torch.ops.aten.exp.default(neg_45); neg_45 = None + add_1528 = torch.ops.aten.add.Tensor(exp_68, 1); exp_68 = None + div_114 = torch.ops.aten.div.Tensor(convert_element_type_1264, add_1528); convert_element_type_1264 = add_1528 = None + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(div_114, torch.bfloat16); div_114 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_487, [0, 2, 1]); wait_tensor_487 = None + _grouped_mm_67 = torch.ops.aten._grouped_mm.default(index_45, permute_351, cumsum_68); permute_351 = None + mul_1113 = torch.ops.aten.mul.Tensor(convert_element_type_1265, _grouped_mm_67); convert_element_type_1265 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_486, [0, 2, 1]); wait_tensor_486 = None + _grouped_mm_68 = torch.ops.aten._grouped_mm.default(mul_1113, permute_352, cumsum_68); permute_352 = None + empty_22 = torch.ops.aten.empty.memory_format([sym_size_int_89, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_44 = torch.ops.aten.index_put.default(empty_22, [getitem_330], _grouped_mm_68); empty_22 = _grouped_mm_68 = None + slice_94 = torch.ops.aten.slice.Tensor(index_put_44, 0, 0, -1); index_put_44 = None + all_to_all_single_68 = torch.ops._c10d_functional.all_to_all_single.default(slice_94, [_local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359], [_local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367], '521'); slice_94 = None + wait_tensor_490 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_68); all_to_all_single_68 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(primals_388, torch.bfloat16) + all_gather_into_tensor_399 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1266, 64, '0'); convert_element_type_1266 = None + wait_tensor_491 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_399); all_gather_into_tensor_399 = None + permute_353 = torch.ops.aten.permute.default(wait_tensor_491, [1, 0]); wait_tensor_491 = None + mm_188 = torch.ops.aten.mm.default(view_1532, permute_353); permute_353 = None + convert_element_type_1269 = torch.ops.prims.convert_element_type.default(mm_188, torch.float32) + neg_46 = torch.ops.aten.neg.default(convert_element_type_1269) + exp_69 = torch.ops.aten.exp.default(neg_46); neg_46 = None + add_1564 = torch.ops.aten.add.Tensor(exp_69, 1); exp_69 = None + div_115 = torch.ops.aten.div.Tensor(convert_element_type_1269, add_1564); convert_element_type_1269 = add_1564 = None + convert_element_type_1270 = torch.ops.prims.convert_element_type.default(div_115, torch.bfloat16); div_115 = None + convert_element_type_1271 = torch.ops.prims.convert_element_type.default(primals_389, torch.bfloat16) + all_gather_into_tensor_400 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1271, 64, '0'); convert_element_type_1271 = None + wait_tensor_492 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_400); all_gather_into_tensor_400 = None + permute_354 = torch.ops.aten.permute.default(wait_tensor_492, [1, 0]); wait_tensor_492 = None + mm_189 = torch.ops.aten.mm.default(view_1532, permute_354); permute_354 = None + mul_1133 = torch.ops.aten.mul.Tensor(convert_element_type_1270, mm_189); convert_element_type_1270 = None + convert_element_type_1274 = torch.ops.prims.convert_element_type.default(primals_390, torch.bfloat16) + all_gather_into_tensor_401 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1274, 64, '0'); convert_element_type_1274 = None + wait_tensor_493 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_401); all_gather_into_tensor_401 = None + permute_355 = torch.ops.aten.permute.default(wait_tensor_493, [1, 0]); wait_tensor_493 = None + mm_190 = torch.ops.aten.mm.default(mul_1133, permute_355); permute_355 = None + index_put_45 = torch.ops.aten.index_put.default(full_default_1, [getitem_329], wait_tensor_490); wait_tensor_490 = None + view_1572 = torch.ops.aten.view.default(mul_1095, [-1, 1, 6]); mul_1095 = None + view_1573 = torch.ops.aten.view.default(index_put_45, [-1, 6, 2048]); index_put_45 = None + convert_element_type_1277 = torch.ops.prims.convert_element_type.default(view_1573, torch.float32); view_1573 = None + bmm_22 = torch.ops.aten.bmm.default(view_1572, convert_element_type_1277) + convert_element_type_1278 = torch.ops.prims.convert_element_type.default(bmm_22, torch.bfloat16); bmm_22 = None + squeeze_22 = torch.ops.aten.squeeze.dim(convert_element_type_1278, 1); convert_element_type_1278 = None + add_1568 = torch.ops.aten.add.Tensor(mm_190, squeeze_22); mm_190 = squeeze_22 = None + view_1574 = torch.ops.aten.view.default(add_1568, [2, 4096, 2048]); add_1568 = None + add_1569 = torch.ops.aten.add.Tensor(add_1504, view_1574); view_1574 = None + convert_element_type_1279 = torch.ops.prims.convert_element_type.default(primals_391, torch.bfloat16) + all_gather_into_tensor_402 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1279, 64, '0'); convert_element_type_1279 = None + wait_tensor_494 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_402); all_gather_into_tensor_402 = None + convert_element_type_1280 = torch.ops.prims.convert_element_type.default(add_1569, torch.float32) + pow_73 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1280, 2) + mean_72 = torch.ops.aten.mean.dim(pow_73, [2], True); pow_73 = None + add_1570 = torch.ops.aten.add.Scalar(mean_72, 1e-05); mean_72 = None + rsqrt_72 = torch.ops.aten.rsqrt.default(add_1570); add_1570 = None + mul_1136 = torch.ops.aten.mul.Tensor(convert_element_type_1280, rsqrt_72); convert_element_type_1280 = None + mul_1137 = torch.ops.aten.mul.Tensor(mul_1136, wait_tensor_494); mul_1136 = wait_tensor_494 = None + convert_element_type_1281 = torch.ops.prims.convert_element_type.default(mul_1137, torch.bfloat16); mul_1137 = None + convert_element_type_1282 = torch.ops.prims.convert_element_type.default(primals_392, torch.bfloat16) + all_gather_into_tensor_403 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1282, 64, '0'); convert_element_type_1282 = None + wait_tensor_495 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_403); all_gather_into_tensor_403 = None + permute_356 = torch.ops.aten.permute.default(wait_tensor_495, [1, 0]); wait_tensor_495 = None + view_1577 = torch.ops.aten.view.default(convert_element_type_1281, [8192, 2048]); convert_element_type_1281 = None + mm_191 = torch.ops.aten.mm.default(view_1577, permute_356); permute_356 = None + view_1578 = torch.ops.aten.view.default(mm_191, [2, 4096, 3072]); mm_191 = None + view_1579 = torch.ops.aten.view.default(view_1578, [2, 4096, -1, 192]); view_1578 = None + split_with_sizes_72 = torch.ops.aten.split_with_sizes.default(view_1579, [128, 64], -1); view_1579 = None + getitem_331 = split_with_sizes_72[0] + getitem_332 = split_with_sizes_72[1]; split_with_sizes_72 = None + convert_element_type_1285 = torch.ops.prims.convert_element_type.default(getitem_332, torch.float32); getitem_332 = None + view_1580 = torch.ops.aten.view.default(convert_element_type_1285, [2, 4096, 16, -1, 2]); convert_element_type_1285 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_1580); view_1580 = None + mul_1138 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_7); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_1138); mul_1138 = None + view_1582 = torch.ops.aten.view.default(view_as_real_48, [2, 4096, 16, 64]); view_as_real_48 = None + convert_element_type_1286 = torch.ops.prims.convert_element_type.default(view_1582, torch.bfloat16); view_1582 = None + cat_71 = torch.ops.aten.cat.default([getitem_331, convert_element_type_1286], -1); getitem_331 = convert_element_type_1286 = None + convert_element_type_1287 = torch.ops.prims.convert_element_type.default(primals_393, torch.bfloat16) + all_gather_into_tensor_404 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1287, 64, '0'); convert_element_type_1287 = None + wait_tensor_496 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_404); all_gather_into_tensor_404 = None + permute_357 = torch.ops.aten.permute.default(wait_tensor_496, [1, 0]); wait_tensor_496 = None + mm_192 = torch.ops.aten.mm.default(view_1577, permute_357); permute_357 = None + view_1585 = torch.ops.aten.view.default(mm_192, [2, 4096, 576]); mm_192 = None + split_with_sizes_73 = torch.ops.aten.split_with_sizes.default(view_1585, [512, 64], -1); view_1585 = None + getitem_333 = split_with_sizes_73[0] + getitem_334 = split_with_sizes_73[1]; split_with_sizes_73 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(getitem_334, 2); getitem_334 = None + convert_element_type_1290 = torch.ops.prims.convert_element_type.default(unsqueeze_47, torch.float32); unsqueeze_47 = None + view_1586 = torch.ops.aten.view.default(convert_element_type_1290, [2, 4096, 1, -1, 2]); convert_element_type_1290 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_1586); view_1586 = None + mul_1139 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_7); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_1139); mul_1139 = None + view_1588 = torch.ops.aten.view.default(view_as_real_49, [2, 4096, 1, 64]); view_as_real_49 = None + convert_element_type_1291 = torch.ops.prims.convert_element_type.default(view_1588, torch.bfloat16); view_1588 = None + convert_element_type_1292 = torch.ops.prims.convert_element_type.default(primals_394, torch.bfloat16) + all_gather_into_tensor_405 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1292, 64, '0'); convert_element_type_1292 = None + wait_tensor_497 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_405); all_gather_into_tensor_405 = None + convert_element_type_1293 = torch.ops.prims.convert_element_type.default(getitem_333, torch.float32) + pow_74 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1293, 2) + mean_73 = torch.ops.aten.mean.dim(pow_74, [2], True); pow_74 = None + add_1571 = torch.ops.aten.add.Scalar(mean_73, 1e-05); mean_73 = None + rsqrt_73 = torch.ops.aten.rsqrt.default(add_1571); add_1571 = None + mul_1140 = torch.ops.aten.mul.Tensor(convert_element_type_1293, rsqrt_73); convert_element_type_1293 = None + mul_1141 = torch.ops.aten.mul.Tensor(mul_1140, wait_tensor_497); mul_1140 = wait_tensor_497 = None + convert_element_type_1294 = torch.ops.prims.convert_element_type.default(mul_1141, torch.bfloat16); mul_1141 = None + convert_element_type_1295 = torch.ops.prims.convert_element_type.default(primals_395, torch.bfloat16) + all_gather_into_tensor_406 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1295, 64, '0'); convert_element_type_1295 = None + wait_tensor_498 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_406); all_gather_into_tensor_406 = None + permute_358 = torch.ops.aten.permute.default(wait_tensor_498, [1, 0]); wait_tensor_498 = None + view_1591 = torch.ops.aten.view.default(convert_element_type_1294, [8192, 512]); convert_element_type_1294 = None + mm_193 = torch.ops.aten.mm.default(view_1591, permute_358); permute_358 = None + view_1592 = torch.ops.aten.view.default(mm_193, [2, 4096, 4096]); mm_193 = None + view_1593 = torch.ops.aten.view.default(view_1592, [2, 4096, -1, 256]); view_1592 = None + split_with_sizes_74 = torch.ops.aten.split_with_sizes.default(view_1593, [128, 128], -1); view_1593 = None + getitem_335 = split_with_sizes_74[0] + getitem_336 = split_with_sizes_74[1]; split_with_sizes_74 = None + expand_24 = torch.ops.aten.expand.default(convert_element_type_1291, [-1, -1, 16, -1]); convert_element_type_1291 = None + cat_72 = torch.ops.aten.cat.default([getitem_335, expand_24], -1); getitem_335 = expand_24 = None + permute_359 = torch.ops.aten.permute.default(cat_71, [0, 2, 1, 3]); cat_71 = None + permute_360 = torch.ops.aten.permute.default(cat_72, [0, 2, 1, 3]); cat_72 = None + permute_361 = torch.ops.aten.permute.default(getitem_336, [0, 2, 1, 3]); getitem_336 = None + sdpa_score24 = self.sdpa_score24 + sdpa_mask24 = self.sdpa_mask24 + flex_attention_24 = torch.ops.higher_order.flex_attention(permute_359, permute_360, permute_361, sdpa_score24, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask24), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score24 = sdpa_mask24 = None + getitem_337 = flex_attention_24[0] + getitem_338 = flex_attention_24[1]; flex_attention_24 = None + permute_362 = torch.ops.aten.permute.default(getitem_337, [0, 2, 1, 3]) + view_1594 = torch.ops.aten.view.default(permute_362, [2, 4096, -1]); permute_362 = None + convert_element_type_1298 = torch.ops.prims.convert_element_type.default(primals_396, torch.bfloat16) + all_gather_into_tensor_407 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1298, 64, '0'); convert_element_type_1298 = None + wait_tensor_499 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_407); all_gather_into_tensor_407 = None + permute_363 = torch.ops.aten.permute.default(wait_tensor_499, [1, 0]); wait_tensor_499 = None + view_1596 = torch.ops.aten.view.default(view_1594, [8192, 2048]); view_1594 = None + mm_194 = torch.ops.aten.mm.default(view_1596, permute_363); view_1596 = permute_363 = None + view_1597 = torch.ops.aten.view.default(mm_194, [2, 4096, 2048]); mm_194 = None + add_1572 = torch.ops.aten.add.Tensor(add_1569, view_1597); view_1597 = None + convert_element_type_1301 = torch.ops.prims.convert_element_type.default(primals_397, torch.bfloat16) + all_gather_into_tensor_408 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1301, 64, '0'); convert_element_type_1301 = None + wait_tensor_500 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_408); all_gather_into_tensor_408 = None + convert_element_type_1302 = torch.ops.prims.convert_element_type.default(add_1572, torch.float32) + pow_75 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1302, 2) + mean_74 = torch.ops.aten.mean.dim(pow_75, [2], True); pow_75 = None + add_1573 = torch.ops.aten.add.Scalar(mean_74, 1e-05); mean_74 = None + rsqrt_74 = torch.ops.aten.rsqrt.default(add_1573); add_1573 = None + mul_1142 = torch.ops.aten.mul.Tensor(convert_element_type_1302, rsqrt_74); convert_element_type_1302 = None + mul_1143 = torch.ops.aten.mul.Tensor(mul_1142, wait_tensor_500); mul_1142 = wait_tensor_500 = None + convert_element_type_1303 = torch.ops.prims.convert_element_type.default(mul_1143, torch.bfloat16); mul_1143 = None + view_1599 = torch.ops.aten.view.default(convert_element_type_1303, [-1, 2048]); convert_element_type_1303 = None + convert_element_type_1304 = torch.ops.prims.convert_element_type.default(primals_399, torch.bfloat16) + all_gather_into_tensor_409 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1304, 64, '0'); convert_element_type_1304 = None + wait_tensor_501 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_409); all_gather_into_tensor_409 = None + permute_364 = torch.ops.aten.permute.default(wait_tensor_501, [1, 0]); wait_tensor_501 = None + mm_195 = torch.ops.aten.mm.default(view_1599, permute_364); permute_364 = None + convert_element_type_1307 = torch.ops.prims.convert_element_type.default(mm_195, torch.float32) + amax_23 = torch.ops.aten.amax.default(convert_element_type_1307, [1], True) + sub_552 = torch.ops.aten.sub.Tensor(convert_element_type_1307, amax_23); convert_element_type_1307 = None + exp_70 = torch.ops.aten.exp.default(sub_552); sub_552 = None + sum_93 = torch.ops.aten.sum.dim_IntList(exp_70, [1], True) + div_116 = torch.ops.aten.div.Tensor(exp_70, sum_93); exp_70 = None + add_1574 = torch.ops.aten.add.Tensor(div_116, primals_398); primals_398 = None + topk_23 = torch.ops.aten.topk.default(add_1574, 6, -1, True, False); add_1574 = None + getitem_341 = topk_23[1]; topk_23 = None + gather_23 = torch.ops.aten.gather.default(div_116, 1, getitem_341); div_116 = None + mul_1144 = torch.ops.aten.mul.Tensor(gather_23, 1.0); gather_23 = None + view_1601 = torch.ops.aten.view.default(getitem_341, [-1]) + histc_46 = torch.ops.aten.histc.default(view_1601, 64, 0, 64) + add_1575 = torch.ops.aten.add.Tensor(primals_400, histc_46) + sort_23 = torch.ops.aten.sort.stable(view_1601, stable = True); view_1601 = None + getitem_343 = sort_23[1]; sort_23 = None + div_117 = torch.ops.aten.div.Tensor_mode(getitem_343, 6, rounding_mode = 'floor') + index_46 = torch.ops.aten.index.Tensor(view_1599, [div_117]) + all_to_all_single_69 = torch.ops._c10d_functional.all_to_all_single.default(histc_46, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_502 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_69); all_to_all_single_69 = None + wait_tensor_503 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_502); wait_tensor_502 = None + view_1605 = torch.ops.aten.view.default(histc_46, [8, -1]); histc_46 = None + sum_94 = torch.ops.aten.sum.dim_IntList(view_1605, [1]); view_1605 = None + device_put_46 = torch.ops.prims.device_put.default(sum_94, device(type='cpu'), True); sum_94 = None + view_1606 = torch.ops.aten.view.default(wait_tensor_503, [8, -1]) + sum_95 = torch.ops.aten.sum.dim_IntList(view_1606, [1]) + device_put_47 = torch.ops.prims.device_put.default(sum_95, device(type='cpu')); sum_95 = None + select_368 = torch.ops.aten.select.int(device_put_46, 0, 0) + _local_scalar_dense_368 = torch.ops.aten._local_scalar_dense.default(select_368); select_368 = None + ge_460 = _local_scalar_dense_368 >= 0 + _assert_scalar_368 = torch.ops.aten._assert_scalar.default(ge_460, "Runtime assertion failed for expression u368 >= 0 on node 'ge_368'"); ge_460 = _assert_scalar_368 = None + select_369 = torch.ops.aten.select.int(device_put_46, 0, 1) + _local_scalar_dense_369 = torch.ops.aten._local_scalar_dense.default(select_369); select_369 = None + ge_461 = _local_scalar_dense_369 >= 0 + _assert_scalar_369 = torch.ops.aten._assert_scalar.default(ge_461, "Runtime assertion failed for expression u369 >= 0 on node 'ge_369'"); ge_461 = _assert_scalar_369 = None + select_370 = torch.ops.aten.select.int(device_put_46, 0, 2) + _local_scalar_dense_370 = torch.ops.aten._local_scalar_dense.default(select_370); select_370 = None + ge_462 = _local_scalar_dense_370 >= 0 + _assert_scalar_370 = torch.ops.aten._assert_scalar.default(ge_462, "Runtime assertion failed for expression u370 >= 0 on node 'ge_370'"); ge_462 = _assert_scalar_370 = None + select_371 = torch.ops.aten.select.int(device_put_46, 0, 3) + _local_scalar_dense_371 = torch.ops.aten._local_scalar_dense.default(select_371); select_371 = None + ge_463 = _local_scalar_dense_371 >= 0 + _assert_scalar_371 = torch.ops.aten._assert_scalar.default(ge_463, "Runtime assertion failed for expression u371 >= 0 on node 'ge_371'"); ge_463 = _assert_scalar_371 = None + select_372 = torch.ops.aten.select.int(device_put_46, 0, 4) + _local_scalar_dense_372 = torch.ops.aten._local_scalar_dense.default(select_372); select_372 = None + ge_464 = _local_scalar_dense_372 >= 0 + _assert_scalar_372 = torch.ops.aten._assert_scalar.default(ge_464, "Runtime assertion failed for expression u372 >= 0 on node 'ge_372'"); ge_464 = _assert_scalar_372 = None + select_373 = torch.ops.aten.select.int(device_put_46, 0, 5) + _local_scalar_dense_373 = torch.ops.aten._local_scalar_dense.default(select_373); select_373 = None + ge_465 = _local_scalar_dense_373 >= 0 + _assert_scalar_373 = torch.ops.aten._assert_scalar.default(ge_465, "Runtime assertion failed for expression u373 >= 0 on node 'ge_373'"); ge_465 = _assert_scalar_373 = None + select_374 = torch.ops.aten.select.int(device_put_46, 0, 6) + _local_scalar_dense_374 = torch.ops.aten._local_scalar_dense.default(select_374); select_374 = None + ge_466 = _local_scalar_dense_374 >= 0 + _assert_scalar_374 = torch.ops.aten._assert_scalar.default(ge_466, "Runtime assertion failed for expression u374 >= 0 on node 'ge_374'"); ge_466 = _assert_scalar_374 = None + select_375 = torch.ops.aten.select.int(device_put_46, 0, 7); device_put_46 = None + _local_scalar_dense_375 = torch.ops.aten._local_scalar_dense.default(select_375); select_375 = None + ge_467 = _local_scalar_dense_375 >= 0 + _assert_scalar_375 = torch.ops.aten._assert_scalar.default(ge_467, "Runtime assertion failed for expression u375 >= 0 on node 'ge_375'"); ge_467 = _assert_scalar_375 = None + select_376 = torch.ops.aten.select.int(device_put_47, 0, 0) + _local_scalar_dense_376 = torch.ops.aten._local_scalar_dense.default(select_376); select_376 = None + ge_468 = _local_scalar_dense_376 >= 0 + _assert_scalar_376 = torch.ops.aten._assert_scalar.default(ge_468, "Runtime assertion failed for expression u376 >= 0 on node 'ge_376'"); ge_468 = _assert_scalar_376 = None + select_377 = torch.ops.aten.select.int(device_put_47, 0, 1) + _local_scalar_dense_377 = torch.ops.aten._local_scalar_dense.default(select_377); select_377 = None + ge_469 = _local_scalar_dense_377 >= 0 + _assert_scalar_377 = torch.ops.aten._assert_scalar.default(ge_469, "Runtime assertion failed for expression u377 >= 0 on node 'ge_377'"); ge_469 = _assert_scalar_377 = None + select_378 = torch.ops.aten.select.int(device_put_47, 0, 2) + _local_scalar_dense_378 = torch.ops.aten._local_scalar_dense.default(select_378); select_378 = None + ge_470 = _local_scalar_dense_378 >= 0 + _assert_scalar_378 = torch.ops.aten._assert_scalar.default(ge_470, "Runtime assertion failed for expression u378 >= 0 on node 'ge_378'"); ge_470 = _assert_scalar_378 = None + select_379 = torch.ops.aten.select.int(device_put_47, 0, 3) + _local_scalar_dense_379 = torch.ops.aten._local_scalar_dense.default(select_379); select_379 = None + ge_471 = _local_scalar_dense_379 >= 0 + _assert_scalar_379 = torch.ops.aten._assert_scalar.default(ge_471, "Runtime assertion failed for expression u379 >= 0 on node 'ge_379'"); ge_471 = _assert_scalar_379 = None + select_380 = torch.ops.aten.select.int(device_put_47, 0, 4) + _local_scalar_dense_380 = torch.ops.aten._local_scalar_dense.default(select_380); select_380 = None + ge_472 = _local_scalar_dense_380 >= 0 + _assert_scalar_380 = torch.ops.aten._assert_scalar.default(ge_472, "Runtime assertion failed for expression u380 >= 0 on node 'ge_380'"); ge_472 = _assert_scalar_380 = None + select_381 = torch.ops.aten.select.int(device_put_47, 0, 5) + _local_scalar_dense_381 = torch.ops.aten._local_scalar_dense.default(select_381); select_381 = None + ge_473 = _local_scalar_dense_381 >= 0 + _assert_scalar_381 = torch.ops.aten._assert_scalar.default(ge_473, "Runtime assertion failed for expression u381 >= 0 on node 'ge_381'"); ge_473 = _assert_scalar_381 = None + select_382 = torch.ops.aten.select.int(device_put_47, 0, 6) + _local_scalar_dense_382 = torch.ops.aten._local_scalar_dense.default(select_382); select_382 = None + ge_474 = _local_scalar_dense_382 >= 0 + _assert_scalar_382 = torch.ops.aten._assert_scalar.default(ge_474, "Runtime assertion failed for expression u382 >= 0 on node 'ge_382'"); ge_474 = _assert_scalar_382 = None + select_383 = torch.ops.aten.select.int(device_put_47, 0, 7); device_put_47 = None + _local_scalar_dense_383 = torch.ops.aten._local_scalar_dense.default(select_383); select_383 = None + ge_475 = _local_scalar_dense_383 >= 0 + _assert_scalar_383 = torch.ops.aten._assert_scalar.default(ge_475, "Runtime assertion failed for expression u383 >= 0 on node 'ge_383'"); ge_475 = _assert_scalar_383 = None + all_to_all_single_70 = torch.ops._c10d_functional.all_to_all_single.default(index_46, [_local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383], [_local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375], '521'); index_46 = None + sym_size_int_92 = torch.ops.aten.sym_size.int(all_to_all_single_70, 0) + wait_tensor_504 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_70); all_to_all_single_70 = None + sym_sum_46 = torch.sym_sum((_local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383)) + add_1582 = sym_sum_46 + 64; sym_sum_46 = None + add_1583 = add_1582 + 8; add_1582 = None + sub_555 = add_1583 - 1; add_1583 = None + floordiv_23 = sub_555 // 8; sub_555 = None + mul_1149 = floordiv_23 * 8; floordiv_23 = None + cumsum_69 = torch.ops.aten.cumsum.default(wait_tensor_503, 0) + sub_556 = torch.ops.aten.sub.Tensor(cumsum_69, wait_tensor_503); cumsum_69 = None + sum_96 = torch.ops.aten.sum.dim_IntList(view_1606, [0]); view_1606 = None + clamp_min_23 = torch.ops.aten.clamp_min.default(sum_96, 8); sum_96 = None + add_1584 = torch.ops.aten.add.Tensor(clamp_min_23, 8); clamp_min_23 = None + sub_557 = torch.ops.aten.sub.Tensor(add_1584, 1); add_1584 = None + div_118 = torch.ops.aten.div.Tensor_mode(sub_557, 8, rounding_mode = 'floor'); sub_557 = None + mul_1150 = torch.ops.aten.mul.Tensor(div_118, 8); div_118 = None + convert_element_type_1310 = torch.ops.prims.convert_element_type.default(mul_1150, torch.int32); mul_1150 = None + cumsum_70 = torch.ops.aten.cumsum.default(convert_element_type_1310, 0) + sub_558 = torch.ops.aten.sub.Tensor(cumsum_70, convert_element_type_1310); cumsum_70 = None + full_319 = torch.ops.aten.full.default([mul_1149], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1149 = None + triton_kernel_wrapper_functional_proxy_23 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 23, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_503, 'start_index_values_ptr': sub_556, 'write_offsets_ptr': sub_558, 'output_ptr': full_319}, tensors_to_clone = ['output_ptr']); wait_tensor_503 = sub_556 = sub_558 = full_319 = None + getitem_344 = triton_kernel_wrapper_functional_proxy_23['output_ptr']; triton_kernel_wrapper_functional_proxy_23 = None + cat_73 = torch.ops.aten.cat.default([wait_tensor_504, full_default]); wait_tensor_504 = None + sym_size_int_93 = torch.ops.aten.sym_size.int(cat_73, 0) + sym_sum_47 = torch.sym_sum((1, _local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383)) + index_47 = torch.ops.aten.index.Tensor(cat_73, [getitem_344]); cat_73 = None + convert_element_type_1312 = torch.ops.prims.convert_element_type.default(primals_401, torch.bfloat16) + all_gather_into_tensor_410 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1312, 8, '513'); convert_element_type_1312 = None + wait_tensor_505 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_410); all_gather_into_tensor_410 = None + convert_element_type_1314 = torch.ops.prims.convert_element_type.default(primals_402, torch.bfloat16) + all_gather_into_tensor_412 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1314, 8, '513'); convert_element_type_1314 = None + wait_tensor_507 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_412); all_gather_into_tensor_412 = None + convert_element_type_1315 = torch.ops.prims.convert_element_type.default(primals_403, torch.bfloat16) + all_gather_into_tensor_413 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1315, 8, '513'); convert_element_type_1315 = None + wait_tensor_508 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_413); all_gather_into_tensor_413 = None + cumsum_71 = torch.ops.aten.cumsum.default(convert_element_type_1310, 0, dtype = torch.int32); convert_element_type_1310 = None + permute_365 = torch.ops.aten.permute.default(wait_tensor_505, [0, 2, 1]); wait_tensor_505 = None + _grouped_mm_69 = torch.ops.aten._grouped_mm.default(index_47, permute_365, cumsum_71); permute_365 = None + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(_grouped_mm_69, torch.float32) + neg_47 = torch.ops.aten.neg.default(convert_element_type_1318) + exp_71 = torch.ops.aten.exp.default(neg_47); neg_47 = None + add_1596 = torch.ops.aten.add.Tensor(exp_71, 1); exp_71 = None + div_119 = torch.ops.aten.div.Tensor(convert_element_type_1318, add_1596); convert_element_type_1318 = add_1596 = None + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(div_119, torch.bfloat16); div_119 = None + permute_366 = torch.ops.aten.permute.default(wait_tensor_508, [0, 2, 1]); wait_tensor_508 = None + _grouped_mm_70 = torch.ops.aten._grouped_mm.default(index_47, permute_366, cumsum_71); permute_366 = None + mul_1162 = torch.ops.aten.mul.Tensor(convert_element_type_1319, _grouped_mm_70); convert_element_type_1319 = None + permute_367 = torch.ops.aten.permute.default(wait_tensor_507, [0, 2, 1]); wait_tensor_507 = None + _grouped_mm_71 = torch.ops.aten._grouped_mm.default(mul_1162, permute_367, cumsum_71); permute_367 = None + empty_23 = torch.ops.aten.empty.memory_format([sym_size_int_93, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_46 = torch.ops.aten.index_put.default(empty_23, [getitem_344], _grouped_mm_71); empty_23 = _grouped_mm_71 = None + slice_98 = torch.ops.aten.slice.Tensor(index_put_46, 0, 0, -1); index_put_46 = None + all_to_all_single_71 = torch.ops._c10d_functional.all_to_all_single.default(slice_98, [_local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375], [_local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383], '521'); slice_98 = None + wait_tensor_511 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_71); all_to_all_single_71 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(primals_404, torch.bfloat16) + all_gather_into_tensor_416 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1320, 64, '0'); convert_element_type_1320 = None + wait_tensor_512 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_416); all_gather_into_tensor_416 = None + permute_368 = torch.ops.aten.permute.default(wait_tensor_512, [1, 0]); wait_tensor_512 = None + mm_196 = torch.ops.aten.mm.default(view_1599, permute_368); permute_368 = None + convert_element_type_1323 = torch.ops.prims.convert_element_type.default(mm_196, torch.float32) + neg_48 = torch.ops.aten.neg.default(convert_element_type_1323) + exp_72 = torch.ops.aten.exp.default(neg_48); neg_48 = None + add_1632 = torch.ops.aten.add.Tensor(exp_72, 1); exp_72 = None + div_120 = torch.ops.aten.div.Tensor(convert_element_type_1323, add_1632); convert_element_type_1323 = add_1632 = None + convert_element_type_1324 = torch.ops.prims.convert_element_type.default(div_120, torch.bfloat16); div_120 = None + convert_element_type_1325 = torch.ops.prims.convert_element_type.default(primals_405, torch.bfloat16) + all_gather_into_tensor_417 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1325, 64, '0'); convert_element_type_1325 = None + wait_tensor_513 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_417); all_gather_into_tensor_417 = None + permute_369 = torch.ops.aten.permute.default(wait_tensor_513, [1, 0]); wait_tensor_513 = None + mm_197 = torch.ops.aten.mm.default(view_1599, permute_369); permute_369 = None + mul_1182 = torch.ops.aten.mul.Tensor(convert_element_type_1324, mm_197); convert_element_type_1324 = None + convert_element_type_1328 = torch.ops.prims.convert_element_type.default(primals_406, torch.bfloat16) + all_gather_into_tensor_418 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1328, 64, '0'); convert_element_type_1328 = None + wait_tensor_514 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_418); all_gather_into_tensor_418 = None + permute_370 = torch.ops.aten.permute.default(wait_tensor_514, [1, 0]); wait_tensor_514 = None + mm_198 = torch.ops.aten.mm.default(mul_1182, permute_370); permute_370 = None + index_put_47 = torch.ops.aten.index_put.default(full_default_1, [getitem_343], wait_tensor_511); wait_tensor_511 = None + view_1639 = torch.ops.aten.view.default(mul_1144, [-1, 1, 6]); mul_1144 = None + view_1640 = torch.ops.aten.view.default(index_put_47, [-1, 6, 2048]); index_put_47 = None + convert_element_type_1331 = torch.ops.prims.convert_element_type.default(view_1640, torch.float32); view_1640 = None + bmm_23 = torch.ops.aten.bmm.default(view_1639, convert_element_type_1331) + convert_element_type_1332 = torch.ops.prims.convert_element_type.default(bmm_23, torch.bfloat16); bmm_23 = None + squeeze_23 = torch.ops.aten.squeeze.dim(convert_element_type_1332, 1); convert_element_type_1332 = None + add_1636 = torch.ops.aten.add.Tensor(mm_198, squeeze_23); mm_198 = squeeze_23 = None + view_1641 = torch.ops.aten.view.default(add_1636, [2, 4096, 2048]); add_1636 = None + add_1637 = torch.ops.aten.add.Tensor(add_1572, view_1641); view_1641 = None + convert_element_type_1333 = torch.ops.prims.convert_element_type.default(primals_407, torch.bfloat16) + all_gather_into_tensor_419 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1333, 64, '0'); convert_element_type_1333 = None + wait_tensor_515 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_419); all_gather_into_tensor_419 = None + convert_element_type_1334 = torch.ops.prims.convert_element_type.default(add_1637, torch.float32) + pow_76 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1334, 2) + mean_75 = torch.ops.aten.mean.dim(pow_76, [2], True); pow_76 = None + add_1638 = torch.ops.aten.add.Scalar(mean_75, 1e-05); mean_75 = None + rsqrt_75 = torch.ops.aten.rsqrt.default(add_1638); add_1638 = None + mul_1185 = torch.ops.aten.mul.Tensor(convert_element_type_1334, rsqrt_75); convert_element_type_1334 = None + mul_1186 = torch.ops.aten.mul.Tensor(mul_1185, wait_tensor_515); mul_1185 = wait_tensor_515 = None + convert_element_type_1335 = torch.ops.prims.convert_element_type.default(mul_1186, torch.bfloat16); mul_1186 = None + convert_element_type_1336 = torch.ops.prims.convert_element_type.default(primals_408, torch.bfloat16) + all_gather_into_tensor_420 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1336, 64, '0'); convert_element_type_1336 = None + wait_tensor_516 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_420); all_gather_into_tensor_420 = None + permute_371 = torch.ops.aten.permute.default(wait_tensor_516, [1, 0]); wait_tensor_516 = None + view_1644 = torch.ops.aten.view.default(convert_element_type_1335, [8192, 2048]); convert_element_type_1335 = None + mm_199 = torch.ops.aten.mm.default(view_1644, permute_371); permute_371 = None + view_1645 = torch.ops.aten.view.default(mm_199, [2, 4096, 3072]); mm_199 = None + view_1646 = torch.ops.aten.view.default(view_1645, [2, 4096, -1, 192]); view_1645 = None + split_with_sizes_75 = torch.ops.aten.split_with_sizes.default(view_1646, [128, 64], -1); view_1646 = None + getitem_345 = split_with_sizes_75[0] + getitem_346 = split_with_sizes_75[1]; split_with_sizes_75 = None + convert_element_type_1339 = torch.ops.prims.convert_element_type.default(getitem_346, torch.float32); getitem_346 = None + view_1647 = torch.ops.aten.view.default(convert_element_type_1339, [2, 4096, 16, -1, 2]); convert_element_type_1339 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_1647); view_1647 = None + mul_1187 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_7); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_1187); mul_1187 = None + view_1649 = torch.ops.aten.view.default(view_as_real_50, [2, 4096, 16, 64]); view_as_real_50 = None + convert_element_type_1340 = torch.ops.prims.convert_element_type.default(view_1649, torch.bfloat16); view_1649 = None + cat_74 = torch.ops.aten.cat.default([getitem_345, convert_element_type_1340], -1); getitem_345 = convert_element_type_1340 = None + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(primals_409, torch.bfloat16) + all_gather_into_tensor_421 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1341, 64, '0'); convert_element_type_1341 = None + wait_tensor_517 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_421); all_gather_into_tensor_421 = None + permute_372 = torch.ops.aten.permute.default(wait_tensor_517, [1, 0]); wait_tensor_517 = None + mm_200 = torch.ops.aten.mm.default(view_1644, permute_372); permute_372 = None + view_1652 = torch.ops.aten.view.default(mm_200, [2, 4096, 576]); mm_200 = None + split_with_sizes_76 = torch.ops.aten.split_with_sizes.default(view_1652, [512, 64], -1); view_1652 = None + getitem_347 = split_with_sizes_76[0] + getitem_348 = split_with_sizes_76[1]; split_with_sizes_76 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(getitem_348, 2); getitem_348 = None + convert_element_type_1344 = torch.ops.prims.convert_element_type.default(unsqueeze_49, torch.float32); unsqueeze_49 = None + view_1653 = torch.ops.aten.view.default(convert_element_type_1344, [2, 4096, 1, -1, 2]); convert_element_type_1344 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_1653); view_1653 = None + mul_1188 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_7); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_1188); mul_1188 = None + view_1655 = torch.ops.aten.view.default(view_as_real_51, [2, 4096, 1, 64]); view_as_real_51 = None + convert_element_type_1345 = torch.ops.prims.convert_element_type.default(view_1655, torch.bfloat16); view_1655 = None + convert_element_type_1346 = torch.ops.prims.convert_element_type.default(primals_410, torch.bfloat16) + all_gather_into_tensor_422 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1346, 64, '0'); convert_element_type_1346 = None + wait_tensor_518 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_422); all_gather_into_tensor_422 = None + convert_element_type_1347 = torch.ops.prims.convert_element_type.default(getitem_347, torch.float32) + pow_77 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1347, 2) + mean_76 = torch.ops.aten.mean.dim(pow_77, [2], True); pow_77 = None + add_1639 = torch.ops.aten.add.Scalar(mean_76, 1e-05); mean_76 = None + rsqrt_76 = torch.ops.aten.rsqrt.default(add_1639); add_1639 = None + mul_1189 = torch.ops.aten.mul.Tensor(convert_element_type_1347, rsqrt_76); convert_element_type_1347 = None + mul_1190 = torch.ops.aten.mul.Tensor(mul_1189, wait_tensor_518); mul_1189 = wait_tensor_518 = None + convert_element_type_1348 = torch.ops.prims.convert_element_type.default(mul_1190, torch.bfloat16); mul_1190 = None + convert_element_type_1349 = torch.ops.prims.convert_element_type.default(primals_411, torch.bfloat16) + all_gather_into_tensor_423 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1349, 64, '0'); convert_element_type_1349 = None + wait_tensor_519 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_423); all_gather_into_tensor_423 = None + permute_373 = torch.ops.aten.permute.default(wait_tensor_519, [1, 0]); wait_tensor_519 = None + view_1658 = torch.ops.aten.view.default(convert_element_type_1348, [8192, 512]); convert_element_type_1348 = None + mm_201 = torch.ops.aten.mm.default(view_1658, permute_373); permute_373 = None + view_1659 = torch.ops.aten.view.default(mm_201, [2, 4096, 4096]); mm_201 = None + view_1660 = torch.ops.aten.view.default(view_1659, [2, 4096, -1, 256]); view_1659 = None + split_with_sizes_77 = torch.ops.aten.split_with_sizes.default(view_1660, [128, 128], -1); view_1660 = None + getitem_349 = split_with_sizes_77[0] + getitem_350 = split_with_sizes_77[1]; split_with_sizes_77 = None + expand_25 = torch.ops.aten.expand.default(convert_element_type_1345, [-1, -1, 16, -1]); convert_element_type_1345 = None + cat_75 = torch.ops.aten.cat.default([getitem_349, expand_25], -1); getitem_349 = expand_25 = None + permute_374 = torch.ops.aten.permute.default(cat_74, [0, 2, 1, 3]); cat_74 = None + permute_375 = torch.ops.aten.permute.default(cat_75, [0, 2, 1, 3]); cat_75 = None + permute_376 = torch.ops.aten.permute.default(getitem_350, [0, 2, 1, 3]); getitem_350 = None + sdpa_score25 = self.sdpa_score25 + sdpa_mask25 = self.sdpa_mask25 + flex_attention_25 = torch.ops.higher_order.flex_attention(permute_374, permute_375, permute_376, sdpa_score25, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask25), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score25 = sdpa_mask25 = None + getitem_351 = flex_attention_25[0] + getitem_352 = flex_attention_25[1]; flex_attention_25 = None + permute_377 = torch.ops.aten.permute.default(getitem_351, [0, 2, 1, 3]) + view_1661 = torch.ops.aten.view.default(permute_377, [2, 4096, -1]); permute_377 = None + convert_element_type_1352 = torch.ops.prims.convert_element_type.default(primals_412, torch.bfloat16) + all_gather_into_tensor_424 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1352, 64, '0'); convert_element_type_1352 = None + wait_tensor_520 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_424); all_gather_into_tensor_424 = None + permute_378 = torch.ops.aten.permute.default(wait_tensor_520, [1, 0]); wait_tensor_520 = None + view_1663 = torch.ops.aten.view.default(view_1661, [8192, 2048]); view_1661 = None + mm_202 = torch.ops.aten.mm.default(view_1663, permute_378); view_1663 = permute_378 = None + view_1664 = torch.ops.aten.view.default(mm_202, [2, 4096, 2048]); mm_202 = None + add_1640 = torch.ops.aten.add.Tensor(add_1637, view_1664); view_1664 = None + convert_element_type_1355 = torch.ops.prims.convert_element_type.default(primals_413, torch.bfloat16) + all_gather_into_tensor_425 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1355, 64, '0'); convert_element_type_1355 = None + wait_tensor_521 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_425); all_gather_into_tensor_425 = None + convert_element_type_1356 = torch.ops.prims.convert_element_type.default(add_1640, torch.float32) + pow_78 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1356, 2) + mean_77 = torch.ops.aten.mean.dim(pow_78, [2], True); pow_78 = None + add_1641 = torch.ops.aten.add.Scalar(mean_77, 1e-05); mean_77 = None + rsqrt_77 = torch.ops.aten.rsqrt.default(add_1641); add_1641 = None + mul_1191 = torch.ops.aten.mul.Tensor(convert_element_type_1356, rsqrt_77); convert_element_type_1356 = None + mul_1192 = torch.ops.aten.mul.Tensor(mul_1191, wait_tensor_521); mul_1191 = wait_tensor_521 = None + convert_element_type_1357 = torch.ops.prims.convert_element_type.default(mul_1192, torch.bfloat16); mul_1192 = None + view_1666 = torch.ops.aten.view.default(convert_element_type_1357, [-1, 2048]); convert_element_type_1357 = None + convert_element_type_1358 = torch.ops.prims.convert_element_type.default(primals_415, torch.bfloat16) + all_gather_into_tensor_426 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1358, 64, '0'); convert_element_type_1358 = None + wait_tensor_522 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_426); all_gather_into_tensor_426 = None + permute_379 = torch.ops.aten.permute.default(wait_tensor_522, [1, 0]); wait_tensor_522 = None + mm_203 = torch.ops.aten.mm.default(view_1666, permute_379); permute_379 = None + convert_element_type_1361 = torch.ops.prims.convert_element_type.default(mm_203, torch.float32) + amax_24 = torch.ops.aten.amax.default(convert_element_type_1361, [1], True) + sub_576 = torch.ops.aten.sub.Tensor(convert_element_type_1361, amax_24); convert_element_type_1361 = None + exp_73 = torch.ops.aten.exp.default(sub_576); sub_576 = None + sum_97 = torch.ops.aten.sum.dim_IntList(exp_73, [1], True) + div_121 = torch.ops.aten.div.Tensor(exp_73, sum_97); exp_73 = None + add_1642 = torch.ops.aten.add.Tensor(div_121, primals_414); primals_414 = None + topk_24 = torch.ops.aten.topk.default(add_1642, 6, -1, True, False); add_1642 = None + getitem_355 = topk_24[1]; topk_24 = None + gather_24 = torch.ops.aten.gather.default(div_121, 1, getitem_355); div_121 = None + mul_1193 = torch.ops.aten.mul.Tensor(gather_24, 1.0); gather_24 = None + view_1668 = torch.ops.aten.view.default(getitem_355, [-1]) + histc_48 = torch.ops.aten.histc.default(view_1668, 64, 0, 64) + add_1643 = torch.ops.aten.add.Tensor(primals_416, histc_48) + sort_24 = torch.ops.aten.sort.stable(view_1668, stable = True); view_1668 = None + getitem_357 = sort_24[1]; sort_24 = None + div_122 = torch.ops.aten.div.Tensor_mode(getitem_357, 6, rounding_mode = 'floor') + index_48 = torch.ops.aten.index.Tensor(view_1666, [div_122]) + all_to_all_single_72 = torch.ops._c10d_functional.all_to_all_single.default(histc_48, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_523 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_72); all_to_all_single_72 = None + wait_tensor_524 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_523); wait_tensor_523 = None + view_1672 = torch.ops.aten.view.default(histc_48, [8, -1]); histc_48 = None + sum_98 = torch.ops.aten.sum.dim_IntList(view_1672, [1]); view_1672 = None + device_put_48 = torch.ops.prims.device_put.default(sum_98, device(type='cpu'), True); sum_98 = None + view_1673 = torch.ops.aten.view.default(wait_tensor_524, [8, -1]) + sum_99 = torch.ops.aten.sum.dim_IntList(view_1673, [1]) + device_put_49 = torch.ops.prims.device_put.default(sum_99, device(type='cpu')); sum_99 = None + select_384 = torch.ops.aten.select.int(device_put_48, 0, 0) + _local_scalar_dense_384 = torch.ops.aten._local_scalar_dense.default(select_384); select_384 = None + ge_480 = _local_scalar_dense_384 >= 0 + _assert_scalar_384 = torch.ops.aten._assert_scalar.default(ge_480, "Runtime assertion failed for expression u384 >= 0 on node 'ge_384'"); ge_480 = _assert_scalar_384 = None + select_385 = torch.ops.aten.select.int(device_put_48, 0, 1) + _local_scalar_dense_385 = torch.ops.aten._local_scalar_dense.default(select_385); select_385 = None + ge_481 = _local_scalar_dense_385 >= 0 + _assert_scalar_385 = torch.ops.aten._assert_scalar.default(ge_481, "Runtime assertion failed for expression u385 >= 0 on node 'ge_385'"); ge_481 = _assert_scalar_385 = None + select_386 = torch.ops.aten.select.int(device_put_48, 0, 2) + _local_scalar_dense_386 = torch.ops.aten._local_scalar_dense.default(select_386); select_386 = None + ge_482 = _local_scalar_dense_386 >= 0 + _assert_scalar_386 = torch.ops.aten._assert_scalar.default(ge_482, "Runtime assertion failed for expression u386 >= 0 on node 'ge_386'"); ge_482 = _assert_scalar_386 = None + select_387 = torch.ops.aten.select.int(device_put_48, 0, 3) + _local_scalar_dense_387 = torch.ops.aten._local_scalar_dense.default(select_387); select_387 = None + ge_483 = _local_scalar_dense_387 >= 0 + _assert_scalar_387 = torch.ops.aten._assert_scalar.default(ge_483, "Runtime assertion failed for expression u387 >= 0 on node 'ge_387'"); ge_483 = _assert_scalar_387 = None + select_388 = torch.ops.aten.select.int(device_put_48, 0, 4) + _local_scalar_dense_388 = torch.ops.aten._local_scalar_dense.default(select_388); select_388 = None + ge_484 = _local_scalar_dense_388 >= 0 + _assert_scalar_388 = torch.ops.aten._assert_scalar.default(ge_484, "Runtime assertion failed for expression u388 >= 0 on node 'ge_388'"); ge_484 = _assert_scalar_388 = None + select_389 = torch.ops.aten.select.int(device_put_48, 0, 5) + _local_scalar_dense_389 = torch.ops.aten._local_scalar_dense.default(select_389); select_389 = None + ge_485 = _local_scalar_dense_389 >= 0 + _assert_scalar_389 = torch.ops.aten._assert_scalar.default(ge_485, "Runtime assertion failed for expression u389 >= 0 on node 'ge_389'"); ge_485 = _assert_scalar_389 = None + select_390 = torch.ops.aten.select.int(device_put_48, 0, 6) + _local_scalar_dense_390 = torch.ops.aten._local_scalar_dense.default(select_390); select_390 = None + ge_486 = _local_scalar_dense_390 >= 0 + _assert_scalar_390 = torch.ops.aten._assert_scalar.default(ge_486, "Runtime assertion failed for expression u390 >= 0 on node 'ge_390'"); ge_486 = _assert_scalar_390 = None + select_391 = torch.ops.aten.select.int(device_put_48, 0, 7); device_put_48 = None + _local_scalar_dense_391 = torch.ops.aten._local_scalar_dense.default(select_391); select_391 = None + ge_487 = _local_scalar_dense_391 >= 0 + _assert_scalar_391 = torch.ops.aten._assert_scalar.default(ge_487, "Runtime assertion failed for expression u391 >= 0 on node 'ge_391'"); ge_487 = _assert_scalar_391 = None + select_392 = torch.ops.aten.select.int(device_put_49, 0, 0) + _local_scalar_dense_392 = torch.ops.aten._local_scalar_dense.default(select_392); select_392 = None + ge_488 = _local_scalar_dense_392 >= 0 + _assert_scalar_392 = torch.ops.aten._assert_scalar.default(ge_488, "Runtime assertion failed for expression u392 >= 0 on node 'ge_392'"); ge_488 = _assert_scalar_392 = None + select_393 = torch.ops.aten.select.int(device_put_49, 0, 1) + _local_scalar_dense_393 = torch.ops.aten._local_scalar_dense.default(select_393); select_393 = None + ge_489 = _local_scalar_dense_393 >= 0 + _assert_scalar_393 = torch.ops.aten._assert_scalar.default(ge_489, "Runtime assertion failed for expression u393 >= 0 on node 'ge_393'"); ge_489 = _assert_scalar_393 = None + select_394 = torch.ops.aten.select.int(device_put_49, 0, 2) + _local_scalar_dense_394 = torch.ops.aten._local_scalar_dense.default(select_394); select_394 = None + ge_490 = _local_scalar_dense_394 >= 0 + _assert_scalar_394 = torch.ops.aten._assert_scalar.default(ge_490, "Runtime assertion failed for expression u394 >= 0 on node 'ge_394'"); ge_490 = _assert_scalar_394 = None + select_395 = torch.ops.aten.select.int(device_put_49, 0, 3) + _local_scalar_dense_395 = torch.ops.aten._local_scalar_dense.default(select_395); select_395 = None + ge_491 = _local_scalar_dense_395 >= 0 + _assert_scalar_395 = torch.ops.aten._assert_scalar.default(ge_491, "Runtime assertion failed for expression u395 >= 0 on node 'ge_395'"); ge_491 = _assert_scalar_395 = None + select_396 = torch.ops.aten.select.int(device_put_49, 0, 4) + _local_scalar_dense_396 = torch.ops.aten._local_scalar_dense.default(select_396); select_396 = None + ge_492 = _local_scalar_dense_396 >= 0 + _assert_scalar_396 = torch.ops.aten._assert_scalar.default(ge_492, "Runtime assertion failed for expression u396 >= 0 on node 'ge_396'"); ge_492 = _assert_scalar_396 = None + select_397 = torch.ops.aten.select.int(device_put_49, 0, 5) + _local_scalar_dense_397 = torch.ops.aten._local_scalar_dense.default(select_397); select_397 = None + ge_493 = _local_scalar_dense_397 >= 0 + _assert_scalar_397 = torch.ops.aten._assert_scalar.default(ge_493, "Runtime assertion failed for expression u397 >= 0 on node 'ge_397'"); ge_493 = _assert_scalar_397 = None + select_398 = torch.ops.aten.select.int(device_put_49, 0, 6) + _local_scalar_dense_398 = torch.ops.aten._local_scalar_dense.default(select_398); select_398 = None + ge_494 = _local_scalar_dense_398 >= 0 + _assert_scalar_398 = torch.ops.aten._assert_scalar.default(ge_494, "Runtime assertion failed for expression u398 >= 0 on node 'ge_398'"); ge_494 = _assert_scalar_398 = None + select_399 = torch.ops.aten.select.int(device_put_49, 0, 7); device_put_49 = None + _local_scalar_dense_399 = torch.ops.aten._local_scalar_dense.default(select_399); select_399 = None + ge_495 = _local_scalar_dense_399 >= 0 + _assert_scalar_399 = torch.ops.aten._assert_scalar.default(ge_495, "Runtime assertion failed for expression u399 >= 0 on node 'ge_399'"); ge_495 = _assert_scalar_399 = None + all_to_all_single_73 = torch.ops._c10d_functional.all_to_all_single.default(index_48, [_local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399], [_local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391], '521'); index_48 = None + sym_size_int_96 = torch.ops.aten.sym_size.int(all_to_all_single_73, 0) + wait_tensor_525 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_73); all_to_all_single_73 = None + sym_sum_48 = torch.sym_sum((_local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399)) + add_1650 = sym_sum_48 + 64; sym_sum_48 = None + add_1651 = add_1650 + 8; add_1650 = None + sub_579 = add_1651 - 1; add_1651 = None + floordiv_24 = sub_579 // 8; sub_579 = None + mul_1198 = floordiv_24 * 8; floordiv_24 = None + cumsum_72 = torch.ops.aten.cumsum.default(wait_tensor_524, 0) + sub_580 = torch.ops.aten.sub.Tensor(cumsum_72, wait_tensor_524); cumsum_72 = None + sum_100 = torch.ops.aten.sum.dim_IntList(view_1673, [0]); view_1673 = None + clamp_min_24 = torch.ops.aten.clamp_min.default(sum_100, 8); sum_100 = None + add_1652 = torch.ops.aten.add.Tensor(clamp_min_24, 8); clamp_min_24 = None + sub_581 = torch.ops.aten.sub.Tensor(add_1652, 1); add_1652 = None + div_123 = torch.ops.aten.div.Tensor_mode(sub_581, 8, rounding_mode = 'floor'); sub_581 = None + mul_1199 = torch.ops.aten.mul.Tensor(div_123, 8); div_123 = None + convert_element_type_1364 = torch.ops.prims.convert_element_type.default(mul_1199, torch.int32); mul_1199 = None + cumsum_73 = torch.ops.aten.cumsum.default(convert_element_type_1364, 0) + sub_582 = torch.ops.aten.sub.Tensor(cumsum_73, convert_element_type_1364); cumsum_73 = None + full_332 = torch.ops.aten.full.default([mul_1198], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1198 = None + triton_kernel_wrapper_functional_proxy_24 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 24, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_524, 'start_index_values_ptr': sub_580, 'write_offsets_ptr': sub_582, 'output_ptr': full_332}, tensors_to_clone = ['output_ptr']); wait_tensor_524 = sub_580 = sub_582 = full_332 = None + getitem_358 = triton_kernel_wrapper_functional_proxy_24['output_ptr']; triton_kernel_wrapper_functional_proxy_24 = None + cat_76 = torch.ops.aten.cat.default([wait_tensor_525, full_default]); wait_tensor_525 = None + sym_size_int_97 = torch.ops.aten.sym_size.int(cat_76, 0) + sym_sum_49 = torch.sym_sum((1, _local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399)) + index_49 = torch.ops.aten.index.Tensor(cat_76, [getitem_358]); cat_76 = None + convert_element_type_1366 = torch.ops.prims.convert_element_type.default(primals_417, torch.bfloat16) + all_gather_into_tensor_427 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1366, 8, '513'); convert_element_type_1366 = None + wait_tensor_526 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_427); all_gather_into_tensor_427 = None + convert_element_type_1368 = torch.ops.prims.convert_element_type.default(primals_418, torch.bfloat16) + all_gather_into_tensor_429 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1368, 8, '513'); convert_element_type_1368 = None + wait_tensor_528 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_429); all_gather_into_tensor_429 = None + convert_element_type_1369 = torch.ops.prims.convert_element_type.default(primals_419, torch.bfloat16) + all_gather_into_tensor_430 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1369, 8, '513'); convert_element_type_1369 = None + wait_tensor_529 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_430); all_gather_into_tensor_430 = None + cumsum_74 = torch.ops.aten.cumsum.default(convert_element_type_1364, 0, dtype = torch.int32); convert_element_type_1364 = None + permute_380 = torch.ops.aten.permute.default(wait_tensor_526, [0, 2, 1]); wait_tensor_526 = None + _grouped_mm_72 = torch.ops.aten._grouped_mm.default(index_49, permute_380, cumsum_74); permute_380 = None + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(_grouped_mm_72, torch.float32) + neg_49 = torch.ops.aten.neg.default(convert_element_type_1372) + exp_74 = torch.ops.aten.exp.default(neg_49); neg_49 = None + add_1664 = torch.ops.aten.add.Tensor(exp_74, 1); exp_74 = None + div_124 = torch.ops.aten.div.Tensor(convert_element_type_1372, add_1664); convert_element_type_1372 = add_1664 = None + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(div_124, torch.bfloat16); div_124 = None + permute_381 = torch.ops.aten.permute.default(wait_tensor_529, [0, 2, 1]); wait_tensor_529 = None + _grouped_mm_73 = torch.ops.aten._grouped_mm.default(index_49, permute_381, cumsum_74); permute_381 = None + mul_1211 = torch.ops.aten.mul.Tensor(convert_element_type_1373, _grouped_mm_73); convert_element_type_1373 = None + permute_382 = torch.ops.aten.permute.default(wait_tensor_528, [0, 2, 1]); wait_tensor_528 = None + _grouped_mm_74 = torch.ops.aten._grouped_mm.default(mul_1211, permute_382, cumsum_74); permute_382 = None + empty_24 = torch.ops.aten.empty.memory_format([sym_size_int_97, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_48 = torch.ops.aten.index_put.default(empty_24, [getitem_358], _grouped_mm_74); empty_24 = _grouped_mm_74 = None + slice_102 = torch.ops.aten.slice.Tensor(index_put_48, 0, 0, -1); index_put_48 = None + all_to_all_single_74 = torch.ops._c10d_functional.all_to_all_single.default(slice_102, [_local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391], [_local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399], '521'); slice_102 = None + wait_tensor_532 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_74); all_to_all_single_74 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(primals_420, torch.bfloat16) + all_gather_into_tensor_433 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1374, 64, '0'); convert_element_type_1374 = None + wait_tensor_533 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_433); all_gather_into_tensor_433 = None + permute_383 = torch.ops.aten.permute.default(wait_tensor_533, [1, 0]); wait_tensor_533 = None + mm_204 = torch.ops.aten.mm.default(view_1666, permute_383); permute_383 = None + convert_element_type_1377 = torch.ops.prims.convert_element_type.default(mm_204, torch.float32) + neg_50 = torch.ops.aten.neg.default(convert_element_type_1377) + exp_75 = torch.ops.aten.exp.default(neg_50); neg_50 = None + add_1700 = torch.ops.aten.add.Tensor(exp_75, 1); exp_75 = None + div_125 = torch.ops.aten.div.Tensor(convert_element_type_1377, add_1700); convert_element_type_1377 = add_1700 = None + convert_element_type_1378 = torch.ops.prims.convert_element_type.default(div_125, torch.bfloat16); div_125 = None + convert_element_type_1379 = torch.ops.prims.convert_element_type.default(primals_421, torch.bfloat16) + all_gather_into_tensor_434 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1379, 64, '0'); convert_element_type_1379 = None + wait_tensor_534 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_434); all_gather_into_tensor_434 = None + permute_384 = torch.ops.aten.permute.default(wait_tensor_534, [1, 0]); wait_tensor_534 = None + mm_205 = torch.ops.aten.mm.default(view_1666, permute_384); permute_384 = None + mul_1231 = torch.ops.aten.mul.Tensor(convert_element_type_1378, mm_205); convert_element_type_1378 = None + convert_element_type_1382 = torch.ops.prims.convert_element_type.default(primals_422, torch.bfloat16) + all_gather_into_tensor_435 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1382, 64, '0'); convert_element_type_1382 = None + wait_tensor_535 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_435); all_gather_into_tensor_435 = None + permute_385 = torch.ops.aten.permute.default(wait_tensor_535, [1, 0]); wait_tensor_535 = None + mm_206 = torch.ops.aten.mm.default(mul_1231, permute_385); permute_385 = None + index_put_49 = torch.ops.aten.index_put.default(full_default_1, [getitem_357], wait_tensor_532); wait_tensor_532 = None + view_1706 = torch.ops.aten.view.default(mul_1193, [-1, 1, 6]); mul_1193 = None + view_1707 = torch.ops.aten.view.default(index_put_49, [-1, 6, 2048]); index_put_49 = None + convert_element_type_1385 = torch.ops.prims.convert_element_type.default(view_1707, torch.float32); view_1707 = None + bmm_24 = torch.ops.aten.bmm.default(view_1706, convert_element_type_1385) + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(bmm_24, torch.bfloat16); bmm_24 = None + squeeze_24 = torch.ops.aten.squeeze.dim(convert_element_type_1386, 1); convert_element_type_1386 = None + add_1704 = torch.ops.aten.add.Tensor(mm_206, squeeze_24); mm_206 = squeeze_24 = None + view_1708 = torch.ops.aten.view.default(add_1704, [2, 4096, 2048]); add_1704 = None + add_1705 = torch.ops.aten.add.Tensor(add_1640, view_1708); view_1708 = None + convert_element_type_1387 = torch.ops.prims.convert_element_type.default(primals_423, torch.bfloat16) + all_gather_into_tensor_436 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1387, 64, '0'); convert_element_type_1387 = None + wait_tensor_536 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_436); all_gather_into_tensor_436 = None + convert_element_type_1388 = torch.ops.prims.convert_element_type.default(add_1705, torch.float32) + pow_79 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1388, 2) + mean_78 = torch.ops.aten.mean.dim(pow_79, [2], True); pow_79 = None + add_1706 = torch.ops.aten.add.Scalar(mean_78, 1e-05); mean_78 = None + rsqrt_78 = torch.ops.aten.rsqrt.default(add_1706); add_1706 = None + mul_1234 = torch.ops.aten.mul.Tensor(convert_element_type_1388, rsqrt_78); convert_element_type_1388 = None + mul_1235 = torch.ops.aten.mul.Tensor(mul_1234, wait_tensor_536); mul_1234 = wait_tensor_536 = None + convert_element_type_1389 = torch.ops.prims.convert_element_type.default(mul_1235, torch.bfloat16); mul_1235 = None + convert_element_type_1390 = torch.ops.prims.convert_element_type.default(primals_424, torch.bfloat16) + all_gather_into_tensor_437 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1390, 64, '0'); convert_element_type_1390 = None + wait_tensor_537 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_437); all_gather_into_tensor_437 = None + permute_386 = torch.ops.aten.permute.default(wait_tensor_537, [1, 0]); wait_tensor_537 = None + view_1711 = torch.ops.aten.view.default(convert_element_type_1389, [8192, 2048]); convert_element_type_1389 = None + mm_207 = torch.ops.aten.mm.default(view_1711, permute_386); permute_386 = None + view_1712 = torch.ops.aten.view.default(mm_207, [2, 4096, 3072]); mm_207 = None + view_1713 = torch.ops.aten.view.default(view_1712, [2, 4096, -1, 192]); view_1712 = None + split_with_sizes_78 = torch.ops.aten.split_with_sizes.default(view_1713, [128, 64], -1); view_1713 = None + getitem_359 = split_with_sizes_78[0] + getitem_360 = split_with_sizes_78[1]; split_with_sizes_78 = None + convert_element_type_1393 = torch.ops.prims.convert_element_type.default(getitem_360, torch.float32); getitem_360 = None + view_1714 = torch.ops.aten.view.default(convert_element_type_1393, [2, 4096, 16, -1, 2]); convert_element_type_1393 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_1714); view_1714 = None + mul_1236 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_7); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_1236); mul_1236 = None + view_1716 = torch.ops.aten.view.default(view_as_real_52, [2, 4096, 16, 64]); view_as_real_52 = None + convert_element_type_1394 = torch.ops.prims.convert_element_type.default(view_1716, torch.bfloat16); view_1716 = None + cat_77 = torch.ops.aten.cat.default([getitem_359, convert_element_type_1394], -1); getitem_359 = convert_element_type_1394 = None + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(primals_425, torch.bfloat16) + all_gather_into_tensor_438 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1395, 64, '0'); convert_element_type_1395 = None + wait_tensor_538 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_438); all_gather_into_tensor_438 = None + permute_387 = torch.ops.aten.permute.default(wait_tensor_538, [1, 0]); wait_tensor_538 = None + mm_208 = torch.ops.aten.mm.default(view_1711, permute_387); permute_387 = None + view_1719 = torch.ops.aten.view.default(mm_208, [2, 4096, 576]); mm_208 = None + split_with_sizes_79 = torch.ops.aten.split_with_sizes.default(view_1719, [512, 64], -1); view_1719 = None + getitem_361 = split_with_sizes_79[0] + getitem_362 = split_with_sizes_79[1]; split_with_sizes_79 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(getitem_362, 2); getitem_362 = None + convert_element_type_1398 = torch.ops.prims.convert_element_type.default(unsqueeze_51, torch.float32); unsqueeze_51 = None + view_1720 = torch.ops.aten.view.default(convert_element_type_1398, [2, 4096, 1, -1, 2]); convert_element_type_1398 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_1720); view_1720 = None + mul_1237 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_7); view_as_complex_53 = view_7 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_1237); mul_1237 = None + view_1722 = torch.ops.aten.view.default(view_as_real_53, [2, 4096, 1, 64]); view_as_real_53 = None + convert_element_type_1399 = torch.ops.prims.convert_element_type.default(view_1722, torch.bfloat16); view_1722 = None + convert_element_type_1400 = torch.ops.prims.convert_element_type.default(primals_426, torch.bfloat16) + all_gather_into_tensor_439 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1400, 64, '0'); convert_element_type_1400 = None + wait_tensor_539 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_439); all_gather_into_tensor_439 = None + convert_element_type_1401 = torch.ops.prims.convert_element_type.default(getitem_361, torch.float32) + pow_80 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1401, 2) + mean_79 = torch.ops.aten.mean.dim(pow_80, [2], True); pow_80 = None + add_1707 = torch.ops.aten.add.Scalar(mean_79, 1e-05); mean_79 = None + rsqrt_79 = torch.ops.aten.rsqrt.default(add_1707); add_1707 = None + mul_1238 = torch.ops.aten.mul.Tensor(convert_element_type_1401, rsqrt_79); convert_element_type_1401 = None + mul_1239 = torch.ops.aten.mul.Tensor(mul_1238, wait_tensor_539); mul_1238 = wait_tensor_539 = None + convert_element_type_1402 = torch.ops.prims.convert_element_type.default(mul_1239, torch.bfloat16); mul_1239 = None + convert_element_type_1403 = torch.ops.prims.convert_element_type.default(primals_427, torch.bfloat16) + all_gather_into_tensor_440 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1403, 64, '0'); convert_element_type_1403 = None + wait_tensor_540 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_440); all_gather_into_tensor_440 = None + permute_388 = torch.ops.aten.permute.default(wait_tensor_540, [1, 0]); wait_tensor_540 = None + view_1725 = torch.ops.aten.view.default(convert_element_type_1402, [8192, 512]); convert_element_type_1402 = None + mm_209 = torch.ops.aten.mm.default(view_1725, permute_388); permute_388 = None + view_1726 = torch.ops.aten.view.default(mm_209, [2, 4096, 4096]); mm_209 = None + view_1727 = torch.ops.aten.view.default(view_1726, [2, 4096, -1, 256]); view_1726 = None + split_with_sizes_80 = torch.ops.aten.split_with_sizes.default(view_1727, [128, 128], -1); view_1727 = None + getitem_363 = split_with_sizes_80[0] + getitem_364 = split_with_sizes_80[1]; split_with_sizes_80 = None + expand_26 = torch.ops.aten.expand.default(convert_element_type_1399, [-1, -1, 16, -1]); convert_element_type_1399 = None + cat_78 = torch.ops.aten.cat.default([getitem_363, expand_26], -1); getitem_363 = expand_26 = None + permute_389 = torch.ops.aten.permute.default(cat_77, [0, 2, 1, 3]); cat_77 = None + permute_390 = torch.ops.aten.permute.default(cat_78, [0, 2, 1, 3]); cat_78 = None + permute_391 = torch.ops.aten.permute.default(getitem_364, [0, 2, 1, 3]); getitem_364 = None + sdpa_score26 = self.sdpa_score26 + sdpa_mask26 = self.sdpa_mask26 + flex_attention_26 = torch.ops.higher_order.flex_attention(permute_389, permute_390, permute_391, sdpa_score26, (4096, 4096, primals_10, primals_9, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, 128, 128, sdpa_mask26), 0.07216878364870322, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (primals_11,)); sdpa_score26 = sdpa_mask26 = None + getitem_365 = flex_attention_26[0] + getitem_366 = flex_attention_26[1]; flex_attention_26 = None + permute_392 = torch.ops.aten.permute.default(getitem_365, [0, 2, 1, 3]) + view_1728 = torch.ops.aten.view.default(permute_392, [2, 4096, -1]); permute_392 = None + convert_element_type_1406 = torch.ops.prims.convert_element_type.default(primals_428, torch.bfloat16) + all_gather_into_tensor_441 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1406, 64, '0'); convert_element_type_1406 = None + wait_tensor_541 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_441); all_gather_into_tensor_441 = None + permute_393 = torch.ops.aten.permute.default(wait_tensor_541, [1, 0]); wait_tensor_541 = None + view_1730 = torch.ops.aten.view.default(view_1728, [8192, 2048]); view_1728 = None + mm_210 = torch.ops.aten.mm.default(view_1730, permute_393); view_1730 = permute_393 = None + view_1731 = torch.ops.aten.view.default(mm_210, [2, 4096, 2048]); mm_210 = None + add_1708 = torch.ops.aten.add.Tensor(add_1705, view_1731); view_1731 = None + convert_element_type_1409 = torch.ops.prims.convert_element_type.default(primals_429, torch.bfloat16) + all_gather_into_tensor_442 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1409, 64, '0'); convert_element_type_1409 = None + wait_tensor_542 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_442); all_gather_into_tensor_442 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(add_1708, torch.float32) + pow_81 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1410, 2) + mean_80 = torch.ops.aten.mean.dim(pow_81, [2], True); pow_81 = None + add_1709 = torch.ops.aten.add.Scalar(mean_80, 1e-05); mean_80 = None + rsqrt_80 = torch.ops.aten.rsqrt.default(add_1709); add_1709 = None + mul_1240 = torch.ops.aten.mul.Tensor(convert_element_type_1410, rsqrt_80); convert_element_type_1410 = None + mul_1241 = torch.ops.aten.mul.Tensor(mul_1240, wait_tensor_542); mul_1240 = wait_tensor_542 = None + convert_element_type_1411 = torch.ops.prims.convert_element_type.default(mul_1241, torch.bfloat16); mul_1241 = None + view_1733 = torch.ops.aten.view.default(convert_element_type_1411, [-1, 2048]); convert_element_type_1411 = None + convert_element_type_1412 = torch.ops.prims.convert_element_type.default(primals_431, torch.bfloat16) + all_gather_into_tensor_443 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1412, 64, '0'); convert_element_type_1412 = None + wait_tensor_543 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_443); all_gather_into_tensor_443 = None + permute_394 = torch.ops.aten.permute.default(wait_tensor_543, [1, 0]); wait_tensor_543 = None + mm_211 = torch.ops.aten.mm.default(view_1733, permute_394); permute_394 = None + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(mm_211, torch.float32) + amax_25 = torch.ops.aten.amax.default(convert_element_type_1415, [1], True) + sub_600 = torch.ops.aten.sub.Tensor(convert_element_type_1415, amax_25); convert_element_type_1415 = None + exp_76 = torch.ops.aten.exp.default(sub_600); sub_600 = None + sum_101 = torch.ops.aten.sum.dim_IntList(exp_76, [1], True) + div_126 = torch.ops.aten.div.Tensor(exp_76, sum_101); exp_76 = None + add_1710 = torch.ops.aten.add.Tensor(div_126, primals_430); primals_430 = None + topk_25 = torch.ops.aten.topk.default(add_1710, 6, -1, True, False); add_1710 = None + getitem_369 = topk_25[1]; topk_25 = None + gather_25 = torch.ops.aten.gather.default(div_126, 1, getitem_369); div_126 = None + mul_1242 = torch.ops.aten.mul.Tensor(gather_25, 1.0); gather_25 = None + view_1735 = torch.ops.aten.view.default(getitem_369, [-1]) + histc_50 = torch.ops.aten.histc.default(view_1735, 64, 0, 64) + add_1711 = torch.ops.aten.add.Tensor(primals_432, histc_50) + sort_25 = torch.ops.aten.sort.stable(view_1735, stable = True); view_1735 = None + getitem_371 = sort_25[1]; sort_25 = None + div_127 = torch.ops.aten.div.Tensor_mode(getitem_371, 6, rounding_mode = 'floor') + index_50 = torch.ops.aten.index.Tensor(view_1733, [div_127]) + all_to_all_single_75 = torch.ops._c10d_functional.all_to_all_single.default(histc_50, [8, 8, 8, 8, 8, 8, 8, 8], [8, 8, 8, 8, 8, 8, 8, 8], '521') + wait_tensor_544 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_75); all_to_all_single_75 = None + wait_tensor_545 = torch.ops._c10d_functional.wait_tensor.default(wait_tensor_544); wait_tensor_544 = None + view_1739 = torch.ops.aten.view.default(histc_50, [8, -1]); histc_50 = None + sum_102 = torch.ops.aten.sum.dim_IntList(view_1739, [1]); view_1739 = None + device_put_50 = torch.ops.prims.device_put.default(sum_102, device(type='cpu'), True); sum_102 = None + view_1740 = torch.ops.aten.view.default(wait_tensor_545, [8, -1]) + sum_103 = torch.ops.aten.sum.dim_IntList(view_1740, [1]) + device_put_51 = torch.ops.prims.device_put.default(sum_103, device(type='cpu')); sum_103 = None + select_400 = torch.ops.aten.select.int(device_put_50, 0, 0) + _local_scalar_dense_400 = torch.ops.aten._local_scalar_dense.default(select_400); select_400 = None + ge_500 = _local_scalar_dense_400 >= 0 + _assert_scalar_400 = torch.ops.aten._assert_scalar.default(ge_500, "Runtime assertion failed for expression u400 >= 0 on node 'ge_400'"); ge_500 = _assert_scalar_400 = None + select_401 = torch.ops.aten.select.int(device_put_50, 0, 1) + _local_scalar_dense_401 = torch.ops.aten._local_scalar_dense.default(select_401); select_401 = None + ge_501 = _local_scalar_dense_401 >= 0 + _assert_scalar_401 = torch.ops.aten._assert_scalar.default(ge_501, "Runtime assertion failed for expression u401 >= 0 on node 'ge_401'"); ge_501 = _assert_scalar_401 = None + select_402 = torch.ops.aten.select.int(device_put_50, 0, 2) + _local_scalar_dense_402 = torch.ops.aten._local_scalar_dense.default(select_402); select_402 = None + ge_502 = _local_scalar_dense_402 >= 0 + _assert_scalar_402 = torch.ops.aten._assert_scalar.default(ge_502, "Runtime assertion failed for expression u402 >= 0 on node 'ge_402'"); ge_502 = _assert_scalar_402 = None + select_403 = torch.ops.aten.select.int(device_put_50, 0, 3) + _local_scalar_dense_403 = torch.ops.aten._local_scalar_dense.default(select_403); select_403 = None + ge_503 = _local_scalar_dense_403 >= 0 + _assert_scalar_403 = torch.ops.aten._assert_scalar.default(ge_503, "Runtime assertion failed for expression u403 >= 0 on node 'ge_403'"); ge_503 = _assert_scalar_403 = None + select_404 = torch.ops.aten.select.int(device_put_50, 0, 4) + _local_scalar_dense_404 = torch.ops.aten._local_scalar_dense.default(select_404); select_404 = None + ge_504 = _local_scalar_dense_404 >= 0 + _assert_scalar_404 = torch.ops.aten._assert_scalar.default(ge_504, "Runtime assertion failed for expression u404 >= 0 on node 'ge_404'"); ge_504 = _assert_scalar_404 = None + select_405 = torch.ops.aten.select.int(device_put_50, 0, 5) + _local_scalar_dense_405 = torch.ops.aten._local_scalar_dense.default(select_405); select_405 = None + ge_505 = _local_scalar_dense_405 >= 0 + _assert_scalar_405 = torch.ops.aten._assert_scalar.default(ge_505, "Runtime assertion failed for expression u405 >= 0 on node 'ge_405'"); ge_505 = _assert_scalar_405 = None + select_406 = torch.ops.aten.select.int(device_put_50, 0, 6) + _local_scalar_dense_406 = torch.ops.aten._local_scalar_dense.default(select_406); select_406 = None + ge_506 = _local_scalar_dense_406 >= 0 + _assert_scalar_406 = torch.ops.aten._assert_scalar.default(ge_506, "Runtime assertion failed for expression u406 >= 0 on node 'ge_406'"); ge_506 = _assert_scalar_406 = None + select_407 = torch.ops.aten.select.int(device_put_50, 0, 7); device_put_50 = None + _local_scalar_dense_407 = torch.ops.aten._local_scalar_dense.default(select_407); select_407 = None + ge_507 = _local_scalar_dense_407 >= 0 + _assert_scalar_407 = torch.ops.aten._assert_scalar.default(ge_507, "Runtime assertion failed for expression u407 >= 0 on node 'ge_407'"); ge_507 = _assert_scalar_407 = None + select_408 = torch.ops.aten.select.int(device_put_51, 0, 0) + _local_scalar_dense_408 = torch.ops.aten._local_scalar_dense.default(select_408); select_408 = None + ge_508 = _local_scalar_dense_408 >= 0 + _assert_scalar_408 = torch.ops.aten._assert_scalar.default(ge_508, "Runtime assertion failed for expression u408 >= 0 on node 'ge_408'"); ge_508 = _assert_scalar_408 = None + select_409 = torch.ops.aten.select.int(device_put_51, 0, 1) + _local_scalar_dense_409 = torch.ops.aten._local_scalar_dense.default(select_409); select_409 = None + ge_509 = _local_scalar_dense_409 >= 0 + _assert_scalar_409 = torch.ops.aten._assert_scalar.default(ge_509, "Runtime assertion failed for expression u409 >= 0 on node 'ge_409'"); ge_509 = _assert_scalar_409 = None + select_410 = torch.ops.aten.select.int(device_put_51, 0, 2) + _local_scalar_dense_410 = torch.ops.aten._local_scalar_dense.default(select_410); select_410 = None + ge_510 = _local_scalar_dense_410 >= 0 + _assert_scalar_410 = torch.ops.aten._assert_scalar.default(ge_510, "Runtime assertion failed for expression u410 >= 0 on node 'ge_410'"); ge_510 = _assert_scalar_410 = None + select_411 = torch.ops.aten.select.int(device_put_51, 0, 3) + _local_scalar_dense_411 = torch.ops.aten._local_scalar_dense.default(select_411); select_411 = None + ge_511 = _local_scalar_dense_411 >= 0 + _assert_scalar_411 = torch.ops.aten._assert_scalar.default(ge_511, "Runtime assertion failed for expression u411 >= 0 on node 'ge_411'"); ge_511 = _assert_scalar_411 = None + select_412 = torch.ops.aten.select.int(device_put_51, 0, 4) + _local_scalar_dense_412 = torch.ops.aten._local_scalar_dense.default(select_412); select_412 = None + ge_512 = _local_scalar_dense_412 >= 0 + _assert_scalar_412 = torch.ops.aten._assert_scalar.default(ge_512, "Runtime assertion failed for expression u412 >= 0 on node 'ge_412'"); ge_512 = _assert_scalar_412 = None + select_413 = torch.ops.aten.select.int(device_put_51, 0, 5) + _local_scalar_dense_413 = torch.ops.aten._local_scalar_dense.default(select_413); select_413 = None + ge_513 = _local_scalar_dense_413 >= 0 + _assert_scalar_413 = torch.ops.aten._assert_scalar.default(ge_513, "Runtime assertion failed for expression u413 >= 0 on node 'ge_413'"); ge_513 = _assert_scalar_413 = None + select_414 = torch.ops.aten.select.int(device_put_51, 0, 6) + _local_scalar_dense_414 = torch.ops.aten._local_scalar_dense.default(select_414); select_414 = None + ge_514 = _local_scalar_dense_414 >= 0 + _assert_scalar_414 = torch.ops.aten._assert_scalar.default(ge_514, "Runtime assertion failed for expression u414 >= 0 on node 'ge_414'"); ge_514 = _assert_scalar_414 = None + select_415 = torch.ops.aten.select.int(device_put_51, 0, 7); device_put_51 = None + _local_scalar_dense_415 = torch.ops.aten._local_scalar_dense.default(select_415); select_415 = None + ge_515 = _local_scalar_dense_415 >= 0 + _assert_scalar_415 = torch.ops.aten._assert_scalar.default(ge_515, "Runtime assertion failed for expression u415 >= 0 on node 'ge_415'"); ge_515 = _assert_scalar_415 = None + all_to_all_single_76 = torch.ops._c10d_functional.all_to_all_single.default(index_50, [_local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415], [_local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407], '521'); index_50 = None + sym_size_int_100 = torch.ops.aten.sym_size.int(all_to_all_single_76, 0) + wait_tensor_546 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_76); all_to_all_single_76 = None + sym_sum_50 = torch.sym_sum((_local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415)) + add_1718 = sym_sum_50 + 64; sym_sum_50 = None + add_1719 = add_1718 + 8; add_1718 = None + sub_603 = add_1719 - 1; add_1719 = None + floordiv_25 = sub_603 // 8; sub_603 = None + mul_1247 = floordiv_25 * 8; floordiv_25 = None + cumsum_75 = torch.ops.aten.cumsum.default(wait_tensor_545, 0) + sub_604 = torch.ops.aten.sub.Tensor(cumsum_75, wait_tensor_545); cumsum_75 = None + sum_104 = torch.ops.aten.sum.dim_IntList(view_1740, [0]); view_1740 = None + clamp_min_25 = torch.ops.aten.clamp_min.default(sum_104, 8); sum_104 = None + add_1720 = torch.ops.aten.add.Tensor(clamp_min_25, 8); clamp_min_25 = None + sub_605 = torch.ops.aten.sub.Tensor(add_1720, 1); add_1720 = None + div_128 = torch.ops.aten.div.Tensor_mode(sub_605, 8, rounding_mode = 'floor'); sub_605 = None + mul_1248 = torch.ops.aten.mul.Tensor(div_128, 8); div_128 = None + convert_element_type_1418 = torch.ops.prims.convert_element_type.default(mul_1248, torch.int32); mul_1248 = None + cumsum_76 = torch.ops.aten.cumsum.default(convert_element_type_1418, 0) + sub_606 = torch.ops.aten.sub.Tensor(cumsum_76, convert_element_type_1418); cumsum_76 = None + full_345 = torch.ops.aten.full.default([mul_1247], -1, dtype = torch.int32, device = device(type='cuda', index=0), pin_memory = False); mul_1247 = None + triton_kernel_wrapper_functional_proxy_25 = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 25, grid = [(8, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'tokens_per_expert_group_ptr': wait_tensor_545, 'start_index_values_ptr': sub_604, 'write_offsets_ptr': sub_606, 'output_ptr': full_345}, tensors_to_clone = ['output_ptr']); wait_tensor_545 = sub_604 = sub_606 = full_345 = None + getitem_372 = triton_kernel_wrapper_functional_proxy_25['output_ptr']; triton_kernel_wrapper_functional_proxy_25 = None + cat_79 = torch.ops.aten.cat.default([wait_tensor_546, full_default]); wait_tensor_546 = full_default = None + sym_size_int_101 = torch.ops.aten.sym_size.int(cat_79, 0) + sym_sum_51 = torch.sym_sum((1, _local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415)) + index_51 = torch.ops.aten.index.Tensor(cat_79, [getitem_372]); cat_79 = None + convert_element_type_1420 = torch.ops.prims.convert_element_type.default(primals_433, torch.bfloat16) + all_gather_into_tensor_444 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1420, 8, '513'); convert_element_type_1420 = None + wait_tensor_547 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_444); all_gather_into_tensor_444 = None + convert_element_type_1422 = torch.ops.prims.convert_element_type.default(primals_434, torch.bfloat16) + all_gather_into_tensor_446 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1422, 8, '513'); convert_element_type_1422 = None + wait_tensor_549 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_446); all_gather_into_tensor_446 = None + convert_element_type_1423 = torch.ops.prims.convert_element_type.default(primals_435, torch.bfloat16) + all_gather_into_tensor_447 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1423, 8, '513'); convert_element_type_1423 = None + wait_tensor_550 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_447); all_gather_into_tensor_447 = None + cumsum_77 = torch.ops.aten.cumsum.default(convert_element_type_1418, 0, dtype = torch.int32); convert_element_type_1418 = None + permute_395 = torch.ops.aten.permute.default(wait_tensor_547, [0, 2, 1]); wait_tensor_547 = None + _grouped_mm_75 = torch.ops.aten._grouped_mm.default(index_51, permute_395, cumsum_77); permute_395 = None + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(_grouped_mm_75, torch.float32) + neg_51 = torch.ops.aten.neg.default(convert_element_type_1426) + exp_77 = torch.ops.aten.exp.default(neg_51); neg_51 = None + add_1732 = torch.ops.aten.add.Tensor(exp_77, 1); exp_77 = None + div_129 = torch.ops.aten.div.Tensor(convert_element_type_1426, add_1732); convert_element_type_1426 = add_1732 = None + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(div_129, torch.bfloat16); div_129 = None + permute_396 = torch.ops.aten.permute.default(wait_tensor_550, [0, 2, 1]); wait_tensor_550 = None + _grouped_mm_76 = torch.ops.aten._grouped_mm.default(index_51, permute_396, cumsum_77); permute_396 = None + mul_1260 = torch.ops.aten.mul.Tensor(convert_element_type_1427, _grouped_mm_76); convert_element_type_1427 = None + permute_397 = torch.ops.aten.permute.default(wait_tensor_549, [0, 2, 1]); wait_tensor_549 = None + _grouped_mm_77 = torch.ops.aten._grouped_mm.default(mul_1260, permute_397, cumsum_77); permute_397 = None + empty_25 = torch.ops.aten.empty.memory_format([sym_size_int_101, 2048], dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_50 = torch.ops.aten.index_put.default(empty_25, [getitem_372], _grouped_mm_77); empty_25 = _grouped_mm_77 = None + slice_106 = torch.ops.aten.slice.Tensor(index_put_50, 0, 0, -1); index_put_50 = None + all_to_all_single_77 = torch.ops._c10d_functional.all_to_all_single.default(slice_106, [_local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407], [_local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415], '521'); slice_106 = None + wait_tensor_553 = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_77); all_to_all_single_77 = None + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(primals_436, torch.bfloat16) + all_gather_into_tensor_450 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1428, 64, '0'); convert_element_type_1428 = None + wait_tensor_554 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_450); all_gather_into_tensor_450 = None + permute_398 = torch.ops.aten.permute.default(wait_tensor_554, [1, 0]); wait_tensor_554 = None + mm_212 = torch.ops.aten.mm.default(view_1733, permute_398); permute_398 = None + convert_element_type_1431 = torch.ops.prims.convert_element_type.default(mm_212, torch.float32) + neg_52 = torch.ops.aten.neg.default(convert_element_type_1431) + exp_78 = torch.ops.aten.exp.default(neg_52); neg_52 = None + add_1768 = torch.ops.aten.add.Tensor(exp_78, 1); exp_78 = None + div_130 = torch.ops.aten.div.Tensor(convert_element_type_1431, add_1768); convert_element_type_1431 = add_1768 = None + convert_element_type_1432 = torch.ops.prims.convert_element_type.default(div_130, torch.bfloat16); div_130 = None + convert_element_type_1433 = torch.ops.prims.convert_element_type.default(primals_437, torch.bfloat16) + all_gather_into_tensor_451 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1433, 64, '0'); convert_element_type_1433 = None + wait_tensor_555 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_451); all_gather_into_tensor_451 = None + permute_399 = torch.ops.aten.permute.default(wait_tensor_555, [1, 0]); wait_tensor_555 = None + mm_213 = torch.ops.aten.mm.default(view_1733, permute_399); permute_399 = None + mul_1280 = torch.ops.aten.mul.Tensor(convert_element_type_1432, mm_213); convert_element_type_1432 = None + convert_element_type_1436 = torch.ops.prims.convert_element_type.default(primals_438, torch.bfloat16) + all_gather_into_tensor_452 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1436, 64, '0'); convert_element_type_1436 = None + wait_tensor_556 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_452); all_gather_into_tensor_452 = None + permute_400 = torch.ops.aten.permute.default(wait_tensor_556, [1, 0]); wait_tensor_556 = None + mm_214 = torch.ops.aten.mm.default(mul_1280, permute_400); permute_400 = None + index_put_51 = torch.ops.aten.index_put.default(full_default_1, [getitem_371], wait_tensor_553); full_default_1 = wait_tensor_553 = None + view_1773 = torch.ops.aten.view.default(mul_1242, [-1, 1, 6]); mul_1242 = None + view_1774 = torch.ops.aten.view.default(index_put_51, [-1, 6, 2048]); index_put_51 = None + convert_element_type_1439 = torch.ops.prims.convert_element_type.default(view_1774, torch.float32); view_1774 = None + bmm_25 = torch.ops.aten.bmm.default(view_1773, convert_element_type_1439) + convert_element_type_1440 = torch.ops.prims.convert_element_type.default(bmm_25, torch.bfloat16); bmm_25 = None + squeeze_25 = torch.ops.aten.squeeze.dim(convert_element_type_1440, 1); convert_element_type_1440 = None + add_1772 = torch.ops.aten.add.Tensor(mm_214, squeeze_25); mm_214 = squeeze_25 = None + view_1775 = torch.ops.aten.view.default(add_1772, [2, 4096, 2048]); add_1772 = None + add_1773 = torch.ops.aten.add.Tensor(add_1708, view_1775); view_1775 = None + convert_element_type_1441 = torch.ops.prims.convert_element_type.default(primals_439, torch.bfloat16) + all_gather_into_tensor_453 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1441, 64, '0'); convert_element_type_1441 = None + wait_tensor_557 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_453); all_gather_into_tensor_453 = None + convert_element_type_1442 = torch.ops.prims.convert_element_type.default(add_1773, torch.float32) + pow_82 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1442, 2) + mean_81 = torch.ops.aten.mean.dim(pow_82, [2], True); pow_82 = None + add_1774 = torch.ops.aten.add.Scalar(mean_81, 1.1920928955078125e-07); mean_81 = None + rsqrt_81 = torch.ops.aten.rsqrt.default(add_1774); add_1774 = None + mul_1283 = torch.ops.aten.mul.Tensor(convert_element_type_1442, rsqrt_81); convert_element_type_1442 = None + mul_1284 = torch.ops.aten.mul.Tensor(mul_1283, wait_tensor_557); mul_1283 = wait_tensor_557 = None + convert_element_type_1443 = torch.ops.prims.convert_element_type.default(mul_1284, torch.bfloat16); mul_1284 = None + convert_element_type_1444 = torch.ops.prims.convert_element_type.default(primals_440, torch.bfloat16) + all_gather_into_tensor_454 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1444, 64, '0'); convert_element_type_1444 = None + wait_tensor_558 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_454); all_gather_into_tensor_454 = None + permute_401 = torch.ops.aten.permute.default(wait_tensor_558, [1, 0]); wait_tensor_558 = None + view_1778 = torch.ops.aten.view.default(convert_element_type_1443, [8192, 2048]); convert_element_type_1443 = None + mm_215 = torch.ops.aten.mm.default(view_1778, permute_401); permute_401 = None + view_1779 = torch.ops.aten.view.default(mm_215, [2, 4096, 102400]); mm_215 = None + permute_406 = torch.ops.aten.permute.default(view_1773, [0, 2, 1]); view_1773 = None + permute_407 = torch.ops.aten.permute.default(convert_element_type_1439, [0, 2, 1]); convert_element_type_1439 = None + add_1781 = 0 + sym_size_int_100; sym_size_int_100 = None + permute_456 = torch.ops.aten.permute.default(view_1706, [0, 2, 1]); view_1706 = None + permute_457 = torch.ops.aten.permute.default(convert_element_type_1385, [0, 2, 1]); convert_element_type_1385 = None + add_1796 = 0 + sym_size_int_96; sym_size_int_96 = None + permute_506 = torch.ops.aten.permute.default(view_1639, [0, 2, 1]); view_1639 = None + permute_507 = torch.ops.aten.permute.default(convert_element_type_1331, [0, 2, 1]); convert_element_type_1331 = None + add_1811 = 0 + sym_size_int_92; sym_size_int_92 = None + permute_556 = torch.ops.aten.permute.default(view_1572, [0, 2, 1]); view_1572 = None + permute_557 = torch.ops.aten.permute.default(convert_element_type_1277, [0, 2, 1]); convert_element_type_1277 = None + add_1826 = 0 + sym_size_int_88; sym_size_int_88 = None + permute_606 = torch.ops.aten.permute.default(view_1505, [0, 2, 1]); view_1505 = None + permute_607 = torch.ops.aten.permute.default(convert_element_type_1223, [0, 2, 1]); convert_element_type_1223 = None + add_1841 = 0 + sym_size_int_84; sym_size_int_84 = None + permute_656 = torch.ops.aten.permute.default(view_1438, [0, 2, 1]); view_1438 = None + permute_657 = torch.ops.aten.permute.default(convert_element_type_1169, [0, 2, 1]); convert_element_type_1169 = None + add_1856 = 0 + sym_size_int_80; sym_size_int_80 = None + permute_706 = torch.ops.aten.permute.default(view_1371, [0, 2, 1]); view_1371 = None + permute_707 = torch.ops.aten.permute.default(convert_element_type_1115, [0, 2, 1]); convert_element_type_1115 = None + add_1871 = 0 + sym_size_int_76; sym_size_int_76 = None + permute_756 = torch.ops.aten.permute.default(view_1304, [0, 2, 1]); view_1304 = None + permute_757 = torch.ops.aten.permute.default(convert_element_type_1061, [0, 2, 1]); convert_element_type_1061 = None + add_1886 = 0 + sym_size_int_72; sym_size_int_72 = None + permute_806 = torch.ops.aten.permute.default(view_1237, [0, 2, 1]); view_1237 = None + permute_807 = torch.ops.aten.permute.default(convert_element_type_1007, [0, 2, 1]); convert_element_type_1007 = None + add_1901 = 0 + sym_size_int_68; sym_size_int_68 = None + permute_856 = torch.ops.aten.permute.default(view_1170, [0, 2, 1]); view_1170 = None + permute_857 = torch.ops.aten.permute.default(convert_element_type_953, [0, 2, 1]); convert_element_type_953 = None + add_1916 = 0 + sym_size_int_64; sym_size_int_64 = None + permute_906 = torch.ops.aten.permute.default(view_1103, [0, 2, 1]); view_1103 = None + permute_907 = torch.ops.aten.permute.default(convert_element_type_899, [0, 2, 1]); convert_element_type_899 = None + add_1931 = 0 + sym_size_int_60; sym_size_int_60 = None + permute_956 = torch.ops.aten.permute.default(view_1036, [0, 2, 1]); view_1036 = None + permute_957 = torch.ops.aten.permute.default(convert_element_type_845, [0, 2, 1]); convert_element_type_845 = None + add_1946 = 0 + sym_size_int_56; sym_size_int_56 = None + permute_1006 = torch.ops.aten.permute.default(view_969, [0, 2, 1]); view_969 = None + permute_1007 = torch.ops.aten.permute.default(convert_element_type_791, [0, 2, 1]); convert_element_type_791 = None + add_1961 = 0 + sym_size_int_52; sym_size_int_52 = None + permute_1056 = torch.ops.aten.permute.default(view_902, [0, 2, 1]); view_902 = None + permute_1057 = torch.ops.aten.permute.default(convert_element_type_737, [0, 2, 1]); convert_element_type_737 = None + add_1976 = 0 + sym_size_int_48; sym_size_int_48 = None + permute_1106 = torch.ops.aten.permute.default(view_835, [0, 2, 1]); view_835 = None + permute_1107 = torch.ops.aten.permute.default(convert_element_type_683, [0, 2, 1]); convert_element_type_683 = None + add_1991 = 0 + sym_size_int_44; sym_size_int_44 = None + permute_1156 = torch.ops.aten.permute.default(view_768, [0, 2, 1]); view_768 = None + permute_1157 = torch.ops.aten.permute.default(convert_element_type_629, [0, 2, 1]); convert_element_type_629 = None + add_2006 = 0 + sym_size_int_40; sym_size_int_40 = None + permute_1206 = torch.ops.aten.permute.default(view_701, [0, 2, 1]); view_701 = None + permute_1207 = torch.ops.aten.permute.default(convert_element_type_575, [0, 2, 1]); convert_element_type_575 = None + add_2021 = 0 + sym_size_int_36; sym_size_int_36 = None + permute_1256 = torch.ops.aten.permute.default(view_634, [0, 2, 1]); view_634 = None + permute_1257 = torch.ops.aten.permute.default(convert_element_type_521, [0, 2, 1]); convert_element_type_521 = None + add_2036 = 0 + sym_size_int_32; sym_size_int_32 = None + permute_1306 = torch.ops.aten.permute.default(view_567, [0, 2, 1]); view_567 = None + permute_1307 = torch.ops.aten.permute.default(convert_element_type_467, [0, 2, 1]); convert_element_type_467 = None + add_2051 = 0 + sym_size_int_28; sym_size_int_28 = None + permute_1356 = torch.ops.aten.permute.default(view_500, [0, 2, 1]); view_500 = None + permute_1357 = torch.ops.aten.permute.default(convert_element_type_413, [0, 2, 1]); convert_element_type_413 = None + add_2066 = 0 + sym_size_int_24; sym_size_int_24 = None + permute_1406 = torch.ops.aten.permute.default(view_433, [0, 2, 1]); view_433 = None + permute_1407 = torch.ops.aten.permute.default(convert_element_type_359, [0, 2, 1]); convert_element_type_359 = None + add_2081 = 0 + sym_size_int_20; sym_size_int_20 = None + permute_1456 = torch.ops.aten.permute.default(view_366, [0, 2, 1]); view_366 = None + permute_1457 = torch.ops.aten.permute.default(convert_element_type_305, [0, 2, 1]); convert_element_type_305 = None + add_2096 = 0 + sym_size_int_16; sym_size_int_16 = None + permute_1506 = torch.ops.aten.permute.default(view_299, [0, 2, 1]); view_299 = None + permute_1507 = torch.ops.aten.permute.default(convert_element_type_251, [0, 2, 1]); convert_element_type_251 = None + add_2111 = 0 + sym_size_int_12; sym_size_int_12 = None + permute_1556 = torch.ops.aten.permute.default(view_232, [0, 2, 1]); view_232 = None + permute_1557 = torch.ops.aten.permute.default(convert_element_type_197, [0, 2, 1]); convert_element_type_197 = None + add_2126 = 0 + sym_size_int_8; sym_size_int_8 = None + permute_1606 = torch.ops.aten.permute.default(view_165, [0, 2, 1]); view_165 = None + permute_1607 = torch.ops.aten.permute.default(convert_element_type_143, [0, 2, 1]); convert_element_type_143 = None + add_2141 = 0 + sym_size_int_4; sym_size_int_4 = None + permute_1656 = torch.ops.aten.permute.default(view_98, [0, 2, 1]); view_98 = None + permute_1657 = torch.ops.aten.permute.default(convert_element_type_89, [0, 2, 1]); convert_element_type_89 = None + add_2156 = 0 + sym_size_int; sym_size_int = None + copy_ = torch.ops.aten.copy_.default(primals_32, add_11); primals_32 = add_11 = copy_ = None + copy__1 = torch.ops.aten.copy_.default(primals_48, add_79); primals_48 = add_79 = copy__1 = None + copy__2 = torch.ops.aten.copy_.default(primals_64, add_147); primals_64 = add_147 = copy__2 = None + copy__3 = torch.ops.aten.copy_.default(primals_80, add_215); primals_80 = add_215 = copy__3 = None + copy__4 = torch.ops.aten.copy_.default(primals_96, add_283); primals_96 = add_283 = copy__4 = None + copy__5 = torch.ops.aten.copy_.default(primals_112, add_351); primals_112 = add_351 = copy__5 = None + copy__6 = torch.ops.aten.copy_.default(primals_128, add_419); primals_128 = add_419 = copy__6 = None + copy__7 = torch.ops.aten.copy_.default(primals_144, add_487); primals_144 = add_487 = copy__7 = None + copy__8 = torch.ops.aten.copy_.default(primals_160, add_555); primals_160 = add_555 = copy__8 = None + copy__9 = torch.ops.aten.copy_.default(primals_176, add_623); primals_176 = add_623 = copy__9 = None + copy__10 = torch.ops.aten.copy_.default(primals_192, add_691); primals_192 = add_691 = copy__10 = None + copy__11 = torch.ops.aten.copy_.default(primals_208, add_759); primals_208 = add_759 = copy__11 = None + copy__12 = torch.ops.aten.copy_.default(primals_224, add_827); primals_224 = add_827 = copy__12 = None + copy__13 = torch.ops.aten.copy_.default(primals_240, add_895); primals_240 = add_895 = copy__13 = None + copy__14 = torch.ops.aten.copy_.default(primals_256, add_963); primals_256 = add_963 = copy__14 = None + copy__15 = torch.ops.aten.copy_.default(primals_272, add_1031); primals_272 = add_1031 = copy__15 = None + copy__16 = torch.ops.aten.copy_.default(primals_288, add_1099); primals_288 = add_1099 = copy__16 = None + copy__17 = torch.ops.aten.copy_.default(primals_304, add_1167); primals_304 = add_1167 = copy__17 = None + copy__18 = torch.ops.aten.copy_.default(primals_320, add_1235); primals_320 = add_1235 = copy__18 = None + copy__19 = torch.ops.aten.copy_.default(primals_336, add_1303); primals_336 = add_1303 = copy__19 = None + copy__20 = torch.ops.aten.copy_.default(primals_352, add_1371); primals_352 = add_1371 = copy__20 = None + copy__21 = torch.ops.aten.copy_.default(primals_368, add_1439); primals_368 = add_1439 = copy__21 = None + copy__22 = torch.ops.aten.copy_.default(primals_384, add_1507); primals_384 = add_1507 = copy__22 = None + copy__23 = torch.ops.aten.copy_.default(primals_400, add_1575); primals_400 = add_1575 = copy__23 = None + copy__24 = torch.ops.aten.copy_.default(primals_416, add_1643); primals_416 = add_1643 = copy__24 = None + copy__25 = torch.ops.aten.copy_.default(primals_432, add_1711); primals_432 = add_1711 = copy__25 = None + return (view_1779, getitem_22, sym_sum_1, _local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15, _local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7, getitem_36, sym_sum_3, _local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31, _local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23, getitem_50, sym_sum_5, _local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47, _local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39, getitem_64, sym_sum_7, _local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63, _local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55, getitem_78, sym_sum_9, _local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79, _local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71, getitem_92, sym_sum_11, _local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95, _local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87, getitem_106, sym_sum_13, _local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111, _local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103, getitem_120, sym_sum_15, _local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127, _local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119, getitem_134, sym_sum_17, _local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143, _local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135, getitem_148, sym_sum_19, _local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159, _local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151, getitem_162, sym_sum_21, _local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175, _local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167, getitem_176, sym_sum_23, _local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191, _local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183, getitem_190, sym_sum_25, _local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207, _local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199, getitem_204, sym_sum_27, _local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223, _local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215, getitem_218, sym_sum_29, _local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239, _local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231, getitem_232, sym_sum_31, _local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255, _local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247, getitem_246, sym_sum_33, _local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271, _local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263, getitem_260, sym_sum_35, _local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287, _local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279, getitem_274, sym_sum_37, _local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303, _local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295, getitem_288, sym_sum_39, _local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319, _local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311, getitem_302, sym_sum_41, _local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335, _local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327, getitem_316, sym_sum_43, _local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351, _local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343, getitem_330, sym_sum_45, _local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367, _local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359, getitem_344, sym_sum_47, _local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383, _local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375, getitem_358, sym_sum_49, _local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399, _local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391, getitem_372, sym_sum_51, _local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415, _local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_31, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_47, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_63, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_79, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_95, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_111, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_127, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_143, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_159, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_175, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_191, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_207, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_223, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_239, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_255, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_271, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_287, primals_289, primals_290, primals_291, primals_292, primals_293, primals_294, primals_295, primals_296, primals_297, primals_298, primals_299, primals_300, primals_301, primals_303, primals_305, primals_306, primals_307, primals_308, primals_309, primals_310, primals_311, primals_312, primals_313, primals_314, primals_315, primals_316, primals_317, primals_319, primals_321, primals_322, primals_323, primals_324, primals_325, primals_326, primals_327, primals_328, primals_329, primals_330, primals_331, primals_332, primals_333, primals_335, primals_337, primals_338, primals_339, primals_340, primals_341, primals_342, primals_343, primals_344, primals_345, primals_346, primals_347, primals_348, primals_349, primals_351, primals_353, primals_354, primals_355, primals_356, primals_357, primals_358, primals_359, primals_360, primals_361, primals_362, primals_363, primals_364, primals_365, primals_367, primals_369, primals_370, primals_371, primals_372, primals_373, primals_374, primals_375, primals_376, primals_377, primals_378, primals_379, primals_380, primals_381, primals_383, primals_385, primals_386, primals_387, primals_388, primals_389, primals_390, primals_391, primals_392, primals_393, primals_394, primals_395, primals_396, primals_397, primals_399, primals_401, primals_402, primals_403, primals_404, primals_405, primals_406, primals_407, primals_408, primals_409, primals_410, primals_411, primals_412, primals_413, primals_415, primals_417, primals_418, primals_419, primals_420, primals_421, primals_422, primals_423, primals_424, primals_425, primals_426, primals_427, primals_428, primals_429, primals_431, primals_433, primals_434, primals_435, primals_436, primals_437, primals_438, primals_439, primals_440, embedding, rsqrt, view_3, getitem_2, rsqrt_1, view_17, permute_3, permute_4, permute_5, getitem_6, getitem_7, mm_3, rsqrt_2, view_26, mm_4, mm_5, view_32, add_5, rsqrt_3, view_36, getitem_11, rsqrt_4, view_50, permute_14, permute_15, permute_16, getitem_15, getitem_16, add_8, rsqrt_5, view_58, mm_11, amax, sum_1, getitem_19, getitem_21, div_2, getitem_22, index_1, cumsum_2, _grouped_mm, _grouped_mm_1, mul_35, mm_12, mm_13, mul_55, add_73, rsqrt_6, view_103, getitem_25, rsqrt_7, view_117, permute_29, permute_30, permute_31, getitem_29, getitem_30, add_76, rsqrt_8, view_125, mm_19, amax_1, sum_5, getitem_33, getitem_35, div_7, getitem_36, index_3, cumsum_5, _grouped_mm_3, _grouped_mm_4, mul_84, mm_20, mm_21, mul_104, add_141, rsqrt_9, view_170, getitem_39, rsqrt_10, view_184, permute_44, permute_45, permute_46, getitem_43, getitem_44, add_144, rsqrt_11, view_192, mm_27, amax_2, sum_9, getitem_47, getitem_49, div_12, getitem_50, index_5, cumsum_8, _grouped_mm_6, _grouped_mm_7, mul_133, mm_28, mm_29, mul_153, add_209, rsqrt_12, view_237, getitem_53, rsqrt_13, view_251, permute_59, permute_60, permute_61, getitem_57, getitem_58, add_212, rsqrt_14, view_259, mm_35, amax_3, sum_13, getitem_61, getitem_63, div_17, getitem_64, index_7, cumsum_11, _grouped_mm_9, _grouped_mm_10, mul_182, mm_36, mm_37, mul_202, add_277, rsqrt_15, view_304, getitem_67, rsqrt_16, view_318, permute_74, permute_75, permute_76, getitem_71, getitem_72, add_280, rsqrt_17, view_326, mm_43, amax_4, sum_17, getitem_75, getitem_77, div_22, getitem_78, index_9, cumsum_14, _grouped_mm_12, _grouped_mm_13, mul_231, mm_44, mm_45, mul_251, add_345, rsqrt_18, view_371, getitem_81, rsqrt_19, view_385, permute_89, permute_90, permute_91, getitem_85, getitem_86, add_348, rsqrt_20, view_393, mm_51, amax_5, sum_21, getitem_89, getitem_91, div_27, getitem_92, index_11, cumsum_17, _grouped_mm_15, _grouped_mm_16, mul_280, mm_52, mm_53, mul_300, add_413, rsqrt_21, view_438, getitem_95, rsqrt_22, view_452, permute_104, permute_105, permute_106, getitem_99, getitem_100, add_416, rsqrt_23, view_460, mm_59, amax_6, sum_25, getitem_103, getitem_105, div_32, getitem_106, index_13, cumsum_20, _grouped_mm_18, _grouped_mm_19, mul_329, mm_60, mm_61, mul_349, add_481, rsqrt_24, view_505, getitem_109, rsqrt_25, view_519, permute_119, permute_120, permute_121, getitem_113, getitem_114, add_484, rsqrt_26, view_527, mm_67, amax_7, sum_29, getitem_117, getitem_119, div_37, getitem_120, index_15, cumsum_23, _grouped_mm_21, _grouped_mm_22, mul_378, mm_68, mm_69, mul_398, add_549, rsqrt_27, view_572, getitem_123, rsqrt_28, view_586, permute_134, permute_135, permute_136, getitem_127, getitem_128, add_552, rsqrt_29, view_594, mm_75, amax_8, sum_33, getitem_131, getitem_133, div_42, getitem_134, index_17, cumsum_26, _grouped_mm_24, _grouped_mm_25, mul_427, mm_76, mm_77, mul_447, add_617, rsqrt_30, view_639, getitem_137, rsqrt_31, view_653, permute_149, permute_150, permute_151, getitem_141, getitem_142, add_620, rsqrt_32, view_661, mm_83, amax_9, sum_37, getitem_145, getitem_147, div_47, getitem_148, index_19, cumsum_29, _grouped_mm_27, _grouped_mm_28, mul_476, mm_84, mm_85, mul_496, add_685, rsqrt_33, view_706, getitem_151, rsqrt_34, view_720, permute_164, permute_165, permute_166, getitem_155, getitem_156, add_688, rsqrt_35, view_728, mm_91, amax_10, sum_41, getitem_159, getitem_161, div_52, getitem_162, index_21, cumsum_32, _grouped_mm_30, _grouped_mm_31, mul_525, mm_92, mm_93, mul_545, add_753, rsqrt_36, view_773, getitem_165, rsqrt_37, view_787, permute_179, permute_180, permute_181, getitem_169, getitem_170, add_756, rsqrt_38, view_795, mm_99, amax_11, sum_45, getitem_173, getitem_175, div_57, getitem_176, index_23, cumsum_35, _grouped_mm_33, _grouped_mm_34, mul_574, mm_100, mm_101, mul_594, add_821, rsqrt_39, view_840, getitem_179, rsqrt_40, view_854, permute_194, permute_195, permute_196, getitem_183, getitem_184, add_824, rsqrt_41, view_862, mm_107, amax_12, sum_49, getitem_187, getitem_189, div_62, getitem_190, index_25, cumsum_38, _grouped_mm_36, _grouped_mm_37, mul_623, mm_108, mm_109, mul_643, add_889, rsqrt_42, view_907, getitem_193, rsqrt_43, view_921, permute_209, permute_210, permute_211, getitem_197, getitem_198, add_892, rsqrt_44, view_929, mm_115, amax_13, sum_53, getitem_201, getitem_203, div_67, getitem_204, index_27, cumsum_41, _grouped_mm_39, _grouped_mm_40, mul_672, mm_116, mm_117, mul_692, add_957, rsqrt_45, view_974, getitem_207, rsqrt_46, view_988, permute_224, permute_225, permute_226, getitem_211, getitem_212, add_960, rsqrt_47, view_996, mm_123, amax_14, sum_57, getitem_215, getitem_217, div_72, getitem_218, index_29, cumsum_44, _grouped_mm_42, _grouped_mm_43, mul_721, mm_124, mm_125, mul_741, add_1025, rsqrt_48, view_1041, getitem_221, rsqrt_49, view_1055, permute_239, permute_240, permute_241, getitem_225, getitem_226, add_1028, rsqrt_50, view_1063, mm_131, amax_15, sum_61, getitem_229, getitem_231, div_77, getitem_232, index_31, cumsum_47, _grouped_mm_45, _grouped_mm_46, mul_770, mm_132, mm_133, mul_790, add_1093, rsqrt_51, view_1108, getitem_235, rsqrt_52, view_1122, permute_254, permute_255, permute_256, getitem_239, getitem_240, add_1096, rsqrt_53, view_1130, mm_139, amax_16, sum_65, getitem_243, getitem_245, div_82, getitem_246, index_33, cumsum_50, _grouped_mm_48, _grouped_mm_49, mul_819, mm_140, mm_141, mul_839, add_1161, rsqrt_54, view_1175, getitem_249, rsqrt_55, view_1189, permute_269, permute_270, permute_271, getitem_253, getitem_254, add_1164, rsqrt_56, view_1197, mm_147, amax_17, sum_69, getitem_257, getitem_259, div_87, getitem_260, index_35, cumsum_53, _grouped_mm_51, _grouped_mm_52, mul_868, mm_148, mm_149, mul_888, add_1229, rsqrt_57, view_1242, getitem_263, rsqrt_58, view_1256, permute_284, permute_285, permute_286, getitem_267, getitem_268, add_1232, rsqrt_59, view_1264, mm_155, amax_18, sum_73, getitem_271, getitem_273, div_92, getitem_274, index_37, cumsum_56, _grouped_mm_54, _grouped_mm_55, mul_917, mm_156, mm_157, mul_937, add_1297, rsqrt_60, view_1309, getitem_277, rsqrt_61, view_1323, permute_299, permute_300, permute_301, getitem_281, getitem_282, add_1300, rsqrt_62, view_1331, mm_163, amax_19, sum_77, getitem_285, getitem_287, div_97, getitem_288, index_39, cumsum_59, _grouped_mm_57, _grouped_mm_58, mul_966, mm_164, mm_165, mul_986, add_1365, rsqrt_63, view_1376, getitem_291, rsqrt_64, view_1390, permute_314, permute_315, permute_316, getitem_295, getitem_296, add_1368, rsqrt_65, view_1398, mm_171, amax_20, sum_81, getitem_299, getitem_301, div_102, getitem_302, index_41, cumsum_62, _grouped_mm_60, _grouped_mm_61, mul_1015, mm_172, mm_173, mul_1035, add_1433, rsqrt_66, view_1443, getitem_305, rsqrt_67, view_1457, permute_329, permute_330, permute_331, getitem_309, getitem_310, add_1436, rsqrt_68, view_1465, mm_179, amax_21, sum_85, getitem_313, getitem_315, div_107, getitem_316, index_43, cumsum_65, _grouped_mm_63, _grouped_mm_64, mul_1064, mm_180, mm_181, mul_1084, add_1501, rsqrt_69, view_1510, getitem_319, rsqrt_70, view_1524, permute_344, permute_345, permute_346, getitem_323, getitem_324, add_1504, rsqrt_71, view_1532, mm_187, amax_22, sum_89, getitem_327, getitem_329, div_112, getitem_330, index_45, cumsum_68, _grouped_mm_66, _grouped_mm_67, mul_1113, mm_188, mm_189, mul_1133, add_1569, rsqrt_72, view_1577, getitem_333, rsqrt_73, view_1591, permute_359, permute_360, permute_361, getitem_337, getitem_338, add_1572, rsqrt_74, view_1599, mm_195, amax_23, sum_93, getitem_341, getitem_343, div_117, getitem_344, index_47, cumsum_71, _grouped_mm_69, _grouped_mm_70, mul_1162, mm_196, mm_197, mul_1182, add_1637, rsqrt_75, view_1644, getitem_347, rsqrt_76, view_1658, permute_374, permute_375, permute_376, getitem_351, getitem_352, add_1640, rsqrt_77, view_1666, mm_203, amax_24, sum_97, getitem_355, getitem_357, div_122, getitem_358, index_49, cumsum_74, _grouped_mm_72, _grouped_mm_73, mul_1211, mm_204, mm_205, mul_1231, add_1705, rsqrt_78, view_1711, getitem_361, rsqrt_79, view_1725, permute_389, permute_390, permute_391, getitem_365, getitem_366, add_1708, rsqrt_80, view_1733, mm_211, amax_25, sum_101, getitem_369, getitem_371, div_127, getitem_372, index_51, cumsum_77, _grouped_mm_75, _grouped_mm_76, mul_1260, mm_212, mm_213, mul_1280, add_1773, rsqrt_81, view_1778, permute_406, permute_407, permute_456, permute_457, permute_506, permute_507, permute_556, permute_557, permute_606, permute_607, permute_656, permute_657, permute_706, permute_707, permute_756, permute_757, permute_806, permute_807, permute_856, permute_857, permute_906, permute_907, permute_956, permute_957, permute_1006, permute_1007, permute_1056, permute_1057, permute_1106, permute_1107, permute_1156, permute_1157, permute_1206, permute_1207, permute_1256, permute_1257, permute_1306, permute_1307, permute_1356, permute_1357, permute_1406, permute_1407, permute_1456, permute_1457, permute_1506, permute_1507, permute_1556, permute_1557, permute_1606, permute_1607, permute_1656, permute_1657, _local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2, _local_scalar_dense_3, _local_scalar_dense_4, _local_scalar_dense_5, _local_scalar_dense_6, _local_scalar_dense_7, _local_scalar_dense_8, _local_scalar_dense_9, _local_scalar_dense_10, _local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13, _local_scalar_dense_14, _local_scalar_dense_15, _local_scalar_dense_16, _local_scalar_dense_17, _local_scalar_dense_18, _local_scalar_dense_19, _local_scalar_dense_20, _local_scalar_dense_21, _local_scalar_dense_22, _local_scalar_dense_23, _local_scalar_dense_24, _local_scalar_dense_25, _local_scalar_dense_26, _local_scalar_dense_27, _local_scalar_dense_28, _local_scalar_dense_29, _local_scalar_dense_30, _local_scalar_dense_31, _local_scalar_dense_32, _local_scalar_dense_33, _local_scalar_dense_34, _local_scalar_dense_35, _local_scalar_dense_36, _local_scalar_dense_37, _local_scalar_dense_38, _local_scalar_dense_39, _local_scalar_dense_40, _local_scalar_dense_41, _local_scalar_dense_42, _local_scalar_dense_43, _local_scalar_dense_44, _local_scalar_dense_45, _local_scalar_dense_46, _local_scalar_dense_47, _local_scalar_dense_48, _local_scalar_dense_49, _local_scalar_dense_50, _local_scalar_dense_51, _local_scalar_dense_52, _local_scalar_dense_53, _local_scalar_dense_54, _local_scalar_dense_55, _local_scalar_dense_56, _local_scalar_dense_57, _local_scalar_dense_58, _local_scalar_dense_59, _local_scalar_dense_60, _local_scalar_dense_61, _local_scalar_dense_62, _local_scalar_dense_63, _local_scalar_dense_64, _local_scalar_dense_65, _local_scalar_dense_66, _local_scalar_dense_67, _local_scalar_dense_68, _local_scalar_dense_69, _local_scalar_dense_70, _local_scalar_dense_71, _local_scalar_dense_72, _local_scalar_dense_73, _local_scalar_dense_74, _local_scalar_dense_75, _local_scalar_dense_76, _local_scalar_dense_77, _local_scalar_dense_78, _local_scalar_dense_79, _local_scalar_dense_80, _local_scalar_dense_81, _local_scalar_dense_82, _local_scalar_dense_83, _local_scalar_dense_84, _local_scalar_dense_85, _local_scalar_dense_86, _local_scalar_dense_87, _local_scalar_dense_88, _local_scalar_dense_89, _local_scalar_dense_90, _local_scalar_dense_91, _local_scalar_dense_92, _local_scalar_dense_93, _local_scalar_dense_94, _local_scalar_dense_95, _local_scalar_dense_96, _local_scalar_dense_97, _local_scalar_dense_98, _local_scalar_dense_99, _local_scalar_dense_100, _local_scalar_dense_101, _local_scalar_dense_102, _local_scalar_dense_103, _local_scalar_dense_104, _local_scalar_dense_105, _local_scalar_dense_106, _local_scalar_dense_107, _local_scalar_dense_108, _local_scalar_dense_109, _local_scalar_dense_110, _local_scalar_dense_111, _local_scalar_dense_112, _local_scalar_dense_113, _local_scalar_dense_114, _local_scalar_dense_115, _local_scalar_dense_116, _local_scalar_dense_117, _local_scalar_dense_118, _local_scalar_dense_119, _local_scalar_dense_120, _local_scalar_dense_121, _local_scalar_dense_122, _local_scalar_dense_123, _local_scalar_dense_124, _local_scalar_dense_125, _local_scalar_dense_126, _local_scalar_dense_127, _local_scalar_dense_128, _local_scalar_dense_129, _local_scalar_dense_130, _local_scalar_dense_131, _local_scalar_dense_132, _local_scalar_dense_133, _local_scalar_dense_134, _local_scalar_dense_135, _local_scalar_dense_136, _local_scalar_dense_137, _local_scalar_dense_138, _local_scalar_dense_139, _local_scalar_dense_140, _local_scalar_dense_141, _local_scalar_dense_142, _local_scalar_dense_143, _local_scalar_dense_144, _local_scalar_dense_145, _local_scalar_dense_146, _local_scalar_dense_147, _local_scalar_dense_148, _local_scalar_dense_149, _local_scalar_dense_150, _local_scalar_dense_151, _local_scalar_dense_152, _local_scalar_dense_153, _local_scalar_dense_154, _local_scalar_dense_155, _local_scalar_dense_156, _local_scalar_dense_157, _local_scalar_dense_158, _local_scalar_dense_159, _local_scalar_dense_160, _local_scalar_dense_161, _local_scalar_dense_162, _local_scalar_dense_163, _local_scalar_dense_164, _local_scalar_dense_165, _local_scalar_dense_166, _local_scalar_dense_167, _local_scalar_dense_168, _local_scalar_dense_169, _local_scalar_dense_170, _local_scalar_dense_171, _local_scalar_dense_172, _local_scalar_dense_173, _local_scalar_dense_174, _local_scalar_dense_175, _local_scalar_dense_176, _local_scalar_dense_177, _local_scalar_dense_178, _local_scalar_dense_179, _local_scalar_dense_180, _local_scalar_dense_181, _local_scalar_dense_182, _local_scalar_dense_183, _local_scalar_dense_184, _local_scalar_dense_185, _local_scalar_dense_186, _local_scalar_dense_187, _local_scalar_dense_188, _local_scalar_dense_189, _local_scalar_dense_190, _local_scalar_dense_191, _local_scalar_dense_192, _local_scalar_dense_193, _local_scalar_dense_194, _local_scalar_dense_195, _local_scalar_dense_196, _local_scalar_dense_197, _local_scalar_dense_198, _local_scalar_dense_199, _local_scalar_dense_200, _local_scalar_dense_201, _local_scalar_dense_202, _local_scalar_dense_203, _local_scalar_dense_204, _local_scalar_dense_205, _local_scalar_dense_206, _local_scalar_dense_207, _local_scalar_dense_208, _local_scalar_dense_209, _local_scalar_dense_210, _local_scalar_dense_211, _local_scalar_dense_212, _local_scalar_dense_213, _local_scalar_dense_214, _local_scalar_dense_215, _local_scalar_dense_216, _local_scalar_dense_217, _local_scalar_dense_218, _local_scalar_dense_219, _local_scalar_dense_220, _local_scalar_dense_221, _local_scalar_dense_222, _local_scalar_dense_223, _local_scalar_dense_224, _local_scalar_dense_225, _local_scalar_dense_226, _local_scalar_dense_227, _local_scalar_dense_228, _local_scalar_dense_229, _local_scalar_dense_230, _local_scalar_dense_231, _local_scalar_dense_232, _local_scalar_dense_233, _local_scalar_dense_234, _local_scalar_dense_235, _local_scalar_dense_236, _local_scalar_dense_237, _local_scalar_dense_238, _local_scalar_dense_239, _local_scalar_dense_240, _local_scalar_dense_241, _local_scalar_dense_242, _local_scalar_dense_243, _local_scalar_dense_244, _local_scalar_dense_245, _local_scalar_dense_246, _local_scalar_dense_247, _local_scalar_dense_248, _local_scalar_dense_249, _local_scalar_dense_250, _local_scalar_dense_251, _local_scalar_dense_252, _local_scalar_dense_253, _local_scalar_dense_254, _local_scalar_dense_255, _local_scalar_dense_256, _local_scalar_dense_257, _local_scalar_dense_258, _local_scalar_dense_259, _local_scalar_dense_260, _local_scalar_dense_261, _local_scalar_dense_262, _local_scalar_dense_263, _local_scalar_dense_264, _local_scalar_dense_265, _local_scalar_dense_266, _local_scalar_dense_267, _local_scalar_dense_268, _local_scalar_dense_269, _local_scalar_dense_270, _local_scalar_dense_271, _local_scalar_dense_272, _local_scalar_dense_273, _local_scalar_dense_274, _local_scalar_dense_275, _local_scalar_dense_276, _local_scalar_dense_277, _local_scalar_dense_278, _local_scalar_dense_279, _local_scalar_dense_280, _local_scalar_dense_281, _local_scalar_dense_282, _local_scalar_dense_283, _local_scalar_dense_284, _local_scalar_dense_285, _local_scalar_dense_286, _local_scalar_dense_287, _local_scalar_dense_288, _local_scalar_dense_289, _local_scalar_dense_290, _local_scalar_dense_291, _local_scalar_dense_292, _local_scalar_dense_293, _local_scalar_dense_294, _local_scalar_dense_295, _local_scalar_dense_296, _local_scalar_dense_297, _local_scalar_dense_298, _local_scalar_dense_299, _local_scalar_dense_300, _local_scalar_dense_301, _local_scalar_dense_302, _local_scalar_dense_303, _local_scalar_dense_304, _local_scalar_dense_305, _local_scalar_dense_306, _local_scalar_dense_307, _local_scalar_dense_308, _local_scalar_dense_309, _local_scalar_dense_310, _local_scalar_dense_311, _local_scalar_dense_312, _local_scalar_dense_313, _local_scalar_dense_314, _local_scalar_dense_315, _local_scalar_dense_316, _local_scalar_dense_317, _local_scalar_dense_318, _local_scalar_dense_319, _local_scalar_dense_320, _local_scalar_dense_321, _local_scalar_dense_322, _local_scalar_dense_323, _local_scalar_dense_324, _local_scalar_dense_325, _local_scalar_dense_326, _local_scalar_dense_327, _local_scalar_dense_328, _local_scalar_dense_329, _local_scalar_dense_330, _local_scalar_dense_331, _local_scalar_dense_332, _local_scalar_dense_333, _local_scalar_dense_334, _local_scalar_dense_335, _local_scalar_dense_336, _local_scalar_dense_337, _local_scalar_dense_338, _local_scalar_dense_339, _local_scalar_dense_340, _local_scalar_dense_341, _local_scalar_dense_342, _local_scalar_dense_343, _local_scalar_dense_344, _local_scalar_dense_345, _local_scalar_dense_346, _local_scalar_dense_347, _local_scalar_dense_348, _local_scalar_dense_349, _local_scalar_dense_350, _local_scalar_dense_351, _local_scalar_dense_352, _local_scalar_dense_353, _local_scalar_dense_354, _local_scalar_dense_355, _local_scalar_dense_356, _local_scalar_dense_357, _local_scalar_dense_358, _local_scalar_dense_359, _local_scalar_dense_360, _local_scalar_dense_361, _local_scalar_dense_362, _local_scalar_dense_363, _local_scalar_dense_364, _local_scalar_dense_365, _local_scalar_dense_366, _local_scalar_dense_367, _local_scalar_dense_368, _local_scalar_dense_369, _local_scalar_dense_370, _local_scalar_dense_371, _local_scalar_dense_372, _local_scalar_dense_373, _local_scalar_dense_374, _local_scalar_dense_375, _local_scalar_dense_376, _local_scalar_dense_377, _local_scalar_dense_378, _local_scalar_dense_379, _local_scalar_dense_380, _local_scalar_dense_381, _local_scalar_dense_382, _local_scalar_dense_383, _local_scalar_dense_384, _local_scalar_dense_385, _local_scalar_dense_386, _local_scalar_dense_387, _local_scalar_dense_388, _local_scalar_dense_389, _local_scalar_dense_390, _local_scalar_dense_391, _local_scalar_dense_392, _local_scalar_dense_393, _local_scalar_dense_394, _local_scalar_dense_395, _local_scalar_dense_396, _local_scalar_dense_397, _local_scalar_dense_398, _local_scalar_dense_399, _local_scalar_dense_400, _local_scalar_dense_401, _local_scalar_dense_402, _local_scalar_dense_403, _local_scalar_dense_404, _local_scalar_dense_405, _local_scalar_dense_406, _local_scalar_dense_407, _local_scalar_dense_408, _local_scalar_dense_409, _local_scalar_dense_410, _local_scalar_dense_411, _local_scalar_dense_412, _local_scalar_dense_413, _local_scalar_dense_414, _local_scalar_dense_415, sym_size_int_1, sym_size_int_5, sym_size_int_9, sym_size_int_13, sym_size_int_17, sym_size_int_21, sym_size_int_25, sym_size_int_29, sym_size_int_33, sym_size_int_37, sym_size_int_41, sym_size_int_45, sym_size_int_49, sym_size_int_53, sym_size_int_57, sym_size_int_61, sym_size_int_65, sym_size_int_69, sym_size_int_73, sym_size_int_77, sym_size_int_81, sym_size_int_85, sym_size_int_89, sym_size_int_93, sym_size_int_97, sym_size_int_101, add_1781, add_1796, add_1811, add_1826, add_1841, add_1856, add_1871, add_1886, add_1901, add_1916, add_1931, add_1946, add_1961, add_1976, add_1991, add_2006, add_2021, add_2036, add_2051, add_2066, add_2081, add_2096, add_2111, add_2126, add_2141, add_2156) + +def load_args(reader): + buf0 = reader.storage(None, 13107200, device=device(type='cuda', index=0)) + reader.tensor(buf0, (1600, 2048), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 65536, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 4096), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 1048576, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (4096, 32), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf3, (32,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf4, (48, 2048), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf5, (9, 2048), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf6, (8,), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf7, (64, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf8, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_9 + buf9 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf9, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_10 + buf10 = reader.storage(None, 32768, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf10, (2, 4096), dtype=torch.int32, is_leaf=True) # primals_11 + buf11 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf11, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_12 + buf12 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf12, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_13 + buf13 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf13, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_14 + buf14 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf14, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_15 + buf15 = reader.storage(None, 256, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf15, (2, 1, 32), dtype=torch.int32, is_leaf=True) # primals_16 + buf16 = reader.storage(None, 8192, device=device(type='cuda', index=0), dtype_hint=torch.int32) + reader.tensor(buf16, (2, 1, 32, 32), dtype=torch.int32, is_leaf=True) # primals_17 + buf17 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf17, (32, 2048), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf18, (32,), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 1400832, device=device(type='cuda', index=0)) + reader.tensor(buf19, (171, 2048), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 1400832, device=device(type='cuda', index=0)) + reader.tensor(buf20, (171, 2048), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 1400832, device=device(type='cuda', index=0)) + reader.tensor(buf21, (32, 10944), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf22, (32,), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf23, (48, 2048), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf24, (9, 2048), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf25, (8,), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf26, (64, 512), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf27, (32, 2048), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf28, (32,), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf29, (64,), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf30, (1, 2048), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf31, (64,), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf32, (1, 1408, 2048), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf33, (1, 2048, 1408), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf34, (1, 1408, 2048), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf35, (44, 2048), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf36, (44, 2048), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf37, (32, 2816), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf38, (32,), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf39, (48, 2048), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf40, (9, 2048), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf41, (8,), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf42, (64, 512), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf43, (32, 2048), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf44, (32,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf45, (64,), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf46, (1, 2048), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf47, (64,), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf48, (1, 1408, 2048), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf49, (1, 2048, 1408), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf50, (1, 1408, 2048), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf51, (44, 2048), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf52, (44, 2048), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf53, (32, 2816), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf54, (32,), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf55, (48, 2048), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf56, (9, 2048), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf57, (8,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf58, (64, 512), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf59, (32, 2048), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf60, (32,), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf61, (64,), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf62, (1, 2048), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf63, (64,), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf64, (1, 1408, 2048), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf65, (1, 2048, 1408), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf66, (1, 1408, 2048), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf67, (44, 2048), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf68, (44, 2048), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf69, (32, 2816), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf70, (32,), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf71, (48, 2048), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf72, (9, 2048), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf73, (8,), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf74, (64, 512), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf75, (32, 2048), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf76, (32,), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf77, (64,), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf78, (1, 2048), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf79, (64,), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf80, (1, 1408, 2048), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf81, (1, 2048, 1408), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf82, (1, 1408, 2048), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf83, (44, 2048), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf84, (44, 2048), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf85, (32, 2816), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf86, (32,), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf87, (48, 2048), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf88, (9, 2048), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf89, (8,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf90, (64, 512), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf91, (32, 2048), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf92, (32,), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf93, (64,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf94, (1, 2048), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf95, (64,), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf96, (1, 1408, 2048), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf97, (1, 2048, 1408), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf98, (1, 1408, 2048), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf99, (44, 2048), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf100, (44, 2048), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf101, (32, 2816), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf102, (32,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf103, (48, 2048), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf104, (9, 2048), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf105, (8,), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf106, (64, 512), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf107, (32, 2048), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf108, (32,), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf109, (64,), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf110, (1, 2048), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf111, (64,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf112, (1, 1408, 2048), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf113, (1, 2048, 1408), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf114, (1, 1408, 2048), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf115, (44, 2048), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf116, (44, 2048), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf117, (32, 2816), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf118, (32,), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf119, (48, 2048), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf120, (9, 2048), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf121, (8,), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf122, (64, 512), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf123, (32, 2048), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf124, (32,), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf125, (64,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf126, (1, 2048), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf127, (64,), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf128, (1, 1408, 2048), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf129, (1, 2048, 1408), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf130, (1, 1408, 2048), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf131, (44, 2048), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf132, (44, 2048), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf133, (32, 2816), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf134, (32,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf135, (48, 2048), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf136, (9, 2048), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf137, (8,), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf138, (64, 512), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf139, (32, 2048), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf140, (32,), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf141, (64,), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf142, (1, 2048), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf143, (64,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf144, (1, 1408, 2048), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf145, (1, 2048, 1408), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf146, (1, 1408, 2048), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf147, (44, 2048), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf148, (44, 2048), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf149, (32, 2816), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf150, (32,), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf151, (48, 2048), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf152, (9, 2048), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf153, (8,), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf154, (64, 512), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf155, (32, 2048), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf156, (32,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf157, (64,), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf158, (1, 2048), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf159, (64,), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf160, (1, 1408, 2048), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf161, (1, 2048, 1408), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf162, (1, 1408, 2048), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf163, (44, 2048), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf164, (44, 2048), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf165, (32, 2816), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf166, (32,), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf167, (48, 2048), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf168, (9, 2048), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf169, (8,), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf170, (64, 512), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf171, (32, 2048), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf172, (32,), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf173, (64,), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf174, (1, 2048), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf175, (64,), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf176, (1, 1408, 2048), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf177, (1, 2048, 1408), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf178, (1, 1408, 2048), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf179, (44, 2048), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf180, (44, 2048), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf181, (32, 2816), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf182, (32,), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf183, (48, 2048), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf184, (9, 2048), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf185, (8,), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf186, (64, 512), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf187, (32, 2048), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf188, (32,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf189, (64,), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf190, (1, 2048), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf191, (64,), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf192, (1, 1408, 2048), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf193, (1, 2048, 1408), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf194, (1, 1408, 2048), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf195, (44, 2048), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf196, (44, 2048), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf197, (32, 2816), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf198, (32,), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf199, (48, 2048), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf200, (9, 2048), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf201, (8,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf202, (64, 512), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf203, (32, 2048), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf204, (32,), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf205, (64,), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf206, (1, 2048), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf207, (64,), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf208, (1, 1408, 2048), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf209, (1, 2048, 1408), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf210, (1, 1408, 2048), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf211, (44, 2048), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf212, (44, 2048), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf213, (32, 2816), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf214, (32,), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf215, (48, 2048), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf216, (9, 2048), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf217, (8,), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf218, (64, 512), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf219, (32, 2048), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf220, (32,), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf221, (64,), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf222, (1, 2048), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf223, (64,), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf224, (1, 1408, 2048), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf225, (1, 2048, 1408), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf226, (1, 1408, 2048), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf227, (44, 2048), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf228, (44, 2048), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf229, (32, 2816), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf230, (32,), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf231, (48, 2048), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf232, (9, 2048), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf233, (8,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf234, (64, 512), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf235, (32, 2048), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf236, (32,), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf237, (64,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf238, (1, 2048), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf239, (64,), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf240, (1, 1408, 2048), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf241, (1, 2048, 1408), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf242, (1, 1408, 2048), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf243, (44, 2048), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf244, (44, 2048), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf245, (32, 2816), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf246, (32,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf247, (48, 2048), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf248, (9, 2048), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf249, (8,), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf250, (64, 512), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf251, (32, 2048), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf252, (32,), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf253, (64,), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf254, (1, 2048), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf255, (64,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf256, (1, 1408, 2048), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf257, (1, 2048, 1408), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf258, (1, 1408, 2048), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf259, (44, 2048), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf260, (44, 2048), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf261, (32, 2816), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf262, (32,), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf263, (48, 2048), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf264, (9, 2048), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf265, (8,), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf266, (64, 512), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf267, (32, 2048), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf268, (32,), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf269, (64,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf270, (1, 2048), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf271, (64,), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf272, (1, 1408, 2048), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf273, (1, 2048, 1408), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf274, (1, 1408, 2048), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf275, (44, 2048), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf276, (44, 2048), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf277, (32, 2816), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf278, (32,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf279, (48, 2048), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf280, (9, 2048), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf281, (8,), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf282, (64, 512), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf283, (32, 2048), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf284, (32,), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf285, (64,), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf286, (1, 2048), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf287, (64,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf288, (1, 1408, 2048), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf289, (1, 2048, 1408), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf290, (1, 1408, 2048), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf291, (44, 2048), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf292, (44, 2048), is_leaf=True) # primals_293 + buf293 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf293, (32, 2816), is_leaf=True) # primals_294 + buf294 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf294, (32,), is_leaf=True) # primals_295 + buf295 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf295, (48, 2048), is_leaf=True) # primals_296 + buf296 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf296, (9, 2048), is_leaf=True) # primals_297 + buf297 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf297, (8,), is_leaf=True) # primals_298 + buf298 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf298, (64, 512), is_leaf=True) # primals_299 + buf299 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf299, (32, 2048), is_leaf=True) # primals_300 + buf300 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf300, (32,), is_leaf=True) # primals_301 + buf301 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf301, (64,), is_leaf=True) # primals_302 + buf302 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf302, (1, 2048), is_leaf=True) # primals_303 + buf303 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf303, (64,), is_leaf=True) # primals_304 + buf304 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf304, (1, 1408, 2048), is_leaf=True) # primals_305 + buf305 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf305, (1, 2048, 1408), is_leaf=True) # primals_306 + buf306 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf306, (1, 1408, 2048), is_leaf=True) # primals_307 + buf307 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf307, (44, 2048), is_leaf=True) # primals_308 + buf308 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf308, (44, 2048), is_leaf=True) # primals_309 + buf309 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf309, (32, 2816), is_leaf=True) # primals_310 + buf310 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf310, (32,), is_leaf=True) # primals_311 + buf311 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf311, (48, 2048), is_leaf=True) # primals_312 + buf312 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf312, (9, 2048), is_leaf=True) # primals_313 + buf313 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf313, (8,), is_leaf=True) # primals_314 + buf314 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf314, (64, 512), is_leaf=True) # primals_315 + buf315 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf315, (32, 2048), is_leaf=True) # primals_316 + buf316 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf316, (32,), is_leaf=True) # primals_317 + buf317 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf317, (64,), is_leaf=True) # primals_318 + buf318 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf318, (1, 2048), is_leaf=True) # primals_319 + buf319 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf319, (64,), is_leaf=True) # primals_320 + buf320 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf320, (1, 1408, 2048), is_leaf=True) # primals_321 + buf321 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf321, (1, 2048, 1408), is_leaf=True) # primals_322 + buf322 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf322, (1, 1408, 2048), is_leaf=True) # primals_323 + buf323 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf323, (44, 2048), is_leaf=True) # primals_324 + buf324 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf324, (44, 2048), is_leaf=True) # primals_325 + buf325 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf325, (32, 2816), is_leaf=True) # primals_326 + buf326 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf326, (32,), is_leaf=True) # primals_327 + buf327 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf327, (48, 2048), is_leaf=True) # primals_328 + buf328 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf328, (9, 2048), is_leaf=True) # primals_329 + buf329 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf329, (8,), is_leaf=True) # primals_330 + buf330 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf330, (64, 512), is_leaf=True) # primals_331 + buf331 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf331, (32, 2048), is_leaf=True) # primals_332 + buf332 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf332, (32,), is_leaf=True) # primals_333 + buf333 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf333, (64,), is_leaf=True) # primals_334 + buf334 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf334, (1, 2048), is_leaf=True) # primals_335 + buf335 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf335, (64,), is_leaf=True) # primals_336 + buf336 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf336, (1, 1408, 2048), is_leaf=True) # primals_337 + buf337 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf337, (1, 2048, 1408), is_leaf=True) # primals_338 + buf338 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf338, (1, 1408, 2048), is_leaf=True) # primals_339 + buf339 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf339, (44, 2048), is_leaf=True) # primals_340 + buf340 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf340, (44, 2048), is_leaf=True) # primals_341 + buf341 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf341, (32, 2816), is_leaf=True) # primals_342 + buf342 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf342, (32,), is_leaf=True) # primals_343 + buf343 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf343, (48, 2048), is_leaf=True) # primals_344 + buf344 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf344, (9, 2048), is_leaf=True) # primals_345 + buf345 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf345, (8,), is_leaf=True) # primals_346 + buf346 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf346, (64, 512), is_leaf=True) # primals_347 + buf347 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf347, (32, 2048), is_leaf=True) # primals_348 + buf348 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf348, (32,), is_leaf=True) # primals_349 + buf349 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf349, (64,), is_leaf=True) # primals_350 + buf350 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf350, (1, 2048), is_leaf=True) # primals_351 + buf351 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf351, (64,), is_leaf=True) # primals_352 + buf352 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf352, (1, 1408, 2048), is_leaf=True) # primals_353 + buf353 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf353, (1, 2048, 1408), is_leaf=True) # primals_354 + buf354 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf354, (1, 1408, 2048), is_leaf=True) # primals_355 + buf355 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf355, (44, 2048), is_leaf=True) # primals_356 + buf356 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf356, (44, 2048), is_leaf=True) # primals_357 + buf357 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf357, (32, 2816), is_leaf=True) # primals_358 + buf358 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf358, (32,), is_leaf=True) # primals_359 + buf359 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf359, (48, 2048), is_leaf=True) # primals_360 + buf360 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf360, (9, 2048), is_leaf=True) # primals_361 + buf361 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf361, (8,), is_leaf=True) # primals_362 + buf362 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf362, (64, 512), is_leaf=True) # primals_363 + buf363 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf363, (32, 2048), is_leaf=True) # primals_364 + buf364 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf364, (32,), is_leaf=True) # primals_365 + buf365 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf365, (64,), is_leaf=True) # primals_366 + buf366 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf366, (1, 2048), is_leaf=True) # primals_367 + buf367 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf367, (64,), is_leaf=True) # primals_368 + buf368 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf368, (1, 1408, 2048), is_leaf=True) # primals_369 + buf369 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf369, (1, 2048, 1408), is_leaf=True) # primals_370 + buf370 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf370, (1, 1408, 2048), is_leaf=True) # primals_371 + buf371 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf371, (44, 2048), is_leaf=True) # primals_372 + buf372 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf372, (44, 2048), is_leaf=True) # primals_373 + buf373 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf373, (32, 2816), is_leaf=True) # primals_374 + buf374 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf374, (32,), is_leaf=True) # primals_375 + buf375 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf375, (48, 2048), is_leaf=True) # primals_376 + buf376 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf376, (9, 2048), is_leaf=True) # primals_377 + buf377 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf377, (8,), is_leaf=True) # primals_378 + buf378 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf378, (64, 512), is_leaf=True) # primals_379 + buf379 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf379, (32, 2048), is_leaf=True) # primals_380 + buf380 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf380, (32,), is_leaf=True) # primals_381 + buf381 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf381, (64,), is_leaf=True) # primals_382 + buf382 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf382, (1, 2048), is_leaf=True) # primals_383 + buf383 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf383, (64,), is_leaf=True) # primals_384 + buf384 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf384, (1, 1408, 2048), is_leaf=True) # primals_385 + buf385 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf385, (1, 2048, 1408), is_leaf=True) # primals_386 + buf386 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf386, (1, 1408, 2048), is_leaf=True) # primals_387 + buf387 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf387, (44, 2048), is_leaf=True) # primals_388 + buf388 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf388, (44, 2048), is_leaf=True) # primals_389 + buf389 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf389, (32, 2816), is_leaf=True) # primals_390 + buf390 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf390, (32,), is_leaf=True) # primals_391 + buf391 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf391, (48, 2048), is_leaf=True) # primals_392 + buf392 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf392, (9, 2048), is_leaf=True) # primals_393 + buf393 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf393, (8,), is_leaf=True) # primals_394 + buf394 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf394, (64, 512), is_leaf=True) # primals_395 + buf395 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf395, (32, 2048), is_leaf=True) # primals_396 + buf396 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf396, (32,), is_leaf=True) # primals_397 + buf397 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf397, (64,), is_leaf=True) # primals_398 + buf398 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf398, (1, 2048), is_leaf=True) # primals_399 + buf399 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf399, (64,), is_leaf=True) # primals_400 + buf400 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf400, (1, 1408, 2048), is_leaf=True) # primals_401 + buf401 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf401, (1, 2048, 1408), is_leaf=True) # primals_402 + buf402 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf402, (1, 1408, 2048), is_leaf=True) # primals_403 + buf403 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf403, (44, 2048), is_leaf=True) # primals_404 + buf404 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf404, (44, 2048), is_leaf=True) # primals_405 + buf405 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf405, (32, 2816), is_leaf=True) # primals_406 + buf406 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf406, (32,), is_leaf=True) # primals_407 + buf407 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf407, (48, 2048), is_leaf=True) # primals_408 + buf408 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf408, (9, 2048), is_leaf=True) # primals_409 + buf409 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf409, (8,), is_leaf=True) # primals_410 + buf410 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf410, (64, 512), is_leaf=True) # primals_411 + buf411 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf411, (32, 2048), is_leaf=True) # primals_412 + buf412 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf412, (32,), is_leaf=True) # primals_413 + buf413 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf413, (64,), is_leaf=True) # primals_414 + buf414 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf414, (1, 2048), is_leaf=True) # primals_415 + buf415 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf415, (64,), is_leaf=True) # primals_416 + buf416 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf416, (1, 1408, 2048), is_leaf=True) # primals_417 + buf417 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf417, (1, 2048, 1408), is_leaf=True) # primals_418 + buf418 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf418, (1, 1408, 2048), is_leaf=True) # primals_419 + buf419 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf419, (44, 2048), is_leaf=True) # primals_420 + buf420 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf420, (44, 2048), is_leaf=True) # primals_421 + buf421 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf421, (32, 2816), is_leaf=True) # primals_422 + buf422 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf422, (32,), is_leaf=True) # primals_423 + buf423 = reader.storage(None, 393216, device=device(type='cuda', index=0)) + reader.tensor(buf423, (48, 2048), is_leaf=True) # primals_424 + buf424 = reader.storage(None, 73728, device=device(type='cuda', index=0)) + reader.tensor(buf424, (9, 2048), is_leaf=True) # primals_425 + buf425 = reader.storage(None, 32, device=device(type='cuda', index=0)) + reader.tensor(buf425, (8,), is_leaf=True) # primals_426 + buf426 = reader.storage(None, 131072, device=device(type='cuda', index=0)) + reader.tensor(buf426, (64, 512), is_leaf=True) # primals_427 + buf427 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf427, (32, 2048), is_leaf=True) # primals_428 + buf428 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf428, (32,), is_leaf=True) # primals_429 + buf429 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf429, (64,), is_leaf=True) # primals_430 + buf430 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf430, (1, 2048), is_leaf=True) # primals_431 + buf431 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf431, (64,), is_leaf=True) # primals_432 + buf432 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf432, (1, 1408, 2048), is_leaf=True) # primals_433 + buf433 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf433, (1, 2048, 1408), is_leaf=True) # primals_434 + buf434 = reader.storage(None, 11534336, device=device(type='cuda', index=0)) + reader.tensor(buf434, (1, 1408, 2048), is_leaf=True) # primals_435 + buf435 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf435, (44, 2048), is_leaf=True) # primals_436 + buf436 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf436, (44, 2048), is_leaf=True) # primals_437 + buf437 = reader.storage(None, 360448, device=device(type='cuda', index=0)) + reader.tensor(buf437, (32, 2816), is_leaf=True) # primals_438 + buf438 = reader.storage(None, 128, device=device(type='cuda', index=0)) + reader.tensor(buf438, (32,), is_leaf=True) # primals_439 + buf439 = reader.storage(None, 13107200, device=device(type='cuda', index=0)) + reader.tensor(buf439, (1600, 2048), is_leaf=True) # primals_440 +load_args._version = 0 +mod = Repro() +if __name__ == '__main__': + from torch._dynamo.repro.after_aot import run_repro + from torch._dynamo.repro.after_aot import setup_fake_process_groups + setup_fake_process_groups({'0': {'size': 64, 'rank': 0}, '521': {'size': 8, 'rank': 0}, '513': {'size': 8, 'rank': 0}}) + with torch.no_grad(): + run_repro(mod, load_args, accuracy=False, command='run', save_dir=None, tracing_mode='real', check_str=None) + # To run it separately, do + # mod, args = run_repro(mod, load_args, accuracy=False, command='get_args', save_dir=None, tracing_mode='real', check_str=None) + # mod(*args) + dist.destroy_process_group() + +# Helper functions for overlap simulator +def get_pg_config(): + """DSv3 64 GPUs: FSDP=64, TP=1, EP=8.""" + return {'0': {'size': 64, 'rank': 0}, '513': {'size': 8, 'rank': 0}, '521': {'size': 8, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls8_8.table" + +def get_colls_group_mapping(): + # FSDP "0" → internode (table group "0"), EP "513","521" → intranode (table group "1") + return {'0': '0', '513': '1', '521': '1'} diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_1d.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_1d.py new file mode 100644 index 00000000..a6d810f1 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_1d.py @@ -0,0 +1,8954 @@ +# fmt: off +# flake8: noqa +# isort: skip_file +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, embedding, mm, mm_2, getitem, getitem_1, getitem_6, getitem_7, mm_4, add_3, mm_7, mm_9, getitem_9, getitem_10, getitem_15, getitem_16, mm_11, add_7, mm_14, mm_16, getitem_18, getitem_19, getitem_24, getitem_25, mm_18, add_11, mm_21, mm_23, getitem_27, getitem_28, getitem_33, getitem_34, mm_25, add_15, mm_28, mm_30, getitem_36, getitem_37, getitem_42, getitem_43, mm_32, add_19, mm_35, mm_37, getitem_45, getitem_46, getitem_51, getitem_52, mm_39, add_23, mm_42, mm_44, getitem_54, getitem_55, getitem_60, getitem_61, mm_46, add_27, mm_49, mm_51, getitem_63, getitem_64, getitem_69, getitem_70, mm_53, add_31, mm_56, mm_58, getitem_72, getitem_73, getitem_78, getitem_79, mm_60, add_35, mm_63, mm_65, getitem_81, getitem_82, getitem_87, getitem_88, mm_67, add_39, mm_70, mm_72, getitem_90, getitem_91, getitem_96, getitem_97, mm_74, add_43, mm_77, mm_79, getitem_99, getitem_100, getitem_105, getitem_106, mm_81, add_47, mm_84, mm_86, getitem_108, getitem_109, getitem_114, getitem_115, mm_88, add_51, mm_91, mm_93, getitem_117, getitem_118, getitem_123, getitem_124, mm_95, add_55, mm_98, mm_100, getitem_126, getitem_127, getitem_132, getitem_133, mm_102, add_59, mm_105, mm_107, getitem_135, getitem_136, getitem_141, getitem_142, mm_109, add_63, mm_112, mm_114, getitem_144, getitem_145, getitem_150, getitem_151, mm_116, add_67, mm_119, mm_121, getitem_153, getitem_154, getitem_159, getitem_160, mm_123, add_71, mm_126, mm_128, getitem_162, getitem_163, getitem_168, getitem_169, mm_130, add_75, mm_133, mm_135, getitem_171, getitem_172, getitem_177, getitem_178, mm_137, add_79, mm_140, mm_142, getitem_180, getitem_181, getitem_186, getitem_187, mm_144, add_83, mm_147, mm_149, getitem_189, getitem_190, getitem_195, getitem_196, mm_151, add_87, mm_154, mm_156, getitem_198, getitem_199, getitem_204, getitem_205, mm_158, add_91, mm_161, mm_163, getitem_207, getitem_208, getitem_213, getitem_214, mm_165, add_95, mm_168, mm_170, getitem_216, getitem_217, getitem_222, getitem_223, mm_172, add_99, mm_175, mm_177, getitem_225, getitem_226, getitem_231, getitem_232, mm_179, add_103, mm_182, mm_184, getitem_234, getitem_235, getitem_240, getitem_241, mm_186, add_107, mm_189, mm_191, getitem_243, getitem_244, getitem_249, getitem_250, mm_193, add_111, mm_196, mm_198, getitem_252, getitem_253, getitem_258, getitem_259, mm_200, add_115, mm_203, mm_205, getitem_261, getitem_262, getitem_267, getitem_268, mm_207, add_119, mm_210, mm_212, getitem_270, getitem_271, getitem_276, getitem_277, mm_214, add_123, mm_217, mm_219, getitem_279, getitem_280, getitem_285, getitem_286, mm_221, mm_223, rsqrt_64, view_1091, tangents_1): + view_1093 = torch.ops.aten.view.default(tangents_1, [16384, 128256]); tangents_1 = None + permute_353 = torch.ops.aten.permute.default(view_1093, [1, 0]) + mm_225 = torch.ops.aten.mm.default(permute_353, view_1091); permute_353 = view_1091 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16); primals_293 = None + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 256, '0'); convert_element_type_1060 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + permute_355 = torch.ops.aten.permute.default(permute_352, [1, 0]); permute_352 = None + mm_226 = torch.ops.aten.mm.default(view_1093, permute_355); view_1093 = permute_355 = None + view_1094 = torch.ops.aten.view.default(mm_226, [2, 8192, 4096]); mm_226 = None + convert_element_type_1067 = torch.ops.prims.convert_element_type.default(mm_225, torch.float32); mm_225 = None + reduce_scatter_tensor = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1067, 'avg', 256, '0'); convert_element_type_1067 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None + convert_element_type_1068 = torch.ops.prims.convert_element_type.default(view_1094, torch.float32); view_1094 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16); primals_292 = None + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 256, '0'); convert_element_type_1057 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_1070 = torch.ops.prims.convert_element_type.default(wait_tensor_289, torch.float32); wait_tensor_289 = None + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_1068, convert_element_type_1070); convert_element_type_1070 = None + permute_347 = torch.ops.aten.permute.default(getitem_279, [0, 2, 1, 3]) + view_1075 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16); primals_287 = None + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 256, '0'); convert_element_type_1040 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1077 = torch.ops.aten.view.default(view_1075, [16384, 4096]); view_1075 = None + mm_220 = torch.ops.aten.mm.default(view_1077, permute_348) + view_1078 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + add_125 = torch.ops.aten.add.Tensor(add_123, view_1078); view_1078 = None + view_1088 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]); mm_223 = None + add_127 = torch.ops.aten.add.Tensor(add_125, view_1088); view_1088 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_260 = torch.ops.aten.mul.Tensor(mul_256, mul_258) + sum_1 = torch.ops.aten.sum.dim_IntList(mul_260, [2], True); mul_260 = None + div = torch.ops.aten.div.Tensor(mul_256, 4096) + mul_261 = torch.ops.aten.mul.Tensor(div, sum_1); div = sum_1 = None + sub = torch.ops.aten.sub.Tensor(mul_258, mul_261); mul_258 = mul_261 = None + mul_262 = torch.ops.aten.mul.Tensor(sub, rsqrt_64); sub = rsqrt_64 = None + mul_263 = torch.ops.aten.mul.Tensor(convert_element_type_1068, mul_256); convert_element_type_1068 = mul_256 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_263, [0, 1]); mul_263 = None + convert_element_type_1071 = torch.ops.prims.convert_element_type.default(mul_262, torch.bfloat16); mul_262 = None + convert_element_type_default_65 = torch.ops.prims.convert_element_type.default(sum_2, torch.float32); sum_2 = None + reduce_scatter_tensor_1 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_65, 'avg', 256, '0'); convert_element_type_default_65 = None + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None + view_1095 = torch.ops.aten.view.default(convert_element_type_1071, [16384, 4096]) + permute_357 = torch.ops.aten.permute.default(view_1095, [1, 0]) + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16); primals_288 = None + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 256, '0'); convert_element_type_1043 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32); add_125 = None + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_285) + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + view_1081 = torch.ops.aten.view.default(convert_element_type_1045, [16384, 4096]); convert_element_type_1045 = None + view_1082 = torch.ops.aten.view.default(mm_221, [2, 8192, 14336]); mm_221 = None + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_1082, torch.float32); view_1082 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16); primals_290 = None + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 256, '0'); convert_element_type_1051 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_287, [1, 0]); wait_tensor_287 = None + mm_222 = torch.ops.aten.mm.default(view_1081, permute_350) + view_1085 = torch.ops.aten.view.default(mm_222, [2, 8192, 14336]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_1085) + view_1087 = torch.ops.aten.view.default(mul_255, [16384, 14336]); mul_255 = None + mm_227 = torch.ops.aten.mm.default(permute_357, view_1087); permute_357 = view_1087 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16); primals_291 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 256, '0'); convert_element_type_1054 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + permute_359 = torch.ops.aten.permute.default(permute_351, [1, 0]); permute_351 = None + mm_228 = torch.ops.aten.mm.default(view_1095, permute_359); view_1095 = permute_359 = None + view_1096 = torch.ops.aten.view.default(mm_228, [2, 8192, 14336]); mm_228 = None + convert_element_type_1078 = torch.ops.prims.convert_element_type.default(mm_227, torch.float32); mm_227 = None + reduce_scatter_tensor_2 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1078, 'avg', 256, '0'); convert_element_type_1078 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None + mul_264 = torch.ops.aten.mul.Tensor(view_1096, convert_element_type_1050); convert_element_type_1050 = None + mul_265 = torch.ops.aten.mul.Tensor(view_1096, view_1085); view_1096 = view_1085 = None + view_1097 = torch.ops.aten.view.default(mul_264, [16384, 14336]); mul_264 = None + permute_361 = torch.ops.aten.permute.default(view_1097, [1, 0]) + mm_229 = torch.ops.aten.mm.default(permute_361, view_1081); permute_361 = None + permute_363 = torch.ops.aten.permute.default(permute_350, [1, 0]); permute_350 = None + mm_230 = torch.ops.aten.mm.default(view_1097, permute_363); view_1097 = permute_363 = None + view_1098 = torch.ops.aten.view.default(mm_230, [2, 8192, 4096]); mm_230 = None + convert_element_type_1083 = torch.ops.prims.convert_element_type.default(mm_229, torch.float32); mm_229 = None + reduce_scatter_tensor_3 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1083, 'avg', 256, '0'); convert_element_type_1083 = None + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3); reduce_scatter_tensor_3 = None + convert_element_type_1084 = torch.ops.prims.convert_element_type.default(mul_265, torch.float32); mul_265 = None + neg = torch.ops.aten.neg.default(convert_element_type_1049) + exp = torch.ops.aten.exp.default(neg); neg = None + add_129 = torch.ops.aten.add.Tensor(exp, 1); exp = None + reciprocal = torch.ops.aten.reciprocal.default(add_129); add_129 = None + mul_266 = torch.ops.aten.mul.Tensor(reciprocal, 1); reciprocal = None + mul_267 = torch.ops.aten.mul.Tensor(convert_element_type_1084, mul_266); convert_element_type_1084 = None + sub_1 = torch.ops.aten.sub.Tensor(1, mul_266); mul_266 = None + mul_268 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sub_1); convert_element_type_1049 = sub_1 = None + add_130 = torch.ops.aten.add.Tensor(mul_268, 1); mul_268 = None + mul_269 = torch.ops.aten.mul.Tensor(mul_267, add_130); mul_267 = add_130 = None + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(mul_269, torch.bfloat16); mul_269 = None + view_1099 = torch.ops.aten.view.default(convert_element_type_1086, [16384, 14336]); convert_element_type_1086 = None + permute_365 = torch.ops.aten.permute.default(view_1099, [1, 0]) + mm_231 = torch.ops.aten.mm.default(permute_365, view_1081); permute_365 = view_1081 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16); primals_289 = None + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 256, '0'); convert_element_type_1046 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + permute_367 = torch.ops.aten.permute.default(permute_349, [1, 0]); permute_349 = None + mm_232 = torch.ops.aten.mm.default(view_1099, permute_367); view_1099 = permute_367 = None + view_1100 = torch.ops.aten.view.default(mm_232, [2, 8192, 4096]); mm_232 = None + add_131 = torch.ops.aten.add.Tensor(view_1098, view_1100); view_1098 = view_1100 = None + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(mm_231, torch.float32); mm_231 = None + reduce_scatter_tensor_4 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1091, 'avg', 256, '0'); convert_element_type_1091 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_4); reduce_scatter_tensor_4 = None + convert_element_type_1092 = torch.ops.prims.convert_element_type.default(add_131, torch.float32); add_131 = None + convert_element_type_1094 = torch.ops.prims.convert_element_type.default(wait_tensor_285, torch.float32); wait_tensor_285 = None + mul_270 = torch.ops.aten.mul.Tensor(convert_element_type_1092, convert_element_type_1094); convert_element_type_1094 = None + mul_272 = torch.ops.aten.mul.Tensor(mul_252, mul_270) + sum_3 = torch.ops.aten.sum.dim_IntList(mul_272, [2], True); mul_272 = None + div_1 = torch.ops.aten.div.Tensor(mul_252, 4096) + mul_273 = torch.ops.aten.mul.Tensor(div_1, sum_3); div_1 = sum_3 = None + sub_2 = torch.ops.aten.sub.Tensor(mul_270, mul_273); mul_270 = mul_273 = None + mul_274 = torch.ops.aten.mul.Tensor(sub_2, rsqrt_63); sub_2 = rsqrt_63 = None + mul_275 = torch.ops.aten.mul.Tensor(convert_element_type_1092, mul_252); convert_element_type_1092 = mul_252 = None + sum_4 = torch.ops.aten.sum.dim_IntList(mul_275, [0, 1]); mul_275 = None + convert_element_type_1095 = torch.ops.prims.convert_element_type.default(mul_274, torch.bfloat16); mul_274 = None + add_132 = torch.ops.aten.add.Tensor(convert_element_type_1071, convert_element_type_1095); convert_element_type_1071 = convert_element_type_1095 = None + convert_element_type_default_64 = torch.ops.prims.convert_element_type.default(sum_4, torch.float32); sum_4 = None + reduce_scatter_tensor_5 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_64, 'avg', 256, '0'); convert_element_type_default_64 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5); reduce_scatter_tensor_5 = None + view_1101 = torch.ops.aten.view.default(add_132, [16384, 4096]) + permute_369 = torch.ops.aten.permute.default(view_1101, [1, 0]) + mm_233 = torch.ops.aten.mm.default(permute_369, view_1077); permute_369 = view_1077 = None + permute_371 = torch.ops.aten.permute.default(permute_348, [1, 0]); permute_348 = None + mm_234 = torch.ops.aten.mm.default(view_1101, permute_371); view_1101 = permute_371 = None + view_1102 = torch.ops.aten.view.default(mm_234, [2, 8192, 4096]); mm_234 = None + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(mm_233, torch.float32); mm_233 = None + reduce_scatter_tensor_6 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1102, 'avg', 256, '0'); convert_element_type_1102 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_6); reduce_scatter_tensor_6 = None + view_1103 = torch.ops.aten.view.default(view_1102, [2, 8192, 32, 128]); view_1102 = None + permute_373 = torch.ops.aten.permute.default(view_1103, [0, 2, 1, 3]); view_1103 = None + view_16 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]); primals_3 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16); primals_283 = None + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 256, '0'); convert_element_type_1024 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32); add_123 = None + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_280) + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + view_1057 = torch.ops.aten.view.default(convert_element_type_1026, [16384, 4096]); convert_element_type_1026 = None + view_1058 = torch.ops.aten.view.default(mm_217, [2, 8192, 4096]); mm_217 = None + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 256, '0'); convert_element_type_1030 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + mm_218 = torch.ops.aten.mm.default(view_1057, permute_342) + view_1061 = torch.ops.aten.view.default(mm_218, [2, 8192, 1024]); mm_218 = None + view_1064 = torch.ops.aten.view.default(mm_219, [2, 8192, 1024]); mm_219 = None + view_1065 = torch.ops.aten.view.default(view_1058, [2, 8192, -1, 128]); view_1058 = None + view_1066 = torch.ops.aten.view.default(view_1061, [2, 8192, -1, 128]); view_1061 = None + view_1067 = torch.ops.aten.view.default(view_1064, [2, 8192, -1, 128]); view_1064 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_1065, torch.float32); view_1065 = None + view_1068 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 32, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1068); view_1068 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_1066, torch.float32); view_1066 = None + view_1069 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 8, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1069); view_1069 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_16); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_1071 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 32, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_16); view_as_complex_63 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_1072 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 8, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_1071, torch.bfloat16); view_1071 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_1072, torch.bfloat16); view_1072 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 8, 4, 128]); unsqueeze_62 = None + clone_62 = torch.ops.aten.clone.default(expand_62, memory_format = torch.contiguous_format); expand_62 = None + view_1073 = torch.ops.aten.view.default(clone_62, [2, 8192, 32, 128]); clone_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_1067, 3); view_1067 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 8, 4, 128]); unsqueeze_63 = None + clone_63 = torch.ops.aten.clone.default(expand_63, memory_format = torch.contiguous_format); expand_63 = None + view_1074 = torch.ops.aten.view.default(clone_63, [2, 8192, 32, 128]); clone_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_1073, [0, 2, 1, 3]); view_1073 = None + permute_346 = torch.ops.aten.permute.default(view_1074, [0, 2, 1, 3]); view_1074 = None + _scaled_dot_product_cudnn_attention_backward = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_373, permute_344, permute_345, permute_346, getitem_279, getitem_280, getitem_285, getitem_286, None, None, None, 8192, 8192, 0.0, True); permute_373 = permute_344 = permute_345 = permute_346 = getitem_279 = getitem_280 = getitem_285 = getitem_286 = None + getitem_288 = _scaled_dot_product_cudnn_attention_backward[0] + getitem_289 = _scaled_dot_product_cudnn_attention_backward[1] + getitem_290 = _scaled_dot_product_cudnn_attention_backward[2]; _scaled_dot_product_cudnn_attention_backward = None + permute_374 = torch.ops.aten.permute.default(getitem_290, [0, 2, 1, 3]); getitem_290 = None + permute_375 = torch.ops.aten.permute.default(getitem_289, [0, 2, 1, 3]); getitem_289 = None + permute_376 = torch.ops.aten.permute.default(getitem_288, [0, 2, 1, 3]); getitem_288 = None + view_1104 = torch.ops.aten.view.default(permute_374, [2, 8192, 8, 4, 128]); permute_374 = None + sum_5 = torch.ops.aten.sum.dim_IntList(view_1104, [3], True); view_1104 = None + squeeze = torch.ops.aten.squeeze.dim(sum_5, 3); sum_5 = None + view_1105 = torch.ops.aten.view.default(permute_375, [2, 8192, 8, 4, 128]); permute_375 = None + sum_6 = torch.ops.aten.sum.dim_IntList(view_1105, [3], True); view_1105 = None + squeeze_1 = torch.ops.aten.squeeze.dim(sum_6, 3); sum_6 = None + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(squeeze_1, torch.float32); squeeze_1 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(permute_376, torch.float32); permute_376 = None + view_1106 = torch.ops.aten.view.default(convert_element_type_1103, [2, 8192, 8, 64, 2]); convert_element_type_1103 = None + view_as_complex_64 = torch.ops.aten.view_as_complex.default(view_1106); view_1106 = None + _conj = torch.ops.aten._conj.default(view_16) + mul_276 = torch.ops.aten.mul.Tensor(view_as_complex_64, _conj); view_as_complex_64 = None + view_1107 = torch.ops.aten.view.default(convert_element_type_1104, [2, 8192, 32, 64, 2]); convert_element_type_1104 = None + view_as_complex_65 = torch.ops.aten.view_as_complex.default(view_1107); view_1107 = None + mul_277 = torch.ops.aten.mul.Tensor(view_as_complex_65, _conj); view_as_complex_65 = None + view_as_real_64 = torch.ops.aten.view_as_real.default(mul_276); mul_276 = None + view_1108 = torch.ops.aten.view.default(view_as_real_64, [2, 8192, 8, 128]); view_as_real_64 = None + convert_element_type_1105 = torch.ops.prims.convert_element_type.default(view_1108, torch.bfloat16); view_1108 = None + view_as_real_65 = torch.ops.aten.view_as_real.default(mul_277); mul_277 = None + view_1109 = torch.ops.aten.view.default(view_as_real_65, [2, 8192, 32, 128]); view_as_real_65 = None + convert_element_type_1106 = torch.ops.prims.convert_element_type.default(view_1109, torch.bfloat16); view_1109 = None + view_1110 = torch.ops.aten.view.default(squeeze, [2, 8192, 1024]); squeeze = None + view_1111 = torch.ops.aten.view.default(convert_element_type_1105, [2, 8192, 1024]); convert_element_type_1105 = None + view_1112 = torch.ops.aten.view.default(convert_element_type_1106, [2, 8192, 4096]); convert_element_type_1106 = None + view_1113 = torch.ops.aten.view.default(view_1110, [16384, 1024]); view_1110 = None + permute_377 = torch.ops.aten.permute.default(view_1113, [1, 0]) + mm_235 = torch.ops.aten.mm.default(permute_377, view_1057); permute_377 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16); primals_286 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 256, '0'); convert_element_type_1033 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + permute_379 = torch.ops.aten.permute.default(permute_343, [1, 0]); permute_343 = None + mm_236 = torch.ops.aten.mm.default(view_1113, permute_379); view_1113 = permute_379 = None + view_1114 = torch.ops.aten.view.default(mm_236, [2, 8192, 4096]); mm_236 = None + convert_element_type_1111 = torch.ops.prims.convert_element_type.default(mm_235, torch.float32); mm_235 = None + reduce_scatter_tensor_7 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1111, 'avg', 256, '0'); convert_element_type_1111 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7); reduce_scatter_tensor_7 = None + view_1115 = torch.ops.aten.view.default(view_1111, [16384, 1024]); view_1111 = None + permute_381 = torch.ops.aten.permute.default(view_1115, [1, 0]) + mm_237 = torch.ops.aten.mm.default(permute_381, view_1057); permute_381 = None + permute_383 = torch.ops.aten.permute.default(permute_342, [1, 0]); permute_342 = None + mm_238 = torch.ops.aten.mm.default(view_1115, permute_383); view_1115 = permute_383 = None + view_1116 = torch.ops.aten.view.default(mm_238, [2, 8192, 4096]); mm_238 = None + add_133 = torch.ops.aten.add.Tensor(view_1114, view_1116); view_1114 = view_1116 = None + convert_element_type_1116 = torch.ops.prims.convert_element_type.default(mm_237, torch.float32); mm_237 = None + reduce_scatter_tensor_8 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1116, 'avg', 256, '0'); convert_element_type_1116 = None + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_8); reduce_scatter_tensor_8 = None + view_1117 = torch.ops.aten.view.default(view_1112, [16384, 4096]); view_1112 = None + permute_385 = torch.ops.aten.permute.default(view_1117, [1, 0]) + mm_239 = torch.ops.aten.mm.default(permute_385, view_1057); permute_385 = view_1057 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 256, '0'); convert_element_type_1027 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + permute_387 = torch.ops.aten.permute.default(permute_341, [1, 0]); permute_341 = None + mm_240 = torch.ops.aten.mm.default(view_1117, permute_387); view_1117 = permute_387 = None + view_1118 = torch.ops.aten.view.default(mm_240, [2, 8192, 4096]); mm_240 = None + add_134 = torch.ops.aten.add.Tensor(add_133, view_1118); add_133 = view_1118 = None + convert_element_type_1121 = torch.ops.prims.convert_element_type.default(mm_239, torch.float32); mm_239 = None + reduce_scatter_tensor_9 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1121, 'avg', 256, '0'); convert_element_type_1121 = None + wait_tensor_300 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9); reduce_scatter_tensor_9 = None + convert_element_type_1122 = torch.ops.prims.convert_element_type.default(add_134, torch.float32); add_134 = None + convert_element_type_1124 = torch.ops.prims.convert_element_type.default(wait_tensor_280, torch.float32); wait_tensor_280 = None + mul_278 = torch.ops.aten.mul.Tensor(convert_element_type_1122, convert_element_type_1124); convert_element_type_1124 = None + mul_280 = torch.ops.aten.mul.Tensor(mul_248, mul_278) + sum_7 = torch.ops.aten.sum.dim_IntList(mul_280, [2], True); mul_280 = None + div_2 = torch.ops.aten.div.Tensor(mul_248, 4096) + mul_281 = torch.ops.aten.mul.Tensor(div_2, sum_7); div_2 = sum_7 = None + sub_3 = torch.ops.aten.sub.Tensor(mul_278, mul_281); mul_278 = mul_281 = None + mul_282 = torch.ops.aten.mul.Tensor(sub_3, rsqrt_62); sub_3 = rsqrt_62 = None + mul_283 = torch.ops.aten.mul.Tensor(convert_element_type_1122, mul_248); convert_element_type_1122 = mul_248 = None + sum_8 = torch.ops.aten.sum.dim_IntList(mul_283, [0, 1]); mul_283 = None + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(mul_282, torch.bfloat16); mul_282 = None + add_135 = torch.ops.aten.add.Tensor(add_132, convert_element_type_1125); add_132 = convert_element_type_1125 = None + convert_element_type_default_63 = torch.ops.prims.convert_element_type.default(sum_8, torch.float32); sum_8 = None + reduce_scatter_tensor_10 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_63, 'avg', 256, '0'); convert_element_type_default_63 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_10); reduce_scatter_tensor_10 = None + view_1119 = torch.ops.aten.view.default(add_135, [16384, 4096]) + permute_389 = torch.ops.aten.permute.default(view_1119, [1, 0]) + permute_336 = torch.ops.aten.permute.default(getitem_270, [0, 2, 1, 3]) + view_1041 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16); primals_278 = None + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 256, '0'); convert_element_type_1007 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_275, [1, 0]); wait_tensor_275 = None + view_1043 = torch.ops.aten.view.default(view_1041, [16384, 4096]); view_1041 = None + mm_213 = torch.ops.aten.mm.default(view_1043, permute_337) + view_1044 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + add_121 = torch.ops.aten.add.Tensor(add_119, view_1044); view_1044 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16); primals_279 = None + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 256, '0'); convert_element_type_1010 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32); add_121 = None + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_276) + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + view_1047 = torch.ops.aten.view.default(convert_element_type_1012, [16384, 4096]); convert_element_type_1012 = None + view_1048 = torch.ops.aten.view.default(mm_214, [2, 8192, 14336]); mm_214 = None + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_1048, torch.float32); view_1048 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16); primals_281 = None + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 256, '0'); convert_element_type_1018 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_215 = torch.ops.aten.mm.default(view_1047, permute_339) + view_1051 = torch.ops.aten.view.default(mm_215, [2, 8192, 14336]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_1051) + view_1053 = torch.ops.aten.view.default(mul_247, [16384, 14336]); mul_247 = None + mm_241 = torch.ops.aten.mm.default(permute_389, view_1053); permute_389 = view_1053 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16); primals_282 = None + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 256, '0'); convert_element_type_1021 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + permute_391 = torch.ops.aten.permute.default(permute_340, [1, 0]); permute_340 = None + mm_242 = torch.ops.aten.mm.default(view_1119, permute_391); view_1119 = permute_391 = None + view_1120 = torch.ops.aten.view.default(mm_242, [2, 8192, 14336]); mm_242 = None + convert_element_type_1132 = torch.ops.prims.convert_element_type.default(mm_241, torch.float32); mm_241 = None + reduce_scatter_tensor_11 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1132, 'avg', 256, '0'); convert_element_type_1132 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11); reduce_scatter_tensor_11 = None + mul_284 = torch.ops.aten.mul.Tensor(view_1120, convert_element_type_1017); convert_element_type_1017 = None + mul_285 = torch.ops.aten.mul.Tensor(view_1120, view_1051); view_1120 = view_1051 = None + view_1121 = torch.ops.aten.view.default(mul_284, [16384, 14336]); mul_284 = None + permute_393 = torch.ops.aten.permute.default(view_1121, [1, 0]) + mm_243 = torch.ops.aten.mm.default(permute_393, view_1047); permute_393 = None + permute_395 = torch.ops.aten.permute.default(permute_339, [1, 0]); permute_339 = None + mm_244 = torch.ops.aten.mm.default(view_1121, permute_395); view_1121 = permute_395 = None + view_1122 = torch.ops.aten.view.default(mm_244, [2, 8192, 4096]); mm_244 = None + convert_element_type_1137 = torch.ops.prims.convert_element_type.default(mm_243, torch.float32); mm_243 = None + reduce_scatter_tensor_12 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1137, 'avg', 256, '0'); convert_element_type_1137 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_12); reduce_scatter_tensor_12 = None + convert_element_type_1138 = torch.ops.prims.convert_element_type.default(mul_285, torch.float32); mul_285 = None + neg_1 = torch.ops.aten.neg.default(convert_element_type_1016) + exp_1 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_136 = torch.ops.aten.add.Tensor(exp_1, 1); exp_1 = None + reciprocal_1 = torch.ops.aten.reciprocal.default(add_136); add_136 = None + mul_286 = torch.ops.aten.mul.Tensor(reciprocal_1, 1); reciprocal_1 = None + mul_287 = torch.ops.aten.mul.Tensor(convert_element_type_1138, mul_286); convert_element_type_1138 = None + sub_4 = torch.ops.aten.sub.Tensor(1, mul_286); mul_286 = None + mul_288 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sub_4); convert_element_type_1016 = sub_4 = None + add_137 = torch.ops.aten.add.Tensor(mul_288, 1); mul_288 = None + mul_289 = torch.ops.aten.mul.Tensor(mul_287, add_137); mul_287 = add_137 = None + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(mul_289, torch.bfloat16); mul_289 = None + view_1123 = torch.ops.aten.view.default(convert_element_type_1140, [16384, 14336]); convert_element_type_1140 = None + permute_397 = torch.ops.aten.permute.default(view_1123, [1, 0]) + mm_245 = torch.ops.aten.mm.default(permute_397, view_1047); permute_397 = view_1047 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16); primals_280 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 256, '0'); convert_element_type_1013 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + permute_399 = torch.ops.aten.permute.default(permute_338, [1, 0]); permute_338 = None + mm_246 = torch.ops.aten.mm.default(view_1123, permute_399); view_1123 = permute_399 = None + view_1124 = torch.ops.aten.view.default(mm_246, [2, 8192, 4096]); mm_246 = None + add_138 = torch.ops.aten.add.Tensor(view_1122, view_1124); view_1122 = view_1124 = None + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(mm_245, torch.float32); mm_245 = None + reduce_scatter_tensor_13 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1145, 'avg', 256, '0'); convert_element_type_1145 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13); reduce_scatter_tensor_13 = None + convert_element_type_1146 = torch.ops.prims.convert_element_type.default(add_138, torch.float32); add_138 = None + convert_element_type_1148 = torch.ops.prims.convert_element_type.default(wait_tensor_276, torch.float32); wait_tensor_276 = None + mul_290 = torch.ops.aten.mul.Tensor(convert_element_type_1146, convert_element_type_1148); convert_element_type_1148 = None + mul_292 = torch.ops.aten.mul.Tensor(mul_244, mul_290) + sum_9 = torch.ops.aten.sum.dim_IntList(mul_292, [2], True); mul_292 = None + div_3 = torch.ops.aten.div.Tensor(mul_244, 4096) + mul_293 = torch.ops.aten.mul.Tensor(div_3, sum_9); div_3 = sum_9 = None + sub_5 = torch.ops.aten.sub.Tensor(mul_290, mul_293); mul_290 = mul_293 = None + mul_294 = torch.ops.aten.mul.Tensor(sub_5, rsqrt_61); sub_5 = rsqrt_61 = None + mul_295 = torch.ops.aten.mul.Tensor(convert_element_type_1146, mul_244); convert_element_type_1146 = mul_244 = None + sum_10 = torch.ops.aten.sum.dim_IntList(mul_295, [0, 1]); mul_295 = None + convert_element_type_1149 = torch.ops.prims.convert_element_type.default(mul_294, torch.bfloat16); mul_294 = None + add_139 = torch.ops.aten.add.Tensor(add_135, convert_element_type_1149); add_135 = convert_element_type_1149 = None + convert_element_type_default_62 = torch.ops.prims.convert_element_type.default(sum_10, torch.float32); sum_10 = None + reduce_scatter_tensor_14 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_62, 'avg', 256, '0'); convert_element_type_default_62 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_14); reduce_scatter_tensor_14 = None + view_1125 = torch.ops.aten.view.default(add_139, [16384, 4096]) + permute_401 = torch.ops.aten.permute.default(view_1125, [1, 0]) + mm_247 = torch.ops.aten.mm.default(permute_401, view_1043); permute_401 = view_1043 = None + permute_403 = torch.ops.aten.permute.default(permute_337, [1, 0]); permute_337 = None + mm_248 = torch.ops.aten.mm.default(view_1125, permute_403); view_1125 = permute_403 = None + view_1126 = torch.ops.aten.view.default(mm_248, [2, 8192, 4096]); mm_248 = None + convert_element_type_1156 = torch.ops.prims.convert_element_type.default(mm_247, torch.float32); mm_247 = None + reduce_scatter_tensor_15 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1156, 'avg', 256, '0'); convert_element_type_1156 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15); reduce_scatter_tensor_15 = None + view_1127 = torch.ops.aten.view.default(view_1126, [2, 8192, 32, 128]); view_1126 = None + permute_405 = torch.ops.aten.permute.default(view_1127, [0, 2, 1, 3]); view_1127 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16); primals_274 = None + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 256, '0'); convert_element_type_991 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32); add_119 = None + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_271) + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + view_1023 = torch.ops.aten.view.default(convert_element_type_993, [16384, 4096]); convert_element_type_993 = None + view_1024 = torch.ops.aten.view.default(mm_210, [2, 8192, 4096]); mm_210 = None + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16); primals_276 = None + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 256, '0'); convert_element_type_997 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + mm_211 = torch.ops.aten.mm.default(view_1023, permute_331) + view_1027 = torch.ops.aten.view.default(mm_211, [2, 8192, 1024]); mm_211 = None + view_1030 = torch.ops.aten.view.default(mm_212, [2, 8192, 1024]); mm_212 = None + view_1031 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1032 = torch.ops.aten.view.default(view_1027, [2, 8192, -1, 128]); view_1027 = None + view_1033 = torch.ops.aten.view.default(view_1030, [2, 8192, -1, 128]); view_1030 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_1031, torch.float32); view_1031 = None + view_1034 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 32, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1034); view_1034 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_1032, torch.float32); view_1032 = None + view_1035 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 8, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1035); view_1035 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_16); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_1037 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 32, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_16); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_1038 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 8, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_1037, torch.bfloat16); view_1037 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_1038, torch.bfloat16); view_1038 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 8, 4, 128]); unsqueeze_60 = None + clone_60 = torch.ops.aten.clone.default(expand_60, memory_format = torch.contiguous_format); expand_60 = None + view_1039 = torch.ops.aten.view.default(clone_60, [2, 8192, 32, 128]); clone_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_1033, 3); view_1033 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 8, 4, 128]); unsqueeze_61 = None + clone_61 = torch.ops.aten.clone.default(expand_61, memory_format = torch.contiguous_format); expand_61 = None + view_1040 = torch.ops.aten.view.default(clone_61, [2, 8192, 32, 128]); clone_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_1039, [0, 2, 1, 3]); view_1039 = None + permute_335 = torch.ops.aten.permute.default(view_1040, [0, 2, 1, 3]); view_1040 = None + _scaled_dot_product_cudnn_attention_backward_1 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_405, permute_333, permute_334, permute_335, getitem_270, getitem_271, getitem_276, getitem_277, None, None, None, 8192, 8192, 0.0, True); permute_405 = permute_333 = permute_334 = permute_335 = getitem_270 = getitem_271 = getitem_276 = getitem_277 = None + getitem_291 = _scaled_dot_product_cudnn_attention_backward_1[0] + getitem_292 = _scaled_dot_product_cudnn_attention_backward_1[1] + getitem_293 = _scaled_dot_product_cudnn_attention_backward_1[2]; _scaled_dot_product_cudnn_attention_backward_1 = None + permute_406 = torch.ops.aten.permute.default(getitem_293, [0, 2, 1, 3]); getitem_293 = None + permute_407 = torch.ops.aten.permute.default(getitem_292, [0, 2, 1, 3]); getitem_292 = None + permute_408 = torch.ops.aten.permute.default(getitem_291, [0, 2, 1, 3]); getitem_291 = None + view_1128 = torch.ops.aten.view.default(permute_406, [2, 8192, 8, 4, 128]); permute_406 = None + sum_11 = torch.ops.aten.sum.dim_IntList(view_1128, [3], True); view_1128 = None + squeeze_2 = torch.ops.aten.squeeze.dim(sum_11, 3); sum_11 = None + view_1129 = torch.ops.aten.view.default(permute_407, [2, 8192, 8, 4, 128]); permute_407 = None + sum_12 = torch.ops.aten.sum.dim_IntList(view_1129, [3], True); view_1129 = None + squeeze_3 = torch.ops.aten.squeeze.dim(sum_12, 3); sum_12 = None + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(squeeze_3, torch.float32); squeeze_3 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(permute_408, torch.float32); permute_408 = None + view_1130 = torch.ops.aten.view.default(convert_element_type_1157, [2, 8192, 8, 64, 2]); convert_element_type_1157 = None + view_as_complex_66 = torch.ops.aten.view_as_complex.default(view_1130); view_1130 = None + mul_296 = torch.ops.aten.mul.Tensor(view_as_complex_66, _conj); view_as_complex_66 = None + view_1131 = torch.ops.aten.view.default(convert_element_type_1158, [2, 8192, 32, 64, 2]); convert_element_type_1158 = None + view_as_complex_67 = torch.ops.aten.view_as_complex.default(view_1131); view_1131 = None + mul_297 = torch.ops.aten.mul.Tensor(view_as_complex_67, _conj); view_as_complex_67 = None + view_as_real_66 = torch.ops.aten.view_as_real.default(mul_296); mul_296 = None + view_1132 = torch.ops.aten.view.default(view_as_real_66, [2, 8192, 8, 128]); view_as_real_66 = None + convert_element_type_1159 = torch.ops.prims.convert_element_type.default(view_1132, torch.bfloat16); view_1132 = None + view_as_real_67 = torch.ops.aten.view_as_real.default(mul_297); mul_297 = None + view_1133 = torch.ops.aten.view.default(view_as_real_67, [2, 8192, 32, 128]); view_as_real_67 = None + convert_element_type_1160 = torch.ops.prims.convert_element_type.default(view_1133, torch.bfloat16); view_1133 = None + view_1134 = torch.ops.aten.view.default(squeeze_2, [2, 8192, 1024]); squeeze_2 = None + view_1135 = torch.ops.aten.view.default(convert_element_type_1159, [2, 8192, 1024]); convert_element_type_1159 = None + view_1136 = torch.ops.aten.view.default(convert_element_type_1160, [2, 8192, 4096]); convert_element_type_1160 = None + view_1137 = torch.ops.aten.view.default(view_1134, [16384, 1024]); view_1134 = None + permute_409 = torch.ops.aten.permute.default(view_1137, [1, 0]) + mm_249 = torch.ops.aten.mm.default(permute_409, view_1023); permute_409 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16); primals_277 = None + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 256, '0'); convert_element_type_1000 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_274, [1, 0]); wait_tensor_274 = None + permute_411 = torch.ops.aten.permute.default(permute_332, [1, 0]); permute_332 = None + mm_250 = torch.ops.aten.mm.default(view_1137, permute_411); view_1137 = permute_411 = None + view_1138 = torch.ops.aten.view.default(mm_250, [2, 8192, 4096]); mm_250 = None + convert_element_type_1165 = torch.ops.prims.convert_element_type.default(mm_249, torch.float32); mm_249 = None + reduce_scatter_tensor_16 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1165, 'avg', 256, '0'); convert_element_type_1165 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_16); reduce_scatter_tensor_16 = None + view_1139 = torch.ops.aten.view.default(view_1135, [16384, 1024]); view_1135 = None + permute_413 = torch.ops.aten.permute.default(view_1139, [1, 0]) + mm_251 = torch.ops.aten.mm.default(permute_413, view_1023); permute_413 = None + permute_415 = torch.ops.aten.permute.default(permute_331, [1, 0]); permute_331 = None + mm_252 = torch.ops.aten.mm.default(view_1139, permute_415); view_1139 = permute_415 = None + view_1140 = torch.ops.aten.view.default(mm_252, [2, 8192, 4096]); mm_252 = None + add_140 = torch.ops.aten.add.Tensor(view_1138, view_1140); view_1138 = view_1140 = None + convert_element_type_1170 = torch.ops.prims.convert_element_type.default(mm_251, torch.float32); mm_251 = None + reduce_scatter_tensor_17 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1170, 'avg', 256, '0'); convert_element_type_1170 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17); reduce_scatter_tensor_17 = None + view_1141 = torch.ops.aten.view.default(view_1136, [16384, 4096]); view_1136 = None + permute_417 = torch.ops.aten.permute.default(view_1141, [1, 0]) + mm_253 = torch.ops.aten.mm.default(permute_417, view_1023); permute_417 = view_1023 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16); primals_275 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 256, '0'); convert_element_type_994 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + permute_419 = torch.ops.aten.permute.default(permute_330, [1, 0]); permute_330 = None + mm_254 = torch.ops.aten.mm.default(view_1141, permute_419); view_1141 = permute_419 = None + view_1142 = torch.ops.aten.view.default(mm_254, [2, 8192, 4096]); mm_254 = None + add_141 = torch.ops.aten.add.Tensor(add_140, view_1142); add_140 = view_1142 = None + convert_element_type_1175 = torch.ops.prims.convert_element_type.default(mm_253, torch.float32); mm_253 = None + reduce_scatter_tensor_18 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1175, 'avg', 256, '0'); convert_element_type_1175 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_18); reduce_scatter_tensor_18 = None + convert_element_type_1176 = torch.ops.prims.convert_element_type.default(add_141, torch.float32); add_141 = None + convert_element_type_1178 = torch.ops.prims.convert_element_type.default(wait_tensor_271, torch.float32); wait_tensor_271 = None + mul_298 = torch.ops.aten.mul.Tensor(convert_element_type_1176, convert_element_type_1178); convert_element_type_1178 = None + mul_300 = torch.ops.aten.mul.Tensor(mul_240, mul_298) + sum_13 = torch.ops.aten.sum.dim_IntList(mul_300, [2], True); mul_300 = None + div_4 = torch.ops.aten.div.Tensor(mul_240, 4096) + mul_301 = torch.ops.aten.mul.Tensor(div_4, sum_13); div_4 = sum_13 = None + sub_6 = torch.ops.aten.sub.Tensor(mul_298, mul_301); mul_298 = mul_301 = None + mul_302 = torch.ops.aten.mul.Tensor(sub_6, rsqrt_60); sub_6 = rsqrt_60 = None + mul_303 = torch.ops.aten.mul.Tensor(convert_element_type_1176, mul_240); convert_element_type_1176 = mul_240 = None + sum_14 = torch.ops.aten.sum.dim_IntList(mul_303, [0, 1]); mul_303 = None + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(mul_302, torch.bfloat16); mul_302 = None + add_142 = torch.ops.aten.add.Tensor(add_139, convert_element_type_1179); add_139 = convert_element_type_1179 = None + convert_element_type_default_61 = torch.ops.prims.convert_element_type.default(sum_14, torch.float32); sum_14 = None + reduce_scatter_tensor_19 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_61, 'avg', 256, '0'); convert_element_type_default_61 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19); reduce_scatter_tensor_19 = None + view_1143 = torch.ops.aten.view.default(add_142, [16384, 4096]) + permute_421 = torch.ops.aten.permute.default(view_1143, [1, 0]) + permute_325 = torch.ops.aten.permute.default(getitem_261, [0, 2, 1, 3]) + view_1007 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16); primals_269 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 256, '0'); convert_element_type_974 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + view_1009 = torch.ops.aten.view.default(view_1007, [16384, 4096]); view_1007 = None + mm_206 = torch.ops.aten.mm.default(view_1009, permute_326) + view_1010 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + add_117 = torch.ops.aten.add.Tensor(add_115, view_1010); view_1010 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16); primals_270 = None + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 256, '0'); convert_element_type_977 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32); add_117 = None + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_267) + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + view_1013 = torch.ops.aten.view.default(convert_element_type_979, [16384, 4096]); convert_element_type_979 = None + view_1014 = torch.ops.aten.view.default(mm_207, [2, 8192, 14336]); mm_207 = None + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_1014, torch.float32); view_1014 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16); primals_272 = None + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 256, '0'); convert_element_type_985 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_269, [1, 0]); wait_tensor_269 = None + mm_208 = torch.ops.aten.mm.default(view_1013, permute_328) + view_1017 = torch.ops.aten.view.default(mm_208, [2, 8192, 14336]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_1017) + view_1019 = torch.ops.aten.view.default(mul_239, [16384, 14336]); mul_239 = None + mm_255 = torch.ops.aten.mm.default(permute_421, view_1019); permute_421 = view_1019 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16); primals_273 = None + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 256, '0'); convert_element_type_988 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_270, [1, 0]); wait_tensor_270 = None + permute_423 = torch.ops.aten.permute.default(permute_329, [1, 0]); permute_329 = None + mm_256 = torch.ops.aten.mm.default(view_1143, permute_423); view_1143 = permute_423 = None + view_1144 = torch.ops.aten.view.default(mm_256, [2, 8192, 14336]); mm_256 = None + convert_element_type_1186 = torch.ops.prims.convert_element_type.default(mm_255, torch.float32); mm_255 = None + reduce_scatter_tensor_20 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1186, 'avg', 256, '0'); convert_element_type_1186 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_20); reduce_scatter_tensor_20 = None + mul_304 = torch.ops.aten.mul.Tensor(view_1144, convert_element_type_984); convert_element_type_984 = None + mul_305 = torch.ops.aten.mul.Tensor(view_1144, view_1017); view_1144 = view_1017 = None + view_1145 = torch.ops.aten.view.default(mul_304, [16384, 14336]); mul_304 = None + permute_425 = torch.ops.aten.permute.default(view_1145, [1, 0]) + mm_257 = torch.ops.aten.mm.default(permute_425, view_1013); permute_425 = None + permute_427 = torch.ops.aten.permute.default(permute_328, [1, 0]); permute_328 = None + mm_258 = torch.ops.aten.mm.default(view_1145, permute_427); view_1145 = permute_427 = None + view_1146 = torch.ops.aten.view.default(mm_258, [2, 8192, 4096]); mm_258 = None + convert_element_type_1191 = torch.ops.prims.convert_element_type.default(mm_257, torch.float32); mm_257 = None + reduce_scatter_tensor_21 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1191, 'avg', 256, '0'); convert_element_type_1191 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21); reduce_scatter_tensor_21 = None + convert_element_type_1192 = torch.ops.prims.convert_element_type.default(mul_305, torch.float32); mul_305 = None + neg_2 = torch.ops.aten.neg.default(convert_element_type_983) + exp_2 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_143 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + reciprocal_2 = torch.ops.aten.reciprocal.default(add_143); add_143 = None + mul_306 = torch.ops.aten.mul.Tensor(reciprocal_2, 1); reciprocal_2 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_1192, mul_306); convert_element_type_1192 = None + sub_7 = torch.ops.aten.sub.Tensor(1, mul_306); mul_306 = None + mul_308 = torch.ops.aten.mul.Tensor(convert_element_type_983, sub_7); convert_element_type_983 = sub_7 = None + add_144 = torch.ops.aten.add.Tensor(mul_308, 1); mul_308 = None + mul_309 = torch.ops.aten.mul.Tensor(mul_307, add_144); mul_307 = add_144 = None + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(mul_309, torch.bfloat16); mul_309 = None + view_1147 = torch.ops.aten.view.default(convert_element_type_1194, [16384, 14336]); convert_element_type_1194 = None + permute_429 = torch.ops.aten.permute.default(view_1147, [1, 0]) + mm_259 = torch.ops.aten.mm.default(permute_429, view_1013); permute_429 = view_1013 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16); primals_271 = None + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 256, '0'); convert_element_type_980 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + permute_431 = torch.ops.aten.permute.default(permute_327, [1, 0]); permute_327 = None + mm_260 = torch.ops.aten.mm.default(view_1147, permute_431); view_1147 = permute_431 = None + view_1148 = torch.ops.aten.view.default(mm_260, [2, 8192, 4096]); mm_260 = None + add_145 = torch.ops.aten.add.Tensor(view_1146, view_1148); view_1146 = view_1148 = None + convert_element_type_1199 = torch.ops.prims.convert_element_type.default(mm_259, torch.float32); mm_259 = None + reduce_scatter_tensor_22 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1199, 'avg', 256, '0'); convert_element_type_1199 = None + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_22); reduce_scatter_tensor_22 = None + convert_element_type_1200 = torch.ops.prims.convert_element_type.default(add_145, torch.float32); add_145 = None + convert_element_type_1202 = torch.ops.prims.convert_element_type.default(wait_tensor_267, torch.float32); wait_tensor_267 = None + mul_310 = torch.ops.aten.mul.Tensor(convert_element_type_1200, convert_element_type_1202); convert_element_type_1202 = None + mul_312 = torch.ops.aten.mul.Tensor(mul_236, mul_310) + sum_15 = torch.ops.aten.sum.dim_IntList(mul_312, [2], True); mul_312 = None + div_5 = torch.ops.aten.div.Tensor(mul_236, 4096) + mul_313 = torch.ops.aten.mul.Tensor(div_5, sum_15); div_5 = sum_15 = None + sub_8 = torch.ops.aten.sub.Tensor(mul_310, mul_313); mul_310 = mul_313 = None + mul_314 = torch.ops.aten.mul.Tensor(sub_8, rsqrt_59); sub_8 = rsqrt_59 = None + mul_315 = torch.ops.aten.mul.Tensor(convert_element_type_1200, mul_236); convert_element_type_1200 = mul_236 = None + sum_16 = torch.ops.aten.sum.dim_IntList(mul_315, [0, 1]); mul_315 = None + convert_element_type_1203 = torch.ops.prims.convert_element_type.default(mul_314, torch.bfloat16); mul_314 = None + add_146 = torch.ops.aten.add.Tensor(add_142, convert_element_type_1203); add_142 = convert_element_type_1203 = None + convert_element_type_default_60 = torch.ops.prims.convert_element_type.default(sum_16, torch.float32); sum_16 = None + reduce_scatter_tensor_23 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_60, 'avg', 256, '0'); convert_element_type_default_60 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23); reduce_scatter_tensor_23 = None + view_1149 = torch.ops.aten.view.default(add_146, [16384, 4096]) + permute_433 = torch.ops.aten.permute.default(view_1149, [1, 0]) + mm_261 = torch.ops.aten.mm.default(permute_433, view_1009); permute_433 = view_1009 = None + permute_435 = torch.ops.aten.permute.default(permute_326, [1, 0]); permute_326 = None + mm_262 = torch.ops.aten.mm.default(view_1149, permute_435); view_1149 = permute_435 = None + view_1150 = torch.ops.aten.view.default(mm_262, [2, 8192, 4096]); mm_262 = None + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(mm_261, torch.float32); mm_261 = None + reduce_scatter_tensor_24 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1210, 'avg', 256, '0'); convert_element_type_1210 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_24); reduce_scatter_tensor_24 = None + view_1151 = torch.ops.aten.view.default(view_1150, [2, 8192, 32, 128]); view_1150 = None + permute_437 = torch.ops.aten.permute.default(view_1151, [0, 2, 1, 3]); view_1151 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16); primals_265 = None + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 256, '0'); convert_element_type_958 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32); add_115 = None + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_262) + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + view_989 = torch.ops.aten.view.default(convert_element_type_960, [16384, 4096]); convert_element_type_960 = None + view_990 = torch.ops.aten.view.default(mm_203, [2, 8192, 4096]); mm_203 = None + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 256, '0'); convert_element_type_964 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + mm_204 = torch.ops.aten.mm.default(view_989, permute_320) + view_993 = torch.ops.aten.view.default(mm_204, [2, 8192, 1024]); mm_204 = None + view_996 = torch.ops.aten.view.default(mm_205, [2, 8192, 1024]); mm_205 = None + view_997 = torch.ops.aten.view.default(view_990, [2, 8192, -1, 128]); view_990 = None + view_998 = torch.ops.aten.view.default(view_993, [2, 8192, -1, 128]); view_993 = None + view_999 = torch.ops.aten.view.default(view_996, [2, 8192, -1, 128]); view_996 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + view_1000 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 32, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1000); view_1000 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_998, torch.float32); view_998 = None + view_1001 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 8, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1001); view_1001 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_16); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_1003 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 32, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_16); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_1004 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 8, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_1003, torch.bfloat16); view_1003 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_1004, torch.bfloat16); view_1004 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 8, 4, 128]); unsqueeze_58 = None + clone_58 = torch.ops.aten.clone.default(expand_58, memory_format = torch.contiguous_format); expand_58 = None + view_1005 = torch.ops.aten.view.default(clone_58, [2, 8192, 32, 128]); clone_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_999, 3); view_999 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 8, 4, 128]); unsqueeze_59 = None + clone_59 = torch.ops.aten.clone.default(expand_59, memory_format = torch.contiguous_format); expand_59 = None + view_1006 = torch.ops.aten.view.default(clone_59, [2, 8192, 32, 128]); clone_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_1005, [0, 2, 1, 3]); view_1005 = None + permute_324 = torch.ops.aten.permute.default(view_1006, [0, 2, 1, 3]); view_1006 = None + _scaled_dot_product_cudnn_attention_backward_2 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_437, permute_322, permute_323, permute_324, getitem_261, getitem_262, getitem_267, getitem_268, None, None, None, 8192, 8192, 0.0, True); permute_437 = permute_322 = permute_323 = permute_324 = getitem_261 = getitem_262 = getitem_267 = getitem_268 = None + getitem_294 = _scaled_dot_product_cudnn_attention_backward_2[0] + getitem_295 = _scaled_dot_product_cudnn_attention_backward_2[1] + getitem_296 = _scaled_dot_product_cudnn_attention_backward_2[2]; _scaled_dot_product_cudnn_attention_backward_2 = None + permute_438 = torch.ops.aten.permute.default(getitem_296, [0, 2, 1, 3]); getitem_296 = None + permute_439 = torch.ops.aten.permute.default(getitem_295, [0, 2, 1, 3]); getitem_295 = None + permute_440 = torch.ops.aten.permute.default(getitem_294, [0, 2, 1, 3]); getitem_294 = None + view_1152 = torch.ops.aten.view.default(permute_438, [2, 8192, 8, 4, 128]); permute_438 = None + sum_17 = torch.ops.aten.sum.dim_IntList(view_1152, [3], True); view_1152 = None + squeeze_4 = torch.ops.aten.squeeze.dim(sum_17, 3); sum_17 = None + view_1153 = torch.ops.aten.view.default(permute_439, [2, 8192, 8, 4, 128]); permute_439 = None + sum_18 = torch.ops.aten.sum.dim_IntList(view_1153, [3], True); view_1153 = None + squeeze_5 = torch.ops.aten.squeeze.dim(sum_18, 3); sum_18 = None + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(squeeze_5, torch.float32); squeeze_5 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(permute_440, torch.float32); permute_440 = None + view_1154 = torch.ops.aten.view.default(convert_element_type_1211, [2, 8192, 8, 64, 2]); convert_element_type_1211 = None + view_as_complex_68 = torch.ops.aten.view_as_complex.default(view_1154); view_1154 = None + mul_316 = torch.ops.aten.mul.Tensor(view_as_complex_68, _conj); view_as_complex_68 = None + view_1155 = torch.ops.aten.view.default(convert_element_type_1212, [2, 8192, 32, 64, 2]); convert_element_type_1212 = None + view_as_complex_69 = torch.ops.aten.view_as_complex.default(view_1155); view_1155 = None + mul_317 = torch.ops.aten.mul.Tensor(view_as_complex_69, _conj); view_as_complex_69 = None + view_as_real_68 = torch.ops.aten.view_as_real.default(mul_316); mul_316 = None + view_1156 = torch.ops.aten.view.default(view_as_real_68, [2, 8192, 8, 128]); view_as_real_68 = None + convert_element_type_1213 = torch.ops.prims.convert_element_type.default(view_1156, torch.bfloat16); view_1156 = None + view_as_real_69 = torch.ops.aten.view_as_real.default(mul_317); mul_317 = None + view_1157 = torch.ops.aten.view.default(view_as_real_69, [2, 8192, 32, 128]); view_as_real_69 = None + convert_element_type_1214 = torch.ops.prims.convert_element_type.default(view_1157, torch.bfloat16); view_1157 = None + view_1158 = torch.ops.aten.view.default(squeeze_4, [2, 8192, 1024]); squeeze_4 = None + view_1159 = torch.ops.aten.view.default(convert_element_type_1213, [2, 8192, 1024]); convert_element_type_1213 = None + view_1160 = torch.ops.aten.view.default(convert_element_type_1214, [2, 8192, 4096]); convert_element_type_1214 = None + view_1161 = torch.ops.aten.view.default(view_1158, [16384, 1024]); view_1158 = None + permute_441 = torch.ops.aten.permute.default(view_1161, [1, 0]) + mm_263 = torch.ops.aten.mm.default(permute_441, view_989); permute_441 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 256, '0'); convert_element_type_967 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + permute_443 = torch.ops.aten.permute.default(permute_321, [1, 0]); permute_321 = None + mm_264 = torch.ops.aten.mm.default(view_1161, permute_443); view_1161 = permute_443 = None + view_1162 = torch.ops.aten.view.default(mm_264, [2, 8192, 4096]); mm_264 = None + convert_element_type_1219 = torch.ops.prims.convert_element_type.default(mm_263, torch.float32); mm_263 = None + reduce_scatter_tensor_25 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1219, 'avg', 256, '0'); convert_element_type_1219 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25); reduce_scatter_tensor_25 = None + view_1163 = torch.ops.aten.view.default(view_1159, [16384, 1024]); view_1159 = None + permute_445 = torch.ops.aten.permute.default(view_1163, [1, 0]) + mm_265 = torch.ops.aten.mm.default(permute_445, view_989); permute_445 = None + permute_447 = torch.ops.aten.permute.default(permute_320, [1, 0]); permute_320 = None + mm_266 = torch.ops.aten.mm.default(view_1163, permute_447); view_1163 = permute_447 = None + view_1164 = torch.ops.aten.view.default(mm_266, [2, 8192, 4096]); mm_266 = None + add_147 = torch.ops.aten.add.Tensor(view_1162, view_1164); view_1162 = view_1164 = None + convert_element_type_1224 = torch.ops.prims.convert_element_type.default(mm_265, torch.float32); mm_265 = None + reduce_scatter_tensor_26 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1224, 'avg', 256, '0'); convert_element_type_1224 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_26); reduce_scatter_tensor_26 = None + view_1165 = torch.ops.aten.view.default(view_1160, [16384, 4096]); view_1160 = None + permute_449 = torch.ops.aten.permute.default(view_1165, [1, 0]) + mm_267 = torch.ops.aten.mm.default(permute_449, view_989); permute_449 = view_989 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16); primals_266 = None + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 256, '0'); convert_element_type_961 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_263, [1, 0]); wait_tensor_263 = None + permute_451 = torch.ops.aten.permute.default(permute_319, [1, 0]); permute_319 = None + mm_268 = torch.ops.aten.mm.default(view_1165, permute_451); view_1165 = permute_451 = None + view_1166 = torch.ops.aten.view.default(mm_268, [2, 8192, 4096]); mm_268 = None + add_148 = torch.ops.aten.add.Tensor(add_147, view_1166); add_147 = view_1166 = None + convert_element_type_1229 = torch.ops.prims.convert_element_type.default(mm_267, torch.float32); mm_267 = None + reduce_scatter_tensor_27 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1229, 'avg', 256, '0'); convert_element_type_1229 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27); reduce_scatter_tensor_27 = None + convert_element_type_1230 = torch.ops.prims.convert_element_type.default(add_148, torch.float32); add_148 = None + convert_element_type_1232 = torch.ops.prims.convert_element_type.default(wait_tensor_262, torch.float32); wait_tensor_262 = None + mul_318 = torch.ops.aten.mul.Tensor(convert_element_type_1230, convert_element_type_1232); convert_element_type_1232 = None + mul_320 = torch.ops.aten.mul.Tensor(mul_232, mul_318) + sum_19 = torch.ops.aten.sum.dim_IntList(mul_320, [2], True); mul_320 = None + div_6 = torch.ops.aten.div.Tensor(mul_232, 4096) + mul_321 = torch.ops.aten.mul.Tensor(div_6, sum_19); div_6 = sum_19 = None + sub_9 = torch.ops.aten.sub.Tensor(mul_318, mul_321); mul_318 = mul_321 = None + mul_322 = torch.ops.aten.mul.Tensor(sub_9, rsqrt_58); sub_9 = rsqrt_58 = None + mul_323 = torch.ops.aten.mul.Tensor(convert_element_type_1230, mul_232); convert_element_type_1230 = mul_232 = None + sum_20 = torch.ops.aten.sum.dim_IntList(mul_323, [0, 1]); mul_323 = None + convert_element_type_1233 = torch.ops.prims.convert_element_type.default(mul_322, torch.bfloat16); mul_322 = None + add_149 = torch.ops.aten.add.Tensor(add_146, convert_element_type_1233); add_146 = convert_element_type_1233 = None + convert_element_type_default_59 = torch.ops.prims.convert_element_type.default(sum_20, torch.float32); sum_20 = None + reduce_scatter_tensor_28 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_59, 'avg', 256, '0'); convert_element_type_default_59 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_28); reduce_scatter_tensor_28 = None + view_1167 = torch.ops.aten.view.default(add_149, [16384, 4096]) + permute_453 = torch.ops.aten.permute.default(view_1167, [1, 0]) + permute_314 = torch.ops.aten.permute.default(getitem_252, [0, 2, 1, 3]) + view_973 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16); primals_260 = None + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 256, '0'); convert_element_type_941 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_257, [1, 0]); wait_tensor_257 = None + view_975 = torch.ops.aten.view.default(view_973, [16384, 4096]); view_973 = None + mm_199 = torch.ops.aten.mm.default(view_975, permute_315) + view_976 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + add_113 = torch.ops.aten.add.Tensor(add_111, view_976); view_976 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16); primals_261 = None + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 256, '0'); convert_element_type_944 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32); add_113 = None + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_258) + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + view_979 = torch.ops.aten.view.default(convert_element_type_946, [16384, 4096]); convert_element_type_946 = None + view_980 = torch.ops.aten.view.default(mm_200, [2, 8192, 14336]); mm_200 = None + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_980, torch.float32); view_980 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16); primals_263 = None + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 256, '0'); convert_element_type_952 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + mm_201 = torch.ops.aten.mm.default(view_979, permute_317) + view_983 = torch.ops.aten.view.default(mm_201, [2, 8192, 14336]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_983) + view_985 = torch.ops.aten.view.default(mul_231, [16384, 14336]); mul_231 = None + mm_269 = torch.ops.aten.mm.default(permute_453, view_985); permute_453 = view_985 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16); primals_264 = None + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 256, '0'); convert_element_type_955 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + permute_455 = torch.ops.aten.permute.default(permute_318, [1, 0]); permute_318 = None + mm_270 = torch.ops.aten.mm.default(view_1167, permute_455); view_1167 = permute_455 = None + view_1168 = torch.ops.aten.view.default(mm_270, [2, 8192, 14336]); mm_270 = None + convert_element_type_1240 = torch.ops.prims.convert_element_type.default(mm_269, torch.float32); mm_269 = None + reduce_scatter_tensor_29 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1240, 'avg', 256, '0'); convert_element_type_1240 = None + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29); reduce_scatter_tensor_29 = None + mul_324 = torch.ops.aten.mul.Tensor(view_1168, convert_element_type_951); convert_element_type_951 = None + mul_325 = torch.ops.aten.mul.Tensor(view_1168, view_983); view_1168 = view_983 = None + view_1169 = torch.ops.aten.view.default(mul_324, [16384, 14336]); mul_324 = None + permute_457 = torch.ops.aten.permute.default(view_1169, [1, 0]) + mm_271 = torch.ops.aten.mm.default(permute_457, view_979); permute_457 = None + permute_459 = torch.ops.aten.permute.default(permute_317, [1, 0]); permute_317 = None + mm_272 = torch.ops.aten.mm.default(view_1169, permute_459); view_1169 = permute_459 = None + view_1170 = torch.ops.aten.view.default(mm_272, [2, 8192, 4096]); mm_272 = None + convert_element_type_1245 = torch.ops.prims.convert_element_type.default(mm_271, torch.float32); mm_271 = None + reduce_scatter_tensor_30 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1245, 'avg', 256, '0'); convert_element_type_1245 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_30); reduce_scatter_tensor_30 = None + convert_element_type_1246 = torch.ops.prims.convert_element_type.default(mul_325, torch.float32); mul_325 = None + neg_3 = torch.ops.aten.neg.default(convert_element_type_950) + exp_3 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_150 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + reciprocal_3 = torch.ops.aten.reciprocal.default(add_150); add_150 = None + mul_326 = torch.ops.aten.mul.Tensor(reciprocal_3, 1); reciprocal_3 = None + mul_327 = torch.ops.aten.mul.Tensor(convert_element_type_1246, mul_326); convert_element_type_1246 = None + sub_10 = torch.ops.aten.sub.Tensor(1, mul_326); mul_326 = None + mul_328 = torch.ops.aten.mul.Tensor(convert_element_type_950, sub_10); convert_element_type_950 = sub_10 = None + add_151 = torch.ops.aten.add.Tensor(mul_328, 1); mul_328 = None + mul_329 = torch.ops.aten.mul.Tensor(mul_327, add_151); mul_327 = add_151 = None + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(mul_329, torch.bfloat16); mul_329 = None + view_1171 = torch.ops.aten.view.default(convert_element_type_1248, [16384, 14336]); convert_element_type_1248 = None + permute_461 = torch.ops.aten.permute.default(view_1171, [1, 0]) + mm_273 = torch.ops.aten.mm.default(permute_461, view_979); permute_461 = view_979 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16); primals_262 = None + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 256, '0'); convert_element_type_947 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + permute_463 = torch.ops.aten.permute.default(permute_316, [1, 0]); permute_316 = None + mm_274 = torch.ops.aten.mm.default(view_1171, permute_463); view_1171 = permute_463 = None + view_1172 = torch.ops.aten.view.default(mm_274, [2, 8192, 4096]); mm_274 = None + add_152 = torch.ops.aten.add.Tensor(view_1170, view_1172); view_1170 = view_1172 = None + convert_element_type_1253 = torch.ops.prims.convert_element_type.default(mm_273, torch.float32); mm_273 = None + reduce_scatter_tensor_31 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1253, 'avg', 256, '0'); convert_element_type_1253 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31); reduce_scatter_tensor_31 = None + convert_element_type_1254 = torch.ops.prims.convert_element_type.default(add_152, torch.float32); add_152 = None + convert_element_type_1256 = torch.ops.prims.convert_element_type.default(wait_tensor_258, torch.float32); wait_tensor_258 = None + mul_330 = torch.ops.aten.mul.Tensor(convert_element_type_1254, convert_element_type_1256); convert_element_type_1256 = None + mul_332 = torch.ops.aten.mul.Tensor(mul_228, mul_330) + sum_21 = torch.ops.aten.sum.dim_IntList(mul_332, [2], True); mul_332 = None + div_7 = torch.ops.aten.div.Tensor(mul_228, 4096) + mul_333 = torch.ops.aten.mul.Tensor(div_7, sum_21); div_7 = sum_21 = None + sub_11 = torch.ops.aten.sub.Tensor(mul_330, mul_333); mul_330 = mul_333 = None + mul_334 = torch.ops.aten.mul.Tensor(sub_11, rsqrt_57); sub_11 = rsqrt_57 = None + mul_335 = torch.ops.aten.mul.Tensor(convert_element_type_1254, mul_228); convert_element_type_1254 = mul_228 = None + sum_22 = torch.ops.aten.sum.dim_IntList(mul_335, [0, 1]); mul_335 = None + convert_element_type_1257 = torch.ops.prims.convert_element_type.default(mul_334, torch.bfloat16); mul_334 = None + add_153 = torch.ops.aten.add.Tensor(add_149, convert_element_type_1257); add_149 = convert_element_type_1257 = None + convert_element_type_default_58 = torch.ops.prims.convert_element_type.default(sum_22, torch.float32); sum_22 = None + reduce_scatter_tensor_32 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_58, 'avg', 256, '0'); convert_element_type_default_58 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + view_1173 = torch.ops.aten.view.default(add_153, [16384, 4096]) + permute_465 = torch.ops.aten.permute.default(view_1173, [1, 0]) + mm_275 = torch.ops.aten.mm.default(permute_465, view_975); permute_465 = view_975 = None + permute_467 = torch.ops.aten.permute.default(permute_315, [1, 0]); permute_315 = None + mm_276 = torch.ops.aten.mm.default(view_1173, permute_467); view_1173 = permute_467 = None + view_1174 = torch.ops.aten.view.default(mm_276, [2, 8192, 4096]); mm_276 = None + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(mm_275, torch.float32); mm_275 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1264, 'avg', 256, '0'); convert_element_type_1264 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33); reduce_scatter_tensor_33 = None + view_1175 = torch.ops.aten.view.default(view_1174, [2, 8192, 32, 128]); view_1174 = None + permute_469 = torch.ops.aten.permute.default(view_1175, [0, 2, 1, 3]); view_1175 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16); primals_256 = None + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 256, '0'); convert_element_type_925 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32); add_111 = None + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_253) + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + view_955 = torch.ops.aten.view.default(convert_element_type_927, [16384, 4096]); convert_element_type_927 = None + view_956 = torch.ops.aten.view.default(mm_196, [2, 8192, 4096]); mm_196 = None + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16); primals_258 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 256, '0'); convert_element_type_931 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_255, [1, 0]); wait_tensor_255 = None + mm_197 = torch.ops.aten.mm.default(view_955, permute_309) + view_959 = torch.ops.aten.view.default(mm_197, [2, 8192, 1024]); mm_197 = None + view_962 = torch.ops.aten.view.default(mm_198, [2, 8192, 1024]); mm_198 = None + view_963 = torch.ops.aten.view.default(view_956, [2, 8192, -1, 128]); view_956 = None + view_964 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_965 = torch.ops.aten.view.default(view_962, [2, 8192, -1, 128]); view_962 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_963, torch.float32); view_963 = None + view_966 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 32, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_966); view_966 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_964, torch.float32); view_964 = None + view_967 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 8, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_967); view_967 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_16); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_969 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 32, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_16); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_970 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 8, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_969, torch.bfloat16); view_969 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_970, torch.bfloat16); view_970 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 8, 4, 128]); unsqueeze_56 = None + clone_56 = torch.ops.aten.clone.default(expand_56, memory_format = torch.contiguous_format); expand_56 = None + view_971 = torch.ops.aten.view.default(clone_56, [2, 8192, 32, 128]); clone_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_965, 3); view_965 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 8, 4, 128]); unsqueeze_57 = None + clone_57 = torch.ops.aten.clone.default(expand_57, memory_format = torch.contiguous_format); expand_57 = None + view_972 = torch.ops.aten.view.default(clone_57, [2, 8192, 32, 128]); clone_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_971, [0, 2, 1, 3]); view_971 = None + permute_313 = torch.ops.aten.permute.default(view_972, [0, 2, 1, 3]); view_972 = None + _scaled_dot_product_cudnn_attention_backward_3 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_469, permute_311, permute_312, permute_313, getitem_252, getitem_253, getitem_258, getitem_259, None, None, None, 8192, 8192, 0.0, True); permute_469 = permute_311 = permute_312 = permute_313 = getitem_252 = getitem_253 = getitem_258 = getitem_259 = None + getitem_297 = _scaled_dot_product_cudnn_attention_backward_3[0] + getitem_298 = _scaled_dot_product_cudnn_attention_backward_3[1] + getitem_299 = _scaled_dot_product_cudnn_attention_backward_3[2]; _scaled_dot_product_cudnn_attention_backward_3 = None + permute_470 = torch.ops.aten.permute.default(getitem_299, [0, 2, 1, 3]); getitem_299 = None + permute_471 = torch.ops.aten.permute.default(getitem_298, [0, 2, 1, 3]); getitem_298 = None + permute_472 = torch.ops.aten.permute.default(getitem_297, [0, 2, 1, 3]); getitem_297 = None + view_1176 = torch.ops.aten.view.default(permute_470, [2, 8192, 8, 4, 128]); permute_470 = None + sum_23 = torch.ops.aten.sum.dim_IntList(view_1176, [3], True); view_1176 = None + squeeze_6 = torch.ops.aten.squeeze.dim(sum_23, 3); sum_23 = None + view_1177 = torch.ops.aten.view.default(permute_471, [2, 8192, 8, 4, 128]); permute_471 = None + sum_24 = torch.ops.aten.sum.dim_IntList(view_1177, [3], True); view_1177 = None + squeeze_7 = torch.ops.aten.squeeze.dim(sum_24, 3); sum_24 = None + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(squeeze_7, torch.float32); squeeze_7 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(permute_472, torch.float32); permute_472 = None + view_1178 = torch.ops.aten.view.default(convert_element_type_1265, [2, 8192, 8, 64, 2]); convert_element_type_1265 = None + view_as_complex_70 = torch.ops.aten.view_as_complex.default(view_1178); view_1178 = None + mul_336 = torch.ops.aten.mul.Tensor(view_as_complex_70, _conj); view_as_complex_70 = None + view_1179 = torch.ops.aten.view.default(convert_element_type_1266, [2, 8192, 32, 64, 2]); convert_element_type_1266 = None + view_as_complex_71 = torch.ops.aten.view_as_complex.default(view_1179); view_1179 = None + mul_337 = torch.ops.aten.mul.Tensor(view_as_complex_71, _conj); view_as_complex_71 = None + view_as_real_70 = torch.ops.aten.view_as_real.default(mul_336); mul_336 = None + view_1180 = torch.ops.aten.view.default(view_as_real_70, [2, 8192, 8, 128]); view_as_real_70 = None + convert_element_type_1267 = torch.ops.prims.convert_element_type.default(view_1180, torch.bfloat16); view_1180 = None + view_as_real_71 = torch.ops.aten.view_as_real.default(mul_337); mul_337 = None + view_1181 = torch.ops.aten.view.default(view_as_real_71, [2, 8192, 32, 128]); view_as_real_71 = None + convert_element_type_1268 = torch.ops.prims.convert_element_type.default(view_1181, torch.bfloat16); view_1181 = None + view_1182 = torch.ops.aten.view.default(squeeze_6, [2, 8192, 1024]); squeeze_6 = None + view_1183 = torch.ops.aten.view.default(convert_element_type_1267, [2, 8192, 1024]); convert_element_type_1267 = None + view_1184 = torch.ops.aten.view.default(convert_element_type_1268, [2, 8192, 4096]); convert_element_type_1268 = None + view_1185 = torch.ops.aten.view.default(view_1182, [16384, 1024]); view_1182 = None + permute_473 = torch.ops.aten.permute.default(view_1185, [1, 0]) + mm_277 = torch.ops.aten.mm.default(permute_473, view_955); permute_473 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16); primals_259 = None + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 256, '0'); convert_element_type_934 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_256, [1, 0]); wait_tensor_256 = None + permute_475 = torch.ops.aten.permute.default(permute_310, [1, 0]); permute_310 = None + mm_278 = torch.ops.aten.mm.default(view_1185, permute_475); view_1185 = permute_475 = None + view_1186 = torch.ops.aten.view.default(mm_278, [2, 8192, 4096]); mm_278 = None + convert_element_type_1273 = torch.ops.prims.convert_element_type.default(mm_277, torch.float32); mm_277 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1273, 'avg', 256, '0'); convert_element_type_1273 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + view_1187 = torch.ops.aten.view.default(view_1183, [16384, 1024]); view_1183 = None + permute_477 = torch.ops.aten.permute.default(view_1187, [1, 0]) + mm_279 = torch.ops.aten.mm.default(permute_477, view_955); permute_477 = None + permute_479 = torch.ops.aten.permute.default(permute_309, [1, 0]); permute_309 = None + mm_280 = torch.ops.aten.mm.default(view_1187, permute_479); view_1187 = permute_479 = None + view_1188 = torch.ops.aten.view.default(mm_280, [2, 8192, 4096]); mm_280 = None + add_154 = torch.ops.aten.add.Tensor(view_1186, view_1188); view_1186 = view_1188 = None + convert_element_type_1278 = torch.ops.prims.convert_element_type.default(mm_279, torch.float32); mm_279 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1278, 'avg', 256, '0'); convert_element_type_1278 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35); reduce_scatter_tensor_35 = None + view_1189 = torch.ops.aten.view.default(view_1184, [16384, 4096]); view_1184 = None + permute_481 = torch.ops.aten.permute.default(view_1189, [1, 0]) + mm_281 = torch.ops.aten.mm.default(permute_481, view_955); permute_481 = view_955 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16); primals_257 = None + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 256, '0'); convert_element_type_928 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + permute_483 = torch.ops.aten.permute.default(permute_308, [1, 0]); permute_308 = None + mm_282 = torch.ops.aten.mm.default(view_1189, permute_483); view_1189 = permute_483 = None + view_1190 = torch.ops.aten.view.default(mm_282, [2, 8192, 4096]); mm_282 = None + add_155 = torch.ops.aten.add.Tensor(add_154, view_1190); add_154 = view_1190 = None + convert_element_type_1283 = torch.ops.prims.convert_element_type.default(mm_281, torch.float32); mm_281 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1283, 'avg', 256, '0'); convert_element_type_1283 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + convert_element_type_1284 = torch.ops.prims.convert_element_type.default(add_155, torch.float32); add_155 = None + convert_element_type_1286 = torch.ops.prims.convert_element_type.default(wait_tensor_253, torch.float32); wait_tensor_253 = None + mul_338 = torch.ops.aten.mul.Tensor(convert_element_type_1284, convert_element_type_1286); convert_element_type_1286 = None + mul_340 = torch.ops.aten.mul.Tensor(mul_224, mul_338) + sum_25 = torch.ops.aten.sum.dim_IntList(mul_340, [2], True); mul_340 = None + div_8 = torch.ops.aten.div.Tensor(mul_224, 4096) + mul_341 = torch.ops.aten.mul.Tensor(div_8, sum_25); div_8 = sum_25 = None + sub_12 = torch.ops.aten.sub.Tensor(mul_338, mul_341); mul_338 = mul_341 = None + mul_342 = torch.ops.aten.mul.Tensor(sub_12, rsqrt_56); sub_12 = rsqrt_56 = None + mul_343 = torch.ops.aten.mul.Tensor(convert_element_type_1284, mul_224); convert_element_type_1284 = mul_224 = None + sum_26 = torch.ops.aten.sum.dim_IntList(mul_343, [0, 1]); mul_343 = None + convert_element_type_1287 = torch.ops.prims.convert_element_type.default(mul_342, torch.bfloat16); mul_342 = None + add_156 = torch.ops.aten.add.Tensor(add_153, convert_element_type_1287); add_153 = convert_element_type_1287 = None + convert_element_type_default_57 = torch.ops.prims.convert_element_type.default(sum_26, torch.float32); sum_26 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_57, 'avg', 256, '0'); convert_element_type_default_57 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37); reduce_scatter_tensor_37 = None + view_1191 = torch.ops.aten.view.default(add_156, [16384, 4096]) + permute_485 = torch.ops.aten.permute.default(view_1191, [1, 0]) + permute_303 = torch.ops.aten.permute.default(getitem_243, [0, 2, 1, 3]) + view_939 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16); primals_251 = None + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 256, '0'); convert_element_type_908 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_248, [1, 0]); wait_tensor_248 = None + view_941 = torch.ops.aten.view.default(view_939, [16384, 4096]); view_939 = None + mm_192 = torch.ops.aten.mm.default(view_941, permute_304) + view_942 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + add_109 = torch.ops.aten.add.Tensor(add_107, view_942); view_942 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16); primals_252 = None + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 256, '0'); convert_element_type_911 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32); add_109 = None + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_249) + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + view_945 = torch.ops.aten.view.default(convert_element_type_913, [16384, 4096]); convert_element_type_913 = None + view_946 = torch.ops.aten.view.default(mm_193, [2, 8192, 14336]); mm_193 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_946, torch.float32); view_946 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16); primals_254 = None + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 256, '0'); convert_element_type_919 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + mm_194 = torch.ops.aten.mm.default(view_945, permute_306) + view_949 = torch.ops.aten.view.default(mm_194, [2, 8192, 14336]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_949) + view_951 = torch.ops.aten.view.default(mul_223, [16384, 14336]); mul_223 = None + mm_283 = torch.ops.aten.mm.default(permute_485, view_951); permute_485 = view_951 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16); primals_255 = None + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 256, '0'); convert_element_type_922 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + permute_487 = torch.ops.aten.permute.default(permute_307, [1, 0]); permute_307 = None + mm_284 = torch.ops.aten.mm.default(view_1191, permute_487); view_1191 = permute_487 = None + view_1192 = torch.ops.aten.view.default(mm_284, [2, 8192, 14336]); mm_284 = None + convert_element_type_1294 = torch.ops.prims.convert_element_type.default(mm_283, torch.float32); mm_283 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1294, 'avg', 256, '0'); convert_element_type_1294 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + mul_344 = torch.ops.aten.mul.Tensor(view_1192, convert_element_type_918); convert_element_type_918 = None + mul_345 = torch.ops.aten.mul.Tensor(view_1192, view_949); view_1192 = view_949 = None + view_1193 = torch.ops.aten.view.default(mul_344, [16384, 14336]); mul_344 = None + permute_489 = torch.ops.aten.permute.default(view_1193, [1, 0]) + mm_285 = torch.ops.aten.mm.default(permute_489, view_945); permute_489 = None + permute_491 = torch.ops.aten.permute.default(permute_306, [1, 0]); permute_306 = None + mm_286 = torch.ops.aten.mm.default(view_1193, permute_491); view_1193 = permute_491 = None + view_1194 = torch.ops.aten.view.default(mm_286, [2, 8192, 4096]); mm_286 = None + convert_element_type_1299 = torch.ops.prims.convert_element_type.default(mm_285, torch.float32); mm_285 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1299, 'avg', 256, '0'); convert_element_type_1299 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39); reduce_scatter_tensor_39 = None + convert_element_type_1300 = torch.ops.prims.convert_element_type.default(mul_345, torch.float32); mul_345 = None + neg_4 = torch.ops.aten.neg.default(convert_element_type_917) + exp_4 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_157 = torch.ops.aten.add.Tensor(exp_4, 1); exp_4 = None + reciprocal_4 = torch.ops.aten.reciprocal.default(add_157); add_157 = None + mul_346 = torch.ops.aten.mul.Tensor(reciprocal_4, 1); reciprocal_4 = None + mul_347 = torch.ops.aten.mul.Tensor(convert_element_type_1300, mul_346); convert_element_type_1300 = None + sub_13 = torch.ops.aten.sub.Tensor(1, mul_346); mul_346 = None + mul_348 = torch.ops.aten.mul.Tensor(convert_element_type_917, sub_13); convert_element_type_917 = sub_13 = None + add_158 = torch.ops.aten.add.Tensor(mul_348, 1); mul_348 = None + mul_349 = torch.ops.aten.mul.Tensor(mul_347, add_158); mul_347 = add_158 = None + convert_element_type_1302 = torch.ops.prims.convert_element_type.default(mul_349, torch.bfloat16); mul_349 = None + view_1195 = torch.ops.aten.view.default(convert_element_type_1302, [16384, 14336]); convert_element_type_1302 = None + permute_493 = torch.ops.aten.permute.default(view_1195, [1, 0]) + mm_287 = torch.ops.aten.mm.default(permute_493, view_945); permute_493 = view_945 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16); primals_253 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 256, '0'); convert_element_type_914 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_250, [1, 0]); wait_tensor_250 = None + permute_495 = torch.ops.aten.permute.default(permute_305, [1, 0]); permute_305 = None + mm_288 = torch.ops.aten.mm.default(view_1195, permute_495); view_1195 = permute_495 = None + view_1196 = torch.ops.aten.view.default(mm_288, [2, 8192, 4096]); mm_288 = None + add_159 = torch.ops.aten.add.Tensor(view_1194, view_1196); view_1194 = view_1196 = None + convert_element_type_1307 = torch.ops.prims.convert_element_type.default(mm_287, torch.float32); mm_287 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1307, 'avg', 256, '0'); convert_element_type_1307 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + convert_element_type_1308 = torch.ops.prims.convert_element_type.default(add_159, torch.float32); add_159 = None + convert_element_type_1310 = torch.ops.prims.convert_element_type.default(wait_tensor_249, torch.float32); wait_tensor_249 = None + mul_350 = torch.ops.aten.mul.Tensor(convert_element_type_1308, convert_element_type_1310); convert_element_type_1310 = None + mul_352 = torch.ops.aten.mul.Tensor(mul_220, mul_350) + sum_27 = torch.ops.aten.sum.dim_IntList(mul_352, [2], True); mul_352 = None + div_9 = torch.ops.aten.div.Tensor(mul_220, 4096) + mul_353 = torch.ops.aten.mul.Tensor(div_9, sum_27); div_9 = sum_27 = None + sub_14 = torch.ops.aten.sub.Tensor(mul_350, mul_353); mul_350 = mul_353 = None + mul_354 = torch.ops.aten.mul.Tensor(sub_14, rsqrt_55); sub_14 = rsqrt_55 = None + mul_355 = torch.ops.aten.mul.Tensor(convert_element_type_1308, mul_220); convert_element_type_1308 = mul_220 = None + sum_28 = torch.ops.aten.sum.dim_IntList(mul_355, [0, 1]); mul_355 = None + convert_element_type_1311 = torch.ops.prims.convert_element_type.default(mul_354, torch.bfloat16); mul_354 = None + add_160 = torch.ops.aten.add.Tensor(add_156, convert_element_type_1311); add_156 = convert_element_type_1311 = None + convert_element_type_default_56 = torch.ops.prims.convert_element_type.default(sum_28, torch.float32); sum_28 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_56, 'avg', 256, '0'); convert_element_type_default_56 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41); reduce_scatter_tensor_41 = None + view_1197 = torch.ops.aten.view.default(add_160, [16384, 4096]) + permute_497 = torch.ops.aten.permute.default(view_1197, [1, 0]) + mm_289 = torch.ops.aten.mm.default(permute_497, view_941); permute_497 = view_941 = None + permute_499 = torch.ops.aten.permute.default(permute_304, [1, 0]); permute_304 = None + mm_290 = torch.ops.aten.mm.default(view_1197, permute_499); view_1197 = permute_499 = None + view_1198 = torch.ops.aten.view.default(mm_290, [2, 8192, 4096]); mm_290 = None + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(mm_289, torch.float32); mm_289 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1318, 'avg', 256, '0'); convert_element_type_1318 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + view_1199 = torch.ops.aten.view.default(view_1198, [2, 8192, 32, 128]); view_1198 = None + permute_501 = torch.ops.aten.permute.default(view_1199, [0, 2, 1, 3]); view_1199 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16); primals_247 = None + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 256, '0'); convert_element_type_892 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32); add_107 = None + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_244) + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + view_921 = torch.ops.aten.view.default(convert_element_type_894, [16384, 4096]); convert_element_type_894 = None + view_922 = torch.ops.aten.view.default(mm_189, [2, 8192, 4096]); mm_189 = None + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16); primals_249 = None + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 256, '0'); convert_element_type_898 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_190 = torch.ops.aten.mm.default(view_921, permute_298) + view_925 = torch.ops.aten.view.default(mm_190, [2, 8192, 1024]); mm_190 = None + view_928 = torch.ops.aten.view.default(mm_191, [2, 8192, 1024]); mm_191 = None + view_929 = torch.ops.aten.view.default(view_922, [2, 8192, -1, 128]); view_922 = None + view_930 = torch.ops.aten.view.default(view_925, [2, 8192, -1, 128]); view_925 = None + view_931 = torch.ops.aten.view.default(view_928, [2, 8192, -1, 128]); view_928 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_929, torch.float32); view_929 = None + view_932 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 32, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_932); view_932 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_930, torch.float32); view_930 = None + view_933 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 8, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_933); view_933 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_16); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_935 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 32, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_16); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_936 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 8, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_935, torch.bfloat16); view_935 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_936, torch.bfloat16); view_936 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 8, 4, 128]); unsqueeze_54 = None + clone_54 = torch.ops.aten.clone.default(expand_54, memory_format = torch.contiguous_format); expand_54 = None + view_937 = torch.ops.aten.view.default(clone_54, [2, 8192, 32, 128]); clone_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_931, 3); view_931 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 8, 4, 128]); unsqueeze_55 = None + clone_55 = torch.ops.aten.clone.default(expand_55, memory_format = torch.contiguous_format); expand_55 = None + view_938 = torch.ops.aten.view.default(clone_55, [2, 8192, 32, 128]); clone_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_937, [0, 2, 1, 3]); view_937 = None + permute_302 = torch.ops.aten.permute.default(view_938, [0, 2, 1, 3]); view_938 = None + _scaled_dot_product_cudnn_attention_backward_4 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_501, permute_300, permute_301, permute_302, getitem_243, getitem_244, getitem_249, getitem_250, None, None, None, 8192, 8192, 0.0, True); permute_501 = permute_300 = permute_301 = permute_302 = getitem_243 = getitem_244 = getitem_249 = getitem_250 = None + getitem_300 = _scaled_dot_product_cudnn_attention_backward_4[0] + getitem_301 = _scaled_dot_product_cudnn_attention_backward_4[1] + getitem_302 = _scaled_dot_product_cudnn_attention_backward_4[2]; _scaled_dot_product_cudnn_attention_backward_4 = None + permute_502 = torch.ops.aten.permute.default(getitem_302, [0, 2, 1, 3]); getitem_302 = None + permute_503 = torch.ops.aten.permute.default(getitem_301, [0, 2, 1, 3]); getitem_301 = None + permute_504 = torch.ops.aten.permute.default(getitem_300, [0, 2, 1, 3]); getitem_300 = None + view_1200 = torch.ops.aten.view.default(permute_502, [2, 8192, 8, 4, 128]); permute_502 = None + sum_29 = torch.ops.aten.sum.dim_IntList(view_1200, [3], True); view_1200 = None + squeeze_8 = torch.ops.aten.squeeze.dim(sum_29, 3); sum_29 = None + view_1201 = torch.ops.aten.view.default(permute_503, [2, 8192, 8, 4, 128]); permute_503 = None + sum_30 = torch.ops.aten.sum.dim_IntList(view_1201, [3], True); view_1201 = None + squeeze_9 = torch.ops.aten.squeeze.dim(sum_30, 3); sum_30 = None + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(squeeze_9, torch.float32); squeeze_9 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(permute_504, torch.float32); permute_504 = None + view_1202 = torch.ops.aten.view.default(convert_element_type_1319, [2, 8192, 8, 64, 2]); convert_element_type_1319 = None + view_as_complex_72 = torch.ops.aten.view_as_complex.default(view_1202); view_1202 = None + mul_356 = torch.ops.aten.mul.Tensor(view_as_complex_72, _conj); view_as_complex_72 = None + view_1203 = torch.ops.aten.view.default(convert_element_type_1320, [2, 8192, 32, 64, 2]); convert_element_type_1320 = None + view_as_complex_73 = torch.ops.aten.view_as_complex.default(view_1203); view_1203 = None + mul_357 = torch.ops.aten.mul.Tensor(view_as_complex_73, _conj); view_as_complex_73 = None + view_as_real_72 = torch.ops.aten.view_as_real.default(mul_356); mul_356 = None + view_1204 = torch.ops.aten.view.default(view_as_real_72, [2, 8192, 8, 128]); view_as_real_72 = None + convert_element_type_1321 = torch.ops.prims.convert_element_type.default(view_1204, torch.bfloat16); view_1204 = None + view_as_real_73 = torch.ops.aten.view_as_real.default(mul_357); mul_357 = None + view_1205 = torch.ops.aten.view.default(view_as_real_73, [2, 8192, 32, 128]); view_as_real_73 = None + convert_element_type_1322 = torch.ops.prims.convert_element_type.default(view_1205, torch.bfloat16); view_1205 = None + view_1206 = torch.ops.aten.view.default(squeeze_8, [2, 8192, 1024]); squeeze_8 = None + view_1207 = torch.ops.aten.view.default(convert_element_type_1321, [2, 8192, 1024]); convert_element_type_1321 = None + view_1208 = torch.ops.aten.view.default(convert_element_type_1322, [2, 8192, 4096]); convert_element_type_1322 = None + view_1209 = torch.ops.aten.view.default(view_1206, [16384, 1024]); view_1206 = None + permute_505 = torch.ops.aten.permute.default(view_1209, [1, 0]) + mm_291 = torch.ops.aten.mm.default(permute_505, view_921); permute_505 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16); primals_250 = None + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 256, '0'); convert_element_type_901 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + permute_507 = torch.ops.aten.permute.default(permute_299, [1, 0]); permute_299 = None + mm_292 = torch.ops.aten.mm.default(view_1209, permute_507); view_1209 = permute_507 = None + view_1210 = torch.ops.aten.view.default(mm_292, [2, 8192, 4096]); mm_292 = None + convert_element_type_1327 = torch.ops.prims.convert_element_type.default(mm_291, torch.float32); mm_291 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1327, 'avg', 256, '0'); convert_element_type_1327 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43); reduce_scatter_tensor_43 = None + view_1211 = torch.ops.aten.view.default(view_1207, [16384, 1024]); view_1207 = None + permute_509 = torch.ops.aten.permute.default(view_1211, [1, 0]) + mm_293 = torch.ops.aten.mm.default(permute_509, view_921); permute_509 = None + permute_511 = torch.ops.aten.permute.default(permute_298, [1, 0]); permute_298 = None + mm_294 = torch.ops.aten.mm.default(view_1211, permute_511); view_1211 = permute_511 = None + view_1212 = torch.ops.aten.view.default(mm_294, [2, 8192, 4096]); mm_294 = None + add_161 = torch.ops.aten.add.Tensor(view_1210, view_1212); view_1210 = view_1212 = None + convert_element_type_1332 = torch.ops.prims.convert_element_type.default(mm_293, torch.float32); mm_293 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1332, 'avg', 256, '0'); convert_element_type_1332 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + view_1213 = torch.ops.aten.view.default(view_1208, [16384, 4096]); view_1208 = None + permute_513 = torch.ops.aten.permute.default(view_1213, [1, 0]) + mm_295 = torch.ops.aten.mm.default(permute_513, view_921); permute_513 = view_921 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16); primals_248 = None + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 256, '0'); convert_element_type_895 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + permute_515 = torch.ops.aten.permute.default(permute_297, [1, 0]); permute_297 = None + mm_296 = torch.ops.aten.mm.default(view_1213, permute_515); view_1213 = permute_515 = None + view_1214 = torch.ops.aten.view.default(mm_296, [2, 8192, 4096]); mm_296 = None + add_162 = torch.ops.aten.add.Tensor(add_161, view_1214); add_161 = view_1214 = None + convert_element_type_1337 = torch.ops.prims.convert_element_type.default(mm_295, torch.float32); mm_295 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1337, 'avg', 256, '0'); convert_element_type_1337 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45); reduce_scatter_tensor_45 = None + convert_element_type_1338 = torch.ops.prims.convert_element_type.default(add_162, torch.float32); add_162 = None + convert_element_type_1340 = torch.ops.prims.convert_element_type.default(wait_tensor_244, torch.float32); wait_tensor_244 = None + mul_358 = torch.ops.aten.mul.Tensor(convert_element_type_1338, convert_element_type_1340); convert_element_type_1340 = None + mul_360 = torch.ops.aten.mul.Tensor(mul_216, mul_358) + sum_31 = torch.ops.aten.sum.dim_IntList(mul_360, [2], True); mul_360 = None + div_10 = torch.ops.aten.div.Tensor(mul_216, 4096) + mul_361 = torch.ops.aten.mul.Tensor(div_10, sum_31); div_10 = sum_31 = None + sub_15 = torch.ops.aten.sub.Tensor(mul_358, mul_361); mul_358 = mul_361 = None + mul_362 = torch.ops.aten.mul.Tensor(sub_15, rsqrt_54); sub_15 = rsqrt_54 = None + mul_363 = torch.ops.aten.mul.Tensor(convert_element_type_1338, mul_216); convert_element_type_1338 = mul_216 = None + sum_32 = torch.ops.aten.sum.dim_IntList(mul_363, [0, 1]); mul_363 = None + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(mul_362, torch.bfloat16); mul_362 = None + add_163 = torch.ops.aten.add.Tensor(add_160, convert_element_type_1341); add_160 = convert_element_type_1341 = None + convert_element_type_default_55 = torch.ops.prims.convert_element_type.default(sum_32, torch.float32); sum_32 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_55, 'avg', 256, '0'); convert_element_type_default_55 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + view_1215 = torch.ops.aten.view.default(add_163, [16384, 4096]) + permute_517 = torch.ops.aten.permute.default(view_1215, [1, 0]) + permute_292 = torch.ops.aten.permute.default(getitem_234, [0, 2, 1, 3]) + view_905 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16); primals_242 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 256, '0'); convert_element_type_875 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + view_907 = torch.ops.aten.view.default(view_905, [16384, 4096]); view_905 = None + mm_185 = torch.ops.aten.mm.default(view_907, permute_293) + view_908 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + add_105 = torch.ops.aten.add.Tensor(add_103, view_908); view_908 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16); primals_243 = None + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 256, '0'); convert_element_type_878 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32); add_105 = None + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_240) + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + view_911 = torch.ops.aten.view.default(convert_element_type_880, [16384, 4096]); convert_element_type_880 = None + view_912 = torch.ops.aten.view.default(mm_186, [2, 8192, 14336]); mm_186 = None + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_912, torch.float32); view_912 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16); primals_245 = None + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 256, '0'); convert_element_type_886 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_242, [1, 0]); wait_tensor_242 = None + mm_187 = torch.ops.aten.mm.default(view_911, permute_295) + view_915 = torch.ops.aten.view.default(mm_187, [2, 8192, 14336]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_915) + view_917 = torch.ops.aten.view.default(mul_215, [16384, 14336]); mul_215 = None + mm_297 = torch.ops.aten.mm.default(permute_517, view_917); permute_517 = view_917 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16); primals_246 = None + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 256, '0'); convert_element_type_889 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + permute_519 = torch.ops.aten.permute.default(permute_296, [1, 0]); permute_296 = None + mm_298 = torch.ops.aten.mm.default(view_1215, permute_519); view_1215 = permute_519 = None + view_1216 = torch.ops.aten.view.default(mm_298, [2, 8192, 14336]); mm_298 = None + convert_element_type_1348 = torch.ops.prims.convert_element_type.default(mm_297, torch.float32); mm_297 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1348, 'avg', 256, '0'); convert_element_type_1348 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47); reduce_scatter_tensor_47 = None + mul_364 = torch.ops.aten.mul.Tensor(view_1216, convert_element_type_885); convert_element_type_885 = None + mul_365 = torch.ops.aten.mul.Tensor(view_1216, view_915); view_1216 = view_915 = None + view_1217 = torch.ops.aten.view.default(mul_364, [16384, 14336]); mul_364 = None + permute_521 = torch.ops.aten.permute.default(view_1217, [1, 0]) + mm_299 = torch.ops.aten.mm.default(permute_521, view_911); permute_521 = None + permute_523 = torch.ops.aten.permute.default(permute_295, [1, 0]); permute_295 = None + mm_300 = torch.ops.aten.mm.default(view_1217, permute_523); view_1217 = permute_523 = None + view_1218 = torch.ops.aten.view.default(mm_300, [2, 8192, 4096]); mm_300 = None + convert_element_type_1353 = torch.ops.prims.convert_element_type.default(mm_299, torch.float32); mm_299 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1353, 'avg', 256, '0'); convert_element_type_1353 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + convert_element_type_1354 = torch.ops.prims.convert_element_type.default(mul_365, torch.float32); mul_365 = None + neg_5 = torch.ops.aten.neg.default(convert_element_type_884) + exp_5 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_164 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + reciprocal_5 = torch.ops.aten.reciprocal.default(add_164); add_164 = None + mul_366 = torch.ops.aten.mul.Tensor(reciprocal_5, 1); reciprocal_5 = None + mul_367 = torch.ops.aten.mul.Tensor(convert_element_type_1354, mul_366); convert_element_type_1354 = None + sub_16 = torch.ops.aten.sub.Tensor(1, mul_366); mul_366 = None + mul_368 = torch.ops.aten.mul.Tensor(convert_element_type_884, sub_16); convert_element_type_884 = sub_16 = None + add_165 = torch.ops.aten.add.Tensor(mul_368, 1); mul_368 = None + mul_369 = torch.ops.aten.mul.Tensor(mul_367, add_165); mul_367 = add_165 = None + convert_element_type_1356 = torch.ops.prims.convert_element_type.default(mul_369, torch.bfloat16); mul_369 = None + view_1219 = torch.ops.aten.view.default(convert_element_type_1356, [16384, 14336]); convert_element_type_1356 = None + permute_525 = torch.ops.aten.permute.default(view_1219, [1, 0]) + mm_301 = torch.ops.aten.mm.default(permute_525, view_911); permute_525 = view_911 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 256, '0'); convert_element_type_881 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + permute_527 = torch.ops.aten.permute.default(permute_294, [1, 0]); permute_294 = None + mm_302 = torch.ops.aten.mm.default(view_1219, permute_527); view_1219 = permute_527 = None + view_1220 = torch.ops.aten.view.default(mm_302, [2, 8192, 4096]); mm_302 = None + add_166 = torch.ops.aten.add.Tensor(view_1218, view_1220); view_1218 = view_1220 = None + convert_element_type_1361 = torch.ops.prims.convert_element_type.default(mm_301, torch.float32); mm_301 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1361, 'avg', 256, '0'); convert_element_type_1361 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49); reduce_scatter_tensor_49 = None + convert_element_type_1362 = torch.ops.prims.convert_element_type.default(add_166, torch.float32); add_166 = None + convert_element_type_1364 = torch.ops.prims.convert_element_type.default(wait_tensor_240, torch.float32); wait_tensor_240 = None + mul_370 = torch.ops.aten.mul.Tensor(convert_element_type_1362, convert_element_type_1364); convert_element_type_1364 = None + mul_372 = torch.ops.aten.mul.Tensor(mul_212, mul_370) + sum_33 = torch.ops.aten.sum.dim_IntList(mul_372, [2], True); mul_372 = None + div_11 = torch.ops.aten.div.Tensor(mul_212, 4096) + mul_373 = torch.ops.aten.mul.Tensor(div_11, sum_33); div_11 = sum_33 = None + sub_17 = torch.ops.aten.sub.Tensor(mul_370, mul_373); mul_370 = mul_373 = None + mul_374 = torch.ops.aten.mul.Tensor(sub_17, rsqrt_53); sub_17 = rsqrt_53 = None + mul_375 = torch.ops.aten.mul.Tensor(convert_element_type_1362, mul_212); convert_element_type_1362 = mul_212 = None + sum_34 = torch.ops.aten.sum.dim_IntList(mul_375, [0, 1]); mul_375 = None + convert_element_type_1365 = torch.ops.prims.convert_element_type.default(mul_374, torch.bfloat16); mul_374 = None + add_167 = torch.ops.aten.add.Tensor(add_163, convert_element_type_1365); add_163 = convert_element_type_1365 = None + convert_element_type_default_54 = torch.ops.prims.convert_element_type.default(sum_34, torch.float32); sum_34 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_54, 'avg', 256, '0'); convert_element_type_default_54 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + view_1221 = torch.ops.aten.view.default(add_167, [16384, 4096]) + permute_529 = torch.ops.aten.permute.default(view_1221, [1, 0]) + mm_303 = torch.ops.aten.mm.default(permute_529, view_907); permute_529 = view_907 = None + permute_531 = torch.ops.aten.permute.default(permute_293, [1, 0]); permute_293 = None + mm_304 = torch.ops.aten.mm.default(view_1221, permute_531); view_1221 = permute_531 = None + view_1222 = torch.ops.aten.view.default(mm_304, [2, 8192, 4096]); mm_304 = None + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(mm_303, torch.float32); mm_303 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1372, 'avg', 256, '0'); convert_element_type_1372 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51); reduce_scatter_tensor_51 = None + view_1223 = torch.ops.aten.view.default(view_1222, [2, 8192, 32, 128]); view_1222 = None + permute_533 = torch.ops.aten.permute.default(view_1223, [0, 2, 1, 3]); view_1223 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16); primals_238 = None + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 256, '0'); convert_element_type_859 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32); add_103 = None + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_235) + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + view_887 = torch.ops.aten.view.default(convert_element_type_861, [16384, 4096]); convert_element_type_861 = None + view_888 = torch.ops.aten.view.default(mm_182, [2, 8192, 4096]); mm_182 = None + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16); primals_240 = None + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 256, '0'); convert_element_type_865 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_237, [1, 0]); wait_tensor_237 = None + mm_183 = torch.ops.aten.mm.default(view_887, permute_287) + view_891 = torch.ops.aten.view.default(mm_183, [2, 8192, 1024]); mm_183 = None + view_894 = torch.ops.aten.view.default(mm_184, [2, 8192, 1024]); mm_184 = None + view_895 = torch.ops.aten.view.default(view_888, [2, 8192, -1, 128]); view_888 = None + view_896 = torch.ops.aten.view.default(view_891, [2, 8192, -1, 128]); view_891 = None + view_897 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_895, torch.float32); view_895 = None + view_898 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 32, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_898); view_898 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 8, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_16); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_901 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 32, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_16); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_902 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 8, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_901, torch.bfloat16); view_901 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 8, 4, 128]); unsqueeze_52 = None + clone_52 = torch.ops.aten.clone.default(expand_52, memory_format = torch.contiguous_format); expand_52 = None + view_903 = torch.ops.aten.view.default(clone_52, [2, 8192, 32, 128]); clone_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_897, 3); view_897 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 8, 4, 128]); unsqueeze_53 = None + clone_53 = torch.ops.aten.clone.default(expand_53, memory_format = torch.contiguous_format); expand_53 = None + view_904 = torch.ops.aten.view.default(clone_53, [2, 8192, 32, 128]); clone_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_903, [0, 2, 1, 3]); view_903 = None + permute_291 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + _scaled_dot_product_cudnn_attention_backward_5 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_533, permute_289, permute_290, permute_291, getitem_234, getitem_235, getitem_240, getitem_241, None, None, None, 8192, 8192, 0.0, True); permute_533 = permute_289 = permute_290 = permute_291 = getitem_234 = getitem_235 = getitem_240 = getitem_241 = None + getitem_303 = _scaled_dot_product_cudnn_attention_backward_5[0] + getitem_304 = _scaled_dot_product_cudnn_attention_backward_5[1] + getitem_305 = _scaled_dot_product_cudnn_attention_backward_5[2]; _scaled_dot_product_cudnn_attention_backward_5 = None + permute_534 = torch.ops.aten.permute.default(getitem_305, [0, 2, 1, 3]); getitem_305 = None + permute_535 = torch.ops.aten.permute.default(getitem_304, [0, 2, 1, 3]); getitem_304 = None + permute_536 = torch.ops.aten.permute.default(getitem_303, [0, 2, 1, 3]); getitem_303 = None + view_1224 = torch.ops.aten.view.default(permute_534, [2, 8192, 8, 4, 128]); permute_534 = None + sum_35 = torch.ops.aten.sum.dim_IntList(view_1224, [3], True); view_1224 = None + squeeze_10 = torch.ops.aten.squeeze.dim(sum_35, 3); sum_35 = None + view_1225 = torch.ops.aten.view.default(permute_535, [2, 8192, 8, 4, 128]); permute_535 = None + sum_36 = torch.ops.aten.sum.dim_IntList(view_1225, [3], True); view_1225 = None + squeeze_11 = torch.ops.aten.squeeze.dim(sum_36, 3); sum_36 = None + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(squeeze_11, torch.float32); squeeze_11 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(permute_536, torch.float32); permute_536 = None + view_1226 = torch.ops.aten.view.default(convert_element_type_1373, [2, 8192, 8, 64, 2]); convert_element_type_1373 = None + view_as_complex_74 = torch.ops.aten.view_as_complex.default(view_1226); view_1226 = None + mul_376 = torch.ops.aten.mul.Tensor(view_as_complex_74, _conj); view_as_complex_74 = None + view_1227 = torch.ops.aten.view.default(convert_element_type_1374, [2, 8192, 32, 64, 2]); convert_element_type_1374 = None + view_as_complex_75 = torch.ops.aten.view_as_complex.default(view_1227); view_1227 = None + mul_377 = torch.ops.aten.mul.Tensor(view_as_complex_75, _conj); view_as_complex_75 = None + view_as_real_74 = torch.ops.aten.view_as_real.default(mul_376); mul_376 = None + view_1228 = torch.ops.aten.view.default(view_as_real_74, [2, 8192, 8, 128]); view_as_real_74 = None + convert_element_type_1375 = torch.ops.prims.convert_element_type.default(view_1228, torch.bfloat16); view_1228 = None + view_as_real_75 = torch.ops.aten.view_as_real.default(mul_377); mul_377 = None + view_1229 = torch.ops.aten.view.default(view_as_real_75, [2, 8192, 32, 128]); view_as_real_75 = None + convert_element_type_1376 = torch.ops.prims.convert_element_type.default(view_1229, torch.bfloat16); view_1229 = None + view_1230 = torch.ops.aten.view.default(squeeze_10, [2, 8192, 1024]); squeeze_10 = None + view_1231 = torch.ops.aten.view.default(convert_element_type_1375, [2, 8192, 1024]); convert_element_type_1375 = None + view_1232 = torch.ops.aten.view.default(convert_element_type_1376, [2, 8192, 4096]); convert_element_type_1376 = None + view_1233 = torch.ops.aten.view.default(view_1230, [16384, 1024]); view_1230 = None + permute_537 = torch.ops.aten.permute.default(view_1233, [1, 0]) + mm_305 = torch.ops.aten.mm.default(permute_537, view_887); permute_537 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16); primals_241 = None + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 256, '0'); convert_element_type_868 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + permute_539 = torch.ops.aten.permute.default(permute_288, [1, 0]); permute_288 = None + mm_306 = torch.ops.aten.mm.default(view_1233, permute_539); view_1233 = permute_539 = None + view_1234 = torch.ops.aten.view.default(mm_306, [2, 8192, 4096]); mm_306 = None + convert_element_type_1381 = torch.ops.prims.convert_element_type.default(mm_305, torch.float32); mm_305 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1381, 'avg', 256, '0'); convert_element_type_1381 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + view_1235 = torch.ops.aten.view.default(view_1231, [16384, 1024]); view_1231 = None + permute_541 = torch.ops.aten.permute.default(view_1235, [1, 0]) + mm_307 = torch.ops.aten.mm.default(permute_541, view_887); permute_541 = None + permute_543 = torch.ops.aten.permute.default(permute_287, [1, 0]); permute_287 = None + mm_308 = torch.ops.aten.mm.default(view_1235, permute_543); view_1235 = permute_543 = None + view_1236 = torch.ops.aten.view.default(mm_308, [2, 8192, 4096]); mm_308 = None + add_168 = torch.ops.aten.add.Tensor(view_1234, view_1236); view_1234 = view_1236 = None + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(mm_307, torch.float32); mm_307 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1386, 'avg', 256, '0'); convert_element_type_1386 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53); reduce_scatter_tensor_53 = None + view_1237 = torch.ops.aten.view.default(view_1232, [16384, 4096]); view_1232 = None + permute_545 = torch.ops.aten.permute.default(view_1237, [1, 0]) + mm_309 = torch.ops.aten.mm.default(permute_545, view_887); permute_545 = view_887 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16); primals_239 = None + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 256, '0'); convert_element_type_862 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_236, [1, 0]); wait_tensor_236 = None + permute_547 = torch.ops.aten.permute.default(permute_286, [1, 0]); permute_286 = None + mm_310 = torch.ops.aten.mm.default(view_1237, permute_547); view_1237 = permute_547 = None + view_1238 = torch.ops.aten.view.default(mm_310, [2, 8192, 4096]); mm_310 = None + add_169 = torch.ops.aten.add.Tensor(add_168, view_1238); add_168 = view_1238 = None + convert_element_type_1391 = torch.ops.prims.convert_element_type.default(mm_309, torch.float32); mm_309 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1391, 'avg', 256, '0'); convert_element_type_1391 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + convert_element_type_1392 = torch.ops.prims.convert_element_type.default(add_169, torch.float32); add_169 = None + convert_element_type_1394 = torch.ops.prims.convert_element_type.default(wait_tensor_235, torch.float32); wait_tensor_235 = None + mul_378 = torch.ops.aten.mul.Tensor(convert_element_type_1392, convert_element_type_1394); convert_element_type_1394 = None + mul_380 = torch.ops.aten.mul.Tensor(mul_208, mul_378) + sum_37 = torch.ops.aten.sum.dim_IntList(mul_380, [2], True); mul_380 = None + div_12 = torch.ops.aten.div.Tensor(mul_208, 4096) + mul_381 = torch.ops.aten.mul.Tensor(div_12, sum_37); div_12 = sum_37 = None + sub_18 = torch.ops.aten.sub.Tensor(mul_378, mul_381); mul_378 = mul_381 = None + mul_382 = torch.ops.aten.mul.Tensor(sub_18, rsqrt_52); sub_18 = rsqrt_52 = None + mul_383 = torch.ops.aten.mul.Tensor(convert_element_type_1392, mul_208); convert_element_type_1392 = mul_208 = None + sum_38 = torch.ops.aten.sum.dim_IntList(mul_383, [0, 1]); mul_383 = None + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(mul_382, torch.bfloat16); mul_382 = None + add_170 = torch.ops.aten.add.Tensor(add_167, convert_element_type_1395); add_167 = convert_element_type_1395 = None + convert_element_type_default_53 = torch.ops.prims.convert_element_type.default(sum_38, torch.float32); sum_38 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_53, 'avg', 256, '0'); convert_element_type_default_53 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55); reduce_scatter_tensor_55 = None + view_1239 = torch.ops.aten.view.default(add_170, [16384, 4096]) + permute_549 = torch.ops.aten.permute.default(view_1239, [1, 0]) + permute_281 = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]) + view_871 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16); primals_233 = None + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 256, '0'); convert_element_type_842 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_230, [1, 0]); wait_tensor_230 = None + view_873 = torch.ops.aten.view.default(view_871, [16384, 4096]); view_871 = None + mm_178 = torch.ops.aten.mm.default(view_873, permute_282) + view_874 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + add_101 = torch.ops.aten.add.Tensor(add_99, view_874); view_874 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16); primals_234 = None + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 256, '0'); convert_element_type_845 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32); add_101 = None + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_231) + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + view_877 = torch.ops.aten.view.default(convert_element_type_847, [16384, 4096]); convert_element_type_847 = None + view_878 = torch.ops.aten.view.default(mm_179, [2, 8192, 14336]); mm_179 = None + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_878, torch.float32); view_878 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16); primals_236 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 256, '0'); convert_element_type_853 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_180 = torch.ops.aten.mm.default(view_877, permute_284) + view_881 = torch.ops.aten.view.default(mm_180, [2, 8192, 14336]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_881) + view_883 = torch.ops.aten.view.default(mul_207, [16384, 14336]); mul_207 = None + mm_311 = torch.ops.aten.mm.default(permute_549, view_883); permute_549 = view_883 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16); primals_237 = None + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 256, '0'); convert_element_type_856 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + permute_551 = torch.ops.aten.permute.default(permute_285, [1, 0]); permute_285 = None + mm_312 = torch.ops.aten.mm.default(view_1239, permute_551); view_1239 = permute_551 = None + view_1240 = torch.ops.aten.view.default(mm_312, [2, 8192, 14336]); mm_312 = None + convert_element_type_1402 = torch.ops.prims.convert_element_type.default(mm_311, torch.float32); mm_311 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1402, 'avg', 256, '0'); convert_element_type_1402 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + mul_384 = torch.ops.aten.mul.Tensor(view_1240, convert_element_type_852); convert_element_type_852 = None + mul_385 = torch.ops.aten.mul.Tensor(view_1240, view_881); view_1240 = view_881 = None + view_1241 = torch.ops.aten.view.default(mul_384, [16384, 14336]); mul_384 = None + permute_553 = torch.ops.aten.permute.default(view_1241, [1, 0]) + mm_313 = torch.ops.aten.mm.default(permute_553, view_877); permute_553 = None + permute_555 = torch.ops.aten.permute.default(permute_284, [1, 0]); permute_284 = None + mm_314 = torch.ops.aten.mm.default(view_1241, permute_555); view_1241 = permute_555 = None + view_1242 = torch.ops.aten.view.default(mm_314, [2, 8192, 4096]); mm_314 = None + convert_element_type_1407 = torch.ops.prims.convert_element_type.default(mm_313, torch.float32); mm_313 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1407, 'avg', 256, '0'); convert_element_type_1407 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57); reduce_scatter_tensor_57 = None + convert_element_type_1408 = torch.ops.prims.convert_element_type.default(mul_385, torch.float32); mul_385 = None + neg_6 = torch.ops.aten.neg.default(convert_element_type_851) + exp_6 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_171 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + reciprocal_6 = torch.ops.aten.reciprocal.default(add_171); add_171 = None + mul_386 = torch.ops.aten.mul.Tensor(reciprocal_6, 1); reciprocal_6 = None + mul_387 = torch.ops.aten.mul.Tensor(convert_element_type_1408, mul_386); convert_element_type_1408 = None + sub_19 = torch.ops.aten.sub.Tensor(1, mul_386); mul_386 = None + mul_388 = torch.ops.aten.mul.Tensor(convert_element_type_851, sub_19); convert_element_type_851 = sub_19 = None + add_172 = torch.ops.aten.add.Tensor(mul_388, 1); mul_388 = None + mul_389 = torch.ops.aten.mul.Tensor(mul_387, add_172); mul_387 = add_172 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(mul_389, torch.bfloat16); mul_389 = None + view_1243 = torch.ops.aten.view.default(convert_element_type_1410, [16384, 14336]); convert_element_type_1410 = None + permute_557 = torch.ops.aten.permute.default(view_1243, [1, 0]) + mm_315 = torch.ops.aten.mm.default(permute_557, view_877); permute_557 = view_877 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16); primals_235 = None + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 256, '0'); convert_element_type_848 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + permute_559 = torch.ops.aten.permute.default(permute_283, [1, 0]); permute_283 = None + mm_316 = torch.ops.aten.mm.default(view_1243, permute_559); view_1243 = permute_559 = None + view_1244 = torch.ops.aten.view.default(mm_316, [2, 8192, 4096]); mm_316 = None + add_173 = torch.ops.aten.add.Tensor(view_1242, view_1244); view_1242 = view_1244 = None + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(mm_315, torch.float32); mm_315 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1415, 'avg', 256, '0'); convert_element_type_1415 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + convert_element_type_1416 = torch.ops.prims.convert_element_type.default(add_173, torch.float32); add_173 = None + convert_element_type_1418 = torch.ops.prims.convert_element_type.default(wait_tensor_231, torch.float32); wait_tensor_231 = None + mul_390 = torch.ops.aten.mul.Tensor(convert_element_type_1416, convert_element_type_1418); convert_element_type_1418 = None + mul_392 = torch.ops.aten.mul.Tensor(mul_204, mul_390) + sum_39 = torch.ops.aten.sum.dim_IntList(mul_392, [2], True); mul_392 = None + div_13 = torch.ops.aten.div.Tensor(mul_204, 4096) + mul_393 = torch.ops.aten.mul.Tensor(div_13, sum_39); div_13 = sum_39 = None + sub_20 = torch.ops.aten.sub.Tensor(mul_390, mul_393); mul_390 = mul_393 = None + mul_394 = torch.ops.aten.mul.Tensor(sub_20, rsqrt_51); sub_20 = rsqrt_51 = None + mul_395 = torch.ops.aten.mul.Tensor(convert_element_type_1416, mul_204); convert_element_type_1416 = mul_204 = None + sum_40 = torch.ops.aten.sum.dim_IntList(mul_395, [0, 1]); mul_395 = None + convert_element_type_1419 = torch.ops.prims.convert_element_type.default(mul_394, torch.bfloat16); mul_394 = None + add_174 = torch.ops.aten.add.Tensor(add_170, convert_element_type_1419); add_170 = convert_element_type_1419 = None + convert_element_type_default_52 = torch.ops.prims.convert_element_type.default(sum_40, torch.float32); sum_40 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_52, 'avg', 256, '0'); convert_element_type_default_52 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59); reduce_scatter_tensor_59 = None + view_1245 = torch.ops.aten.view.default(add_174, [16384, 4096]) + permute_561 = torch.ops.aten.permute.default(view_1245, [1, 0]) + mm_317 = torch.ops.aten.mm.default(permute_561, view_873); permute_561 = view_873 = None + permute_563 = torch.ops.aten.permute.default(permute_282, [1, 0]); permute_282 = None + mm_318 = torch.ops.aten.mm.default(view_1245, permute_563); view_1245 = permute_563 = None + view_1246 = torch.ops.aten.view.default(mm_318, [2, 8192, 4096]); mm_318 = None + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(mm_317, torch.float32); mm_317 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1426, 'avg', 256, '0'); convert_element_type_1426 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + view_1247 = torch.ops.aten.view.default(view_1246, [2, 8192, 32, 128]); view_1246 = None + permute_565 = torch.ops.aten.permute.default(view_1247, [0, 2, 1, 3]); view_1247 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16); primals_229 = None + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 256, '0'); convert_element_type_826 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32); add_99 = None + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_226) + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + view_853 = torch.ops.aten.view.default(convert_element_type_828, [16384, 4096]); convert_element_type_828 = None + view_854 = torch.ops.aten.view.default(mm_175, [2, 8192, 4096]); mm_175 = None + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16); primals_231 = None + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 256, '0'); convert_element_type_832 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + mm_176 = torch.ops.aten.mm.default(view_853, permute_276) + view_857 = torch.ops.aten.view.default(mm_176, [2, 8192, 1024]); mm_176 = None + view_860 = torch.ops.aten.view.default(mm_177, [2, 8192, 1024]); mm_177 = None + view_861 = torch.ops.aten.view.default(view_854, [2, 8192, -1, 128]); view_854 = None + view_862 = torch.ops.aten.view.default(view_857, [2, 8192, -1, 128]); view_857 = None + view_863 = torch.ops.aten.view.default(view_860, [2, 8192, -1, 128]); view_860 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_861, torch.float32); view_861 = None + view_864 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 32, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_864); view_864 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_862, torch.float32); view_862 = None + view_865 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 8, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_865); view_865 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_16); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_867 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 32, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_16); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_868 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 8, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_867, torch.bfloat16); view_867 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_868, torch.bfloat16); view_868 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 8, 4, 128]); unsqueeze_50 = None + clone_50 = torch.ops.aten.clone.default(expand_50, memory_format = torch.contiguous_format); expand_50 = None + view_869 = torch.ops.aten.view.default(clone_50, [2, 8192, 32, 128]); clone_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_863, 3); view_863 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 8, 4, 128]); unsqueeze_51 = None + clone_51 = torch.ops.aten.clone.default(expand_51, memory_format = torch.contiguous_format); expand_51 = None + view_870 = torch.ops.aten.view.default(clone_51, [2, 8192, 32, 128]); clone_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_869, [0, 2, 1, 3]); view_869 = None + permute_280 = torch.ops.aten.permute.default(view_870, [0, 2, 1, 3]); view_870 = None + _scaled_dot_product_cudnn_attention_backward_6 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_565, permute_278, permute_279, permute_280, getitem_225, getitem_226, getitem_231, getitem_232, None, None, None, 8192, 8192, 0.0, True); permute_565 = permute_278 = permute_279 = permute_280 = getitem_225 = getitem_226 = getitem_231 = getitem_232 = None + getitem_306 = _scaled_dot_product_cudnn_attention_backward_6[0] + getitem_307 = _scaled_dot_product_cudnn_attention_backward_6[1] + getitem_308 = _scaled_dot_product_cudnn_attention_backward_6[2]; _scaled_dot_product_cudnn_attention_backward_6 = None + permute_566 = torch.ops.aten.permute.default(getitem_308, [0, 2, 1, 3]); getitem_308 = None + permute_567 = torch.ops.aten.permute.default(getitem_307, [0, 2, 1, 3]); getitem_307 = None + permute_568 = torch.ops.aten.permute.default(getitem_306, [0, 2, 1, 3]); getitem_306 = None + view_1248 = torch.ops.aten.view.default(permute_566, [2, 8192, 8, 4, 128]); permute_566 = None + sum_41 = torch.ops.aten.sum.dim_IntList(view_1248, [3], True); view_1248 = None + squeeze_12 = torch.ops.aten.squeeze.dim(sum_41, 3); sum_41 = None + view_1249 = torch.ops.aten.view.default(permute_567, [2, 8192, 8, 4, 128]); permute_567 = None + sum_42 = torch.ops.aten.sum.dim_IntList(view_1249, [3], True); view_1249 = None + squeeze_13 = torch.ops.aten.squeeze.dim(sum_42, 3); sum_42 = None + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(squeeze_13, torch.float32); squeeze_13 = None + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(permute_568, torch.float32); permute_568 = None + view_1250 = torch.ops.aten.view.default(convert_element_type_1427, [2, 8192, 8, 64, 2]); convert_element_type_1427 = None + view_as_complex_76 = torch.ops.aten.view_as_complex.default(view_1250); view_1250 = None + mul_396 = torch.ops.aten.mul.Tensor(view_as_complex_76, _conj); view_as_complex_76 = None + view_1251 = torch.ops.aten.view.default(convert_element_type_1428, [2, 8192, 32, 64, 2]); convert_element_type_1428 = None + view_as_complex_77 = torch.ops.aten.view_as_complex.default(view_1251); view_1251 = None + mul_397 = torch.ops.aten.mul.Tensor(view_as_complex_77, _conj); view_as_complex_77 = None + view_as_real_76 = torch.ops.aten.view_as_real.default(mul_396); mul_396 = None + view_1252 = torch.ops.aten.view.default(view_as_real_76, [2, 8192, 8, 128]); view_as_real_76 = None + convert_element_type_1429 = torch.ops.prims.convert_element_type.default(view_1252, torch.bfloat16); view_1252 = None + view_as_real_77 = torch.ops.aten.view_as_real.default(mul_397); mul_397 = None + view_1253 = torch.ops.aten.view.default(view_as_real_77, [2, 8192, 32, 128]); view_as_real_77 = None + convert_element_type_1430 = torch.ops.prims.convert_element_type.default(view_1253, torch.bfloat16); view_1253 = None + view_1254 = torch.ops.aten.view.default(squeeze_12, [2, 8192, 1024]); squeeze_12 = None + view_1255 = torch.ops.aten.view.default(convert_element_type_1429, [2, 8192, 1024]); convert_element_type_1429 = None + view_1256 = torch.ops.aten.view.default(convert_element_type_1430, [2, 8192, 4096]); convert_element_type_1430 = None + view_1257 = torch.ops.aten.view.default(view_1254, [16384, 1024]); view_1254 = None + permute_569 = torch.ops.aten.permute.default(view_1257, [1, 0]) + mm_319 = torch.ops.aten.mm.default(permute_569, view_853); permute_569 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16); primals_232 = None + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 256, '0'); convert_element_type_835 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_229, [1, 0]); wait_tensor_229 = None + permute_571 = torch.ops.aten.permute.default(permute_277, [1, 0]); permute_277 = None + mm_320 = torch.ops.aten.mm.default(view_1257, permute_571); view_1257 = permute_571 = None + view_1258 = torch.ops.aten.view.default(mm_320, [2, 8192, 4096]); mm_320 = None + convert_element_type_1435 = torch.ops.prims.convert_element_type.default(mm_319, torch.float32); mm_319 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1435, 'avg', 256, '0'); convert_element_type_1435 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61); reduce_scatter_tensor_61 = None + view_1259 = torch.ops.aten.view.default(view_1255, [16384, 1024]); view_1255 = None + permute_573 = torch.ops.aten.permute.default(view_1259, [1, 0]) + mm_321 = torch.ops.aten.mm.default(permute_573, view_853); permute_573 = None + permute_575 = torch.ops.aten.permute.default(permute_276, [1, 0]); permute_276 = None + mm_322 = torch.ops.aten.mm.default(view_1259, permute_575); view_1259 = permute_575 = None + view_1260 = torch.ops.aten.view.default(mm_322, [2, 8192, 4096]); mm_322 = None + add_175 = torch.ops.aten.add.Tensor(view_1258, view_1260); view_1258 = view_1260 = None + convert_element_type_1440 = torch.ops.prims.convert_element_type.default(mm_321, torch.float32); mm_321 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1440, 'avg', 256, '0'); convert_element_type_1440 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + view_1261 = torch.ops.aten.view.default(view_1256, [16384, 4096]); view_1256 = None + permute_577 = torch.ops.aten.permute.default(view_1261, [1, 0]) + mm_323 = torch.ops.aten.mm.default(permute_577, view_853); permute_577 = view_853 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16); primals_230 = None + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 256, '0'); convert_element_type_829 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + permute_579 = torch.ops.aten.permute.default(permute_275, [1, 0]); permute_275 = None + mm_324 = torch.ops.aten.mm.default(view_1261, permute_579); view_1261 = permute_579 = None + view_1262 = torch.ops.aten.view.default(mm_324, [2, 8192, 4096]); mm_324 = None + add_176 = torch.ops.aten.add.Tensor(add_175, view_1262); add_175 = view_1262 = None + convert_element_type_1445 = torch.ops.prims.convert_element_type.default(mm_323, torch.float32); mm_323 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1445, 'avg', 256, '0'); convert_element_type_1445 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63); reduce_scatter_tensor_63 = None + convert_element_type_1446 = torch.ops.prims.convert_element_type.default(add_176, torch.float32); add_176 = None + convert_element_type_1448 = torch.ops.prims.convert_element_type.default(wait_tensor_226, torch.float32); wait_tensor_226 = None + mul_398 = torch.ops.aten.mul.Tensor(convert_element_type_1446, convert_element_type_1448); convert_element_type_1448 = None + mul_400 = torch.ops.aten.mul.Tensor(mul_200, mul_398) + sum_43 = torch.ops.aten.sum.dim_IntList(mul_400, [2], True); mul_400 = None + div_14 = torch.ops.aten.div.Tensor(mul_200, 4096) + mul_401 = torch.ops.aten.mul.Tensor(div_14, sum_43); div_14 = sum_43 = None + sub_21 = torch.ops.aten.sub.Tensor(mul_398, mul_401); mul_398 = mul_401 = None + mul_402 = torch.ops.aten.mul.Tensor(sub_21, rsqrt_50); sub_21 = rsqrt_50 = None + mul_403 = torch.ops.aten.mul.Tensor(convert_element_type_1446, mul_200); convert_element_type_1446 = mul_200 = None + sum_44 = torch.ops.aten.sum.dim_IntList(mul_403, [0, 1]); mul_403 = None + convert_element_type_1449 = torch.ops.prims.convert_element_type.default(mul_402, torch.bfloat16); mul_402 = None + add_177 = torch.ops.aten.add.Tensor(add_174, convert_element_type_1449); add_174 = convert_element_type_1449 = None + convert_element_type_default_51 = torch.ops.prims.convert_element_type.default(sum_44, torch.float32); sum_44 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_51, 'avg', 256, '0'); convert_element_type_default_51 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64); reduce_scatter_tensor_64 = None + view_1263 = torch.ops.aten.view.default(add_177, [16384, 4096]) + permute_581 = torch.ops.aten.permute.default(view_1263, [1, 0]) + permute_270 = torch.ops.aten.permute.default(getitem_216, [0, 2, 1, 3]) + view_837 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16); primals_224 = None + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 256, '0'); convert_element_type_809 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_839 = torch.ops.aten.view.default(view_837, [16384, 4096]); view_837 = None + mm_171 = torch.ops.aten.mm.default(view_839, permute_271) + view_840 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + add_97 = torch.ops.aten.add.Tensor(add_95, view_840); view_840 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16); primals_225 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 256, '0'); convert_element_type_812 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32); add_97 = None + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_222) + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + view_843 = torch.ops.aten.view.default(convert_element_type_814, [16384, 4096]); convert_element_type_814 = None + view_844 = torch.ops.aten.view.default(mm_172, [2, 8192, 14336]); mm_172 = None + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_844, torch.float32); view_844 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16); primals_227 = None + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 256, '0'); convert_element_type_820 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_224, [1, 0]); wait_tensor_224 = None + mm_173 = torch.ops.aten.mm.default(view_843, permute_273) + view_847 = torch.ops.aten.view.default(mm_173, [2, 8192, 14336]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_847) + view_849 = torch.ops.aten.view.default(mul_199, [16384, 14336]); mul_199 = None + mm_325 = torch.ops.aten.mm.default(permute_581, view_849); permute_581 = view_849 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16); primals_228 = None + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 256, '0'); convert_element_type_823 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + permute_583 = torch.ops.aten.permute.default(permute_274, [1, 0]); permute_274 = None + mm_326 = torch.ops.aten.mm.default(view_1263, permute_583); view_1263 = permute_583 = None + view_1264 = torch.ops.aten.view.default(mm_326, [2, 8192, 14336]); mm_326 = None + convert_element_type_1456 = torch.ops.prims.convert_element_type.default(mm_325, torch.float32); mm_325 = None + reduce_scatter_tensor_65 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1456, 'avg', 256, '0'); convert_element_type_1456 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_65); reduce_scatter_tensor_65 = None + mul_404 = torch.ops.aten.mul.Tensor(view_1264, convert_element_type_819); convert_element_type_819 = None + mul_405 = torch.ops.aten.mul.Tensor(view_1264, view_847); view_1264 = view_847 = None + view_1265 = torch.ops.aten.view.default(mul_404, [16384, 14336]); mul_404 = None + permute_585 = torch.ops.aten.permute.default(view_1265, [1, 0]) + mm_327 = torch.ops.aten.mm.default(permute_585, view_843); permute_585 = None + permute_587 = torch.ops.aten.permute.default(permute_273, [1, 0]); permute_273 = None + mm_328 = torch.ops.aten.mm.default(view_1265, permute_587); view_1265 = permute_587 = None + view_1266 = torch.ops.aten.view.default(mm_328, [2, 8192, 4096]); mm_328 = None + convert_element_type_1461 = torch.ops.prims.convert_element_type.default(mm_327, torch.float32); mm_327 = None + reduce_scatter_tensor_66 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1461, 'avg', 256, '0'); convert_element_type_1461 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_66); reduce_scatter_tensor_66 = None + convert_element_type_1462 = torch.ops.prims.convert_element_type.default(mul_405, torch.float32); mul_405 = None + neg_7 = torch.ops.aten.neg.default(convert_element_type_818) + exp_7 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_178 = torch.ops.aten.add.Tensor(exp_7, 1); exp_7 = None + reciprocal_7 = torch.ops.aten.reciprocal.default(add_178); add_178 = None + mul_406 = torch.ops.aten.mul.Tensor(reciprocal_7, 1); reciprocal_7 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_1462, mul_406); convert_element_type_1462 = None + sub_22 = torch.ops.aten.sub.Tensor(1, mul_406); mul_406 = None + mul_408 = torch.ops.aten.mul.Tensor(convert_element_type_818, sub_22); convert_element_type_818 = sub_22 = None + add_179 = torch.ops.aten.add.Tensor(mul_408, 1); mul_408 = None + mul_409 = torch.ops.aten.mul.Tensor(mul_407, add_179); mul_407 = add_179 = None + convert_element_type_1464 = torch.ops.prims.convert_element_type.default(mul_409, torch.bfloat16); mul_409 = None + view_1267 = torch.ops.aten.view.default(convert_element_type_1464, [16384, 14336]); convert_element_type_1464 = None + permute_589 = torch.ops.aten.permute.default(view_1267, [1, 0]) + mm_329 = torch.ops.aten.mm.default(permute_589, view_843); permute_589 = view_843 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16); primals_226 = None + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 256, '0'); convert_element_type_815 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_223, [1, 0]); wait_tensor_223 = None + permute_591 = torch.ops.aten.permute.default(permute_272, [1, 0]); permute_272 = None + mm_330 = torch.ops.aten.mm.default(view_1267, permute_591); view_1267 = permute_591 = None + view_1268 = torch.ops.aten.view.default(mm_330, [2, 8192, 4096]); mm_330 = None + add_180 = torch.ops.aten.add.Tensor(view_1266, view_1268); view_1266 = view_1268 = None + convert_element_type_1469 = torch.ops.prims.convert_element_type.default(mm_329, torch.float32); mm_329 = None + reduce_scatter_tensor_67 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1469, 'avg', 256, '0'); convert_element_type_1469 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_67); reduce_scatter_tensor_67 = None + convert_element_type_1470 = torch.ops.prims.convert_element_type.default(add_180, torch.float32); add_180 = None + convert_element_type_1472 = torch.ops.prims.convert_element_type.default(wait_tensor_222, torch.float32); wait_tensor_222 = None + mul_410 = torch.ops.aten.mul.Tensor(convert_element_type_1470, convert_element_type_1472); convert_element_type_1472 = None + mul_412 = torch.ops.aten.mul.Tensor(mul_196, mul_410) + sum_45 = torch.ops.aten.sum.dim_IntList(mul_412, [2], True); mul_412 = None + div_15 = torch.ops.aten.div.Tensor(mul_196, 4096) + mul_413 = torch.ops.aten.mul.Tensor(div_15, sum_45); div_15 = sum_45 = None + sub_23 = torch.ops.aten.sub.Tensor(mul_410, mul_413); mul_410 = mul_413 = None + mul_414 = torch.ops.aten.mul.Tensor(sub_23, rsqrt_49); sub_23 = rsqrt_49 = None + mul_415 = torch.ops.aten.mul.Tensor(convert_element_type_1470, mul_196); convert_element_type_1470 = mul_196 = None + sum_46 = torch.ops.aten.sum.dim_IntList(mul_415, [0, 1]); mul_415 = None + convert_element_type_1473 = torch.ops.prims.convert_element_type.default(mul_414, torch.bfloat16); mul_414 = None + add_181 = torch.ops.aten.add.Tensor(add_177, convert_element_type_1473); add_177 = convert_element_type_1473 = None + convert_element_type_default_50 = torch.ops.prims.convert_element_type.default(sum_46, torch.float32); sum_46 = None + reduce_scatter_tensor_68 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_50, 'avg', 256, '0'); convert_element_type_default_50 = None + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_68); reduce_scatter_tensor_68 = None + view_1269 = torch.ops.aten.view.default(add_181, [16384, 4096]) + permute_593 = torch.ops.aten.permute.default(view_1269, [1, 0]) + mm_331 = torch.ops.aten.mm.default(permute_593, view_839); permute_593 = view_839 = None + permute_595 = torch.ops.aten.permute.default(permute_271, [1, 0]); permute_271 = None + mm_332 = torch.ops.aten.mm.default(view_1269, permute_595); view_1269 = permute_595 = None + view_1270 = torch.ops.aten.view.default(mm_332, [2, 8192, 4096]); mm_332 = None + convert_element_type_1480 = torch.ops.prims.convert_element_type.default(mm_331, torch.float32); mm_331 = None + reduce_scatter_tensor_69 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1480, 'avg', 256, '0'); convert_element_type_1480 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_69); reduce_scatter_tensor_69 = None + view_1271 = torch.ops.aten.view.default(view_1270, [2, 8192, 32, 128]); view_1270 = None + permute_597 = torch.ops.aten.permute.default(view_1271, [0, 2, 1, 3]); view_1271 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16); primals_220 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 256, '0'); convert_element_type_793 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32); add_95 = None + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_217) + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + view_819 = torch.ops.aten.view.default(convert_element_type_795, [16384, 4096]); convert_element_type_795 = None + view_820 = torch.ops.aten.view.default(mm_168, [2, 8192, 4096]); mm_168 = None + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16); primals_222 = None + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 256, '0'); convert_element_type_799 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + mm_169 = torch.ops.aten.mm.default(view_819, permute_265) + view_823 = torch.ops.aten.view.default(mm_169, [2, 8192, 1024]); mm_169 = None + view_826 = torch.ops.aten.view.default(mm_170, [2, 8192, 1024]); mm_170 = None + view_827 = torch.ops.aten.view.default(view_820, [2, 8192, -1, 128]); view_820 = None + view_828 = torch.ops.aten.view.default(view_823, [2, 8192, -1, 128]); view_823 = None + view_829 = torch.ops.aten.view.default(view_826, [2, 8192, -1, 128]); view_826 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_827, torch.float32); view_827 = None + view_830 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 32, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_830); view_830 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_828, torch.float32); view_828 = None + view_831 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 8, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_831); view_831 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_16); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_833 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 32, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_16); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_834 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 8, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_833, torch.bfloat16); view_833 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_834, torch.bfloat16); view_834 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 8, 4, 128]); unsqueeze_48 = None + clone_48 = torch.ops.aten.clone.default(expand_48, memory_format = torch.contiguous_format); expand_48 = None + view_835 = torch.ops.aten.view.default(clone_48, [2, 8192, 32, 128]); clone_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_829, 3); view_829 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 8, 4, 128]); unsqueeze_49 = None + clone_49 = torch.ops.aten.clone.default(expand_49, memory_format = torch.contiguous_format); expand_49 = None + view_836 = torch.ops.aten.view.default(clone_49, [2, 8192, 32, 128]); clone_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_835, [0, 2, 1, 3]); view_835 = None + permute_269 = torch.ops.aten.permute.default(view_836, [0, 2, 1, 3]); view_836 = None + _scaled_dot_product_cudnn_attention_backward_7 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_597, permute_267, permute_268, permute_269, getitem_216, getitem_217, getitem_222, getitem_223, None, None, None, 8192, 8192, 0.0, True); permute_597 = permute_267 = permute_268 = permute_269 = getitem_216 = getitem_217 = getitem_222 = getitem_223 = None + getitem_309 = _scaled_dot_product_cudnn_attention_backward_7[0] + getitem_310 = _scaled_dot_product_cudnn_attention_backward_7[1] + getitem_311 = _scaled_dot_product_cudnn_attention_backward_7[2]; _scaled_dot_product_cudnn_attention_backward_7 = None + permute_598 = torch.ops.aten.permute.default(getitem_311, [0, 2, 1, 3]); getitem_311 = None + permute_599 = torch.ops.aten.permute.default(getitem_310, [0, 2, 1, 3]); getitem_310 = None + permute_600 = torch.ops.aten.permute.default(getitem_309, [0, 2, 1, 3]); getitem_309 = None + view_1272 = torch.ops.aten.view.default(permute_598, [2, 8192, 8, 4, 128]); permute_598 = None + sum_47 = torch.ops.aten.sum.dim_IntList(view_1272, [3], True); view_1272 = None + squeeze_14 = torch.ops.aten.squeeze.dim(sum_47, 3); sum_47 = None + view_1273 = torch.ops.aten.view.default(permute_599, [2, 8192, 8, 4, 128]); permute_599 = None + sum_48 = torch.ops.aten.sum.dim_IntList(view_1273, [3], True); view_1273 = None + squeeze_15 = torch.ops.aten.squeeze.dim(sum_48, 3); sum_48 = None + convert_element_type_1481 = torch.ops.prims.convert_element_type.default(squeeze_15, torch.float32); squeeze_15 = None + convert_element_type_1482 = torch.ops.prims.convert_element_type.default(permute_600, torch.float32); permute_600 = None + view_1274 = torch.ops.aten.view.default(convert_element_type_1481, [2, 8192, 8, 64, 2]); convert_element_type_1481 = None + view_as_complex_78 = torch.ops.aten.view_as_complex.default(view_1274); view_1274 = None + mul_416 = torch.ops.aten.mul.Tensor(view_as_complex_78, _conj); view_as_complex_78 = None + view_1275 = torch.ops.aten.view.default(convert_element_type_1482, [2, 8192, 32, 64, 2]); convert_element_type_1482 = None + view_as_complex_79 = torch.ops.aten.view_as_complex.default(view_1275); view_1275 = None + mul_417 = torch.ops.aten.mul.Tensor(view_as_complex_79, _conj); view_as_complex_79 = None + view_as_real_78 = torch.ops.aten.view_as_real.default(mul_416); mul_416 = None + view_1276 = torch.ops.aten.view.default(view_as_real_78, [2, 8192, 8, 128]); view_as_real_78 = None + convert_element_type_1483 = torch.ops.prims.convert_element_type.default(view_1276, torch.bfloat16); view_1276 = None + view_as_real_79 = torch.ops.aten.view_as_real.default(mul_417); mul_417 = None + view_1277 = torch.ops.aten.view.default(view_as_real_79, [2, 8192, 32, 128]); view_as_real_79 = None + convert_element_type_1484 = torch.ops.prims.convert_element_type.default(view_1277, torch.bfloat16); view_1277 = None + view_1278 = torch.ops.aten.view.default(squeeze_14, [2, 8192, 1024]); squeeze_14 = None + view_1279 = torch.ops.aten.view.default(convert_element_type_1483, [2, 8192, 1024]); convert_element_type_1483 = None + view_1280 = torch.ops.aten.view.default(convert_element_type_1484, [2, 8192, 4096]); convert_element_type_1484 = None + view_1281 = torch.ops.aten.view.default(view_1278, [16384, 1024]); view_1278 = None + permute_601 = torch.ops.aten.permute.default(view_1281, [1, 0]) + mm_333 = torch.ops.aten.mm.default(permute_601, view_819); permute_601 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16); primals_223 = None + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 256, '0'); convert_element_type_802 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + permute_603 = torch.ops.aten.permute.default(permute_266, [1, 0]); permute_266 = None + mm_334 = torch.ops.aten.mm.default(view_1281, permute_603); view_1281 = permute_603 = None + view_1282 = torch.ops.aten.view.default(mm_334, [2, 8192, 4096]); mm_334 = None + convert_element_type_1489 = torch.ops.prims.convert_element_type.default(mm_333, torch.float32); mm_333 = None + reduce_scatter_tensor_70 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1489, 'avg', 256, '0'); convert_element_type_1489 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_70); reduce_scatter_tensor_70 = None + view_1283 = torch.ops.aten.view.default(view_1279, [16384, 1024]); view_1279 = None + permute_605 = torch.ops.aten.permute.default(view_1283, [1, 0]) + mm_335 = torch.ops.aten.mm.default(permute_605, view_819); permute_605 = None + permute_607 = torch.ops.aten.permute.default(permute_265, [1, 0]); permute_265 = None + mm_336 = torch.ops.aten.mm.default(view_1283, permute_607); view_1283 = permute_607 = None + view_1284 = torch.ops.aten.view.default(mm_336, [2, 8192, 4096]); mm_336 = None + add_182 = torch.ops.aten.add.Tensor(view_1282, view_1284); view_1282 = view_1284 = None + convert_element_type_1494 = torch.ops.prims.convert_element_type.default(mm_335, torch.float32); mm_335 = None + reduce_scatter_tensor_71 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1494, 'avg', 256, '0'); convert_element_type_1494 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_71); reduce_scatter_tensor_71 = None + view_1285 = torch.ops.aten.view.default(view_1280, [16384, 4096]); view_1280 = None + permute_609 = torch.ops.aten.permute.default(view_1285, [1, 0]) + mm_337 = torch.ops.aten.mm.default(permute_609, view_819); permute_609 = view_819 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16); primals_221 = None + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 256, '0'); convert_element_type_796 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + permute_611 = torch.ops.aten.permute.default(permute_264, [1, 0]); permute_264 = None + mm_338 = torch.ops.aten.mm.default(view_1285, permute_611); view_1285 = permute_611 = None + view_1286 = torch.ops.aten.view.default(mm_338, [2, 8192, 4096]); mm_338 = None + add_183 = torch.ops.aten.add.Tensor(add_182, view_1286); add_182 = view_1286 = None + convert_element_type_1499 = torch.ops.prims.convert_element_type.default(mm_337, torch.float32); mm_337 = None + reduce_scatter_tensor_72 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1499, 'avg', 256, '0'); convert_element_type_1499 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_72); reduce_scatter_tensor_72 = None + convert_element_type_1500 = torch.ops.prims.convert_element_type.default(add_183, torch.float32); add_183 = None + convert_element_type_1502 = torch.ops.prims.convert_element_type.default(wait_tensor_217, torch.float32); wait_tensor_217 = None + mul_418 = torch.ops.aten.mul.Tensor(convert_element_type_1500, convert_element_type_1502); convert_element_type_1502 = None + mul_420 = torch.ops.aten.mul.Tensor(mul_192, mul_418) + sum_49 = torch.ops.aten.sum.dim_IntList(mul_420, [2], True); mul_420 = None + div_16 = torch.ops.aten.div.Tensor(mul_192, 4096) + mul_421 = torch.ops.aten.mul.Tensor(div_16, sum_49); div_16 = sum_49 = None + sub_24 = torch.ops.aten.sub.Tensor(mul_418, mul_421); mul_418 = mul_421 = None + mul_422 = torch.ops.aten.mul.Tensor(sub_24, rsqrt_48); sub_24 = rsqrt_48 = None + mul_423 = torch.ops.aten.mul.Tensor(convert_element_type_1500, mul_192); convert_element_type_1500 = mul_192 = None + sum_50 = torch.ops.aten.sum.dim_IntList(mul_423, [0, 1]); mul_423 = None + convert_element_type_1503 = torch.ops.prims.convert_element_type.default(mul_422, torch.bfloat16); mul_422 = None + add_184 = torch.ops.aten.add.Tensor(add_181, convert_element_type_1503); add_181 = convert_element_type_1503 = None + convert_element_type_default_49 = torch.ops.prims.convert_element_type.default(sum_50, torch.float32); sum_50 = None + reduce_scatter_tensor_73 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_49, 'avg', 256, '0'); convert_element_type_default_49 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_73); reduce_scatter_tensor_73 = None + view_1287 = torch.ops.aten.view.default(add_184, [16384, 4096]) + permute_613 = torch.ops.aten.permute.default(view_1287, [1, 0]) + permute_259 = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3]) + view_803 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16); primals_215 = None + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 256, '0'); convert_element_type_776 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_805 = torch.ops.aten.view.default(view_803, [16384, 4096]); view_803 = None + mm_164 = torch.ops.aten.mm.default(view_805, permute_260) + view_806 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + add_93 = torch.ops.aten.add.Tensor(add_91, view_806); view_806 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16); primals_216 = None + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 256, '0'); convert_element_type_779 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32); add_93 = None + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_213) + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + view_809 = torch.ops.aten.view.default(convert_element_type_781, [16384, 4096]); convert_element_type_781 = None + view_810 = torch.ops.aten.view.default(mm_165, [2, 8192, 14336]); mm_165 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_810, torch.float32); view_810 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16); primals_218 = None + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 256, '0'); convert_element_type_787 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + mm_166 = torch.ops.aten.mm.default(view_809, permute_262) + view_813 = torch.ops.aten.view.default(mm_166, [2, 8192, 14336]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_813) + view_815 = torch.ops.aten.view.default(mul_191, [16384, 14336]); mul_191 = None + mm_339 = torch.ops.aten.mm.default(permute_613, view_815); permute_613 = view_815 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16); primals_219 = None + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 256, '0'); convert_element_type_790 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_216, [1, 0]); wait_tensor_216 = None + permute_615 = torch.ops.aten.permute.default(permute_263, [1, 0]); permute_263 = None + mm_340 = torch.ops.aten.mm.default(view_1287, permute_615); view_1287 = permute_615 = None + view_1288 = torch.ops.aten.view.default(mm_340, [2, 8192, 14336]); mm_340 = None + convert_element_type_1510 = torch.ops.prims.convert_element_type.default(mm_339, torch.float32); mm_339 = None + reduce_scatter_tensor_74 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1510, 'avg', 256, '0'); convert_element_type_1510 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_74); reduce_scatter_tensor_74 = None + mul_424 = torch.ops.aten.mul.Tensor(view_1288, convert_element_type_786); convert_element_type_786 = None + mul_425 = torch.ops.aten.mul.Tensor(view_1288, view_813); view_1288 = view_813 = None + view_1289 = torch.ops.aten.view.default(mul_424, [16384, 14336]); mul_424 = None + permute_617 = torch.ops.aten.permute.default(view_1289, [1, 0]) + mm_341 = torch.ops.aten.mm.default(permute_617, view_809); permute_617 = None + permute_619 = torch.ops.aten.permute.default(permute_262, [1, 0]); permute_262 = None + mm_342 = torch.ops.aten.mm.default(view_1289, permute_619); view_1289 = permute_619 = None + view_1290 = torch.ops.aten.view.default(mm_342, [2, 8192, 4096]); mm_342 = None + convert_element_type_1515 = torch.ops.prims.convert_element_type.default(mm_341, torch.float32); mm_341 = None + reduce_scatter_tensor_75 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1515, 'avg', 256, '0'); convert_element_type_1515 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_75); reduce_scatter_tensor_75 = None + convert_element_type_1516 = torch.ops.prims.convert_element_type.default(mul_425, torch.float32); mul_425 = None + neg_8 = torch.ops.aten.neg.default(convert_element_type_785) + exp_8 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_185 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + reciprocal_8 = torch.ops.aten.reciprocal.default(add_185); add_185 = None + mul_426 = torch.ops.aten.mul.Tensor(reciprocal_8, 1); reciprocal_8 = None + mul_427 = torch.ops.aten.mul.Tensor(convert_element_type_1516, mul_426); convert_element_type_1516 = None + sub_25 = torch.ops.aten.sub.Tensor(1, mul_426); mul_426 = None + mul_428 = torch.ops.aten.mul.Tensor(convert_element_type_785, sub_25); convert_element_type_785 = sub_25 = None + add_186 = torch.ops.aten.add.Tensor(mul_428, 1); mul_428 = None + mul_429 = torch.ops.aten.mul.Tensor(mul_427, add_186); mul_427 = add_186 = None + convert_element_type_1518 = torch.ops.prims.convert_element_type.default(mul_429, torch.bfloat16); mul_429 = None + view_1291 = torch.ops.aten.view.default(convert_element_type_1518, [16384, 14336]); convert_element_type_1518 = None + permute_621 = torch.ops.aten.permute.default(view_1291, [1, 0]) + mm_343 = torch.ops.aten.mm.default(permute_621, view_809); permute_621 = view_809 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16); primals_217 = None + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 256, '0'); convert_element_type_782 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + permute_623 = torch.ops.aten.permute.default(permute_261, [1, 0]); permute_261 = None + mm_344 = torch.ops.aten.mm.default(view_1291, permute_623); view_1291 = permute_623 = None + view_1292 = torch.ops.aten.view.default(mm_344, [2, 8192, 4096]); mm_344 = None + add_187 = torch.ops.aten.add.Tensor(view_1290, view_1292); view_1290 = view_1292 = None + convert_element_type_1523 = torch.ops.prims.convert_element_type.default(mm_343, torch.float32); mm_343 = None + reduce_scatter_tensor_76 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1523, 'avg', 256, '0'); convert_element_type_1523 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_76); reduce_scatter_tensor_76 = None + convert_element_type_1524 = torch.ops.prims.convert_element_type.default(add_187, torch.float32); add_187 = None + convert_element_type_1526 = torch.ops.prims.convert_element_type.default(wait_tensor_213, torch.float32); wait_tensor_213 = None + mul_430 = torch.ops.aten.mul.Tensor(convert_element_type_1524, convert_element_type_1526); convert_element_type_1526 = None + mul_432 = torch.ops.aten.mul.Tensor(mul_188, mul_430) + sum_51 = torch.ops.aten.sum.dim_IntList(mul_432, [2], True); mul_432 = None + div_17 = torch.ops.aten.div.Tensor(mul_188, 4096) + mul_433 = torch.ops.aten.mul.Tensor(div_17, sum_51); div_17 = sum_51 = None + sub_26 = torch.ops.aten.sub.Tensor(mul_430, mul_433); mul_430 = mul_433 = None + mul_434 = torch.ops.aten.mul.Tensor(sub_26, rsqrt_47); sub_26 = rsqrt_47 = None + mul_435 = torch.ops.aten.mul.Tensor(convert_element_type_1524, mul_188); convert_element_type_1524 = mul_188 = None + sum_52 = torch.ops.aten.sum.dim_IntList(mul_435, [0, 1]); mul_435 = None + convert_element_type_1527 = torch.ops.prims.convert_element_type.default(mul_434, torch.bfloat16); mul_434 = None + add_188 = torch.ops.aten.add.Tensor(add_184, convert_element_type_1527); add_184 = convert_element_type_1527 = None + convert_element_type_default_48 = torch.ops.prims.convert_element_type.default(sum_52, torch.float32); sum_52 = None + reduce_scatter_tensor_77 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_48, 'avg', 256, '0'); convert_element_type_default_48 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_77); reduce_scatter_tensor_77 = None + view_1293 = torch.ops.aten.view.default(add_188, [16384, 4096]) + permute_625 = torch.ops.aten.permute.default(view_1293, [1, 0]) + mm_345 = torch.ops.aten.mm.default(permute_625, view_805); permute_625 = view_805 = None + permute_627 = torch.ops.aten.permute.default(permute_260, [1, 0]); permute_260 = None + mm_346 = torch.ops.aten.mm.default(view_1293, permute_627); view_1293 = permute_627 = None + view_1294 = torch.ops.aten.view.default(mm_346, [2, 8192, 4096]); mm_346 = None + convert_element_type_1534 = torch.ops.prims.convert_element_type.default(mm_345, torch.float32); mm_345 = None + reduce_scatter_tensor_78 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1534, 'avg', 256, '0'); convert_element_type_1534 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_78); reduce_scatter_tensor_78 = None + view_1295 = torch.ops.aten.view.default(view_1294, [2, 8192, 32, 128]); view_1294 = None + permute_629 = torch.ops.aten.permute.default(view_1295, [0, 2, 1, 3]); view_1295 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16); primals_211 = None + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 256, '0'); convert_element_type_760 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32); add_91 = None + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_208) + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + view_785 = torch.ops.aten.view.default(convert_element_type_762, [16384, 4096]); convert_element_type_762 = None + view_786 = torch.ops.aten.view.default(mm_161, [2, 8192, 4096]); mm_161 = None + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16); primals_213 = None + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 256, '0'); convert_element_type_766 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_210, [1, 0]); wait_tensor_210 = None + mm_162 = torch.ops.aten.mm.default(view_785, permute_254) + view_789 = torch.ops.aten.view.default(mm_162, [2, 8192, 1024]); mm_162 = None + view_792 = torch.ops.aten.view.default(mm_163, [2, 8192, 1024]); mm_163 = None + view_793 = torch.ops.aten.view.default(view_786, [2, 8192, -1, 128]); view_786 = None + view_794 = torch.ops.aten.view.default(view_789, [2, 8192, -1, 128]); view_789 = None + view_795 = torch.ops.aten.view.default(view_792, [2, 8192, -1, 128]); view_792 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_793, torch.float32); view_793 = None + view_796 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 32, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_796); view_796 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_794, torch.float32); view_794 = None + view_797 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 8, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_797); view_797 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_16); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_799 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 32, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_16); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_800 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 8, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_799, torch.bfloat16); view_799 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_800, torch.bfloat16); view_800 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 8, 4, 128]); unsqueeze_46 = None + clone_46 = torch.ops.aten.clone.default(expand_46, memory_format = torch.contiguous_format); expand_46 = None + view_801 = torch.ops.aten.view.default(clone_46, [2, 8192, 32, 128]); clone_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_795, 3); view_795 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 8, 4, 128]); unsqueeze_47 = None + clone_47 = torch.ops.aten.clone.default(expand_47, memory_format = torch.contiguous_format); expand_47 = None + view_802 = torch.ops.aten.view.default(clone_47, [2, 8192, 32, 128]); clone_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_801, [0, 2, 1, 3]); view_801 = None + permute_258 = torch.ops.aten.permute.default(view_802, [0, 2, 1, 3]); view_802 = None + _scaled_dot_product_cudnn_attention_backward_8 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_629, permute_256, permute_257, permute_258, getitem_207, getitem_208, getitem_213, getitem_214, None, None, None, 8192, 8192, 0.0, True); permute_629 = permute_256 = permute_257 = permute_258 = getitem_207 = getitem_208 = getitem_213 = getitem_214 = None + getitem_312 = _scaled_dot_product_cudnn_attention_backward_8[0] + getitem_313 = _scaled_dot_product_cudnn_attention_backward_8[1] + getitem_314 = _scaled_dot_product_cudnn_attention_backward_8[2]; _scaled_dot_product_cudnn_attention_backward_8 = None + permute_630 = torch.ops.aten.permute.default(getitem_314, [0, 2, 1, 3]); getitem_314 = None + permute_631 = torch.ops.aten.permute.default(getitem_313, [0, 2, 1, 3]); getitem_313 = None + permute_632 = torch.ops.aten.permute.default(getitem_312, [0, 2, 1, 3]); getitem_312 = None + view_1296 = torch.ops.aten.view.default(permute_630, [2, 8192, 8, 4, 128]); permute_630 = None + sum_53 = torch.ops.aten.sum.dim_IntList(view_1296, [3], True); view_1296 = None + squeeze_16 = torch.ops.aten.squeeze.dim(sum_53, 3); sum_53 = None + view_1297 = torch.ops.aten.view.default(permute_631, [2, 8192, 8, 4, 128]); permute_631 = None + sum_54 = torch.ops.aten.sum.dim_IntList(view_1297, [3], True); view_1297 = None + squeeze_17 = torch.ops.aten.squeeze.dim(sum_54, 3); sum_54 = None + convert_element_type_1535 = torch.ops.prims.convert_element_type.default(squeeze_17, torch.float32); squeeze_17 = None + convert_element_type_1536 = torch.ops.prims.convert_element_type.default(permute_632, torch.float32); permute_632 = None + view_1298 = torch.ops.aten.view.default(convert_element_type_1535, [2, 8192, 8, 64, 2]); convert_element_type_1535 = None + view_as_complex_80 = torch.ops.aten.view_as_complex.default(view_1298); view_1298 = None + mul_436 = torch.ops.aten.mul.Tensor(view_as_complex_80, _conj); view_as_complex_80 = None + view_1299 = torch.ops.aten.view.default(convert_element_type_1536, [2, 8192, 32, 64, 2]); convert_element_type_1536 = None + view_as_complex_81 = torch.ops.aten.view_as_complex.default(view_1299); view_1299 = None + mul_437 = torch.ops.aten.mul.Tensor(view_as_complex_81, _conj); view_as_complex_81 = None + view_as_real_80 = torch.ops.aten.view_as_real.default(mul_436); mul_436 = None + view_1300 = torch.ops.aten.view.default(view_as_real_80, [2, 8192, 8, 128]); view_as_real_80 = None + convert_element_type_1537 = torch.ops.prims.convert_element_type.default(view_1300, torch.bfloat16); view_1300 = None + view_as_real_81 = torch.ops.aten.view_as_real.default(mul_437); mul_437 = None + view_1301 = torch.ops.aten.view.default(view_as_real_81, [2, 8192, 32, 128]); view_as_real_81 = None + convert_element_type_1538 = torch.ops.prims.convert_element_type.default(view_1301, torch.bfloat16); view_1301 = None + view_1302 = torch.ops.aten.view.default(squeeze_16, [2, 8192, 1024]); squeeze_16 = None + view_1303 = torch.ops.aten.view.default(convert_element_type_1537, [2, 8192, 1024]); convert_element_type_1537 = None + view_1304 = torch.ops.aten.view.default(convert_element_type_1538, [2, 8192, 4096]); convert_element_type_1538 = None + view_1305 = torch.ops.aten.view.default(view_1302, [16384, 1024]); view_1302 = None + permute_633 = torch.ops.aten.permute.default(view_1305, [1, 0]) + mm_347 = torch.ops.aten.mm.default(permute_633, view_785); permute_633 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16); primals_214 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 256, '0'); convert_element_type_769 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_211, [1, 0]); wait_tensor_211 = None + permute_635 = torch.ops.aten.permute.default(permute_255, [1, 0]); permute_255 = None + mm_348 = torch.ops.aten.mm.default(view_1305, permute_635); view_1305 = permute_635 = None + view_1306 = torch.ops.aten.view.default(mm_348, [2, 8192, 4096]); mm_348 = None + convert_element_type_1543 = torch.ops.prims.convert_element_type.default(mm_347, torch.float32); mm_347 = None + reduce_scatter_tensor_79 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1543, 'avg', 256, '0'); convert_element_type_1543 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_79); reduce_scatter_tensor_79 = None + view_1307 = torch.ops.aten.view.default(view_1303, [16384, 1024]); view_1303 = None + permute_637 = torch.ops.aten.permute.default(view_1307, [1, 0]) + mm_349 = torch.ops.aten.mm.default(permute_637, view_785); permute_637 = None + permute_639 = torch.ops.aten.permute.default(permute_254, [1, 0]); permute_254 = None + mm_350 = torch.ops.aten.mm.default(view_1307, permute_639); view_1307 = permute_639 = None + view_1308 = torch.ops.aten.view.default(mm_350, [2, 8192, 4096]); mm_350 = None + add_189 = torch.ops.aten.add.Tensor(view_1306, view_1308); view_1306 = view_1308 = None + convert_element_type_1548 = torch.ops.prims.convert_element_type.default(mm_349, torch.float32); mm_349 = None + reduce_scatter_tensor_80 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1548, 'avg', 256, '0'); convert_element_type_1548 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_80); reduce_scatter_tensor_80 = None + view_1309 = torch.ops.aten.view.default(view_1304, [16384, 4096]); view_1304 = None + permute_641 = torch.ops.aten.permute.default(view_1309, [1, 0]) + mm_351 = torch.ops.aten.mm.default(permute_641, view_785); permute_641 = view_785 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16); primals_212 = None + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 256, '0'); convert_element_type_763 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_209, [1, 0]); wait_tensor_209 = None + permute_643 = torch.ops.aten.permute.default(permute_253, [1, 0]); permute_253 = None + mm_352 = torch.ops.aten.mm.default(view_1309, permute_643); view_1309 = permute_643 = None + view_1310 = torch.ops.aten.view.default(mm_352, [2, 8192, 4096]); mm_352 = None + add_190 = torch.ops.aten.add.Tensor(add_189, view_1310); add_189 = view_1310 = None + convert_element_type_1553 = torch.ops.prims.convert_element_type.default(mm_351, torch.float32); mm_351 = None + reduce_scatter_tensor_81 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1553, 'avg', 256, '0'); convert_element_type_1553 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_81); reduce_scatter_tensor_81 = None + convert_element_type_1554 = torch.ops.prims.convert_element_type.default(add_190, torch.float32); add_190 = None + convert_element_type_1556 = torch.ops.prims.convert_element_type.default(wait_tensor_208, torch.float32); wait_tensor_208 = None + mul_438 = torch.ops.aten.mul.Tensor(convert_element_type_1554, convert_element_type_1556); convert_element_type_1556 = None + mul_440 = torch.ops.aten.mul.Tensor(mul_184, mul_438) + sum_55 = torch.ops.aten.sum.dim_IntList(mul_440, [2], True); mul_440 = None + div_18 = torch.ops.aten.div.Tensor(mul_184, 4096) + mul_441 = torch.ops.aten.mul.Tensor(div_18, sum_55); div_18 = sum_55 = None + sub_27 = torch.ops.aten.sub.Tensor(mul_438, mul_441); mul_438 = mul_441 = None + mul_442 = torch.ops.aten.mul.Tensor(sub_27, rsqrt_46); sub_27 = rsqrt_46 = None + mul_443 = torch.ops.aten.mul.Tensor(convert_element_type_1554, mul_184); convert_element_type_1554 = mul_184 = None + sum_56 = torch.ops.aten.sum.dim_IntList(mul_443, [0, 1]); mul_443 = None + convert_element_type_1557 = torch.ops.prims.convert_element_type.default(mul_442, torch.bfloat16); mul_442 = None + add_191 = torch.ops.aten.add.Tensor(add_188, convert_element_type_1557); add_188 = convert_element_type_1557 = None + convert_element_type_default_47 = torch.ops.prims.convert_element_type.default(sum_56, torch.float32); sum_56 = None + reduce_scatter_tensor_82 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_47, 'avg', 256, '0'); convert_element_type_default_47 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_82); reduce_scatter_tensor_82 = None + view_1311 = torch.ops.aten.view.default(add_191, [16384, 4096]) + permute_645 = torch.ops.aten.permute.default(view_1311, [1, 0]) + permute_248 = torch.ops.aten.permute.default(getitem_198, [0, 2, 1, 3]) + view_769 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16); primals_206 = None + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 256, '0'); convert_element_type_743 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_203, [1, 0]); wait_tensor_203 = None + view_771 = torch.ops.aten.view.default(view_769, [16384, 4096]); view_769 = None + mm_157 = torch.ops.aten.mm.default(view_771, permute_249) + view_772 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + add_89 = torch.ops.aten.add.Tensor(add_87, view_772); view_772 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16); primals_207 = None + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 256, '0'); convert_element_type_746 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32); add_89 = None + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_204) + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + view_775 = torch.ops.aten.view.default(convert_element_type_748, [16384, 4096]); convert_element_type_748 = None + view_776 = torch.ops.aten.view.default(mm_158, [2, 8192, 14336]); mm_158 = None + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_776, torch.float32); view_776 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16); primals_209 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 256, '0'); convert_element_type_754 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + mm_159 = torch.ops.aten.mm.default(view_775, permute_251) + view_779 = torch.ops.aten.view.default(mm_159, [2, 8192, 14336]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_779) + view_781 = torch.ops.aten.view.default(mul_183, [16384, 14336]); mul_183 = None + mm_353 = torch.ops.aten.mm.default(permute_645, view_781); permute_645 = view_781 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16); primals_210 = None + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 256, '0'); convert_element_type_757 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + permute_647 = torch.ops.aten.permute.default(permute_252, [1, 0]); permute_252 = None + mm_354 = torch.ops.aten.mm.default(view_1311, permute_647); view_1311 = permute_647 = None + view_1312 = torch.ops.aten.view.default(mm_354, [2, 8192, 14336]); mm_354 = None + convert_element_type_1564 = torch.ops.prims.convert_element_type.default(mm_353, torch.float32); mm_353 = None + reduce_scatter_tensor_83 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1564, 'avg', 256, '0'); convert_element_type_1564 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_83); reduce_scatter_tensor_83 = None + mul_444 = torch.ops.aten.mul.Tensor(view_1312, convert_element_type_753); convert_element_type_753 = None + mul_445 = torch.ops.aten.mul.Tensor(view_1312, view_779); view_1312 = view_779 = None + view_1313 = torch.ops.aten.view.default(mul_444, [16384, 14336]); mul_444 = None + permute_649 = torch.ops.aten.permute.default(view_1313, [1, 0]) + mm_355 = torch.ops.aten.mm.default(permute_649, view_775); permute_649 = None + permute_651 = torch.ops.aten.permute.default(permute_251, [1, 0]); permute_251 = None + mm_356 = torch.ops.aten.mm.default(view_1313, permute_651); view_1313 = permute_651 = None + view_1314 = torch.ops.aten.view.default(mm_356, [2, 8192, 4096]); mm_356 = None + convert_element_type_1569 = torch.ops.prims.convert_element_type.default(mm_355, torch.float32); mm_355 = None + reduce_scatter_tensor_84 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1569, 'avg', 256, '0'); convert_element_type_1569 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_84); reduce_scatter_tensor_84 = None + convert_element_type_1570 = torch.ops.prims.convert_element_type.default(mul_445, torch.float32); mul_445 = None + neg_9 = torch.ops.aten.neg.default(convert_element_type_752) + exp_9 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_192 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + reciprocal_9 = torch.ops.aten.reciprocal.default(add_192); add_192 = None + mul_446 = torch.ops.aten.mul.Tensor(reciprocal_9, 1); reciprocal_9 = None + mul_447 = torch.ops.aten.mul.Tensor(convert_element_type_1570, mul_446); convert_element_type_1570 = None + sub_28 = torch.ops.aten.sub.Tensor(1, mul_446); mul_446 = None + mul_448 = torch.ops.aten.mul.Tensor(convert_element_type_752, sub_28); convert_element_type_752 = sub_28 = None + add_193 = torch.ops.aten.add.Tensor(mul_448, 1); mul_448 = None + mul_449 = torch.ops.aten.mul.Tensor(mul_447, add_193); mul_447 = add_193 = None + convert_element_type_1572 = torch.ops.prims.convert_element_type.default(mul_449, torch.bfloat16); mul_449 = None + view_1315 = torch.ops.aten.view.default(convert_element_type_1572, [16384, 14336]); convert_element_type_1572 = None + permute_653 = torch.ops.aten.permute.default(view_1315, [1, 0]) + mm_357 = torch.ops.aten.mm.default(permute_653, view_775); permute_653 = view_775 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16); primals_208 = None + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 256, '0'); convert_element_type_749 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + permute_655 = torch.ops.aten.permute.default(permute_250, [1, 0]); permute_250 = None + mm_358 = torch.ops.aten.mm.default(view_1315, permute_655); view_1315 = permute_655 = None + view_1316 = torch.ops.aten.view.default(mm_358, [2, 8192, 4096]); mm_358 = None + add_194 = torch.ops.aten.add.Tensor(view_1314, view_1316); view_1314 = view_1316 = None + convert_element_type_1577 = torch.ops.prims.convert_element_type.default(mm_357, torch.float32); mm_357 = None + reduce_scatter_tensor_85 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1577, 'avg', 256, '0'); convert_element_type_1577 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_85); reduce_scatter_tensor_85 = None + convert_element_type_1578 = torch.ops.prims.convert_element_type.default(add_194, torch.float32); add_194 = None + convert_element_type_1580 = torch.ops.prims.convert_element_type.default(wait_tensor_204, torch.float32); wait_tensor_204 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_1578, convert_element_type_1580); convert_element_type_1580 = None + mul_452 = torch.ops.aten.mul.Tensor(mul_180, mul_450) + sum_57 = torch.ops.aten.sum.dim_IntList(mul_452, [2], True); mul_452 = None + div_19 = torch.ops.aten.div.Tensor(mul_180, 4096) + mul_453 = torch.ops.aten.mul.Tensor(div_19, sum_57); div_19 = sum_57 = None + sub_29 = torch.ops.aten.sub.Tensor(mul_450, mul_453); mul_450 = mul_453 = None + mul_454 = torch.ops.aten.mul.Tensor(sub_29, rsqrt_45); sub_29 = rsqrt_45 = None + mul_455 = torch.ops.aten.mul.Tensor(convert_element_type_1578, mul_180); convert_element_type_1578 = mul_180 = None + sum_58 = torch.ops.aten.sum.dim_IntList(mul_455, [0, 1]); mul_455 = None + convert_element_type_1581 = torch.ops.prims.convert_element_type.default(mul_454, torch.bfloat16); mul_454 = None + add_195 = torch.ops.aten.add.Tensor(add_191, convert_element_type_1581); add_191 = convert_element_type_1581 = None + convert_element_type_default_46 = torch.ops.prims.convert_element_type.default(sum_58, torch.float32); sum_58 = None + reduce_scatter_tensor_86 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_46, 'avg', 256, '0'); convert_element_type_default_46 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_86); reduce_scatter_tensor_86 = None + view_1317 = torch.ops.aten.view.default(add_195, [16384, 4096]) + permute_657 = torch.ops.aten.permute.default(view_1317, [1, 0]) + mm_359 = torch.ops.aten.mm.default(permute_657, view_771); permute_657 = view_771 = None + permute_659 = torch.ops.aten.permute.default(permute_249, [1, 0]); permute_249 = None + mm_360 = torch.ops.aten.mm.default(view_1317, permute_659); view_1317 = permute_659 = None + view_1318 = torch.ops.aten.view.default(mm_360, [2, 8192, 4096]); mm_360 = None + convert_element_type_1588 = torch.ops.prims.convert_element_type.default(mm_359, torch.float32); mm_359 = None + reduce_scatter_tensor_87 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1588, 'avg', 256, '0'); convert_element_type_1588 = None + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_87); reduce_scatter_tensor_87 = None + view_1319 = torch.ops.aten.view.default(view_1318, [2, 8192, 32, 128]); view_1318 = None + permute_661 = torch.ops.aten.permute.default(view_1319, [0, 2, 1, 3]); view_1319 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16); primals_202 = None + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 256, '0'); convert_element_type_727 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32); add_87 = None + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_199) + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + view_751 = torch.ops.aten.view.default(convert_element_type_729, [16384, 4096]); convert_element_type_729 = None + view_752 = torch.ops.aten.view.default(mm_154, [2, 8192, 4096]); mm_154 = None + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16); primals_204 = None + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 256, '0'); convert_element_type_733 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_155 = torch.ops.aten.mm.default(view_751, permute_243) + view_755 = torch.ops.aten.view.default(mm_155, [2, 8192, 1024]); mm_155 = None + view_758 = torch.ops.aten.view.default(mm_156, [2, 8192, 1024]); mm_156 = None + view_759 = torch.ops.aten.view.default(view_752, [2, 8192, -1, 128]); view_752 = None + view_760 = torch.ops.aten.view.default(view_755, [2, 8192, -1, 128]); view_755 = None + view_761 = torch.ops.aten.view.default(view_758, [2, 8192, -1, 128]); view_758 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_759, torch.float32); view_759 = None + view_762 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 32, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_762); view_762 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_760, torch.float32); view_760 = None + view_763 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 8, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_763); view_763 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_16); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_765 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 32, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_16); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_766 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 8, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_765, torch.bfloat16); view_765 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_766, torch.bfloat16); view_766 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 8, 4, 128]); unsqueeze_44 = None + clone_44 = torch.ops.aten.clone.default(expand_44, memory_format = torch.contiguous_format); expand_44 = None + view_767 = torch.ops.aten.view.default(clone_44, [2, 8192, 32, 128]); clone_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_761, 3); view_761 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 8, 4, 128]); unsqueeze_45 = None + clone_45 = torch.ops.aten.clone.default(expand_45, memory_format = torch.contiguous_format); expand_45 = None + view_768 = torch.ops.aten.view.default(clone_45, [2, 8192, 32, 128]); clone_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_767, [0, 2, 1, 3]); view_767 = None + permute_247 = torch.ops.aten.permute.default(view_768, [0, 2, 1, 3]); view_768 = None + _scaled_dot_product_cudnn_attention_backward_9 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_661, permute_245, permute_246, permute_247, getitem_198, getitem_199, getitem_204, getitem_205, None, None, None, 8192, 8192, 0.0, True); permute_661 = permute_245 = permute_246 = permute_247 = getitem_198 = getitem_199 = getitem_204 = getitem_205 = None + getitem_315 = _scaled_dot_product_cudnn_attention_backward_9[0] + getitem_316 = _scaled_dot_product_cudnn_attention_backward_9[1] + getitem_317 = _scaled_dot_product_cudnn_attention_backward_9[2]; _scaled_dot_product_cudnn_attention_backward_9 = None + permute_662 = torch.ops.aten.permute.default(getitem_317, [0, 2, 1, 3]); getitem_317 = None + permute_663 = torch.ops.aten.permute.default(getitem_316, [0, 2, 1, 3]); getitem_316 = None + permute_664 = torch.ops.aten.permute.default(getitem_315, [0, 2, 1, 3]); getitem_315 = None + view_1320 = torch.ops.aten.view.default(permute_662, [2, 8192, 8, 4, 128]); permute_662 = None + sum_59 = torch.ops.aten.sum.dim_IntList(view_1320, [3], True); view_1320 = None + squeeze_18 = torch.ops.aten.squeeze.dim(sum_59, 3); sum_59 = None + view_1321 = torch.ops.aten.view.default(permute_663, [2, 8192, 8, 4, 128]); permute_663 = None + sum_60 = torch.ops.aten.sum.dim_IntList(view_1321, [3], True); view_1321 = None + squeeze_19 = torch.ops.aten.squeeze.dim(sum_60, 3); sum_60 = None + convert_element_type_1589 = torch.ops.prims.convert_element_type.default(squeeze_19, torch.float32); squeeze_19 = None + convert_element_type_1590 = torch.ops.prims.convert_element_type.default(permute_664, torch.float32); permute_664 = None + view_1322 = torch.ops.aten.view.default(convert_element_type_1589, [2, 8192, 8, 64, 2]); convert_element_type_1589 = None + view_as_complex_82 = torch.ops.aten.view_as_complex.default(view_1322); view_1322 = None + mul_456 = torch.ops.aten.mul.Tensor(view_as_complex_82, _conj); view_as_complex_82 = None + view_1323 = torch.ops.aten.view.default(convert_element_type_1590, [2, 8192, 32, 64, 2]); convert_element_type_1590 = None + view_as_complex_83 = torch.ops.aten.view_as_complex.default(view_1323); view_1323 = None + mul_457 = torch.ops.aten.mul.Tensor(view_as_complex_83, _conj); view_as_complex_83 = None + view_as_real_82 = torch.ops.aten.view_as_real.default(mul_456); mul_456 = None + view_1324 = torch.ops.aten.view.default(view_as_real_82, [2, 8192, 8, 128]); view_as_real_82 = None + convert_element_type_1591 = torch.ops.prims.convert_element_type.default(view_1324, torch.bfloat16); view_1324 = None + view_as_real_83 = torch.ops.aten.view_as_real.default(mul_457); mul_457 = None + view_1325 = torch.ops.aten.view.default(view_as_real_83, [2, 8192, 32, 128]); view_as_real_83 = None + convert_element_type_1592 = torch.ops.prims.convert_element_type.default(view_1325, torch.bfloat16); view_1325 = None + view_1326 = torch.ops.aten.view.default(squeeze_18, [2, 8192, 1024]); squeeze_18 = None + view_1327 = torch.ops.aten.view.default(convert_element_type_1591, [2, 8192, 1024]); convert_element_type_1591 = None + view_1328 = torch.ops.aten.view.default(convert_element_type_1592, [2, 8192, 4096]); convert_element_type_1592 = None + view_1329 = torch.ops.aten.view.default(view_1326, [16384, 1024]); view_1326 = None + permute_665 = torch.ops.aten.permute.default(view_1329, [1, 0]) + mm_361 = torch.ops.aten.mm.default(permute_665, view_751); permute_665 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16); primals_205 = None + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 256, '0'); convert_element_type_736 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + permute_667 = torch.ops.aten.permute.default(permute_244, [1, 0]); permute_244 = None + mm_362 = torch.ops.aten.mm.default(view_1329, permute_667); view_1329 = permute_667 = None + view_1330 = torch.ops.aten.view.default(mm_362, [2, 8192, 4096]); mm_362 = None + convert_element_type_1597 = torch.ops.prims.convert_element_type.default(mm_361, torch.float32); mm_361 = None + reduce_scatter_tensor_88 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1597, 'avg', 256, '0'); convert_element_type_1597 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_88); reduce_scatter_tensor_88 = None + view_1331 = torch.ops.aten.view.default(view_1327, [16384, 1024]); view_1327 = None + permute_669 = torch.ops.aten.permute.default(view_1331, [1, 0]) + mm_363 = torch.ops.aten.mm.default(permute_669, view_751); permute_669 = None + permute_671 = torch.ops.aten.permute.default(permute_243, [1, 0]); permute_243 = None + mm_364 = torch.ops.aten.mm.default(view_1331, permute_671); view_1331 = permute_671 = None + view_1332 = torch.ops.aten.view.default(mm_364, [2, 8192, 4096]); mm_364 = None + add_196 = torch.ops.aten.add.Tensor(view_1330, view_1332); view_1330 = view_1332 = None + convert_element_type_1602 = torch.ops.prims.convert_element_type.default(mm_363, torch.float32); mm_363 = None + reduce_scatter_tensor_89 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1602, 'avg', 256, '0'); convert_element_type_1602 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_89); reduce_scatter_tensor_89 = None + view_1333 = torch.ops.aten.view.default(view_1328, [16384, 4096]); view_1328 = None + permute_673 = torch.ops.aten.permute.default(view_1333, [1, 0]) + mm_365 = torch.ops.aten.mm.default(permute_673, view_751); permute_673 = view_751 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16); primals_203 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 256, '0'); convert_element_type_730 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + permute_675 = torch.ops.aten.permute.default(permute_242, [1, 0]); permute_242 = None + mm_366 = torch.ops.aten.mm.default(view_1333, permute_675); view_1333 = permute_675 = None + view_1334 = torch.ops.aten.view.default(mm_366, [2, 8192, 4096]); mm_366 = None + add_197 = torch.ops.aten.add.Tensor(add_196, view_1334); add_196 = view_1334 = None + convert_element_type_1607 = torch.ops.prims.convert_element_type.default(mm_365, torch.float32); mm_365 = None + reduce_scatter_tensor_90 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1607, 'avg', 256, '0'); convert_element_type_1607 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_90); reduce_scatter_tensor_90 = None + convert_element_type_1608 = torch.ops.prims.convert_element_type.default(add_197, torch.float32); add_197 = None + convert_element_type_1610 = torch.ops.prims.convert_element_type.default(wait_tensor_199, torch.float32); wait_tensor_199 = None + mul_458 = torch.ops.aten.mul.Tensor(convert_element_type_1608, convert_element_type_1610); convert_element_type_1610 = None + mul_460 = torch.ops.aten.mul.Tensor(mul_176, mul_458) + sum_61 = torch.ops.aten.sum.dim_IntList(mul_460, [2], True); mul_460 = None + div_20 = torch.ops.aten.div.Tensor(mul_176, 4096) + mul_461 = torch.ops.aten.mul.Tensor(div_20, sum_61); div_20 = sum_61 = None + sub_30 = torch.ops.aten.sub.Tensor(mul_458, mul_461); mul_458 = mul_461 = None + mul_462 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_44); sub_30 = rsqrt_44 = None + mul_463 = torch.ops.aten.mul.Tensor(convert_element_type_1608, mul_176); convert_element_type_1608 = mul_176 = None + sum_62 = torch.ops.aten.sum.dim_IntList(mul_463, [0, 1]); mul_463 = None + convert_element_type_1611 = torch.ops.prims.convert_element_type.default(mul_462, torch.bfloat16); mul_462 = None + add_198 = torch.ops.aten.add.Tensor(add_195, convert_element_type_1611); add_195 = convert_element_type_1611 = None + convert_element_type_default_45 = torch.ops.prims.convert_element_type.default(sum_62, torch.float32); sum_62 = None + reduce_scatter_tensor_91 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_45, 'avg', 256, '0'); convert_element_type_default_45 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_91); reduce_scatter_tensor_91 = None + view_1335 = torch.ops.aten.view.default(add_198, [16384, 4096]) + permute_677 = torch.ops.aten.permute.default(view_1335, [1, 0]) + permute_237 = torch.ops.aten.permute.default(getitem_189, [0, 2, 1, 3]) + view_735 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16); primals_197 = None + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 256, '0'); convert_element_type_710 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + view_737 = torch.ops.aten.view.default(view_735, [16384, 4096]); view_735 = None + mm_150 = torch.ops.aten.mm.default(view_737, permute_238) + view_738 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + add_85 = torch.ops.aten.add.Tensor(add_83, view_738); view_738 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16); primals_198 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 256, '0'); convert_element_type_713 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32); add_85 = None + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_195) + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + view_741 = torch.ops.aten.view.default(convert_element_type_715, [16384, 4096]); convert_element_type_715 = None + view_742 = torch.ops.aten.view.default(mm_151, [2, 8192, 14336]); mm_151 = None + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_742, torch.float32); view_742 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16); primals_200 = None + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 256, '0'); convert_element_type_721 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + mm_152 = torch.ops.aten.mm.default(view_741, permute_240) + view_745 = torch.ops.aten.view.default(mm_152, [2, 8192, 14336]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_745) + view_747 = torch.ops.aten.view.default(mul_175, [16384, 14336]); mul_175 = None + mm_367 = torch.ops.aten.mm.default(permute_677, view_747); permute_677 = view_747 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16); primals_201 = None + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 256, '0'); convert_element_type_724 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + permute_679 = torch.ops.aten.permute.default(permute_241, [1, 0]); permute_241 = None + mm_368 = torch.ops.aten.mm.default(view_1335, permute_679); view_1335 = permute_679 = None + view_1336 = torch.ops.aten.view.default(mm_368, [2, 8192, 14336]); mm_368 = None + convert_element_type_1618 = torch.ops.prims.convert_element_type.default(mm_367, torch.float32); mm_367 = None + reduce_scatter_tensor_92 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1618, 'avg', 256, '0'); convert_element_type_1618 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_92); reduce_scatter_tensor_92 = None + mul_464 = torch.ops.aten.mul.Tensor(view_1336, convert_element_type_720); convert_element_type_720 = None + mul_465 = torch.ops.aten.mul.Tensor(view_1336, view_745); view_1336 = view_745 = None + view_1337 = torch.ops.aten.view.default(mul_464, [16384, 14336]); mul_464 = None + permute_681 = torch.ops.aten.permute.default(view_1337, [1, 0]) + mm_369 = torch.ops.aten.mm.default(permute_681, view_741); permute_681 = None + permute_683 = torch.ops.aten.permute.default(permute_240, [1, 0]); permute_240 = None + mm_370 = torch.ops.aten.mm.default(view_1337, permute_683); view_1337 = permute_683 = None + view_1338 = torch.ops.aten.view.default(mm_370, [2, 8192, 4096]); mm_370 = None + convert_element_type_1623 = torch.ops.prims.convert_element_type.default(mm_369, torch.float32); mm_369 = None + reduce_scatter_tensor_93 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1623, 'avg', 256, '0'); convert_element_type_1623 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_93); reduce_scatter_tensor_93 = None + convert_element_type_1624 = torch.ops.prims.convert_element_type.default(mul_465, torch.float32); mul_465 = None + neg_10 = torch.ops.aten.neg.default(convert_element_type_719) + exp_10 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_199 = torch.ops.aten.add.Tensor(exp_10, 1); exp_10 = None + reciprocal_10 = torch.ops.aten.reciprocal.default(add_199); add_199 = None + mul_466 = torch.ops.aten.mul.Tensor(reciprocal_10, 1); reciprocal_10 = None + mul_467 = torch.ops.aten.mul.Tensor(convert_element_type_1624, mul_466); convert_element_type_1624 = None + sub_31 = torch.ops.aten.sub.Tensor(1, mul_466); mul_466 = None + mul_468 = torch.ops.aten.mul.Tensor(convert_element_type_719, sub_31); convert_element_type_719 = sub_31 = None + add_200 = torch.ops.aten.add.Tensor(mul_468, 1); mul_468 = None + mul_469 = torch.ops.aten.mul.Tensor(mul_467, add_200); mul_467 = add_200 = None + convert_element_type_1626 = torch.ops.prims.convert_element_type.default(mul_469, torch.bfloat16); mul_469 = None + view_1339 = torch.ops.aten.view.default(convert_element_type_1626, [16384, 14336]); convert_element_type_1626 = None + permute_685 = torch.ops.aten.permute.default(view_1339, [1, 0]) + mm_371 = torch.ops.aten.mm.default(permute_685, view_741); permute_685 = view_741 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16); primals_199 = None + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 256, '0'); convert_element_type_716 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_196, [1, 0]); wait_tensor_196 = None + permute_687 = torch.ops.aten.permute.default(permute_239, [1, 0]); permute_239 = None + mm_372 = torch.ops.aten.mm.default(view_1339, permute_687); view_1339 = permute_687 = None + view_1340 = torch.ops.aten.view.default(mm_372, [2, 8192, 4096]); mm_372 = None + add_201 = torch.ops.aten.add.Tensor(view_1338, view_1340); view_1338 = view_1340 = None + convert_element_type_1631 = torch.ops.prims.convert_element_type.default(mm_371, torch.float32); mm_371 = None + reduce_scatter_tensor_94 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1631, 'avg', 256, '0'); convert_element_type_1631 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_94); reduce_scatter_tensor_94 = None + convert_element_type_1632 = torch.ops.prims.convert_element_type.default(add_201, torch.float32); add_201 = None + convert_element_type_1634 = torch.ops.prims.convert_element_type.default(wait_tensor_195, torch.float32); wait_tensor_195 = None + mul_470 = torch.ops.aten.mul.Tensor(convert_element_type_1632, convert_element_type_1634); convert_element_type_1634 = None + mul_472 = torch.ops.aten.mul.Tensor(mul_172, mul_470) + sum_63 = torch.ops.aten.sum.dim_IntList(mul_472, [2], True); mul_472 = None + div_21 = torch.ops.aten.div.Tensor(mul_172, 4096) + mul_473 = torch.ops.aten.mul.Tensor(div_21, sum_63); div_21 = sum_63 = None + sub_32 = torch.ops.aten.sub.Tensor(mul_470, mul_473); mul_470 = mul_473 = None + mul_474 = torch.ops.aten.mul.Tensor(sub_32, rsqrt_43); sub_32 = rsqrt_43 = None + mul_475 = torch.ops.aten.mul.Tensor(convert_element_type_1632, mul_172); convert_element_type_1632 = mul_172 = None + sum_64 = torch.ops.aten.sum.dim_IntList(mul_475, [0, 1]); mul_475 = None + convert_element_type_1635 = torch.ops.prims.convert_element_type.default(mul_474, torch.bfloat16); mul_474 = None + add_202 = torch.ops.aten.add.Tensor(add_198, convert_element_type_1635); add_198 = convert_element_type_1635 = None + convert_element_type_default_44 = torch.ops.prims.convert_element_type.default(sum_64, torch.float32); sum_64 = None + reduce_scatter_tensor_95 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_44, 'avg', 256, '0'); convert_element_type_default_44 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_95); reduce_scatter_tensor_95 = None + view_1341 = torch.ops.aten.view.default(add_202, [16384, 4096]) + permute_689 = torch.ops.aten.permute.default(view_1341, [1, 0]) + mm_373 = torch.ops.aten.mm.default(permute_689, view_737); permute_689 = view_737 = None + permute_691 = torch.ops.aten.permute.default(permute_238, [1, 0]); permute_238 = None + mm_374 = torch.ops.aten.mm.default(view_1341, permute_691); view_1341 = permute_691 = None + view_1342 = torch.ops.aten.view.default(mm_374, [2, 8192, 4096]); mm_374 = None + convert_element_type_1642 = torch.ops.prims.convert_element_type.default(mm_373, torch.float32); mm_373 = None + reduce_scatter_tensor_96 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1642, 'avg', 256, '0'); convert_element_type_1642 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_96); reduce_scatter_tensor_96 = None + view_1343 = torch.ops.aten.view.default(view_1342, [2, 8192, 32, 128]); view_1342 = None + permute_693 = torch.ops.aten.permute.default(view_1343, [0, 2, 1, 3]); view_1343 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16); primals_193 = None + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 256, '0'); convert_element_type_694 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32); add_83 = None + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_190) + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + view_717 = torch.ops.aten.view.default(convert_element_type_696, [16384, 4096]); convert_element_type_696 = None + view_718 = torch.ops.aten.view.default(mm_147, [2, 8192, 4096]); mm_147 = None + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 256, '0'); convert_element_type_700 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_192, [1, 0]); wait_tensor_192 = None + mm_148 = torch.ops.aten.mm.default(view_717, permute_232) + view_721 = torch.ops.aten.view.default(mm_148, [2, 8192, 1024]); mm_148 = None + view_724 = torch.ops.aten.view.default(mm_149, [2, 8192, 1024]); mm_149 = None + view_725 = torch.ops.aten.view.default(view_718, [2, 8192, -1, 128]); view_718 = None + view_726 = torch.ops.aten.view.default(view_721, [2, 8192, -1, 128]); view_721 = None + view_727 = torch.ops.aten.view.default(view_724, [2, 8192, -1, 128]); view_724 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_725, torch.float32); view_725 = None + view_728 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 32, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_728); view_728 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_726, torch.float32); view_726 = None + view_729 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 8, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_729); view_729 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_16); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_731 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 32, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_16); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_732 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 8, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_731, torch.bfloat16); view_731 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_732, torch.bfloat16); view_732 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 8, 4, 128]); unsqueeze_42 = None + clone_42 = torch.ops.aten.clone.default(expand_42, memory_format = torch.contiguous_format); expand_42 = None + view_733 = torch.ops.aten.view.default(clone_42, [2, 8192, 32, 128]); clone_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_727, 3); view_727 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 8, 4, 128]); unsqueeze_43 = None + clone_43 = torch.ops.aten.clone.default(expand_43, memory_format = torch.contiguous_format); expand_43 = None + view_734 = torch.ops.aten.view.default(clone_43, [2, 8192, 32, 128]); clone_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_733, [0, 2, 1, 3]); view_733 = None + permute_236 = torch.ops.aten.permute.default(view_734, [0, 2, 1, 3]); view_734 = None + _scaled_dot_product_cudnn_attention_backward_10 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_693, permute_234, permute_235, permute_236, getitem_189, getitem_190, getitem_195, getitem_196, None, None, None, 8192, 8192, 0.0, True); permute_693 = permute_234 = permute_235 = permute_236 = getitem_189 = getitem_190 = getitem_195 = getitem_196 = None + getitem_318 = _scaled_dot_product_cudnn_attention_backward_10[0] + getitem_319 = _scaled_dot_product_cudnn_attention_backward_10[1] + getitem_320 = _scaled_dot_product_cudnn_attention_backward_10[2]; _scaled_dot_product_cudnn_attention_backward_10 = None + permute_694 = torch.ops.aten.permute.default(getitem_320, [0, 2, 1, 3]); getitem_320 = None + permute_695 = torch.ops.aten.permute.default(getitem_319, [0, 2, 1, 3]); getitem_319 = None + permute_696 = torch.ops.aten.permute.default(getitem_318, [0, 2, 1, 3]); getitem_318 = None + view_1344 = torch.ops.aten.view.default(permute_694, [2, 8192, 8, 4, 128]); permute_694 = None + sum_65 = torch.ops.aten.sum.dim_IntList(view_1344, [3], True); view_1344 = None + squeeze_20 = torch.ops.aten.squeeze.dim(sum_65, 3); sum_65 = None + view_1345 = torch.ops.aten.view.default(permute_695, [2, 8192, 8, 4, 128]); permute_695 = None + sum_66 = torch.ops.aten.sum.dim_IntList(view_1345, [3], True); view_1345 = None + squeeze_21 = torch.ops.aten.squeeze.dim(sum_66, 3); sum_66 = None + convert_element_type_1643 = torch.ops.prims.convert_element_type.default(squeeze_21, torch.float32); squeeze_21 = None + convert_element_type_1644 = torch.ops.prims.convert_element_type.default(permute_696, torch.float32); permute_696 = None + view_1346 = torch.ops.aten.view.default(convert_element_type_1643, [2, 8192, 8, 64, 2]); convert_element_type_1643 = None + view_as_complex_84 = torch.ops.aten.view_as_complex.default(view_1346); view_1346 = None + mul_476 = torch.ops.aten.mul.Tensor(view_as_complex_84, _conj); view_as_complex_84 = None + view_1347 = torch.ops.aten.view.default(convert_element_type_1644, [2, 8192, 32, 64, 2]); convert_element_type_1644 = None + view_as_complex_85 = torch.ops.aten.view_as_complex.default(view_1347); view_1347 = None + mul_477 = torch.ops.aten.mul.Tensor(view_as_complex_85, _conj); view_as_complex_85 = None + view_as_real_84 = torch.ops.aten.view_as_real.default(mul_476); mul_476 = None + view_1348 = torch.ops.aten.view.default(view_as_real_84, [2, 8192, 8, 128]); view_as_real_84 = None + convert_element_type_1645 = torch.ops.prims.convert_element_type.default(view_1348, torch.bfloat16); view_1348 = None + view_as_real_85 = torch.ops.aten.view_as_real.default(mul_477); mul_477 = None + view_1349 = torch.ops.aten.view.default(view_as_real_85, [2, 8192, 32, 128]); view_as_real_85 = None + convert_element_type_1646 = torch.ops.prims.convert_element_type.default(view_1349, torch.bfloat16); view_1349 = None + view_1350 = torch.ops.aten.view.default(squeeze_20, [2, 8192, 1024]); squeeze_20 = None + view_1351 = torch.ops.aten.view.default(convert_element_type_1645, [2, 8192, 1024]); convert_element_type_1645 = None + view_1352 = torch.ops.aten.view.default(convert_element_type_1646, [2, 8192, 4096]); convert_element_type_1646 = None + view_1353 = torch.ops.aten.view.default(view_1350, [16384, 1024]); view_1350 = None + permute_697 = torch.ops.aten.permute.default(view_1353, [1, 0]) + mm_375 = torch.ops.aten.mm.default(permute_697, view_717); permute_697 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 256, '0'); convert_element_type_703 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + permute_699 = torch.ops.aten.permute.default(permute_233, [1, 0]); permute_233 = None + mm_376 = torch.ops.aten.mm.default(view_1353, permute_699); view_1353 = permute_699 = None + view_1354 = torch.ops.aten.view.default(mm_376, [2, 8192, 4096]); mm_376 = None + convert_element_type_1651 = torch.ops.prims.convert_element_type.default(mm_375, torch.float32); mm_375 = None + reduce_scatter_tensor_97 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1651, 'avg', 256, '0'); convert_element_type_1651 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_97); reduce_scatter_tensor_97 = None + view_1355 = torch.ops.aten.view.default(view_1351, [16384, 1024]); view_1351 = None + permute_701 = torch.ops.aten.permute.default(view_1355, [1, 0]) + mm_377 = torch.ops.aten.mm.default(permute_701, view_717); permute_701 = None + permute_703 = torch.ops.aten.permute.default(permute_232, [1, 0]); permute_232 = None + mm_378 = torch.ops.aten.mm.default(view_1355, permute_703); view_1355 = permute_703 = None + view_1356 = torch.ops.aten.view.default(mm_378, [2, 8192, 4096]); mm_378 = None + add_203 = torch.ops.aten.add.Tensor(view_1354, view_1356); view_1354 = view_1356 = None + convert_element_type_1656 = torch.ops.prims.convert_element_type.default(mm_377, torch.float32); mm_377 = None + reduce_scatter_tensor_98 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1656, 'avg', 256, '0'); convert_element_type_1656 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_98); reduce_scatter_tensor_98 = None + view_1357 = torch.ops.aten.view.default(view_1352, [16384, 4096]); view_1352 = None + permute_705 = torch.ops.aten.permute.default(view_1357, [1, 0]) + mm_379 = torch.ops.aten.mm.default(permute_705, view_717); permute_705 = view_717 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16); primals_194 = None + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 256, '0'); convert_element_type_697 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_191, [1, 0]); wait_tensor_191 = None + permute_707 = torch.ops.aten.permute.default(permute_231, [1, 0]); permute_231 = None + mm_380 = torch.ops.aten.mm.default(view_1357, permute_707); view_1357 = permute_707 = None + view_1358 = torch.ops.aten.view.default(mm_380, [2, 8192, 4096]); mm_380 = None + add_204 = torch.ops.aten.add.Tensor(add_203, view_1358); add_203 = view_1358 = None + convert_element_type_1661 = torch.ops.prims.convert_element_type.default(mm_379, torch.float32); mm_379 = None + reduce_scatter_tensor_99 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1661, 'avg', 256, '0'); convert_element_type_1661 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_99); reduce_scatter_tensor_99 = None + convert_element_type_1662 = torch.ops.prims.convert_element_type.default(add_204, torch.float32); add_204 = None + convert_element_type_1664 = torch.ops.prims.convert_element_type.default(wait_tensor_190, torch.float32); wait_tensor_190 = None + mul_478 = torch.ops.aten.mul.Tensor(convert_element_type_1662, convert_element_type_1664); convert_element_type_1664 = None + mul_480 = torch.ops.aten.mul.Tensor(mul_168, mul_478) + sum_67 = torch.ops.aten.sum.dim_IntList(mul_480, [2], True); mul_480 = None + div_22 = torch.ops.aten.div.Tensor(mul_168, 4096) + mul_481 = torch.ops.aten.mul.Tensor(div_22, sum_67); div_22 = sum_67 = None + sub_33 = torch.ops.aten.sub.Tensor(mul_478, mul_481); mul_478 = mul_481 = None + mul_482 = torch.ops.aten.mul.Tensor(sub_33, rsqrt_42); sub_33 = rsqrt_42 = None + mul_483 = torch.ops.aten.mul.Tensor(convert_element_type_1662, mul_168); convert_element_type_1662 = mul_168 = None + sum_68 = torch.ops.aten.sum.dim_IntList(mul_483, [0, 1]); mul_483 = None + convert_element_type_1665 = torch.ops.prims.convert_element_type.default(mul_482, torch.bfloat16); mul_482 = None + add_205 = torch.ops.aten.add.Tensor(add_202, convert_element_type_1665); add_202 = convert_element_type_1665 = None + convert_element_type_default_43 = torch.ops.prims.convert_element_type.default(sum_68, torch.float32); sum_68 = None + reduce_scatter_tensor_100 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_43, 'avg', 256, '0'); convert_element_type_default_43 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_100); reduce_scatter_tensor_100 = None + view_1359 = torch.ops.aten.view.default(add_205, [16384, 4096]) + permute_709 = torch.ops.aten.permute.default(view_1359, [1, 0]) + permute_226 = torch.ops.aten.permute.default(getitem_180, [0, 2, 1, 3]) + view_701 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16); primals_188 = None + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 256, '0'); convert_element_type_677 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_185, [1, 0]); wait_tensor_185 = None + view_703 = torch.ops.aten.view.default(view_701, [16384, 4096]); view_701 = None + mm_143 = torch.ops.aten.mm.default(view_703, permute_227) + view_704 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + add_81 = torch.ops.aten.add.Tensor(add_79, view_704); view_704 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16); primals_189 = None + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 256, '0'); convert_element_type_680 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32); add_81 = None + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_186) + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + view_707 = torch.ops.aten.view.default(convert_element_type_682, [16384, 4096]); convert_element_type_682 = None + view_708 = torch.ops.aten.view.default(mm_144, [2, 8192, 14336]); mm_144 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_708, torch.float32); view_708 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16); primals_191 = None + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 256, '0'); convert_element_type_688 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_145 = torch.ops.aten.mm.default(view_707, permute_229) + view_711 = torch.ops.aten.view.default(mm_145, [2, 8192, 14336]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_711) + view_713 = torch.ops.aten.view.default(mul_167, [16384, 14336]); mul_167 = None + mm_381 = torch.ops.aten.mm.default(permute_709, view_713); permute_709 = view_713 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16); primals_192 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 256, '0'); convert_element_type_691 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + permute_711 = torch.ops.aten.permute.default(permute_230, [1, 0]); permute_230 = None + mm_382 = torch.ops.aten.mm.default(view_1359, permute_711); view_1359 = permute_711 = None + view_1360 = torch.ops.aten.view.default(mm_382, [2, 8192, 14336]); mm_382 = None + convert_element_type_1672 = torch.ops.prims.convert_element_type.default(mm_381, torch.float32); mm_381 = None + reduce_scatter_tensor_101 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1672, 'avg', 256, '0'); convert_element_type_1672 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_101); reduce_scatter_tensor_101 = None + mul_484 = torch.ops.aten.mul.Tensor(view_1360, convert_element_type_687); convert_element_type_687 = None + mul_485 = torch.ops.aten.mul.Tensor(view_1360, view_711); view_1360 = view_711 = None + view_1361 = torch.ops.aten.view.default(mul_484, [16384, 14336]); mul_484 = None + permute_713 = torch.ops.aten.permute.default(view_1361, [1, 0]) + mm_383 = torch.ops.aten.mm.default(permute_713, view_707); permute_713 = None + permute_715 = torch.ops.aten.permute.default(permute_229, [1, 0]); permute_229 = None + mm_384 = torch.ops.aten.mm.default(view_1361, permute_715); view_1361 = permute_715 = None + view_1362 = torch.ops.aten.view.default(mm_384, [2, 8192, 4096]); mm_384 = None + convert_element_type_1677 = torch.ops.prims.convert_element_type.default(mm_383, torch.float32); mm_383 = None + reduce_scatter_tensor_102 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1677, 'avg', 256, '0'); convert_element_type_1677 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_102); reduce_scatter_tensor_102 = None + convert_element_type_1678 = torch.ops.prims.convert_element_type.default(mul_485, torch.float32); mul_485 = None + neg_11 = torch.ops.aten.neg.default(convert_element_type_686) + exp_11 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_206 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + reciprocal_11 = torch.ops.aten.reciprocal.default(add_206); add_206 = None + mul_486 = torch.ops.aten.mul.Tensor(reciprocal_11, 1); reciprocal_11 = None + mul_487 = torch.ops.aten.mul.Tensor(convert_element_type_1678, mul_486); convert_element_type_1678 = None + sub_34 = torch.ops.aten.sub.Tensor(1, mul_486); mul_486 = None + mul_488 = torch.ops.aten.mul.Tensor(convert_element_type_686, sub_34); convert_element_type_686 = sub_34 = None + add_207 = torch.ops.aten.add.Tensor(mul_488, 1); mul_488 = None + mul_489 = torch.ops.aten.mul.Tensor(mul_487, add_207); mul_487 = add_207 = None + convert_element_type_1680 = torch.ops.prims.convert_element_type.default(mul_489, torch.bfloat16); mul_489 = None + view_1363 = torch.ops.aten.view.default(convert_element_type_1680, [16384, 14336]); convert_element_type_1680 = None + permute_717 = torch.ops.aten.permute.default(view_1363, [1, 0]) + mm_385 = torch.ops.aten.mm.default(permute_717, view_707); permute_717 = view_707 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16); primals_190 = None + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 256, '0'); convert_element_type_683 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + permute_719 = torch.ops.aten.permute.default(permute_228, [1, 0]); permute_228 = None + mm_386 = torch.ops.aten.mm.default(view_1363, permute_719); view_1363 = permute_719 = None + view_1364 = torch.ops.aten.view.default(mm_386, [2, 8192, 4096]); mm_386 = None + add_208 = torch.ops.aten.add.Tensor(view_1362, view_1364); view_1362 = view_1364 = None + convert_element_type_1685 = torch.ops.prims.convert_element_type.default(mm_385, torch.float32); mm_385 = None + reduce_scatter_tensor_103 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1685, 'avg', 256, '0'); convert_element_type_1685 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_103); reduce_scatter_tensor_103 = None + convert_element_type_1686 = torch.ops.prims.convert_element_type.default(add_208, torch.float32); add_208 = None + convert_element_type_1688 = torch.ops.prims.convert_element_type.default(wait_tensor_186, torch.float32); wait_tensor_186 = None + mul_490 = torch.ops.aten.mul.Tensor(convert_element_type_1686, convert_element_type_1688); convert_element_type_1688 = None + mul_492 = torch.ops.aten.mul.Tensor(mul_164, mul_490) + sum_69 = torch.ops.aten.sum.dim_IntList(mul_492, [2], True); mul_492 = None + div_23 = torch.ops.aten.div.Tensor(mul_164, 4096) + mul_493 = torch.ops.aten.mul.Tensor(div_23, sum_69); div_23 = sum_69 = None + sub_35 = torch.ops.aten.sub.Tensor(mul_490, mul_493); mul_490 = mul_493 = None + mul_494 = torch.ops.aten.mul.Tensor(sub_35, rsqrt_41); sub_35 = rsqrt_41 = None + mul_495 = torch.ops.aten.mul.Tensor(convert_element_type_1686, mul_164); convert_element_type_1686 = mul_164 = None + sum_70 = torch.ops.aten.sum.dim_IntList(mul_495, [0, 1]); mul_495 = None + convert_element_type_1689 = torch.ops.prims.convert_element_type.default(mul_494, torch.bfloat16); mul_494 = None + add_209 = torch.ops.aten.add.Tensor(add_205, convert_element_type_1689); add_205 = convert_element_type_1689 = None + convert_element_type_default_42 = torch.ops.prims.convert_element_type.default(sum_70, torch.float32); sum_70 = None + reduce_scatter_tensor_104 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_42, 'avg', 256, '0'); convert_element_type_default_42 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_104); reduce_scatter_tensor_104 = None + view_1365 = torch.ops.aten.view.default(add_209, [16384, 4096]) + permute_721 = torch.ops.aten.permute.default(view_1365, [1, 0]) + mm_387 = torch.ops.aten.mm.default(permute_721, view_703); permute_721 = view_703 = None + permute_723 = torch.ops.aten.permute.default(permute_227, [1, 0]); permute_227 = None + mm_388 = torch.ops.aten.mm.default(view_1365, permute_723); view_1365 = permute_723 = None + view_1366 = torch.ops.aten.view.default(mm_388, [2, 8192, 4096]); mm_388 = None + convert_element_type_1696 = torch.ops.prims.convert_element_type.default(mm_387, torch.float32); mm_387 = None + reduce_scatter_tensor_105 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1696, 'avg', 256, '0'); convert_element_type_1696 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_105); reduce_scatter_tensor_105 = None + view_1367 = torch.ops.aten.view.default(view_1366, [2, 8192, 32, 128]); view_1366 = None + permute_725 = torch.ops.aten.permute.default(view_1367, [0, 2, 1, 3]); view_1367 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16); primals_184 = None + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 256, '0'); convert_element_type_661 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32); add_79 = None + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_181) + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + view_683 = torch.ops.aten.view.default(convert_element_type_663, [16384, 4096]); convert_element_type_663 = None + view_684 = torch.ops.aten.view.default(mm_140, [2, 8192, 4096]); mm_140 = None + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16); primals_186 = None + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 256, '0'); convert_element_type_667 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + mm_141 = torch.ops.aten.mm.default(view_683, permute_221) + view_687 = torch.ops.aten.view.default(mm_141, [2, 8192, 1024]); mm_141 = None + view_690 = torch.ops.aten.view.default(mm_142, [2, 8192, 1024]); mm_142 = None + view_691 = torch.ops.aten.view.default(view_684, [2, 8192, -1, 128]); view_684 = None + view_692 = torch.ops.aten.view.default(view_687, [2, 8192, -1, 128]); view_687 = None + view_693 = torch.ops.aten.view.default(view_690, [2, 8192, -1, 128]); view_690 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_691, torch.float32); view_691 = None + view_694 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 32, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_694); view_694 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_692, torch.float32); view_692 = None + view_695 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 8, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_695); view_695 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_16); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_697 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 32, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_16); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_698 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 8, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_697, torch.bfloat16); view_697 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_698, torch.bfloat16); view_698 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 8, 4, 128]); unsqueeze_40 = None + clone_40 = torch.ops.aten.clone.default(expand_40, memory_format = torch.contiguous_format); expand_40 = None + view_699 = torch.ops.aten.view.default(clone_40, [2, 8192, 32, 128]); clone_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_693, 3); view_693 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 8, 4, 128]); unsqueeze_41 = None + clone_41 = torch.ops.aten.clone.default(expand_41, memory_format = torch.contiguous_format); expand_41 = None + view_700 = torch.ops.aten.view.default(clone_41, [2, 8192, 32, 128]); clone_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_699, [0, 2, 1, 3]); view_699 = None + permute_225 = torch.ops.aten.permute.default(view_700, [0, 2, 1, 3]); view_700 = None + _scaled_dot_product_cudnn_attention_backward_11 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_725, permute_223, permute_224, permute_225, getitem_180, getitem_181, getitem_186, getitem_187, None, None, None, 8192, 8192, 0.0, True); permute_725 = permute_223 = permute_224 = permute_225 = getitem_180 = getitem_181 = getitem_186 = getitem_187 = None + getitem_321 = _scaled_dot_product_cudnn_attention_backward_11[0] + getitem_322 = _scaled_dot_product_cudnn_attention_backward_11[1] + getitem_323 = _scaled_dot_product_cudnn_attention_backward_11[2]; _scaled_dot_product_cudnn_attention_backward_11 = None + permute_726 = torch.ops.aten.permute.default(getitem_323, [0, 2, 1, 3]); getitem_323 = None + permute_727 = torch.ops.aten.permute.default(getitem_322, [0, 2, 1, 3]); getitem_322 = None + permute_728 = torch.ops.aten.permute.default(getitem_321, [0, 2, 1, 3]); getitem_321 = None + view_1368 = torch.ops.aten.view.default(permute_726, [2, 8192, 8, 4, 128]); permute_726 = None + sum_71 = torch.ops.aten.sum.dim_IntList(view_1368, [3], True); view_1368 = None + squeeze_22 = torch.ops.aten.squeeze.dim(sum_71, 3); sum_71 = None + view_1369 = torch.ops.aten.view.default(permute_727, [2, 8192, 8, 4, 128]); permute_727 = None + sum_72 = torch.ops.aten.sum.dim_IntList(view_1369, [3], True); view_1369 = None + squeeze_23 = torch.ops.aten.squeeze.dim(sum_72, 3); sum_72 = None + convert_element_type_1697 = torch.ops.prims.convert_element_type.default(squeeze_23, torch.float32); squeeze_23 = None + convert_element_type_1698 = torch.ops.prims.convert_element_type.default(permute_728, torch.float32); permute_728 = None + view_1370 = torch.ops.aten.view.default(convert_element_type_1697, [2, 8192, 8, 64, 2]); convert_element_type_1697 = None + view_as_complex_86 = torch.ops.aten.view_as_complex.default(view_1370); view_1370 = None + mul_496 = torch.ops.aten.mul.Tensor(view_as_complex_86, _conj); view_as_complex_86 = None + view_1371 = torch.ops.aten.view.default(convert_element_type_1698, [2, 8192, 32, 64, 2]); convert_element_type_1698 = None + view_as_complex_87 = torch.ops.aten.view_as_complex.default(view_1371); view_1371 = None + mul_497 = torch.ops.aten.mul.Tensor(view_as_complex_87, _conj); view_as_complex_87 = None + view_as_real_86 = torch.ops.aten.view_as_real.default(mul_496); mul_496 = None + view_1372 = torch.ops.aten.view.default(view_as_real_86, [2, 8192, 8, 128]); view_as_real_86 = None + convert_element_type_1699 = torch.ops.prims.convert_element_type.default(view_1372, torch.bfloat16); view_1372 = None + view_as_real_87 = torch.ops.aten.view_as_real.default(mul_497); mul_497 = None + view_1373 = torch.ops.aten.view.default(view_as_real_87, [2, 8192, 32, 128]); view_as_real_87 = None + convert_element_type_1700 = torch.ops.prims.convert_element_type.default(view_1373, torch.bfloat16); view_1373 = None + view_1374 = torch.ops.aten.view.default(squeeze_22, [2, 8192, 1024]); squeeze_22 = None + view_1375 = torch.ops.aten.view.default(convert_element_type_1699, [2, 8192, 1024]); convert_element_type_1699 = None + view_1376 = torch.ops.aten.view.default(convert_element_type_1700, [2, 8192, 4096]); convert_element_type_1700 = None + view_1377 = torch.ops.aten.view.default(view_1374, [16384, 1024]); view_1374 = None + permute_729 = torch.ops.aten.permute.default(view_1377, [1, 0]) + mm_389 = torch.ops.aten.mm.default(permute_729, view_683); permute_729 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16); primals_187 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 256, '0'); convert_element_type_670 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + permute_731 = torch.ops.aten.permute.default(permute_222, [1, 0]); permute_222 = None + mm_390 = torch.ops.aten.mm.default(view_1377, permute_731); view_1377 = permute_731 = None + view_1378 = torch.ops.aten.view.default(mm_390, [2, 8192, 4096]); mm_390 = None + convert_element_type_1705 = torch.ops.prims.convert_element_type.default(mm_389, torch.float32); mm_389 = None + reduce_scatter_tensor_106 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1705, 'avg', 256, '0'); convert_element_type_1705 = None + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_106); reduce_scatter_tensor_106 = None + view_1379 = torch.ops.aten.view.default(view_1375, [16384, 1024]); view_1375 = None + permute_733 = torch.ops.aten.permute.default(view_1379, [1, 0]) + mm_391 = torch.ops.aten.mm.default(permute_733, view_683); permute_733 = None + permute_735 = torch.ops.aten.permute.default(permute_221, [1, 0]); permute_221 = None + mm_392 = torch.ops.aten.mm.default(view_1379, permute_735); view_1379 = permute_735 = None + view_1380 = torch.ops.aten.view.default(mm_392, [2, 8192, 4096]); mm_392 = None + add_210 = torch.ops.aten.add.Tensor(view_1378, view_1380); view_1378 = view_1380 = None + convert_element_type_1710 = torch.ops.prims.convert_element_type.default(mm_391, torch.float32); mm_391 = None + reduce_scatter_tensor_107 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1710, 'avg', 256, '0'); convert_element_type_1710 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_107); reduce_scatter_tensor_107 = None + view_1381 = torch.ops.aten.view.default(view_1376, [16384, 4096]); view_1376 = None + permute_737 = torch.ops.aten.permute.default(view_1381, [1, 0]) + mm_393 = torch.ops.aten.mm.default(permute_737, view_683); permute_737 = view_683 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16); primals_185 = None + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 256, '0'); convert_element_type_664 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + permute_739 = torch.ops.aten.permute.default(permute_220, [1, 0]); permute_220 = None + mm_394 = torch.ops.aten.mm.default(view_1381, permute_739); view_1381 = permute_739 = None + view_1382 = torch.ops.aten.view.default(mm_394, [2, 8192, 4096]); mm_394 = None + add_211 = torch.ops.aten.add.Tensor(add_210, view_1382); add_210 = view_1382 = None + convert_element_type_1715 = torch.ops.prims.convert_element_type.default(mm_393, torch.float32); mm_393 = None + reduce_scatter_tensor_108 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1715, 'avg', 256, '0'); convert_element_type_1715 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_108); reduce_scatter_tensor_108 = None + convert_element_type_1716 = torch.ops.prims.convert_element_type.default(add_211, torch.float32); add_211 = None + convert_element_type_1718 = torch.ops.prims.convert_element_type.default(wait_tensor_181, torch.float32); wait_tensor_181 = None + mul_498 = torch.ops.aten.mul.Tensor(convert_element_type_1716, convert_element_type_1718); convert_element_type_1718 = None + mul_500 = torch.ops.aten.mul.Tensor(mul_160, mul_498) + sum_73 = torch.ops.aten.sum.dim_IntList(mul_500, [2], True); mul_500 = None + div_24 = torch.ops.aten.div.Tensor(mul_160, 4096) + mul_501 = torch.ops.aten.mul.Tensor(div_24, sum_73); div_24 = sum_73 = None + sub_36 = torch.ops.aten.sub.Tensor(mul_498, mul_501); mul_498 = mul_501 = None + mul_502 = torch.ops.aten.mul.Tensor(sub_36, rsqrt_40); sub_36 = rsqrt_40 = None + mul_503 = torch.ops.aten.mul.Tensor(convert_element_type_1716, mul_160); convert_element_type_1716 = mul_160 = None + sum_74 = torch.ops.aten.sum.dim_IntList(mul_503, [0, 1]); mul_503 = None + convert_element_type_1719 = torch.ops.prims.convert_element_type.default(mul_502, torch.bfloat16); mul_502 = None + add_212 = torch.ops.aten.add.Tensor(add_209, convert_element_type_1719); add_209 = convert_element_type_1719 = None + convert_element_type_default_41 = torch.ops.prims.convert_element_type.default(sum_74, torch.float32); sum_74 = None + reduce_scatter_tensor_109 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_41, 'avg', 256, '0'); convert_element_type_default_41 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_109); reduce_scatter_tensor_109 = None + view_1383 = torch.ops.aten.view.default(add_212, [16384, 4096]) + permute_741 = torch.ops.aten.permute.default(view_1383, [1, 0]) + permute_215 = torch.ops.aten.permute.default(getitem_171, [0, 2, 1, 3]) + view_667 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 256, '0'); convert_element_type_644 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_669 = torch.ops.aten.view.default(view_667, [16384, 4096]); view_667 = None + mm_136 = torch.ops.aten.mm.default(view_669, permute_216) + view_670 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + add_77 = torch.ops.aten.add.Tensor(add_75, view_670); view_670 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 256, '0'); convert_element_type_647 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32); add_77 = None + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_177) + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + view_673 = torch.ops.aten.view.default(convert_element_type_649, [16384, 4096]); convert_element_type_649 = None + view_674 = torch.ops.aten.view.default(mm_137, [2, 8192, 14336]); mm_137 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_674, torch.float32); view_674 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16); primals_182 = None + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 256, '0'); convert_element_type_655 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_179, [1, 0]); wait_tensor_179 = None + mm_138 = torch.ops.aten.mm.default(view_673, permute_218) + view_677 = torch.ops.aten.view.default(mm_138, [2, 8192, 14336]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_677) + view_679 = torch.ops.aten.view.default(mul_159, [16384, 14336]); mul_159 = None + mm_395 = torch.ops.aten.mm.default(permute_741, view_679); permute_741 = view_679 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16); primals_183 = None + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 256, '0'); convert_element_type_658 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + permute_743 = torch.ops.aten.permute.default(permute_219, [1, 0]); permute_219 = None + mm_396 = torch.ops.aten.mm.default(view_1383, permute_743); view_1383 = permute_743 = None + view_1384 = torch.ops.aten.view.default(mm_396, [2, 8192, 14336]); mm_396 = None + convert_element_type_1726 = torch.ops.prims.convert_element_type.default(mm_395, torch.float32); mm_395 = None + reduce_scatter_tensor_110 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1726, 'avg', 256, '0'); convert_element_type_1726 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_110); reduce_scatter_tensor_110 = None + mul_504 = torch.ops.aten.mul.Tensor(view_1384, convert_element_type_654); convert_element_type_654 = None + mul_505 = torch.ops.aten.mul.Tensor(view_1384, view_677); view_1384 = view_677 = None + view_1385 = torch.ops.aten.view.default(mul_504, [16384, 14336]); mul_504 = None + permute_745 = torch.ops.aten.permute.default(view_1385, [1, 0]) + mm_397 = torch.ops.aten.mm.default(permute_745, view_673); permute_745 = None + permute_747 = torch.ops.aten.permute.default(permute_218, [1, 0]); permute_218 = None + mm_398 = torch.ops.aten.mm.default(view_1385, permute_747); view_1385 = permute_747 = None + view_1386 = torch.ops.aten.view.default(mm_398, [2, 8192, 4096]); mm_398 = None + convert_element_type_1731 = torch.ops.prims.convert_element_type.default(mm_397, torch.float32); mm_397 = None + reduce_scatter_tensor_111 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1731, 'avg', 256, '0'); convert_element_type_1731 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_111); reduce_scatter_tensor_111 = None + convert_element_type_1732 = torch.ops.prims.convert_element_type.default(mul_505, torch.float32); mul_505 = None + neg_12 = torch.ops.aten.neg.default(convert_element_type_653) + exp_12 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_213 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + reciprocal_12 = torch.ops.aten.reciprocal.default(add_213); add_213 = None + mul_506 = torch.ops.aten.mul.Tensor(reciprocal_12, 1); reciprocal_12 = None + mul_507 = torch.ops.aten.mul.Tensor(convert_element_type_1732, mul_506); convert_element_type_1732 = None + sub_37 = torch.ops.aten.sub.Tensor(1, mul_506); mul_506 = None + mul_508 = torch.ops.aten.mul.Tensor(convert_element_type_653, sub_37); convert_element_type_653 = sub_37 = None + add_214 = torch.ops.aten.add.Tensor(mul_508, 1); mul_508 = None + mul_509 = torch.ops.aten.mul.Tensor(mul_507, add_214); mul_507 = add_214 = None + convert_element_type_1734 = torch.ops.prims.convert_element_type.default(mul_509, torch.bfloat16); mul_509 = None + view_1387 = torch.ops.aten.view.default(convert_element_type_1734, [16384, 14336]); convert_element_type_1734 = None + permute_749 = torch.ops.aten.permute.default(view_1387, [1, 0]) + mm_399 = torch.ops.aten.mm.default(permute_749, view_673); permute_749 = view_673 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16); primals_181 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 256, '0'); convert_element_type_650 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + permute_751 = torch.ops.aten.permute.default(permute_217, [1, 0]); permute_217 = None + mm_400 = torch.ops.aten.mm.default(view_1387, permute_751); view_1387 = permute_751 = None + view_1388 = torch.ops.aten.view.default(mm_400, [2, 8192, 4096]); mm_400 = None + add_215 = torch.ops.aten.add.Tensor(view_1386, view_1388); view_1386 = view_1388 = None + convert_element_type_1739 = torch.ops.prims.convert_element_type.default(mm_399, torch.float32); mm_399 = None + reduce_scatter_tensor_112 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1739, 'avg', 256, '0'); convert_element_type_1739 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_112); reduce_scatter_tensor_112 = None + convert_element_type_1740 = torch.ops.prims.convert_element_type.default(add_215, torch.float32); add_215 = None + convert_element_type_1742 = torch.ops.prims.convert_element_type.default(wait_tensor_177, torch.float32); wait_tensor_177 = None + mul_510 = torch.ops.aten.mul.Tensor(convert_element_type_1740, convert_element_type_1742); convert_element_type_1742 = None + mul_512 = torch.ops.aten.mul.Tensor(mul_156, mul_510) + sum_75 = torch.ops.aten.sum.dim_IntList(mul_512, [2], True); mul_512 = None + div_25 = torch.ops.aten.div.Tensor(mul_156, 4096) + mul_513 = torch.ops.aten.mul.Tensor(div_25, sum_75); div_25 = sum_75 = None + sub_38 = torch.ops.aten.sub.Tensor(mul_510, mul_513); mul_510 = mul_513 = None + mul_514 = torch.ops.aten.mul.Tensor(sub_38, rsqrt_39); sub_38 = rsqrt_39 = None + mul_515 = torch.ops.aten.mul.Tensor(convert_element_type_1740, mul_156); convert_element_type_1740 = mul_156 = None + sum_76 = torch.ops.aten.sum.dim_IntList(mul_515, [0, 1]); mul_515 = None + convert_element_type_1743 = torch.ops.prims.convert_element_type.default(mul_514, torch.bfloat16); mul_514 = None + add_216 = torch.ops.aten.add.Tensor(add_212, convert_element_type_1743); add_212 = convert_element_type_1743 = None + convert_element_type_default_40 = torch.ops.prims.convert_element_type.default(sum_76, torch.float32); sum_76 = None + reduce_scatter_tensor_113 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_40, 'avg', 256, '0'); convert_element_type_default_40 = None + wait_tensor_404 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_113); reduce_scatter_tensor_113 = None + view_1389 = torch.ops.aten.view.default(add_216, [16384, 4096]) + permute_753 = torch.ops.aten.permute.default(view_1389, [1, 0]) + mm_401 = torch.ops.aten.mm.default(permute_753, view_669); permute_753 = view_669 = None + permute_755 = torch.ops.aten.permute.default(permute_216, [1, 0]); permute_216 = None + mm_402 = torch.ops.aten.mm.default(view_1389, permute_755); view_1389 = permute_755 = None + view_1390 = torch.ops.aten.view.default(mm_402, [2, 8192, 4096]); mm_402 = None + convert_element_type_1750 = torch.ops.prims.convert_element_type.default(mm_401, torch.float32); mm_401 = None + reduce_scatter_tensor_114 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1750, 'avg', 256, '0'); convert_element_type_1750 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_114); reduce_scatter_tensor_114 = None + view_1391 = torch.ops.aten.view.default(view_1390, [2, 8192, 32, 128]); view_1390 = None + permute_757 = torch.ops.aten.permute.default(view_1391, [0, 2, 1, 3]); view_1391 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16); primals_175 = None + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 256, '0'); convert_element_type_628 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32); add_75 = None + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_172) + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + view_649 = torch.ops.aten.view.default(convert_element_type_630, [16384, 4096]); convert_element_type_630 = None + view_650 = torch.ops.aten.view.default(mm_133, [2, 8192, 4096]); mm_133 = None + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16); primals_177 = None + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 256, '0'); convert_element_type_634 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_134 = torch.ops.aten.mm.default(view_649, permute_210) + view_653 = torch.ops.aten.view.default(mm_134, [2, 8192, 1024]); mm_134 = None + view_656 = torch.ops.aten.view.default(mm_135, [2, 8192, 1024]); mm_135 = None + view_657 = torch.ops.aten.view.default(view_650, [2, 8192, -1, 128]); view_650 = None + view_658 = torch.ops.aten.view.default(view_653, [2, 8192, -1, 128]); view_653 = None + view_659 = torch.ops.aten.view.default(view_656, [2, 8192, -1, 128]); view_656 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_657, torch.float32); view_657 = None + view_660 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 32, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_660); view_660 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_658, torch.float32); view_658 = None + view_661 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 8, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_661); view_661 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_16); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_663 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 32, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_16); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_664 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 8, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_663, torch.bfloat16); view_663 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_664, torch.bfloat16); view_664 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 8, 4, 128]); unsqueeze_38 = None + clone_38 = torch.ops.aten.clone.default(expand_38, memory_format = torch.contiguous_format); expand_38 = None + view_665 = torch.ops.aten.view.default(clone_38, [2, 8192, 32, 128]); clone_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_659, 3); view_659 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 8, 4, 128]); unsqueeze_39 = None + clone_39 = torch.ops.aten.clone.default(expand_39, memory_format = torch.contiguous_format); expand_39 = None + view_666 = torch.ops.aten.view.default(clone_39, [2, 8192, 32, 128]); clone_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_665, [0, 2, 1, 3]); view_665 = None + permute_214 = torch.ops.aten.permute.default(view_666, [0, 2, 1, 3]); view_666 = None + _scaled_dot_product_cudnn_attention_backward_12 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_757, permute_212, permute_213, permute_214, getitem_171, getitem_172, getitem_177, getitem_178, None, None, None, 8192, 8192, 0.0, True); permute_757 = permute_212 = permute_213 = permute_214 = getitem_171 = getitem_172 = getitem_177 = getitem_178 = None + getitem_324 = _scaled_dot_product_cudnn_attention_backward_12[0] + getitem_325 = _scaled_dot_product_cudnn_attention_backward_12[1] + getitem_326 = _scaled_dot_product_cudnn_attention_backward_12[2]; _scaled_dot_product_cudnn_attention_backward_12 = None + permute_758 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]); getitem_326 = None + permute_759 = torch.ops.aten.permute.default(getitem_325, [0, 2, 1, 3]); getitem_325 = None + permute_760 = torch.ops.aten.permute.default(getitem_324, [0, 2, 1, 3]); getitem_324 = None + view_1392 = torch.ops.aten.view.default(permute_758, [2, 8192, 8, 4, 128]); permute_758 = None + sum_77 = torch.ops.aten.sum.dim_IntList(view_1392, [3], True); view_1392 = None + squeeze_24 = torch.ops.aten.squeeze.dim(sum_77, 3); sum_77 = None + view_1393 = torch.ops.aten.view.default(permute_759, [2, 8192, 8, 4, 128]); permute_759 = None + sum_78 = torch.ops.aten.sum.dim_IntList(view_1393, [3], True); view_1393 = None + squeeze_25 = torch.ops.aten.squeeze.dim(sum_78, 3); sum_78 = None + convert_element_type_1751 = torch.ops.prims.convert_element_type.default(squeeze_25, torch.float32); squeeze_25 = None + convert_element_type_1752 = torch.ops.prims.convert_element_type.default(permute_760, torch.float32); permute_760 = None + view_1394 = torch.ops.aten.view.default(convert_element_type_1751, [2, 8192, 8, 64, 2]); convert_element_type_1751 = None + view_as_complex_88 = torch.ops.aten.view_as_complex.default(view_1394); view_1394 = None + mul_516 = torch.ops.aten.mul.Tensor(view_as_complex_88, _conj); view_as_complex_88 = None + view_1395 = torch.ops.aten.view.default(convert_element_type_1752, [2, 8192, 32, 64, 2]); convert_element_type_1752 = None + view_as_complex_89 = torch.ops.aten.view_as_complex.default(view_1395); view_1395 = None + mul_517 = torch.ops.aten.mul.Tensor(view_as_complex_89, _conj); view_as_complex_89 = None + view_as_real_88 = torch.ops.aten.view_as_real.default(mul_516); mul_516 = None + view_1396 = torch.ops.aten.view.default(view_as_real_88, [2, 8192, 8, 128]); view_as_real_88 = None + convert_element_type_1753 = torch.ops.prims.convert_element_type.default(view_1396, torch.bfloat16); view_1396 = None + view_as_real_89 = torch.ops.aten.view_as_real.default(mul_517); mul_517 = None + view_1397 = torch.ops.aten.view.default(view_as_real_89, [2, 8192, 32, 128]); view_as_real_89 = None + convert_element_type_1754 = torch.ops.prims.convert_element_type.default(view_1397, torch.bfloat16); view_1397 = None + view_1398 = torch.ops.aten.view.default(squeeze_24, [2, 8192, 1024]); squeeze_24 = None + view_1399 = torch.ops.aten.view.default(convert_element_type_1753, [2, 8192, 1024]); convert_element_type_1753 = None + view_1400 = torch.ops.aten.view.default(convert_element_type_1754, [2, 8192, 4096]); convert_element_type_1754 = None + view_1401 = torch.ops.aten.view.default(view_1398, [16384, 1024]); view_1398 = None + permute_761 = torch.ops.aten.permute.default(view_1401, [1, 0]) + mm_403 = torch.ops.aten.mm.default(permute_761, view_649); permute_761 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16); primals_178 = None + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 256, '0'); convert_element_type_637 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + permute_763 = torch.ops.aten.permute.default(permute_211, [1, 0]); permute_211 = None + mm_404 = torch.ops.aten.mm.default(view_1401, permute_763); view_1401 = permute_763 = None + view_1402 = torch.ops.aten.view.default(mm_404, [2, 8192, 4096]); mm_404 = None + convert_element_type_1759 = torch.ops.prims.convert_element_type.default(mm_403, torch.float32); mm_403 = None + reduce_scatter_tensor_115 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1759, 'avg', 256, '0'); convert_element_type_1759 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_115); reduce_scatter_tensor_115 = None + view_1403 = torch.ops.aten.view.default(view_1399, [16384, 1024]); view_1399 = None + permute_765 = torch.ops.aten.permute.default(view_1403, [1, 0]) + mm_405 = torch.ops.aten.mm.default(permute_765, view_649); permute_765 = None + permute_767 = torch.ops.aten.permute.default(permute_210, [1, 0]); permute_210 = None + mm_406 = torch.ops.aten.mm.default(view_1403, permute_767); view_1403 = permute_767 = None + view_1404 = torch.ops.aten.view.default(mm_406, [2, 8192, 4096]); mm_406 = None + add_217 = torch.ops.aten.add.Tensor(view_1402, view_1404); view_1402 = view_1404 = None + convert_element_type_1764 = torch.ops.prims.convert_element_type.default(mm_405, torch.float32); mm_405 = None + reduce_scatter_tensor_116 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1764, 'avg', 256, '0'); convert_element_type_1764 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_116); reduce_scatter_tensor_116 = None + view_1405 = torch.ops.aten.view.default(view_1400, [16384, 4096]); view_1400 = None + permute_769 = torch.ops.aten.permute.default(view_1405, [1, 0]) + mm_407 = torch.ops.aten.mm.default(permute_769, view_649); permute_769 = view_649 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16); primals_176 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 256, '0'); convert_element_type_631 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + permute_771 = torch.ops.aten.permute.default(permute_209, [1, 0]); permute_209 = None + mm_408 = torch.ops.aten.mm.default(view_1405, permute_771); view_1405 = permute_771 = None + view_1406 = torch.ops.aten.view.default(mm_408, [2, 8192, 4096]); mm_408 = None + add_218 = torch.ops.aten.add.Tensor(add_217, view_1406); add_217 = view_1406 = None + convert_element_type_1769 = torch.ops.prims.convert_element_type.default(mm_407, torch.float32); mm_407 = None + reduce_scatter_tensor_117 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1769, 'avg', 256, '0'); convert_element_type_1769 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_117); reduce_scatter_tensor_117 = None + convert_element_type_1770 = torch.ops.prims.convert_element_type.default(add_218, torch.float32); add_218 = None + convert_element_type_1772 = torch.ops.prims.convert_element_type.default(wait_tensor_172, torch.float32); wait_tensor_172 = None + mul_518 = torch.ops.aten.mul.Tensor(convert_element_type_1770, convert_element_type_1772); convert_element_type_1772 = None + mul_520 = torch.ops.aten.mul.Tensor(mul_152, mul_518) + sum_79 = torch.ops.aten.sum.dim_IntList(mul_520, [2], True); mul_520 = None + div_26 = torch.ops.aten.div.Tensor(mul_152, 4096) + mul_521 = torch.ops.aten.mul.Tensor(div_26, sum_79); div_26 = sum_79 = None + sub_39 = torch.ops.aten.sub.Tensor(mul_518, mul_521); mul_518 = mul_521 = None + mul_522 = torch.ops.aten.mul.Tensor(sub_39, rsqrt_38); sub_39 = rsqrt_38 = None + mul_523 = torch.ops.aten.mul.Tensor(convert_element_type_1770, mul_152); convert_element_type_1770 = mul_152 = None + sum_80 = torch.ops.aten.sum.dim_IntList(mul_523, [0, 1]); mul_523 = None + convert_element_type_1773 = torch.ops.prims.convert_element_type.default(mul_522, torch.bfloat16); mul_522 = None + add_219 = torch.ops.aten.add.Tensor(add_216, convert_element_type_1773); add_216 = convert_element_type_1773 = None + convert_element_type_default_39 = torch.ops.prims.convert_element_type.default(sum_80, torch.float32); sum_80 = None + reduce_scatter_tensor_118 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_39, 'avg', 256, '0'); convert_element_type_default_39 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_118); reduce_scatter_tensor_118 = None + view_1407 = torch.ops.aten.view.default(add_219, [16384, 4096]) + permute_773 = torch.ops.aten.permute.default(view_1407, [1, 0]) + permute_204 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_633 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16); primals_170 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 256, '0'); convert_element_type_611 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_635 = torch.ops.aten.view.default(view_633, [16384, 4096]); view_633 = None + mm_129 = torch.ops.aten.mm.default(view_635, permute_205) + view_636 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + add_73 = torch.ops.aten.add.Tensor(add_71, view_636); view_636 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16); primals_171 = None + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 256, '0'); convert_element_type_614 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32); add_73 = None + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_168) + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + view_639 = torch.ops.aten.view.default(convert_element_type_616, [16384, 4096]); convert_element_type_616 = None + view_640 = torch.ops.aten.view.default(mm_130, [2, 8192, 14336]); mm_130 = None + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_640, torch.float32); view_640 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16); primals_173 = None + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 256, '0'); convert_element_type_622 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_170, [1, 0]); wait_tensor_170 = None + mm_131 = torch.ops.aten.mm.default(view_639, permute_207) + view_643 = torch.ops.aten.view.default(mm_131, [2, 8192, 14336]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_643) + view_645 = torch.ops.aten.view.default(mul_151, [16384, 14336]); mul_151 = None + mm_409 = torch.ops.aten.mm.default(permute_773, view_645); permute_773 = view_645 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16); primals_174 = None + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 256, '0'); convert_element_type_625 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_171, [1, 0]); wait_tensor_171 = None + permute_775 = torch.ops.aten.permute.default(permute_208, [1, 0]); permute_208 = None + mm_410 = torch.ops.aten.mm.default(view_1407, permute_775); view_1407 = permute_775 = None + view_1408 = torch.ops.aten.view.default(mm_410, [2, 8192, 14336]); mm_410 = None + convert_element_type_1780 = torch.ops.prims.convert_element_type.default(mm_409, torch.float32); mm_409 = None + reduce_scatter_tensor_119 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1780, 'avg', 256, '0'); convert_element_type_1780 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_119); reduce_scatter_tensor_119 = None + mul_524 = torch.ops.aten.mul.Tensor(view_1408, convert_element_type_621); convert_element_type_621 = None + mul_525 = torch.ops.aten.mul.Tensor(view_1408, view_643); view_1408 = view_643 = None + view_1409 = torch.ops.aten.view.default(mul_524, [16384, 14336]); mul_524 = None + permute_777 = torch.ops.aten.permute.default(view_1409, [1, 0]) + mm_411 = torch.ops.aten.mm.default(permute_777, view_639); permute_777 = None + permute_779 = torch.ops.aten.permute.default(permute_207, [1, 0]); permute_207 = None + mm_412 = torch.ops.aten.mm.default(view_1409, permute_779); view_1409 = permute_779 = None + view_1410 = torch.ops.aten.view.default(mm_412, [2, 8192, 4096]); mm_412 = None + convert_element_type_1785 = torch.ops.prims.convert_element_type.default(mm_411, torch.float32); mm_411 = None + reduce_scatter_tensor_120 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1785, 'avg', 256, '0'); convert_element_type_1785 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_120); reduce_scatter_tensor_120 = None + convert_element_type_1786 = torch.ops.prims.convert_element_type.default(mul_525, torch.float32); mul_525 = None + neg_13 = torch.ops.aten.neg.default(convert_element_type_620) + exp_13 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_220 = torch.ops.aten.add.Tensor(exp_13, 1); exp_13 = None + reciprocal_13 = torch.ops.aten.reciprocal.default(add_220); add_220 = None + mul_526 = torch.ops.aten.mul.Tensor(reciprocal_13, 1); reciprocal_13 = None + mul_527 = torch.ops.aten.mul.Tensor(convert_element_type_1786, mul_526); convert_element_type_1786 = None + sub_40 = torch.ops.aten.sub.Tensor(1, mul_526); mul_526 = None + mul_528 = torch.ops.aten.mul.Tensor(convert_element_type_620, sub_40); convert_element_type_620 = sub_40 = None + add_221 = torch.ops.aten.add.Tensor(mul_528, 1); mul_528 = None + mul_529 = torch.ops.aten.mul.Tensor(mul_527, add_221); mul_527 = add_221 = None + convert_element_type_1788 = torch.ops.prims.convert_element_type.default(mul_529, torch.bfloat16); mul_529 = None + view_1411 = torch.ops.aten.view.default(convert_element_type_1788, [16384, 14336]); convert_element_type_1788 = None + permute_781 = torch.ops.aten.permute.default(view_1411, [1, 0]) + mm_413 = torch.ops.aten.mm.default(permute_781, view_639); permute_781 = view_639 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16); primals_172 = None + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 256, '0'); convert_element_type_617 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + permute_783 = torch.ops.aten.permute.default(permute_206, [1, 0]); permute_206 = None + mm_414 = torch.ops.aten.mm.default(view_1411, permute_783); view_1411 = permute_783 = None + view_1412 = torch.ops.aten.view.default(mm_414, [2, 8192, 4096]); mm_414 = None + add_222 = torch.ops.aten.add.Tensor(view_1410, view_1412); view_1410 = view_1412 = None + convert_element_type_1793 = torch.ops.prims.convert_element_type.default(mm_413, torch.float32); mm_413 = None + reduce_scatter_tensor_121 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1793, 'avg', 256, '0'); convert_element_type_1793 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_121); reduce_scatter_tensor_121 = None + convert_element_type_1794 = torch.ops.prims.convert_element_type.default(add_222, torch.float32); add_222 = None + convert_element_type_1796 = torch.ops.prims.convert_element_type.default(wait_tensor_168, torch.float32); wait_tensor_168 = None + mul_530 = torch.ops.aten.mul.Tensor(convert_element_type_1794, convert_element_type_1796); convert_element_type_1796 = None + mul_532 = torch.ops.aten.mul.Tensor(mul_148, mul_530) + sum_81 = torch.ops.aten.sum.dim_IntList(mul_532, [2], True); mul_532 = None + div_27 = torch.ops.aten.div.Tensor(mul_148, 4096) + mul_533 = torch.ops.aten.mul.Tensor(div_27, sum_81); div_27 = sum_81 = None + sub_41 = torch.ops.aten.sub.Tensor(mul_530, mul_533); mul_530 = mul_533 = None + mul_534 = torch.ops.aten.mul.Tensor(sub_41, rsqrt_37); sub_41 = rsqrt_37 = None + mul_535 = torch.ops.aten.mul.Tensor(convert_element_type_1794, mul_148); convert_element_type_1794 = mul_148 = None + sum_82 = torch.ops.aten.sum.dim_IntList(mul_535, [0, 1]); mul_535 = None + convert_element_type_1797 = torch.ops.prims.convert_element_type.default(mul_534, torch.bfloat16); mul_534 = None + add_223 = torch.ops.aten.add.Tensor(add_219, convert_element_type_1797); add_219 = convert_element_type_1797 = None + convert_element_type_default_38 = torch.ops.prims.convert_element_type.default(sum_82, torch.float32); sum_82 = None + reduce_scatter_tensor_122 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_38, 'avg', 256, '0'); convert_element_type_default_38 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_122); reduce_scatter_tensor_122 = None + view_1413 = torch.ops.aten.view.default(add_223, [16384, 4096]) + permute_785 = torch.ops.aten.permute.default(view_1413, [1, 0]) + mm_415 = torch.ops.aten.mm.default(permute_785, view_635); permute_785 = view_635 = None + permute_787 = torch.ops.aten.permute.default(permute_205, [1, 0]); permute_205 = None + mm_416 = torch.ops.aten.mm.default(view_1413, permute_787); view_1413 = permute_787 = None + view_1414 = torch.ops.aten.view.default(mm_416, [2, 8192, 4096]); mm_416 = None + convert_element_type_1804 = torch.ops.prims.convert_element_type.default(mm_415, torch.float32); mm_415 = None + reduce_scatter_tensor_123 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1804, 'avg', 256, '0'); convert_element_type_1804 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_123); reduce_scatter_tensor_123 = None + view_1415 = torch.ops.aten.view.default(view_1414, [2, 8192, 32, 128]); view_1414 = None + permute_789 = torch.ops.aten.permute.default(view_1415, [0, 2, 1, 3]); view_1415 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16); primals_166 = None + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 256, '0'); convert_element_type_595 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32); add_71 = None + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_163) + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + view_615 = torch.ops.aten.view.default(convert_element_type_597, [16384, 4096]); convert_element_type_597 = None + view_616 = torch.ops.aten.view.default(mm_126, [2, 8192, 4096]); mm_126 = None + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16); primals_168 = None + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 256, '0'); convert_element_type_601 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_165, [1, 0]); wait_tensor_165 = None + mm_127 = torch.ops.aten.mm.default(view_615, permute_199) + view_619 = torch.ops.aten.view.default(mm_127, [2, 8192, 1024]); mm_127 = None + view_622 = torch.ops.aten.view.default(mm_128, [2, 8192, 1024]); mm_128 = None + view_623 = torch.ops.aten.view.default(view_616, [2, 8192, -1, 128]); view_616 = None + view_624 = torch.ops.aten.view.default(view_619, [2, 8192, -1, 128]); view_619 = None + view_625 = torch.ops.aten.view.default(view_622, [2, 8192, -1, 128]); view_622 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_623, torch.float32); view_623 = None + view_626 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 32, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_626); view_626 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_624, torch.float32); view_624 = None + view_627 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 8, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_627); view_627 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_16); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_629 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 32, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_16); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_630 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 8, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_629, torch.bfloat16); view_629 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_630, torch.bfloat16); view_630 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 8, 4, 128]); unsqueeze_36 = None + clone_36 = torch.ops.aten.clone.default(expand_36, memory_format = torch.contiguous_format); expand_36 = None + view_631 = torch.ops.aten.view.default(clone_36, [2, 8192, 32, 128]); clone_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_625, 3); view_625 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 8, 4, 128]); unsqueeze_37 = None + clone_37 = torch.ops.aten.clone.default(expand_37, memory_format = torch.contiguous_format); expand_37 = None + view_632 = torch.ops.aten.view.default(clone_37, [2, 8192, 32, 128]); clone_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_631, [0, 2, 1, 3]); view_631 = None + permute_203 = torch.ops.aten.permute.default(view_632, [0, 2, 1, 3]); view_632 = None + _scaled_dot_product_cudnn_attention_backward_13 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_789, permute_201, permute_202, permute_203, getitem_162, getitem_163, getitem_168, getitem_169, None, None, None, 8192, 8192, 0.0, True); permute_789 = permute_201 = permute_202 = permute_203 = getitem_162 = getitem_163 = getitem_168 = getitem_169 = None + getitem_327 = _scaled_dot_product_cudnn_attention_backward_13[0] + getitem_328 = _scaled_dot_product_cudnn_attention_backward_13[1] + getitem_329 = _scaled_dot_product_cudnn_attention_backward_13[2]; _scaled_dot_product_cudnn_attention_backward_13 = None + permute_790 = torch.ops.aten.permute.default(getitem_329, [0, 2, 1, 3]); getitem_329 = None + permute_791 = torch.ops.aten.permute.default(getitem_328, [0, 2, 1, 3]); getitem_328 = None + permute_792 = torch.ops.aten.permute.default(getitem_327, [0, 2, 1, 3]); getitem_327 = None + view_1416 = torch.ops.aten.view.default(permute_790, [2, 8192, 8, 4, 128]); permute_790 = None + sum_83 = torch.ops.aten.sum.dim_IntList(view_1416, [3], True); view_1416 = None + squeeze_26 = torch.ops.aten.squeeze.dim(sum_83, 3); sum_83 = None + view_1417 = torch.ops.aten.view.default(permute_791, [2, 8192, 8, 4, 128]); permute_791 = None + sum_84 = torch.ops.aten.sum.dim_IntList(view_1417, [3], True); view_1417 = None + squeeze_27 = torch.ops.aten.squeeze.dim(sum_84, 3); sum_84 = None + convert_element_type_1805 = torch.ops.prims.convert_element_type.default(squeeze_27, torch.float32); squeeze_27 = None + convert_element_type_1806 = torch.ops.prims.convert_element_type.default(permute_792, torch.float32); permute_792 = None + view_1418 = torch.ops.aten.view.default(convert_element_type_1805, [2, 8192, 8, 64, 2]); convert_element_type_1805 = None + view_as_complex_90 = torch.ops.aten.view_as_complex.default(view_1418); view_1418 = None + mul_536 = torch.ops.aten.mul.Tensor(view_as_complex_90, _conj); view_as_complex_90 = None + view_1419 = torch.ops.aten.view.default(convert_element_type_1806, [2, 8192, 32, 64, 2]); convert_element_type_1806 = None + view_as_complex_91 = torch.ops.aten.view_as_complex.default(view_1419); view_1419 = None + mul_537 = torch.ops.aten.mul.Tensor(view_as_complex_91, _conj); view_as_complex_91 = None + view_as_real_90 = torch.ops.aten.view_as_real.default(mul_536); mul_536 = None + view_1420 = torch.ops.aten.view.default(view_as_real_90, [2, 8192, 8, 128]); view_as_real_90 = None + convert_element_type_1807 = torch.ops.prims.convert_element_type.default(view_1420, torch.bfloat16); view_1420 = None + view_as_real_91 = torch.ops.aten.view_as_real.default(mul_537); mul_537 = None + view_1421 = torch.ops.aten.view.default(view_as_real_91, [2, 8192, 32, 128]); view_as_real_91 = None + convert_element_type_1808 = torch.ops.prims.convert_element_type.default(view_1421, torch.bfloat16); view_1421 = None + view_1422 = torch.ops.aten.view.default(squeeze_26, [2, 8192, 1024]); squeeze_26 = None + view_1423 = torch.ops.aten.view.default(convert_element_type_1807, [2, 8192, 1024]); convert_element_type_1807 = None + view_1424 = torch.ops.aten.view.default(convert_element_type_1808, [2, 8192, 4096]); convert_element_type_1808 = None + view_1425 = torch.ops.aten.view.default(view_1422, [16384, 1024]); view_1422 = None + permute_793 = torch.ops.aten.permute.default(view_1425, [1, 0]) + mm_417 = torch.ops.aten.mm.default(permute_793, view_615); permute_793 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16); primals_169 = None + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 256, '0'); convert_element_type_604 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_166, [1, 0]); wait_tensor_166 = None + permute_795 = torch.ops.aten.permute.default(permute_200, [1, 0]); permute_200 = None + mm_418 = torch.ops.aten.mm.default(view_1425, permute_795); view_1425 = permute_795 = None + view_1426 = torch.ops.aten.view.default(mm_418, [2, 8192, 4096]); mm_418 = None + convert_element_type_1813 = torch.ops.prims.convert_element_type.default(mm_417, torch.float32); mm_417 = None + reduce_scatter_tensor_124 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1813, 'avg', 256, '0'); convert_element_type_1813 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_124); reduce_scatter_tensor_124 = None + view_1427 = torch.ops.aten.view.default(view_1423, [16384, 1024]); view_1423 = None + permute_797 = torch.ops.aten.permute.default(view_1427, [1, 0]) + mm_419 = torch.ops.aten.mm.default(permute_797, view_615); permute_797 = None + permute_799 = torch.ops.aten.permute.default(permute_199, [1, 0]); permute_199 = None + mm_420 = torch.ops.aten.mm.default(view_1427, permute_799); view_1427 = permute_799 = None + view_1428 = torch.ops.aten.view.default(mm_420, [2, 8192, 4096]); mm_420 = None + add_224 = torch.ops.aten.add.Tensor(view_1426, view_1428); view_1426 = view_1428 = None + convert_element_type_1818 = torch.ops.prims.convert_element_type.default(mm_419, torch.float32); mm_419 = None + reduce_scatter_tensor_125 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1818, 'avg', 256, '0'); convert_element_type_1818 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_125); reduce_scatter_tensor_125 = None + view_1429 = torch.ops.aten.view.default(view_1424, [16384, 4096]); view_1424 = None + permute_801 = torch.ops.aten.permute.default(view_1429, [1, 0]) + mm_421 = torch.ops.aten.mm.default(permute_801, view_615); permute_801 = view_615 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16); primals_167 = None + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 256, '0'); convert_element_type_598 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_164, [1, 0]); wait_tensor_164 = None + permute_803 = torch.ops.aten.permute.default(permute_198, [1, 0]); permute_198 = None + mm_422 = torch.ops.aten.mm.default(view_1429, permute_803); view_1429 = permute_803 = None + view_1430 = torch.ops.aten.view.default(mm_422, [2, 8192, 4096]); mm_422 = None + add_225 = torch.ops.aten.add.Tensor(add_224, view_1430); add_224 = view_1430 = None + convert_element_type_1823 = torch.ops.prims.convert_element_type.default(mm_421, torch.float32); mm_421 = None + reduce_scatter_tensor_126 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1823, 'avg', 256, '0'); convert_element_type_1823 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_126); reduce_scatter_tensor_126 = None + convert_element_type_1824 = torch.ops.prims.convert_element_type.default(add_225, torch.float32); add_225 = None + convert_element_type_1826 = torch.ops.prims.convert_element_type.default(wait_tensor_163, torch.float32); wait_tensor_163 = None + mul_538 = torch.ops.aten.mul.Tensor(convert_element_type_1824, convert_element_type_1826); convert_element_type_1826 = None + mul_540 = torch.ops.aten.mul.Tensor(mul_144, mul_538) + sum_85 = torch.ops.aten.sum.dim_IntList(mul_540, [2], True); mul_540 = None + div_28 = torch.ops.aten.div.Tensor(mul_144, 4096) + mul_541 = torch.ops.aten.mul.Tensor(div_28, sum_85); div_28 = sum_85 = None + sub_42 = torch.ops.aten.sub.Tensor(mul_538, mul_541); mul_538 = mul_541 = None + mul_542 = torch.ops.aten.mul.Tensor(sub_42, rsqrt_36); sub_42 = rsqrt_36 = None + mul_543 = torch.ops.aten.mul.Tensor(convert_element_type_1824, mul_144); convert_element_type_1824 = mul_144 = None + sum_86 = torch.ops.aten.sum.dim_IntList(mul_543, [0, 1]); mul_543 = None + convert_element_type_1827 = torch.ops.prims.convert_element_type.default(mul_542, torch.bfloat16); mul_542 = None + add_226 = torch.ops.aten.add.Tensor(add_223, convert_element_type_1827); add_223 = convert_element_type_1827 = None + convert_element_type_default_37 = torch.ops.prims.convert_element_type.default(sum_86, torch.float32); sum_86 = None + reduce_scatter_tensor_127 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_37, 'avg', 256, '0'); convert_element_type_default_37 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_127); reduce_scatter_tensor_127 = None + view_1431 = torch.ops.aten.view.default(add_226, [16384, 4096]) + permute_805 = torch.ops.aten.permute.default(view_1431, [1, 0]) + permute_193 = torch.ops.aten.permute.default(getitem_153, [0, 2, 1, 3]) + view_599 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16); primals_161 = None + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 256, '0'); convert_element_type_578 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_158, [1, 0]); wait_tensor_158 = None + view_601 = torch.ops.aten.view.default(view_599, [16384, 4096]); view_599 = None + mm_122 = torch.ops.aten.mm.default(view_601, permute_194) + view_602 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + add_69 = torch.ops.aten.add.Tensor(add_67, view_602); view_602 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16); primals_162 = None + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 256, '0'); convert_element_type_581 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32); add_69 = None + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_159) + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + view_605 = torch.ops.aten.view.default(convert_element_type_583, [16384, 4096]); convert_element_type_583 = None + view_606 = torch.ops.aten.view.default(mm_123, [2, 8192, 14336]); mm_123 = None + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_606, torch.float32); view_606 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16); primals_164 = None + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 256, '0'); convert_element_type_589 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_124 = torch.ops.aten.mm.default(view_605, permute_196) + view_609 = torch.ops.aten.view.default(mm_124, [2, 8192, 14336]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_609) + view_611 = torch.ops.aten.view.default(mul_143, [16384, 14336]); mul_143 = None + mm_423 = torch.ops.aten.mm.default(permute_805, view_611); permute_805 = view_611 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16); primals_165 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 256, '0'); convert_element_type_592 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + permute_807 = torch.ops.aten.permute.default(permute_197, [1, 0]); permute_197 = None + mm_424 = torch.ops.aten.mm.default(view_1431, permute_807); view_1431 = permute_807 = None + view_1432 = torch.ops.aten.view.default(mm_424, [2, 8192, 14336]); mm_424 = None + convert_element_type_1834 = torch.ops.prims.convert_element_type.default(mm_423, torch.float32); mm_423 = None + reduce_scatter_tensor_128 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1834, 'avg', 256, '0'); convert_element_type_1834 = None + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_128); reduce_scatter_tensor_128 = None + mul_544 = torch.ops.aten.mul.Tensor(view_1432, convert_element_type_588); convert_element_type_588 = None + mul_545 = torch.ops.aten.mul.Tensor(view_1432, view_609); view_1432 = view_609 = None + view_1433 = torch.ops.aten.view.default(mul_544, [16384, 14336]); mul_544 = None + permute_809 = torch.ops.aten.permute.default(view_1433, [1, 0]) + mm_425 = torch.ops.aten.mm.default(permute_809, view_605); permute_809 = None + permute_811 = torch.ops.aten.permute.default(permute_196, [1, 0]); permute_196 = None + mm_426 = torch.ops.aten.mm.default(view_1433, permute_811); view_1433 = permute_811 = None + view_1434 = torch.ops.aten.view.default(mm_426, [2, 8192, 4096]); mm_426 = None + convert_element_type_1839 = torch.ops.prims.convert_element_type.default(mm_425, torch.float32); mm_425 = None + reduce_scatter_tensor_129 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1839, 'avg', 256, '0'); convert_element_type_1839 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_129); reduce_scatter_tensor_129 = None + convert_element_type_1840 = torch.ops.prims.convert_element_type.default(mul_545, torch.float32); mul_545 = None + neg_14 = torch.ops.aten.neg.default(convert_element_type_587) + exp_14 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_227 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + reciprocal_14 = torch.ops.aten.reciprocal.default(add_227); add_227 = None + mul_546 = torch.ops.aten.mul.Tensor(reciprocal_14, 1); reciprocal_14 = None + mul_547 = torch.ops.aten.mul.Tensor(convert_element_type_1840, mul_546); convert_element_type_1840 = None + sub_43 = torch.ops.aten.sub.Tensor(1, mul_546); mul_546 = None + mul_548 = torch.ops.aten.mul.Tensor(convert_element_type_587, sub_43); convert_element_type_587 = sub_43 = None + add_228 = torch.ops.aten.add.Tensor(mul_548, 1); mul_548 = None + mul_549 = torch.ops.aten.mul.Tensor(mul_547, add_228); mul_547 = add_228 = None + convert_element_type_1842 = torch.ops.prims.convert_element_type.default(mul_549, torch.bfloat16); mul_549 = None + view_1435 = torch.ops.aten.view.default(convert_element_type_1842, [16384, 14336]); convert_element_type_1842 = None + permute_813 = torch.ops.aten.permute.default(view_1435, [1, 0]) + mm_427 = torch.ops.aten.mm.default(permute_813, view_605); permute_813 = view_605 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16); primals_163 = None + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 256, '0'); convert_element_type_584 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + permute_815 = torch.ops.aten.permute.default(permute_195, [1, 0]); permute_195 = None + mm_428 = torch.ops.aten.mm.default(view_1435, permute_815); view_1435 = permute_815 = None + view_1436 = torch.ops.aten.view.default(mm_428, [2, 8192, 4096]); mm_428 = None + add_229 = torch.ops.aten.add.Tensor(view_1434, view_1436); view_1434 = view_1436 = None + convert_element_type_1847 = torch.ops.prims.convert_element_type.default(mm_427, torch.float32); mm_427 = None + reduce_scatter_tensor_130 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1847, 'avg', 256, '0'); convert_element_type_1847 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_130); reduce_scatter_tensor_130 = None + convert_element_type_1848 = torch.ops.prims.convert_element_type.default(add_229, torch.float32); add_229 = None + convert_element_type_1850 = torch.ops.prims.convert_element_type.default(wait_tensor_159, torch.float32); wait_tensor_159 = None + mul_550 = torch.ops.aten.mul.Tensor(convert_element_type_1848, convert_element_type_1850); convert_element_type_1850 = None + mul_552 = torch.ops.aten.mul.Tensor(mul_140, mul_550) + sum_87 = torch.ops.aten.sum.dim_IntList(mul_552, [2], True); mul_552 = None + div_29 = torch.ops.aten.div.Tensor(mul_140, 4096) + mul_553 = torch.ops.aten.mul.Tensor(div_29, sum_87); div_29 = sum_87 = None + sub_44 = torch.ops.aten.sub.Tensor(mul_550, mul_553); mul_550 = mul_553 = None + mul_554 = torch.ops.aten.mul.Tensor(sub_44, rsqrt_35); sub_44 = rsqrt_35 = None + mul_555 = torch.ops.aten.mul.Tensor(convert_element_type_1848, mul_140); convert_element_type_1848 = mul_140 = None + sum_88 = torch.ops.aten.sum.dim_IntList(mul_555, [0, 1]); mul_555 = None + convert_element_type_1851 = torch.ops.prims.convert_element_type.default(mul_554, torch.bfloat16); mul_554 = None + add_230 = torch.ops.aten.add.Tensor(add_226, convert_element_type_1851); add_226 = convert_element_type_1851 = None + convert_element_type_default_36 = torch.ops.prims.convert_element_type.default(sum_88, torch.float32); sum_88 = None + reduce_scatter_tensor_131 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_36, 'avg', 256, '0'); convert_element_type_default_36 = None + wait_tensor_422 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_131); reduce_scatter_tensor_131 = None + view_1437 = torch.ops.aten.view.default(add_230, [16384, 4096]) + permute_817 = torch.ops.aten.permute.default(view_1437, [1, 0]) + mm_429 = torch.ops.aten.mm.default(permute_817, view_601); permute_817 = view_601 = None + permute_819 = torch.ops.aten.permute.default(permute_194, [1, 0]); permute_194 = None + mm_430 = torch.ops.aten.mm.default(view_1437, permute_819); view_1437 = permute_819 = None + view_1438 = torch.ops.aten.view.default(mm_430, [2, 8192, 4096]); mm_430 = None + convert_element_type_1858 = torch.ops.prims.convert_element_type.default(mm_429, torch.float32); mm_429 = None + reduce_scatter_tensor_132 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1858, 'avg', 256, '0'); convert_element_type_1858 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_132); reduce_scatter_tensor_132 = None + view_1439 = torch.ops.aten.view.default(view_1438, [2, 8192, 32, 128]); view_1438 = None + permute_821 = torch.ops.aten.permute.default(view_1439, [0, 2, 1, 3]); view_1439 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16); primals_157 = None + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 256, '0'); convert_element_type_562 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32); add_67 = None + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_154) + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + view_581 = torch.ops.aten.view.default(convert_element_type_564, [16384, 4096]); convert_element_type_564 = None + view_582 = torch.ops.aten.view.default(mm_119, [2, 8192, 4096]); mm_119 = None + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16); primals_159 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 256, '0'); convert_element_type_568 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + mm_120 = torch.ops.aten.mm.default(view_581, permute_188) + view_585 = torch.ops.aten.view.default(mm_120, [2, 8192, 1024]); mm_120 = None + view_588 = torch.ops.aten.view.default(mm_121, [2, 8192, 1024]); mm_121 = None + view_589 = torch.ops.aten.view.default(view_582, [2, 8192, -1, 128]); view_582 = None + view_590 = torch.ops.aten.view.default(view_585, [2, 8192, -1, 128]); view_585 = None + view_591 = torch.ops.aten.view.default(view_588, [2, 8192, -1, 128]); view_588 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_589, torch.float32); view_589 = None + view_592 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 32, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_592); view_592 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_590, torch.float32); view_590 = None + view_593 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 8, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_593); view_593 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_16); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_595 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 32, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_16); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_596 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 8, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_595, torch.bfloat16); view_595 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_596, torch.bfloat16); view_596 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 8, 4, 128]); unsqueeze_34 = None + clone_34 = torch.ops.aten.clone.default(expand_34, memory_format = torch.contiguous_format); expand_34 = None + view_597 = torch.ops.aten.view.default(clone_34, [2, 8192, 32, 128]); clone_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_591, 3); view_591 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 8, 4, 128]); unsqueeze_35 = None + clone_35 = torch.ops.aten.clone.default(expand_35, memory_format = torch.contiguous_format); expand_35 = None + view_598 = torch.ops.aten.view.default(clone_35, [2, 8192, 32, 128]); clone_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_597, [0, 2, 1, 3]); view_597 = None + permute_192 = torch.ops.aten.permute.default(view_598, [0, 2, 1, 3]); view_598 = None + _scaled_dot_product_cudnn_attention_backward_14 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_821, permute_190, permute_191, permute_192, getitem_153, getitem_154, getitem_159, getitem_160, None, None, None, 8192, 8192, 0.0, True); permute_821 = permute_190 = permute_191 = permute_192 = getitem_153 = getitem_154 = getitem_159 = getitem_160 = None + getitem_330 = _scaled_dot_product_cudnn_attention_backward_14[0] + getitem_331 = _scaled_dot_product_cudnn_attention_backward_14[1] + getitem_332 = _scaled_dot_product_cudnn_attention_backward_14[2]; _scaled_dot_product_cudnn_attention_backward_14 = None + permute_822 = torch.ops.aten.permute.default(getitem_332, [0, 2, 1, 3]); getitem_332 = None + permute_823 = torch.ops.aten.permute.default(getitem_331, [0, 2, 1, 3]); getitem_331 = None + permute_824 = torch.ops.aten.permute.default(getitem_330, [0, 2, 1, 3]); getitem_330 = None + view_1440 = torch.ops.aten.view.default(permute_822, [2, 8192, 8, 4, 128]); permute_822 = None + sum_89 = torch.ops.aten.sum.dim_IntList(view_1440, [3], True); view_1440 = None + squeeze_28 = torch.ops.aten.squeeze.dim(sum_89, 3); sum_89 = None + view_1441 = torch.ops.aten.view.default(permute_823, [2, 8192, 8, 4, 128]); permute_823 = None + sum_90 = torch.ops.aten.sum.dim_IntList(view_1441, [3], True); view_1441 = None + squeeze_29 = torch.ops.aten.squeeze.dim(sum_90, 3); sum_90 = None + convert_element_type_1859 = torch.ops.prims.convert_element_type.default(squeeze_29, torch.float32); squeeze_29 = None + convert_element_type_1860 = torch.ops.prims.convert_element_type.default(permute_824, torch.float32); permute_824 = None + view_1442 = torch.ops.aten.view.default(convert_element_type_1859, [2, 8192, 8, 64, 2]); convert_element_type_1859 = None + view_as_complex_92 = torch.ops.aten.view_as_complex.default(view_1442); view_1442 = None + mul_556 = torch.ops.aten.mul.Tensor(view_as_complex_92, _conj); view_as_complex_92 = None + view_1443 = torch.ops.aten.view.default(convert_element_type_1860, [2, 8192, 32, 64, 2]); convert_element_type_1860 = None + view_as_complex_93 = torch.ops.aten.view_as_complex.default(view_1443); view_1443 = None + mul_557 = torch.ops.aten.mul.Tensor(view_as_complex_93, _conj); view_as_complex_93 = None + view_as_real_92 = torch.ops.aten.view_as_real.default(mul_556); mul_556 = None + view_1444 = torch.ops.aten.view.default(view_as_real_92, [2, 8192, 8, 128]); view_as_real_92 = None + convert_element_type_1861 = torch.ops.prims.convert_element_type.default(view_1444, torch.bfloat16); view_1444 = None + view_as_real_93 = torch.ops.aten.view_as_real.default(mul_557); mul_557 = None + view_1445 = torch.ops.aten.view.default(view_as_real_93, [2, 8192, 32, 128]); view_as_real_93 = None + convert_element_type_1862 = torch.ops.prims.convert_element_type.default(view_1445, torch.bfloat16); view_1445 = None + view_1446 = torch.ops.aten.view.default(squeeze_28, [2, 8192, 1024]); squeeze_28 = None + view_1447 = torch.ops.aten.view.default(convert_element_type_1861, [2, 8192, 1024]); convert_element_type_1861 = None + view_1448 = torch.ops.aten.view.default(convert_element_type_1862, [2, 8192, 4096]); convert_element_type_1862 = None + view_1449 = torch.ops.aten.view.default(view_1446, [16384, 1024]); view_1446 = None + permute_825 = torch.ops.aten.permute.default(view_1449, [1, 0]) + mm_431 = torch.ops.aten.mm.default(permute_825, view_581); permute_825 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16); primals_160 = None + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 256, '0'); convert_element_type_571 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + permute_827 = torch.ops.aten.permute.default(permute_189, [1, 0]); permute_189 = None + mm_432 = torch.ops.aten.mm.default(view_1449, permute_827); view_1449 = permute_827 = None + view_1450 = torch.ops.aten.view.default(mm_432, [2, 8192, 4096]); mm_432 = None + convert_element_type_1867 = torch.ops.prims.convert_element_type.default(mm_431, torch.float32); mm_431 = None + reduce_scatter_tensor_133 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1867, 'avg', 256, '0'); convert_element_type_1867 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_133); reduce_scatter_tensor_133 = None + view_1451 = torch.ops.aten.view.default(view_1447, [16384, 1024]); view_1447 = None + permute_829 = torch.ops.aten.permute.default(view_1451, [1, 0]) + mm_433 = torch.ops.aten.mm.default(permute_829, view_581); permute_829 = None + permute_831 = torch.ops.aten.permute.default(permute_188, [1, 0]); permute_188 = None + mm_434 = torch.ops.aten.mm.default(view_1451, permute_831); view_1451 = permute_831 = None + view_1452 = torch.ops.aten.view.default(mm_434, [2, 8192, 4096]); mm_434 = None + add_231 = torch.ops.aten.add.Tensor(view_1450, view_1452); view_1450 = view_1452 = None + convert_element_type_1872 = torch.ops.prims.convert_element_type.default(mm_433, torch.float32); mm_433 = None + reduce_scatter_tensor_134 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1872, 'avg', 256, '0'); convert_element_type_1872 = None + wait_tensor_425 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_134); reduce_scatter_tensor_134 = None + view_1453 = torch.ops.aten.view.default(view_1448, [16384, 4096]); view_1448 = None + permute_833 = torch.ops.aten.permute.default(view_1453, [1, 0]) + mm_435 = torch.ops.aten.mm.default(permute_833, view_581); permute_833 = view_581 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16); primals_158 = None + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 256, '0'); convert_element_type_565 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + permute_835 = torch.ops.aten.permute.default(permute_187, [1, 0]); permute_187 = None + mm_436 = torch.ops.aten.mm.default(view_1453, permute_835); view_1453 = permute_835 = None + view_1454 = torch.ops.aten.view.default(mm_436, [2, 8192, 4096]); mm_436 = None + add_232 = torch.ops.aten.add.Tensor(add_231, view_1454); add_231 = view_1454 = None + convert_element_type_1877 = torch.ops.prims.convert_element_type.default(mm_435, torch.float32); mm_435 = None + reduce_scatter_tensor_135 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1877, 'avg', 256, '0'); convert_element_type_1877 = None + wait_tensor_426 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_135); reduce_scatter_tensor_135 = None + convert_element_type_1878 = torch.ops.prims.convert_element_type.default(add_232, torch.float32); add_232 = None + convert_element_type_1880 = torch.ops.prims.convert_element_type.default(wait_tensor_154, torch.float32); wait_tensor_154 = None + mul_558 = torch.ops.aten.mul.Tensor(convert_element_type_1878, convert_element_type_1880); convert_element_type_1880 = None + mul_560 = torch.ops.aten.mul.Tensor(mul_136, mul_558) + sum_91 = torch.ops.aten.sum.dim_IntList(mul_560, [2], True); mul_560 = None + div_30 = torch.ops.aten.div.Tensor(mul_136, 4096) + mul_561 = torch.ops.aten.mul.Tensor(div_30, sum_91); div_30 = sum_91 = None + sub_45 = torch.ops.aten.sub.Tensor(mul_558, mul_561); mul_558 = mul_561 = None + mul_562 = torch.ops.aten.mul.Tensor(sub_45, rsqrt_34); sub_45 = rsqrt_34 = None + mul_563 = torch.ops.aten.mul.Tensor(convert_element_type_1878, mul_136); convert_element_type_1878 = mul_136 = None + sum_92 = torch.ops.aten.sum.dim_IntList(mul_563, [0, 1]); mul_563 = None + convert_element_type_1881 = torch.ops.prims.convert_element_type.default(mul_562, torch.bfloat16); mul_562 = None + add_233 = torch.ops.aten.add.Tensor(add_230, convert_element_type_1881); add_230 = convert_element_type_1881 = None + convert_element_type_default_35 = torch.ops.prims.convert_element_type.default(sum_92, torch.float32); sum_92 = None + reduce_scatter_tensor_136 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_35, 'avg', 256, '0'); convert_element_type_default_35 = None + wait_tensor_427 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_136); reduce_scatter_tensor_136 = None + view_1455 = torch.ops.aten.view.default(add_233, [16384, 4096]) + permute_837 = torch.ops.aten.permute.default(view_1455, [1, 0]) + permute_182 = torch.ops.aten.permute.default(getitem_144, [0, 2, 1, 3]) + view_565 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16); primals_152 = None + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 256, '0'); convert_element_type_545 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + view_567 = torch.ops.aten.view.default(view_565, [16384, 4096]); view_565 = None + mm_115 = torch.ops.aten.mm.default(view_567, permute_183) + view_568 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + add_65 = torch.ops.aten.add.Tensor(add_63, view_568); view_568 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16); primals_153 = None + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 256, '0'); convert_element_type_548 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32); add_65 = None + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_150) + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + view_571 = torch.ops.aten.view.default(convert_element_type_550, [16384, 4096]); convert_element_type_550 = None + view_572 = torch.ops.aten.view.default(mm_116, [2, 8192, 14336]); mm_116 = None + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_572, torch.float32); view_572 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16); primals_155 = None + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 256, '0'); convert_element_type_556 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_152, [1, 0]); wait_tensor_152 = None + mm_117 = torch.ops.aten.mm.default(view_571, permute_185) + view_575 = torch.ops.aten.view.default(mm_117, [2, 8192, 14336]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_575) + view_577 = torch.ops.aten.view.default(mul_135, [16384, 14336]); mul_135 = None + mm_437 = torch.ops.aten.mm.default(permute_837, view_577); permute_837 = view_577 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16); primals_156 = None + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 256, '0'); convert_element_type_559 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_153, [1, 0]); wait_tensor_153 = None + permute_839 = torch.ops.aten.permute.default(permute_186, [1, 0]); permute_186 = None + mm_438 = torch.ops.aten.mm.default(view_1455, permute_839); view_1455 = permute_839 = None + view_1456 = torch.ops.aten.view.default(mm_438, [2, 8192, 14336]); mm_438 = None + convert_element_type_1888 = torch.ops.prims.convert_element_type.default(mm_437, torch.float32); mm_437 = None + reduce_scatter_tensor_137 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1888, 'avg', 256, '0'); convert_element_type_1888 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_137); reduce_scatter_tensor_137 = None + mul_564 = torch.ops.aten.mul.Tensor(view_1456, convert_element_type_555); convert_element_type_555 = None + mul_565 = torch.ops.aten.mul.Tensor(view_1456, view_575); view_1456 = view_575 = None + view_1457 = torch.ops.aten.view.default(mul_564, [16384, 14336]); mul_564 = None + permute_841 = torch.ops.aten.permute.default(view_1457, [1, 0]) + mm_439 = torch.ops.aten.mm.default(permute_841, view_571); permute_841 = None + permute_843 = torch.ops.aten.permute.default(permute_185, [1, 0]); permute_185 = None + mm_440 = torch.ops.aten.mm.default(view_1457, permute_843); view_1457 = permute_843 = None + view_1458 = torch.ops.aten.view.default(mm_440, [2, 8192, 4096]); mm_440 = None + convert_element_type_1893 = torch.ops.prims.convert_element_type.default(mm_439, torch.float32); mm_439 = None + reduce_scatter_tensor_138 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1893, 'avg', 256, '0'); convert_element_type_1893 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_138); reduce_scatter_tensor_138 = None + convert_element_type_1894 = torch.ops.prims.convert_element_type.default(mul_565, torch.float32); mul_565 = None + neg_15 = torch.ops.aten.neg.default(convert_element_type_554) + exp_15 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_234 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + reciprocal_15 = torch.ops.aten.reciprocal.default(add_234); add_234 = None + mul_566 = torch.ops.aten.mul.Tensor(reciprocal_15, 1); reciprocal_15 = None + mul_567 = torch.ops.aten.mul.Tensor(convert_element_type_1894, mul_566); convert_element_type_1894 = None + sub_46 = torch.ops.aten.sub.Tensor(1, mul_566); mul_566 = None + mul_568 = torch.ops.aten.mul.Tensor(convert_element_type_554, sub_46); convert_element_type_554 = sub_46 = None + add_235 = torch.ops.aten.add.Tensor(mul_568, 1); mul_568 = None + mul_569 = torch.ops.aten.mul.Tensor(mul_567, add_235); mul_567 = add_235 = None + convert_element_type_1896 = torch.ops.prims.convert_element_type.default(mul_569, torch.bfloat16); mul_569 = None + view_1459 = torch.ops.aten.view.default(convert_element_type_1896, [16384, 14336]); convert_element_type_1896 = None + permute_845 = torch.ops.aten.permute.default(view_1459, [1, 0]) + mm_441 = torch.ops.aten.mm.default(permute_845, view_571); permute_845 = view_571 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16); primals_154 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 256, '0'); convert_element_type_551 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_151, [1, 0]); wait_tensor_151 = None + permute_847 = torch.ops.aten.permute.default(permute_184, [1, 0]); permute_184 = None + mm_442 = torch.ops.aten.mm.default(view_1459, permute_847); view_1459 = permute_847 = None + view_1460 = torch.ops.aten.view.default(mm_442, [2, 8192, 4096]); mm_442 = None + add_236 = torch.ops.aten.add.Tensor(view_1458, view_1460); view_1458 = view_1460 = None + convert_element_type_1901 = torch.ops.prims.convert_element_type.default(mm_441, torch.float32); mm_441 = None + reduce_scatter_tensor_139 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1901, 'avg', 256, '0'); convert_element_type_1901 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_139); reduce_scatter_tensor_139 = None + convert_element_type_1902 = torch.ops.prims.convert_element_type.default(add_236, torch.float32); add_236 = None + convert_element_type_1904 = torch.ops.prims.convert_element_type.default(wait_tensor_150, torch.float32); wait_tensor_150 = None + mul_570 = torch.ops.aten.mul.Tensor(convert_element_type_1902, convert_element_type_1904); convert_element_type_1904 = None + mul_572 = torch.ops.aten.mul.Tensor(mul_132, mul_570) + sum_93 = torch.ops.aten.sum.dim_IntList(mul_572, [2], True); mul_572 = None + div_31 = torch.ops.aten.div.Tensor(mul_132, 4096) + mul_573 = torch.ops.aten.mul.Tensor(div_31, sum_93); div_31 = sum_93 = None + sub_47 = torch.ops.aten.sub.Tensor(mul_570, mul_573); mul_570 = mul_573 = None + mul_574 = torch.ops.aten.mul.Tensor(sub_47, rsqrt_33); sub_47 = rsqrt_33 = None + mul_575 = torch.ops.aten.mul.Tensor(convert_element_type_1902, mul_132); convert_element_type_1902 = mul_132 = None + sum_94 = torch.ops.aten.sum.dim_IntList(mul_575, [0, 1]); mul_575 = None + convert_element_type_1905 = torch.ops.prims.convert_element_type.default(mul_574, torch.bfloat16); mul_574 = None + add_237 = torch.ops.aten.add.Tensor(add_233, convert_element_type_1905); add_233 = convert_element_type_1905 = None + convert_element_type_default_34 = torch.ops.prims.convert_element_type.default(sum_94, torch.float32); sum_94 = None + reduce_scatter_tensor_140 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_34, 'avg', 256, '0'); convert_element_type_default_34 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_140); reduce_scatter_tensor_140 = None + view_1461 = torch.ops.aten.view.default(add_237, [16384, 4096]) + permute_849 = torch.ops.aten.permute.default(view_1461, [1, 0]) + mm_443 = torch.ops.aten.mm.default(permute_849, view_567); permute_849 = view_567 = None + permute_851 = torch.ops.aten.permute.default(permute_183, [1, 0]); permute_183 = None + mm_444 = torch.ops.aten.mm.default(view_1461, permute_851); view_1461 = permute_851 = None + view_1462 = torch.ops.aten.view.default(mm_444, [2, 8192, 4096]); mm_444 = None + convert_element_type_1912 = torch.ops.prims.convert_element_type.default(mm_443, torch.float32); mm_443 = None + reduce_scatter_tensor_141 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1912, 'avg', 256, '0'); convert_element_type_1912 = None + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_141); reduce_scatter_tensor_141 = None + view_1463 = torch.ops.aten.view.default(view_1462, [2, 8192, 32, 128]); view_1462 = None + permute_853 = torch.ops.aten.permute.default(view_1463, [0, 2, 1, 3]); view_1463 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16); primals_148 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 256, '0'); convert_element_type_529 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32); add_63 = None + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_145) + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + view_547 = torch.ops.aten.view.default(convert_element_type_531, [16384, 4096]); convert_element_type_531 = None + view_548 = torch.ops.aten.view.default(mm_112, [2, 8192, 4096]); mm_112 = None + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16); primals_150 = None + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 256, '0'); convert_element_type_535 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + mm_113 = torch.ops.aten.mm.default(view_547, permute_177) + view_551 = torch.ops.aten.view.default(mm_113, [2, 8192, 1024]); mm_113 = None + view_554 = torch.ops.aten.view.default(mm_114, [2, 8192, 1024]); mm_114 = None + view_555 = torch.ops.aten.view.default(view_548, [2, 8192, -1, 128]); view_548 = None + view_556 = torch.ops.aten.view.default(view_551, [2, 8192, -1, 128]); view_551 = None + view_557 = torch.ops.aten.view.default(view_554, [2, 8192, -1, 128]); view_554 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_555, torch.float32); view_555 = None + view_558 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 32, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_558); view_558 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_556, torch.float32); view_556 = None + view_559 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 8, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_559); view_559 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_16); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_561 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 32, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_16); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_562 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 8, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_561, torch.bfloat16); view_561 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_562, torch.bfloat16); view_562 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 8, 4, 128]); unsqueeze_32 = None + clone_32 = torch.ops.aten.clone.default(expand_32, memory_format = torch.contiguous_format); expand_32 = None + view_563 = torch.ops.aten.view.default(clone_32, [2, 8192, 32, 128]); clone_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_557, 3); view_557 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 8, 4, 128]); unsqueeze_33 = None + clone_33 = torch.ops.aten.clone.default(expand_33, memory_format = torch.contiguous_format); expand_33 = None + view_564 = torch.ops.aten.view.default(clone_33, [2, 8192, 32, 128]); clone_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_563, [0, 2, 1, 3]); view_563 = None + permute_181 = torch.ops.aten.permute.default(view_564, [0, 2, 1, 3]); view_564 = None + _scaled_dot_product_cudnn_attention_backward_15 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_853, permute_179, permute_180, permute_181, getitem_144, getitem_145, getitem_150, getitem_151, None, None, None, 8192, 8192, 0.0, True); permute_853 = permute_179 = permute_180 = permute_181 = getitem_144 = getitem_145 = getitem_150 = getitem_151 = None + getitem_333 = _scaled_dot_product_cudnn_attention_backward_15[0] + getitem_334 = _scaled_dot_product_cudnn_attention_backward_15[1] + getitem_335 = _scaled_dot_product_cudnn_attention_backward_15[2]; _scaled_dot_product_cudnn_attention_backward_15 = None + permute_854 = torch.ops.aten.permute.default(getitem_335, [0, 2, 1, 3]); getitem_335 = None + permute_855 = torch.ops.aten.permute.default(getitem_334, [0, 2, 1, 3]); getitem_334 = None + permute_856 = torch.ops.aten.permute.default(getitem_333, [0, 2, 1, 3]); getitem_333 = None + view_1464 = torch.ops.aten.view.default(permute_854, [2, 8192, 8, 4, 128]); permute_854 = None + sum_95 = torch.ops.aten.sum.dim_IntList(view_1464, [3], True); view_1464 = None + squeeze_30 = torch.ops.aten.squeeze.dim(sum_95, 3); sum_95 = None + view_1465 = torch.ops.aten.view.default(permute_855, [2, 8192, 8, 4, 128]); permute_855 = None + sum_96 = torch.ops.aten.sum.dim_IntList(view_1465, [3], True); view_1465 = None + squeeze_31 = torch.ops.aten.squeeze.dim(sum_96, 3); sum_96 = None + convert_element_type_1913 = torch.ops.prims.convert_element_type.default(squeeze_31, torch.float32); squeeze_31 = None + convert_element_type_1914 = torch.ops.prims.convert_element_type.default(permute_856, torch.float32); permute_856 = None + view_1466 = torch.ops.aten.view.default(convert_element_type_1913, [2, 8192, 8, 64, 2]); convert_element_type_1913 = None + view_as_complex_94 = torch.ops.aten.view_as_complex.default(view_1466); view_1466 = None + mul_576 = torch.ops.aten.mul.Tensor(view_as_complex_94, _conj); view_as_complex_94 = None + view_1467 = torch.ops.aten.view.default(convert_element_type_1914, [2, 8192, 32, 64, 2]); convert_element_type_1914 = None + view_as_complex_95 = torch.ops.aten.view_as_complex.default(view_1467); view_1467 = None + mul_577 = torch.ops.aten.mul.Tensor(view_as_complex_95, _conj); view_as_complex_95 = None + view_as_real_94 = torch.ops.aten.view_as_real.default(mul_576); mul_576 = None + view_1468 = torch.ops.aten.view.default(view_as_real_94, [2, 8192, 8, 128]); view_as_real_94 = None + convert_element_type_1915 = torch.ops.prims.convert_element_type.default(view_1468, torch.bfloat16); view_1468 = None + view_as_real_95 = torch.ops.aten.view_as_real.default(mul_577); mul_577 = None + view_1469 = torch.ops.aten.view.default(view_as_real_95, [2, 8192, 32, 128]); view_as_real_95 = None + convert_element_type_1916 = torch.ops.prims.convert_element_type.default(view_1469, torch.bfloat16); view_1469 = None + view_1470 = torch.ops.aten.view.default(squeeze_30, [2, 8192, 1024]); squeeze_30 = None + view_1471 = torch.ops.aten.view.default(convert_element_type_1915, [2, 8192, 1024]); convert_element_type_1915 = None + view_1472 = torch.ops.aten.view.default(convert_element_type_1916, [2, 8192, 4096]); convert_element_type_1916 = None + view_1473 = torch.ops.aten.view.default(view_1470, [16384, 1024]); view_1470 = None + permute_857 = torch.ops.aten.permute.default(view_1473, [1, 0]) + mm_445 = torch.ops.aten.mm.default(permute_857, view_547); permute_857 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16); primals_151 = None + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 256, '0'); convert_element_type_538 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + permute_859 = torch.ops.aten.permute.default(permute_178, [1, 0]); permute_178 = None + mm_446 = torch.ops.aten.mm.default(view_1473, permute_859); view_1473 = permute_859 = None + view_1474 = torch.ops.aten.view.default(mm_446, [2, 8192, 4096]); mm_446 = None + convert_element_type_1921 = torch.ops.prims.convert_element_type.default(mm_445, torch.float32); mm_445 = None + reduce_scatter_tensor_142 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1921, 'avg', 256, '0'); convert_element_type_1921 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_142); reduce_scatter_tensor_142 = None + view_1475 = torch.ops.aten.view.default(view_1471, [16384, 1024]); view_1471 = None + permute_861 = torch.ops.aten.permute.default(view_1475, [1, 0]) + mm_447 = torch.ops.aten.mm.default(permute_861, view_547); permute_861 = None + permute_863 = torch.ops.aten.permute.default(permute_177, [1, 0]); permute_177 = None + mm_448 = torch.ops.aten.mm.default(view_1475, permute_863); view_1475 = permute_863 = None + view_1476 = torch.ops.aten.view.default(mm_448, [2, 8192, 4096]); mm_448 = None + add_238 = torch.ops.aten.add.Tensor(view_1474, view_1476); view_1474 = view_1476 = None + convert_element_type_1926 = torch.ops.prims.convert_element_type.default(mm_447, torch.float32); mm_447 = None + reduce_scatter_tensor_143 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1926, 'avg', 256, '0'); convert_element_type_1926 = None + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_143); reduce_scatter_tensor_143 = None + view_1477 = torch.ops.aten.view.default(view_1472, [16384, 4096]); view_1472 = None + permute_865 = torch.ops.aten.permute.default(view_1477, [1, 0]) + mm_449 = torch.ops.aten.mm.default(permute_865, view_547); permute_865 = view_547 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16); primals_149 = None + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 256, '0'); convert_element_type_532 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_146, [1, 0]); wait_tensor_146 = None + permute_867 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None + mm_450 = torch.ops.aten.mm.default(view_1477, permute_867); view_1477 = permute_867 = None + view_1478 = torch.ops.aten.view.default(mm_450, [2, 8192, 4096]); mm_450 = None + add_239 = torch.ops.aten.add.Tensor(add_238, view_1478); add_238 = view_1478 = None + convert_element_type_1931 = torch.ops.prims.convert_element_type.default(mm_449, torch.float32); mm_449 = None + reduce_scatter_tensor_144 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1931, 'avg', 256, '0'); convert_element_type_1931 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_144); reduce_scatter_tensor_144 = None + convert_element_type_1932 = torch.ops.prims.convert_element_type.default(add_239, torch.float32); add_239 = None + convert_element_type_1934 = torch.ops.prims.convert_element_type.default(wait_tensor_145, torch.float32); wait_tensor_145 = None + mul_578 = torch.ops.aten.mul.Tensor(convert_element_type_1932, convert_element_type_1934); convert_element_type_1934 = None + mul_580 = torch.ops.aten.mul.Tensor(mul_128, mul_578) + sum_97 = torch.ops.aten.sum.dim_IntList(mul_580, [2], True); mul_580 = None + div_32 = torch.ops.aten.div.Tensor(mul_128, 4096) + mul_581 = torch.ops.aten.mul.Tensor(div_32, sum_97); div_32 = sum_97 = None + sub_48 = torch.ops.aten.sub.Tensor(mul_578, mul_581); mul_578 = mul_581 = None + mul_582 = torch.ops.aten.mul.Tensor(sub_48, rsqrt_32); sub_48 = rsqrt_32 = None + mul_583 = torch.ops.aten.mul.Tensor(convert_element_type_1932, mul_128); convert_element_type_1932 = mul_128 = None + sum_98 = torch.ops.aten.sum.dim_IntList(mul_583, [0, 1]); mul_583 = None + convert_element_type_1935 = torch.ops.prims.convert_element_type.default(mul_582, torch.bfloat16); mul_582 = None + add_240 = torch.ops.aten.add.Tensor(add_237, convert_element_type_1935); add_237 = convert_element_type_1935 = None + convert_element_type_default_33 = torch.ops.prims.convert_element_type.default(sum_98, torch.float32); sum_98 = None + reduce_scatter_tensor_145 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_33, 'avg', 256, '0'); convert_element_type_default_33 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_145); reduce_scatter_tensor_145 = None + view_1479 = torch.ops.aten.view.default(add_240, [16384, 4096]) + permute_869 = torch.ops.aten.permute.default(view_1479, [1, 0]) + permute_171 = torch.ops.aten.permute.default(getitem_135, [0, 2, 1, 3]) + view_531 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 256, '0'); convert_element_type_512 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_140, [1, 0]); wait_tensor_140 = None + view_533 = torch.ops.aten.view.default(view_531, [16384, 4096]); view_531 = None + mm_108 = torch.ops.aten.mm.default(view_533, permute_172) + view_534 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + add_61 = torch.ops.aten.add.Tensor(add_59, view_534); view_534 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16); primals_144 = None + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 256, '0'); convert_element_type_515 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32); add_61 = None + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_141) + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + view_537 = torch.ops.aten.view.default(convert_element_type_517, [16384, 4096]); convert_element_type_517 = None + view_538 = torch.ops.aten.view.default(mm_109, [2, 8192, 14336]); mm_109 = None + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_538, torch.float32); view_538 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16); primals_146 = None + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 256, '0'); convert_element_type_523 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + mm_110 = torch.ops.aten.mm.default(view_537, permute_174) + view_541 = torch.ops.aten.view.default(mm_110, [2, 8192, 14336]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_541) + view_543 = torch.ops.aten.view.default(mul_127, [16384, 14336]); mul_127 = None + mm_451 = torch.ops.aten.mm.default(permute_869, view_543); permute_869 = view_543 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16); primals_147 = None + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 256, '0'); convert_element_type_526 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_144, [1, 0]); wait_tensor_144 = None + permute_871 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None + mm_452 = torch.ops.aten.mm.default(view_1479, permute_871); view_1479 = permute_871 = None + view_1480 = torch.ops.aten.view.default(mm_452, [2, 8192, 14336]); mm_452 = None + convert_element_type_1942 = torch.ops.prims.convert_element_type.default(mm_451, torch.float32); mm_451 = None + reduce_scatter_tensor_146 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1942, 'avg', 256, '0'); convert_element_type_1942 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_146); reduce_scatter_tensor_146 = None + mul_584 = torch.ops.aten.mul.Tensor(view_1480, convert_element_type_522); convert_element_type_522 = None + mul_585 = torch.ops.aten.mul.Tensor(view_1480, view_541); view_1480 = view_541 = None + view_1481 = torch.ops.aten.view.default(mul_584, [16384, 14336]); mul_584 = None + permute_873 = torch.ops.aten.permute.default(view_1481, [1, 0]) + mm_453 = torch.ops.aten.mm.default(permute_873, view_537); permute_873 = None + permute_875 = torch.ops.aten.permute.default(permute_174, [1, 0]); permute_174 = None + mm_454 = torch.ops.aten.mm.default(view_1481, permute_875); view_1481 = permute_875 = None + view_1482 = torch.ops.aten.view.default(mm_454, [2, 8192, 4096]); mm_454 = None + convert_element_type_1947 = torch.ops.prims.convert_element_type.default(mm_453, torch.float32); mm_453 = None + reduce_scatter_tensor_147 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1947, 'avg', 256, '0'); convert_element_type_1947 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_147); reduce_scatter_tensor_147 = None + convert_element_type_1948 = torch.ops.prims.convert_element_type.default(mul_585, torch.float32); mul_585 = None + neg_16 = torch.ops.aten.neg.default(convert_element_type_521) + exp_16 = torch.ops.aten.exp.default(neg_16); neg_16 = None + add_241 = torch.ops.aten.add.Tensor(exp_16, 1); exp_16 = None + reciprocal_16 = torch.ops.aten.reciprocal.default(add_241); add_241 = None + mul_586 = torch.ops.aten.mul.Tensor(reciprocal_16, 1); reciprocal_16 = None + mul_587 = torch.ops.aten.mul.Tensor(convert_element_type_1948, mul_586); convert_element_type_1948 = None + sub_49 = torch.ops.aten.sub.Tensor(1, mul_586); mul_586 = None + mul_588 = torch.ops.aten.mul.Tensor(convert_element_type_521, sub_49); convert_element_type_521 = sub_49 = None + add_242 = torch.ops.aten.add.Tensor(mul_588, 1); mul_588 = None + mul_589 = torch.ops.aten.mul.Tensor(mul_587, add_242); mul_587 = add_242 = None + convert_element_type_1950 = torch.ops.prims.convert_element_type.default(mul_589, torch.bfloat16); mul_589 = None + view_1483 = torch.ops.aten.view.default(convert_element_type_1950, [16384, 14336]); convert_element_type_1950 = None + permute_877 = torch.ops.aten.permute.default(view_1483, [1, 0]) + mm_455 = torch.ops.aten.mm.default(permute_877, view_537); permute_877 = view_537 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16); primals_145 = None + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 256, '0'); convert_element_type_518 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + permute_879 = torch.ops.aten.permute.default(permute_173, [1, 0]); permute_173 = None + mm_456 = torch.ops.aten.mm.default(view_1483, permute_879); view_1483 = permute_879 = None + view_1484 = torch.ops.aten.view.default(mm_456, [2, 8192, 4096]); mm_456 = None + add_243 = torch.ops.aten.add.Tensor(view_1482, view_1484); view_1482 = view_1484 = None + convert_element_type_1955 = torch.ops.prims.convert_element_type.default(mm_455, torch.float32); mm_455 = None + reduce_scatter_tensor_148 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1955, 'avg', 256, '0'); convert_element_type_1955 = None + wait_tensor_439 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_148); reduce_scatter_tensor_148 = None + convert_element_type_1956 = torch.ops.prims.convert_element_type.default(add_243, torch.float32); add_243 = None + convert_element_type_1958 = torch.ops.prims.convert_element_type.default(wait_tensor_141, torch.float32); wait_tensor_141 = None + mul_590 = torch.ops.aten.mul.Tensor(convert_element_type_1956, convert_element_type_1958); convert_element_type_1958 = None + mul_592 = torch.ops.aten.mul.Tensor(mul_124, mul_590) + sum_99 = torch.ops.aten.sum.dim_IntList(mul_592, [2], True); mul_592 = None + div_33 = torch.ops.aten.div.Tensor(mul_124, 4096) + mul_593 = torch.ops.aten.mul.Tensor(div_33, sum_99); div_33 = sum_99 = None + sub_50 = torch.ops.aten.sub.Tensor(mul_590, mul_593); mul_590 = mul_593 = None + mul_594 = torch.ops.aten.mul.Tensor(sub_50, rsqrt_31); sub_50 = rsqrt_31 = None + mul_595 = torch.ops.aten.mul.Tensor(convert_element_type_1956, mul_124); convert_element_type_1956 = mul_124 = None + sum_100 = torch.ops.aten.sum.dim_IntList(mul_595, [0, 1]); mul_595 = None + convert_element_type_1959 = torch.ops.prims.convert_element_type.default(mul_594, torch.bfloat16); mul_594 = None + add_244 = torch.ops.aten.add.Tensor(add_240, convert_element_type_1959); add_240 = convert_element_type_1959 = None + convert_element_type_default_32 = torch.ops.prims.convert_element_type.default(sum_100, torch.float32); sum_100 = None + reduce_scatter_tensor_149 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_32, 'avg', 256, '0'); convert_element_type_default_32 = None + wait_tensor_440 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_149); reduce_scatter_tensor_149 = None + view_1485 = torch.ops.aten.view.default(add_244, [16384, 4096]) + permute_881 = torch.ops.aten.permute.default(view_1485, [1, 0]) + mm_457 = torch.ops.aten.mm.default(permute_881, view_533); permute_881 = view_533 = None + permute_883 = torch.ops.aten.permute.default(permute_172, [1, 0]); permute_172 = None + mm_458 = torch.ops.aten.mm.default(view_1485, permute_883); view_1485 = permute_883 = None + view_1486 = torch.ops.aten.view.default(mm_458, [2, 8192, 4096]); mm_458 = None + convert_element_type_1966 = torch.ops.prims.convert_element_type.default(mm_457, torch.float32); mm_457 = None + reduce_scatter_tensor_150 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1966, 'avg', 256, '0'); convert_element_type_1966 = None + wait_tensor_441 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_150); reduce_scatter_tensor_150 = None + view_1487 = torch.ops.aten.view.default(view_1486, [2, 8192, 32, 128]); view_1486 = None + permute_885 = torch.ops.aten.permute.default(view_1487, [0, 2, 1, 3]); view_1487 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 256, '0'); convert_element_type_496 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32); add_59 = None + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_136) + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + view_513 = torch.ops.aten.view.default(convert_element_type_498, [16384, 4096]); convert_element_type_498 = None + view_514 = torch.ops.aten.view.default(mm_105, [2, 8192, 4096]); mm_105 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 256, '0'); convert_element_type_502 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + mm_106 = torch.ops.aten.mm.default(view_513, permute_166) + view_517 = torch.ops.aten.view.default(mm_106, [2, 8192, 1024]); mm_106 = None + view_520 = torch.ops.aten.view.default(mm_107, [2, 8192, 1024]); mm_107 = None + view_521 = torch.ops.aten.view.default(view_514, [2, 8192, -1, 128]); view_514 = None + view_522 = torch.ops.aten.view.default(view_517, [2, 8192, -1, 128]); view_517 = None + view_523 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_521, torch.float32); view_521 = None + view_524 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 32, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_524); view_524 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_522, torch.float32); view_522 = None + view_525 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 8, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_525); view_525 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_16); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_527 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 32, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_16); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_528 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 8, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_527, torch.bfloat16); view_527 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_528, torch.bfloat16); view_528 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 8, 4, 128]); unsqueeze_30 = None + clone_30 = torch.ops.aten.clone.default(expand_30, memory_format = torch.contiguous_format); expand_30 = None + view_529 = torch.ops.aten.view.default(clone_30, [2, 8192, 32, 128]); clone_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_523, 3); view_523 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 8, 4, 128]); unsqueeze_31 = None + clone_31 = torch.ops.aten.clone.default(expand_31, memory_format = torch.contiguous_format); expand_31 = None + view_530 = torch.ops.aten.view.default(clone_31, [2, 8192, 32, 128]); clone_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_529, [0, 2, 1, 3]); view_529 = None + permute_170 = torch.ops.aten.permute.default(view_530, [0, 2, 1, 3]); view_530 = None + _scaled_dot_product_cudnn_attention_backward_16 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_885, permute_168, permute_169, permute_170, getitem_135, getitem_136, getitem_141, getitem_142, None, None, None, 8192, 8192, 0.0, True); permute_885 = permute_168 = permute_169 = permute_170 = getitem_135 = getitem_136 = getitem_141 = getitem_142 = None + getitem_336 = _scaled_dot_product_cudnn_attention_backward_16[0] + getitem_337 = _scaled_dot_product_cudnn_attention_backward_16[1] + getitem_338 = _scaled_dot_product_cudnn_attention_backward_16[2]; _scaled_dot_product_cudnn_attention_backward_16 = None + permute_886 = torch.ops.aten.permute.default(getitem_338, [0, 2, 1, 3]); getitem_338 = None + permute_887 = torch.ops.aten.permute.default(getitem_337, [0, 2, 1, 3]); getitem_337 = None + permute_888 = torch.ops.aten.permute.default(getitem_336, [0, 2, 1, 3]); getitem_336 = None + view_1488 = torch.ops.aten.view.default(permute_886, [2, 8192, 8, 4, 128]); permute_886 = None + sum_101 = torch.ops.aten.sum.dim_IntList(view_1488, [3], True); view_1488 = None + squeeze_32 = torch.ops.aten.squeeze.dim(sum_101, 3); sum_101 = None + view_1489 = torch.ops.aten.view.default(permute_887, [2, 8192, 8, 4, 128]); permute_887 = None + sum_102 = torch.ops.aten.sum.dim_IntList(view_1489, [3], True); view_1489 = None + squeeze_33 = torch.ops.aten.squeeze.dim(sum_102, 3); sum_102 = None + convert_element_type_1967 = torch.ops.prims.convert_element_type.default(squeeze_33, torch.float32); squeeze_33 = None + convert_element_type_1968 = torch.ops.prims.convert_element_type.default(permute_888, torch.float32); permute_888 = None + view_1490 = torch.ops.aten.view.default(convert_element_type_1967, [2, 8192, 8, 64, 2]); convert_element_type_1967 = None + view_as_complex_96 = torch.ops.aten.view_as_complex.default(view_1490); view_1490 = None + mul_596 = torch.ops.aten.mul.Tensor(view_as_complex_96, _conj); view_as_complex_96 = None + view_1491 = torch.ops.aten.view.default(convert_element_type_1968, [2, 8192, 32, 64, 2]); convert_element_type_1968 = None + view_as_complex_97 = torch.ops.aten.view_as_complex.default(view_1491); view_1491 = None + mul_597 = torch.ops.aten.mul.Tensor(view_as_complex_97, _conj); view_as_complex_97 = None + view_as_real_96 = torch.ops.aten.view_as_real.default(mul_596); mul_596 = None + view_1492 = torch.ops.aten.view.default(view_as_real_96, [2, 8192, 8, 128]); view_as_real_96 = None + convert_element_type_1969 = torch.ops.prims.convert_element_type.default(view_1492, torch.bfloat16); view_1492 = None + view_as_real_97 = torch.ops.aten.view_as_real.default(mul_597); mul_597 = None + view_1493 = torch.ops.aten.view.default(view_as_real_97, [2, 8192, 32, 128]); view_as_real_97 = None + convert_element_type_1970 = torch.ops.prims.convert_element_type.default(view_1493, torch.bfloat16); view_1493 = None + view_1494 = torch.ops.aten.view.default(squeeze_32, [2, 8192, 1024]); squeeze_32 = None + view_1495 = torch.ops.aten.view.default(convert_element_type_1969, [2, 8192, 1024]); convert_element_type_1969 = None + view_1496 = torch.ops.aten.view.default(convert_element_type_1970, [2, 8192, 4096]); convert_element_type_1970 = None + view_1497 = torch.ops.aten.view.default(view_1494, [16384, 1024]); view_1494 = None + permute_889 = torch.ops.aten.permute.default(view_1497, [1, 0]) + mm_459 = torch.ops.aten.mm.default(permute_889, view_513); permute_889 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16); primals_142 = None + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 256, '0'); convert_element_type_505 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_139, [1, 0]); wait_tensor_139 = None + permute_891 = torch.ops.aten.permute.default(permute_167, [1, 0]); permute_167 = None + mm_460 = torch.ops.aten.mm.default(view_1497, permute_891); view_1497 = permute_891 = None + view_1498 = torch.ops.aten.view.default(mm_460, [2, 8192, 4096]); mm_460 = None + convert_element_type_1975 = torch.ops.prims.convert_element_type.default(mm_459, torch.float32); mm_459 = None + reduce_scatter_tensor_151 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1975, 'avg', 256, '0'); convert_element_type_1975 = None + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_151); reduce_scatter_tensor_151 = None + view_1499 = torch.ops.aten.view.default(view_1495, [16384, 1024]); view_1495 = None + permute_893 = torch.ops.aten.permute.default(view_1499, [1, 0]) + mm_461 = torch.ops.aten.mm.default(permute_893, view_513); permute_893 = None + permute_895 = torch.ops.aten.permute.default(permute_166, [1, 0]); permute_166 = None + mm_462 = torch.ops.aten.mm.default(view_1499, permute_895); view_1499 = permute_895 = None + view_1500 = torch.ops.aten.view.default(mm_462, [2, 8192, 4096]); mm_462 = None + add_245 = torch.ops.aten.add.Tensor(view_1498, view_1500); view_1498 = view_1500 = None + convert_element_type_1980 = torch.ops.prims.convert_element_type.default(mm_461, torch.float32); mm_461 = None + reduce_scatter_tensor_152 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1980, 'avg', 256, '0'); convert_element_type_1980 = None + wait_tensor_443 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_152); reduce_scatter_tensor_152 = None + view_1501 = torch.ops.aten.view.default(view_1496, [16384, 4096]); view_1496 = None + permute_897 = torch.ops.aten.permute.default(view_1501, [1, 0]) + mm_463 = torch.ops.aten.mm.default(permute_897, view_513); permute_897 = view_513 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 256, '0'); convert_element_type_499 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + permute_899 = torch.ops.aten.permute.default(permute_165, [1, 0]); permute_165 = None + mm_464 = torch.ops.aten.mm.default(view_1501, permute_899); view_1501 = permute_899 = None + view_1502 = torch.ops.aten.view.default(mm_464, [2, 8192, 4096]); mm_464 = None + add_246 = torch.ops.aten.add.Tensor(add_245, view_1502); add_245 = view_1502 = None + convert_element_type_1985 = torch.ops.prims.convert_element_type.default(mm_463, torch.float32); mm_463 = None + reduce_scatter_tensor_153 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1985, 'avg', 256, '0'); convert_element_type_1985 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_153); reduce_scatter_tensor_153 = None + convert_element_type_1986 = torch.ops.prims.convert_element_type.default(add_246, torch.float32); add_246 = None + convert_element_type_1988 = torch.ops.prims.convert_element_type.default(wait_tensor_136, torch.float32); wait_tensor_136 = None + mul_598 = torch.ops.aten.mul.Tensor(convert_element_type_1986, convert_element_type_1988); convert_element_type_1988 = None + mul_600 = torch.ops.aten.mul.Tensor(mul_120, mul_598) + sum_103 = torch.ops.aten.sum.dim_IntList(mul_600, [2], True); mul_600 = None + div_34 = torch.ops.aten.div.Tensor(mul_120, 4096) + mul_601 = torch.ops.aten.mul.Tensor(div_34, sum_103); div_34 = sum_103 = None + sub_51 = torch.ops.aten.sub.Tensor(mul_598, mul_601); mul_598 = mul_601 = None + mul_602 = torch.ops.aten.mul.Tensor(sub_51, rsqrt_30); sub_51 = rsqrt_30 = None + mul_603 = torch.ops.aten.mul.Tensor(convert_element_type_1986, mul_120); convert_element_type_1986 = mul_120 = None + sum_104 = torch.ops.aten.sum.dim_IntList(mul_603, [0, 1]); mul_603 = None + convert_element_type_1989 = torch.ops.prims.convert_element_type.default(mul_602, torch.bfloat16); mul_602 = None + add_247 = torch.ops.aten.add.Tensor(add_244, convert_element_type_1989); add_244 = convert_element_type_1989 = None + convert_element_type_default_31 = torch.ops.prims.convert_element_type.default(sum_104, torch.float32); sum_104 = None + reduce_scatter_tensor_154 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_31, 'avg', 256, '0'); convert_element_type_default_31 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_154); reduce_scatter_tensor_154 = None + view_1503 = torch.ops.aten.view.default(add_247, [16384, 4096]) + permute_901 = torch.ops.aten.permute.default(view_1503, [1, 0]) + permute_160 = torch.ops.aten.permute.default(getitem_126, [0, 2, 1, 3]) + view_497 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16); primals_134 = None + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 256, '0'); convert_element_type_479 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_131, [1, 0]); wait_tensor_131 = None + view_499 = torch.ops.aten.view.default(view_497, [16384, 4096]); view_497 = None + mm_101 = torch.ops.aten.mm.default(view_499, permute_161) + view_500 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + add_57 = torch.ops.aten.add.Tensor(add_55, view_500); view_500 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16); primals_135 = None + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 256, '0'); convert_element_type_482 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32); add_57 = None + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_132) + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + view_503 = torch.ops.aten.view.default(convert_element_type_484, [16384, 4096]); convert_element_type_484 = None + view_504 = torch.ops.aten.view.default(mm_102, [2, 8192, 14336]); mm_102 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_504, torch.float32); view_504 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 256, '0'); convert_element_type_490 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + mm_103 = torch.ops.aten.mm.default(view_503, permute_163) + view_507 = torch.ops.aten.view.default(mm_103, [2, 8192, 14336]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_507) + view_509 = torch.ops.aten.view.default(mul_119, [16384, 14336]); mul_119 = None + mm_465 = torch.ops.aten.mm.default(permute_901, view_509); permute_901 = view_509 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 256, '0'); convert_element_type_493 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + permute_903 = torch.ops.aten.permute.default(permute_164, [1, 0]); permute_164 = None + mm_466 = torch.ops.aten.mm.default(view_1503, permute_903); view_1503 = permute_903 = None + view_1504 = torch.ops.aten.view.default(mm_466, [2, 8192, 14336]); mm_466 = None + convert_element_type_1996 = torch.ops.prims.convert_element_type.default(mm_465, torch.float32); mm_465 = None + reduce_scatter_tensor_155 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1996, 'avg', 256, '0'); convert_element_type_1996 = None + wait_tensor_446 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_155); reduce_scatter_tensor_155 = None + mul_604 = torch.ops.aten.mul.Tensor(view_1504, convert_element_type_489); convert_element_type_489 = None + mul_605 = torch.ops.aten.mul.Tensor(view_1504, view_507); view_1504 = view_507 = None + view_1505 = torch.ops.aten.view.default(mul_604, [16384, 14336]); mul_604 = None + permute_905 = torch.ops.aten.permute.default(view_1505, [1, 0]) + mm_467 = torch.ops.aten.mm.default(permute_905, view_503); permute_905 = None + permute_907 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None + mm_468 = torch.ops.aten.mm.default(view_1505, permute_907); view_1505 = permute_907 = None + view_1506 = torch.ops.aten.view.default(mm_468, [2, 8192, 4096]); mm_468 = None + convert_element_type_2001 = torch.ops.prims.convert_element_type.default(mm_467, torch.float32); mm_467 = None + reduce_scatter_tensor_156 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2001, 'avg', 256, '0'); convert_element_type_2001 = None + wait_tensor_447 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_156); reduce_scatter_tensor_156 = None + convert_element_type_2002 = torch.ops.prims.convert_element_type.default(mul_605, torch.float32); mul_605 = None + neg_17 = torch.ops.aten.neg.default(convert_element_type_488) + exp_17 = torch.ops.aten.exp.default(neg_17); neg_17 = None + add_248 = torch.ops.aten.add.Tensor(exp_17, 1); exp_17 = None + reciprocal_17 = torch.ops.aten.reciprocal.default(add_248); add_248 = None + mul_606 = torch.ops.aten.mul.Tensor(reciprocal_17, 1); reciprocal_17 = None + mul_607 = torch.ops.aten.mul.Tensor(convert_element_type_2002, mul_606); convert_element_type_2002 = None + sub_52 = torch.ops.aten.sub.Tensor(1, mul_606); mul_606 = None + mul_608 = torch.ops.aten.mul.Tensor(convert_element_type_488, sub_52); convert_element_type_488 = sub_52 = None + add_249 = torch.ops.aten.add.Tensor(mul_608, 1); mul_608 = None + mul_609 = torch.ops.aten.mul.Tensor(mul_607, add_249); mul_607 = add_249 = None + convert_element_type_2004 = torch.ops.prims.convert_element_type.default(mul_609, torch.bfloat16); mul_609 = None + view_1507 = torch.ops.aten.view.default(convert_element_type_2004, [16384, 14336]); convert_element_type_2004 = None + permute_909 = torch.ops.aten.permute.default(view_1507, [1, 0]) + mm_469 = torch.ops.aten.mm.default(permute_909, view_503); permute_909 = view_503 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 256, '0'); convert_element_type_485 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_133, [1, 0]); wait_tensor_133 = None + permute_911 = torch.ops.aten.permute.default(permute_162, [1, 0]); permute_162 = None + mm_470 = torch.ops.aten.mm.default(view_1507, permute_911); view_1507 = permute_911 = None + view_1508 = torch.ops.aten.view.default(mm_470, [2, 8192, 4096]); mm_470 = None + add_250 = torch.ops.aten.add.Tensor(view_1506, view_1508); view_1506 = view_1508 = None + convert_element_type_2009 = torch.ops.prims.convert_element_type.default(mm_469, torch.float32); mm_469 = None + reduce_scatter_tensor_157 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2009, 'avg', 256, '0'); convert_element_type_2009 = None + wait_tensor_448 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_157); reduce_scatter_tensor_157 = None + convert_element_type_2010 = torch.ops.prims.convert_element_type.default(add_250, torch.float32); add_250 = None + convert_element_type_2012 = torch.ops.prims.convert_element_type.default(wait_tensor_132, torch.float32); wait_tensor_132 = None + mul_610 = torch.ops.aten.mul.Tensor(convert_element_type_2010, convert_element_type_2012); convert_element_type_2012 = None + mul_612 = torch.ops.aten.mul.Tensor(mul_116, mul_610) + sum_105 = torch.ops.aten.sum.dim_IntList(mul_612, [2], True); mul_612 = None + div_35 = torch.ops.aten.div.Tensor(mul_116, 4096) + mul_613 = torch.ops.aten.mul.Tensor(div_35, sum_105); div_35 = sum_105 = None + sub_53 = torch.ops.aten.sub.Tensor(mul_610, mul_613); mul_610 = mul_613 = None + mul_614 = torch.ops.aten.mul.Tensor(sub_53, rsqrt_29); sub_53 = rsqrt_29 = None + mul_615 = torch.ops.aten.mul.Tensor(convert_element_type_2010, mul_116); convert_element_type_2010 = mul_116 = None + sum_106 = torch.ops.aten.sum.dim_IntList(mul_615, [0, 1]); mul_615 = None + convert_element_type_2013 = torch.ops.prims.convert_element_type.default(mul_614, torch.bfloat16); mul_614 = None + add_251 = torch.ops.aten.add.Tensor(add_247, convert_element_type_2013); add_247 = convert_element_type_2013 = None + convert_element_type_default_30 = torch.ops.prims.convert_element_type.default(sum_106, torch.float32); sum_106 = None + reduce_scatter_tensor_158 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_30, 'avg', 256, '0'); convert_element_type_default_30 = None + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_158); reduce_scatter_tensor_158 = None + view_1509 = torch.ops.aten.view.default(add_251, [16384, 4096]) + permute_913 = torch.ops.aten.permute.default(view_1509, [1, 0]) + mm_471 = torch.ops.aten.mm.default(permute_913, view_499); permute_913 = view_499 = None + permute_915 = torch.ops.aten.permute.default(permute_161, [1, 0]); permute_161 = None + mm_472 = torch.ops.aten.mm.default(view_1509, permute_915); view_1509 = permute_915 = None + view_1510 = torch.ops.aten.view.default(mm_472, [2, 8192, 4096]); mm_472 = None + convert_element_type_2020 = torch.ops.prims.convert_element_type.default(mm_471, torch.float32); mm_471 = None + reduce_scatter_tensor_159 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2020, 'avg', 256, '0'); convert_element_type_2020 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_159); reduce_scatter_tensor_159 = None + view_1511 = torch.ops.aten.view.default(view_1510, [2, 8192, 32, 128]); view_1510 = None + permute_917 = torch.ops.aten.permute.default(view_1511, [0, 2, 1, 3]); view_1511 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16); primals_130 = None + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 256, '0'); convert_element_type_463 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32); add_55 = None + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_127) + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + view_479 = torch.ops.aten.view.default(convert_element_type_465, [16384, 4096]); convert_element_type_465 = None + view_480 = torch.ops.aten.view.default(mm_98, [2, 8192, 4096]); mm_98 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16); primals_132 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 256, '0'); convert_element_type_469 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_99 = torch.ops.aten.mm.default(view_479, permute_155) + view_483 = torch.ops.aten.view.default(mm_99, [2, 8192, 1024]); mm_99 = None + view_486 = torch.ops.aten.view.default(mm_100, [2, 8192, 1024]); mm_100 = None + view_487 = torch.ops.aten.view.default(view_480, [2, 8192, -1, 128]); view_480 = None + view_488 = torch.ops.aten.view.default(view_483, [2, 8192, -1, 128]); view_483 = None + view_489 = torch.ops.aten.view.default(view_486, [2, 8192, -1, 128]); view_486 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_487, torch.float32); view_487 = None + view_490 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 32, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_490); view_490 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_488, torch.float32); view_488 = None + view_491 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 8, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_491); view_491 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_16); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_493 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 32, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_16); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_494 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 8, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_493, torch.bfloat16); view_493 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_494, torch.bfloat16); view_494 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 8, 4, 128]); unsqueeze_28 = None + clone_28 = torch.ops.aten.clone.default(expand_28, memory_format = torch.contiguous_format); expand_28 = None + view_495 = torch.ops.aten.view.default(clone_28, [2, 8192, 32, 128]); clone_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_489, 3); view_489 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 8, 4, 128]); unsqueeze_29 = None + clone_29 = torch.ops.aten.clone.default(expand_29, memory_format = torch.contiguous_format); expand_29 = None + view_496 = torch.ops.aten.view.default(clone_29, [2, 8192, 32, 128]); clone_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_495, [0, 2, 1, 3]); view_495 = None + permute_159 = torch.ops.aten.permute.default(view_496, [0, 2, 1, 3]); view_496 = None + _scaled_dot_product_cudnn_attention_backward_17 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_917, permute_157, permute_158, permute_159, getitem_126, getitem_127, getitem_132, getitem_133, None, None, None, 8192, 8192, 0.0, True); permute_917 = permute_157 = permute_158 = permute_159 = getitem_126 = getitem_127 = getitem_132 = getitem_133 = None + getitem_339 = _scaled_dot_product_cudnn_attention_backward_17[0] + getitem_340 = _scaled_dot_product_cudnn_attention_backward_17[1] + getitem_341 = _scaled_dot_product_cudnn_attention_backward_17[2]; _scaled_dot_product_cudnn_attention_backward_17 = None + permute_918 = torch.ops.aten.permute.default(getitem_341, [0, 2, 1, 3]); getitem_341 = None + permute_919 = torch.ops.aten.permute.default(getitem_340, [0, 2, 1, 3]); getitem_340 = None + permute_920 = torch.ops.aten.permute.default(getitem_339, [0, 2, 1, 3]); getitem_339 = None + view_1512 = torch.ops.aten.view.default(permute_918, [2, 8192, 8, 4, 128]); permute_918 = None + sum_107 = torch.ops.aten.sum.dim_IntList(view_1512, [3], True); view_1512 = None + squeeze_34 = torch.ops.aten.squeeze.dim(sum_107, 3); sum_107 = None + view_1513 = torch.ops.aten.view.default(permute_919, [2, 8192, 8, 4, 128]); permute_919 = None + sum_108 = torch.ops.aten.sum.dim_IntList(view_1513, [3], True); view_1513 = None + squeeze_35 = torch.ops.aten.squeeze.dim(sum_108, 3); sum_108 = None + convert_element_type_2021 = torch.ops.prims.convert_element_type.default(squeeze_35, torch.float32); squeeze_35 = None + convert_element_type_2022 = torch.ops.prims.convert_element_type.default(permute_920, torch.float32); permute_920 = None + view_1514 = torch.ops.aten.view.default(convert_element_type_2021, [2, 8192, 8, 64, 2]); convert_element_type_2021 = None + view_as_complex_98 = torch.ops.aten.view_as_complex.default(view_1514); view_1514 = None + mul_616 = torch.ops.aten.mul.Tensor(view_as_complex_98, _conj); view_as_complex_98 = None + view_1515 = torch.ops.aten.view.default(convert_element_type_2022, [2, 8192, 32, 64, 2]); convert_element_type_2022 = None + view_as_complex_99 = torch.ops.aten.view_as_complex.default(view_1515); view_1515 = None + mul_617 = torch.ops.aten.mul.Tensor(view_as_complex_99, _conj); view_as_complex_99 = None + view_as_real_98 = torch.ops.aten.view_as_real.default(mul_616); mul_616 = None + view_1516 = torch.ops.aten.view.default(view_as_real_98, [2, 8192, 8, 128]); view_as_real_98 = None + convert_element_type_2023 = torch.ops.prims.convert_element_type.default(view_1516, torch.bfloat16); view_1516 = None + view_as_real_99 = torch.ops.aten.view_as_real.default(mul_617); mul_617 = None + view_1517 = torch.ops.aten.view.default(view_as_real_99, [2, 8192, 32, 128]); view_as_real_99 = None + convert_element_type_2024 = torch.ops.prims.convert_element_type.default(view_1517, torch.bfloat16); view_1517 = None + view_1518 = torch.ops.aten.view.default(squeeze_34, [2, 8192, 1024]); squeeze_34 = None + view_1519 = torch.ops.aten.view.default(convert_element_type_2023, [2, 8192, 1024]); convert_element_type_2023 = None + view_1520 = torch.ops.aten.view.default(convert_element_type_2024, [2, 8192, 4096]); convert_element_type_2024 = None + view_1521 = torch.ops.aten.view.default(view_1518, [16384, 1024]); view_1518 = None + permute_921 = torch.ops.aten.permute.default(view_1521, [1, 0]) + mm_473 = torch.ops.aten.mm.default(permute_921, view_479); permute_921 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16); primals_133 = None + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 256, '0'); convert_element_type_472 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + permute_923 = torch.ops.aten.permute.default(permute_156, [1, 0]); permute_156 = None + mm_474 = torch.ops.aten.mm.default(view_1521, permute_923); view_1521 = permute_923 = None + view_1522 = torch.ops.aten.view.default(mm_474, [2, 8192, 4096]); mm_474 = None + convert_element_type_2029 = torch.ops.prims.convert_element_type.default(mm_473, torch.float32); mm_473 = None + reduce_scatter_tensor_160 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2029, 'avg', 256, '0'); convert_element_type_2029 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_160); reduce_scatter_tensor_160 = None + view_1523 = torch.ops.aten.view.default(view_1519, [16384, 1024]); view_1519 = None + permute_925 = torch.ops.aten.permute.default(view_1523, [1, 0]) + mm_475 = torch.ops.aten.mm.default(permute_925, view_479); permute_925 = None + permute_927 = torch.ops.aten.permute.default(permute_155, [1, 0]); permute_155 = None + mm_476 = torch.ops.aten.mm.default(view_1523, permute_927); view_1523 = permute_927 = None + view_1524 = torch.ops.aten.view.default(mm_476, [2, 8192, 4096]); mm_476 = None + add_252 = torch.ops.aten.add.Tensor(view_1522, view_1524); view_1522 = view_1524 = None + convert_element_type_2034 = torch.ops.prims.convert_element_type.default(mm_475, torch.float32); mm_475 = None + reduce_scatter_tensor_161 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2034, 'avg', 256, '0'); convert_element_type_2034 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_161); reduce_scatter_tensor_161 = None + view_1525 = torch.ops.aten.view.default(view_1520, [16384, 4096]); view_1520 = None + permute_929 = torch.ops.aten.permute.default(view_1525, [1, 0]) + mm_477 = torch.ops.aten.mm.default(permute_929, view_479); permute_929 = view_479 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16); primals_131 = None + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 256, '0'); convert_element_type_466 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + permute_931 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None + mm_478 = torch.ops.aten.mm.default(view_1525, permute_931); view_1525 = permute_931 = None + view_1526 = torch.ops.aten.view.default(mm_478, [2, 8192, 4096]); mm_478 = None + add_253 = torch.ops.aten.add.Tensor(add_252, view_1526); add_252 = view_1526 = None + convert_element_type_2039 = torch.ops.prims.convert_element_type.default(mm_477, torch.float32); mm_477 = None + reduce_scatter_tensor_162 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2039, 'avg', 256, '0'); convert_element_type_2039 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_162); reduce_scatter_tensor_162 = None + convert_element_type_2040 = torch.ops.prims.convert_element_type.default(add_253, torch.float32); add_253 = None + convert_element_type_2042 = torch.ops.prims.convert_element_type.default(wait_tensor_127, torch.float32); wait_tensor_127 = None + mul_618 = torch.ops.aten.mul.Tensor(convert_element_type_2040, convert_element_type_2042); convert_element_type_2042 = None + mul_620 = torch.ops.aten.mul.Tensor(mul_112, mul_618) + sum_109 = torch.ops.aten.sum.dim_IntList(mul_620, [2], True); mul_620 = None + div_36 = torch.ops.aten.div.Tensor(mul_112, 4096) + mul_621 = torch.ops.aten.mul.Tensor(div_36, sum_109); div_36 = sum_109 = None + sub_54 = torch.ops.aten.sub.Tensor(mul_618, mul_621); mul_618 = mul_621 = None + mul_622 = torch.ops.aten.mul.Tensor(sub_54, rsqrt_28); sub_54 = rsqrt_28 = None + mul_623 = torch.ops.aten.mul.Tensor(convert_element_type_2040, mul_112); convert_element_type_2040 = mul_112 = None + sum_110 = torch.ops.aten.sum.dim_IntList(mul_623, [0, 1]); mul_623 = None + convert_element_type_2043 = torch.ops.prims.convert_element_type.default(mul_622, torch.bfloat16); mul_622 = None + add_254 = torch.ops.aten.add.Tensor(add_251, convert_element_type_2043); add_251 = convert_element_type_2043 = None + convert_element_type_default_29 = torch.ops.prims.convert_element_type.default(sum_110, torch.float32); sum_110 = None + reduce_scatter_tensor_163 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_29, 'avg', 256, '0'); convert_element_type_default_29 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_163); reduce_scatter_tensor_163 = None + view_1527 = torch.ops.aten.view.default(add_254, [16384, 4096]) + permute_933 = torch.ops.aten.permute.default(view_1527, [1, 0]) + permute_149 = torch.ops.aten.permute.default(getitem_117, [0, 2, 1, 3]) + view_463 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 256, '0'); convert_element_type_446 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + view_465 = torch.ops.aten.view.default(view_463, [16384, 4096]); view_463 = None + mm_94 = torch.ops.aten.mm.default(view_465, permute_150) + view_466 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + add_53 = torch.ops.aten.add.Tensor(add_51, view_466); view_466 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16); primals_126 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 256, '0'); convert_element_type_449 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32); add_53 = None + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_123) + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + view_469 = torch.ops.aten.view.default(convert_element_type_451, [16384, 4096]); convert_element_type_451 = None + view_470 = torch.ops.aten.view.default(mm_95, [2, 8192, 14336]); mm_95 = None + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_470, torch.float32); view_470 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16); primals_128 = None + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 256, '0'); convert_element_type_457 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_125, [1, 0]); wait_tensor_125 = None + mm_96 = torch.ops.aten.mm.default(view_469, permute_152) + view_473 = torch.ops.aten.view.default(mm_96, [2, 8192, 14336]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_473) + view_475 = torch.ops.aten.view.default(mul_111, [16384, 14336]); mul_111 = None + mm_479 = torch.ops.aten.mm.default(permute_933, view_475); permute_933 = view_475 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16); primals_129 = None + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 256, '0'); convert_element_type_460 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_126, [1, 0]); wait_tensor_126 = None + permute_935 = torch.ops.aten.permute.default(permute_153, [1, 0]); permute_153 = None + mm_480 = torch.ops.aten.mm.default(view_1527, permute_935); view_1527 = permute_935 = None + view_1528 = torch.ops.aten.view.default(mm_480, [2, 8192, 14336]); mm_480 = None + convert_element_type_2050 = torch.ops.prims.convert_element_type.default(mm_479, torch.float32); mm_479 = None + reduce_scatter_tensor_164 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2050, 'avg', 256, '0'); convert_element_type_2050 = None + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_164); reduce_scatter_tensor_164 = None + mul_624 = torch.ops.aten.mul.Tensor(view_1528, convert_element_type_456); convert_element_type_456 = None + mul_625 = torch.ops.aten.mul.Tensor(view_1528, view_473); view_1528 = view_473 = None + view_1529 = torch.ops.aten.view.default(mul_624, [16384, 14336]); mul_624 = None + permute_937 = torch.ops.aten.permute.default(view_1529, [1, 0]) + mm_481 = torch.ops.aten.mm.default(permute_937, view_469); permute_937 = None + permute_939 = torch.ops.aten.permute.default(permute_152, [1, 0]); permute_152 = None + mm_482 = torch.ops.aten.mm.default(view_1529, permute_939); view_1529 = permute_939 = None + view_1530 = torch.ops.aten.view.default(mm_482, [2, 8192, 4096]); mm_482 = None + convert_element_type_2055 = torch.ops.prims.convert_element_type.default(mm_481, torch.float32); mm_481 = None + reduce_scatter_tensor_165 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2055, 'avg', 256, '0'); convert_element_type_2055 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_165); reduce_scatter_tensor_165 = None + convert_element_type_2056 = torch.ops.prims.convert_element_type.default(mul_625, torch.float32); mul_625 = None + neg_18 = torch.ops.aten.neg.default(convert_element_type_455) + exp_18 = torch.ops.aten.exp.default(neg_18); neg_18 = None + add_255 = torch.ops.aten.add.Tensor(exp_18, 1); exp_18 = None + reciprocal_18 = torch.ops.aten.reciprocal.default(add_255); add_255 = None + mul_626 = torch.ops.aten.mul.Tensor(reciprocal_18, 1); reciprocal_18 = None + mul_627 = torch.ops.aten.mul.Tensor(convert_element_type_2056, mul_626); convert_element_type_2056 = None + sub_55 = torch.ops.aten.sub.Tensor(1, mul_626); mul_626 = None + mul_628 = torch.ops.aten.mul.Tensor(convert_element_type_455, sub_55); convert_element_type_455 = sub_55 = None + add_256 = torch.ops.aten.add.Tensor(mul_628, 1); mul_628 = None + mul_629 = torch.ops.aten.mul.Tensor(mul_627, add_256); mul_627 = add_256 = None + convert_element_type_2058 = torch.ops.prims.convert_element_type.default(mul_629, torch.bfloat16); mul_629 = None + view_1531 = torch.ops.aten.view.default(convert_element_type_2058, [16384, 14336]); convert_element_type_2058 = None + permute_941 = torch.ops.aten.permute.default(view_1531, [1, 0]) + mm_483 = torch.ops.aten.mm.default(permute_941, view_469); permute_941 = view_469 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16); primals_127 = None + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 256, '0'); convert_element_type_452 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + permute_943 = torch.ops.aten.permute.default(permute_151, [1, 0]); permute_151 = None + mm_484 = torch.ops.aten.mm.default(view_1531, permute_943); view_1531 = permute_943 = None + view_1532 = torch.ops.aten.view.default(mm_484, [2, 8192, 4096]); mm_484 = None + add_257 = torch.ops.aten.add.Tensor(view_1530, view_1532); view_1530 = view_1532 = None + convert_element_type_2063 = torch.ops.prims.convert_element_type.default(mm_483, torch.float32); mm_483 = None + reduce_scatter_tensor_166 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2063, 'avg', 256, '0'); convert_element_type_2063 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_166); reduce_scatter_tensor_166 = None + convert_element_type_2064 = torch.ops.prims.convert_element_type.default(add_257, torch.float32); add_257 = None + convert_element_type_2066 = torch.ops.prims.convert_element_type.default(wait_tensor_123, torch.float32); wait_tensor_123 = None + mul_630 = torch.ops.aten.mul.Tensor(convert_element_type_2064, convert_element_type_2066); convert_element_type_2066 = None + mul_632 = torch.ops.aten.mul.Tensor(mul_108, mul_630) + sum_111 = torch.ops.aten.sum.dim_IntList(mul_632, [2], True); mul_632 = None + div_37 = torch.ops.aten.div.Tensor(mul_108, 4096) + mul_633 = torch.ops.aten.mul.Tensor(div_37, sum_111); div_37 = sum_111 = None + sub_56 = torch.ops.aten.sub.Tensor(mul_630, mul_633); mul_630 = mul_633 = None + mul_634 = torch.ops.aten.mul.Tensor(sub_56, rsqrt_27); sub_56 = rsqrt_27 = None + mul_635 = torch.ops.aten.mul.Tensor(convert_element_type_2064, mul_108); convert_element_type_2064 = mul_108 = None + sum_112 = torch.ops.aten.sum.dim_IntList(mul_635, [0, 1]); mul_635 = None + convert_element_type_2067 = torch.ops.prims.convert_element_type.default(mul_634, torch.bfloat16); mul_634 = None + add_258 = torch.ops.aten.add.Tensor(add_254, convert_element_type_2067); add_254 = convert_element_type_2067 = None + convert_element_type_default_28 = torch.ops.prims.convert_element_type.default(sum_112, torch.float32); sum_112 = None + reduce_scatter_tensor_167 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_28, 'avg', 256, '0'); convert_element_type_default_28 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_167); reduce_scatter_tensor_167 = None + view_1533 = torch.ops.aten.view.default(add_258, [16384, 4096]) + permute_945 = torch.ops.aten.permute.default(view_1533, [1, 0]) + mm_485 = torch.ops.aten.mm.default(permute_945, view_465); permute_945 = view_465 = None + permute_947 = torch.ops.aten.permute.default(permute_150, [1, 0]); permute_150 = None + mm_486 = torch.ops.aten.mm.default(view_1533, permute_947); view_1533 = permute_947 = None + view_1534 = torch.ops.aten.view.default(mm_486, [2, 8192, 4096]); mm_486 = None + convert_element_type_2074 = torch.ops.prims.convert_element_type.default(mm_485, torch.float32); mm_485 = None + reduce_scatter_tensor_168 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2074, 'avg', 256, '0'); convert_element_type_2074 = None + wait_tensor_459 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_168); reduce_scatter_tensor_168 = None + view_1535 = torch.ops.aten.view.default(view_1534, [2, 8192, 32, 128]); view_1534 = None + permute_949 = torch.ops.aten.permute.default(view_1535, [0, 2, 1, 3]); view_1535 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 256, '0'); convert_element_type_430 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32); add_51 = None + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_118) + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + view_445 = torch.ops.aten.view.default(convert_element_type_432, [16384, 4096]); convert_element_type_432 = None + view_446 = torch.ops.aten.view.default(mm_91, [2, 8192, 4096]); mm_91 = None + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 256, '0'); convert_element_type_436 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + mm_92 = torch.ops.aten.mm.default(view_445, permute_144) + view_449 = torch.ops.aten.view.default(mm_92, [2, 8192, 1024]); mm_92 = None + view_452 = torch.ops.aten.view.default(mm_93, [2, 8192, 1024]); mm_93 = None + view_453 = torch.ops.aten.view.default(view_446, [2, 8192, -1, 128]); view_446 = None + view_454 = torch.ops.aten.view.default(view_449, [2, 8192, -1, 128]); view_449 = None + view_455 = torch.ops.aten.view.default(view_452, [2, 8192, -1, 128]); view_452 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_453, torch.float32); view_453 = None + view_456 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 32, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_456); view_456 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_454, torch.float32); view_454 = None + view_457 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 8, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_457); view_457 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_16); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_459 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 32, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_16); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_460 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 8, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_459, torch.bfloat16); view_459 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_460, torch.bfloat16); view_460 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 8, 4, 128]); unsqueeze_26 = None + clone_26 = torch.ops.aten.clone.default(expand_26, memory_format = torch.contiguous_format); expand_26 = None + view_461 = torch.ops.aten.view.default(clone_26, [2, 8192, 32, 128]); clone_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_455, 3); view_455 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 8, 4, 128]); unsqueeze_27 = None + clone_27 = torch.ops.aten.clone.default(expand_27, memory_format = torch.contiguous_format); expand_27 = None + view_462 = torch.ops.aten.view.default(clone_27, [2, 8192, 32, 128]); clone_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_461, [0, 2, 1, 3]); view_461 = None + permute_148 = torch.ops.aten.permute.default(view_462, [0, 2, 1, 3]); view_462 = None + _scaled_dot_product_cudnn_attention_backward_18 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_949, permute_146, permute_147, permute_148, getitem_117, getitem_118, getitem_123, getitem_124, None, None, None, 8192, 8192, 0.0, True); permute_949 = permute_146 = permute_147 = permute_148 = getitem_117 = getitem_118 = getitem_123 = getitem_124 = None + getitem_342 = _scaled_dot_product_cudnn_attention_backward_18[0] + getitem_343 = _scaled_dot_product_cudnn_attention_backward_18[1] + getitem_344 = _scaled_dot_product_cudnn_attention_backward_18[2]; _scaled_dot_product_cudnn_attention_backward_18 = None + permute_950 = torch.ops.aten.permute.default(getitem_344, [0, 2, 1, 3]); getitem_344 = None + permute_951 = torch.ops.aten.permute.default(getitem_343, [0, 2, 1, 3]); getitem_343 = None + permute_952 = torch.ops.aten.permute.default(getitem_342, [0, 2, 1, 3]); getitem_342 = None + view_1536 = torch.ops.aten.view.default(permute_950, [2, 8192, 8, 4, 128]); permute_950 = None + sum_113 = torch.ops.aten.sum.dim_IntList(view_1536, [3], True); view_1536 = None + squeeze_36 = torch.ops.aten.squeeze.dim(sum_113, 3); sum_113 = None + view_1537 = torch.ops.aten.view.default(permute_951, [2, 8192, 8, 4, 128]); permute_951 = None + sum_114 = torch.ops.aten.sum.dim_IntList(view_1537, [3], True); view_1537 = None + squeeze_37 = torch.ops.aten.squeeze.dim(sum_114, 3); sum_114 = None + convert_element_type_2075 = torch.ops.prims.convert_element_type.default(squeeze_37, torch.float32); squeeze_37 = None + convert_element_type_2076 = torch.ops.prims.convert_element_type.default(permute_952, torch.float32); permute_952 = None + view_1538 = torch.ops.aten.view.default(convert_element_type_2075, [2, 8192, 8, 64, 2]); convert_element_type_2075 = None + view_as_complex_100 = torch.ops.aten.view_as_complex.default(view_1538); view_1538 = None + mul_636 = torch.ops.aten.mul.Tensor(view_as_complex_100, _conj); view_as_complex_100 = None + view_1539 = torch.ops.aten.view.default(convert_element_type_2076, [2, 8192, 32, 64, 2]); convert_element_type_2076 = None + view_as_complex_101 = torch.ops.aten.view_as_complex.default(view_1539); view_1539 = None + mul_637 = torch.ops.aten.mul.Tensor(view_as_complex_101, _conj); view_as_complex_101 = None + view_as_real_100 = torch.ops.aten.view_as_real.default(mul_636); mul_636 = None + view_1540 = torch.ops.aten.view.default(view_as_real_100, [2, 8192, 8, 128]); view_as_real_100 = None + convert_element_type_2077 = torch.ops.prims.convert_element_type.default(view_1540, torch.bfloat16); view_1540 = None + view_as_real_101 = torch.ops.aten.view_as_real.default(mul_637); mul_637 = None + view_1541 = torch.ops.aten.view.default(view_as_real_101, [2, 8192, 32, 128]); view_as_real_101 = None + convert_element_type_2078 = torch.ops.prims.convert_element_type.default(view_1541, torch.bfloat16); view_1541 = None + view_1542 = torch.ops.aten.view.default(squeeze_36, [2, 8192, 1024]); squeeze_36 = None + view_1543 = torch.ops.aten.view.default(convert_element_type_2077, [2, 8192, 1024]); convert_element_type_2077 = None + view_1544 = torch.ops.aten.view.default(convert_element_type_2078, [2, 8192, 4096]); convert_element_type_2078 = None + view_1545 = torch.ops.aten.view.default(view_1542, [16384, 1024]); view_1542 = None + permute_953 = torch.ops.aten.permute.default(view_1545, [1, 0]) + mm_487 = torch.ops.aten.mm.default(permute_953, view_445); permute_953 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 256, '0'); convert_element_type_439 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + permute_955 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None + mm_488 = torch.ops.aten.mm.default(view_1545, permute_955); view_1545 = permute_955 = None + view_1546 = torch.ops.aten.view.default(mm_488, [2, 8192, 4096]); mm_488 = None + convert_element_type_2083 = torch.ops.prims.convert_element_type.default(mm_487, torch.float32); mm_487 = None + reduce_scatter_tensor_169 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2083, 'avg', 256, '0'); convert_element_type_2083 = None + wait_tensor_460 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_169); reduce_scatter_tensor_169 = None + view_1547 = torch.ops.aten.view.default(view_1543, [16384, 1024]); view_1543 = None + permute_957 = torch.ops.aten.permute.default(view_1547, [1, 0]) + mm_489 = torch.ops.aten.mm.default(permute_957, view_445); permute_957 = None + permute_959 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None + mm_490 = torch.ops.aten.mm.default(view_1547, permute_959); view_1547 = permute_959 = None + view_1548 = torch.ops.aten.view.default(mm_490, [2, 8192, 4096]); mm_490 = None + add_259 = torch.ops.aten.add.Tensor(view_1546, view_1548); view_1546 = view_1548 = None + convert_element_type_2088 = torch.ops.prims.convert_element_type.default(mm_489, torch.float32); mm_489 = None + reduce_scatter_tensor_170 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2088, 'avg', 256, '0'); convert_element_type_2088 = None + wait_tensor_461 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_170); reduce_scatter_tensor_170 = None + view_1549 = torch.ops.aten.view.default(view_1544, [16384, 4096]); view_1544 = None + permute_961 = torch.ops.aten.permute.default(view_1549, [1, 0]) + mm_491 = torch.ops.aten.mm.default(permute_961, view_445); permute_961 = view_445 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 256, '0'); convert_element_type_433 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_119, [1, 0]); wait_tensor_119 = None + permute_963 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None + mm_492 = torch.ops.aten.mm.default(view_1549, permute_963); view_1549 = permute_963 = None + view_1550 = torch.ops.aten.view.default(mm_492, [2, 8192, 4096]); mm_492 = None + add_260 = torch.ops.aten.add.Tensor(add_259, view_1550); add_259 = view_1550 = None + convert_element_type_2093 = torch.ops.prims.convert_element_type.default(mm_491, torch.float32); mm_491 = None + reduce_scatter_tensor_171 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2093, 'avg', 256, '0'); convert_element_type_2093 = None + wait_tensor_462 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_171); reduce_scatter_tensor_171 = None + convert_element_type_2094 = torch.ops.prims.convert_element_type.default(add_260, torch.float32); add_260 = None + convert_element_type_2096 = torch.ops.prims.convert_element_type.default(wait_tensor_118, torch.float32); wait_tensor_118 = None + mul_638 = torch.ops.aten.mul.Tensor(convert_element_type_2094, convert_element_type_2096); convert_element_type_2096 = None + mul_640 = torch.ops.aten.mul.Tensor(mul_104, mul_638) + sum_115 = torch.ops.aten.sum.dim_IntList(mul_640, [2], True); mul_640 = None + div_38 = torch.ops.aten.div.Tensor(mul_104, 4096) + mul_641 = torch.ops.aten.mul.Tensor(div_38, sum_115); div_38 = sum_115 = None + sub_57 = torch.ops.aten.sub.Tensor(mul_638, mul_641); mul_638 = mul_641 = None + mul_642 = torch.ops.aten.mul.Tensor(sub_57, rsqrt_26); sub_57 = rsqrt_26 = None + mul_643 = torch.ops.aten.mul.Tensor(convert_element_type_2094, mul_104); convert_element_type_2094 = mul_104 = None + sum_116 = torch.ops.aten.sum.dim_IntList(mul_643, [0, 1]); mul_643 = None + convert_element_type_2097 = torch.ops.prims.convert_element_type.default(mul_642, torch.bfloat16); mul_642 = None + add_261 = torch.ops.aten.add.Tensor(add_258, convert_element_type_2097); add_258 = convert_element_type_2097 = None + convert_element_type_default_27 = torch.ops.prims.convert_element_type.default(sum_116, torch.float32); sum_116 = None + reduce_scatter_tensor_172 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_27, 'avg', 256, '0'); convert_element_type_default_27 = None + wait_tensor_463 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_172); reduce_scatter_tensor_172 = None + view_1551 = torch.ops.aten.view.default(add_261, [16384, 4096]) + permute_965 = torch.ops.aten.permute.default(view_1551, [1, 0]) + permute_138 = torch.ops.aten.permute.default(getitem_108, [0, 2, 1, 3]) + view_429 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16); primals_116 = None + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 256, '0'); convert_element_type_413 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + view_431 = torch.ops.aten.view.default(view_429, [16384, 4096]); view_429 = None + mm_87 = torch.ops.aten.mm.default(view_431, permute_139) + view_432 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + add_49 = torch.ops.aten.add.Tensor(add_47, view_432); view_432 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16); primals_117 = None + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 256, '0'); convert_element_type_416 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32); add_49 = None + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_114) + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + view_435 = torch.ops.aten.view.default(convert_element_type_418, [16384, 4096]); convert_element_type_418 = None + view_436 = torch.ops.aten.view.default(mm_88, [2, 8192, 14336]); mm_88 = None + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_436, torch.float32); view_436 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 256, '0'); convert_element_type_424 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_89 = torch.ops.aten.mm.default(view_435, permute_141) + view_439 = torch.ops.aten.view.default(mm_89, [2, 8192, 14336]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_439) + view_441 = torch.ops.aten.view.default(mul_103, [16384, 14336]); mul_103 = None + mm_493 = torch.ops.aten.mm.default(permute_965, view_441); permute_965 = view_441 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 256, '0'); convert_element_type_427 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + permute_967 = torch.ops.aten.permute.default(permute_142, [1, 0]); permute_142 = None + mm_494 = torch.ops.aten.mm.default(view_1551, permute_967); view_1551 = permute_967 = None + view_1552 = torch.ops.aten.view.default(mm_494, [2, 8192, 14336]); mm_494 = None + convert_element_type_2104 = torch.ops.prims.convert_element_type.default(mm_493, torch.float32); mm_493 = None + reduce_scatter_tensor_173 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2104, 'avg', 256, '0'); convert_element_type_2104 = None + wait_tensor_464 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_173); reduce_scatter_tensor_173 = None + mul_644 = torch.ops.aten.mul.Tensor(view_1552, convert_element_type_423); convert_element_type_423 = None + mul_645 = torch.ops.aten.mul.Tensor(view_1552, view_439); view_1552 = view_439 = None + view_1553 = torch.ops.aten.view.default(mul_644, [16384, 14336]); mul_644 = None + permute_969 = torch.ops.aten.permute.default(view_1553, [1, 0]) + mm_495 = torch.ops.aten.mm.default(permute_969, view_435); permute_969 = None + permute_971 = torch.ops.aten.permute.default(permute_141, [1, 0]); permute_141 = None + mm_496 = torch.ops.aten.mm.default(view_1553, permute_971); view_1553 = permute_971 = None + view_1554 = torch.ops.aten.view.default(mm_496, [2, 8192, 4096]); mm_496 = None + convert_element_type_2109 = torch.ops.prims.convert_element_type.default(mm_495, torch.float32); mm_495 = None + reduce_scatter_tensor_174 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2109, 'avg', 256, '0'); convert_element_type_2109 = None + wait_tensor_465 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_174); reduce_scatter_tensor_174 = None + convert_element_type_2110 = torch.ops.prims.convert_element_type.default(mul_645, torch.float32); mul_645 = None + neg_19 = torch.ops.aten.neg.default(convert_element_type_422) + exp_19 = torch.ops.aten.exp.default(neg_19); neg_19 = None + add_262 = torch.ops.aten.add.Tensor(exp_19, 1); exp_19 = None + reciprocal_19 = torch.ops.aten.reciprocal.default(add_262); add_262 = None + mul_646 = torch.ops.aten.mul.Tensor(reciprocal_19, 1); reciprocal_19 = None + mul_647 = torch.ops.aten.mul.Tensor(convert_element_type_2110, mul_646); convert_element_type_2110 = None + sub_58 = torch.ops.aten.sub.Tensor(1, mul_646); mul_646 = None + mul_648 = torch.ops.aten.mul.Tensor(convert_element_type_422, sub_58); convert_element_type_422 = sub_58 = None + add_263 = torch.ops.aten.add.Tensor(mul_648, 1); mul_648 = None + mul_649 = torch.ops.aten.mul.Tensor(mul_647, add_263); mul_647 = add_263 = None + convert_element_type_2112 = torch.ops.prims.convert_element_type.default(mul_649, torch.bfloat16); mul_649 = None + view_1555 = torch.ops.aten.view.default(convert_element_type_2112, [16384, 14336]); convert_element_type_2112 = None + permute_973 = torch.ops.aten.permute.default(view_1555, [1, 0]) + mm_497 = torch.ops.aten.mm.default(permute_973, view_435); permute_973 = view_435 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 256, '0'); convert_element_type_419 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + permute_975 = torch.ops.aten.permute.default(permute_140, [1, 0]); permute_140 = None + mm_498 = torch.ops.aten.mm.default(view_1555, permute_975); view_1555 = permute_975 = None + view_1556 = torch.ops.aten.view.default(mm_498, [2, 8192, 4096]); mm_498 = None + add_264 = torch.ops.aten.add.Tensor(view_1554, view_1556); view_1554 = view_1556 = None + convert_element_type_2117 = torch.ops.prims.convert_element_type.default(mm_497, torch.float32); mm_497 = None + reduce_scatter_tensor_175 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2117, 'avg', 256, '0'); convert_element_type_2117 = None + wait_tensor_466 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_175); reduce_scatter_tensor_175 = None + convert_element_type_2118 = torch.ops.prims.convert_element_type.default(add_264, torch.float32); add_264 = None + convert_element_type_2120 = torch.ops.prims.convert_element_type.default(wait_tensor_114, torch.float32); wait_tensor_114 = None + mul_650 = torch.ops.aten.mul.Tensor(convert_element_type_2118, convert_element_type_2120); convert_element_type_2120 = None + mul_652 = torch.ops.aten.mul.Tensor(mul_100, mul_650) + sum_117 = torch.ops.aten.sum.dim_IntList(mul_652, [2], True); mul_652 = None + div_39 = torch.ops.aten.div.Tensor(mul_100, 4096) + mul_653 = torch.ops.aten.mul.Tensor(div_39, sum_117); div_39 = sum_117 = None + sub_59 = torch.ops.aten.sub.Tensor(mul_650, mul_653); mul_650 = mul_653 = None + mul_654 = torch.ops.aten.mul.Tensor(sub_59, rsqrt_25); sub_59 = rsqrt_25 = None + mul_655 = torch.ops.aten.mul.Tensor(convert_element_type_2118, mul_100); convert_element_type_2118 = mul_100 = None + sum_118 = torch.ops.aten.sum.dim_IntList(mul_655, [0, 1]); mul_655 = None + convert_element_type_2121 = torch.ops.prims.convert_element_type.default(mul_654, torch.bfloat16); mul_654 = None + add_265 = torch.ops.aten.add.Tensor(add_261, convert_element_type_2121); add_261 = convert_element_type_2121 = None + convert_element_type_default_26 = torch.ops.prims.convert_element_type.default(sum_118, torch.float32); sum_118 = None + reduce_scatter_tensor_176 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_26, 'avg', 256, '0'); convert_element_type_default_26 = None + wait_tensor_467 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_176); reduce_scatter_tensor_176 = None + view_1557 = torch.ops.aten.view.default(add_265, [16384, 4096]) + permute_977 = torch.ops.aten.permute.default(view_1557, [1, 0]) + mm_499 = torch.ops.aten.mm.default(permute_977, view_431); permute_977 = view_431 = None + permute_979 = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None + mm_500 = torch.ops.aten.mm.default(view_1557, permute_979); view_1557 = permute_979 = None + view_1558 = torch.ops.aten.view.default(mm_500, [2, 8192, 4096]); mm_500 = None + convert_element_type_2128 = torch.ops.prims.convert_element_type.default(mm_499, torch.float32); mm_499 = None + reduce_scatter_tensor_177 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2128, 'avg', 256, '0'); convert_element_type_2128 = None + wait_tensor_468 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_177); reduce_scatter_tensor_177 = None + view_1559 = torch.ops.aten.view.default(view_1558, [2, 8192, 32, 128]); view_1558 = None + permute_981 = torch.ops.aten.permute.default(view_1559, [0, 2, 1, 3]); view_1559 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16); primals_112 = None + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 256, '0'); convert_element_type_397 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32); add_47 = None + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_109) + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + view_411 = torch.ops.aten.view.default(convert_element_type_399, [16384, 4096]); convert_element_type_399 = None + view_412 = torch.ops.aten.view.default(mm_84, [2, 8192, 4096]); mm_84 = None + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16); primals_114 = None + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 256, '0'); convert_element_type_403 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + mm_85 = torch.ops.aten.mm.default(view_411, permute_133) + view_415 = torch.ops.aten.view.default(mm_85, [2, 8192, 1024]); mm_85 = None + view_418 = torch.ops.aten.view.default(mm_86, [2, 8192, 1024]); mm_86 = None + view_419 = torch.ops.aten.view.default(view_412, [2, 8192, -1, 128]); view_412 = None + view_420 = torch.ops.aten.view.default(view_415, [2, 8192, -1, 128]); view_415 = None + view_421 = torch.ops.aten.view.default(view_418, [2, 8192, -1, 128]); view_418 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_419, torch.float32); view_419 = None + view_422 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 32, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_422); view_422 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_420, torch.float32); view_420 = None + view_423 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 8, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_423); view_423 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_16); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_425 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 32, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_16); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_426 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 8, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_425, torch.bfloat16); view_425 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_426, torch.bfloat16); view_426 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 8, 4, 128]); unsqueeze_24 = None + clone_24 = torch.ops.aten.clone.default(expand_24, memory_format = torch.contiguous_format); expand_24 = None + view_427 = torch.ops.aten.view.default(clone_24, [2, 8192, 32, 128]); clone_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_421, 3); view_421 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 8, 4, 128]); unsqueeze_25 = None + clone_25 = torch.ops.aten.clone.default(expand_25, memory_format = torch.contiguous_format); expand_25 = None + view_428 = torch.ops.aten.view.default(clone_25, [2, 8192, 32, 128]); clone_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_427, [0, 2, 1, 3]); view_427 = None + permute_137 = torch.ops.aten.permute.default(view_428, [0, 2, 1, 3]); view_428 = None + _scaled_dot_product_cudnn_attention_backward_19 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_981, permute_135, permute_136, permute_137, getitem_108, getitem_109, getitem_114, getitem_115, None, None, None, 8192, 8192, 0.0, True); permute_981 = permute_135 = permute_136 = permute_137 = getitem_108 = getitem_109 = getitem_114 = getitem_115 = None + getitem_345 = _scaled_dot_product_cudnn_attention_backward_19[0] + getitem_346 = _scaled_dot_product_cudnn_attention_backward_19[1] + getitem_347 = _scaled_dot_product_cudnn_attention_backward_19[2]; _scaled_dot_product_cudnn_attention_backward_19 = None + permute_982 = torch.ops.aten.permute.default(getitem_347, [0, 2, 1, 3]); getitem_347 = None + permute_983 = torch.ops.aten.permute.default(getitem_346, [0, 2, 1, 3]); getitem_346 = None + permute_984 = torch.ops.aten.permute.default(getitem_345, [0, 2, 1, 3]); getitem_345 = None + view_1560 = torch.ops.aten.view.default(permute_982, [2, 8192, 8, 4, 128]); permute_982 = None + sum_119 = torch.ops.aten.sum.dim_IntList(view_1560, [3], True); view_1560 = None + squeeze_38 = torch.ops.aten.squeeze.dim(sum_119, 3); sum_119 = None + view_1561 = torch.ops.aten.view.default(permute_983, [2, 8192, 8, 4, 128]); permute_983 = None + sum_120 = torch.ops.aten.sum.dim_IntList(view_1561, [3], True); view_1561 = None + squeeze_39 = torch.ops.aten.squeeze.dim(sum_120, 3); sum_120 = None + convert_element_type_2129 = torch.ops.prims.convert_element_type.default(squeeze_39, torch.float32); squeeze_39 = None + convert_element_type_2130 = torch.ops.prims.convert_element_type.default(permute_984, torch.float32); permute_984 = None + view_1562 = torch.ops.aten.view.default(convert_element_type_2129, [2, 8192, 8, 64, 2]); convert_element_type_2129 = None + view_as_complex_102 = torch.ops.aten.view_as_complex.default(view_1562); view_1562 = None + mul_656 = torch.ops.aten.mul.Tensor(view_as_complex_102, _conj); view_as_complex_102 = None + view_1563 = torch.ops.aten.view.default(convert_element_type_2130, [2, 8192, 32, 64, 2]); convert_element_type_2130 = None + view_as_complex_103 = torch.ops.aten.view_as_complex.default(view_1563); view_1563 = None + mul_657 = torch.ops.aten.mul.Tensor(view_as_complex_103, _conj); view_as_complex_103 = None + view_as_real_102 = torch.ops.aten.view_as_real.default(mul_656); mul_656 = None + view_1564 = torch.ops.aten.view.default(view_as_real_102, [2, 8192, 8, 128]); view_as_real_102 = None + convert_element_type_2131 = torch.ops.prims.convert_element_type.default(view_1564, torch.bfloat16); view_1564 = None + view_as_real_103 = torch.ops.aten.view_as_real.default(mul_657); mul_657 = None + view_1565 = torch.ops.aten.view.default(view_as_real_103, [2, 8192, 32, 128]); view_as_real_103 = None + convert_element_type_2132 = torch.ops.prims.convert_element_type.default(view_1565, torch.bfloat16); view_1565 = None + view_1566 = torch.ops.aten.view.default(squeeze_38, [2, 8192, 1024]); squeeze_38 = None + view_1567 = torch.ops.aten.view.default(convert_element_type_2131, [2, 8192, 1024]); convert_element_type_2131 = None + view_1568 = torch.ops.aten.view.default(convert_element_type_2132, [2, 8192, 4096]); convert_element_type_2132 = None + view_1569 = torch.ops.aten.view.default(view_1566, [16384, 1024]); view_1566 = None + permute_985 = torch.ops.aten.permute.default(view_1569, [1, 0]) + mm_501 = torch.ops.aten.mm.default(permute_985, view_411); permute_985 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16); primals_115 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 256, '0'); convert_element_type_406 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_112, [1, 0]); wait_tensor_112 = None + permute_987 = torch.ops.aten.permute.default(permute_134, [1, 0]); permute_134 = None + mm_502 = torch.ops.aten.mm.default(view_1569, permute_987); view_1569 = permute_987 = None + view_1570 = torch.ops.aten.view.default(mm_502, [2, 8192, 4096]); mm_502 = None + convert_element_type_2137 = torch.ops.prims.convert_element_type.default(mm_501, torch.float32); mm_501 = None + reduce_scatter_tensor_178 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2137, 'avg', 256, '0'); convert_element_type_2137 = None + wait_tensor_469 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_178); reduce_scatter_tensor_178 = None + view_1571 = torch.ops.aten.view.default(view_1567, [16384, 1024]); view_1567 = None + permute_989 = torch.ops.aten.permute.default(view_1571, [1, 0]) + mm_503 = torch.ops.aten.mm.default(permute_989, view_411); permute_989 = None + permute_991 = torch.ops.aten.permute.default(permute_133, [1, 0]); permute_133 = None + mm_504 = torch.ops.aten.mm.default(view_1571, permute_991); view_1571 = permute_991 = None + view_1572 = torch.ops.aten.view.default(mm_504, [2, 8192, 4096]); mm_504 = None + add_266 = torch.ops.aten.add.Tensor(view_1570, view_1572); view_1570 = view_1572 = None + convert_element_type_2142 = torch.ops.prims.convert_element_type.default(mm_503, torch.float32); mm_503 = None + reduce_scatter_tensor_179 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2142, 'avg', 256, '0'); convert_element_type_2142 = None + wait_tensor_470 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_179); reduce_scatter_tensor_179 = None + view_1573 = torch.ops.aten.view.default(view_1568, [16384, 4096]); view_1568 = None + permute_993 = torch.ops.aten.permute.default(view_1573, [1, 0]) + mm_505 = torch.ops.aten.mm.default(permute_993, view_411); permute_993 = view_411 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16); primals_113 = None + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 256, '0'); convert_element_type_400 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + permute_995 = torch.ops.aten.permute.default(permute_132, [1, 0]); permute_132 = None + mm_506 = torch.ops.aten.mm.default(view_1573, permute_995); view_1573 = permute_995 = None + view_1574 = torch.ops.aten.view.default(mm_506, [2, 8192, 4096]); mm_506 = None + add_267 = torch.ops.aten.add.Tensor(add_266, view_1574); add_266 = view_1574 = None + convert_element_type_2147 = torch.ops.prims.convert_element_type.default(mm_505, torch.float32); mm_505 = None + reduce_scatter_tensor_180 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2147, 'avg', 256, '0'); convert_element_type_2147 = None + wait_tensor_471 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_180); reduce_scatter_tensor_180 = None + convert_element_type_2148 = torch.ops.prims.convert_element_type.default(add_267, torch.float32); add_267 = None + convert_element_type_2150 = torch.ops.prims.convert_element_type.default(wait_tensor_109, torch.float32); wait_tensor_109 = None + mul_658 = torch.ops.aten.mul.Tensor(convert_element_type_2148, convert_element_type_2150); convert_element_type_2150 = None + mul_660 = torch.ops.aten.mul.Tensor(mul_96, mul_658) + sum_121 = torch.ops.aten.sum.dim_IntList(mul_660, [2], True); mul_660 = None + div_40 = torch.ops.aten.div.Tensor(mul_96, 4096) + mul_661 = torch.ops.aten.mul.Tensor(div_40, sum_121); div_40 = sum_121 = None + sub_60 = torch.ops.aten.sub.Tensor(mul_658, mul_661); mul_658 = mul_661 = None + mul_662 = torch.ops.aten.mul.Tensor(sub_60, rsqrt_24); sub_60 = rsqrt_24 = None + mul_663 = torch.ops.aten.mul.Tensor(convert_element_type_2148, mul_96); convert_element_type_2148 = mul_96 = None + sum_122 = torch.ops.aten.sum.dim_IntList(mul_663, [0, 1]); mul_663 = None + convert_element_type_2151 = torch.ops.prims.convert_element_type.default(mul_662, torch.bfloat16); mul_662 = None + add_268 = torch.ops.aten.add.Tensor(add_265, convert_element_type_2151); add_265 = convert_element_type_2151 = None + convert_element_type_default_25 = torch.ops.prims.convert_element_type.default(sum_122, torch.float32); sum_122 = None + reduce_scatter_tensor_181 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_25, 'avg', 256, '0'); convert_element_type_default_25 = None + wait_tensor_472 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_181); reduce_scatter_tensor_181 = None + view_1575 = torch.ops.aten.view.default(add_268, [16384, 4096]) + permute_997 = torch.ops.aten.permute.default(view_1575, [1, 0]) + permute_127 = torch.ops.aten.permute.default(getitem_99, [0, 2, 1, 3]) + view_395 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 256, '0'); convert_element_type_380 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_397 = torch.ops.aten.view.default(view_395, [16384, 4096]); view_395 = None + mm_80 = torch.ops.aten.mm.default(view_397, permute_128) + view_398 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + add_45 = torch.ops.aten.add.Tensor(add_43, view_398); view_398 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 256, '0'); convert_element_type_383 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32); add_45 = None + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_105) + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + view_401 = torch.ops.aten.view.default(convert_element_type_385, [16384, 4096]); convert_element_type_385 = None + view_402 = torch.ops.aten.view.default(mm_81, [2, 8192, 14336]); mm_81 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_402, torch.float32); view_402 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16); primals_110 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 256, '0'); convert_element_type_391 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_107, [1, 0]); wait_tensor_107 = None + mm_82 = torch.ops.aten.mm.default(view_401, permute_130) + view_405 = torch.ops.aten.view.default(mm_82, [2, 8192, 14336]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_405) + view_407 = torch.ops.aten.view.default(mul_95, [16384, 14336]); mul_95 = None + mm_507 = torch.ops.aten.mm.default(permute_997, view_407); permute_997 = view_407 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16); primals_111 = None + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 256, '0'); convert_element_type_394 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + permute_999 = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None + mm_508 = torch.ops.aten.mm.default(view_1575, permute_999); view_1575 = permute_999 = None + view_1576 = torch.ops.aten.view.default(mm_508, [2, 8192, 14336]); mm_508 = None + convert_element_type_2158 = torch.ops.prims.convert_element_type.default(mm_507, torch.float32); mm_507 = None + reduce_scatter_tensor_182 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2158, 'avg', 256, '0'); convert_element_type_2158 = None + wait_tensor_473 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_182); reduce_scatter_tensor_182 = None + mul_664 = torch.ops.aten.mul.Tensor(view_1576, convert_element_type_390); convert_element_type_390 = None + mul_665 = torch.ops.aten.mul.Tensor(view_1576, view_405); view_1576 = view_405 = None + view_1577 = torch.ops.aten.view.default(mul_664, [16384, 14336]); mul_664 = None + permute_1001 = torch.ops.aten.permute.default(view_1577, [1, 0]) + mm_509 = torch.ops.aten.mm.default(permute_1001, view_401); permute_1001 = None + permute_1003 = torch.ops.aten.permute.default(permute_130, [1, 0]); permute_130 = None + mm_510 = torch.ops.aten.mm.default(view_1577, permute_1003); view_1577 = permute_1003 = None + view_1578 = torch.ops.aten.view.default(mm_510, [2, 8192, 4096]); mm_510 = None + convert_element_type_2163 = torch.ops.prims.convert_element_type.default(mm_509, torch.float32); mm_509 = None + reduce_scatter_tensor_183 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2163, 'avg', 256, '0'); convert_element_type_2163 = None + wait_tensor_474 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_183); reduce_scatter_tensor_183 = None + convert_element_type_2164 = torch.ops.prims.convert_element_type.default(mul_665, torch.float32); mul_665 = None + neg_20 = torch.ops.aten.neg.default(convert_element_type_389) + exp_20 = torch.ops.aten.exp.default(neg_20); neg_20 = None + add_269 = torch.ops.aten.add.Tensor(exp_20, 1); exp_20 = None + reciprocal_20 = torch.ops.aten.reciprocal.default(add_269); add_269 = None + mul_666 = torch.ops.aten.mul.Tensor(reciprocal_20, 1); reciprocal_20 = None + mul_667 = torch.ops.aten.mul.Tensor(convert_element_type_2164, mul_666); convert_element_type_2164 = None + sub_61 = torch.ops.aten.sub.Tensor(1, mul_666); mul_666 = None + mul_668 = torch.ops.aten.mul.Tensor(convert_element_type_389, sub_61); convert_element_type_389 = sub_61 = None + add_270 = torch.ops.aten.add.Tensor(mul_668, 1); mul_668 = None + mul_669 = torch.ops.aten.mul.Tensor(mul_667, add_270); mul_667 = add_270 = None + convert_element_type_2166 = torch.ops.prims.convert_element_type.default(mul_669, torch.bfloat16); mul_669 = None + view_1579 = torch.ops.aten.view.default(convert_element_type_2166, [16384, 14336]); convert_element_type_2166 = None + permute_1005 = torch.ops.aten.permute.default(view_1579, [1, 0]) + mm_511 = torch.ops.aten.mm.default(permute_1005, view_401); permute_1005 = view_401 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16); primals_109 = None + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 256, '0'); convert_element_type_386 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_106, [1, 0]); wait_tensor_106 = None + permute_1007 = torch.ops.aten.permute.default(permute_129, [1, 0]); permute_129 = None + mm_512 = torch.ops.aten.mm.default(view_1579, permute_1007); view_1579 = permute_1007 = None + view_1580 = torch.ops.aten.view.default(mm_512, [2, 8192, 4096]); mm_512 = None + add_271 = torch.ops.aten.add.Tensor(view_1578, view_1580); view_1578 = view_1580 = None + convert_element_type_2171 = torch.ops.prims.convert_element_type.default(mm_511, torch.float32); mm_511 = None + reduce_scatter_tensor_184 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2171, 'avg', 256, '0'); convert_element_type_2171 = None + wait_tensor_475 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_184); reduce_scatter_tensor_184 = None + convert_element_type_2172 = torch.ops.prims.convert_element_type.default(add_271, torch.float32); add_271 = None + convert_element_type_2174 = torch.ops.prims.convert_element_type.default(wait_tensor_105, torch.float32); wait_tensor_105 = None + mul_670 = torch.ops.aten.mul.Tensor(convert_element_type_2172, convert_element_type_2174); convert_element_type_2174 = None + mul_672 = torch.ops.aten.mul.Tensor(mul_92, mul_670) + sum_123 = torch.ops.aten.sum.dim_IntList(mul_672, [2], True); mul_672 = None + div_41 = torch.ops.aten.div.Tensor(mul_92, 4096) + mul_673 = torch.ops.aten.mul.Tensor(div_41, sum_123); div_41 = sum_123 = None + sub_62 = torch.ops.aten.sub.Tensor(mul_670, mul_673); mul_670 = mul_673 = None + mul_674 = torch.ops.aten.mul.Tensor(sub_62, rsqrt_23); sub_62 = rsqrt_23 = None + mul_675 = torch.ops.aten.mul.Tensor(convert_element_type_2172, mul_92); convert_element_type_2172 = mul_92 = None + sum_124 = torch.ops.aten.sum.dim_IntList(mul_675, [0, 1]); mul_675 = None + convert_element_type_2175 = torch.ops.prims.convert_element_type.default(mul_674, torch.bfloat16); mul_674 = None + add_272 = torch.ops.aten.add.Tensor(add_268, convert_element_type_2175); add_268 = convert_element_type_2175 = None + convert_element_type_default_24 = torch.ops.prims.convert_element_type.default(sum_124, torch.float32); sum_124 = None + reduce_scatter_tensor_185 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_24, 'avg', 256, '0'); convert_element_type_default_24 = None + wait_tensor_476 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_185); reduce_scatter_tensor_185 = None + view_1581 = torch.ops.aten.view.default(add_272, [16384, 4096]) + permute_1009 = torch.ops.aten.permute.default(view_1581, [1, 0]) + mm_513 = torch.ops.aten.mm.default(permute_1009, view_397); permute_1009 = view_397 = None + permute_1011 = torch.ops.aten.permute.default(permute_128, [1, 0]); permute_128 = None + mm_514 = torch.ops.aten.mm.default(view_1581, permute_1011); view_1581 = permute_1011 = None + view_1582 = torch.ops.aten.view.default(mm_514, [2, 8192, 4096]); mm_514 = None + convert_element_type_2182 = torch.ops.prims.convert_element_type.default(mm_513, torch.float32); mm_513 = None + reduce_scatter_tensor_186 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2182, 'avg', 256, '0'); convert_element_type_2182 = None + wait_tensor_477 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_186); reduce_scatter_tensor_186 = None + view_1583 = torch.ops.aten.view.default(view_1582, [2, 8192, 32, 128]); view_1582 = None + permute_1013 = torch.ops.aten.permute.default(view_1583, [0, 2, 1, 3]); view_1583 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 256, '0'); convert_element_type_364 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32); add_43 = None + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_100) + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + view_377 = torch.ops.aten.view.default(convert_element_type_366, [16384, 4096]); convert_element_type_366 = None + view_378 = torch.ops.aten.view.default(mm_77, [2, 8192, 4096]); mm_77 = None + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 256, '0'); convert_element_type_370 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + mm_78 = torch.ops.aten.mm.default(view_377, permute_122) + view_381 = torch.ops.aten.view.default(mm_78, [2, 8192, 1024]); mm_78 = None + view_384 = torch.ops.aten.view.default(mm_79, [2, 8192, 1024]); mm_79 = None + view_385 = torch.ops.aten.view.default(view_378, [2, 8192, -1, 128]); view_378 = None + view_386 = torch.ops.aten.view.default(view_381, [2, 8192, -1, 128]); view_381 = None + view_387 = torch.ops.aten.view.default(view_384, [2, 8192, -1, 128]); view_384 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_385, torch.float32); view_385 = None + view_388 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 32, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_388); view_388 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_386, torch.float32); view_386 = None + view_389 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 8, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_389); view_389 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_16); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_391 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 32, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_16); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_392 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 8, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_391, torch.bfloat16); view_391 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_392, torch.bfloat16); view_392 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 8, 4, 128]); unsqueeze_22 = None + clone_22 = torch.ops.aten.clone.default(expand_22, memory_format = torch.contiguous_format); expand_22 = None + view_393 = torch.ops.aten.view.default(clone_22, [2, 8192, 32, 128]); clone_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_387, 3); view_387 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 8, 4, 128]); unsqueeze_23 = None + clone_23 = torch.ops.aten.clone.default(expand_23, memory_format = torch.contiguous_format); expand_23 = None + view_394 = torch.ops.aten.view.default(clone_23, [2, 8192, 32, 128]); clone_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_393, [0, 2, 1, 3]); view_393 = None + permute_126 = torch.ops.aten.permute.default(view_394, [0, 2, 1, 3]); view_394 = None + _scaled_dot_product_cudnn_attention_backward_20 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1013, permute_124, permute_125, permute_126, getitem_99, getitem_100, getitem_105, getitem_106, None, None, None, 8192, 8192, 0.0, True); permute_1013 = permute_124 = permute_125 = permute_126 = getitem_99 = getitem_100 = getitem_105 = getitem_106 = None + getitem_348 = _scaled_dot_product_cudnn_attention_backward_20[0] + getitem_349 = _scaled_dot_product_cudnn_attention_backward_20[1] + getitem_350 = _scaled_dot_product_cudnn_attention_backward_20[2]; _scaled_dot_product_cudnn_attention_backward_20 = None + permute_1014 = torch.ops.aten.permute.default(getitem_350, [0, 2, 1, 3]); getitem_350 = None + permute_1015 = torch.ops.aten.permute.default(getitem_349, [0, 2, 1, 3]); getitem_349 = None + permute_1016 = torch.ops.aten.permute.default(getitem_348, [0, 2, 1, 3]); getitem_348 = None + view_1584 = torch.ops.aten.view.default(permute_1014, [2, 8192, 8, 4, 128]); permute_1014 = None + sum_125 = torch.ops.aten.sum.dim_IntList(view_1584, [3], True); view_1584 = None + squeeze_40 = torch.ops.aten.squeeze.dim(sum_125, 3); sum_125 = None + view_1585 = torch.ops.aten.view.default(permute_1015, [2, 8192, 8, 4, 128]); permute_1015 = None + sum_126 = torch.ops.aten.sum.dim_IntList(view_1585, [3], True); view_1585 = None + squeeze_41 = torch.ops.aten.squeeze.dim(sum_126, 3); sum_126 = None + convert_element_type_2183 = torch.ops.prims.convert_element_type.default(squeeze_41, torch.float32); squeeze_41 = None + convert_element_type_2184 = torch.ops.prims.convert_element_type.default(permute_1016, torch.float32); permute_1016 = None + view_1586 = torch.ops.aten.view.default(convert_element_type_2183, [2, 8192, 8, 64, 2]); convert_element_type_2183 = None + view_as_complex_104 = torch.ops.aten.view_as_complex.default(view_1586); view_1586 = None + mul_676 = torch.ops.aten.mul.Tensor(view_as_complex_104, _conj); view_as_complex_104 = None + view_1587 = torch.ops.aten.view.default(convert_element_type_2184, [2, 8192, 32, 64, 2]); convert_element_type_2184 = None + view_as_complex_105 = torch.ops.aten.view_as_complex.default(view_1587); view_1587 = None + mul_677 = torch.ops.aten.mul.Tensor(view_as_complex_105, _conj); view_as_complex_105 = None + view_as_real_104 = torch.ops.aten.view_as_real.default(mul_676); mul_676 = None + view_1588 = torch.ops.aten.view.default(view_as_real_104, [2, 8192, 8, 128]); view_as_real_104 = None + convert_element_type_2185 = torch.ops.prims.convert_element_type.default(view_1588, torch.bfloat16); view_1588 = None + view_as_real_105 = torch.ops.aten.view_as_real.default(mul_677); mul_677 = None + view_1589 = torch.ops.aten.view.default(view_as_real_105, [2, 8192, 32, 128]); view_as_real_105 = None + convert_element_type_2186 = torch.ops.prims.convert_element_type.default(view_1589, torch.bfloat16); view_1589 = None + view_1590 = torch.ops.aten.view.default(squeeze_40, [2, 8192, 1024]); squeeze_40 = None + view_1591 = torch.ops.aten.view.default(convert_element_type_2185, [2, 8192, 1024]); convert_element_type_2185 = None + view_1592 = torch.ops.aten.view.default(convert_element_type_2186, [2, 8192, 4096]); convert_element_type_2186 = None + view_1593 = torch.ops.aten.view.default(view_1590, [16384, 1024]); view_1590 = None + permute_1017 = torch.ops.aten.permute.default(view_1593, [1, 0]) + mm_515 = torch.ops.aten.mm.default(permute_1017, view_377); permute_1017 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 256, '0'); convert_element_type_373 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + permute_1019 = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None + mm_516 = torch.ops.aten.mm.default(view_1593, permute_1019); view_1593 = permute_1019 = None + view_1594 = torch.ops.aten.view.default(mm_516, [2, 8192, 4096]); mm_516 = None + convert_element_type_2191 = torch.ops.prims.convert_element_type.default(mm_515, torch.float32); mm_515 = None + reduce_scatter_tensor_187 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2191, 'avg', 256, '0'); convert_element_type_2191 = None + wait_tensor_478 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_187); reduce_scatter_tensor_187 = None + view_1595 = torch.ops.aten.view.default(view_1591, [16384, 1024]); view_1591 = None + permute_1021 = torch.ops.aten.permute.default(view_1595, [1, 0]) + mm_517 = torch.ops.aten.mm.default(permute_1021, view_377); permute_1021 = None + permute_1023 = torch.ops.aten.permute.default(permute_122, [1, 0]); permute_122 = None + mm_518 = torch.ops.aten.mm.default(view_1595, permute_1023); view_1595 = permute_1023 = None + view_1596 = torch.ops.aten.view.default(mm_518, [2, 8192, 4096]); mm_518 = None + add_273 = torch.ops.aten.add.Tensor(view_1594, view_1596); view_1594 = view_1596 = None + convert_element_type_2196 = torch.ops.prims.convert_element_type.default(mm_517, torch.float32); mm_517 = None + reduce_scatter_tensor_188 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2196, 'avg', 256, '0'); convert_element_type_2196 = None + wait_tensor_479 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_188); reduce_scatter_tensor_188 = None + view_1597 = torch.ops.aten.view.default(view_1592, [16384, 4096]); view_1592 = None + permute_1025 = torch.ops.aten.permute.default(view_1597, [1, 0]) + mm_519 = torch.ops.aten.mm.default(permute_1025, view_377); permute_1025 = view_377 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 256, '0'); convert_element_type_367 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_101, [1, 0]); wait_tensor_101 = None + permute_1027 = torch.ops.aten.permute.default(permute_121, [1, 0]); permute_121 = None + mm_520 = torch.ops.aten.mm.default(view_1597, permute_1027); view_1597 = permute_1027 = None + view_1598 = torch.ops.aten.view.default(mm_520, [2, 8192, 4096]); mm_520 = None + add_274 = torch.ops.aten.add.Tensor(add_273, view_1598); add_273 = view_1598 = None + convert_element_type_2201 = torch.ops.prims.convert_element_type.default(mm_519, torch.float32); mm_519 = None + reduce_scatter_tensor_189 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2201, 'avg', 256, '0'); convert_element_type_2201 = None + wait_tensor_480 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_189); reduce_scatter_tensor_189 = None + convert_element_type_2202 = torch.ops.prims.convert_element_type.default(add_274, torch.float32); add_274 = None + convert_element_type_2204 = torch.ops.prims.convert_element_type.default(wait_tensor_100, torch.float32); wait_tensor_100 = None + mul_678 = torch.ops.aten.mul.Tensor(convert_element_type_2202, convert_element_type_2204); convert_element_type_2204 = None + mul_680 = torch.ops.aten.mul.Tensor(mul_88, mul_678) + sum_127 = torch.ops.aten.sum.dim_IntList(mul_680, [2], True); mul_680 = None + div_42 = torch.ops.aten.div.Tensor(mul_88, 4096) + mul_681 = torch.ops.aten.mul.Tensor(div_42, sum_127); div_42 = sum_127 = None + sub_63 = torch.ops.aten.sub.Tensor(mul_678, mul_681); mul_678 = mul_681 = None + mul_682 = torch.ops.aten.mul.Tensor(sub_63, rsqrt_22); sub_63 = rsqrt_22 = None + mul_683 = torch.ops.aten.mul.Tensor(convert_element_type_2202, mul_88); convert_element_type_2202 = mul_88 = None + sum_128 = torch.ops.aten.sum.dim_IntList(mul_683, [0, 1]); mul_683 = None + convert_element_type_2205 = torch.ops.prims.convert_element_type.default(mul_682, torch.bfloat16); mul_682 = None + add_275 = torch.ops.aten.add.Tensor(add_272, convert_element_type_2205); add_272 = convert_element_type_2205 = None + convert_element_type_default_23 = torch.ops.prims.convert_element_type.default(sum_128, torch.float32); sum_128 = None + reduce_scatter_tensor_190 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_23, 'avg', 256, '0'); convert_element_type_default_23 = None + wait_tensor_481 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_190); reduce_scatter_tensor_190 = None + view_1599 = torch.ops.aten.view.default(add_275, [16384, 4096]) + permute_1029 = torch.ops.aten.permute.default(view_1599, [1, 0]) + permute_116 = torch.ops.aten.permute.default(getitem_90, [0, 2, 1, 3]) + view_361 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16); primals_98 = None + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 256, '0'); convert_element_type_347 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_363 = torch.ops.aten.view.default(view_361, [16384, 4096]); view_361 = None + mm_73 = torch.ops.aten.mm.default(view_363, permute_117) + view_364 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + add_41 = torch.ops.aten.add.Tensor(add_39, view_364); view_364 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16); primals_99 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 256, '0'); convert_element_type_350 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32); add_41 = None + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_96) + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + view_367 = torch.ops.aten.view.default(convert_element_type_352, [16384, 4096]); convert_element_type_352 = None + view_368 = torch.ops.aten.view.default(mm_74, [2, 8192, 14336]); mm_74 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_368, torch.float32); view_368 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 256, '0'); convert_element_type_358 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + mm_75 = torch.ops.aten.mm.default(view_367, permute_119) + view_371 = torch.ops.aten.view.default(mm_75, [2, 8192, 14336]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_371) + view_373 = torch.ops.aten.view.default(mul_87, [16384, 14336]); mul_87 = None + mm_521 = torch.ops.aten.mm.default(permute_1029, view_373); permute_1029 = view_373 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 256, '0'); convert_element_type_361 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + permute_1031 = torch.ops.aten.permute.default(permute_120, [1, 0]); permute_120 = None + mm_522 = torch.ops.aten.mm.default(view_1599, permute_1031); view_1599 = permute_1031 = None + view_1600 = torch.ops.aten.view.default(mm_522, [2, 8192, 14336]); mm_522 = None + convert_element_type_2212 = torch.ops.prims.convert_element_type.default(mm_521, torch.float32); mm_521 = None + reduce_scatter_tensor_191 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2212, 'avg', 256, '0'); convert_element_type_2212 = None + wait_tensor_482 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_191); reduce_scatter_tensor_191 = None + mul_684 = torch.ops.aten.mul.Tensor(view_1600, convert_element_type_357); convert_element_type_357 = None + mul_685 = torch.ops.aten.mul.Tensor(view_1600, view_371); view_1600 = view_371 = None + view_1601 = torch.ops.aten.view.default(mul_684, [16384, 14336]); mul_684 = None + permute_1033 = torch.ops.aten.permute.default(view_1601, [1, 0]) + mm_523 = torch.ops.aten.mm.default(permute_1033, view_367); permute_1033 = None + permute_1035 = torch.ops.aten.permute.default(permute_119, [1, 0]); permute_119 = None + mm_524 = torch.ops.aten.mm.default(view_1601, permute_1035); view_1601 = permute_1035 = None + view_1602 = torch.ops.aten.view.default(mm_524, [2, 8192, 4096]); mm_524 = None + convert_element_type_2217 = torch.ops.prims.convert_element_type.default(mm_523, torch.float32); mm_523 = None + reduce_scatter_tensor_192 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2217, 'avg', 256, '0'); convert_element_type_2217 = None + wait_tensor_483 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_192); reduce_scatter_tensor_192 = None + convert_element_type_2218 = torch.ops.prims.convert_element_type.default(mul_685, torch.float32); mul_685 = None + neg_21 = torch.ops.aten.neg.default(convert_element_type_356) + exp_21 = torch.ops.aten.exp.default(neg_21); neg_21 = None + add_276 = torch.ops.aten.add.Tensor(exp_21, 1); exp_21 = None + reciprocal_21 = torch.ops.aten.reciprocal.default(add_276); add_276 = None + mul_686 = torch.ops.aten.mul.Tensor(reciprocal_21, 1); reciprocal_21 = None + mul_687 = torch.ops.aten.mul.Tensor(convert_element_type_2218, mul_686); convert_element_type_2218 = None + sub_64 = torch.ops.aten.sub.Tensor(1, mul_686); mul_686 = None + mul_688 = torch.ops.aten.mul.Tensor(convert_element_type_356, sub_64); convert_element_type_356 = sub_64 = None + add_277 = torch.ops.aten.add.Tensor(mul_688, 1); mul_688 = None + mul_689 = torch.ops.aten.mul.Tensor(mul_687, add_277); mul_687 = add_277 = None + convert_element_type_2220 = torch.ops.prims.convert_element_type.default(mul_689, torch.bfloat16); mul_689 = None + view_1603 = torch.ops.aten.view.default(convert_element_type_2220, [16384, 14336]); convert_element_type_2220 = None + permute_1037 = torch.ops.aten.permute.default(view_1603, [1, 0]) + mm_525 = torch.ops.aten.mm.default(permute_1037, view_367); permute_1037 = view_367 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 256, '0'); convert_element_type_353 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + permute_1039 = torch.ops.aten.permute.default(permute_118, [1, 0]); permute_118 = None + mm_526 = torch.ops.aten.mm.default(view_1603, permute_1039); view_1603 = permute_1039 = None + view_1604 = torch.ops.aten.view.default(mm_526, [2, 8192, 4096]); mm_526 = None + add_278 = torch.ops.aten.add.Tensor(view_1602, view_1604); view_1602 = view_1604 = None + convert_element_type_2225 = torch.ops.prims.convert_element_type.default(mm_525, torch.float32); mm_525 = None + reduce_scatter_tensor_193 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2225, 'avg', 256, '0'); convert_element_type_2225 = None + wait_tensor_484 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_193); reduce_scatter_tensor_193 = None + convert_element_type_2226 = torch.ops.prims.convert_element_type.default(add_278, torch.float32); add_278 = None + convert_element_type_2228 = torch.ops.prims.convert_element_type.default(wait_tensor_96, torch.float32); wait_tensor_96 = None + mul_690 = torch.ops.aten.mul.Tensor(convert_element_type_2226, convert_element_type_2228); convert_element_type_2228 = None + mul_692 = torch.ops.aten.mul.Tensor(mul_84, mul_690) + sum_129 = torch.ops.aten.sum.dim_IntList(mul_692, [2], True); mul_692 = None + div_43 = torch.ops.aten.div.Tensor(mul_84, 4096) + mul_693 = torch.ops.aten.mul.Tensor(div_43, sum_129); div_43 = sum_129 = None + sub_65 = torch.ops.aten.sub.Tensor(mul_690, mul_693); mul_690 = mul_693 = None + mul_694 = torch.ops.aten.mul.Tensor(sub_65, rsqrt_21); sub_65 = rsqrt_21 = None + mul_695 = torch.ops.aten.mul.Tensor(convert_element_type_2226, mul_84); convert_element_type_2226 = mul_84 = None + sum_130 = torch.ops.aten.sum.dim_IntList(mul_695, [0, 1]); mul_695 = None + convert_element_type_2229 = torch.ops.prims.convert_element_type.default(mul_694, torch.bfloat16); mul_694 = None + add_279 = torch.ops.aten.add.Tensor(add_275, convert_element_type_2229); add_275 = convert_element_type_2229 = None + convert_element_type_default_22 = torch.ops.prims.convert_element_type.default(sum_130, torch.float32); sum_130 = None + reduce_scatter_tensor_194 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_22, 'avg', 256, '0'); convert_element_type_default_22 = None + wait_tensor_485 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_194); reduce_scatter_tensor_194 = None + view_1605 = torch.ops.aten.view.default(add_279, [16384, 4096]) + permute_1041 = torch.ops.aten.permute.default(view_1605, [1, 0]) + mm_527 = torch.ops.aten.mm.default(permute_1041, view_363); permute_1041 = view_363 = None + permute_1043 = torch.ops.aten.permute.default(permute_117, [1, 0]); permute_117 = None + mm_528 = torch.ops.aten.mm.default(view_1605, permute_1043); view_1605 = permute_1043 = None + view_1606 = torch.ops.aten.view.default(mm_528, [2, 8192, 4096]); mm_528 = None + convert_element_type_2236 = torch.ops.prims.convert_element_type.default(mm_527, torch.float32); mm_527 = None + reduce_scatter_tensor_195 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2236, 'avg', 256, '0'); convert_element_type_2236 = None + wait_tensor_486 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_195); reduce_scatter_tensor_195 = None + view_1607 = torch.ops.aten.view.default(view_1606, [2, 8192, 32, 128]); view_1606 = None + permute_1045 = torch.ops.aten.permute.default(view_1607, [0, 2, 1, 3]); view_1607 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16); primals_94 = None + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 256, '0'); convert_element_type_331 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32); add_39 = None + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_91) + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + view_343 = torch.ops.aten.view.default(convert_element_type_333, [16384, 4096]); convert_element_type_333 = None + view_344 = torch.ops.aten.view.default(mm_70, [2, 8192, 4096]); mm_70 = None + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16); primals_96 = None + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 256, '0'); convert_element_type_337 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + mm_71 = torch.ops.aten.mm.default(view_343, permute_111) + view_347 = torch.ops.aten.view.default(mm_71, [2, 8192, 1024]); mm_71 = None + view_350 = torch.ops.aten.view.default(mm_72, [2, 8192, 1024]); mm_72 = None + view_351 = torch.ops.aten.view.default(view_344, [2, 8192, -1, 128]); view_344 = None + view_352 = torch.ops.aten.view.default(view_347, [2, 8192, -1, 128]); view_347 = None + view_353 = torch.ops.aten.view.default(view_350, [2, 8192, -1, 128]); view_350 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_351, torch.float32); view_351 = None + view_354 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 32, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_354); view_354 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_352, torch.float32); view_352 = None + view_355 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 8, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_355); view_355 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_16); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_357 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 32, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_16); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_358 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 8, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_357, torch.bfloat16); view_357 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_358, torch.bfloat16); view_358 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 8, 4, 128]); unsqueeze_20 = None + clone_20 = torch.ops.aten.clone.default(expand_20, memory_format = torch.contiguous_format); expand_20 = None + view_359 = torch.ops.aten.view.default(clone_20, [2, 8192, 32, 128]); clone_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_353, 3); view_353 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 8, 4, 128]); unsqueeze_21 = None + clone_21 = torch.ops.aten.clone.default(expand_21, memory_format = torch.contiguous_format); expand_21 = None + view_360 = torch.ops.aten.view.default(clone_21, [2, 8192, 32, 128]); clone_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_359, [0, 2, 1, 3]); view_359 = None + permute_115 = torch.ops.aten.permute.default(view_360, [0, 2, 1, 3]); view_360 = None + _scaled_dot_product_cudnn_attention_backward_21 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1045, permute_113, permute_114, permute_115, getitem_90, getitem_91, getitem_96, getitem_97, None, None, None, 8192, 8192, 0.0, True); permute_1045 = permute_113 = permute_114 = permute_115 = getitem_90 = getitem_91 = getitem_96 = getitem_97 = None + getitem_351 = _scaled_dot_product_cudnn_attention_backward_21[0] + getitem_352 = _scaled_dot_product_cudnn_attention_backward_21[1] + getitem_353 = _scaled_dot_product_cudnn_attention_backward_21[2]; _scaled_dot_product_cudnn_attention_backward_21 = None + permute_1046 = torch.ops.aten.permute.default(getitem_353, [0, 2, 1, 3]); getitem_353 = None + permute_1047 = torch.ops.aten.permute.default(getitem_352, [0, 2, 1, 3]); getitem_352 = None + permute_1048 = torch.ops.aten.permute.default(getitem_351, [0, 2, 1, 3]); getitem_351 = None + view_1608 = torch.ops.aten.view.default(permute_1046, [2, 8192, 8, 4, 128]); permute_1046 = None + sum_131 = torch.ops.aten.sum.dim_IntList(view_1608, [3], True); view_1608 = None + squeeze_42 = torch.ops.aten.squeeze.dim(sum_131, 3); sum_131 = None + view_1609 = torch.ops.aten.view.default(permute_1047, [2, 8192, 8, 4, 128]); permute_1047 = None + sum_132 = torch.ops.aten.sum.dim_IntList(view_1609, [3], True); view_1609 = None + squeeze_43 = torch.ops.aten.squeeze.dim(sum_132, 3); sum_132 = None + convert_element_type_2237 = torch.ops.prims.convert_element_type.default(squeeze_43, torch.float32); squeeze_43 = None + convert_element_type_2238 = torch.ops.prims.convert_element_type.default(permute_1048, torch.float32); permute_1048 = None + view_1610 = torch.ops.aten.view.default(convert_element_type_2237, [2, 8192, 8, 64, 2]); convert_element_type_2237 = None + view_as_complex_106 = torch.ops.aten.view_as_complex.default(view_1610); view_1610 = None + mul_696 = torch.ops.aten.mul.Tensor(view_as_complex_106, _conj); view_as_complex_106 = None + view_1611 = torch.ops.aten.view.default(convert_element_type_2238, [2, 8192, 32, 64, 2]); convert_element_type_2238 = None + view_as_complex_107 = torch.ops.aten.view_as_complex.default(view_1611); view_1611 = None + mul_697 = torch.ops.aten.mul.Tensor(view_as_complex_107, _conj); view_as_complex_107 = None + view_as_real_106 = torch.ops.aten.view_as_real.default(mul_696); mul_696 = None + view_1612 = torch.ops.aten.view.default(view_as_real_106, [2, 8192, 8, 128]); view_as_real_106 = None + convert_element_type_2239 = torch.ops.prims.convert_element_type.default(view_1612, torch.bfloat16); view_1612 = None + view_as_real_107 = torch.ops.aten.view_as_real.default(mul_697); mul_697 = None + view_1613 = torch.ops.aten.view.default(view_as_real_107, [2, 8192, 32, 128]); view_as_real_107 = None + convert_element_type_2240 = torch.ops.prims.convert_element_type.default(view_1613, torch.bfloat16); view_1613 = None + view_1614 = torch.ops.aten.view.default(squeeze_42, [2, 8192, 1024]); squeeze_42 = None + view_1615 = torch.ops.aten.view.default(convert_element_type_2239, [2, 8192, 1024]); convert_element_type_2239 = None + view_1616 = torch.ops.aten.view.default(convert_element_type_2240, [2, 8192, 4096]); convert_element_type_2240 = None + view_1617 = torch.ops.aten.view.default(view_1614, [16384, 1024]); view_1614 = None + permute_1049 = torch.ops.aten.permute.default(view_1617, [1, 0]) + mm_529 = torch.ops.aten.mm.default(permute_1049, view_343); permute_1049 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16); primals_97 = None + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 256, '0'); convert_element_type_340 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + permute_1051 = torch.ops.aten.permute.default(permute_112, [1, 0]); permute_112 = None + mm_530 = torch.ops.aten.mm.default(view_1617, permute_1051); view_1617 = permute_1051 = None + view_1618 = torch.ops.aten.view.default(mm_530, [2, 8192, 4096]); mm_530 = None + convert_element_type_2245 = torch.ops.prims.convert_element_type.default(mm_529, torch.float32); mm_529 = None + reduce_scatter_tensor_196 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2245, 'avg', 256, '0'); convert_element_type_2245 = None + wait_tensor_487 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_196); reduce_scatter_tensor_196 = None + view_1619 = torch.ops.aten.view.default(view_1615, [16384, 1024]); view_1615 = None + permute_1053 = torch.ops.aten.permute.default(view_1619, [1, 0]) + mm_531 = torch.ops.aten.mm.default(permute_1053, view_343); permute_1053 = None + permute_1055 = torch.ops.aten.permute.default(permute_111, [1, 0]); permute_111 = None + mm_532 = torch.ops.aten.mm.default(view_1619, permute_1055); view_1619 = permute_1055 = None + view_1620 = torch.ops.aten.view.default(mm_532, [2, 8192, 4096]); mm_532 = None + add_280 = torch.ops.aten.add.Tensor(view_1618, view_1620); view_1618 = view_1620 = None + convert_element_type_2250 = torch.ops.prims.convert_element_type.default(mm_531, torch.float32); mm_531 = None + reduce_scatter_tensor_197 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2250, 'avg', 256, '0'); convert_element_type_2250 = None + wait_tensor_488 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_197); reduce_scatter_tensor_197 = None + view_1621 = torch.ops.aten.view.default(view_1616, [16384, 4096]); view_1616 = None + permute_1057 = torch.ops.aten.permute.default(view_1621, [1, 0]) + mm_533 = torch.ops.aten.mm.default(permute_1057, view_343); permute_1057 = view_343 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16); primals_95 = None + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 256, '0'); convert_element_type_334 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + permute_1059 = torch.ops.aten.permute.default(permute_110, [1, 0]); permute_110 = None + mm_534 = torch.ops.aten.mm.default(view_1621, permute_1059); view_1621 = permute_1059 = None + view_1622 = torch.ops.aten.view.default(mm_534, [2, 8192, 4096]); mm_534 = None + add_281 = torch.ops.aten.add.Tensor(add_280, view_1622); add_280 = view_1622 = None + convert_element_type_2255 = torch.ops.prims.convert_element_type.default(mm_533, torch.float32); mm_533 = None + reduce_scatter_tensor_198 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2255, 'avg', 256, '0'); convert_element_type_2255 = None + wait_tensor_489 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_198); reduce_scatter_tensor_198 = None + convert_element_type_2256 = torch.ops.prims.convert_element_type.default(add_281, torch.float32); add_281 = None + convert_element_type_2258 = torch.ops.prims.convert_element_type.default(wait_tensor_91, torch.float32); wait_tensor_91 = None + mul_698 = torch.ops.aten.mul.Tensor(convert_element_type_2256, convert_element_type_2258); convert_element_type_2258 = None + mul_700 = torch.ops.aten.mul.Tensor(mul_80, mul_698) + sum_133 = torch.ops.aten.sum.dim_IntList(mul_700, [2], True); mul_700 = None + div_44 = torch.ops.aten.div.Tensor(mul_80, 4096) + mul_701 = torch.ops.aten.mul.Tensor(div_44, sum_133); div_44 = sum_133 = None + sub_66 = torch.ops.aten.sub.Tensor(mul_698, mul_701); mul_698 = mul_701 = None + mul_702 = torch.ops.aten.mul.Tensor(sub_66, rsqrt_20); sub_66 = rsqrt_20 = None + mul_703 = torch.ops.aten.mul.Tensor(convert_element_type_2256, mul_80); convert_element_type_2256 = mul_80 = None + sum_134 = torch.ops.aten.sum.dim_IntList(mul_703, [0, 1]); mul_703 = None + convert_element_type_2259 = torch.ops.prims.convert_element_type.default(mul_702, torch.bfloat16); mul_702 = None + add_282 = torch.ops.aten.add.Tensor(add_279, convert_element_type_2259); add_279 = convert_element_type_2259 = None + convert_element_type_default_21 = torch.ops.prims.convert_element_type.default(sum_134, torch.float32); sum_134 = None + reduce_scatter_tensor_199 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_21, 'avg', 256, '0'); convert_element_type_default_21 = None + wait_tensor_490 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_199); reduce_scatter_tensor_199 = None + view_1623 = torch.ops.aten.view.default(add_282, [16384, 4096]) + permute_1061 = torch.ops.aten.permute.default(view_1623, [1, 0]) + permute_105 = torch.ops.aten.permute.default(getitem_81, [0, 2, 1, 3]) + view_327 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 256, '0'); convert_element_type_314 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_86, [1, 0]); wait_tensor_86 = None + view_329 = torch.ops.aten.view.default(view_327, [16384, 4096]); view_327 = None + mm_66 = torch.ops.aten.mm.default(view_329, permute_106) + view_330 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + add_37 = torch.ops.aten.add.Tensor(add_35, view_330); view_330 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 256, '0'); convert_element_type_317 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32); add_37 = None + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_87) + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + view_333 = torch.ops.aten.view.default(convert_element_type_319, [16384, 4096]); convert_element_type_319 = None + view_334 = torch.ops.aten.view.default(mm_67, [2, 8192, 14336]); mm_67 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_334, torch.float32); view_334 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16); primals_92 = None + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 256, '0'); convert_element_type_325 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + mm_68 = torch.ops.aten.mm.default(view_333, permute_108) + view_337 = torch.ops.aten.view.default(mm_68, [2, 8192, 14336]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_337) + view_339 = torch.ops.aten.view.default(mul_79, [16384, 14336]); mul_79 = None + mm_535 = torch.ops.aten.mm.default(permute_1061, view_339); permute_1061 = view_339 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16); primals_93 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 256, '0'); convert_element_type_328 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + permute_1063 = torch.ops.aten.permute.default(permute_109, [1, 0]); permute_109 = None + mm_536 = torch.ops.aten.mm.default(view_1623, permute_1063); view_1623 = permute_1063 = None + view_1624 = torch.ops.aten.view.default(mm_536, [2, 8192, 14336]); mm_536 = None + convert_element_type_2266 = torch.ops.prims.convert_element_type.default(mm_535, torch.float32); mm_535 = None + reduce_scatter_tensor_200 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2266, 'avg', 256, '0'); convert_element_type_2266 = None + wait_tensor_491 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_200); reduce_scatter_tensor_200 = None + mul_704 = torch.ops.aten.mul.Tensor(view_1624, convert_element_type_324); convert_element_type_324 = None + mul_705 = torch.ops.aten.mul.Tensor(view_1624, view_337); view_1624 = view_337 = None + view_1625 = torch.ops.aten.view.default(mul_704, [16384, 14336]); mul_704 = None + permute_1065 = torch.ops.aten.permute.default(view_1625, [1, 0]) + mm_537 = torch.ops.aten.mm.default(permute_1065, view_333); permute_1065 = None + permute_1067 = torch.ops.aten.permute.default(permute_108, [1, 0]); permute_108 = None + mm_538 = torch.ops.aten.mm.default(view_1625, permute_1067); view_1625 = permute_1067 = None + view_1626 = torch.ops.aten.view.default(mm_538, [2, 8192, 4096]); mm_538 = None + convert_element_type_2271 = torch.ops.prims.convert_element_type.default(mm_537, torch.float32); mm_537 = None + reduce_scatter_tensor_201 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2271, 'avg', 256, '0'); convert_element_type_2271 = None + wait_tensor_492 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_201); reduce_scatter_tensor_201 = None + convert_element_type_2272 = torch.ops.prims.convert_element_type.default(mul_705, torch.float32); mul_705 = None + neg_22 = torch.ops.aten.neg.default(convert_element_type_323) + exp_22 = torch.ops.aten.exp.default(neg_22); neg_22 = None + add_283 = torch.ops.aten.add.Tensor(exp_22, 1); exp_22 = None + reciprocal_22 = torch.ops.aten.reciprocal.default(add_283); add_283 = None + mul_706 = torch.ops.aten.mul.Tensor(reciprocal_22, 1); reciprocal_22 = None + mul_707 = torch.ops.aten.mul.Tensor(convert_element_type_2272, mul_706); convert_element_type_2272 = None + sub_67 = torch.ops.aten.sub.Tensor(1, mul_706); mul_706 = None + mul_708 = torch.ops.aten.mul.Tensor(convert_element_type_323, sub_67); convert_element_type_323 = sub_67 = None + add_284 = torch.ops.aten.add.Tensor(mul_708, 1); mul_708 = None + mul_709 = torch.ops.aten.mul.Tensor(mul_707, add_284); mul_707 = add_284 = None + convert_element_type_2274 = torch.ops.prims.convert_element_type.default(mul_709, torch.bfloat16); mul_709 = None + view_1627 = torch.ops.aten.view.default(convert_element_type_2274, [16384, 14336]); convert_element_type_2274 = None + permute_1069 = torch.ops.aten.permute.default(view_1627, [1, 0]) + mm_539 = torch.ops.aten.mm.default(permute_1069, view_333); permute_1069 = view_333 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16); primals_91 = None + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 256, '0'); convert_element_type_320 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_88, [1, 0]); wait_tensor_88 = None + permute_1071 = torch.ops.aten.permute.default(permute_107, [1, 0]); permute_107 = None + mm_540 = torch.ops.aten.mm.default(view_1627, permute_1071); view_1627 = permute_1071 = None + view_1628 = torch.ops.aten.view.default(mm_540, [2, 8192, 4096]); mm_540 = None + add_285 = torch.ops.aten.add.Tensor(view_1626, view_1628); view_1626 = view_1628 = None + convert_element_type_2279 = torch.ops.prims.convert_element_type.default(mm_539, torch.float32); mm_539 = None + reduce_scatter_tensor_202 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2279, 'avg', 256, '0'); convert_element_type_2279 = None + wait_tensor_493 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_202); reduce_scatter_tensor_202 = None + convert_element_type_2280 = torch.ops.prims.convert_element_type.default(add_285, torch.float32); add_285 = None + convert_element_type_2282 = torch.ops.prims.convert_element_type.default(wait_tensor_87, torch.float32); wait_tensor_87 = None + mul_710 = torch.ops.aten.mul.Tensor(convert_element_type_2280, convert_element_type_2282); convert_element_type_2282 = None + mul_712 = torch.ops.aten.mul.Tensor(mul_76, mul_710) + sum_135 = torch.ops.aten.sum.dim_IntList(mul_712, [2], True); mul_712 = None + div_45 = torch.ops.aten.div.Tensor(mul_76, 4096) + mul_713 = torch.ops.aten.mul.Tensor(div_45, sum_135); div_45 = sum_135 = None + sub_68 = torch.ops.aten.sub.Tensor(mul_710, mul_713); mul_710 = mul_713 = None + mul_714 = torch.ops.aten.mul.Tensor(sub_68, rsqrt_19); sub_68 = rsqrt_19 = None + mul_715 = torch.ops.aten.mul.Tensor(convert_element_type_2280, mul_76); convert_element_type_2280 = mul_76 = None + sum_136 = torch.ops.aten.sum.dim_IntList(mul_715, [0, 1]); mul_715 = None + convert_element_type_2283 = torch.ops.prims.convert_element_type.default(mul_714, torch.bfloat16); mul_714 = None + add_286 = torch.ops.aten.add.Tensor(add_282, convert_element_type_2283); add_282 = convert_element_type_2283 = None + convert_element_type_default_20 = torch.ops.prims.convert_element_type.default(sum_136, torch.float32); sum_136 = None + reduce_scatter_tensor_203 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_20, 'avg', 256, '0'); convert_element_type_default_20 = None + wait_tensor_494 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_203); reduce_scatter_tensor_203 = None + view_1629 = torch.ops.aten.view.default(add_286, [16384, 4096]) + permute_1073 = torch.ops.aten.permute.default(view_1629, [1, 0]) + mm_541 = torch.ops.aten.mm.default(permute_1073, view_329); permute_1073 = view_329 = None + permute_1075 = torch.ops.aten.permute.default(permute_106, [1, 0]); permute_106 = None + mm_542 = torch.ops.aten.mm.default(view_1629, permute_1075); view_1629 = permute_1075 = None + view_1630 = torch.ops.aten.view.default(mm_542, [2, 8192, 4096]); mm_542 = None + convert_element_type_2290 = torch.ops.prims.convert_element_type.default(mm_541, torch.float32); mm_541 = None + reduce_scatter_tensor_204 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2290, 'avg', 256, '0'); convert_element_type_2290 = None + wait_tensor_495 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_204); reduce_scatter_tensor_204 = None + view_1631 = torch.ops.aten.view.default(view_1630, [2, 8192, 32, 128]); view_1630 = None + permute_1077 = torch.ops.aten.permute.default(view_1631, [0, 2, 1, 3]); view_1631 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 256, '0'); convert_element_type_298 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_82) + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + view_309 = torch.ops.aten.view.default(convert_element_type_300, [16384, 4096]); convert_element_type_300 = None + view_310 = torch.ops.aten.view.default(mm_63, [2, 8192, 4096]); mm_63 = None + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 256, '0'); convert_element_type_304 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_64 = torch.ops.aten.mm.default(view_309, permute_100) + view_313 = torch.ops.aten.view.default(mm_64, [2, 8192, 1024]); mm_64 = None + view_316 = torch.ops.aten.view.default(mm_65, [2, 8192, 1024]); mm_65 = None + view_317 = torch.ops.aten.view.default(view_310, [2, 8192, -1, 128]); view_310 = None + view_318 = torch.ops.aten.view.default(view_313, [2, 8192, -1, 128]); view_313 = None + view_319 = torch.ops.aten.view.default(view_316, [2, 8192, -1, 128]); view_316 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_317, torch.float32); view_317 = None + view_320 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 32, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_320); view_320 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_318, torch.float32); view_318 = None + view_321 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 8, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_321); view_321 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_16); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_323 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 32, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_16); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_324 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 8, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_323, torch.bfloat16); view_323 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_324, torch.bfloat16); view_324 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 8, 4, 128]); unsqueeze_18 = None + clone_18 = torch.ops.aten.clone.default(expand_18, memory_format = torch.contiguous_format); expand_18 = None + view_325 = torch.ops.aten.view.default(clone_18, [2, 8192, 32, 128]); clone_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_319, 3); view_319 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 8, 4, 128]); unsqueeze_19 = None + clone_19 = torch.ops.aten.clone.default(expand_19, memory_format = torch.contiguous_format); expand_19 = None + view_326 = torch.ops.aten.view.default(clone_19, [2, 8192, 32, 128]); clone_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_325, [0, 2, 1, 3]); view_325 = None + permute_104 = torch.ops.aten.permute.default(view_326, [0, 2, 1, 3]); view_326 = None + _scaled_dot_product_cudnn_attention_backward_22 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1077, permute_102, permute_103, permute_104, getitem_81, getitem_82, getitem_87, getitem_88, None, None, None, 8192, 8192, 0.0, True); permute_1077 = permute_102 = permute_103 = permute_104 = getitem_81 = getitem_82 = getitem_87 = getitem_88 = None + getitem_354 = _scaled_dot_product_cudnn_attention_backward_22[0] + getitem_355 = _scaled_dot_product_cudnn_attention_backward_22[1] + getitem_356 = _scaled_dot_product_cudnn_attention_backward_22[2]; _scaled_dot_product_cudnn_attention_backward_22 = None + permute_1078 = torch.ops.aten.permute.default(getitem_356, [0, 2, 1, 3]); getitem_356 = None + permute_1079 = torch.ops.aten.permute.default(getitem_355, [0, 2, 1, 3]); getitem_355 = None + permute_1080 = torch.ops.aten.permute.default(getitem_354, [0, 2, 1, 3]); getitem_354 = None + view_1632 = torch.ops.aten.view.default(permute_1078, [2, 8192, 8, 4, 128]); permute_1078 = None + sum_137 = torch.ops.aten.sum.dim_IntList(view_1632, [3], True); view_1632 = None + squeeze_44 = torch.ops.aten.squeeze.dim(sum_137, 3); sum_137 = None + view_1633 = torch.ops.aten.view.default(permute_1079, [2, 8192, 8, 4, 128]); permute_1079 = None + sum_138 = torch.ops.aten.sum.dim_IntList(view_1633, [3], True); view_1633 = None + squeeze_45 = torch.ops.aten.squeeze.dim(sum_138, 3); sum_138 = None + convert_element_type_2291 = torch.ops.prims.convert_element_type.default(squeeze_45, torch.float32); squeeze_45 = None + convert_element_type_2292 = torch.ops.prims.convert_element_type.default(permute_1080, torch.float32); permute_1080 = None + view_1634 = torch.ops.aten.view.default(convert_element_type_2291, [2, 8192, 8, 64, 2]); convert_element_type_2291 = None + view_as_complex_108 = torch.ops.aten.view_as_complex.default(view_1634); view_1634 = None + mul_716 = torch.ops.aten.mul.Tensor(view_as_complex_108, _conj); view_as_complex_108 = None + view_1635 = torch.ops.aten.view.default(convert_element_type_2292, [2, 8192, 32, 64, 2]); convert_element_type_2292 = None + view_as_complex_109 = torch.ops.aten.view_as_complex.default(view_1635); view_1635 = None + mul_717 = torch.ops.aten.mul.Tensor(view_as_complex_109, _conj); view_as_complex_109 = None + view_as_real_108 = torch.ops.aten.view_as_real.default(mul_716); mul_716 = None + view_1636 = torch.ops.aten.view.default(view_as_real_108, [2, 8192, 8, 128]); view_as_real_108 = None + convert_element_type_2293 = torch.ops.prims.convert_element_type.default(view_1636, torch.bfloat16); view_1636 = None + view_as_real_109 = torch.ops.aten.view_as_real.default(mul_717); mul_717 = None + view_1637 = torch.ops.aten.view.default(view_as_real_109, [2, 8192, 32, 128]); view_as_real_109 = None + convert_element_type_2294 = torch.ops.prims.convert_element_type.default(view_1637, torch.bfloat16); view_1637 = None + view_1638 = torch.ops.aten.view.default(squeeze_44, [2, 8192, 1024]); squeeze_44 = None + view_1639 = torch.ops.aten.view.default(convert_element_type_2293, [2, 8192, 1024]); convert_element_type_2293 = None + view_1640 = torch.ops.aten.view.default(convert_element_type_2294, [2, 8192, 4096]); convert_element_type_2294 = None + view_1641 = torch.ops.aten.view.default(view_1638, [16384, 1024]); view_1638 = None + permute_1081 = torch.ops.aten.permute.default(view_1641, [1, 0]) + mm_543 = torch.ops.aten.mm.default(permute_1081, view_309); permute_1081 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 256, '0'); convert_element_type_307 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + permute_1083 = torch.ops.aten.permute.default(permute_101, [1, 0]); permute_101 = None + mm_544 = torch.ops.aten.mm.default(view_1641, permute_1083); view_1641 = permute_1083 = None + view_1642 = torch.ops.aten.view.default(mm_544, [2, 8192, 4096]); mm_544 = None + convert_element_type_2299 = torch.ops.prims.convert_element_type.default(mm_543, torch.float32); mm_543 = None + reduce_scatter_tensor_205 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2299, 'avg', 256, '0'); convert_element_type_2299 = None + wait_tensor_496 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_205); reduce_scatter_tensor_205 = None + view_1643 = torch.ops.aten.view.default(view_1639, [16384, 1024]); view_1639 = None + permute_1085 = torch.ops.aten.permute.default(view_1643, [1, 0]) + mm_545 = torch.ops.aten.mm.default(permute_1085, view_309); permute_1085 = None + permute_1087 = torch.ops.aten.permute.default(permute_100, [1, 0]); permute_100 = None + mm_546 = torch.ops.aten.mm.default(view_1643, permute_1087); view_1643 = permute_1087 = None + view_1644 = torch.ops.aten.view.default(mm_546, [2, 8192, 4096]); mm_546 = None + add_287 = torch.ops.aten.add.Tensor(view_1642, view_1644); view_1642 = view_1644 = None + convert_element_type_2304 = torch.ops.prims.convert_element_type.default(mm_545, torch.float32); mm_545 = None + reduce_scatter_tensor_206 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2304, 'avg', 256, '0'); convert_element_type_2304 = None + wait_tensor_497 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_206); reduce_scatter_tensor_206 = None + view_1645 = torch.ops.aten.view.default(view_1640, [16384, 4096]); view_1640 = None + permute_1089 = torch.ops.aten.permute.default(view_1645, [1, 0]) + mm_547 = torch.ops.aten.mm.default(permute_1089, view_309); permute_1089 = view_309 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 256, '0'); convert_element_type_301 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + permute_1091 = torch.ops.aten.permute.default(permute_99, [1, 0]); permute_99 = None + mm_548 = torch.ops.aten.mm.default(view_1645, permute_1091); view_1645 = permute_1091 = None + view_1646 = torch.ops.aten.view.default(mm_548, [2, 8192, 4096]); mm_548 = None + add_288 = torch.ops.aten.add.Tensor(add_287, view_1646); add_287 = view_1646 = None + convert_element_type_2309 = torch.ops.prims.convert_element_type.default(mm_547, torch.float32); mm_547 = None + reduce_scatter_tensor_207 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2309, 'avg', 256, '0'); convert_element_type_2309 = None + wait_tensor_498 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_207); reduce_scatter_tensor_207 = None + convert_element_type_2310 = torch.ops.prims.convert_element_type.default(add_288, torch.float32); add_288 = None + convert_element_type_2312 = torch.ops.prims.convert_element_type.default(wait_tensor_82, torch.float32); wait_tensor_82 = None + mul_718 = torch.ops.aten.mul.Tensor(convert_element_type_2310, convert_element_type_2312); convert_element_type_2312 = None + mul_720 = torch.ops.aten.mul.Tensor(mul_72, mul_718) + sum_139 = torch.ops.aten.sum.dim_IntList(mul_720, [2], True); mul_720 = None + div_46 = torch.ops.aten.div.Tensor(mul_72, 4096) + mul_721 = torch.ops.aten.mul.Tensor(div_46, sum_139); div_46 = sum_139 = None + sub_69 = torch.ops.aten.sub.Tensor(mul_718, mul_721); mul_718 = mul_721 = None + mul_722 = torch.ops.aten.mul.Tensor(sub_69, rsqrt_18); sub_69 = rsqrt_18 = None + mul_723 = torch.ops.aten.mul.Tensor(convert_element_type_2310, mul_72); convert_element_type_2310 = mul_72 = None + sum_140 = torch.ops.aten.sum.dim_IntList(mul_723, [0, 1]); mul_723 = None + convert_element_type_2313 = torch.ops.prims.convert_element_type.default(mul_722, torch.bfloat16); mul_722 = None + add_289 = torch.ops.aten.add.Tensor(add_286, convert_element_type_2313); add_286 = convert_element_type_2313 = None + convert_element_type_default_19 = torch.ops.prims.convert_element_type.default(sum_140, torch.float32); sum_140 = None + reduce_scatter_tensor_208 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_19, 'avg', 256, '0'); convert_element_type_default_19 = None + wait_tensor_499 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_208); reduce_scatter_tensor_208 = None + view_1647 = torch.ops.aten.view.default(add_289, [16384, 4096]) + permute_1093 = torch.ops.aten.permute.default(view_1647, [1, 0]) + permute_94 = torch.ops.aten.permute.default(getitem_72, [0, 2, 1, 3]) + view_293 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16); primals_80 = None + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 256, '0'); convert_element_type_281 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + view_295 = torch.ops.aten.view.default(view_293, [16384, 4096]); view_293 = None + mm_59 = torch.ops.aten.mm.default(view_295, permute_95) + view_296 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + add_33 = torch.ops.aten.add.Tensor(add_31, view_296); view_296 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16); primals_81 = None + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 256, '0'); convert_element_type_284 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32); add_33 = None + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_78) + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + view_299 = torch.ops.aten.view.default(convert_element_type_286, [16384, 4096]); convert_element_type_286 = None + view_300 = torch.ops.aten.view.default(mm_60, [2, 8192, 14336]); mm_60 = None + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 256, '0'); convert_element_type_292 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_80, [1, 0]); wait_tensor_80 = None + mm_61 = torch.ops.aten.mm.default(view_299, permute_97) + view_303 = torch.ops.aten.view.default(mm_61, [2, 8192, 14336]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_303) + view_305 = torch.ops.aten.view.default(mul_71, [16384, 14336]); mul_71 = None + mm_549 = torch.ops.aten.mm.default(permute_1093, view_305); permute_1093 = view_305 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 256, '0'); convert_element_type_295 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_81, [1, 0]); wait_tensor_81 = None + permute_1095 = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None + mm_550 = torch.ops.aten.mm.default(view_1647, permute_1095); view_1647 = permute_1095 = None + view_1648 = torch.ops.aten.view.default(mm_550, [2, 8192, 14336]); mm_550 = None + convert_element_type_2320 = torch.ops.prims.convert_element_type.default(mm_549, torch.float32); mm_549 = None + reduce_scatter_tensor_209 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2320, 'avg', 256, '0'); convert_element_type_2320 = None + wait_tensor_500 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_209); reduce_scatter_tensor_209 = None + mul_724 = torch.ops.aten.mul.Tensor(view_1648, convert_element_type_291); convert_element_type_291 = None + mul_725 = torch.ops.aten.mul.Tensor(view_1648, view_303); view_1648 = view_303 = None + view_1649 = torch.ops.aten.view.default(mul_724, [16384, 14336]); mul_724 = None + permute_1097 = torch.ops.aten.permute.default(view_1649, [1, 0]) + mm_551 = torch.ops.aten.mm.default(permute_1097, view_299); permute_1097 = None + permute_1099 = torch.ops.aten.permute.default(permute_97, [1, 0]); permute_97 = None + mm_552 = torch.ops.aten.mm.default(view_1649, permute_1099); view_1649 = permute_1099 = None + view_1650 = torch.ops.aten.view.default(mm_552, [2, 8192, 4096]); mm_552 = None + convert_element_type_2325 = torch.ops.prims.convert_element_type.default(mm_551, torch.float32); mm_551 = None + reduce_scatter_tensor_210 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2325, 'avg', 256, '0'); convert_element_type_2325 = None + wait_tensor_501 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_210); reduce_scatter_tensor_210 = None + convert_element_type_2326 = torch.ops.prims.convert_element_type.default(mul_725, torch.float32); mul_725 = None + neg_23 = torch.ops.aten.neg.default(convert_element_type_290) + exp_23 = torch.ops.aten.exp.default(neg_23); neg_23 = None + add_290 = torch.ops.aten.add.Tensor(exp_23, 1); exp_23 = None + reciprocal_23 = torch.ops.aten.reciprocal.default(add_290); add_290 = None + mul_726 = torch.ops.aten.mul.Tensor(reciprocal_23, 1); reciprocal_23 = None + mul_727 = torch.ops.aten.mul.Tensor(convert_element_type_2326, mul_726); convert_element_type_2326 = None + sub_70 = torch.ops.aten.sub.Tensor(1, mul_726); mul_726 = None + mul_728 = torch.ops.aten.mul.Tensor(convert_element_type_290, sub_70); convert_element_type_290 = sub_70 = None + add_291 = torch.ops.aten.add.Tensor(mul_728, 1); mul_728 = None + mul_729 = torch.ops.aten.mul.Tensor(mul_727, add_291); mul_727 = add_291 = None + convert_element_type_2328 = torch.ops.prims.convert_element_type.default(mul_729, torch.bfloat16); mul_729 = None + view_1651 = torch.ops.aten.view.default(convert_element_type_2328, [16384, 14336]); convert_element_type_2328 = None + permute_1101 = torch.ops.aten.permute.default(view_1651, [1, 0]) + mm_553 = torch.ops.aten.mm.default(permute_1101, view_299); permute_1101 = view_299 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 256, '0'); convert_element_type_287 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + permute_1103 = torch.ops.aten.permute.default(permute_96, [1, 0]); permute_96 = None + mm_554 = torch.ops.aten.mm.default(view_1651, permute_1103); view_1651 = permute_1103 = None + view_1652 = torch.ops.aten.view.default(mm_554, [2, 8192, 4096]); mm_554 = None + add_292 = torch.ops.aten.add.Tensor(view_1650, view_1652); view_1650 = view_1652 = None + convert_element_type_2333 = torch.ops.prims.convert_element_type.default(mm_553, torch.float32); mm_553 = None + reduce_scatter_tensor_211 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2333, 'avg', 256, '0'); convert_element_type_2333 = None + wait_tensor_502 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_211); reduce_scatter_tensor_211 = None + convert_element_type_2334 = torch.ops.prims.convert_element_type.default(add_292, torch.float32); add_292 = None + convert_element_type_2336 = torch.ops.prims.convert_element_type.default(wait_tensor_78, torch.float32); wait_tensor_78 = None + mul_730 = torch.ops.aten.mul.Tensor(convert_element_type_2334, convert_element_type_2336); convert_element_type_2336 = None + mul_732 = torch.ops.aten.mul.Tensor(mul_68, mul_730) + sum_141 = torch.ops.aten.sum.dim_IntList(mul_732, [2], True); mul_732 = None + div_47 = torch.ops.aten.div.Tensor(mul_68, 4096) + mul_733 = torch.ops.aten.mul.Tensor(div_47, sum_141); div_47 = sum_141 = None + sub_71 = torch.ops.aten.sub.Tensor(mul_730, mul_733); mul_730 = mul_733 = None + mul_734 = torch.ops.aten.mul.Tensor(sub_71, rsqrt_17); sub_71 = rsqrt_17 = None + mul_735 = torch.ops.aten.mul.Tensor(convert_element_type_2334, mul_68); convert_element_type_2334 = mul_68 = None + sum_142 = torch.ops.aten.sum.dim_IntList(mul_735, [0, 1]); mul_735 = None + convert_element_type_2337 = torch.ops.prims.convert_element_type.default(mul_734, torch.bfloat16); mul_734 = None + add_293 = torch.ops.aten.add.Tensor(add_289, convert_element_type_2337); add_289 = convert_element_type_2337 = None + convert_element_type_default_18 = torch.ops.prims.convert_element_type.default(sum_142, torch.float32); sum_142 = None + reduce_scatter_tensor_212 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_18, 'avg', 256, '0'); convert_element_type_default_18 = None + wait_tensor_503 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_212); reduce_scatter_tensor_212 = None + view_1653 = torch.ops.aten.view.default(add_293, [16384, 4096]) + permute_1105 = torch.ops.aten.permute.default(view_1653, [1, 0]) + mm_555 = torch.ops.aten.mm.default(permute_1105, view_295); permute_1105 = view_295 = None + permute_1107 = torch.ops.aten.permute.default(permute_95, [1, 0]); permute_95 = None + mm_556 = torch.ops.aten.mm.default(view_1653, permute_1107); view_1653 = permute_1107 = None + view_1654 = torch.ops.aten.view.default(mm_556, [2, 8192, 4096]); mm_556 = None + convert_element_type_2344 = torch.ops.prims.convert_element_type.default(mm_555, torch.float32); mm_555 = None + reduce_scatter_tensor_213 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2344, 'avg', 256, '0'); convert_element_type_2344 = None + wait_tensor_504 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_213); reduce_scatter_tensor_213 = None + view_1655 = torch.ops.aten.view.default(view_1654, [2, 8192, 32, 128]); view_1654 = None + permute_1109 = torch.ops.aten.permute.default(view_1655, [0, 2, 1, 3]); view_1655 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16); primals_76 = None + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 256, '0'); convert_element_type_265 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32); add_31 = None + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_73) + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + view_275 = torch.ops.aten.view.default(convert_element_type_267, [16384, 4096]); convert_element_type_267 = None + view_276 = torch.ops.aten.view.default(mm_56, [2, 8192, 4096]); mm_56 = None + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16); primals_78 = None + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 256, '0'); convert_element_type_271 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + mm_57 = torch.ops.aten.mm.default(view_275, permute_89) + view_279 = torch.ops.aten.view.default(mm_57, [2, 8192, 1024]); mm_57 = None + view_282 = torch.ops.aten.view.default(mm_58, [2, 8192, 1024]); mm_58 = None + view_283 = torch.ops.aten.view.default(view_276, [2, 8192, -1, 128]); view_276 = None + view_284 = torch.ops.aten.view.default(view_279, [2, 8192, -1, 128]); view_279 = None + view_285 = torch.ops.aten.view.default(view_282, [2, 8192, -1, 128]); view_282 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_283, torch.float32); view_283 = None + view_286 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 32, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_286); view_286 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_284, torch.float32); view_284 = None + view_287 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 8, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_287); view_287 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_16); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_289 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 32, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_16); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_290 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 8, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_289, torch.bfloat16); view_289 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_290, torch.bfloat16); view_290 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 8, 4, 128]); unsqueeze_16 = None + clone_16 = torch.ops.aten.clone.default(expand_16, memory_format = torch.contiguous_format); expand_16 = None + view_291 = torch.ops.aten.view.default(clone_16, [2, 8192, 32, 128]); clone_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_285, 3); view_285 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 8, 4, 128]); unsqueeze_17 = None + clone_17 = torch.ops.aten.clone.default(expand_17, memory_format = torch.contiguous_format); expand_17 = None + view_292 = torch.ops.aten.view.default(clone_17, [2, 8192, 32, 128]); clone_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_291, [0, 2, 1, 3]); view_291 = None + permute_93 = torch.ops.aten.permute.default(view_292, [0, 2, 1, 3]); view_292 = None + _scaled_dot_product_cudnn_attention_backward_23 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1109, permute_91, permute_92, permute_93, getitem_72, getitem_73, getitem_78, getitem_79, None, None, None, 8192, 8192, 0.0, True); permute_1109 = permute_91 = permute_92 = permute_93 = getitem_72 = getitem_73 = getitem_78 = getitem_79 = None + getitem_357 = _scaled_dot_product_cudnn_attention_backward_23[0] + getitem_358 = _scaled_dot_product_cudnn_attention_backward_23[1] + getitem_359 = _scaled_dot_product_cudnn_attention_backward_23[2]; _scaled_dot_product_cudnn_attention_backward_23 = None + permute_1110 = torch.ops.aten.permute.default(getitem_359, [0, 2, 1, 3]); getitem_359 = None + permute_1111 = torch.ops.aten.permute.default(getitem_358, [0, 2, 1, 3]); getitem_358 = None + permute_1112 = torch.ops.aten.permute.default(getitem_357, [0, 2, 1, 3]); getitem_357 = None + view_1656 = torch.ops.aten.view.default(permute_1110, [2, 8192, 8, 4, 128]); permute_1110 = None + sum_143 = torch.ops.aten.sum.dim_IntList(view_1656, [3], True); view_1656 = None + squeeze_46 = torch.ops.aten.squeeze.dim(sum_143, 3); sum_143 = None + view_1657 = torch.ops.aten.view.default(permute_1111, [2, 8192, 8, 4, 128]); permute_1111 = None + sum_144 = torch.ops.aten.sum.dim_IntList(view_1657, [3], True); view_1657 = None + squeeze_47 = torch.ops.aten.squeeze.dim(sum_144, 3); sum_144 = None + convert_element_type_2345 = torch.ops.prims.convert_element_type.default(squeeze_47, torch.float32); squeeze_47 = None + convert_element_type_2346 = torch.ops.prims.convert_element_type.default(permute_1112, torch.float32); permute_1112 = None + view_1658 = torch.ops.aten.view.default(convert_element_type_2345, [2, 8192, 8, 64, 2]); convert_element_type_2345 = None + view_as_complex_110 = torch.ops.aten.view_as_complex.default(view_1658); view_1658 = None + mul_736 = torch.ops.aten.mul.Tensor(view_as_complex_110, _conj); view_as_complex_110 = None + view_1659 = torch.ops.aten.view.default(convert_element_type_2346, [2, 8192, 32, 64, 2]); convert_element_type_2346 = None + view_as_complex_111 = torch.ops.aten.view_as_complex.default(view_1659); view_1659 = None + mul_737 = torch.ops.aten.mul.Tensor(view_as_complex_111, _conj); view_as_complex_111 = None + view_as_real_110 = torch.ops.aten.view_as_real.default(mul_736); mul_736 = None + view_1660 = torch.ops.aten.view.default(view_as_real_110, [2, 8192, 8, 128]); view_as_real_110 = None + convert_element_type_2347 = torch.ops.prims.convert_element_type.default(view_1660, torch.bfloat16); view_1660 = None + view_as_real_111 = torch.ops.aten.view_as_real.default(mul_737); mul_737 = None + view_1661 = torch.ops.aten.view.default(view_as_real_111, [2, 8192, 32, 128]); view_as_real_111 = None + convert_element_type_2348 = torch.ops.prims.convert_element_type.default(view_1661, torch.bfloat16); view_1661 = None + view_1662 = torch.ops.aten.view.default(squeeze_46, [2, 8192, 1024]); squeeze_46 = None + view_1663 = torch.ops.aten.view.default(convert_element_type_2347, [2, 8192, 1024]); convert_element_type_2347 = None + view_1664 = torch.ops.aten.view.default(convert_element_type_2348, [2, 8192, 4096]); convert_element_type_2348 = None + view_1665 = torch.ops.aten.view.default(view_1662, [16384, 1024]); view_1662 = None + permute_1113 = torch.ops.aten.permute.default(view_1665, [1, 0]) + mm_557 = torch.ops.aten.mm.default(permute_1113, view_275); permute_1113 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16); primals_79 = None + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 256, '0'); convert_element_type_274 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + permute_1115 = torch.ops.aten.permute.default(permute_90, [1, 0]); permute_90 = None + mm_558 = torch.ops.aten.mm.default(view_1665, permute_1115); view_1665 = permute_1115 = None + view_1666 = torch.ops.aten.view.default(mm_558, [2, 8192, 4096]); mm_558 = None + convert_element_type_2353 = torch.ops.prims.convert_element_type.default(mm_557, torch.float32); mm_557 = None + reduce_scatter_tensor_214 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2353, 'avg', 256, '0'); convert_element_type_2353 = None + wait_tensor_505 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_214); reduce_scatter_tensor_214 = None + view_1667 = torch.ops.aten.view.default(view_1663, [16384, 1024]); view_1663 = None + permute_1117 = torch.ops.aten.permute.default(view_1667, [1, 0]) + mm_559 = torch.ops.aten.mm.default(permute_1117, view_275); permute_1117 = None + permute_1119 = torch.ops.aten.permute.default(permute_89, [1, 0]); permute_89 = None + mm_560 = torch.ops.aten.mm.default(view_1667, permute_1119); view_1667 = permute_1119 = None + view_1668 = torch.ops.aten.view.default(mm_560, [2, 8192, 4096]); mm_560 = None + add_294 = torch.ops.aten.add.Tensor(view_1666, view_1668); view_1666 = view_1668 = None + convert_element_type_2358 = torch.ops.prims.convert_element_type.default(mm_559, torch.float32); mm_559 = None + reduce_scatter_tensor_215 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2358, 'avg', 256, '0'); convert_element_type_2358 = None + wait_tensor_506 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_215); reduce_scatter_tensor_215 = None + view_1669 = torch.ops.aten.view.default(view_1664, [16384, 4096]); view_1664 = None + permute_1121 = torch.ops.aten.permute.default(view_1669, [1, 0]) + mm_561 = torch.ops.aten.mm.default(permute_1121, view_275); permute_1121 = view_275 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16); primals_77 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 256, '0'); convert_element_type_268 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_74, [1, 0]); wait_tensor_74 = None + permute_1123 = torch.ops.aten.permute.default(permute_88, [1, 0]); permute_88 = None + mm_562 = torch.ops.aten.mm.default(view_1669, permute_1123); view_1669 = permute_1123 = None + view_1670 = torch.ops.aten.view.default(mm_562, [2, 8192, 4096]); mm_562 = None + add_295 = torch.ops.aten.add.Tensor(add_294, view_1670); add_294 = view_1670 = None + convert_element_type_2363 = torch.ops.prims.convert_element_type.default(mm_561, torch.float32); mm_561 = None + reduce_scatter_tensor_216 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2363, 'avg', 256, '0'); convert_element_type_2363 = None + wait_tensor_507 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_216); reduce_scatter_tensor_216 = None + convert_element_type_2364 = torch.ops.prims.convert_element_type.default(add_295, torch.float32); add_295 = None + convert_element_type_2366 = torch.ops.prims.convert_element_type.default(wait_tensor_73, torch.float32); wait_tensor_73 = None + mul_738 = torch.ops.aten.mul.Tensor(convert_element_type_2364, convert_element_type_2366); convert_element_type_2366 = None + mul_740 = torch.ops.aten.mul.Tensor(mul_64, mul_738) + sum_145 = torch.ops.aten.sum.dim_IntList(mul_740, [2], True); mul_740 = None + div_48 = torch.ops.aten.div.Tensor(mul_64, 4096) + mul_741 = torch.ops.aten.mul.Tensor(div_48, sum_145); div_48 = sum_145 = None + sub_72 = torch.ops.aten.sub.Tensor(mul_738, mul_741); mul_738 = mul_741 = None + mul_742 = torch.ops.aten.mul.Tensor(sub_72, rsqrt_16); sub_72 = rsqrt_16 = None + mul_743 = torch.ops.aten.mul.Tensor(convert_element_type_2364, mul_64); convert_element_type_2364 = mul_64 = None + sum_146 = torch.ops.aten.sum.dim_IntList(mul_743, [0, 1]); mul_743 = None + convert_element_type_2367 = torch.ops.prims.convert_element_type.default(mul_742, torch.bfloat16); mul_742 = None + add_296 = torch.ops.aten.add.Tensor(add_293, convert_element_type_2367); add_293 = convert_element_type_2367 = None + convert_element_type_default_17 = torch.ops.prims.convert_element_type.default(sum_146, torch.float32); sum_146 = None + reduce_scatter_tensor_217 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_17, 'avg', 256, '0'); convert_element_type_default_17 = None + wait_tensor_508 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_217); reduce_scatter_tensor_217 = None + view_1671 = torch.ops.aten.view.default(add_296, [16384, 4096]) + permute_1125 = torch.ops.aten.permute.default(view_1671, [1, 0]) + permute_83 = torch.ops.aten.permute.default(getitem_63, [0, 2, 1, 3]) + view_259 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 256, '0'); convert_element_type_248 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_68, [1, 0]); wait_tensor_68 = None + view_261 = torch.ops.aten.view.default(view_259, [16384, 4096]); view_259 = None + mm_52 = torch.ops.aten.mm.default(view_261, permute_84) + view_262 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + add_29 = torch.ops.aten.add.Tensor(add_27, view_262); view_262 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 256, '0'); convert_element_type_251 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32); add_29 = None + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_69) + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + view_265 = torch.ops.aten.view.default(convert_element_type_253, [16384, 4096]); convert_element_type_253 = None + view_266 = torch.ops.aten.view.default(mm_53, [2, 8192, 14336]); mm_53 = None + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_266, torch.float32); view_266 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16); primals_74 = None + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 256, '0'); convert_element_type_259 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_54 = torch.ops.aten.mm.default(view_265, permute_86) + view_269 = torch.ops.aten.view.default(mm_54, [2, 8192, 14336]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_269) + view_271 = torch.ops.aten.view.default(mul_63, [16384, 14336]); mul_63 = None + mm_563 = torch.ops.aten.mm.default(permute_1125, view_271); permute_1125 = view_271 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16); primals_75 = None + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 256, '0'); convert_element_type_262 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + permute_1127 = torch.ops.aten.permute.default(permute_87, [1, 0]); permute_87 = None + mm_564 = torch.ops.aten.mm.default(view_1671, permute_1127); view_1671 = permute_1127 = None + view_1672 = torch.ops.aten.view.default(mm_564, [2, 8192, 14336]); mm_564 = None + convert_element_type_2374 = torch.ops.prims.convert_element_type.default(mm_563, torch.float32); mm_563 = None + reduce_scatter_tensor_218 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2374, 'avg', 256, '0'); convert_element_type_2374 = None + wait_tensor_509 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_218); reduce_scatter_tensor_218 = None + mul_744 = torch.ops.aten.mul.Tensor(view_1672, convert_element_type_258); convert_element_type_258 = None + mul_745 = torch.ops.aten.mul.Tensor(view_1672, view_269); view_1672 = view_269 = None + view_1673 = torch.ops.aten.view.default(mul_744, [16384, 14336]); mul_744 = None + permute_1129 = torch.ops.aten.permute.default(view_1673, [1, 0]) + mm_565 = torch.ops.aten.mm.default(permute_1129, view_265); permute_1129 = None + permute_1131 = torch.ops.aten.permute.default(permute_86, [1, 0]); permute_86 = None + mm_566 = torch.ops.aten.mm.default(view_1673, permute_1131); view_1673 = permute_1131 = None + view_1674 = torch.ops.aten.view.default(mm_566, [2, 8192, 4096]); mm_566 = None + convert_element_type_2379 = torch.ops.prims.convert_element_type.default(mm_565, torch.float32); mm_565 = None + reduce_scatter_tensor_219 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2379, 'avg', 256, '0'); convert_element_type_2379 = None + wait_tensor_510 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_219); reduce_scatter_tensor_219 = None + convert_element_type_2380 = torch.ops.prims.convert_element_type.default(mul_745, torch.float32); mul_745 = None + neg_24 = torch.ops.aten.neg.default(convert_element_type_257) + exp_24 = torch.ops.aten.exp.default(neg_24); neg_24 = None + add_297 = torch.ops.aten.add.Tensor(exp_24, 1); exp_24 = None + reciprocal_24 = torch.ops.aten.reciprocal.default(add_297); add_297 = None + mul_746 = torch.ops.aten.mul.Tensor(reciprocal_24, 1); reciprocal_24 = None + mul_747 = torch.ops.aten.mul.Tensor(convert_element_type_2380, mul_746); convert_element_type_2380 = None + sub_73 = torch.ops.aten.sub.Tensor(1, mul_746); mul_746 = None + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_257, sub_73); convert_element_type_257 = sub_73 = None + add_298 = torch.ops.aten.add.Tensor(mul_748, 1); mul_748 = None + mul_749 = torch.ops.aten.mul.Tensor(mul_747, add_298); mul_747 = add_298 = None + convert_element_type_2382 = torch.ops.prims.convert_element_type.default(mul_749, torch.bfloat16); mul_749 = None + view_1675 = torch.ops.aten.view.default(convert_element_type_2382, [16384, 14336]); convert_element_type_2382 = None + permute_1133 = torch.ops.aten.permute.default(view_1675, [1, 0]) + mm_567 = torch.ops.aten.mm.default(permute_1133, view_265); permute_1133 = view_265 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16); primals_73 = None + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 256, '0'); convert_element_type_254 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + permute_1135 = torch.ops.aten.permute.default(permute_85, [1, 0]); permute_85 = None + mm_568 = torch.ops.aten.mm.default(view_1675, permute_1135); view_1675 = permute_1135 = None + view_1676 = torch.ops.aten.view.default(mm_568, [2, 8192, 4096]); mm_568 = None + add_299 = torch.ops.aten.add.Tensor(view_1674, view_1676); view_1674 = view_1676 = None + convert_element_type_2387 = torch.ops.prims.convert_element_type.default(mm_567, torch.float32); mm_567 = None + reduce_scatter_tensor_220 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2387, 'avg', 256, '0'); convert_element_type_2387 = None + wait_tensor_511 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_220); reduce_scatter_tensor_220 = None + convert_element_type_2388 = torch.ops.prims.convert_element_type.default(add_299, torch.float32); add_299 = None + convert_element_type_2390 = torch.ops.prims.convert_element_type.default(wait_tensor_69, torch.float32); wait_tensor_69 = None + mul_750 = torch.ops.aten.mul.Tensor(convert_element_type_2388, convert_element_type_2390); convert_element_type_2390 = None + mul_752 = torch.ops.aten.mul.Tensor(mul_60, mul_750) + sum_147 = torch.ops.aten.sum.dim_IntList(mul_752, [2], True); mul_752 = None + div_49 = torch.ops.aten.div.Tensor(mul_60, 4096) + mul_753 = torch.ops.aten.mul.Tensor(div_49, sum_147); div_49 = sum_147 = None + sub_74 = torch.ops.aten.sub.Tensor(mul_750, mul_753); mul_750 = mul_753 = None + mul_754 = torch.ops.aten.mul.Tensor(sub_74, rsqrt_15); sub_74 = rsqrt_15 = None + mul_755 = torch.ops.aten.mul.Tensor(convert_element_type_2388, mul_60); convert_element_type_2388 = mul_60 = None + sum_148 = torch.ops.aten.sum.dim_IntList(mul_755, [0, 1]); mul_755 = None + convert_element_type_2391 = torch.ops.prims.convert_element_type.default(mul_754, torch.bfloat16); mul_754 = None + add_300 = torch.ops.aten.add.Tensor(add_296, convert_element_type_2391); add_296 = convert_element_type_2391 = None + convert_element_type_default_16 = torch.ops.prims.convert_element_type.default(sum_148, torch.float32); sum_148 = None + reduce_scatter_tensor_221 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_16, 'avg', 256, '0'); convert_element_type_default_16 = None + wait_tensor_512 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_221); reduce_scatter_tensor_221 = None + view_1677 = torch.ops.aten.view.default(add_300, [16384, 4096]) + permute_1137 = torch.ops.aten.permute.default(view_1677, [1, 0]) + mm_569 = torch.ops.aten.mm.default(permute_1137, view_261); permute_1137 = view_261 = None + permute_1139 = torch.ops.aten.permute.default(permute_84, [1, 0]); permute_84 = None + mm_570 = torch.ops.aten.mm.default(view_1677, permute_1139); view_1677 = permute_1139 = None + view_1678 = torch.ops.aten.view.default(mm_570, [2, 8192, 4096]); mm_570 = None + convert_element_type_2398 = torch.ops.prims.convert_element_type.default(mm_569, torch.float32); mm_569 = None + reduce_scatter_tensor_222 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2398, 'avg', 256, '0'); convert_element_type_2398 = None + wait_tensor_513 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_222); reduce_scatter_tensor_222 = None + view_1679 = torch.ops.aten.view.default(view_1678, [2, 8192, 32, 128]); view_1678 = None + permute_1141 = torch.ops.aten.permute.default(view_1679, [0, 2, 1, 3]); view_1679 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 256, '0'); convert_element_type_232 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32); add_27 = None + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_64) + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + view_241 = torch.ops.aten.view.default(convert_element_type_234, [16384, 4096]); convert_element_type_234 = None + view_242 = torch.ops.aten.view.default(mm_49, [2, 8192, 4096]); mm_49 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 256, '0'); convert_element_type_238 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_66, [1, 0]); wait_tensor_66 = None + mm_50 = torch.ops.aten.mm.default(view_241, permute_78) + view_245 = torch.ops.aten.view.default(mm_50, [2, 8192, 1024]); mm_50 = None + view_248 = torch.ops.aten.view.default(mm_51, [2, 8192, 1024]); mm_51 = None + view_249 = torch.ops.aten.view.default(view_242, [2, 8192, -1, 128]); view_242 = None + view_250 = torch.ops.aten.view.default(view_245, [2, 8192, -1, 128]); view_245 = None + view_251 = torch.ops.aten.view.default(view_248, [2, 8192, -1, 128]); view_248 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 32, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_250, torch.float32); view_250 = None + view_253 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 8, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_253); view_253 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_16); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_255 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 32, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_16); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_256 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 8, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_256, torch.bfloat16); view_256 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 8, 4, 128]); unsqueeze_14 = None + clone_14 = torch.ops.aten.clone.default(expand_14, memory_format = torch.contiguous_format); expand_14 = None + view_257 = torch.ops.aten.view.default(clone_14, [2, 8192, 32, 128]); clone_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_251, 3); view_251 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 8, 4, 128]); unsqueeze_15 = None + clone_15 = torch.ops.aten.clone.default(expand_15, memory_format = torch.contiguous_format); expand_15 = None + view_258 = torch.ops.aten.view.default(clone_15, [2, 8192, 32, 128]); clone_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + permute_82 = torch.ops.aten.permute.default(view_258, [0, 2, 1, 3]); view_258 = None + _scaled_dot_product_cudnn_attention_backward_24 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1141, permute_80, permute_81, permute_82, getitem_63, getitem_64, getitem_69, getitem_70, None, None, None, 8192, 8192, 0.0, True); permute_1141 = permute_80 = permute_81 = permute_82 = getitem_63 = getitem_64 = getitem_69 = getitem_70 = None + getitem_360 = _scaled_dot_product_cudnn_attention_backward_24[0] + getitem_361 = _scaled_dot_product_cudnn_attention_backward_24[1] + getitem_362 = _scaled_dot_product_cudnn_attention_backward_24[2]; _scaled_dot_product_cudnn_attention_backward_24 = None + permute_1142 = torch.ops.aten.permute.default(getitem_362, [0, 2, 1, 3]); getitem_362 = None + permute_1143 = torch.ops.aten.permute.default(getitem_361, [0, 2, 1, 3]); getitem_361 = None + permute_1144 = torch.ops.aten.permute.default(getitem_360, [0, 2, 1, 3]); getitem_360 = None + view_1680 = torch.ops.aten.view.default(permute_1142, [2, 8192, 8, 4, 128]); permute_1142 = None + sum_149 = torch.ops.aten.sum.dim_IntList(view_1680, [3], True); view_1680 = None + squeeze_48 = torch.ops.aten.squeeze.dim(sum_149, 3); sum_149 = None + view_1681 = torch.ops.aten.view.default(permute_1143, [2, 8192, 8, 4, 128]); permute_1143 = None + sum_150 = torch.ops.aten.sum.dim_IntList(view_1681, [3], True); view_1681 = None + squeeze_49 = torch.ops.aten.squeeze.dim(sum_150, 3); sum_150 = None + convert_element_type_2399 = torch.ops.prims.convert_element_type.default(squeeze_49, torch.float32); squeeze_49 = None + convert_element_type_2400 = torch.ops.prims.convert_element_type.default(permute_1144, torch.float32); permute_1144 = None + view_1682 = torch.ops.aten.view.default(convert_element_type_2399, [2, 8192, 8, 64, 2]); convert_element_type_2399 = None + view_as_complex_112 = torch.ops.aten.view_as_complex.default(view_1682); view_1682 = None + mul_756 = torch.ops.aten.mul.Tensor(view_as_complex_112, _conj); view_as_complex_112 = None + view_1683 = torch.ops.aten.view.default(convert_element_type_2400, [2, 8192, 32, 64, 2]); convert_element_type_2400 = None + view_as_complex_113 = torch.ops.aten.view_as_complex.default(view_1683); view_1683 = None + mul_757 = torch.ops.aten.mul.Tensor(view_as_complex_113, _conj); view_as_complex_113 = None + view_as_real_112 = torch.ops.aten.view_as_real.default(mul_756); mul_756 = None + view_1684 = torch.ops.aten.view.default(view_as_real_112, [2, 8192, 8, 128]); view_as_real_112 = None + convert_element_type_2401 = torch.ops.prims.convert_element_type.default(view_1684, torch.bfloat16); view_1684 = None + view_as_real_113 = torch.ops.aten.view_as_real.default(mul_757); mul_757 = None + view_1685 = torch.ops.aten.view.default(view_as_real_113, [2, 8192, 32, 128]); view_as_real_113 = None + convert_element_type_2402 = torch.ops.prims.convert_element_type.default(view_1685, torch.bfloat16); view_1685 = None + view_1686 = torch.ops.aten.view.default(squeeze_48, [2, 8192, 1024]); squeeze_48 = None + view_1687 = torch.ops.aten.view.default(convert_element_type_2401, [2, 8192, 1024]); convert_element_type_2401 = None + view_1688 = torch.ops.aten.view.default(convert_element_type_2402, [2, 8192, 4096]); convert_element_type_2402 = None + view_1689 = torch.ops.aten.view.default(view_1686, [16384, 1024]); view_1686 = None + permute_1145 = torch.ops.aten.permute.default(view_1689, [1, 0]) + mm_571 = torch.ops.aten.mm.default(permute_1145, view_241); permute_1145 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 256, '0'); convert_element_type_241 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_67, [1, 0]); wait_tensor_67 = None + permute_1147 = torch.ops.aten.permute.default(permute_79, [1, 0]); permute_79 = None + mm_572 = torch.ops.aten.mm.default(view_1689, permute_1147); view_1689 = permute_1147 = None + view_1690 = torch.ops.aten.view.default(mm_572, [2, 8192, 4096]); mm_572 = None + convert_element_type_2407 = torch.ops.prims.convert_element_type.default(mm_571, torch.float32); mm_571 = None + reduce_scatter_tensor_223 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2407, 'avg', 256, '0'); convert_element_type_2407 = None + wait_tensor_514 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_223); reduce_scatter_tensor_223 = None + view_1691 = torch.ops.aten.view.default(view_1687, [16384, 1024]); view_1687 = None + permute_1149 = torch.ops.aten.permute.default(view_1691, [1, 0]) + mm_573 = torch.ops.aten.mm.default(permute_1149, view_241); permute_1149 = None + permute_1151 = torch.ops.aten.permute.default(permute_78, [1, 0]); permute_78 = None + mm_574 = torch.ops.aten.mm.default(view_1691, permute_1151); view_1691 = permute_1151 = None + view_1692 = torch.ops.aten.view.default(mm_574, [2, 8192, 4096]); mm_574 = None + add_301 = torch.ops.aten.add.Tensor(view_1690, view_1692); view_1690 = view_1692 = None + convert_element_type_2412 = torch.ops.prims.convert_element_type.default(mm_573, torch.float32); mm_573 = None + reduce_scatter_tensor_224 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2412, 'avg', 256, '0'); convert_element_type_2412 = None + wait_tensor_515 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_224); reduce_scatter_tensor_224 = None + view_1693 = torch.ops.aten.view.default(view_1688, [16384, 4096]); view_1688 = None + permute_1153 = torch.ops.aten.permute.default(view_1693, [1, 0]) + mm_575 = torch.ops.aten.mm.default(permute_1153, view_241); permute_1153 = view_241 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 256, '0'); convert_element_type_235 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + permute_1155 = torch.ops.aten.permute.default(permute_77, [1, 0]); permute_77 = None + mm_576 = torch.ops.aten.mm.default(view_1693, permute_1155); view_1693 = permute_1155 = None + view_1694 = torch.ops.aten.view.default(mm_576, [2, 8192, 4096]); mm_576 = None + add_302 = torch.ops.aten.add.Tensor(add_301, view_1694); add_301 = view_1694 = None + convert_element_type_2417 = torch.ops.prims.convert_element_type.default(mm_575, torch.float32); mm_575 = None + reduce_scatter_tensor_225 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2417, 'avg', 256, '0'); convert_element_type_2417 = None + wait_tensor_516 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_225); reduce_scatter_tensor_225 = None + convert_element_type_2418 = torch.ops.prims.convert_element_type.default(add_302, torch.float32); add_302 = None + convert_element_type_2420 = torch.ops.prims.convert_element_type.default(wait_tensor_64, torch.float32); wait_tensor_64 = None + mul_758 = torch.ops.aten.mul.Tensor(convert_element_type_2418, convert_element_type_2420); convert_element_type_2420 = None + mul_760 = torch.ops.aten.mul.Tensor(mul_56, mul_758) + sum_151 = torch.ops.aten.sum.dim_IntList(mul_760, [2], True); mul_760 = None + div_50 = torch.ops.aten.div.Tensor(mul_56, 4096) + mul_761 = torch.ops.aten.mul.Tensor(div_50, sum_151); div_50 = sum_151 = None + sub_75 = torch.ops.aten.sub.Tensor(mul_758, mul_761); mul_758 = mul_761 = None + mul_762 = torch.ops.aten.mul.Tensor(sub_75, rsqrt_14); sub_75 = rsqrt_14 = None + mul_763 = torch.ops.aten.mul.Tensor(convert_element_type_2418, mul_56); convert_element_type_2418 = mul_56 = None + sum_152 = torch.ops.aten.sum.dim_IntList(mul_763, [0, 1]); mul_763 = None + convert_element_type_2421 = torch.ops.prims.convert_element_type.default(mul_762, torch.bfloat16); mul_762 = None + add_303 = torch.ops.aten.add.Tensor(add_300, convert_element_type_2421); add_300 = convert_element_type_2421 = None + convert_element_type_default_15 = torch.ops.prims.convert_element_type.default(sum_152, torch.float32); sum_152 = None + reduce_scatter_tensor_226 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_15, 'avg', 256, '0'); convert_element_type_default_15 = None + wait_tensor_517 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_226); reduce_scatter_tensor_226 = None + view_1695 = torch.ops.aten.view.default(add_303, [16384, 4096]) + permute_1157 = torch.ops.aten.permute.default(view_1695, [1, 0]) + permute_72 = torch.ops.aten.permute.default(getitem_54, [0, 2, 1, 3]) + view_225 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16); primals_62 = None + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 256, '0'); convert_element_type_215 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_227 = torch.ops.aten.view.default(view_225, [16384, 4096]); view_225 = None + mm_45 = torch.ops.aten.mm.default(view_227, permute_73) + view_228 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + add_25 = torch.ops.aten.add.Tensor(add_23, view_228); view_228 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16); primals_63 = None + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 256, '0'); convert_element_type_218 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32); add_25 = None + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_60) + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + view_231 = torch.ops.aten.view.default(convert_element_type_220, [16384, 4096]); convert_element_type_220 = None + view_232 = torch.ops.aten.view.default(mm_46, [2, 8192, 14336]); mm_46 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_232, torch.float32); view_232 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 256, '0'); convert_element_type_226 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_62, [1, 0]); wait_tensor_62 = None + mm_47 = torch.ops.aten.mm.default(view_231, permute_75) + view_235 = torch.ops.aten.view.default(mm_47, [2, 8192, 14336]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_235) + view_237 = torch.ops.aten.view.default(mul_55, [16384, 14336]); mul_55 = None + mm_577 = torch.ops.aten.mm.default(permute_1157, view_237); permute_1157 = view_237 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 256, '0'); convert_element_type_229 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + permute_1159 = torch.ops.aten.permute.default(permute_76, [1, 0]); permute_76 = None + mm_578 = torch.ops.aten.mm.default(view_1695, permute_1159); view_1695 = permute_1159 = None + view_1696 = torch.ops.aten.view.default(mm_578, [2, 8192, 14336]); mm_578 = None + convert_element_type_2428 = torch.ops.prims.convert_element_type.default(mm_577, torch.float32); mm_577 = None + reduce_scatter_tensor_227 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2428, 'avg', 256, '0'); convert_element_type_2428 = None + wait_tensor_518 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_227); reduce_scatter_tensor_227 = None + mul_764 = torch.ops.aten.mul.Tensor(view_1696, convert_element_type_225); convert_element_type_225 = None + mul_765 = torch.ops.aten.mul.Tensor(view_1696, view_235); view_1696 = view_235 = None + view_1697 = torch.ops.aten.view.default(mul_764, [16384, 14336]); mul_764 = None + permute_1161 = torch.ops.aten.permute.default(view_1697, [1, 0]) + mm_579 = torch.ops.aten.mm.default(permute_1161, view_231); permute_1161 = None + permute_1163 = torch.ops.aten.permute.default(permute_75, [1, 0]); permute_75 = None + mm_580 = torch.ops.aten.mm.default(view_1697, permute_1163); view_1697 = permute_1163 = None + view_1698 = torch.ops.aten.view.default(mm_580, [2, 8192, 4096]); mm_580 = None + convert_element_type_2433 = torch.ops.prims.convert_element_type.default(mm_579, torch.float32); mm_579 = None + reduce_scatter_tensor_228 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2433, 'avg', 256, '0'); convert_element_type_2433 = None + wait_tensor_519 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_228); reduce_scatter_tensor_228 = None + convert_element_type_2434 = torch.ops.prims.convert_element_type.default(mul_765, torch.float32); mul_765 = None + neg_25 = torch.ops.aten.neg.default(convert_element_type_224) + exp_25 = torch.ops.aten.exp.default(neg_25); neg_25 = None + add_304 = torch.ops.aten.add.Tensor(exp_25, 1); exp_25 = None + reciprocal_25 = torch.ops.aten.reciprocal.default(add_304); add_304 = None + mul_766 = torch.ops.aten.mul.Tensor(reciprocal_25, 1); reciprocal_25 = None + mul_767 = torch.ops.aten.mul.Tensor(convert_element_type_2434, mul_766); convert_element_type_2434 = None + sub_76 = torch.ops.aten.sub.Tensor(1, mul_766); mul_766 = None + mul_768 = torch.ops.aten.mul.Tensor(convert_element_type_224, sub_76); convert_element_type_224 = sub_76 = None + add_305 = torch.ops.aten.add.Tensor(mul_768, 1); mul_768 = None + mul_769 = torch.ops.aten.mul.Tensor(mul_767, add_305); mul_767 = add_305 = None + convert_element_type_2436 = torch.ops.prims.convert_element_type.default(mul_769, torch.bfloat16); mul_769 = None + view_1699 = torch.ops.aten.view.default(convert_element_type_2436, [16384, 14336]); convert_element_type_2436 = None + permute_1165 = torch.ops.aten.permute.default(view_1699, [1, 0]) + mm_581 = torch.ops.aten.mm.default(permute_1165, view_231); permute_1165 = view_231 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16); primals_64 = None + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 256, '0'); convert_element_type_221 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_61, [1, 0]); wait_tensor_61 = None + permute_1167 = torch.ops.aten.permute.default(permute_74, [1, 0]); permute_74 = None + mm_582 = torch.ops.aten.mm.default(view_1699, permute_1167); view_1699 = permute_1167 = None + view_1700 = torch.ops.aten.view.default(mm_582, [2, 8192, 4096]); mm_582 = None + add_306 = torch.ops.aten.add.Tensor(view_1698, view_1700); view_1698 = view_1700 = None + convert_element_type_2441 = torch.ops.prims.convert_element_type.default(mm_581, torch.float32); mm_581 = None + reduce_scatter_tensor_229 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2441, 'avg', 256, '0'); convert_element_type_2441 = None + wait_tensor_520 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_229); reduce_scatter_tensor_229 = None + convert_element_type_2442 = torch.ops.prims.convert_element_type.default(add_306, torch.float32); add_306 = None + convert_element_type_2444 = torch.ops.prims.convert_element_type.default(wait_tensor_60, torch.float32); wait_tensor_60 = None + mul_770 = torch.ops.aten.mul.Tensor(convert_element_type_2442, convert_element_type_2444); convert_element_type_2444 = None + mul_772 = torch.ops.aten.mul.Tensor(mul_52, mul_770) + sum_153 = torch.ops.aten.sum.dim_IntList(mul_772, [2], True); mul_772 = None + div_51 = torch.ops.aten.div.Tensor(mul_52, 4096) + mul_773 = torch.ops.aten.mul.Tensor(div_51, sum_153); div_51 = sum_153 = None + sub_77 = torch.ops.aten.sub.Tensor(mul_770, mul_773); mul_770 = mul_773 = None + mul_774 = torch.ops.aten.mul.Tensor(sub_77, rsqrt_13); sub_77 = rsqrt_13 = None + mul_775 = torch.ops.aten.mul.Tensor(convert_element_type_2442, mul_52); convert_element_type_2442 = mul_52 = None + sum_154 = torch.ops.aten.sum.dim_IntList(mul_775, [0, 1]); mul_775 = None + convert_element_type_2445 = torch.ops.prims.convert_element_type.default(mul_774, torch.bfloat16); mul_774 = None + add_307 = torch.ops.aten.add.Tensor(add_303, convert_element_type_2445); add_303 = convert_element_type_2445 = None + convert_element_type_default_14 = torch.ops.prims.convert_element_type.default(sum_154, torch.float32); sum_154 = None + reduce_scatter_tensor_230 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_14, 'avg', 256, '0'); convert_element_type_default_14 = None + wait_tensor_521 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_230); reduce_scatter_tensor_230 = None + view_1701 = torch.ops.aten.view.default(add_307, [16384, 4096]) + permute_1169 = torch.ops.aten.permute.default(view_1701, [1, 0]) + mm_583 = torch.ops.aten.mm.default(permute_1169, view_227); permute_1169 = view_227 = None + permute_1171 = torch.ops.aten.permute.default(permute_73, [1, 0]); permute_73 = None + mm_584 = torch.ops.aten.mm.default(view_1701, permute_1171); view_1701 = permute_1171 = None + view_1702 = torch.ops.aten.view.default(mm_584, [2, 8192, 4096]); mm_584 = None + convert_element_type_2452 = torch.ops.prims.convert_element_type.default(mm_583, torch.float32); mm_583 = None + reduce_scatter_tensor_231 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2452, 'avg', 256, '0'); convert_element_type_2452 = None + wait_tensor_522 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_231); reduce_scatter_tensor_231 = None + view_1703 = torch.ops.aten.view.default(view_1702, [2, 8192, 32, 128]); view_1702 = None + permute_1173 = torch.ops.aten.permute.default(view_1703, [0, 2, 1, 3]); view_1703 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16); primals_58 = None + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 256, '0'); convert_element_type_199 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32); add_23 = None + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_55) + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + view_207 = torch.ops.aten.view.default(convert_element_type_201, [16384, 4096]); convert_element_type_201 = None + view_208 = torch.ops.aten.view.default(mm_42, [2, 8192, 4096]); mm_42 = None + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16); primals_60 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 256, '0'); convert_element_type_205 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_43 = torch.ops.aten.mm.default(view_207, permute_67) + view_211 = torch.ops.aten.view.default(mm_43, [2, 8192, 1024]); mm_43 = None + view_214 = torch.ops.aten.view.default(mm_44, [2, 8192, 1024]); mm_44 = None + view_215 = torch.ops.aten.view.default(view_208, [2, 8192, -1, 128]); view_208 = None + view_216 = torch.ops.aten.view.default(view_211, [2, 8192, -1, 128]); view_211 = None + view_217 = torch.ops.aten.view.default(view_214, [2, 8192, -1, 128]); view_214 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_215, torch.float32); view_215 = None + view_218 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 32, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_218); view_218 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_216, torch.float32); view_216 = None + view_219 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 8, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_219); view_219 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_16); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_221 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 32, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_16); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_222 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 8, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_221, torch.bfloat16); view_221 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_222, torch.bfloat16); view_222 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 8, 4, 128]); unsqueeze_12 = None + clone_12 = torch.ops.aten.clone.default(expand_12, memory_format = torch.contiguous_format); expand_12 = None + view_223 = torch.ops.aten.view.default(clone_12, [2, 8192, 32, 128]); clone_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_217, 3); view_217 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 8, 4, 128]); unsqueeze_13 = None + clone_13 = torch.ops.aten.clone.default(expand_13, memory_format = torch.contiguous_format); expand_13 = None + view_224 = torch.ops.aten.view.default(clone_13, [2, 8192, 32, 128]); clone_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_223, [0, 2, 1, 3]); view_223 = None + permute_71 = torch.ops.aten.permute.default(view_224, [0, 2, 1, 3]); view_224 = None + _scaled_dot_product_cudnn_attention_backward_25 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1173, permute_69, permute_70, permute_71, getitem_54, getitem_55, getitem_60, getitem_61, None, None, None, 8192, 8192, 0.0, True); permute_1173 = permute_69 = permute_70 = permute_71 = getitem_54 = getitem_55 = getitem_60 = getitem_61 = None + getitem_363 = _scaled_dot_product_cudnn_attention_backward_25[0] + getitem_364 = _scaled_dot_product_cudnn_attention_backward_25[1] + getitem_365 = _scaled_dot_product_cudnn_attention_backward_25[2]; _scaled_dot_product_cudnn_attention_backward_25 = None + permute_1174 = torch.ops.aten.permute.default(getitem_365, [0, 2, 1, 3]); getitem_365 = None + permute_1175 = torch.ops.aten.permute.default(getitem_364, [0, 2, 1, 3]); getitem_364 = None + permute_1176 = torch.ops.aten.permute.default(getitem_363, [0, 2, 1, 3]); getitem_363 = None + view_1704 = torch.ops.aten.view.default(permute_1174, [2, 8192, 8, 4, 128]); permute_1174 = None + sum_155 = torch.ops.aten.sum.dim_IntList(view_1704, [3], True); view_1704 = None + squeeze_50 = torch.ops.aten.squeeze.dim(sum_155, 3); sum_155 = None + view_1705 = torch.ops.aten.view.default(permute_1175, [2, 8192, 8, 4, 128]); permute_1175 = None + sum_156 = torch.ops.aten.sum.dim_IntList(view_1705, [3], True); view_1705 = None + squeeze_51 = torch.ops.aten.squeeze.dim(sum_156, 3); sum_156 = None + convert_element_type_2453 = torch.ops.prims.convert_element_type.default(squeeze_51, torch.float32); squeeze_51 = None + convert_element_type_2454 = torch.ops.prims.convert_element_type.default(permute_1176, torch.float32); permute_1176 = None + view_1706 = torch.ops.aten.view.default(convert_element_type_2453, [2, 8192, 8, 64, 2]); convert_element_type_2453 = None + view_as_complex_114 = torch.ops.aten.view_as_complex.default(view_1706); view_1706 = None + mul_776 = torch.ops.aten.mul.Tensor(view_as_complex_114, _conj); view_as_complex_114 = None + view_1707 = torch.ops.aten.view.default(convert_element_type_2454, [2, 8192, 32, 64, 2]); convert_element_type_2454 = None + view_as_complex_115 = torch.ops.aten.view_as_complex.default(view_1707); view_1707 = None + mul_777 = torch.ops.aten.mul.Tensor(view_as_complex_115, _conj); view_as_complex_115 = None + view_as_real_114 = torch.ops.aten.view_as_real.default(mul_776); mul_776 = None + view_1708 = torch.ops.aten.view.default(view_as_real_114, [2, 8192, 8, 128]); view_as_real_114 = None + convert_element_type_2455 = torch.ops.prims.convert_element_type.default(view_1708, torch.bfloat16); view_1708 = None + view_as_real_115 = torch.ops.aten.view_as_real.default(mul_777); mul_777 = None + view_1709 = torch.ops.aten.view.default(view_as_real_115, [2, 8192, 32, 128]); view_as_real_115 = None + convert_element_type_2456 = torch.ops.prims.convert_element_type.default(view_1709, torch.bfloat16); view_1709 = None + view_1710 = torch.ops.aten.view.default(squeeze_50, [2, 8192, 1024]); squeeze_50 = None + view_1711 = torch.ops.aten.view.default(convert_element_type_2455, [2, 8192, 1024]); convert_element_type_2455 = None + view_1712 = torch.ops.aten.view.default(convert_element_type_2456, [2, 8192, 4096]); convert_element_type_2456 = None + view_1713 = torch.ops.aten.view.default(view_1710, [16384, 1024]); view_1710 = None + permute_1177 = torch.ops.aten.permute.default(view_1713, [1, 0]) + mm_585 = torch.ops.aten.mm.default(permute_1177, view_207); permute_1177 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16); primals_61 = None + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 256, '0'); convert_element_type_208 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + permute_1179 = torch.ops.aten.permute.default(permute_68, [1, 0]); permute_68 = None + mm_586 = torch.ops.aten.mm.default(view_1713, permute_1179); view_1713 = permute_1179 = None + view_1714 = torch.ops.aten.view.default(mm_586, [2, 8192, 4096]); mm_586 = None + convert_element_type_2461 = torch.ops.prims.convert_element_type.default(mm_585, torch.float32); mm_585 = None + reduce_scatter_tensor_232 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2461, 'avg', 256, '0'); convert_element_type_2461 = None + wait_tensor_523 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_232); reduce_scatter_tensor_232 = None + view_1715 = torch.ops.aten.view.default(view_1711, [16384, 1024]); view_1711 = None + permute_1181 = torch.ops.aten.permute.default(view_1715, [1, 0]) + mm_587 = torch.ops.aten.mm.default(permute_1181, view_207); permute_1181 = None + permute_1183 = torch.ops.aten.permute.default(permute_67, [1, 0]); permute_67 = None + mm_588 = torch.ops.aten.mm.default(view_1715, permute_1183); view_1715 = permute_1183 = None + view_1716 = torch.ops.aten.view.default(mm_588, [2, 8192, 4096]); mm_588 = None + add_308 = torch.ops.aten.add.Tensor(view_1714, view_1716); view_1714 = view_1716 = None + convert_element_type_2466 = torch.ops.prims.convert_element_type.default(mm_587, torch.float32); mm_587 = None + reduce_scatter_tensor_233 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2466, 'avg', 256, '0'); convert_element_type_2466 = None + wait_tensor_524 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_233); reduce_scatter_tensor_233 = None + view_1717 = torch.ops.aten.view.default(view_1712, [16384, 4096]); view_1712 = None + permute_1185 = torch.ops.aten.permute.default(view_1717, [1, 0]) + mm_589 = torch.ops.aten.mm.default(permute_1185, view_207); permute_1185 = view_207 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16); primals_59 = None + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 256, '0'); convert_element_type_202 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + permute_1187 = torch.ops.aten.permute.default(permute_66, [1, 0]); permute_66 = None + mm_590 = torch.ops.aten.mm.default(view_1717, permute_1187); view_1717 = permute_1187 = None + view_1718 = torch.ops.aten.view.default(mm_590, [2, 8192, 4096]); mm_590 = None + add_309 = torch.ops.aten.add.Tensor(add_308, view_1718); add_308 = view_1718 = None + convert_element_type_2471 = torch.ops.prims.convert_element_type.default(mm_589, torch.float32); mm_589 = None + reduce_scatter_tensor_234 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2471, 'avg', 256, '0'); convert_element_type_2471 = None + wait_tensor_525 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_234); reduce_scatter_tensor_234 = None + convert_element_type_2472 = torch.ops.prims.convert_element_type.default(add_309, torch.float32); add_309 = None + convert_element_type_2474 = torch.ops.prims.convert_element_type.default(wait_tensor_55, torch.float32); wait_tensor_55 = None + mul_778 = torch.ops.aten.mul.Tensor(convert_element_type_2472, convert_element_type_2474); convert_element_type_2474 = None + mul_780 = torch.ops.aten.mul.Tensor(mul_48, mul_778) + sum_157 = torch.ops.aten.sum.dim_IntList(mul_780, [2], True); mul_780 = None + div_52 = torch.ops.aten.div.Tensor(mul_48, 4096) + mul_781 = torch.ops.aten.mul.Tensor(div_52, sum_157); div_52 = sum_157 = None + sub_78 = torch.ops.aten.sub.Tensor(mul_778, mul_781); mul_778 = mul_781 = None + mul_782 = torch.ops.aten.mul.Tensor(sub_78, rsqrt_12); sub_78 = rsqrt_12 = None + mul_783 = torch.ops.aten.mul.Tensor(convert_element_type_2472, mul_48); convert_element_type_2472 = mul_48 = None + sum_158 = torch.ops.aten.sum.dim_IntList(mul_783, [0, 1]); mul_783 = None + convert_element_type_2475 = torch.ops.prims.convert_element_type.default(mul_782, torch.bfloat16); mul_782 = None + add_310 = torch.ops.aten.add.Tensor(add_307, convert_element_type_2475); add_307 = convert_element_type_2475 = None + convert_element_type_default_13 = torch.ops.prims.convert_element_type.default(sum_158, torch.float32); sum_158 = None + reduce_scatter_tensor_235 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_13, 'avg', 256, '0'); convert_element_type_default_13 = None + wait_tensor_526 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_235); reduce_scatter_tensor_235 = None + view_1719 = torch.ops.aten.view.default(add_310, [16384, 4096]) + permute_1189 = torch.ops.aten.permute.default(view_1719, [1, 0]) + permute_61 = torch.ops.aten.permute.default(getitem_45, [0, 2, 1, 3]) + view_191 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 256, '0'); convert_element_type_182 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_193 = torch.ops.aten.view.default(view_191, [16384, 4096]); view_191 = None + mm_38 = torch.ops.aten.mm.default(view_193, permute_62) + view_194 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + add_21 = torch.ops.aten.add.Tensor(add_19, view_194); view_194 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 256, '0'); convert_element_type_185 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32); add_21 = None + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_51) + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + view_197 = torch.ops.aten.view.default(convert_element_type_187, [16384, 4096]); convert_element_type_187 = None + view_198 = torch.ops.aten.view.default(mm_39, [2, 8192, 14336]); mm_39 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_198, torch.float32); view_198 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16); primals_56 = None + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 256, '0'); convert_element_type_193 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_53, [1, 0]); wait_tensor_53 = None + mm_40 = torch.ops.aten.mm.default(view_197, permute_64) + view_201 = torch.ops.aten.view.default(mm_40, [2, 8192, 14336]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_201) + view_203 = torch.ops.aten.view.default(mul_47, [16384, 14336]); mul_47 = None + mm_591 = torch.ops.aten.mm.default(permute_1189, view_203); permute_1189 = view_203 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16); primals_57 = None + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 256, '0'); convert_element_type_196 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + permute_1191 = torch.ops.aten.permute.default(permute_65, [1, 0]); permute_65 = None + mm_592 = torch.ops.aten.mm.default(view_1719, permute_1191); view_1719 = permute_1191 = None + view_1720 = torch.ops.aten.view.default(mm_592, [2, 8192, 14336]); mm_592 = None + convert_element_type_2482 = torch.ops.prims.convert_element_type.default(mm_591, torch.float32); mm_591 = None + reduce_scatter_tensor_236 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2482, 'avg', 256, '0'); convert_element_type_2482 = None + wait_tensor_527 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_236); reduce_scatter_tensor_236 = None + mul_784 = torch.ops.aten.mul.Tensor(view_1720, convert_element_type_192); convert_element_type_192 = None + mul_785 = torch.ops.aten.mul.Tensor(view_1720, view_201); view_1720 = view_201 = None + view_1721 = torch.ops.aten.view.default(mul_784, [16384, 14336]); mul_784 = None + permute_1193 = torch.ops.aten.permute.default(view_1721, [1, 0]) + mm_593 = torch.ops.aten.mm.default(permute_1193, view_197); permute_1193 = None + permute_1195 = torch.ops.aten.permute.default(permute_64, [1, 0]); permute_64 = None + mm_594 = torch.ops.aten.mm.default(view_1721, permute_1195); view_1721 = permute_1195 = None + view_1722 = torch.ops.aten.view.default(mm_594, [2, 8192, 4096]); mm_594 = None + convert_element_type_2487 = torch.ops.prims.convert_element_type.default(mm_593, torch.float32); mm_593 = None + reduce_scatter_tensor_237 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2487, 'avg', 256, '0'); convert_element_type_2487 = None + wait_tensor_528 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_237); reduce_scatter_tensor_237 = None + convert_element_type_2488 = torch.ops.prims.convert_element_type.default(mul_785, torch.float32); mul_785 = None + neg_26 = torch.ops.aten.neg.default(convert_element_type_191) + exp_26 = torch.ops.aten.exp.default(neg_26); neg_26 = None + add_311 = torch.ops.aten.add.Tensor(exp_26, 1); exp_26 = None + reciprocal_26 = torch.ops.aten.reciprocal.default(add_311); add_311 = None + mul_786 = torch.ops.aten.mul.Tensor(reciprocal_26, 1); reciprocal_26 = None + mul_787 = torch.ops.aten.mul.Tensor(convert_element_type_2488, mul_786); convert_element_type_2488 = None + sub_79 = torch.ops.aten.sub.Tensor(1, mul_786); mul_786 = None + mul_788 = torch.ops.aten.mul.Tensor(convert_element_type_191, sub_79); convert_element_type_191 = sub_79 = None + add_312 = torch.ops.aten.add.Tensor(mul_788, 1); mul_788 = None + mul_789 = torch.ops.aten.mul.Tensor(mul_787, add_312); mul_787 = add_312 = None + convert_element_type_2490 = torch.ops.prims.convert_element_type.default(mul_789, torch.bfloat16); mul_789 = None + view_1723 = torch.ops.aten.view.default(convert_element_type_2490, [16384, 14336]); convert_element_type_2490 = None + permute_1197 = torch.ops.aten.permute.default(view_1723, [1, 0]) + mm_595 = torch.ops.aten.mm.default(permute_1197, view_197); permute_1197 = view_197 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16); primals_55 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 256, '0'); convert_element_type_188 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + permute_1199 = torch.ops.aten.permute.default(permute_63, [1, 0]); permute_63 = None + mm_596 = torch.ops.aten.mm.default(view_1723, permute_1199); view_1723 = permute_1199 = None + view_1724 = torch.ops.aten.view.default(mm_596, [2, 8192, 4096]); mm_596 = None + add_313 = torch.ops.aten.add.Tensor(view_1722, view_1724); view_1722 = view_1724 = None + convert_element_type_2495 = torch.ops.prims.convert_element_type.default(mm_595, torch.float32); mm_595 = None + reduce_scatter_tensor_238 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2495, 'avg', 256, '0'); convert_element_type_2495 = None + wait_tensor_529 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_238); reduce_scatter_tensor_238 = None + convert_element_type_2496 = torch.ops.prims.convert_element_type.default(add_313, torch.float32); add_313 = None + convert_element_type_2498 = torch.ops.prims.convert_element_type.default(wait_tensor_51, torch.float32); wait_tensor_51 = None + mul_790 = torch.ops.aten.mul.Tensor(convert_element_type_2496, convert_element_type_2498); convert_element_type_2498 = None + mul_792 = torch.ops.aten.mul.Tensor(mul_44, mul_790) + sum_159 = torch.ops.aten.sum.dim_IntList(mul_792, [2], True); mul_792 = None + div_53 = torch.ops.aten.div.Tensor(mul_44, 4096) + mul_793 = torch.ops.aten.mul.Tensor(div_53, sum_159); div_53 = sum_159 = None + sub_80 = torch.ops.aten.sub.Tensor(mul_790, mul_793); mul_790 = mul_793 = None + mul_794 = torch.ops.aten.mul.Tensor(sub_80, rsqrt_11); sub_80 = rsqrt_11 = None + mul_795 = torch.ops.aten.mul.Tensor(convert_element_type_2496, mul_44); convert_element_type_2496 = mul_44 = None + sum_160 = torch.ops.aten.sum.dim_IntList(mul_795, [0, 1]); mul_795 = None + convert_element_type_2499 = torch.ops.prims.convert_element_type.default(mul_794, torch.bfloat16); mul_794 = None + add_314 = torch.ops.aten.add.Tensor(add_310, convert_element_type_2499); add_310 = convert_element_type_2499 = None + convert_element_type_default_12 = torch.ops.prims.convert_element_type.default(sum_160, torch.float32); sum_160 = None + reduce_scatter_tensor_239 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_12, 'avg', 256, '0'); convert_element_type_default_12 = None + wait_tensor_530 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_239); reduce_scatter_tensor_239 = None + view_1725 = torch.ops.aten.view.default(add_314, [16384, 4096]) + permute_1201 = torch.ops.aten.permute.default(view_1725, [1, 0]) + mm_597 = torch.ops.aten.mm.default(permute_1201, view_193); permute_1201 = view_193 = None + permute_1203 = torch.ops.aten.permute.default(permute_62, [1, 0]); permute_62 = None + mm_598 = torch.ops.aten.mm.default(view_1725, permute_1203); view_1725 = permute_1203 = None + view_1726 = torch.ops.aten.view.default(mm_598, [2, 8192, 4096]); mm_598 = None + convert_element_type_2506 = torch.ops.prims.convert_element_type.default(mm_597, torch.float32); mm_597 = None + reduce_scatter_tensor_240 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2506, 'avg', 256, '0'); convert_element_type_2506 = None + wait_tensor_531 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_240); reduce_scatter_tensor_240 = None + view_1727 = torch.ops.aten.view.default(view_1726, [2, 8192, 32, 128]); view_1726 = None + permute_1205 = torch.ops.aten.permute.default(view_1727, [0, 2, 1, 3]); view_1727 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 256, '0'); convert_element_type_166 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32); add_19 = None + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_46) + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + view_173 = torch.ops.aten.view.default(convert_element_type_168, [16384, 4096]); convert_element_type_168 = None + view_174 = torch.ops.aten.view.default(mm_35, [2, 8192, 4096]); mm_35 = None + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 256, '0'); convert_element_type_172 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_48, [1, 0]); wait_tensor_48 = None + mm_36 = torch.ops.aten.mm.default(view_173, permute_56) + view_177 = torch.ops.aten.view.default(mm_36, [2, 8192, 1024]); mm_36 = None + view_180 = torch.ops.aten.view.default(mm_37, [2, 8192, 1024]); mm_37 = None + view_181 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + view_182 = torch.ops.aten.view.default(view_177, [2, 8192, -1, 128]); view_177 = None + view_183 = torch.ops.aten.view.default(view_180, [2, 8192, -1, 128]); view_180 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_181, torch.float32); view_181 = None + view_184 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 32, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_184); view_184 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_182, torch.float32); view_182 = None + view_185 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 8, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_185); view_185 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_16); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_187 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 32, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_16); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_188 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 8, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_187, torch.bfloat16); view_187 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_188, torch.bfloat16); view_188 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 8, 4, 128]); unsqueeze_10 = None + clone_10 = torch.ops.aten.clone.default(expand_10, memory_format = torch.contiguous_format); expand_10 = None + view_189 = torch.ops.aten.view.default(clone_10, [2, 8192, 32, 128]); clone_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_183, 3); view_183 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 8, 4, 128]); unsqueeze_11 = None + clone_11 = torch.ops.aten.clone.default(expand_11, memory_format = torch.contiguous_format); expand_11 = None + view_190 = torch.ops.aten.view.default(clone_11, [2, 8192, 32, 128]); clone_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_189, [0, 2, 1, 3]); view_189 = None + permute_60 = torch.ops.aten.permute.default(view_190, [0, 2, 1, 3]); view_190 = None + _scaled_dot_product_cudnn_attention_backward_26 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1205, permute_58, permute_59, permute_60, getitem_45, getitem_46, getitem_51, getitem_52, None, None, None, 8192, 8192, 0.0, True); permute_1205 = permute_58 = permute_59 = permute_60 = getitem_45 = getitem_46 = getitem_51 = getitem_52 = None + getitem_366 = _scaled_dot_product_cudnn_attention_backward_26[0] + getitem_367 = _scaled_dot_product_cudnn_attention_backward_26[1] + getitem_368 = _scaled_dot_product_cudnn_attention_backward_26[2]; _scaled_dot_product_cudnn_attention_backward_26 = None + permute_1206 = torch.ops.aten.permute.default(getitem_368, [0, 2, 1, 3]); getitem_368 = None + permute_1207 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]); getitem_367 = None + permute_1208 = torch.ops.aten.permute.default(getitem_366, [0, 2, 1, 3]); getitem_366 = None + view_1728 = torch.ops.aten.view.default(permute_1206, [2, 8192, 8, 4, 128]); permute_1206 = None + sum_161 = torch.ops.aten.sum.dim_IntList(view_1728, [3], True); view_1728 = None + squeeze_52 = torch.ops.aten.squeeze.dim(sum_161, 3); sum_161 = None + view_1729 = torch.ops.aten.view.default(permute_1207, [2, 8192, 8, 4, 128]); permute_1207 = None + sum_162 = torch.ops.aten.sum.dim_IntList(view_1729, [3], True); view_1729 = None + squeeze_53 = torch.ops.aten.squeeze.dim(sum_162, 3); sum_162 = None + convert_element_type_2507 = torch.ops.prims.convert_element_type.default(squeeze_53, torch.float32); squeeze_53 = None + convert_element_type_2508 = torch.ops.prims.convert_element_type.default(permute_1208, torch.float32); permute_1208 = None + view_1730 = torch.ops.aten.view.default(convert_element_type_2507, [2, 8192, 8, 64, 2]); convert_element_type_2507 = None + view_as_complex_116 = torch.ops.aten.view_as_complex.default(view_1730); view_1730 = None + mul_796 = torch.ops.aten.mul.Tensor(view_as_complex_116, _conj); view_as_complex_116 = None + view_1731 = torch.ops.aten.view.default(convert_element_type_2508, [2, 8192, 32, 64, 2]); convert_element_type_2508 = None + view_as_complex_117 = torch.ops.aten.view_as_complex.default(view_1731); view_1731 = None + mul_797 = torch.ops.aten.mul.Tensor(view_as_complex_117, _conj); view_as_complex_117 = None + view_as_real_116 = torch.ops.aten.view_as_real.default(mul_796); mul_796 = None + view_1732 = torch.ops.aten.view.default(view_as_real_116, [2, 8192, 8, 128]); view_as_real_116 = None + convert_element_type_2509 = torch.ops.prims.convert_element_type.default(view_1732, torch.bfloat16); view_1732 = None + view_as_real_117 = torch.ops.aten.view_as_real.default(mul_797); mul_797 = None + view_1733 = torch.ops.aten.view.default(view_as_real_117, [2, 8192, 32, 128]); view_as_real_117 = None + convert_element_type_2510 = torch.ops.prims.convert_element_type.default(view_1733, torch.bfloat16); view_1733 = None + view_1734 = torch.ops.aten.view.default(squeeze_52, [2, 8192, 1024]); squeeze_52 = None + view_1735 = torch.ops.aten.view.default(convert_element_type_2509, [2, 8192, 1024]); convert_element_type_2509 = None + view_1736 = torch.ops.aten.view.default(convert_element_type_2510, [2, 8192, 4096]); convert_element_type_2510 = None + view_1737 = torch.ops.aten.view.default(view_1734, [16384, 1024]); view_1734 = None + permute_1209 = torch.ops.aten.permute.default(view_1737, [1, 0]) + mm_599 = torch.ops.aten.mm.default(permute_1209, view_173); permute_1209 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 256, '0'); convert_element_type_175 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_49, [1, 0]); wait_tensor_49 = None + permute_1211 = torch.ops.aten.permute.default(permute_57, [1, 0]); permute_57 = None + mm_600 = torch.ops.aten.mm.default(view_1737, permute_1211); view_1737 = permute_1211 = None + view_1738 = torch.ops.aten.view.default(mm_600, [2, 8192, 4096]); mm_600 = None + convert_element_type_2515 = torch.ops.prims.convert_element_type.default(mm_599, torch.float32); mm_599 = None + reduce_scatter_tensor_241 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2515, 'avg', 256, '0'); convert_element_type_2515 = None + wait_tensor_532 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_241); reduce_scatter_tensor_241 = None + view_1739 = torch.ops.aten.view.default(view_1735, [16384, 1024]); view_1735 = None + permute_1213 = torch.ops.aten.permute.default(view_1739, [1, 0]) + mm_601 = torch.ops.aten.mm.default(permute_1213, view_173); permute_1213 = None + permute_1215 = torch.ops.aten.permute.default(permute_56, [1, 0]); permute_56 = None + mm_602 = torch.ops.aten.mm.default(view_1739, permute_1215); view_1739 = permute_1215 = None + view_1740 = torch.ops.aten.view.default(mm_602, [2, 8192, 4096]); mm_602 = None + add_315 = torch.ops.aten.add.Tensor(view_1738, view_1740); view_1738 = view_1740 = None + convert_element_type_2520 = torch.ops.prims.convert_element_type.default(mm_601, torch.float32); mm_601 = None + reduce_scatter_tensor_242 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2520, 'avg', 256, '0'); convert_element_type_2520 = None + wait_tensor_533 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_242); reduce_scatter_tensor_242 = None + view_1741 = torch.ops.aten.view.default(view_1736, [16384, 4096]); view_1736 = None + permute_1217 = torch.ops.aten.permute.default(view_1741, [1, 0]) + mm_603 = torch.ops.aten.mm.default(permute_1217, view_173); permute_1217 = view_173 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 256, '0'); convert_element_type_169 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_47, [1, 0]); wait_tensor_47 = None + permute_1219 = torch.ops.aten.permute.default(permute_55, [1, 0]); permute_55 = None + mm_604 = torch.ops.aten.mm.default(view_1741, permute_1219); view_1741 = permute_1219 = None + view_1742 = torch.ops.aten.view.default(mm_604, [2, 8192, 4096]); mm_604 = None + add_316 = torch.ops.aten.add.Tensor(add_315, view_1742); add_315 = view_1742 = None + convert_element_type_2525 = torch.ops.prims.convert_element_type.default(mm_603, torch.float32); mm_603 = None + reduce_scatter_tensor_243 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2525, 'avg', 256, '0'); convert_element_type_2525 = None + wait_tensor_534 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_243); reduce_scatter_tensor_243 = None + convert_element_type_2526 = torch.ops.prims.convert_element_type.default(add_316, torch.float32); add_316 = None + convert_element_type_2528 = torch.ops.prims.convert_element_type.default(wait_tensor_46, torch.float32); wait_tensor_46 = None + mul_798 = torch.ops.aten.mul.Tensor(convert_element_type_2526, convert_element_type_2528); convert_element_type_2528 = None + mul_800 = torch.ops.aten.mul.Tensor(mul_40, mul_798) + sum_163 = torch.ops.aten.sum.dim_IntList(mul_800, [2], True); mul_800 = None + div_54 = torch.ops.aten.div.Tensor(mul_40, 4096) + mul_801 = torch.ops.aten.mul.Tensor(div_54, sum_163); div_54 = sum_163 = None + sub_81 = torch.ops.aten.sub.Tensor(mul_798, mul_801); mul_798 = mul_801 = None + mul_802 = torch.ops.aten.mul.Tensor(sub_81, rsqrt_10); sub_81 = rsqrt_10 = None + mul_803 = torch.ops.aten.mul.Tensor(convert_element_type_2526, mul_40); convert_element_type_2526 = mul_40 = None + sum_164 = torch.ops.aten.sum.dim_IntList(mul_803, [0, 1]); mul_803 = None + convert_element_type_2529 = torch.ops.prims.convert_element_type.default(mul_802, torch.bfloat16); mul_802 = None + add_317 = torch.ops.aten.add.Tensor(add_314, convert_element_type_2529); add_314 = convert_element_type_2529 = None + convert_element_type_default_11 = torch.ops.prims.convert_element_type.default(sum_164, torch.float32); sum_164 = None + reduce_scatter_tensor_244 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_11, 'avg', 256, '0'); convert_element_type_default_11 = None + wait_tensor_535 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_244); reduce_scatter_tensor_244 = None + view_1743 = torch.ops.aten.view.default(add_317, [16384, 4096]) + permute_1221 = torch.ops.aten.permute.default(view_1743, [1, 0]) + permute_50 = torch.ops.aten.permute.default(getitem_36, [0, 2, 1, 3]) + view_157 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16); primals_44 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 256, '0'); convert_element_type_149 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_41, [1, 0]); wait_tensor_41 = None + view_159 = torch.ops.aten.view.default(view_157, [16384, 4096]); view_157 = None + mm_31 = torch.ops.aten.mm.default(view_159, permute_51) + view_160 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + add_17 = torch.ops.aten.add.Tensor(add_15, view_160); view_160 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16); primals_45 = None + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 256, '0'); convert_element_type_152 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32); add_17 = None + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_42) + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + view_163 = torch.ops.aten.view.default(convert_element_type_154, [16384, 4096]); convert_element_type_154 = None + view_164 = torch.ops.aten.view.default(mm_32, [2, 8192, 14336]); mm_32 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_164, torch.float32); view_164 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 256, '0'); convert_element_type_160 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_33 = torch.ops.aten.mm.default(view_163, permute_53) + view_167 = torch.ops.aten.view.default(mm_33, [2, 8192, 14336]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_167) + view_169 = torch.ops.aten.view.default(mul_39, [16384, 14336]); mul_39 = None + mm_605 = torch.ops.aten.mm.default(permute_1221, view_169); permute_1221 = view_169 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16); primals_48 = None + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 256, '0'); convert_element_type_163 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + permute_1223 = torch.ops.aten.permute.default(permute_54, [1, 0]); permute_54 = None + mm_606 = torch.ops.aten.mm.default(view_1743, permute_1223); view_1743 = permute_1223 = None + view_1744 = torch.ops.aten.view.default(mm_606, [2, 8192, 14336]); mm_606 = None + convert_element_type_2536 = torch.ops.prims.convert_element_type.default(mm_605, torch.float32); mm_605 = None + reduce_scatter_tensor_245 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2536, 'avg', 256, '0'); convert_element_type_2536 = None + wait_tensor_536 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_245); reduce_scatter_tensor_245 = None + mul_804 = torch.ops.aten.mul.Tensor(view_1744, convert_element_type_159); convert_element_type_159 = None + mul_805 = torch.ops.aten.mul.Tensor(view_1744, view_167); view_1744 = view_167 = None + view_1745 = torch.ops.aten.view.default(mul_804, [16384, 14336]); mul_804 = None + permute_1225 = torch.ops.aten.permute.default(view_1745, [1, 0]) + mm_607 = torch.ops.aten.mm.default(permute_1225, view_163); permute_1225 = None + permute_1227 = torch.ops.aten.permute.default(permute_53, [1, 0]); permute_53 = None + mm_608 = torch.ops.aten.mm.default(view_1745, permute_1227); view_1745 = permute_1227 = None + view_1746 = torch.ops.aten.view.default(mm_608, [2, 8192, 4096]); mm_608 = None + convert_element_type_2541 = torch.ops.prims.convert_element_type.default(mm_607, torch.float32); mm_607 = None + reduce_scatter_tensor_246 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2541, 'avg', 256, '0'); convert_element_type_2541 = None + wait_tensor_537 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_246); reduce_scatter_tensor_246 = None + convert_element_type_2542 = torch.ops.prims.convert_element_type.default(mul_805, torch.float32); mul_805 = None + neg_27 = torch.ops.aten.neg.default(convert_element_type_158) + exp_27 = torch.ops.aten.exp.default(neg_27); neg_27 = None + add_318 = torch.ops.aten.add.Tensor(exp_27, 1); exp_27 = None + reciprocal_27 = torch.ops.aten.reciprocal.default(add_318); add_318 = None + mul_806 = torch.ops.aten.mul.Tensor(reciprocal_27, 1); reciprocal_27 = None + mul_807 = torch.ops.aten.mul.Tensor(convert_element_type_2542, mul_806); convert_element_type_2542 = None + sub_82 = torch.ops.aten.sub.Tensor(1, mul_806); mul_806 = None + mul_808 = torch.ops.aten.mul.Tensor(convert_element_type_158, sub_82); convert_element_type_158 = sub_82 = None + add_319 = torch.ops.aten.add.Tensor(mul_808, 1); mul_808 = None + mul_809 = torch.ops.aten.mul.Tensor(mul_807, add_319); mul_807 = add_319 = None + convert_element_type_2544 = torch.ops.prims.convert_element_type.default(mul_809, torch.bfloat16); mul_809 = None + view_1747 = torch.ops.aten.view.default(convert_element_type_2544, [16384, 14336]); convert_element_type_2544 = None + permute_1229 = torch.ops.aten.permute.default(view_1747, [1, 0]) + mm_609 = torch.ops.aten.mm.default(permute_1229, view_163); permute_1229 = view_163 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16); primals_46 = None + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 256, '0'); convert_element_type_155 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + permute_1231 = torch.ops.aten.permute.default(permute_52, [1, 0]); permute_52 = None + mm_610 = torch.ops.aten.mm.default(view_1747, permute_1231); view_1747 = permute_1231 = None + view_1748 = torch.ops.aten.view.default(mm_610, [2, 8192, 4096]); mm_610 = None + add_320 = torch.ops.aten.add.Tensor(view_1746, view_1748); view_1746 = view_1748 = None + convert_element_type_2549 = torch.ops.prims.convert_element_type.default(mm_609, torch.float32); mm_609 = None + reduce_scatter_tensor_247 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2549, 'avg', 256, '0'); convert_element_type_2549 = None + wait_tensor_538 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_247); reduce_scatter_tensor_247 = None + convert_element_type_2550 = torch.ops.prims.convert_element_type.default(add_320, torch.float32); add_320 = None + convert_element_type_2552 = torch.ops.prims.convert_element_type.default(wait_tensor_42, torch.float32); wait_tensor_42 = None + mul_810 = torch.ops.aten.mul.Tensor(convert_element_type_2550, convert_element_type_2552); convert_element_type_2552 = None + mul_812 = torch.ops.aten.mul.Tensor(mul_36, mul_810) + sum_165 = torch.ops.aten.sum.dim_IntList(mul_812, [2], True); mul_812 = None + div_55 = torch.ops.aten.div.Tensor(mul_36, 4096) + mul_813 = torch.ops.aten.mul.Tensor(div_55, sum_165); div_55 = sum_165 = None + sub_83 = torch.ops.aten.sub.Tensor(mul_810, mul_813); mul_810 = mul_813 = None + mul_814 = torch.ops.aten.mul.Tensor(sub_83, rsqrt_9); sub_83 = rsqrt_9 = None + mul_815 = torch.ops.aten.mul.Tensor(convert_element_type_2550, mul_36); convert_element_type_2550 = mul_36 = None + sum_166 = torch.ops.aten.sum.dim_IntList(mul_815, [0, 1]); mul_815 = None + convert_element_type_2553 = torch.ops.prims.convert_element_type.default(mul_814, torch.bfloat16); mul_814 = None + add_321 = torch.ops.aten.add.Tensor(add_317, convert_element_type_2553); add_317 = convert_element_type_2553 = None + convert_element_type_default_10 = torch.ops.prims.convert_element_type.default(sum_166, torch.float32); sum_166 = None + reduce_scatter_tensor_248 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_10, 'avg', 256, '0'); convert_element_type_default_10 = None + wait_tensor_539 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_248); reduce_scatter_tensor_248 = None + view_1749 = torch.ops.aten.view.default(add_321, [16384, 4096]) + permute_1233 = torch.ops.aten.permute.default(view_1749, [1, 0]) + mm_611 = torch.ops.aten.mm.default(permute_1233, view_159); permute_1233 = view_159 = None + permute_1235 = torch.ops.aten.permute.default(permute_51, [1, 0]); permute_51 = None + mm_612 = torch.ops.aten.mm.default(view_1749, permute_1235); view_1749 = permute_1235 = None + view_1750 = torch.ops.aten.view.default(mm_612, [2, 8192, 4096]); mm_612 = None + convert_element_type_2560 = torch.ops.prims.convert_element_type.default(mm_611, torch.float32); mm_611 = None + reduce_scatter_tensor_249 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2560, 'avg', 256, '0'); convert_element_type_2560 = None + wait_tensor_540 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_249); reduce_scatter_tensor_249 = None + view_1751 = torch.ops.aten.view.default(view_1750, [2, 8192, 32, 128]); view_1750 = None + permute_1237 = torch.ops.aten.permute.default(view_1751, [0, 2, 1, 3]); view_1751 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16); primals_40 = None + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 256, '0'); convert_element_type_133 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32); add_15 = None + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_37) + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + view_139 = torch.ops.aten.view.default(convert_element_type_135, [16384, 4096]); convert_element_type_135 = None + view_140 = torch.ops.aten.view.default(mm_28, [2, 8192, 4096]); mm_28 = None + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16); primals_42 = None + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 256, '0'); convert_element_type_139 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + mm_29 = torch.ops.aten.mm.default(view_139, permute_45) + view_143 = torch.ops.aten.view.default(mm_29, [2, 8192, 1024]); mm_29 = None + view_146 = torch.ops.aten.view.default(mm_30, [2, 8192, 1024]); mm_30 = None + view_147 = torch.ops.aten.view.default(view_140, [2, 8192, -1, 128]); view_140 = None + view_148 = torch.ops.aten.view.default(view_143, [2, 8192, -1, 128]); view_143 = None + view_149 = torch.ops.aten.view.default(view_146, [2, 8192, -1, 128]); view_146 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_147, torch.float32); view_147 = None + view_150 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 32, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_150); view_150 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_148, torch.float32); view_148 = None + view_151 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 8, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_151); view_151 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_16); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_153 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 32, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_16); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_154 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 8, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_153, torch.bfloat16); view_153 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_154, torch.bfloat16); view_154 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 8, 4, 128]); unsqueeze_8 = None + clone_8 = torch.ops.aten.clone.default(expand_8, memory_format = torch.contiguous_format); expand_8 = None + view_155 = torch.ops.aten.view.default(clone_8, [2, 8192, 32, 128]); clone_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_149, 3); view_149 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 8, 4, 128]); unsqueeze_9 = None + clone_9 = torch.ops.aten.clone.default(expand_9, memory_format = torch.contiguous_format); expand_9 = None + view_156 = torch.ops.aten.view.default(clone_9, [2, 8192, 32, 128]); clone_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_155, [0, 2, 1, 3]); view_155 = None + permute_49 = torch.ops.aten.permute.default(view_156, [0, 2, 1, 3]); view_156 = None + _scaled_dot_product_cudnn_attention_backward_27 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1237, permute_47, permute_48, permute_49, getitem_36, getitem_37, getitem_42, getitem_43, None, None, None, 8192, 8192, 0.0, True); permute_1237 = permute_47 = permute_48 = permute_49 = getitem_36 = getitem_37 = getitem_42 = getitem_43 = None + getitem_369 = _scaled_dot_product_cudnn_attention_backward_27[0] + getitem_370 = _scaled_dot_product_cudnn_attention_backward_27[1] + getitem_371 = _scaled_dot_product_cudnn_attention_backward_27[2]; _scaled_dot_product_cudnn_attention_backward_27 = None + permute_1238 = torch.ops.aten.permute.default(getitem_371, [0, 2, 1, 3]); getitem_371 = None + permute_1239 = torch.ops.aten.permute.default(getitem_370, [0, 2, 1, 3]); getitem_370 = None + permute_1240 = torch.ops.aten.permute.default(getitem_369, [0, 2, 1, 3]); getitem_369 = None + view_1752 = torch.ops.aten.view.default(permute_1238, [2, 8192, 8, 4, 128]); permute_1238 = None + sum_167 = torch.ops.aten.sum.dim_IntList(view_1752, [3], True); view_1752 = None + squeeze_54 = torch.ops.aten.squeeze.dim(sum_167, 3); sum_167 = None + view_1753 = torch.ops.aten.view.default(permute_1239, [2, 8192, 8, 4, 128]); permute_1239 = None + sum_168 = torch.ops.aten.sum.dim_IntList(view_1753, [3], True); view_1753 = None + squeeze_55 = torch.ops.aten.squeeze.dim(sum_168, 3); sum_168 = None + convert_element_type_2561 = torch.ops.prims.convert_element_type.default(squeeze_55, torch.float32); squeeze_55 = None + convert_element_type_2562 = torch.ops.prims.convert_element_type.default(permute_1240, torch.float32); permute_1240 = None + view_1754 = torch.ops.aten.view.default(convert_element_type_2561, [2, 8192, 8, 64, 2]); convert_element_type_2561 = None + view_as_complex_118 = torch.ops.aten.view_as_complex.default(view_1754); view_1754 = None + mul_816 = torch.ops.aten.mul.Tensor(view_as_complex_118, _conj); view_as_complex_118 = None + view_1755 = torch.ops.aten.view.default(convert_element_type_2562, [2, 8192, 32, 64, 2]); convert_element_type_2562 = None + view_as_complex_119 = torch.ops.aten.view_as_complex.default(view_1755); view_1755 = None + mul_817 = torch.ops.aten.mul.Tensor(view_as_complex_119, _conj); view_as_complex_119 = None + view_as_real_118 = torch.ops.aten.view_as_real.default(mul_816); mul_816 = None + view_1756 = torch.ops.aten.view.default(view_as_real_118, [2, 8192, 8, 128]); view_as_real_118 = None + convert_element_type_2563 = torch.ops.prims.convert_element_type.default(view_1756, torch.bfloat16); view_1756 = None + view_as_real_119 = torch.ops.aten.view_as_real.default(mul_817); mul_817 = None + view_1757 = torch.ops.aten.view.default(view_as_real_119, [2, 8192, 32, 128]); view_as_real_119 = None + convert_element_type_2564 = torch.ops.prims.convert_element_type.default(view_1757, torch.bfloat16); view_1757 = None + view_1758 = torch.ops.aten.view.default(squeeze_54, [2, 8192, 1024]); squeeze_54 = None + view_1759 = torch.ops.aten.view.default(convert_element_type_2563, [2, 8192, 1024]); convert_element_type_2563 = None + view_1760 = torch.ops.aten.view.default(convert_element_type_2564, [2, 8192, 4096]); convert_element_type_2564 = None + view_1761 = torch.ops.aten.view.default(view_1758, [16384, 1024]); view_1758 = None + permute_1241 = torch.ops.aten.permute.default(view_1761, [1, 0]) + mm_613 = torch.ops.aten.mm.default(permute_1241, view_139); permute_1241 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16); primals_43 = None + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 256, '0'); convert_element_type_142 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_40, [1, 0]); wait_tensor_40 = None + permute_1243 = torch.ops.aten.permute.default(permute_46, [1, 0]); permute_46 = None + mm_614 = torch.ops.aten.mm.default(view_1761, permute_1243); view_1761 = permute_1243 = None + view_1762 = torch.ops.aten.view.default(mm_614, [2, 8192, 4096]); mm_614 = None + convert_element_type_2569 = torch.ops.prims.convert_element_type.default(mm_613, torch.float32); mm_613 = None + reduce_scatter_tensor_250 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2569, 'avg', 256, '0'); convert_element_type_2569 = None + wait_tensor_541 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_250); reduce_scatter_tensor_250 = None + view_1763 = torch.ops.aten.view.default(view_1759, [16384, 1024]); view_1759 = None + permute_1245 = torch.ops.aten.permute.default(view_1763, [1, 0]) + mm_615 = torch.ops.aten.mm.default(permute_1245, view_139); permute_1245 = None + permute_1247 = torch.ops.aten.permute.default(permute_45, [1, 0]); permute_45 = None + mm_616 = torch.ops.aten.mm.default(view_1763, permute_1247); view_1763 = permute_1247 = None + view_1764 = torch.ops.aten.view.default(mm_616, [2, 8192, 4096]); mm_616 = None + add_322 = torch.ops.aten.add.Tensor(view_1762, view_1764); view_1762 = view_1764 = None + convert_element_type_2574 = torch.ops.prims.convert_element_type.default(mm_615, torch.float32); mm_615 = None + reduce_scatter_tensor_251 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2574, 'avg', 256, '0'); convert_element_type_2574 = None + wait_tensor_542 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_251); reduce_scatter_tensor_251 = None + view_1765 = torch.ops.aten.view.default(view_1760, [16384, 4096]); view_1760 = None + permute_1249 = torch.ops.aten.permute.default(view_1765, [1, 0]) + mm_617 = torch.ops.aten.mm.default(permute_1249, view_139); permute_1249 = view_139 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16); primals_41 = None + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 256, '0'); convert_element_type_136 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + permute_1251 = torch.ops.aten.permute.default(permute_44, [1, 0]); permute_44 = None + mm_618 = torch.ops.aten.mm.default(view_1765, permute_1251); view_1765 = permute_1251 = None + view_1766 = torch.ops.aten.view.default(mm_618, [2, 8192, 4096]); mm_618 = None + add_323 = torch.ops.aten.add.Tensor(add_322, view_1766); add_322 = view_1766 = None + convert_element_type_2579 = torch.ops.prims.convert_element_type.default(mm_617, torch.float32); mm_617 = None + reduce_scatter_tensor_252 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2579, 'avg', 256, '0'); convert_element_type_2579 = None + wait_tensor_543 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_252); reduce_scatter_tensor_252 = None + convert_element_type_2580 = torch.ops.prims.convert_element_type.default(add_323, torch.float32); add_323 = None + convert_element_type_2582 = torch.ops.prims.convert_element_type.default(wait_tensor_37, torch.float32); wait_tensor_37 = None + mul_818 = torch.ops.aten.mul.Tensor(convert_element_type_2580, convert_element_type_2582); convert_element_type_2582 = None + mul_820 = torch.ops.aten.mul.Tensor(mul_32, mul_818) + sum_169 = torch.ops.aten.sum.dim_IntList(mul_820, [2], True); mul_820 = None + div_56 = torch.ops.aten.div.Tensor(mul_32, 4096) + mul_821 = torch.ops.aten.mul.Tensor(div_56, sum_169); div_56 = sum_169 = None + sub_84 = torch.ops.aten.sub.Tensor(mul_818, mul_821); mul_818 = mul_821 = None + mul_822 = torch.ops.aten.mul.Tensor(sub_84, rsqrt_8); sub_84 = rsqrt_8 = None + mul_823 = torch.ops.aten.mul.Tensor(convert_element_type_2580, mul_32); convert_element_type_2580 = mul_32 = None + sum_170 = torch.ops.aten.sum.dim_IntList(mul_823, [0, 1]); mul_823 = None + convert_element_type_2583 = torch.ops.prims.convert_element_type.default(mul_822, torch.bfloat16); mul_822 = None + add_324 = torch.ops.aten.add.Tensor(add_321, convert_element_type_2583); add_321 = convert_element_type_2583 = None + convert_element_type_default_9 = torch.ops.prims.convert_element_type.default(sum_170, torch.float32); sum_170 = None + reduce_scatter_tensor_253 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_9, 'avg', 256, '0'); convert_element_type_default_9 = None + wait_tensor_544 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_253); reduce_scatter_tensor_253 = None + view_1767 = torch.ops.aten.view.default(add_324, [16384, 4096]) + permute_1253 = torch.ops.aten.permute.default(view_1767, [1, 0]) + permute_39 = torch.ops.aten.permute.default(getitem_27, [0, 2, 1, 3]) + view_123 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 256, '0'); convert_element_type_116 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + view_125 = torch.ops.aten.view.default(view_123, [16384, 4096]); view_123 = None + mm_24 = torch.ops.aten.mm.default(view_125, permute_40) + view_126 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + add_13 = torch.ops.aten.add.Tensor(add_11, view_126); view_126 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 256, '0'); convert_element_type_119 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32); add_13 = None + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_33) + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + view_129 = torch.ops.aten.view.default(convert_element_type_121, [16384, 4096]); convert_element_type_121 = None + view_130 = torch.ops.aten.view.default(mm_25, [2, 8192, 14336]); mm_25 = None + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_130, torch.float32); view_130 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16); primals_38 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 256, '0'); convert_element_type_127 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_35, [1, 0]); wait_tensor_35 = None + mm_26 = torch.ops.aten.mm.default(view_129, permute_42) + view_133 = torch.ops.aten.view.default(mm_26, [2, 8192, 14336]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_133) + view_135 = torch.ops.aten.view.default(mul_31, [16384, 14336]); mul_31 = None + mm_619 = torch.ops.aten.mm.default(permute_1253, view_135); permute_1253 = view_135 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16); primals_39 = None + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 256, '0'); convert_element_type_130 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + permute_1255 = torch.ops.aten.permute.default(permute_43, [1, 0]); permute_43 = None + mm_620 = torch.ops.aten.mm.default(view_1767, permute_1255); view_1767 = permute_1255 = None + view_1768 = torch.ops.aten.view.default(mm_620, [2, 8192, 14336]); mm_620 = None + convert_element_type_2590 = torch.ops.prims.convert_element_type.default(mm_619, torch.float32); mm_619 = None + reduce_scatter_tensor_254 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2590, 'avg', 256, '0'); convert_element_type_2590 = None + wait_tensor_545 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_254); reduce_scatter_tensor_254 = None + mul_824 = torch.ops.aten.mul.Tensor(view_1768, convert_element_type_126); convert_element_type_126 = None + mul_825 = torch.ops.aten.mul.Tensor(view_1768, view_133); view_1768 = view_133 = None + view_1769 = torch.ops.aten.view.default(mul_824, [16384, 14336]); mul_824 = None + permute_1257 = torch.ops.aten.permute.default(view_1769, [1, 0]) + mm_621 = torch.ops.aten.mm.default(permute_1257, view_129); permute_1257 = None + permute_1259 = torch.ops.aten.permute.default(permute_42, [1, 0]); permute_42 = None + mm_622 = torch.ops.aten.mm.default(view_1769, permute_1259); view_1769 = permute_1259 = None + view_1770 = torch.ops.aten.view.default(mm_622, [2, 8192, 4096]); mm_622 = None + convert_element_type_2595 = torch.ops.prims.convert_element_type.default(mm_621, torch.float32); mm_621 = None + reduce_scatter_tensor_255 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2595, 'avg', 256, '0'); convert_element_type_2595 = None + wait_tensor_546 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_255); reduce_scatter_tensor_255 = None + convert_element_type_2596 = torch.ops.prims.convert_element_type.default(mul_825, torch.float32); mul_825 = None + neg_28 = torch.ops.aten.neg.default(convert_element_type_125) + exp_28 = torch.ops.aten.exp.default(neg_28); neg_28 = None + add_325 = torch.ops.aten.add.Tensor(exp_28, 1); exp_28 = None + reciprocal_28 = torch.ops.aten.reciprocal.default(add_325); add_325 = None + mul_826 = torch.ops.aten.mul.Tensor(reciprocal_28, 1); reciprocal_28 = None + mul_827 = torch.ops.aten.mul.Tensor(convert_element_type_2596, mul_826); convert_element_type_2596 = None + sub_85 = torch.ops.aten.sub.Tensor(1, mul_826); mul_826 = None + mul_828 = torch.ops.aten.mul.Tensor(convert_element_type_125, sub_85); convert_element_type_125 = sub_85 = None + add_326 = torch.ops.aten.add.Tensor(mul_828, 1); mul_828 = None + mul_829 = torch.ops.aten.mul.Tensor(mul_827, add_326); mul_827 = add_326 = None + convert_element_type_2598 = torch.ops.prims.convert_element_type.default(mul_829, torch.bfloat16); mul_829 = None + view_1771 = torch.ops.aten.view.default(convert_element_type_2598, [16384, 14336]); convert_element_type_2598 = None + permute_1261 = torch.ops.aten.permute.default(view_1771, [1, 0]) + mm_623 = torch.ops.aten.mm.default(permute_1261, view_129); permute_1261 = view_129 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16); primals_37 = None + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 256, '0'); convert_element_type_122 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_34, [1, 0]); wait_tensor_34 = None + permute_1263 = torch.ops.aten.permute.default(permute_41, [1, 0]); permute_41 = None + mm_624 = torch.ops.aten.mm.default(view_1771, permute_1263); view_1771 = permute_1263 = None + view_1772 = torch.ops.aten.view.default(mm_624, [2, 8192, 4096]); mm_624 = None + add_327 = torch.ops.aten.add.Tensor(view_1770, view_1772); view_1770 = view_1772 = None + convert_element_type_2603 = torch.ops.prims.convert_element_type.default(mm_623, torch.float32); mm_623 = None + reduce_scatter_tensor_256 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2603, 'avg', 256, '0'); convert_element_type_2603 = None + wait_tensor_547 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_256); reduce_scatter_tensor_256 = None + convert_element_type_2604 = torch.ops.prims.convert_element_type.default(add_327, torch.float32); add_327 = None + convert_element_type_2606 = torch.ops.prims.convert_element_type.default(wait_tensor_33, torch.float32); wait_tensor_33 = None + mul_830 = torch.ops.aten.mul.Tensor(convert_element_type_2604, convert_element_type_2606); convert_element_type_2606 = None + mul_832 = torch.ops.aten.mul.Tensor(mul_28, mul_830) + sum_171 = torch.ops.aten.sum.dim_IntList(mul_832, [2], True); mul_832 = None + div_57 = torch.ops.aten.div.Tensor(mul_28, 4096) + mul_833 = torch.ops.aten.mul.Tensor(div_57, sum_171); div_57 = sum_171 = None + sub_86 = torch.ops.aten.sub.Tensor(mul_830, mul_833); mul_830 = mul_833 = None + mul_834 = torch.ops.aten.mul.Tensor(sub_86, rsqrt_7); sub_86 = rsqrt_7 = None + mul_835 = torch.ops.aten.mul.Tensor(convert_element_type_2604, mul_28); convert_element_type_2604 = mul_28 = None + sum_172 = torch.ops.aten.sum.dim_IntList(mul_835, [0, 1]); mul_835 = None + convert_element_type_2607 = torch.ops.prims.convert_element_type.default(mul_834, torch.bfloat16); mul_834 = None + add_328 = torch.ops.aten.add.Tensor(add_324, convert_element_type_2607); add_324 = convert_element_type_2607 = None + convert_element_type_default_8 = torch.ops.prims.convert_element_type.default(sum_172, torch.float32); sum_172 = None + reduce_scatter_tensor_257 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_8, 'avg', 256, '0'); convert_element_type_default_8 = None + wait_tensor_548 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_257); reduce_scatter_tensor_257 = None + view_1773 = torch.ops.aten.view.default(add_328, [16384, 4096]) + permute_1265 = torch.ops.aten.permute.default(view_1773, [1, 0]) + mm_625 = torch.ops.aten.mm.default(permute_1265, view_125); permute_1265 = view_125 = None + permute_1267 = torch.ops.aten.permute.default(permute_40, [1, 0]); permute_40 = None + mm_626 = torch.ops.aten.mm.default(view_1773, permute_1267); view_1773 = permute_1267 = None + view_1774 = torch.ops.aten.view.default(mm_626, [2, 8192, 4096]); mm_626 = None + convert_element_type_2614 = torch.ops.prims.convert_element_type.default(mm_625, torch.float32); mm_625 = None + reduce_scatter_tensor_258 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2614, 'avg', 256, '0'); convert_element_type_2614 = None + wait_tensor_549 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_258); reduce_scatter_tensor_258 = None + view_1775 = torch.ops.aten.view.default(view_1774, [2, 8192, 32, 128]); view_1774 = None + permute_1269 = torch.ops.aten.permute.default(view_1775, [0, 2, 1, 3]); view_1775 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 256, '0'); convert_element_type_100 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32); add_11 = None + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_28) + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + view_105 = torch.ops.aten.view.default(convert_element_type_102, [16384, 4096]); convert_element_type_102 = None + view_106 = torch.ops.aten.view.default(mm_21, [2, 8192, 4096]); mm_21 = None + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 256, '0'); convert_element_type_106 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + mm_22 = torch.ops.aten.mm.default(view_105, permute_34) + view_109 = torch.ops.aten.view.default(mm_22, [2, 8192, 1024]); mm_22 = None + view_112 = torch.ops.aten.view.default(mm_23, [2, 8192, 1024]); mm_23 = None + view_113 = torch.ops.aten.view.default(view_106, [2, 8192, -1, 128]); view_106 = None + view_114 = torch.ops.aten.view.default(view_109, [2, 8192, -1, 128]); view_109 = None + view_115 = torch.ops.aten.view.default(view_112, [2, 8192, -1, 128]); view_112 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_113, torch.float32); view_113 = None + view_116 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 32, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_116); view_116 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_114, torch.float32); view_114 = None + view_117 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 8, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_117); view_117 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_16); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_119 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 32, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_16); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_120 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 8, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_119, torch.bfloat16); view_119 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_120, torch.bfloat16); view_120 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 8, 4, 128]); unsqueeze_6 = None + clone_6 = torch.ops.aten.clone.default(expand_6, memory_format = torch.contiguous_format); expand_6 = None + view_121 = torch.ops.aten.view.default(clone_6, [2, 8192, 32, 128]); clone_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_115, 3); view_115 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 8, 4, 128]); unsqueeze_7 = None + clone_7 = torch.ops.aten.clone.default(expand_7, memory_format = torch.contiguous_format); expand_7 = None + view_122 = torch.ops.aten.view.default(clone_7, [2, 8192, 32, 128]); clone_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_121, [0, 2, 1, 3]); view_121 = None + permute_38 = torch.ops.aten.permute.default(view_122, [0, 2, 1, 3]); view_122 = None + _scaled_dot_product_cudnn_attention_backward_28 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1269, permute_36, permute_37, permute_38, getitem_27, getitem_28, getitem_33, getitem_34, None, None, None, 8192, 8192, 0.0, True); permute_1269 = permute_36 = permute_37 = permute_38 = getitem_27 = getitem_28 = getitem_33 = getitem_34 = None + getitem_372 = _scaled_dot_product_cudnn_attention_backward_28[0] + getitem_373 = _scaled_dot_product_cudnn_attention_backward_28[1] + getitem_374 = _scaled_dot_product_cudnn_attention_backward_28[2]; _scaled_dot_product_cudnn_attention_backward_28 = None + permute_1270 = torch.ops.aten.permute.default(getitem_374, [0, 2, 1, 3]); getitem_374 = None + permute_1271 = torch.ops.aten.permute.default(getitem_373, [0, 2, 1, 3]); getitem_373 = None + permute_1272 = torch.ops.aten.permute.default(getitem_372, [0, 2, 1, 3]); getitem_372 = None + view_1776 = torch.ops.aten.view.default(permute_1270, [2, 8192, 8, 4, 128]); permute_1270 = None + sum_173 = torch.ops.aten.sum.dim_IntList(view_1776, [3], True); view_1776 = None + squeeze_56 = torch.ops.aten.squeeze.dim(sum_173, 3); sum_173 = None + view_1777 = torch.ops.aten.view.default(permute_1271, [2, 8192, 8, 4, 128]); permute_1271 = None + sum_174 = torch.ops.aten.sum.dim_IntList(view_1777, [3], True); view_1777 = None + squeeze_57 = torch.ops.aten.squeeze.dim(sum_174, 3); sum_174 = None + convert_element_type_2615 = torch.ops.prims.convert_element_type.default(squeeze_57, torch.float32); squeeze_57 = None + convert_element_type_2616 = torch.ops.prims.convert_element_type.default(permute_1272, torch.float32); permute_1272 = None + view_1778 = torch.ops.aten.view.default(convert_element_type_2615, [2, 8192, 8, 64, 2]); convert_element_type_2615 = None + view_as_complex_120 = torch.ops.aten.view_as_complex.default(view_1778); view_1778 = None + mul_836 = torch.ops.aten.mul.Tensor(view_as_complex_120, _conj); view_as_complex_120 = None + view_1779 = torch.ops.aten.view.default(convert_element_type_2616, [2, 8192, 32, 64, 2]); convert_element_type_2616 = None + view_as_complex_121 = torch.ops.aten.view_as_complex.default(view_1779); view_1779 = None + mul_837 = torch.ops.aten.mul.Tensor(view_as_complex_121, _conj); view_as_complex_121 = None + view_as_real_120 = torch.ops.aten.view_as_real.default(mul_836); mul_836 = None + view_1780 = torch.ops.aten.view.default(view_as_real_120, [2, 8192, 8, 128]); view_as_real_120 = None + convert_element_type_2617 = torch.ops.prims.convert_element_type.default(view_1780, torch.bfloat16); view_1780 = None + view_as_real_121 = torch.ops.aten.view_as_real.default(mul_837); mul_837 = None + view_1781 = torch.ops.aten.view.default(view_as_real_121, [2, 8192, 32, 128]); view_as_real_121 = None + convert_element_type_2618 = torch.ops.prims.convert_element_type.default(view_1781, torch.bfloat16); view_1781 = None + view_1782 = torch.ops.aten.view.default(squeeze_56, [2, 8192, 1024]); squeeze_56 = None + view_1783 = torch.ops.aten.view.default(convert_element_type_2617, [2, 8192, 1024]); convert_element_type_2617 = None + view_1784 = torch.ops.aten.view.default(convert_element_type_2618, [2, 8192, 4096]); convert_element_type_2618 = None + view_1785 = torch.ops.aten.view.default(view_1782, [16384, 1024]); view_1782 = None + permute_1273 = torch.ops.aten.permute.default(view_1785, [1, 0]) + mm_627 = torch.ops.aten.mm.default(permute_1273, view_105); permute_1273 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 256, '0'); convert_element_type_109 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + permute_1275 = torch.ops.aten.permute.default(permute_35, [1, 0]); permute_35 = None + mm_628 = torch.ops.aten.mm.default(view_1785, permute_1275); view_1785 = permute_1275 = None + view_1786 = torch.ops.aten.view.default(mm_628, [2, 8192, 4096]); mm_628 = None + convert_element_type_2623 = torch.ops.prims.convert_element_type.default(mm_627, torch.float32); mm_627 = None + reduce_scatter_tensor_259 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2623, 'avg', 256, '0'); convert_element_type_2623 = None + wait_tensor_550 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_259); reduce_scatter_tensor_259 = None + view_1787 = torch.ops.aten.view.default(view_1783, [16384, 1024]); view_1783 = None + permute_1277 = torch.ops.aten.permute.default(view_1787, [1, 0]) + mm_629 = torch.ops.aten.mm.default(permute_1277, view_105); permute_1277 = None + permute_1279 = torch.ops.aten.permute.default(permute_34, [1, 0]); permute_34 = None + mm_630 = torch.ops.aten.mm.default(view_1787, permute_1279); view_1787 = permute_1279 = None + view_1788 = torch.ops.aten.view.default(mm_630, [2, 8192, 4096]); mm_630 = None + add_329 = torch.ops.aten.add.Tensor(view_1786, view_1788); view_1786 = view_1788 = None + convert_element_type_2628 = torch.ops.prims.convert_element_type.default(mm_629, torch.float32); mm_629 = None + reduce_scatter_tensor_260 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2628, 'avg', 256, '0'); convert_element_type_2628 = None + wait_tensor_551 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_260); reduce_scatter_tensor_260 = None + view_1789 = torch.ops.aten.view.default(view_1784, [16384, 4096]); view_1784 = None + permute_1281 = torch.ops.aten.permute.default(view_1789, [1, 0]) + mm_631 = torch.ops.aten.mm.default(permute_1281, view_105); permute_1281 = view_105 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16); primals_32 = None + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 256, '0'); convert_element_type_103 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + permute_1283 = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None + mm_632 = torch.ops.aten.mm.default(view_1789, permute_1283); view_1789 = permute_1283 = None + view_1790 = torch.ops.aten.view.default(mm_632, [2, 8192, 4096]); mm_632 = None + add_330 = torch.ops.aten.add.Tensor(add_329, view_1790); add_329 = view_1790 = None + convert_element_type_2633 = torch.ops.prims.convert_element_type.default(mm_631, torch.float32); mm_631 = None + reduce_scatter_tensor_261 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2633, 'avg', 256, '0'); convert_element_type_2633 = None + wait_tensor_552 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_261); reduce_scatter_tensor_261 = None + convert_element_type_2634 = torch.ops.prims.convert_element_type.default(add_330, torch.float32); add_330 = None + convert_element_type_2636 = torch.ops.prims.convert_element_type.default(wait_tensor_28, torch.float32); wait_tensor_28 = None + mul_838 = torch.ops.aten.mul.Tensor(convert_element_type_2634, convert_element_type_2636); convert_element_type_2636 = None + mul_840 = torch.ops.aten.mul.Tensor(mul_24, mul_838) + sum_175 = torch.ops.aten.sum.dim_IntList(mul_840, [2], True); mul_840 = None + div_58 = torch.ops.aten.div.Tensor(mul_24, 4096) + mul_841 = torch.ops.aten.mul.Tensor(div_58, sum_175); div_58 = sum_175 = None + sub_87 = torch.ops.aten.sub.Tensor(mul_838, mul_841); mul_838 = mul_841 = None + mul_842 = torch.ops.aten.mul.Tensor(sub_87, rsqrt_6); sub_87 = rsqrt_6 = None + mul_843 = torch.ops.aten.mul.Tensor(convert_element_type_2634, mul_24); convert_element_type_2634 = mul_24 = None + sum_176 = torch.ops.aten.sum.dim_IntList(mul_843, [0, 1]); mul_843 = None + convert_element_type_2637 = torch.ops.prims.convert_element_type.default(mul_842, torch.bfloat16); mul_842 = None + add_331 = torch.ops.aten.add.Tensor(add_328, convert_element_type_2637); add_328 = convert_element_type_2637 = None + convert_element_type_default_7 = torch.ops.prims.convert_element_type.default(sum_176, torch.float32); sum_176 = None + reduce_scatter_tensor_262 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_7, 'avg', 256, '0'); convert_element_type_default_7 = None + wait_tensor_553 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_262); reduce_scatter_tensor_262 = None + view_1791 = torch.ops.aten.view.default(add_331, [16384, 4096]) + permute_1285 = torch.ops.aten.permute.default(view_1791, [1, 0]) + permute_28 = torch.ops.aten.permute.default(getitem_18, [0, 2, 1, 3]) + view_89 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16); primals_26 = None + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 256, '0'); convert_element_type_83 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_23, [1, 0]); wait_tensor_23 = None + view_91 = torch.ops.aten.view.default(view_89, [16384, 4096]); view_89 = None + mm_17 = torch.ops.aten.mm.default(view_91, permute_29) + view_92 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + add_9 = torch.ops.aten.add.Tensor(add_7, view_92); view_92 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16); primals_27 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 256, '0'); convert_element_type_86 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32); add_9 = None + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_24) + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + view_95 = torch.ops.aten.view.default(convert_element_type_88, [16384, 4096]); convert_element_type_88 = None + view_96 = torch.ops.aten.view.default(mm_18, [2, 8192, 14336]); mm_18 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_96, torch.float32); view_96 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 256, '0'); convert_element_type_94 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + mm_19 = torch.ops.aten.mm.default(view_95, permute_31) + view_99 = torch.ops.aten.view.default(mm_19, [2, 8192, 14336]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_99) + view_101 = torch.ops.aten.view.default(mul_23, [16384, 14336]); mul_23 = None + mm_633 = torch.ops.aten.mm.default(permute_1285, view_101); permute_1285 = view_101 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16); primals_30 = None + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 256, '0'); convert_element_type_97 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_27, [1, 0]); wait_tensor_27 = None + permute_1287 = torch.ops.aten.permute.default(permute_32, [1, 0]); permute_32 = None + mm_634 = torch.ops.aten.mm.default(view_1791, permute_1287); view_1791 = permute_1287 = None + view_1792 = torch.ops.aten.view.default(mm_634, [2, 8192, 14336]); mm_634 = None + convert_element_type_2644 = torch.ops.prims.convert_element_type.default(mm_633, torch.float32); mm_633 = None + reduce_scatter_tensor_263 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2644, 'avg', 256, '0'); convert_element_type_2644 = None + wait_tensor_554 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_263); reduce_scatter_tensor_263 = None + mul_844 = torch.ops.aten.mul.Tensor(view_1792, convert_element_type_93); convert_element_type_93 = None + mul_845 = torch.ops.aten.mul.Tensor(view_1792, view_99); view_1792 = view_99 = None + view_1793 = torch.ops.aten.view.default(mul_844, [16384, 14336]); mul_844 = None + permute_1289 = torch.ops.aten.permute.default(view_1793, [1, 0]) + mm_635 = torch.ops.aten.mm.default(permute_1289, view_95); permute_1289 = None + permute_1291 = torch.ops.aten.permute.default(permute_31, [1, 0]); permute_31 = None + mm_636 = torch.ops.aten.mm.default(view_1793, permute_1291); view_1793 = permute_1291 = None + view_1794 = torch.ops.aten.view.default(mm_636, [2, 8192, 4096]); mm_636 = None + convert_element_type_2649 = torch.ops.prims.convert_element_type.default(mm_635, torch.float32); mm_635 = None + reduce_scatter_tensor_264 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2649, 'avg', 256, '0'); convert_element_type_2649 = None + wait_tensor_555 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_264); reduce_scatter_tensor_264 = None + convert_element_type_2650 = torch.ops.prims.convert_element_type.default(mul_845, torch.float32); mul_845 = None + neg_29 = torch.ops.aten.neg.default(convert_element_type_92) + exp_29 = torch.ops.aten.exp.default(neg_29); neg_29 = None + add_332 = torch.ops.aten.add.Tensor(exp_29, 1); exp_29 = None + reciprocal_29 = torch.ops.aten.reciprocal.default(add_332); add_332 = None + mul_846 = torch.ops.aten.mul.Tensor(reciprocal_29, 1); reciprocal_29 = None + mul_847 = torch.ops.aten.mul.Tensor(convert_element_type_2650, mul_846); convert_element_type_2650 = None + sub_88 = torch.ops.aten.sub.Tensor(1, mul_846); mul_846 = None + mul_848 = torch.ops.aten.mul.Tensor(convert_element_type_92, sub_88); convert_element_type_92 = sub_88 = None + add_333 = torch.ops.aten.add.Tensor(mul_848, 1); mul_848 = None + mul_849 = torch.ops.aten.mul.Tensor(mul_847, add_333); mul_847 = add_333 = None + convert_element_type_2652 = torch.ops.prims.convert_element_type.default(mul_849, torch.bfloat16); mul_849 = None + view_1795 = torch.ops.aten.view.default(convert_element_type_2652, [16384, 14336]); convert_element_type_2652 = None + permute_1293 = torch.ops.aten.permute.default(view_1795, [1, 0]) + mm_637 = torch.ops.aten.mm.default(permute_1293, view_95); permute_1293 = view_95 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 256, '0'); convert_element_type_89 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + permute_1295 = torch.ops.aten.permute.default(permute_30, [1, 0]); permute_30 = None + mm_638 = torch.ops.aten.mm.default(view_1795, permute_1295); view_1795 = permute_1295 = None + view_1796 = torch.ops.aten.view.default(mm_638, [2, 8192, 4096]); mm_638 = None + add_334 = torch.ops.aten.add.Tensor(view_1794, view_1796); view_1794 = view_1796 = None + convert_element_type_2657 = torch.ops.prims.convert_element_type.default(mm_637, torch.float32); mm_637 = None + reduce_scatter_tensor_265 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2657, 'avg', 256, '0'); convert_element_type_2657 = None + wait_tensor_556 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_265); reduce_scatter_tensor_265 = None + convert_element_type_2658 = torch.ops.prims.convert_element_type.default(add_334, torch.float32); add_334 = None + convert_element_type_2660 = torch.ops.prims.convert_element_type.default(wait_tensor_24, torch.float32); wait_tensor_24 = None + mul_850 = torch.ops.aten.mul.Tensor(convert_element_type_2658, convert_element_type_2660); convert_element_type_2660 = None + mul_852 = torch.ops.aten.mul.Tensor(mul_20, mul_850) + sum_177 = torch.ops.aten.sum.dim_IntList(mul_852, [2], True); mul_852 = None + div_59 = torch.ops.aten.div.Tensor(mul_20, 4096) + mul_853 = torch.ops.aten.mul.Tensor(div_59, sum_177); div_59 = sum_177 = None + sub_89 = torch.ops.aten.sub.Tensor(mul_850, mul_853); mul_850 = mul_853 = None + mul_854 = torch.ops.aten.mul.Tensor(sub_89, rsqrt_5); sub_89 = rsqrt_5 = None + mul_855 = torch.ops.aten.mul.Tensor(convert_element_type_2658, mul_20); convert_element_type_2658 = mul_20 = None + sum_178 = torch.ops.aten.sum.dim_IntList(mul_855, [0, 1]); mul_855 = None + convert_element_type_2661 = torch.ops.prims.convert_element_type.default(mul_854, torch.bfloat16); mul_854 = None + add_335 = torch.ops.aten.add.Tensor(add_331, convert_element_type_2661); add_331 = convert_element_type_2661 = None + convert_element_type_default_6 = torch.ops.prims.convert_element_type.default(sum_178, torch.float32); sum_178 = None + reduce_scatter_tensor_266 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_6, 'avg', 256, '0'); convert_element_type_default_6 = None + wait_tensor_557 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_266); reduce_scatter_tensor_266 = None + view_1797 = torch.ops.aten.view.default(add_335, [16384, 4096]) + permute_1297 = torch.ops.aten.permute.default(view_1797, [1, 0]) + mm_639 = torch.ops.aten.mm.default(permute_1297, view_91); permute_1297 = view_91 = None + permute_1299 = torch.ops.aten.permute.default(permute_29, [1, 0]); permute_29 = None + mm_640 = torch.ops.aten.mm.default(view_1797, permute_1299); view_1797 = permute_1299 = None + view_1798 = torch.ops.aten.view.default(mm_640, [2, 8192, 4096]); mm_640 = None + convert_element_type_2668 = torch.ops.prims.convert_element_type.default(mm_639, torch.float32); mm_639 = None + reduce_scatter_tensor_267 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2668, 'avg', 256, '0'); convert_element_type_2668 = None + wait_tensor_558 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_267); reduce_scatter_tensor_267 = None + view_1799 = torch.ops.aten.view.default(view_1798, [2, 8192, 32, 128]); view_1798 = None + permute_1301 = torch.ops.aten.permute.default(view_1799, [0, 2, 1, 3]); view_1799 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16); primals_22 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 256, '0'); convert_element_type_67 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32); add_7 = None + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_19) + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + view_71 = torch.ops.aten.view.default(convert_element_type_69, [16384, 4096]); convert_element_type_69 = None + view_72 = torch.ops.aten.view.default(mm_14, [2, 8192, 4096]); mm_14 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16); primals_24 = None + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 256, '0'); convert_element_type_73 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_21, [1, 0]); wait_tensor_21 = None + mm_15 = torch.ops.aten.mm.default(view_71, permute_23) + view_75 = torch.ops.aten.view.default(mm_15, [2, 8192, 1024]); mm_15 = None + view_78 = torch.ops.aten.view.default(mm_16, [2, 8192, 1024]); mm_16 = None + view_79 = torch.ops.aten.view.default(view_72, [2, 8192, -1, 128]); view_72 = None + view_80 = torch.ops.aten.view.default(view_75, [2, 8192, -1, 128]); view_75 = None + view_81 = torch.ops.aten.view.default(view_78, [2, 8192, -1, 128]); view_78 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_79, torch.float32); view_79 = None + view_82 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 32, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_82); view_82 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_80, torch.float32); view_80 = None + view_83 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 8, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_83); view_83 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_16); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_85 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 32, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_16); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_86 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 8, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_85, torch.bfloat16); view_85 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_86, torch.bfloat16); view_86 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 8, 4, 128]); unsqueeze_4 = None + clone_4 = torch.ops.aten.clone.default(expand_4, memory_format = torch.contiguous_format); expand_4 = None + view_87 = torch.ops.aten.view.default(clone_4, [2, 8192, 32, 128]); clone_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_81, 3); view_81 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 8, 4, 128]); unsqueeze_5 = None + clone_5 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format); expand_5 = None + view_88 = torch.ops.aten.view.default(clone_5, [2, 8192, 32, 128]); clone_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_87, [0, 2, 1, 3]); view_87 = None + permute_27 = torch.ops.aten.permute.default(view_88, [0, 2, 1, 3]); view_88 = None + _scaled_dot_product_cudnn_attention_backward_29 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1301, permute_25, permute_26, permute_27, getitem_18, getitem_19, getitem_24, getitem_25, None, None, None, 8192, 8192, 0.0, True); permute_1301 = permute_25 = permute_26 = permute_27 = getitem_18 = getitem_19 = getitem_24 = getitem_25 = None + getitem_375 = _scaled_dot_product_cudnn_attention_backward_29[0] + getitem_376 = _scaled_dot_product_cudnn_attention_backward_29[1] + getitem_377 = _scaled_dot_product_cudnn_attention_backward_29[2]; _scaled_dot_product_cudnn_attention_backward_29 = None + permute_1302 = torch.ops.aten.permute.default(getitem_377, [0, 2, 1, 3]); getitem_377 = None + permute_1303 = torch.ops.aten.permute.default(getitem_376, [0, 2, 1, 3]); getitem_376 = None + permute_1304 = torch.ops.aten.permute.default(getitem_375, [0, 2, 1, 3]); getitem_375 = None + view_1800 = torch.ops.aten.view.default(permute_1302, [2, 8192, 8, 4, 128]); permute_1302 = None + sum_179 = torch.ops.aten.sum.dim_IntList(view_1800, [3], True); view_1800 = None + squeeze_58 = torch.ops.aten.squeeze.dim(sum_179, 3); sum_179 = None + view_1801 = torch.ops.aten.view.default(permute_1303, [2, 8192, 8, 4, 128]); permute_1303 = None + sum_180 = torch.ops.aten.sum.dim_IntList(view_1801, [3], True); view_1801 = None + squeeze_59 = torch.ops.aten.squeeze.dim(sum_180, 3); sum_180 = None + convert_element_type_2669 = torch.ops.prims.convert_element_type.default(squeeze_59, torch.float32); squeeze_59 = None + convert_element_type_2670 = torch.ops.prims.convert_element_type.default(permute_1304, torch.float32); permute_1304 = None + view_1802 = torch.ops.aten.view.default(convert_element_type_2669, [2, 8192, 8, 64, 2]); convert_element_type_2669 = None + view_as_complex_122 = torch.ops.aten.view_as_complex.default(view_1802); view_1802 = None + mul_856 = torch.ops.aten.mul.Tensor(view_as_complex_122, _conj); view_as_complex_122 = None + view_1803 = torch.ops.aten.view.default(convert_element_type_2670, [2, 8192, 32, 64, 2]); convert_element_type_2670 = None + view_as_complex_123 = torch.ops.aten.view_as_complex.default(view_1803); view_1803 = None + mul_857 = torch.ops.aten.mul.Tensor(view_as_complex_123, _conj); view_as_complex_123 = None + view_as_real_122 = torch.ops.aten.view_as_real.default(mul_856); mul_856 = None + view_1804 = torch.ops.aten.view.default(view_as_real_122, [2, 8192, 8, 128]); view_as_real_122 = None + convert_element_type_2671 = torch.ops.prims.convert_element_type.default(view_1804, torch.bfloat16); view_1804 = None + view_as_real_123 = torch.ops.aten.view_as_real.default(mul_857); mul_857 = None + view_1805 = torch.ops.aten.view.default(view_as_real_123, [2, 8192, 32, 128]); view_as_real_123 = None + convert_element_type_2672 = torch.ops.prims.convert_element_type.default(view_1805, torch.bfloat16); view_1805 = None + view_1806 = torch.ops.aten.view.default(squeeze_58, [2, 8192, 1024]); squeeze_58 = None + view_1807 = torch.ops.aten.view.default(convert_element_type_2671, [2, 8192, 1024]); convert_element_type_2671 = None + view_1808 = torch.ops.aten.view.default(convert_element_type_2672, [2, 8192, 4096]); convert_element_type_2672 = None + view_1809 = torch.ops.aten.view.default(view_1806, [16384, 1024]); view_1806 = None + permute_1305 = torch.ops.aten.permute.default(view_1809, [1, 0]) + mm_641 = torch.ops.aten.mm.default(permute_1305, view_71); permute_1305 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16); primals_25 = None + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 256, '0'); convert_element_type_76 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_22, [1, 0]); wait_tensor_22 = None + permute_1307 = torch.ops.aten.permute.default(permute_24, [1, 0]); permute_24 = None + mm_642 = torch.ops.aten.mm.default(view_1809, permute_1307); view_1809 = permute_1307 = None + view_1810 = torch.ops.aten.view.default(mm_642, [2, 8192, 4096]); mm_642 = None + convert_element_type_2677 = torch.ops.prims.convert_element_type.default(mm_641, torch.float32); mm_641 = None + reduce_scatter_tensor_268 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2677, 'avg', 256, '0'); convert_element_type_2677 = None + wait_tensor_559 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_268); reduce_scatter_tensor_268 = None + view_1811 = torch.ops.aten.view.default(view_1807, [16384, 1024]); view_1807 = None + permute_1309 = torch.ops.aten.permute.default(view_1811, [1, 0]) + mm_643 = torch.ops.aten.mm.default(permute_1309, view_71); permute_1309 = None + permute_1311 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None + mm_644 = torch.ops.aten.mm.default(view_1811, permute_1311); view_1811 = permute_1311 = None + view_1812 = torch.ops.aten.view.default(mm_644, [2, 8192, 4096]); mm_644 = None + add_336 = torch.ops.aten.add.Tensor(view_1810, view_1812); view_1810 = view_1812 = None + convert_element_type_2682 = torch.ops.prims.convert_element_type.default(mm_643, torch.float32); mm_643 = None + reduce_scatter_tensor_269 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2682, 'avg', 256, '0'); convert_element_type_2682 = None + wait_tensor_560 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_269); reduce_scatter_tensor_269 = None + view_1813 = torch.ops.aten.view.default(view_1808, [16384, 4096]); view_1808 = None + permute_1313 = torch.ops.aten.permute.default(view_1813, [1, 0]) + mm_645 = torch.ops.aten.mm.default(permute_1313, view_71); permute_1313 = view_71 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16); primals_23 = None + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 256, '0'); convert_element_type_70 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + permute_1315 = torch.ops.aten.permute.default(permute_22, [1, 0]); permute_22 = None + mm_646 = torch.ops.aten.mm.default(view_1813, permute_1315); view_1813 = permute_1315 = None + view_1814 = torch.ops.aten.view.default(mm_646, [2, 8192, 4096]); mm_646 = None + add_337 = torch.ops.aten.add.Tensor(add_336, view_1814); add_336 = view_1814 = None + convert_element_type_2687 = torch.ops.prims.convert_element_type.default(mm_645, torch.float32); mm_645 = None + reduce_scatter_tensor_270 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2687, 'avg', 256, '0'); convert_element_type_2687 = None + wait_tensor_561 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_270); reduce_scatter_tensor_270 = None + convert_element_type_2688 = torch.ops.prims.convert_element_type.default(add_337, torch.float32); add_337 = None + convert_element_type_2690 = torch.ops.prims.convert_element_type.default(wait_tensor_19, torch.float32); wait_tensor_19 = None + mul_858 = torch.ops.aten.mul.Tensor(convert_element_type_2688, convert_element_type_2690); convert_element_type_2690 = None + mul_860 = torch.ops.aten.mul.Tensor(mul_16, mul_858) + sum_181 = torch.ops.aten.sum.dim_IntList(mul_860, [2], True); mul_860 = None + div_60 = torch.ops.aten.div.Tensor(mul_16, 4096) + mul_861 = torch.ops.aten.mul.Tensor(div_60, sum_181); div_60 = sum_181 = None + sub_90 = torch.ops.aten.sub.Tensor(mul_858, mul_861); mul_858 = mul_861 = None + mul_862 = torch.ops.aten.mul.Tensor(sub_90, rsqrt_4); sub_90 = rsqrt_4 = None + mul_863 = torch.ops.aten.mul.Tensor(convert_element_type_2688, mul_16); convert_element_type_2688 = mul_16 = None + sum_182 = torch.ops.aten.sum.dim_IntList(mul_863, [0, 1]); mul_863 = None + convert_element_type_2691 = torch.ops.prims.convert_element_type.default(mul_862, torch.bfloat16); mul_862 = None + add_338 = torch.ops.aten.add.Tensor(add_335, convert_element_type_2691); add_335 = convert_element_type_2691 = None + convert_element_type_default_5 = torch.ops.prims.convert_element_type.default(sum_182, torch.float32); sum_182 = None + reduce_scatter_tensor_271 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_5, 'avg', 256, '0'); convert_element_type_default_5 = None + wait_tensor_562 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_271); reduce_scatter_tensor_271 = None + view_1815 = torch.ops.aten.view.default(add_338, [16384, 4096]) + permute_1317 = torch.ops.aten.permute.default(view_1815, [1, 0]) + permute_17 = torch.ops.aten.permute.default(getitem_9, [0, 2, 1, 3]) + view_55 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16); primals_17 = None + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 256, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_14, [1, 0]); wait_tensor_14 = None + view_57 = torch.ops.aten.view.default(view_55, [16384, 4096]); view_55 = None + mm_10 = torch.ops.aten.mm.default(view_57, permute_18) + view_58 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + add_5 = torch.ops.aten.add.Tensor(add_3, view_58); view_58 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16); primals_18 = None + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 256, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32); add_5 = None + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_15) + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + view_61 = torch.ops.aten.view.default(convert_element_type_55, [16384, 4096]); convert_element_type_55 = None + view_62 = torch.ops.aten.view.default(mm_11, [2, 8192, 14336]); mm_11 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_62, torch.float32); view_62 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16); primals_20 = None + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 256, '0'); convert_element_type_61 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + mm_12 = torch.ops.aten.mm.default(view_61, permute_20) + view_65 = torch.ops.aten.view.default(mm_12, [2, 8192, 14336]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_65) + view_67 = torch.ops.aten.view.default(mul_15, [16384, 14336]); mul_15 = None + mm_647 = torch.ops.aten.mm.default(permute_1317, view_67); permute_1317 = view_67 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16); primals_21 = None + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 256, '0'); convert_element_type_64 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + permute_1319 = torch.ops.aten.permute.default(permute_21, [1, 0]); permute_21 = None + mm_648 = torch.ops.aten.mm.default(view_1815, permute_1319); view_1815 = permute_1319 = None + view_1816 = torch.ops.aten.view.default(mm_648, [2, 8192, 14336]); mm_648 = None + convert_element_type_2698 = torch.ops.prims.convert_element_type.default(mm_647, torch.float32); mm_647 = None + reduce_scatter_tensor_272 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2698, 'avg', 256, '0'); convert_element_type_2698 = None + wait_tensor_563 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_272); reduce_scatter_tensor_272 = None + mul_864 = torch.ops.aten.mul.Tensor(view_1816, convert_element_type_60); convert_element_type_60 = None + mul_865 = torch.ops.aten.mul.Tensor(view_1816, view_65); view_1816 = view_65 = None + view_1817 = torch.ops.aten.view.default(mul_864, [16384, 14336]); mul_864 = None + permute_1321 = torch.ops.aten.permute.default(view_1817, [1, 0]) + mm_649 = torch.ops.aten.mm.default(permute_1321, view_61); permute_1321 = None + permute_1323 = torch.ops.aten.permute.default(permute_20, [1, 0]); permute_20 = None + mm_650 = torch.ops.aten.mm.default(view_1817, permute_1323); view_1817 = permute_1323 = None + view_1818 = torch.ops.aten.view.default(mm_650, [2, 8192, 4096]); mm_650 = None + convert_element_type_2703 = torch.ops.prims.convert_element_type.default(mm_649, torch.float32); mm_649 = None + reduce_scatter_tensor_273 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2703, 'avg', 256, '0'); convert_element_type_2703 = None + wait_tensor_564 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_273); reduce_scatter_tensor_273 = None + convert_element_type_2704 = torch.ops.prims.convert_element_type.default(mul_865, torch.float32); mul_865 = None + neg_30 = torch.ops.aten.neg.default(convert_element_type_59) + exp_30 = torch.ops.aten.exp.default(neg_30); neg_30 = None + add_339 = torch.ops.aten.add.Tensor(exp_30, 1); exp_30 = None + reciprocal_30 = torch.ops.aten.reciprocal.default(add_339); add_339 = None + mul_866 = torch.ops.aten.mul.Tensor(reciprocal_30, 1); reciprocal_30 = None + mul_867 = torch.ops.aten.mul.Tensor(convert_element_type_2704, mul_866); convert_element_type_2704 = None + sub_91 = torch.ops.aten.sub.Tensor(1, mul_866); mul_866 = None + mul_868 = torch.ops.aten.mul.Tensor(convert_element_type_59, sub_91); convert_element_type_59 = sub_91 = None + add_340 = torch.ops.aten.add.Tensor(mul_868, 1); mul_868 = None + mul_869 = torch.ops.aten.mul.Tensor(mul_867, add_340); mul_867 = add_340 = None + convert_element_type_2706 = torch.ops.prims.convert_element_type.default(mul_869, torch.bfloat16); mul_869 = None + view_1819 = torch.ops.aten.view.default(convert_element_type_2706, [16384, 14336]); convert_element_type_2706 = None + permute_1325 = torch.ops.aten.permute.default(view_1819, [1, 0]) + mm_651 = torch.ops.aten.mm.default(permute_1325, view_61); permute_1325 = view_61 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16); primals_19 = None + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 256, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + permute_1327 = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None + mm_652 = torch.ops.aten.mm.default(view_1819, permute_1327); view_1819 = permute_1327 = None + view_1820 = torch.ops.aten.view.default(mm_652, [2, 8192, 4096]); mm_652 = None + add_341 = torch.ops.aten.add.Tensor(view_1818, view_1820); view_1818 = view_1820 = None + convert_element_type_2711 = torch.ops.prims.convert_element_type.default(mm_651, torch.float32); mm_651 = None + reduce_scatter_tensor_274 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2711, 'avg', 256, '0'); convert_element_type_2711 = None + wait_tensor_565 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_274); reduce_scatter_tensor_274 = None + convert_element_type_2712 = torch.ops.prims.convert_element_type.default(add_341, torch.float32); add_341 = None + convert_element_type_2714 = torch.ops.prims.convert_element_type.default(wait_tensor_15, torch.float32); wait_tensor_15 = None + mul_870 = torch.ops.aten.mul.Tensor(convert_element_type_2712, convert_element_type_2714); convert_element_type_2714 = None + mul_872 = torch.ops.aten.mul.Tensor(mul_12, mul_870) + sum_183 = torch.ops.aten.sum.dim_IntList(mul_872, [2], True); mul_872 = None + div_61 = torch.ops.aten.div.Tensor(mul_12, 4096) + mul_873 = torch.ops.aten.mul.Tensor(div_61, sum_183); div_61 = sum_183 = None + sub_92 = torch.ops.aten.sub.Tensor(mul_870, mul_873); mul_870 = mul_873 = None + mul_874 = torch.ops.aten.mul.Tensor(sub_92, rsqrt_3); sub_92 = rsqrt_3 = None + mul_875 = torch.ops.aten.mul.Tensor(convert_element_type_2712, mul_12); convert_element_type_2712 = mul_12 = None + sum_184 = torch.ops.aten.sum.dim_IntList(mul_875, [0, 1]); mul_875 = None + convert_element_type_2715 = torch.ops.prims.convert_element_type.default(mul_874, torch.bfloat16); mul_874 = None + add_342 = torch.ops.aten.add.Tensor(add_338, convert_element_type_2715); add_338 = convert_element_type_2715 = None + convert_element_type_default_4 = torch.ops.prims.convert_element_type.default(sum_184, torch.float32); sum_184 = None + reduce_scatter_tensor_275 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_4, 'avg', 256, '0'); convert_element_type_default_4 = None + wait_tensor_566 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_275); reduce_scatter_tensor_275 = None + view_1821 = torch.ops.aten.view.default(add_342, [16384, 4096]) + permute_1329 = torch.ops.aten.permute.default(view_1821, [1, 0]) + mm_653 = torch.ops.aten.mm.default(permute_1329, view_57); permute_1329 = view_57 = None + permute_1331 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None + mm_654 = torch.ops.aten.mm.default(view_1821, permute_1331); view_1821 = permute_1331 = None + view_1822 = torch.ops.aten.view.default(mm_654, [2, 8192, 4096]); mm_654 = None + convert_element_type_2722 = torch.ops.prims.convert_element_type.default(mm_653, torch.float32); mm_653 = None + reduce_scatter_tensor_276 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2722, 'avg', 256, '0'); convert_element_type_2722 = None + wait_tensor_567 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_276); reduce_scatter_tensor_276 = None + view_1823 = torch.ops.aten.view.default(view_1822, [2, 8192, 32, 128]); view_1822 = None + permute_1333 = torch.ops.aten.permute.default(view_1823, [0, 2, 1, 3]); view_1823 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16); primals_13 = None + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 256, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32); add_3 = None + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_10) + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + view_37 = torch.ops.aten.view.default(convert_element_type_36, [16384, 4096]); convert_element_type_36 = None + view_38 = torch.ops.aten.view.default(mm_7, [2, 8192, 4096]); mm_7 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16); primals_15 = None + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 256, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_8 = torch.ops.aten.mm.default(view_37, permute_12) + view_41 = torch.ops.aten.view.default(mm_8, [2, 8192, 1024]); mm_8 = None + view_44 = torch.ops.aten.view.default(mm_9, [2, 8192, 1024]); mm_9 = None + view_45 = torch.ops.aten.view.default(view_38, [2, 8192, -1, 128]); view_38 = None + view_46 = torch.ops.aten.view.default(view_41, [2, 8192, -1, 128]); view_41 = None + view_47 = torch.ops.aten.view.default(view_44, [2, 8192, -1, 128]); view_44 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_45, torch.float32); view_45 = None + view_48 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 32, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_48); view_48 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_46, torch.float32); view_46 = None + view_49 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 8, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_49); view_49 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_16); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_51 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 32, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_16); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_52 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 8, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_51, torch.bfloat16); view_51 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_52, torch.bfloat16); view_52 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 8, 4, 128]); unsqueeze_2 = None + clone_2 = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format); expand_2 = None + view_53 = torch.ops.aten.view.default(clone_2, [2, 8192, 32, 128]); clone_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_47, 3); view_47 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 8, 4, 128]); unsqueeze_3 = None + clone_3 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None + view_54 = torch.ops.aten.view.default(clone_3, [2, 8192, 32, 128]); clone_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_53, [0, 2, 1, 3]); view_53 = None + permute_16 = torch.ops.aten.permute.default(view_54, [0, 2, 1, 3]); view_54 = None + _scaled_dot_product_cudnn_attention_backward_30 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1333, permute_14, permute_15, permute_16, getitem_9, getitem_10, getitem_15, getitem_16, None, None, None, 8192, 8192, 0.0, True); permute_1333 = permute_14 = permute_15 = permute_16 = getitem_9 = getitem_10 = getitem_15 = getitem_16 = None + getitem_378 = _scaled_dot_product_cudnn_attention_backward_30[0] + getitem_379 = _scaled_dot_product_cudnn_attention_backward_30[1] + getitem_380 = _scaled_dot_product_cudnn_attention_backward_30[2]; _scaled_dot_product_cudnn_attention_backward_30 = None + permute_1334 = torch.ops.aten.permute.default(getitem_380, [0, 2, 1, 3]); getitem_380 = None + permute_1335 = torch.ops.aten.permute.default(getitem_379, [0, 2, 1, 3]); getitem_379 = None + permute_1336 = torch.ops.aten.permute.default(getitem_378, [0, 2, 1, 3]); getitem_378 = None + view_1824 = torch.ops.aten.view.default(permute_1334, [2, 8192, 8, 4, 128]); permute_1334 = None + sum_185 = torch.ops.aten.sum.dim_IntList(view_1824, [3], True); view_1824 = None + squeeze_60 = torch.ops.aten.squeeze.dim(sum_185, 3); sum_185 = None + view_1825 = torch.ops.aten.view.default(permute_1335, [2, 8192, 8, 4, 128]); permute_1335 = None + sum_186 = torch.ops.aten.sum.dim_IntList(view_1825, [3], True); view_1825 = None + squeeze_61 = torch.ops.aten.squeeze.dim(sum_186, 3); sum_186 = None + convert_element_type_2723 = torch.ops.prims.convert_element_type.default(squeeze_61, torch.float32); squeeze_61 = None + convert_element_type_2724 = torch.ops.prims.convert_element_type.default(permute_1336, torch.float32); permute_1336 = None + view_1826 = torch.ops.aten.view.default(convert_element_type_2723, [2, 8192, 8, 64, 2]); convert_element_type_2723 = None + view_as_complex_124 = torch.ops.aten.view_as_complex.default(view_1826); view_1826 = None + mul_876 = torch.ops.aten.mul.Tensor(view_as_complex_124, _conj); view_as_complex_124 = None + view_1827 = torch.ops.aten.view.default(convert_element_type_2724, [2, 8192, 32, 64, 2]); convert_element_type_2724 = None + view_as_complex_125 = torch.ops.aten.view_as_complex.default(view_1827); view_1827 = None + mul_877 = torch.ops.aten.mul.Tensor(view_as_complex_125, _conj); view_as_complex_125 = None + view_as_real_124 = torch.ops.aten.view_as_real.default(mul_876); mul_876 = None + view_1828 = torch.ops.aten.view.default(view_as_real_124, [2, 8192, 8, 128]); view_as_real_124 = None + convert_element_type_2725 = torch.ops.prims.convert_element_type.default(view_1828, torch.bfloat16); view_1828 = None + view_as_real_125 = torch.ops.aten.view_as_real.default(mul_877); mul_877 = None + view_1829 = torch.ops.aten.view.default(view_as_real_125, [2, 8192, 32, 128]); view_as_real_125 = None + convert_element_type_2726 = torch.ops.prims.convert_element_type.default(view_1829, torch.bfloat16); view_1829 = None + view_1830 = torch.ops.aten.view.default(squeeze_60, [2, 8192, 1024]); squeeze_60 = None + view_1831 = torch.ops.aten.view.default(convert_element_type_2725, [2, 8192, 1024]); convert_element_type_2725 = None + view_1832 = torch.ops.aten.view.default(convert_element_type_2726, [2, 8192, 4096]); convert_element_type_2726 = None + view_1833 = torch.ops.aten.view.default(view_1830, [16384, 1024]); view_1830 = None + permute_1337 = torch.ops.aten.permute.default(view_1833, [1, 0]) + mm_655 = torch.ops.aten.mm.default(permute_1337, view_37); permute_1337 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16); primals_16 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 256, '0'); convert_element_type_43 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + permute_1339 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None + mm_656 = torch.ops.aten.mm.default(view_1833, permute_1339); view_1833 = permute_1339 = None + view_1834 = torch.ops.aten.view.default(mm_656, [2, 8192, 4096]); mm_656 = None + convert_element_type_2731 = torch.ops.prims.convert_element_type.default(mm_655, torch.float32); mm_655 = None + reduce_scatter_tensor_277 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2731, 'avg', 256, '0'); convert_element_type_2731 = None + wait_tensor_568 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_277); reduce_scatter_tensor_277 = None + view_1835 = torch.ops.aten.view.default(view_1831, [16384, 1024]); view_1831 = None + permute_1341 = torch.ops.aten.permute.default(view_1835, [1, 0]) + mm_657 = torch.ops.aten.mm.default(permute_1341, view_37); permute_1341 = None + permute_1343 = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None + mm_658 = torch.ops.aten.mm.default(view_1835, permute_1343); view_1835 = permute_1343 = None + view_1836 = torch.ops.aten.view.default(mm_658, [2, 8192, 4096]); mm_658 = None + add_343 = torch.ops.aten.add.Tensor(view_1834, view_1836); view_1834 = view_1836 = None + convert_element_type_2736 = torch.ops.prims.convert_element_type.default(mm_657, torch.float32); mm_657 = None + reduce_scatter_tensor_278 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2736, 'avg', 256, '0'); convert_element_type_2736 = None + wait_tensor_569 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_278); reduce_scatter_tensor_278 = None + view_1837 = torch.ops.aten.view.default(view_1832, [16384, 4096]); view_1832 = None + permute_1345 = torch.ops.aten.permute.default(view_1837, [1, 0]) + mm_659 = torch.ops.aten.mm.default(permute_1345, view_37); permute_1345 = view_37 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16); primals_14 = None + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 256, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + permute_1347 = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None + mm_660 = torch.ops.aten.mm.default(view_1837, permute_1347); view_1837 = permute_1347 = None + view_1838 = torch.ops.aten.view.default(mm_660, [2, 8192, 4096]); mm_660 = None + add_344 = torch.ops.aten.add.Tensor(add_343, view_1838); add_343 = view_1838 = None + convert_element_type_2741 = torch.ops.prims.convert_element_type.default(mm_659, torch.float32); mm_659 = None + reduce_scatter_tensor_279 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2741, 'avg', 256, '0'); convert_element_type_2741 = None + wait_tensor_570 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_279); reduce_scatter_tensor_279 = None + convert_element_type_2742 = torch.ops.prims.convert_element_type.default(add_344, torch.float32); add_344 = None + convert_element_type_2744 = torch.ops.prims.convert_element_type.default(wait_tensor_10, torch.float32); wait_tensor_10 = None + mul_878 = torch.ops.aten.mul.Tensor(convert_element_type_2742, convert_element_type_2744); convert_element_type_2744 = None + mul_880 = torch.ops.aten.mul.Tensor(mul_8, mul_878) + sum_187 = torch.ops.aten.sum.dim_IntList(mul_880, [2], True); mul_880 = None + div_62 = torch.ops.aten.div.Tensor(mul_8, 4096) + mul_881 = torch.ops.aten.mul.Tensor(div_62, sum_187); div_62 = sum_187 = None + sub_93 = torch.ops.aten.sub.Tensor(mul_878, mul_881); mul_878 = mul_881 = None + mul_882 = torch.ops.aten.mul.Tensor(sub_93, rsqrt_2); sub_93 = rsqrt_2 = None + mul_883 = torch.ops.aten.mul.Tensor(convert_element_type_2742, mul_8); convert_element_type_2742 = mul_8 = None + sum_188 = torch.ops.aten.sum.dim_IntList(mul_883, [0, 1]); mul_883 = None + convert_element_type_2745 = torch.ops.prims.convert_element_type.default(mul_882, torch.bfloat16); mul_882 = None + add_345 = torch.ops.aten.add.Tensor(add_342, convert_element_type_2745); add_342 = convert_element_type_2745 = None + convert_element_type_default_3 = torch.ops.prims.convert_element_type.default(sum_188, torch.float32); sum_188 = None + reduce_scatter_tensor_280 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_3, 'avg', 256, '0'); convert_element_type_default_3 = None + wait_tensor_571 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_280); reduce_scatter_tensor_280 = None + view_1839 = torch.ops.aten.view.default(add_345, [16384, 4096]) + permute_1349 = torch.ops.aten.permute.default(view_1839, [1, 0]) + permute_6 = torch.ops.aten.permute.default(getitem, [0, 2, 1, 3]) + view_21 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16); primals_8 = None + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 256, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + view_23 = torch.ops.aten.view.default(view_21, [16384, 4096]); view_21 = None + mm_3 = torch.ops.aten.mm.default(view_23, permute_7) + view_24 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + add_1 = torch.ops.aten.add.Tensor(embedding, view_24); view_24 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16); primals_9 = None + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 256, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32); add_1 = None + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_6) + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + view_27 = torch.ops.aten.view.default(convert_element_type_22, [16384, 4096]); convert_element_type_22 = None + view_28 = torch.ops.aten.view.default(mm_4, [2, 8192, 14336]); mm_4 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_28, torch.float32); view_28 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16); primals_11 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 256, '0'); convert_element_type_28 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_8, [1, 0]); wait_tensor_8 = None + mm_5 = torch.ops.aten.mm.default(view_27, permute_9) + view_31 = torch.ops.aten.view.default(mm_5, [2, 8192, 14336]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_31) + view_33 = torch.ops.aten.view.default(mul_7, [16384, 14336]); mul_7 = None + mm_661 = torch.ops.aten.mm.default(permute_1349, view_33); permute_1349 = view_33 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16); primals_12 = None + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 256, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_9, [1, 0]); wait_tensor_9 = None + permute_1351 = torch.ops.aten.permute.default(permute_10, [1, 0]); permute_10 = None + mm_662 = torch.ops.aten.mm.default(view_1839, permute_1351); view_1839 = permute_1351 = None + view_1840 = torch.ops.aten.view.default(mm_662, [2, 8192, 14336]); mm_662 = None + convert_element_type_2752 = torch.ops.prims.convert_element_type.default(mm_661, torch.float32); mm_661 = None + reduce_scatter_tensor_281 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2752, 'avg', 256, '0'); convert_element_type_2752 = None + wait_tensor_572 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_281); reduce_scatter_tensor_281 = None + mul_884 = torch.ops.aten.mul.Tensor(view_1840, convert_element_type_27); convert_element_type_27 = None + mul_885 = torch.ops.aten.mul.Tensor(view_1840, view_31); view_1840 = view_31 = None + view_1841 = torch.ops.aten.view.default(mul_884, [16384, 14336]); mul_884 = None + permute_1353 = torch.ops.aten.permute.default(view_1841, [1, 0]) + mm_663 = torch.ops.aten.mm.default(permute_1353, view_27); permute_1353 = None + permute_1355 = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None + mm_664 = torch.ops.aten.mm.default(view_1841, permute_1355); view_1841 = permute_1355 = None + view_1842 = torch.ops.aten.view.default(mm_664, [2, 8192, 4096]); mm_664 = None + convert_element_type_2757 = torch.ops.prims.convert_element_type.default(mm_663, torch.float32); mm_663 = None + reduce_scatter_tensor_282 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2757, 'avg', 256, '0'); convert_element_type_2757 = None + wait_tensor_573 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_282); reduce_scatter_tensor_282 = None + convert_element_type_2758 = torch.ops.prims.convert_element_type.default(mul_885, torch.float32); mul_885 = None + neg_31 = torch.ops.aten.neg.default(convert_element_type_26) + exp_31 = torch.ops.aten.exp.default(neg_31); neg_31 = None + add_346 = torch.ops.aten.add.Tensor(exp_31, 1); exp_31 = None + reciprocal_31 = torch.ops.aten.reciprocal.default(add_346); add_346 = None + mul_886 = torch.ops.aten.mul.Tensor(reciprocal_31, 1); reciprocal_31 = None + mul_887 = torch.ops.aten.mul.Tensor(convert_element_type_2758, mul_886); convert_element_type_2758 = None + sub_94 = torch.ops.aten.sub.Tensor(1, mul_886); mul_886 = None + mul_888 = torch.ops.aten.mul.Tensor(convert_element_type_26, sub_94); convert_element_type_26 = sub_94 = None + add_347 = torch.ops.aten.add.Tensor(mul_888, 1); mul_888 = None + mul_889 = torch.ops.aten.mul.Tensor(mul_887, add_347); mul_887 = add_347 = None + convert_element_type_2760 = torch.ops.prims.convert_element_type.default(mul_889, torch.bfloat16); mul_889 = None + view_1843 = torch.ops.aten.view.default(convert_element_type_2760, [16384, 14336]); convert_element_type_2760 = None + permute_1357 = torch.ops.aten.permute.default(view_1843, [1, 0]) + mm_665 = torch.ops.aten.mm.default(permute_1357, view_27); permute_1357 = view_27 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16); primals_10 = None + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 256, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + permute_1359 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None + mm_666 = torch.ops.aten.mm.default(view_1843, permute_1359); view_1843 = permute_1359 = None + view_1844 = torch.ops.aten.view.default(mm_666, [2, 8192, 4096]); mm_666 = None + add_348 = torch.ops.aten.add.Tensor(view_1842, view_1844); view_1842 = view_1844 = None + convert_element_type_2765 = torch.ops.prims.convert_element_type.default(mm_665, torch.float32); mm_665 = None + reduce_scatter_tensor_283 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2765, 'avg', 256, '0'); convert_element_type_2765 = None + wait_tensor_574 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_283); reduce_scatter_tensor_283 = None + convert_element_type_2766 = torch.ops.prims.convert_element_type.default(add_348, torch.float32); add_348 = None + convert_element_type_2768 = torch.ops.prims.convert_element_type.default(wait_tensor_6, torch.float32); wait_tensor_6 = None + mul_890 = torch.ops.aten.mul.Tensor(convert_element_type_2766, convert_element_type_2768); convert_element_type_2768 = None + mul_892 = torch.ops.aten.mul.Tensor(mul_4, mul_890) + sum_189 = torch.ops.aten.sum.dim_IntList(mul_892, [2], True); mul_892 = None + div_63 = torch.ops.aten.div.Tensor(mul_4, 4096) + mul_893 = torch.ops.aten.mul.Tensor(div_63, sum_189); div_63 = sum_189 = None + sub_95 = torch.ops.aten.sub.Tensor(mul_890, mul_893); mul_890 = mul_893 = None + mul_894 = torch.ops.aten.mul.Tensor(sub_95, rsqrt_1); sub_95 = rsqrt_1 = None + mul_895 = torch.ops.aten.mul.Tensor(convert_element_type_2766, mul_4); convert_element_type_2766 = mul_4 = None + sum_190 = torch.ops.aten.sum.dim_IntList(mul_895, [0, 1]); mul_895 = None + convert_element_type_2769 = torch.ops.prims.convert_element_type.default(mul_894, torch.bfloat16); mul_894 = None + add_349 = torch.ops.aten.add.Tensor(add_345, convert_element_type_2769); add_345 = convert_element_type_2769 = None + convert_element_type_default_2 = torch.ops.prims.convert_element_type.default(sum_190, torch.float32); sum_190 = None + reduce_scatter_tensor_284 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_2, 'avg', 256, '0'); convert_element_type_default_2 = None + wait_tensor_575 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_284); reduce_scatter_tensor_284 = None + view_1845 = torch.ops.aten.view.default(add_349, [16384, 4096]) + permute_1361 = torch.ops.aten.permute.default(view_1845, [1, 0]) + mm_667 = torch.ops.aten.mm.default(permute_1361, view_23); permute_1361 = view_23 = None + permute_1363 = torch.ops.aten.permute.default(permute_7, [1, 0]); permute_7 = None + mm_668 = torch.ops.aten.mm.default(view_1845, permute_1363); view_1845 = permute_1363 = None + view_1846 = torch.ops.aten.view.default(mm_668, [2, 8192, 4096]); mm_668 = None + convert_element_type_2776 = torch.ops.prims.convert_element_type.default(mm_667, torch.float32); mm_667 = None + reduce_scatter_tensor_285 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2776, 'avg', 256, '0'); convert_element_type_2776 = None + wait_tensor_576 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_285); reduce_scatter_tensor_285 = None + view_1847 = torch.ops.aten.view.default(view_1846, [2, 8192, 32, 128]); view_1846 = None + permute_1365 = torch.ops.aten.permute.default(view_1847, [0, 2, 1, 3]); view_1847 = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 256, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32); embedding = None + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_1) + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + view_3 = torch.ops.aten.view.default(convert_element_type_3, [16384, 4096]); convert_element_type_3 = None + view_4 = torch.ops.aten.view.default(mm, [2, 8192, 4096]); mm = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16); primals_6 = None + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 256, '0'); convert_element_type_7 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_3, [1, 0]); wait_tensor_3 = None + mm_1 = torch.ops.aten.mm.default(view_3, permute_1) + view_7 = torch.ops.aten.view.default(mm_1, [2, 8192, 1024]); mm_1 = None + view_10 = torch.ops.aten.view.default(mm_2, [2, 8192, 1024]); mm_2 = None + view_11 = torch.ops.aten.view.default(view_4, [2, 8192, -1, 128]); view_4 = None + view_12 = torch.ops.aten.view.default(view_7, [2, 8192, -1, 128]); view_7 = None + view_13 = torch.ops.aten.view.default(view_10, [2, 8192, -1, 128]); view_10 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None + view_14 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 32, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_14); view_14 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_12, torch.float32); view_12 = None + view_15 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 8, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_15); view_15 = None + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_16); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_17 = torch.ops.aten.view.default(view_as_real, [2, 8192, 32, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_16); view_as_complex_1 = view_16 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_18 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 8, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_17, torch.bfloat16); view_17 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_18, torch.bfloat16); view_18 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 8, 4, 128]); unsqueeze = None + clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None + view_19 = torch.ops.aten.view.default(clone, [2, 8192, 32, 128]); clone = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_13, 3); view_13 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 8, 4, 128]); unsqueeze_1 = None + clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None + view_20 = torch.ops.aten.view.default(clone_1, [2, 8192, 32, 128]); clone_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]); view_19 = None + permute_5 = torch.ops.aten.permute.default(view_20, [0, 2, 1, 3]); view_20 = None + _scaled_dot_product_cudnn_attention_backward_31 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1365, permute_3, permute_4, permute_5, getitem, getitem_1, getitem_6, getitem_7, None, None, None, 8192, 8192, 0.0, True); permute_1365 = permute_3 = permute_4 = permute_5 = getitem = getitem_1 = getitem_6 = getitem_7 = None + getitem_381 = _scaled_dot_product_cudnn_attention_backward_31[0] + getitem_382 = _scaled_dot_product_cudnn_attention_backward_31[1] + getitem_383 = _scaled_dot_product_cudnn_attention_backward_31[2]; _scaled_dot_product_cudnn_attention_backward_31 = None + permute_1366 = torch.ops.aten.permute.default(getitem_383, [0, 2, 1, 3]); getitem_383 = None + permute_1367 = torch.ops.aten.permute.default(getitem_382, [0, 2, 1, 3]); getitem_382 = None + permute_1368 = torch.ops.aten.permute.default(getitem_381, [0, 2, 1, 3]); getitem_381 = None + view_1848 = torch.ops.aten.view.default(permute_1366, [2, 8192, 8, 4, 128]); permute_1366 = None + sum_191 = torch.ops.aten.sum.dim_IntList(view_1848, [3], True); view_1848 = None + squeeze_62 = torch.ops.aten.squeeze.dim(sum_191, 3); sum_191 = None + view_1849 = torch.ops.aten.view.default(permute_1367, [2, 8192, 8, 4, 128]); permute_1367 = None + sum_192 = torch.ops.aten.sum.dim_IntList(view_1849, [3], True); view_1849 = None + squeeze_63 = torch.ops.aten.squeeze.dim(sum_192, 3); sum_192 = None + convert_element_type_2777 = torch.ops.prims.convert_element_type.default(squeeze_63, torch.float32); squeeze_63 = None + convert_element_type_2778 = torch.ops.prims.convert_element_type.default(permute_1368, torch.float32); permute_1368 = None + view_1850 = torch.ops.aten.view.default(convert_element_type_2777, [2, 8192, 8, 64, 2]); convert_element_type_2777 = None + view_as_complex_126 = torch.ops.aten.view_as_complex.default(view_1850); view_1850 = None + mul_896 = torch.ops.aten.mul.Tensor(view_as_complex_126, _conj); view_as_complex_126 = None + view_1851 = torch.ops.aten.view.default(convert_element_type_2778, [2, 8192, 32, 64, 2]); convert_element_type_2778 = None + view_as_complex_127 = torch.ops.aten.view_as_complex.default(view_1851); view_1851 = None + mul_897 = torch.ops.aten.mul.Tensor(view_as_complex_127, _conj); view_as_complex_127 = _conj = None + view_as_real_126 = torch.ops.aten.view_as_real.default(mul_896); mul_896 = None + view_1852 = torch.ops.aten.view.default(view_as_real_126, [2, 8192, 8, 128]); view_as_real_126 = None + convert_element_type_2779 = torch.ops.prims.convert_element_type.default(view_1852, torch.bfloat16); view_1852 = None + view_as_real_127 = torch.ops.aten.view_as_real.default(mul_897); mul_897 = None + view_1853 = torch.ops.aten.view.default(view_as_real_127, [2, 8192, 32, 128]); view_as_real_127 = None + convert_element_type_2780 = torch.ops.prims.convert_element_type.default(view_1853, torch.bfloat16); view_1853 = None + view_1854 = torch.ops.aten.view.default(squeeze_62, [2, 8192, 1024]); squeeze_62 = None + view_1855 = torch.ops.aten.view.default(convert_element_type_2779, [2, 8192, 1024]); convert_element_type_2779 = None + view_1856 = torch.ops.aten.view.default(convert_element_type_2780, [2, 8192, 4096]); convert_element_type_2780 = None + view_1857 = torch.ops.aten.view.default(view_1854, [16384, 1024]); view_1854 = None + permute_1369 = torch.ops.aten.permute.default(view_1857, [1, 0]) + mm_669 = torch.ops.aten.mm.default(permute_1369, view_3); permute_1369 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16); primals_7 = None + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 256, '0'); convert_element_type_10 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + permute_1371 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + mm_670 = torch.ops.aten.mm.default(view_1857, permute_1371); view_1857 = permute_1371 = None + view_1858 = torch.ops.aten.view.default(mm_670, [2, 8192, 4096]); mm_670 = None + convert_element_type_2785 = torch.ops.prims.convert_element_type.default(mm_669, torch.float32); mm_669 = None + reduce_scatter_tensor_286 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2785, 'avg', 256, '0'); convert_element_type_2785 = None + wait_tensor_577 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_286); reduce_scatter_tensor_286 = None + view_1859 = torch.ops.aten.view.default(view_1855, [16384, 1024]); view_1855 = None + permute_1373 = torch.ops.aten.permute.default(view_1859, [1, 0]) + mm_671 = torch.ops.aten.mm.default(permute_1373, view_3); permute_1373 = None + permute_1375 = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None + mm_672 = torch.ops.aten.mm.default(view_1859, permute_1375); view_1859 = permute_1375 = None + view_1860 = torch.ops.aten.view.default(mm_672, [2, 8192, 4096]); mm_672 = None + add_350 = torch.ops.aten.add.Tensor(view_1858, view_1860); view_1858 = view_1860 = None + convert_element_type_2790 = torch.ops.prims.convert_element_type.default(mm_671, torch.float32); mm_671 = None + reduce_scatter_tensor_287 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2790, 'avg', 256, '0'); convert_element_type_2790 = None + wait_tensor_578 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_287); reduce_scatter_tensor_287 = None + view_1861 = torch.ops.aten.view.default(view_1856, [16384, 4096]); view_1856 = None + permute_1377 = torch.ops.aten.permute.default(view_1861, [1, 0]) + mm_673 = torch.ops.aten.mm.default(permute_1377, view_3); permute_1377 = view_3 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16); primals_5 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 256, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + permute_1379 = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + mm_674 = torch.ops.aten.mm.default(view_1861, permute_1379); view_1861 = permute_1379 = None + view_1862 = torch.ops.aten.view.default(mm_674, [2, 8192, 4096]); mm_674 = None + add_351 = torch.ops.aten.add.Tensor(add_350, view_1862); add_350 = view_1862 = None + convert_element_type_2795 = torch.ops.prims.convert_element_type.default(mm_673, torch.float32); mm_673 = None + reduce_scatter_tensor_288 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2795, 'avg', 256, '0'); convert_element_type_2795 = None + wait_tensor_579 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_288); reduce_scatter_tensor_288 = None + convert_element_type_2796 = torch.ops.prims.convert_element_type.default(add_351, torch.float32); add_351 = None + convert_element_type_2798 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32); wait_tensor_1 = None + mul_898 = torch.ops.aten.mul.Tensor(convert_element_type_2796, convert_element_type_2798); convert_element_type_2798 = None + mul_900 = torch.ops.aten.mul.Tensor(mul, mul_898) + sum_193 = torch.ops.aten.sum.dim_IntList(mul_900, [2], True); mul_900 = None + div_64 = torch.ops.aten.div.Tensor(mul, 4096) + mul_901 = torch.ops.aten.mul.Tensor(div_64, sum_193); div_64 = sum_193 = None + sub_96 = torch.ops.aten.sub.Tensor(mul_898, mul_901); mul_898 = mul_901 = None + mul_902 = torch.ops.aten.mul.Tensor(sub_96, rsqrt); sub_96 = rsqrt = None + mul_903 = torch.ops.aten.mul.Tensor(convert_element_type_2796, mul); convert_element_type_2796 = mul = None + sum_194 = torch.ops.aten.sum.dim_IntList(mul_903, [0, 1]); mul_903 = None + convert_element_type_2799 = torch.ops.prims.convert_element_type.default(mul_902, torch.bfloat16); mul_902 = None + add_352 = torch.ops.aten.add.Tensor(add_349, convert_element_type_2799); add_349 = convert_element_type_2799 = None + convert_element_type_default_1 = torch.ops.prims.convert_element_type.default(sum_194, torch.float32); sum_194 = None + reduce_scatter_tensor_289 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_1, 'avg', 256, '0'); convert_element_type_default_1 = None + wait_tensor_580 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_289); reduce_scatter_tensor_289 = None + convert_element_type_2802 = torch.ops.prims.convert_element_type.default(add_352, torch.float32); add_352 = None + eq = torch.ops.aten.eq.Scalar(primals_2, -1) + unsqueeze_64 = torch.ops.aten.unsqueeze.default(eq, -1); eq = None + full_default = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + where = torch.ops.aten.where.self(unsqueeze_64, full_default, convert_element_type_2802); unsqueeze_64 = full_default = convert_element_type_2802 = None + full_default_1 = torch.ops.aten.full.default([128256, 4096], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put = torch.ops.aten.index_put.default(full_default_1, [primals_2], where, True); full_default_1 = primals_2 = where = None + convert_element_type_default = torch.ops.prims.convert_element_type.default(index_put, torch.float32); index_put = None + reduce_scatter_tensor_290 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default, 'avg', 256, '0'); convert_element_type_default = None + wait_tensor_581 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_290); reduce_scatter_tensor_290 = None + return (wait_tensor_581, None, None, wait_tensor_580, wait_tensor_579, wait_tensor_578, wait_tensor_577, wait_tensor_576, wait_tensor_575, wait_tensor_574, wait_tensor_573, wait_tensor_572, wait_tensor_571, wait_tensor_570, wait_tensor_569, wait_tensor_568, wait_tensor_567, wait_tensor_566, wait_tensor_565, wait_tensor_564, wait_tensor_563, wait_tensor_562, wait_tensor_561, wait_tensor_560, wait_tensor_559, wait_tensor_558, wait_tensor_557, wait_tensor_556, wait_tensor_555, wait_tensor_554, wait_tensor_553, wait_tensor_552, wait_tensor_551, wait_tensor_550, wait_tensor_549, wait_tensor_548, wait_tensor_547, wait_tensor_546, wait_tensor_545, wait_tensor_544, wait_tensor_543, wait_tensor_542, wait_tensor_541, wait_tensor_540, wait_tensor_539, wait_tensor_538, wait_tensor_537, wait_tensor_536, wait_tensor_535, wait_tensor_534, wait_tensor_533, wait_tensor_532, wait_tensor_531, wait_tensor_530, wait_tensor_529, wait_tensor_528, wait_tensor_527, wait_tensor_526, wait_tensor_525, wait_tensor_524, wait_tensor_523, wait_tensor_522, wait_tensor_521, wait_tensor_520, wait_tensor_519, wait_tensor_518, wait_tensor_517, wait_tensor_516, wait_tensor_515, wait_tensor_514, wait_tensor_513, wait_tensor_512, wait_tensor_511, wait_tensor_510, wait_tensor_509, wait_tensor_508, wait_tensor_507, wait_tensor_506, wait_tensor_505, wait_tensor_504, wait_tensor_503, wait_tensor_502, wait_tensor_501, wait_tensor_500, wait_tensor_499, wait_tensor_498, wait_tensor_497, wait_tensor_496, wait_tensor_495, wait_tensor_494, wait_tensor_493, wait_tensor_492, wait_tensor_491, wait_tensor_490, wait_tensor_489, wait_tensor_488, wait_tensor_487, wait_tensor_486, wait_tensor_485, wait_tensor_484, wait_tensor_483, wait_tensor_482, wait_tensor_481, wait_tensor_480, wait_tensor_479, wait_tensor_478, wait_tensor_477, wait_tensor_476, wait_tensor_475, wait_tensor_474, wait_tensor_473, wait_tensor_472, wait_tensor_471, wait_tensor_470, wait_tensor_469, wait_tensor_468, wait_tensor_467, wait_tensor_466, wait_tensor_465, wait_tensor_464, wait_tensor_463, wait_tensor_462, wait_tensor_461, wait_tensor_460, wait_tensor_459, wait_tensor_458, wait_tensor_457, wait_tensor_456, wait_tensor_455, wait_tensor_454, wait_tensor_453, wait_tensor_452, wait_tensor_451, wait_tensor_450, wait_tensor_449, wait_tensor_448, wait_tensor_447, wait_tensor_446, wait_tensor_445, wait_tensor_444, wait_tensor_443, wait_tensor_442, wait_tensor_441, wait_tensor_440, wait_tensor_439, wait_tensor_438, wait_tensor_437, wait_tensor_436, wait_tensor_435, wait_tensor_434, wait_tensor_433, wait_tensor_432, wait_tensor_431, wait_tensor_430, wait_tensor_429, wait_tensor_428, wait_tensor_427, wait_tensor_426, wait_tensor_425, wait_tensor_424, wait_tensor_423, wait_tensor_422, wait_tensor_421, wait_tensor_420, wait_tensor_419, wait_tensor_418, wait_tensor_417, wait_tensor_416, wait_tensor_415, wait_tensor_414, wait_tensor_413, wait_tensor_412, wait_tensor_411, wait_tensor_410, wait_tensor_409, wait_tensor_408, wait_tensor_407, wait_tensor_406, wait_tensor_405, wait_tensor_404, wait_tensor_403, wait_tensor_402, wait_tensor_401, wait_tensor_400, wait_tensor_399, wait_tensor_398, wait_tensor_397, wait_tensor_396, wait_tensor_395, wait_tensor_394, wait_tensor_393, wait_tensor_392, wait_tensor_391, wait_tensor_390, wait_tensor_389, wait_tensor_388, wait_tensor_387, wait_tensor_386, wait_tensor_385, wait_tensor_384, wait_tensor_383, wait_tensor_382, wait_tensor_381, wait_tensor_380, wait_tensor_379, wait_tensor_378, wait_tensor_377, wait_tensor_376, wait_tensor_375, wait_tensor_374, wait_tensor_373, wait_tensor_372, wait_tensor_371, wait_tensor_370, wait_tensor_369, wait_tensor_368, wait_tensor_367, wait_tensor_366, wait_tensor_365, wait_tensor_364, wait_tensor_363, wait_tensor_362, wait_tensor_361, wait_tensor_360, wait_tensor_359, wait_tensor_358, wait_tensor_357, wait_tensor_356, wait_tensor_355, wait_tensor_354, wait_tensor_353, wait_tensor_352, wait_tensor_351, wait_tensor_350, wait_tensor_349, wait_tensor_348, wait_tensor_347, wait_tensor_346, wait_tensor_345, wait_tensor_344, wait_tensor_343, wait_tensor_342, wait_tensor_341, wait_tensor_340, wait_tensor_339, wait_tensor_338, wait_tensor_337, wait_tensor_336, wait_tensor_335, wait_tensor_334, wait_tensor_333, wait_tensor_332, wait_tensor_331, wait_tensor_330, wait_tensor_329, wait_tensor_328, wait_tensor_327, wait_tensor_326, wait_tensor_325, wait_tensor_324, wait_tensor_323, wait_tensor_322, wait_tensor_321, wait_tensor_320, wait_tensor_319, wait_tensor_318, wait_tensor_317, wait_tensor_316, wait_tensor_315, wait_tensor_314, wait_tensor_313, wait_tensor_312, wait_tensor_311, wait_tensor_310, wait_tensor_309, wait_tensor_308, wait_tensor_307, wait_tensor_306, wait_tensor_305, wait_tensor_304, wait_tensor_303, wait_tensor_302, wait_tensor_301, wait_tensor_300, wait_tensor_299, wait_tensor_298, wait_tensor_297, wait_tensor_296, wait_tensor_295, wait_tensor_294, wait_tensor_293, wait_tensor_292, wait_tensor_291) + +def load_args(reader): + buf0 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf0, (501, 4096), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf3, (16,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf4, (16, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf5, (4, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf6, (4, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf7, (16, 4096), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf8, (16,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf9, (56, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf10, (56, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf11, (16, 14336), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf12, (16,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf13, (16, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf14, (4, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf15, (4, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf16, (16, 4096), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf17, (16,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf18, (56, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf19, (56, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf20, (16, 14336), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf21, (16,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf22, (16, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf23, (4, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf24, (4, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf25, (16, 4096), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf26, (16,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf27, (56, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf28, (56, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf29, (16, 14336), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf30, (16,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf31, (16, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf32, (4, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf33, (4, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf34, (16, 4096), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf35, (16,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf36, (56, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf37, (56, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf38, (16, 14336), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf39, (16,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf40, (16, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf41, (4, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf42, (4, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf43, (16, 4096), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf44, (16,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf45, (56, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf46, (56, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf47, (16, 14336), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf48, (16,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf49, (16, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf50, (4, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf51, (4, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf52, (16, 4096), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf53, (16,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf54, (56, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf55, (56, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf56, (16, 14336), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf57, (16,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf58, (16, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf59, (4, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf60, (4, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf61, (16, 4096), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf62, (16,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf63, (56, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf64, (56, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf65, (16, 14336), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf66, (16,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf67, (16, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf68, (4, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf69, (4, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf70, (16, 4096), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf71, (16,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf72, (56, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf73, (56, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf74, (16, 14336), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf75, (16,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf76, (16, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf77, (4, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf78, (4, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf79, (16, 4096), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf80, (16,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf81, (56, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf82, (56, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf83, (16, 14336), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf84, (16,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf85, (16, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf86, (4, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf87, (4, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf88, (16, 4096), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf89, (16,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf90, (56, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf91, (56, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf92, (16, 14336), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf93, (16,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf94, (16, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf95, (4, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf96, (4, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf97, (16, 4096), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf98, (16,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf99, (56, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf100, (56, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf101, (16, 14336), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf102, (16,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf103, (16, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf104, (4, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf105, (4, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf106, (16, 4096), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf107, (16,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf108, (56, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf109, (56, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf110, (16, 14336), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf111, (16,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf112, (16, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf113, (4, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf114, (4, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf115, (16, 4096), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf116, (16,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf117, (56, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf118, (56, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf119, (16, 14336), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf120, (16,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf121, (16, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf122, (4, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf123, (4, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf124, (16, 4096), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf125, (16,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf126, (56, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf127, (56, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf128, (16, 14336), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf129, (16,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf130, (16, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf131, (4, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf132, (4, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf133, (16, 4096), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf134, (16,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf135, (56, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf136, (56, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf137, (16, 14336), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf138, (16,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf139, (16, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf140, (4, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf141, (4, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf142, (16, 4096), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf143, (16,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf144, (56, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf145, (56, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf146, (16, 14336), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf147, (16,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf148, (16, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf149, (4, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf150, (4, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf151, (16, 4096), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf152, (16,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf153, (56, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf154, (56, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf155, (16, 14336), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf156, (16,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf157, (16, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf158, (4, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf159, (4, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf160, (16, 4096), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf161, (16,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf162, (56, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf163, (56, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf164, (16, 14336), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf165, (16,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf166, (16, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf167, (4, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf168, (4, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf169, (16, 4096), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf170, (16,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf171, (56, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf172, (56, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf173, (16, 14336), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf174, (16,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf175, (16, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf176, (4, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf177, (4, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf178, (16, 4096), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf179, (16,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf180, (56, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf181, (56, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf182, (16, 14336), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf183, (16,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf184, (16, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf185, (4, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf186, (4, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf187, (16, 4096), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf188, (16,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf189, (56, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf190, (56, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf191, (16, 14336), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf192, (16,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf193, (16, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf194, (4, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf195, (4, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf196, (16, 4096), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf197, (16,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf198, (56, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf199, (56, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf200, (16, 14336), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf201, (16,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf202, (16, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf203, (4, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf204, (4, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf205, (16, 4096), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf206, (16,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf207, (56, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf208, (56, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf209, (16, 14336), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf210, (16,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf211, (16, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf212, (4, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf213, (4, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf214, (16, 4096), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf215, (16,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf216, (56, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf217, (56, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf218, (16, 14336), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf219, (16,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf220, (16, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf221, (4, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf222, (4, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf223, (16, 4096), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf224, (16,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf225, (56, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf226, (56, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf227, (16, 14336), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf228, (16,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf229, (16, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf230, (4, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf231, (4, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf232, (16, 4096), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf233, (16,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf234, (56, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf235, (56, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf236, (16, 14336), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf237, (16,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf238, (16, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf239, (4, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf240, (4, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf241, (16, 4096), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf242, (16,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf243, (56, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf244, (56, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf245, (16, 14336), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf246, (16,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf247, (16, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf248, (4, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf249, (4, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf250, (16, 4096), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf251, (16,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf252, (56, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf253, (56, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf254, (16, 14336), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf255, (16,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf256, (16, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf257, (4, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf258, (4, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf259, (16, 4096), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf260, (16,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf261, (56, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf262, (56, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf263, (16, 14336), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf264, (16,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf265, (16, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf266, (4, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf267, (4, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf268, (16, 4096), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf269, (16,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf270, (56, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf271, (56, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf272, (16, 14336), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf273, (16,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf274, (16, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf275, (4, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf276, (4, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf277, (16, 4096), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf278, (16,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf279, (56, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf280, (56, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf281, (16, 14336), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf282, (16,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf283, (16, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf284, (4, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf285, (4, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf286, (16, 4096), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf287, (16,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf288, (56, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf289, (56, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf290, (16, 14336), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf291, (16,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf292, (501, 4096), is_leaf=True) # primals_293 + buf293 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf293, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # embedding + buf294 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf294, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm + buf295 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf295, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_2 + buf296 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf296, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem + buf297 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf297, (2, 32, 8192, 1), is_leaf=True) # getitem_1 + buf298 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf298, (), dtype=torch.int64, is_leaf=True) # getitem_6 + buf299 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf299, (), dtype=torch.int64, is_leaf=True) # getitem_7 + buf300 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf300, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_4 + buf301 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf301, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_3 + buf302 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf302, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_7 + buf303 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf303, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_9 + buf304 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf304, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_9 + buf305 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf305, (2, 32, 8192, 1), is_leaf=True) # getitem_10 + buf306 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf306, (), dtype=torch.int64, is_leaf=True) # getitem_15 + buf307 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf307, (), dtype=torch.int64, is_leaf=True) # getitem_16 + buf308 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf308, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_11 + buf309 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf309, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_7 + buf310 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf310, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_14 + buf311 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf311, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_16 + buf312 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf312, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_18 + buf313 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf313, (2, 32, 8192, 1), is_leaf=True) # getitem_19 + buf314 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf314, (), dtype=torch.int64, is_leaf=True) # getitem_24 + buf315 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf315, (), dtype=torch.int64, is_leaf=True) # getitem_25 + buf316 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf316, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_18 + buf317 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf317, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_11 + buf318 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf318, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_21 + buf319 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf319, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_23 + buf320 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf320, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_27 + buf321 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf321, (2, 32, 8192, 1), is_leaf=True) # getitem_28 + buf322 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf322, (), dtype=torch.int64, is_leaf=True) # getitem_33 + buf323 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf323, (), dtype=torch.int64, is_leaf=True) # getitem_34 + buf324 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf324, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_25 + buf325 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf325, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_15 + buf326 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf326, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_28 + buf327 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf327, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_30 + buf328 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf328, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_36 + buf329 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf329, (2, 32, 8192, 1), is_leaf=True) # getitem_37 + buf330 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf330, (), dtype=torch.int64, is_leaf=True) # getitem_42 + buf331 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf331, (), dtype=torch.int64, is_leaf=True) # getitem_43 + buf332 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf332, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_32 + buf333 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf333, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_19 + buf334 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf334, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_35 + buf335 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf335, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_37 + buf336 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf336, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_45 + buf337 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf337, (2, 32, 8192, 1), is_leaf=True) # getitem_46 + buf338 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf338, (), dtype=torch.int64, is_leaf=True) # getitem_51 + buf339 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf339, (), dtype=torch.int64, is_leaf=True) # getitem_52 + buf340 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf340, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_39 + buf341 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf341, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_23 + buf342 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf342, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_42 + buf343 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf343, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_44 + buf344 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf344, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_54 + buf345 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf345, (2, 32, 8192, 1), is_leaf=True) # getitem_55 + buf346 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf346, (), dtype=torch.int64, is_leaf=True) # getitem_60 + buf347 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf347, (), dtype=torch.int64, is_leaf=True) # getitem_61 + buf348 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf348, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_46 + buf349 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf349, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_27 + buf350 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf350, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_49 + buf351 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf351, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_51 + buf352 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf352, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_63 + buf353 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf353, (2, 32, 8192, 1), is_leaf=True) # getitem_64 + buf354 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf354, (), dtype=torch.int64, is_leaf=True) # getitem_69 + buf355 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf355, (), dtype=torch.int64, is_leaf=True) # getitem_70 + buf356 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf356, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_53 + buf357 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf357, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_31 + buf358 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf358, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_56 + buf359 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf359, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_58 + buf360 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf360, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_72 + buf361 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf361, (2, 32, 8192, 1), is_leaf=True) # getitem_73 + buf362 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf362, (), dtype=torch.int64, is_leaf=True) # getitem_78 + buf363 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf363, (), dtype=torch.int64, is_leaf=True) # getitem_79 + buf364 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf364, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_60 + buf365 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf365, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_35 + buf366 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf366, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_63 + buf367 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf367, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_65 + buf368 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf368, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_81 + buf369 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf369, (2, 32, 8192, 1), is_leaf=True) # getitem_82 + buf370 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf370, (), dtype=torch.int64, is_leaf=True) # getitem_87 + buf371 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf371, (), dtype=torch.int64, is_leaf=True) # getitem_88 + buf372 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf372, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_67 + buf373 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf373, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_39 + buf374 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf374, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_70 + buf375 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf375, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_72 + buf376 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf376, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_90 + buf377 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf377, (2, 32, 8192, 1), is_leaf=True) # getitem_91 + buf378 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf378, (), dtype=torch.int64, is_leaf=True) # getitem_96 + buf379 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf379, (), dtype=torch.int64, is_leaf=True) # getitem_97 + buf380 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf380, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_74 + buf381 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf381, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_43 + buf382 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf382, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_77 + buf383 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf383, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_79 + buf384 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf384, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_99 + buf385 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf385, (2, 32, 8192, 1), is_leaf=True) # getitem_100 + buf386 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf386, (), dtype=torch.int64, is_leaf=True) # getitem_105 + buf387 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf387, (), dtype=torch.int64, is_leaf=True) # getitem_106 + buf388 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf388, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_81 + buf389 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf389, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_47 + buf390 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf390, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_84 + buf391 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf391, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_86 + buf392 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf392, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_108 + buf393 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf393, (2, 32, 8192, 1), is_leaf=True) # getitem_109 + buf394 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf394, (), dtype=torch.int64, is_leaf=True) # getitem_114 + buf395 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf395, (), dtype=torch.int64, is_leaf=True) # getitem_115 + buf396 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf396, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_88 + buf397 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf397, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_51 + buf398 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf398, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_91 + buf399 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf399, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_93 + buf400 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf400, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_117 + buf401 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf401, (2, 32, 8192, 1), is_leaf=True) # getitem_118 + buf402 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf402, (), dtype=torch.int64, is_leaf=True) # getitem_123 + buf403 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf403, (), dtype=torch.int64, is_leaf=True) # getitem_124 + buf404 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf404, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_95 + buf405 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf405, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_55 + buf406 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf406, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_98 + buf407 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf407, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_100 + buf408 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf408, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_126 + buf409 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf409, (2, 32, 8192, 1), is_leaf=True) # getitem_127 + buf410 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf410, (), dtype=torch.int64, is_leaf=True) # getitem_132 + buf411 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf411, (), dtype=torch.int64, is_leaf=True) # getitem_133 + buf412 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf412, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_102 + buf413 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf413, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_59 + buf414 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf414, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_105 + buf415 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf415, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_107 + buf416 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf416, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_135 + buf417 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf417, (2, 32, 8192, 1), is_leaf=True) # getitem_136 + buf418 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf418, (), dtype=torch.int64, is_leaf=True) # getitem_141 + buf419 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf419, (), dtype=torch.int64, is_leaf=True) # getitem_142 + buf420 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf420, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_109 + buf421 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf421, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_63 + buf422 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf422, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_112 + buf423 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf423, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_114 + buf424 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf424, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_144 + buf425 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf425, (2, 32, 8192, 1), is_leaf=True) # getitem_145 + buf426 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf426, (), dtype=torch.int64, is_leaf=True) # getitem_150 + buf427 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf427, (), dtype=torch.int64, is_leaf=True) # getitem_151 + buf428 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf428, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_116 + buf429 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf429, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_67 + buf430 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf430, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_119 + buf431 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf431, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_121 + buf432 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf432, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_153 + buf433 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf433, (2, 32, 8192, 1), is_leaf=True) # getitem_154 + buf434 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf434, (), dtype=torch.int64, is_leaf=True) # getitem_159 + buf435 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf435, (), dtype=torch.int64, is_leaf=True) # getitem_160 + buf436 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf436, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_123 + buf437 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf437, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_71 + buf438 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf438, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_126 + buf439 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf439, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_128 + buf440 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf440, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_162 + buf441 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf441, (2, 32, 8192, 1), is_leaf=True) # getitem_163 + buf442 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf442, (), dtype=torch.int64, is_leaf=True) # getitem_168 + buf443 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf443, (), dtype=torch.int64, is_leaf=True) # getitem_169 + buf444 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf444, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_130 + buf445 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf445, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_75 + buf446 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf446, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_133 + buf447 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf447, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_135 + buf448 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf448, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_171 + buf449 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf449, (2, 32, 8192, 1), is_leaf=True) # getitem_172 + buf450 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf450, (), dtype=torch.int64, is_leaf=True) # getitem_177 + buf451 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf451, (), dtype=torch.int64, is_leaf=True) # getitem_178 + buf452 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf452, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_137 + buf453 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf453, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_79 + buf454 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf454, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_140 + buf455 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf455, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_142 + buf456 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf456, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_180 + buf457 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf457, (2, 32, 8192, 1), is_leaf=True) # getitem_181 + buf458 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf458, (), dtype=torch.int64, is_leaf=True) # getitem_186 + buf459 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf459, (), dtype=torch.int64, is_leaf=True) # getitem_187 + buf460 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf460, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_144 + buf461 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf461, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_83 + buf462 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf462, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_147 + buf463 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf463, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_149 + buf464 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf464, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_189 + buf465 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf465, (2, 32, 8192, 1), is_leaf=True) # getitem_190 + buf466 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf466, (), dtype=torch.int64, is_leaf=True) # getitem_195 + buf467 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf467, (), dtype=torch.int64, is_leaf=True) # getitem_196 + buf468 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf468, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_151 + buf469 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf469, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_87 + buf470 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf470, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_154 + buf471 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf471, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_156 + buf472 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf472, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_198 + buf473 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf473, (2, 32, 8192, 1), is_leaf=True) # getitem_199 + buf474 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf474, (), dtype=torch.int64, is_leaf=True) # getitem_204 + buf475 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf475, (), dtype=torch.int64, is_leaf=True) # getitem_205 + buf476 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf476, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_158 + buf477 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf477, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_91 + buf478 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf478, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_161 + buf479 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf479, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_163 + buf480 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf480, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_207 + buf481 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf481, (2, 32, 8192, 1), is_leaf=True) # getitem_208 + buf482 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf482, (), dtype=torch.int64, is_leaf=True) # getitem_213 + buf483 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf483, (), dtype=torch.int64, is_leaf=True) # getitem_214 + buf484 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf484, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_165 + buf485 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf485, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_95 + buf486 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf486, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_168 + buf487 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf487, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_170 + buf488 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf488, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_216 + buf489 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf489, (2, 32, 8192, 1), is_leaf=True) # getitem_217 + buf490 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf490, (), dtype=torch.int64, is_leaf=True) # getitem_222 + buf491 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf491, (), dtype=torch.int64, is_leaf=True) # getitem_223 + buf492 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf492, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_172 + buf493 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf493, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_99 + buf494 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf494, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_175 + buf495 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf495, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_177 + buf496 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf496, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_225 + buf497 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf497, (2, 32, 8192, 1), is_leaf=True) # getitem_226 + buf498 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf498, (), dtype=torch.int64, is_leaf=True) # getitem_231 + buf499 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf499, (), dtype=torch.int64, is_leaf=True) # getitem_232 + buf500 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf500, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_179 + buf501 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf501, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_103 + buf502 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf502, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_182 + buf503 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf503, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_184 + buf504 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf504, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_234 + buf505 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf505, (2, 32, 8192, 1), is_leaf=True) # getitem_235 + buf506 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf506, (), dtype=torch.int64, is_leaf=True) # getitem_240 + buf507 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf507, (), dtype=torch.int64, is_leaf=True) # getitem_241 + buf508 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf508, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_186 + buf509 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf509, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_107 + buf510 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf510, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_189 + buf511 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf511, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_191 + buf512 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf512, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_243 + buf513 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf513, (2, 32, 8192, 1), is_leaf=True) # getitem_244 + buf514 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf514, (), dtype=torch.int64, is_leaf=True) # getitem_249 + buf515 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf515, (), dtype=torch.int64, is_leaf=True) # getitem_250 + buf516 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf516, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_193 + buf517 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf517, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_111 + buf518 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf518, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_196 + buf519 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf519, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_198 + buf520 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf520, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_252 + buf521 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf521, (2, 32, 8192, 1), is_leaf=True) # getitem_253 + buf522 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf522, (), dtype=torch.int64, is_leaf=True) # getitem_258 + buf523 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf523, (), dtype=torch.int64, is_leaf=True) # getitem_259 + buf524 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf524, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_200 + buf525 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf525, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_115 + buf526 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf526, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_203 + buf527 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf527, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_205 + buf528 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf528, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_261 + buf529 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf529, (2, 32, 8192, 1), is_leaf=True) # getitem_262 + buf530 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf530, (), dtype=torch.int64, is_leaf=True) # getitem_267 + buf531 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf531, (), dtype=torch.int64, is_leaf=True) # getitem_268 + buf532 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf532, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_207 + buf533 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf533, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_119 + buf534 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf534, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_210 + buf535 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf535, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_212 + buf536 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf536, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_270 + buf537 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf537, (2, 32, 8192, 1), is_leaf=True) # getitem_271 + buf538 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf538, (), dtype=torch.int64, is_leaf=True) # getitem_276 + buf539 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf539, (), dtype=torch.int64, is_leaf=True) # getitem_277 + buf540 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf540, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_214 + buf541 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf541, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_123 + buf542 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf542, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_217 + buf543 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf543, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_219 + buf544 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf544, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_279 + buf545 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf545, (2, 32, 8192, 1), is_leaf=True) # getitem_280 + buf546 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf546, (), dtype=torch.int64, is_leaf=True) # getitem_285 + buf547 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf547, (), dtype=torch.int64, is_leaf=True) # getitem_286 + buf548 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf548, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_221 + buf549 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf549, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_223 + buf550 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf550, (2, 8192, 1), is_leaf=True) # rsqrt_64 + buf551 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf551, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # view_1091 + buf552 = reader.storage(None, 4202692608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf552, (2, 8192, 128256), dtype=torch.bfloat16, is_leaf=True) # tangents_1 + +load_args._version = 0 + +def get_pg_config(): + return {'0': {'size': 256, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls32_8.table" + diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_2d.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_2d.py new file mode 100644 index 00000000..e1bba721 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_256_2d.py @@ -0,0 +1,11446 @@ +# fmt: off +# flake8: noqa +# isort: skip_file +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, wait_tensor_1, mm, mm_2, getitem_80, getitem_81, getitem_86, getitem_87, reduce_scatter_tensor_1, mm_4, add_3, mm_7, mm_9, getitem_121, getitem_122, getitem_127, getitem_128, reduce_scatter_tensor_3, mm_11, add_7, mm_14, mm_16, getitem_162, getitem_163, getitem_168, getitem_169, reduce_scatter_tensor_5, mm_18, add_11, mm_21, mm_23, getitem_203, getitem_204, getitem_209, getitem_210, reduce_scatter_tensor_7, mm_25, add_15, mm_28, mm_30, getitem_244, getitem_245, getitem_250, getitem_251, reduce_scatter_tensor_9, mm_32, add_19, mm_35, mm_37, getitem_285, getitem_286, getitem_291, getitem_292, reduce_scatter_tensor_11, mm_39, add_23, mm_42, mm_44, getitem_326, getitem_327, getitem_332, getitem_333, reduce_scatter_tensor_13, mm_46, add_27, mm_49, mm_51, getitem_367, getitem_368, getitem_373, getitem_374, reduce_scatter_tensor_15, mm_53, add_31, mm_56, mm_58, getitem_408, getitem_409, getitem_414, getitem_415, reduce_scatter_tensor_17, mm_60, add_35, mm_63, mm_65, getitem_449, getitem_450, getitem_455, getitem_456, reduce_scatter_tensor_19, mm_67, add_39, mm_70, mm_72, getitem_490, getitem_491, getitem_496, getitem_497, reduce_scatter_tensor_21, mm_74, add_43, mm_77, mm_79, getitem_531, getitem_532, getitem_537, getitem_538, reduce_scatter_tensor_23, mm_81, add_47, mm_84, mm_86, getitem_572, getitem_573, getitem_578, getitem_579, reduce_scatter_tensor_25, mm_88, add_51, mm_91, mm_93, getitem_613, getitem_614, getitem_619, getitem_620, reduce_scatter_tensor_27, mm_95, add_55, mm_98, mm_100, getitem_654, getitem_655, getitem_660, getitem_661, reduce_scatter_tensor_29, mm_102, add_59, mm_105, mm_107, getitem_695, getitem_696, getitem_701, getitem_702, reduce_scatter_tensor_31, mm_109, add_63, mm_112, mm_114, getitem_736, getitem_737, getitem_742, getitem_743, reduce_scatter_tensor_33, mm_116, add_67, mm_119, mm_121, getitem_777, getitem_778, getitem_783, getitem_784, reduce_scatter_tensor_35, mm_123, add_71, mm_126, mm_128, getitem_818, getitem_819, getitem_824, getitem_825, reduce_scatter_tensor_37, mm_130, add_75, mm_133, mm_135, getitem_859, getitem_860, getitem_865, getitem_866, reduce_scatter_tensor_39, mm_137, add_79, mm_140, mm_142, getitem_900, getitem_901, getitem_906, getitem_907, reduce_scatter_tensor_41, mm_144, add_83, mm_147, mm_149, getitem_941, getitem_942, getitem_947, getitem_948, reduce_scatter_tensor_43, mm_151, add_87, mm_154, mm_156, getitem_982, getitem_983, getitem_988, getitem_989, reduce_scatter_tensor_45, mm_158, add_91, mm_161, mm_163, getitem_1023, getitem_1024, getitem_1029, getitem_1030, reduce_scatter_tensor_47, mm_165, add_95, mm_168, mm_170, getitem_1064, getitem_1065, getitem_1070, getitem_1071, reduce_scatter_tensor_49, mm_172, add_99, mm_175, mm_177, getitem_1105, getitem_1106, getitem_1111, getitem_1112, reduce_scatter_tensor_51, mm_179, add_103, mm_182, mm_184, getitem_1146, getitem_1147, getitem_1152, getitem_1153, reduce_scatter_tensor_53, mm_186, add_107, mm_189, mm_191, getitem_1187, getitem_1188, getitem_1193, getitem_1194, reduce_scatter_tensor_55, mm_193, add_111, mm_196, mm_198, getitem_1228, getitem_1229, getitem_1234, getitem_1235, reduce_scatter_tensor_57, mm_200, add_115, mm_203, mm_205, getitem_1269, getitem_1270, getitem_1275, getitem_1276, reduce_scatter_tensor_59, mm_207, add_119, mm_210, mm_212, getitem_1310, getitem_1311, getitem_1316, getitem_1317, reduce_scatter_tensor_61, mm_214, add_123, mm_217, mm_219, getitem_1351, getitem_1352, getitem_1357, getitem_1358, reduce_scatter_tensor_63, mm_221, reduce_scatter_tensor_64, rsqrt_64, view_2319, tangents_1): + view_2321 = torch.ops.aten.view.default(tangents_1, [16384, 16032]); tangents_1 = None + permute_353 = torch.ops.aten.permute.default(view_2321, [1, 0]) + mm_225 = torch.ops.aten.mm.default(permute_353, view_2319); permute_353 = view_2319 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16); primals_293 = None + all_gather_into_tensor_355 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 32, '0'); convert_element_type_1060 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_355); all_gather_into_tensor_355 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_420, [1, 0]); wait_tensor_420 = None + permute_355 = torch.ops.aten.permute.default(permute_352, [1, 0]); permute_352 = None + mm_226 = torch.ops.aten.mm.default(view_2321, permute_355); view_2321 = permute_355 = None + view_2322 = torch.ops.aten.view.default(mm_226, [2, 8192, 4096]); mm_226 = None + convert_element_type_1067 = torch.ops.prims.convert_element_type.default(mm_225, torch.float32); mm_225 = None + reduce_scatter_tensor_65 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1067, 'avg', 32, '0'); convert_element_type_1067 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_65); reduce_scatter_tensor_65 = None + split_138 = torch.ops.aten.split.Tensor(view_2322, 1024, 1); view_2322 = None + getitem_1392 = split_138[0] + getitem_1393 = split_138[1] + getitem_1394 = split_138[2] + getitem_1395 = split_138[3] + getitem_1396 = split_138[4] + getitem_1397 = split_138[5] + getitem_1398 = split_138[6] + getitem_1399 = split_138[7]; split_138 = None + cat_130 = torch.ops.aten.cat.default([getitem_1392, getitem_1393, getitem_1394, getitem_1395, getitem_1396, getitem_1397, getitem_1398, getitem_1399]); getitem_1392 = getitem_1393 = getitem_1394 = getitem_1395 = getitem_1396 = getitem_1397 = getitem_1398 = getitem_1399 = None + reduce_scatter_tensor_66 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_130, 'sum', 8, '1'); cat_130 = None + wait_tensor_422 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_66); reduce_scatter_tensor_66 = None + convert_element_type_1068 = torch.ops.prims.convert_element_type.default(wait_tensor_422, torch.float32); wait_tensor_422 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16); primals_292 = None + all_gather_into_tensor_353 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 32, '0'); convert_element_type_1057 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_353); all_gather_into_tensor_353 = None + convert_element_type_1070 = torch.ops.prims.convert_element_type.default(wait_tensor_418, torch.float32); wait_tensor_418 = None + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_1068, convert_element_type_1070); convert_element_type_1070 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63); reduce_scatter_tensor_63 = None + add_125 = torch.ops.aten.add.Tensor(add_123, wait_tensor_411); wait_tensor_411 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64); reduce_scatter_tensor_64 = None + add_127 = torch.ops.aten.add.Tensor(add_125, wait_tensor_417); wait_tensor_417 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_260 = torch.ops.aten.mul.Tensor(mul_256, mul_258) + sum_1 = torch.ops.aten.sum.dim_IntList(mul_260, [2], True); mul_260 = None + div = torch.ops.aten.div.Tensor(mul_256, 4096) + mul_261 = torch.ops.aten.mul.Tensor(div, sum_1); div = sum_1 = None + sub_1 = torch.ops.aten.sub.Tensor(mul_258, mul_261); mul_258 = mul_261 = None + mul_262 = torch.ops.aten.mul.Tensor(sub_1, rsqrt_64); sub_1 = rsqrt_64 = None + mul_263 = torch.ops.aten.mul.Tensor(convert_element_type_1068, mul_256); convert_element_type_1068 = mul_256 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_263, [0, 1]); mul_263 = None + convert_element_type_1071 = torch.ops.prims.convert_element_type.default(mul_262, torch.bfloat16); mul_262 = None + convert_element_type_1072 = torch.ops.prims.convert_element_type.default(sum_2, torch.bfloat16); sum_2 = None + all_reduce = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1072, 'sum', '1'); convert_element_type_1072 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(all_reduce); all_reduce = None + convert_element_type_1073 = torch.ops.prims.convert_element_type.default(wait_tensor_423, torch.float32); wait_tensor_423 = None + reduce_scatter_tensor_67 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1073, 'avg', 32, '0'); convert_element_type_1073 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_67); reduce_scatter_tensor_67 = None + all_gather_into_tensor_356 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1071, 8, '1') + wait_tensor_425 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_356); all_gather_into_tensor_356 = None + split_139 = torch.ops.aten.split.Tensor(wait_tensor_425, 2); wait_tensor_425 = None + getitem_1400 = split_139[0] + getitem_1401 = split_139[1] + getitem_1402 = split_139[2] + getitem_1403 = split_139[3] + getitem_1404 = split_139[4] + getitem_1405 = split_139[5] + getitem_1406 = split_139[6] + getitem_1407 = split_139[7]; split_139 = None + cat_131 = torch.ops.aten.cat.default([getitem_1400, getitem_1401, getitem_1402, getitem_1403, getitem_1404, getitem_1405, getitem_1406, getitem_1407], 1); getitem_1400 = getitem_1401 = getitem_1402 = getitem_1403 = getitem_1404 = getitem_1405 = getitem_1406 = getitem_1407 = None + view_2323 = torch.ops.aten.view.default(cat_131, [16384, 4096]); cat_131 = None + permute_357 = torch.ops.aten.permute.default(view_2323, [1, 0]) + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16); primals_288 = None + all_gather_into_tensor_348 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 32, '0'); convert_element_type_1043 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_348); all_gather_into_tensor_348 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32); add_125 = None + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_412) + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + all_gather_into_tensor_349 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1045, 8, '1'); convert_element_type_1045 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_349); all_gather_into_tensor_349 = None + split_135 = torch.ops.aten.split.Tensor(wait_tensor_413, 2); wait_tensor_413 = None + getitem_1368 = split_135[0] + getitem_1369 = split_135[1] + getitem_1370 = split_135[2] + getitem_1371 = split_135[3] + getitem_1372 = split_135[4] + getitem_1373 = split_135[5] + getitem_1374 = split_135[6] + getitem_1375 = split_135[7]; split_135 = None + cat_127 = torch.ops.aten.cat.default([getitem_1368, getitem_1369, getitem_1370, getitem_1371, getitem_1372, getitem_1373, getitem_1374, getitem_1375], 1); getitem_1368 = getitem_1369 = getitem_1370 = getitem_1371 = getitem_1372 = getitem_1373 = getitem_1374 = getitem_1375 = None + view_2292 = torch.ops.aten.view.default(cat_127, [16384, 4096]); cat_127 = None + view_2293 = torch.ops.aten.view.default(mm_221, [2, 8192, 1792]); mm_221 = None + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_2293, torch.float32); view_2293 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16); primals_290 = None + all_gather_into_tensor_351 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 32, '0'); convert_element_type_1051 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_351); all_gather_into_tensor_351 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_415, [1, 0]); wait_tensor_415 = None + mm_222 = torch.ops.aten.mm.default(view_2292, permute_350) + view_2300 = torch.ops.aten.view.default(mm_222, [2, 8192, 1792]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_2300) + view_2307 = torch.ops.aten.view.default(mul_255, [16384, 1792]); mul_255 = None + mm_227 = torch.ops.aten.mm.default(permute_357, view_2307); permute_357 = view_2307 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16); primals_291 = None + all_gather_into_tensor_352 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 32, '0'); convert_element_type_1054 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_352); all_gather_into_tensor_352 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_416, [1, 0]); wait_tensor_416 = None + permute_359 = torch.ops.aten.permute.default(permute_351, [1, 0]); permute_351 = None + mm_228 = torch.ops.aten.mm.default(view_2323, permute_359); view_2323 = permute_359 = None + view_2324 = torch.ops.aten.view.default(mm_228, [2, 8192, 1792]); mm_228 = None + convert_element_type_1078 = torch.ops.prims.convert_element_type.default(mm_227, torch.float32); mm_227 = None + reduce_scatter_tensor_68 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1078, 'avg', 32, '0'); convert_element_type_1078 = None + wait_tensor_426 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_68); reduce_scatter_tensor_68 = None + mul_264 = torch.ops.aten.mul.Tensor(view_2324, convert_element_type_1050); convert_element_type_1050 = None + mul_265 = torch.ops.aten.mul.Tensor(view_2324, view_2300); view_2324 = view_2300 = None + view_2325 = torch.ops.aten.view.default(mul_264, [16384, 1792]); mul_264 = None + permute_361 = torch.ops.aten.permute.default(view_2325, [1, 0]) + mm_229 = torch.ops.aten.mm.default(permute_361, view_2292); permute_361 = None + permute_363 = torch.ops.aten.permute.default(permute_350, [1, 0]); permute_350 = None + mm_230 = torch.ops.aten.mm.default(view_2325, permute_363); view_2325 = permute_363 = None + view_2326 = torch.ops.aten.view.default(mm_230, [2, 8192, 4096]); mm_230 = None + convert_element_type_1083 = torch.ops.prims.convert_element_type.default(mm_229, torch.float32); mm_229 = None + reduce_scatter_tensor_69 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1083, 'avg', 32, '0'); convert_element_type_1083 = None + wait_tensor_427 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_69); reduce_scatter_tensor_69 = None + convert_element_type_1084 = torch.ops.prims.convert_element_type.default(mul_265, torch.float32); mul_265 = None + neg = torch.ops.aten.neg.default(convert_element_type_1049) + exp = torch.ops.aten.exp.default(neg); neg = None + add_129 = torch.ops.aten.add.Tensor(exp, 1); exp = None + reciprocal = torch.ops.aten.reciprocal.default(add_129); add_129 = None + mul_266 = torch.ops.aten.mul.Tensor(reciprocal, 1); reciprocal = None + mul_267 = torch.ops.aten.mul.Tensor(convert_element_type_1084, mul_266); convert_element_type_1084 = None + sub_2 = torch.ops.aten.sub.Tensor(1, mul_266); mul_266 = None + mul_268 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sub_2); convert_element_type_1049 = sub_2 = None + add_130 = torch.ops.aten.add.Tensor(mul_268, 1); mul_268 = None + mul_269 = torch.ops.aten.mul.Tensor(mul_267, add_130); mul_267 = add_130 = None + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(mul_269, torch.bfloat16); mul_269 = None + view_2327 = torch.ops.aten.view.default(convert_element_type_1086, [16384, 1792]); convert_element_type_1086 = None + permute_365 = torch.ops.aten.permute.default(view_2327, [1, 0]) + mm_231 = torch.ops.aten.mm.default(permute_365, view_2292); permute_365 = view_2292 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16); primals_289 = None + all_gather_into_tensor_350 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 32, '0'); convert_element_type_1046 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_350); all_gather_into_tensor_350 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_414, [1, 0]); wait_tensor_414 = None + permute_367 = torch.ops.aten.permute.default(permute_349, [1, 0]); permute_349 = None + mm_232 = torch.ops.aten.mm.default(view_2327, permute_367); view_2327 = permute_367 = None + view_2328 = torch.ops.aten.view.default(mm_232, [2, 8192, 4096]); mm_232 = None + add_131 = torch.ops.aten.add.Tensor(view_2326, view_2328); view_2326 = view_2328 = None + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(mm_231, torch.float32); mm_231 = None + reduce_scatter_tensor_70 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1091, 'avg', 32, '0'); convert_element_type_1091 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_70); reduce_scatter_tensor_70 = None + split_140 = torch.ops.aten.split.Tensor(add_131, 1024, 1); add_131 = None + getitem_1408 = split_140[0] + getitem_1409 = split_140[1] + getitem_1410 = split_140[2] + getitem_1411 = split_140[3] + getitem_1412 = split_140[4] + getitem_1413 = split_140[5] + getitem_1414 = split_140[6] + getitem_1415 = split_140[7]; split_140 = None + cat_132 = torch.ops.aten.cat.default([getitem_1408, getitem_1409, getitem_1410, getitem_1411, getitem_1412, getitem_1413, getitem_1414, getitem_1415]); getitem_1408 = getitem_1409 = getitem_1410 = getitem_1411 = getitem_1412 = getitem_1413 = getitem_1414 = getitem_1415 = None + reduce_scatter_tensor_71 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_132, 'sum', 8, '1'); cat_132 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_71); reduce_scatter_tensor_71 = None + convert_element_type_1092 = torch.ops.prims.convert_element_type.default(wait_tensor_429, torch.float32); wait_tensor_429 = None + convert_element_type_1094 = torch.ops.prims.convert_element_type.default(wait_tensor_412, torch.float32); wait_tensor_412 = None + mul_270 = torch.ops.aten.mul.Tensor(convert_element_type_1092, convert_element_type_1094); convert_element_type_1094 = None + mul_272 = torch.ops.aten.mul.Tensor(mul_252, mul_270) + sum_3 = torch.ops.aten.sum.dim_IntList(mul_272, [2], True); mul_272 = None + div_1 = torch.ops.aten.div.Tensor(mul_252, 4096) + mul_273 = torch.ops.aten.mul.Tensor(div_1, sum_3); div_1 = sum_3 = None + sub_3 = torch.ops.aten.sub.Tensor(mul_270, mul_273); mul_270 = mul_273 = None + mul_274 = torch.ops.aten.mul.Tensor(sub_3, rsqrt_63); sub_3 = rsqrt_63 = None + mul_275 = torch.ops.aten.mul.Tensor(convert_element_type_1092, mul_252); convert_element_type_1092 = mul_252 = None + sum_4 = torch.ops.aten.sum.dim_IntList(mul_275, [0, 1]); mul_275 = None + convert_element_type_1095 = torch.ops.prims.convert_element_type.default(mul_274, torch.bfloat16); mul_274 = None + convert_element_type_1096 = torch.ops.prims.convert_element_type.default(sum_4, torch.bfloat16); sum_4 = None + all_reduce_1 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1096, 'sum', '1'); convert_element_type_1096 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_1); all_reduce_1 = None + convert_element_type_1097 = torch.ops.prims.convert_element_type.default(wait_tensor_430, torch.float32); wait_tensor_430 = None + reduce_scatter_tensor_72 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1097, 'avg', 32, '0'); convert_element_type_1097 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_72); reduce_scatter_tensor_72 = None + add_132 = torch.ops.aten.add.Tensor(convert_element_type_1071, convert_element_type_1095); convert_element_type_1071 = convert_element_type_1095 = None + all_gather_into_tensor_357 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_132, 8, '1') + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_357); all_gather_into_tensor_357 = None + split_141 = torch.ops.aten.split.Tensor(wait_tensor_432, 2); wait_tensor_432 = None + getitem_1416 = split_141[0] + getitem_1417 = split_141[1] + getitem_1418 = split_141[2] + getitem_1419 = split_141[3] + getitem_1420 = split_141[4] + getitem_1421 = split_141[5] + getitem_1422 = split_141[6] + getitem_1423 = split_141[7]; split_141 = None + cat_133 = torch.ops.aten.cat.default([getitem_1416, getitem_1417, getitem_1418, getitem_1419, getitem_1420, getitem_1421, getitem_1422, getitem_1423], 1); getitem_1416 = getitem_1417 = getitem_1418 = getitem_1419 = getitem_1420 = getitem_1421 = getitem_1422 = getitem_1423 = None + view_2329 = torch.ops.aten.view.default(cat_133, [16384, 4096]); cat_133 = None + permute_369 = torch.ops.aten.permute.default(view_2329, [1, 0]) + permute_347 = torch.ops.aten.permute.default(getitem_1351, [0, 2, 1, 3]) + view_2274 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + view_2280 = torch.ops.aten.view.default(view_2274, [16384, 512]); view_2274 = None + mm_233 = torch.ops.aten.mm.default(permute_369, view_2280); permute_369 = view_2280 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16); primals_287 = None + all_gather_into_tensor_347 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 32, '0'); convert_element_type_1040 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_347); all_gather_into_tensor_347 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_410, [1, 0]); wait_tensor_410 = None + permute_371 = torch.ops.aten.permute.default(permute_348, [1, 0]); permute_348 = None + mm_234 = torch.ops.aten.mm.default(view_2329, permute_371); view_2329 = permute_371 = None + view_2330 = torch.ops.aten.view.default(mm_234, [2, 8192, 512]); mm_234 = None + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(mm_233, torch.float32); mm_233 = None + reduce_scatter_tensor_73 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1102, 'avg', 32, '0'); convert_element_type_1102 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_73); reduce_scatter_tensor_73 = None + view_2331 = torch.ops.aten.view.default(view_2330, [2, 8192, 4, 128]); view_2330 = None + permute_373 = torch.ops.aten.permute.default(view_2331, [0, 2, 1, 3]); view_2331 = None + view_37 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]); primals_3 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16); primals_283 = None + all_gather_into_tensor_342 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 32, '0'); convert_element_type_1024 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_342); all_gather_into_tensor_342 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32); add_123 = None + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_405) + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + all_gather_into_tensor_343 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1026, 8, '1'); convert_element_type_1026 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_343); all_gather_into_tensor_343 = None + split_133 = torch.ops.aten.split.Tensor(wait_tensor_406, 2); wait_tensor_406 = None + getitem_1343 = split_133[0] + getitem_1344 = split_133[1] + getitem_1345 = split_133[2] + getitem_1346 = split_133[3] + getitem_1347 = split_133[4] + getitem_1348 = split_133[5] + getitem_1349 = split_133[6] + getitem_1350 = split_133[7]; split_133 = None + cat_125 = torch.ops.aten.cat.default([getitem_1343, getitem_1344, getitem_1345, getitem_1346, getitem_1347, getitem_1348, getitem_1349, getitem_1350], 1); getitem_1343 = getitem_1344 = getitem_1345 = getitem_1346 = getitem_1347 = getitem_1348 = getitem_1349 = getitem_1350 = None + view_2247 = torch.ops.aten.view.default(cat_125, [16384, 4096]); cat_125 = None + view_2248 = torch.ops.aten.view.default(mm_217, [2, 8192, 512]); mm_217 = None + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None + all_gather_into_tensor_345 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 32, '0'); convert_element_type_1030 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_345); all_gather_into_tensor_345 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_408, [1, 0]); wait_tensor_408 = None + mm_218 = torch.ops.aten.mm.default(view_2247, permute_342) + view_2255 = torch.ops.aten.view.default(mm_218, [2, 8192, 128]); mm_218 = None + view_2262 = torch.ops.aten.view.default(mm_219, [2, 8192, 128]); mm_219 = None + view_2264 = torch.ops.aten.view.default(view_2248, [2, 8192, -1, 128]); view_2248 = None + view_2265 = torch.ops.aten.view.default(view_2255, [2, 8192, -1, 128]); view_2255 = None + view_2266 = torch.ops.aten.view.default(view_2262, [2, 8192, -1, 128]); view_2262 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_2264, torch.float32); view_2264 = None + view_2267 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 4, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_2267); view_2267 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_2265, torch.float32); view_2265 = None + view_2268 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 1, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_2268); view_2268 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_37); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_2270 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 4, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_37); view_as_complex_63 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_2271 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 1, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_2270, torch.bfloat16); view_2270 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_2271, torch.bfloat16); view_2271 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 1, 4, 128]); unsqueeze_62 = None + view_2272 = torch.ops.aten.view.default(expand_62, [2, 8192, 4, 128]); expand_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_2266, 3); view_2266 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 1, 4, 128]); unsqueeze_63 = None + view_2273 = torch.ops.aten.view.default(expand_63, [2, 8192, 4, 128]); expand_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_2272, [0, 2, 1, 3]); view_2272 = None + permute_346 = torch.ops.aten.permute.default(view_2273, [0, 2, 1, 3]); view_2273 = None + _scaled_dot_product_cudnn_attention_backward = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_373, permute_344, permute_345, permute_346, getitem_1351, getitem_1352, getitem_1357, getitem_1358, None, None, None, 8192, 8192, 0.0, True); permute_373 = permute_344 = permute_345 = permute_346 = getitem_1351 = getitem_1352 = getitem_1357 = getitem_1358 = None + getitem_1424 = _scaled_dot_product_cudnn_attention_backward[0] + getitem_1425 = _scaled_dot_product_cudnn_attention_backward[1] + getitem_1426 = _scaled_dot_product_cudnn_attention_backward[2]; _scaled_dot_product_cudnn_attention_backward = None + permute_374 = torch.ops.aten.permute.default(getitem_1426, [0, 2, 1, 3]); getitem_1426 = None + permute_375 = torch.ops.aten.permute.default(getitem_1425, [0, 2, 1, 3]); getitem_1425 = None + permute_376 = torch.ops.aten.permute.default(getitem_1424, [0, 2, 1, 3]); getitem_1424 = None + view_2332 = torch.ops.aten.view.default(permute_374, [2, 8192, 1, 4, 128]); permute_374 = None + sum_5 = torch.ops.aten.sum.dim_IntList(view_2332, [3], True); view_2332 = None + squeeze = torch.ops.aten.squeeze.dim(sum_5, 3); sum_5 = None + view_2333 = torch.ops.aten.view.default(permute_375, [2, 8192, 1, 4, 128]); permute_375 = None + sum_6 = torch.ops.aten.sum.dim_IntList(view_2333, [3], True); view_2333 = None + squeeze_1 = torch.ops.aten.squeeze.dim(sum_6, 3); sum_6 = None + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(squeeze_1, torch.float32); squeeze_1 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(permute_376, torch.float32); permute_376 = None + view_2334 = torch.ops.aten.view.default(convert_element_type_1103, [2, 8192, 1, 64, 2]); convert_element_type_1103 = None + view_as_complex_64 = torch.ops.aten.view_as_complex.default(view_2334); view_2334 = None + _conj = torch.ops.aten._conj.default(view_37) + mul_276 = torch.ops.aten.mul.Tensor(view_as_complex_64, _conj); view_as_complex_64 = None + view_2335 = torch.ops.aten.view.default(convert_element_type_1104, [2, 8192, 4, 64, 2]); convert_element_type_1104 = None + view_as_complex_65 = torch.ops.aten.view_as_complex.default(view_2335); view_2335 = None + mul_277 = torch.ops.aten.mul.Tensor(view_as_complex_65, _conj); view_as_complex_65 = None + view_as_real_64 = torch.ops.aten.view_as_real.default(mul_276); mul_276 = None + view_2336 = torch.ops.aten.view.default(view_as_real_64, [2, 8192, 1, 128]); view_as_real_64 = None + convert_element_type_1105 = torch.ops.prims.convert_element_type.default(view_2336, torch.bfloat16); view_2336 = None + view_as_real_65 = torch.ops.aten.view_as_real.default(mul_277); mul_277 = None + view_2337 = torch.ops.aten.view.default(view_as_real_65, [2, 8192, 4, 128]); view_as_real_65 = None + convert_element_type_1106 = torch.ops.prims.convert_element_type.default(view_2337, torch.bfloat16); view_2337 = None + view_2338 = torch.ops.aten.view.default(squeeze, [2, 8192, 128]); squeeze = None + view_2339 = torch.ops.aten.view.default(convert_element_type_1105, [2, 8192, 128]); convert_element_type_1105 = None + view_2340 = torch.ops.aten.view.default(convert_element_type_1106, [2, 8192, 512]); convert_element_type_1106 = None + view_2341 = torch.ops.aten.view.default(view_2338, [16384, 128]); view_2338 = None + permute_377 = torch.ops.aten.permute.default(view_2341, [1, 0]) + mm_235 = torch.ops.aten.mm.default(permute_377, view_2247); permute_377 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16); primals_286 = None + all_gather_into_tensor_346 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 32, '0'); convert_element_type_1033 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_346); all_gather_into_tensor_346 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_409, [1, 0]); wait_tensor_409 = None + permute_379 = torch.ops.aten.permute.default(permute_343, [1, 0]); permute_343 = None + mm_236 = torch.ops.aten.mm.default(view_2341, permute_379); view_2341 = permute_379 = None + view_2342 = torch.ops.aten.view.default(mm_236, [2, 8192, 4096]); mm_236 = None + convert_element_type_1111 = torch.ops.prims.convert_element_type.default(mm_235, torch.float32); mm_235 = None + reduce_scatter_tensor_74 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1111, 'avg', 32, '0'); convert_element_type_1111 = None + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_74); reduce_scatter_tensor_74 = None + view_2343 = torch.ops.aten.view.default(view_2339, [16384, 128]); view_2339 = None + permute_381 = torch.ops.aten.permute.default(view_2343, [1, 0]) + mm_237 = torch.ops.aten.mm.default(permute_381, view_2247); permute_381 = None + permute_383 = torch.ops.aten.permute.default(permute_342, [1, 0]); permute_342 = None + mm_238 = torch.ops.aten.mm.default(view_2343, permute_383); view_2343 = permute_383 = None + view_2344 = torch.ops.aten.view.default(mm_238, [2, 8192, 4096]); mm_238 = None + add_133 = torch.ops.aten.add.Tensor(view_2342, view_2344); view_2342 = view_2344 = None + convert_element_type_1116 = torch.ops.prims.convert_element_type.default(mm_237, torch.float32); mm_237 = None + reduce_scatter_tensor_75 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1116, 'avg', 32, '0'); convert_element_type_1116 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_75); reduce_scatter_tensor_75 = None + view_2345 = torch.ops.aten.view.default(view_2340, [16384, 512]); view_2340 = None + permute_385 = torch.ops.aten.permute.default(view_2345, [1, 0]) + mm_239 = torch.ops.aten.mm.default(permute_385, view_2247); permute_385 = view_2247 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None + all_gather_into_tensor_344 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 32, '0'); convert_element_type_1027 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_344); all_gather_into_tensor_344 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_407, [1, 0]); wait_tensor_407 = None + permute_387 = torch.ops.aten.permute.default(permute_341, [1, 0]); permute_341 = None + mm_240 = torch.ops.aten.mm.default(view_2345, permute_387); view_2345 = permute_387 = None + view_2346 = torch.ops.aten.view.default(mm_240, [2, 8192, 4096]); mm_240 = None + add_134 = torch.ops.aten.add.Tensor(add_133, view_2346); add_133 = view_2346 = None + convert_element_type_1121 = torch.ops.prims.convert_element_type.default(mm_239, torch.float32); mm_239 = None + reduce_scatter_tensor_76 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1121, 'avg', 32, '0'); convert_element_type_1121 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_76); reduce_scatter_tensor_76 = None + split_142 = torch.ops.aten.split.Tensor(add_134, 1024, 1); add_134 = None + getitem_1427 = split_142[0] + getitem_1428 = split_142[1] + getitem_1429 = split_142[2] + getitem_1430 = split_142[3] + getitem_1431 = split_142[4] + getitem_1432 = split_142[5] + getitem_1433 = split_142[6] + getitem_1434 = split_142[7]; split_142 = None + cat_134 = torch.ops.aten.cat.default([getitem_1427, getitem_1428, getitem_1429, getitem_1430, getitem_1431, getitem_1432, getitem_1433, getitem_1434]); getitem_1427 = getitem_1428 = getitem_1429 = getitem_1430 = getitem_1431 = getitem_1432 = getitem_1433 = getitem_1434 = None + reduce_scatter_tensor_77 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_134, 'sum', 8, '1'); cat_134 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_77); reduce_scatter_tensor_77 = None + convert_element_type_1122 = torch.ops.prims.convert_element_type.default(wait_tensor_437, torch.float32); wait_tensor_437 = None + convert_element_type_1124 = torch.ops.prims.convert_element_type.default(wait_tensor_405, torch.float32); wait_tensor_405 = None + mul_278 = torch.ops.aten.mul.Tensor(convert_element_type_1122, convert_element_type_1124); convert_element_type_1124 = None + mul_280 = torch.ops.aten.mul.Tensor(mul_248, mul_278) + sum_7 = torch.ops.aten.sum.dim_IntList(mul_280, [2], True); mul_280 = None + div_2 = torch.ops.aten.div.Tensor(mul_248, 4096) + mul_281 = torch.ops.aten.mul.Tensor(div_2, sum_7); div_2 = sum_7 = None + sub_4 = torch.ops.aten.sub.Tensor(mul_278, mul_281); mul_278 = mul_281 = None + mul_282 = torch.ops.aten.mul.Tensor(sub_4, rsqrt_62); sub_4 = rsqrt_62 = None + mul_283 = torch.ops.aten.mul.Tensor(convert_element_type_1122, mul_248); convert_element_type_1122 = mul_248 = None + sum_8 = torch.ops.aten.sum.dim_IntList(mul_283, [0, 1]); mul_283 = None + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(mul_282, torch.bfloat16); mul_282 = None + convert_element_type_1126 = torch.ops.prims.convert_element_type.default(sum_8, torch.bfloat16); sum_8 = None + all_reduce_2 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1126, 'sum', '1'); convert_element_type_1126 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_2); all_reduce_2 = None + convert_element_type_1127 = torch.ops.prims.convert_element_type.default(wait_tensor_438, torch.float32); wait_tensor_438 = None + reduce_scatter_tensor_78 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1127, 'avg', 32, '0'); convert_element_type_1127 = None + wait_tensor_439 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_78); reduce_scatter_tensor_78 = None + add_135 = torch.ops.aten.add.Tensor(add_132, convert_element_type_1125); add_132 = convert_element_type_1125 = None + all_gather_into_tensor_358 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_135, 8, '1') + wait_tensor_440 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_358); all_gather_into_tensor_358 = None + split_143 = torch.ops.aten.split.Tensor(wait_tensor_440, 2); wait_tensor_440 = None + getitem_1435 = split_143[0] + getitem_1436 = split_143[1] + getitem_1437 = split_143[2] + getitem_1438 = split_143[3] + getitem_1439 = split_143[4] + getitem_1440 = split_143[5] + getitem_1441 = split_143[6] + getitem_1442 = split_143[7]; split_143 = None + cat_135 = torch.ops.aten.cat.default([getitem_1435, getitem_1436, getitem_1437, getitem_1438, getitem_1439, getitem_1440, getitem_1441, getitem_1442], 1); getitem_1435 = getitem_1436 = getitem_1437 = getitem_1438 = getitem_1439 = getitem_1440 = getitem_1441 = getitem_1442 = None + view_2347 = torch.ops.aten.view.default(cat_135, [16384, 4096]); cat_135 = None + permute_389 = torch.ops.aten.permute.default(view_2347, [1, 0]) + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61); reduce_scatter_tensor_61 = None + add_121 = torch.ops.aten.add.Tensor(add_119, wait_tensor_398); wait_tensor_398 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16); primals_279 = None + all_gather_into_tensor_337 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 32, '0'); convert_element_type_1010 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_337); all_gather_into_tensor_337 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32); add_121 = None + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_399) + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + all_gather_into_tensor_338 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1012, 8, '1'); convert_element_type_1012 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_338); all_gather_into_tensor_338 = None + split_131 = torch.ops.aten.split.Tensor(wait_tensor_400, 2); wait_tensor_400 = None + getitem_1327 = split_131[0] + getitem_1328 = split_131[1] + getitem_1329 = split_131[2] + getitem_1330 = split_131[3] + getitem_1331 = split_131[4] + getitem_1332 = split_131[5] + getitem_1333 = split_131[6] + getitem_1334 = split_131[7]; split_131 = None + cat_123 = torch.ops.aten.cat.default([getitem_1327, getitem_1328, getitem_1329, getitem_1330, getitem_1331, getitem_1332, getitem_1333, getitem_1334], 1); getitem_1327 = getitem_1328 = getitem_1329 = getitem_1330 = getitem_1331 = getitem_1332 = getitem_1333 = getitem_1334 = None + view_2220 = torch.ops.aten.view.default(cat_123, [16384, 4096]); cat_123 = None + view_2221 = torch.ops.aten.view.default(mm_214, [2, 8192, 1792]); mm_214 = None + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_2221, torch.float32); view_2221 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16); primals_281 = None + all_gather_into_tensor_340 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 32, '0'); convert_element_type_1018 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_340); all_gather_into_tensor_340 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_402, [1, 0]); wait_tensor_402 = None + mm_215 = torch.ops.aten.mm.default(view_2220, permute_339) + view_2228 = torch.ops.aten.view.default(mm_215, [2, 8192, 1792]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_2228) + view_2235 = torch.ops.aten.view.default(mul_247, [16384, 1792]); mul_247 = None + mm_241 = torch.ops.aten.mm.default(permute_389, view_2235); permute_389 = view_2235 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16); primals_282 = None + all_gather_into_tensor_341 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 32, '0'); convert_element_type_1021 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_341); all_gather_into_tensor_341 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_403, [1, 0]); wait_tensor_403 = None + permute_391 = torch.ops.aten.permute.default(permute_340, [1, 0]); permute_340 = None + mm_242 = torch.ops.aten.mm.default(view_2347, permute_391); view_2347 = permute_391 = None + view_2348 = torch.ops.aten.view.default(mm_242, [2, 8192, 1792]); mm_242 = None + convert_element_type_1132 = torch.ops.prims.convert_element_type.default(mm_241, torch.float32); mm_241 = None + reduce_scatter_tensor_79 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1132, 'avg', 32, '0'); convert_element_type_1132 = None + wait_tensor_441 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_79); reduce_scatter_tensor_79 = None + mul_284 = torch.ops.aten.mul.Tensor(view_2348, convert_element_type_1017); convert_element_type_1017 = None + mul_285 = torch.ops.aten.mul.Tensor(view_2348, view_2228); view_2348 = view_2228 = None + view_2349 = torch.ops.aten.view.default(mul_284, [16384, 1792]); mul_284 = None + permute_393 = torch.ops.aten.permute.default(view_2349, [1, 0]) + mm_243 = torch.ops.aten.mm.default(permute_393, view_2220); permute_393 = None + permute_395 = torch.ops.aten.permute.default(permute_339, [1, 0]); permute_339 = None + mm_244 = torch.ops.aten.mm.default(view_2349, permute_395); view_2349 = permute_395 = None + view_2350 = torch.ops.aten.view.default(mm_244, [2, 8192, 4096]); mm_244 = None + convert_element_type_1137 = torch.ops.prims.convert_element_type.default(mm_243, torch.float32); mm_243 = None + reduce_scatter_tensor_80 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1137, 'avg', 32, '0'); convert_element_type_1137 = None + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_80); reduce_scatter_tensor_80 = None + convert_element_type_1138 = torch.ops.prims.convert_element_type.default(mul_285, torch.float32); mul_285 = None + neg_1 = torch.ops.aten.neg.default(convert_element_type_1016) + exp_1 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_136 = torch.ops.aten.add.Tensor(exp_1, 1); exp_1 = None + reciprocal_1 = torch.ops.aten.reciprocal.default(add_136); add_136 = None + mul_286 = torch.ops.aten.mul.Tensor(reciprocal_1, 1); reciprocal_1 = None + mul_287 = torch.ops.aten.mul.Tensor(convert_element_type_1138, mul_286); convert_element_type_1138 = None + sub_5 = torch.ops.aten.sub.Tensor(1, mul_286); mul_286 = None + mul_288 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sub_5); convert_element_type_1016 = sub_5 = None + add_137 = torch.ops.aten.add.Tensor(mul_288, 1); mul_288 = None + mul_289 = torch.ops.aten.mul.Tensor(mul_287, add_137); mul_287 = add_137 = None + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(mul_289, torch.bfloat16); mul_289 = None + view_2351 = torch.ops.aten.view.default(convert_element_type_1140, [16384, 1792]); convert_element_type_1140 = None + permute_397 = torch.ops.aten.permute.default(view_2351, [1, 0]) + mm_245 = torch.ops.aten.mm.default(permute_397, view_2220); permute_397 = view_2220 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16); primals_280 = None + all_gather_into_tensor_339 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 32, '0'); convert_element_type_1013 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_339); all_gather_into_tensor_339 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_401, [1, 0]); wait_tensor_401 = None + permute_399 = torch.ops.aten.permute.default(permute_338, [1, 0]); permute_338 = None + mm_246 = torch.ops.aten.mm.default(view_2351, permute_399); view_2351 = permute_399 = None + view_2352 = torch.ops.aten.view.default(mm_246, [2, 8192, 4096]); mm_246 = None + add_138 = torch.ops.aten.add.Tensor(view_2350, view_2352); view_2350 = view_2352 = None + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(mm_245, torch.float32); mm_245 = None + reduce_scatter_tensor_81 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1145, 'avg', 32, '0'); convert_element_type_1145 = None + wait_tensor_443 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_81); reduce_scatter_tensor_81 = None + split_144 = torch.ops.aten.split.Tensor(add_138, 1024, 1); add_138 = None + getitem_1443 = split_144[0] + getitem_1444 = split_144[1] + getitem_1445 = split_144[2] + getitem_1446 = split_144[3] + getitem_1447 = split_144[4] + getitem_1448 = split_144[5] + getitem_1449 = split_144[6] + getitem_1450 = split_144[7]; split_144 = None + cat_136 = torch.ops.aten.cat.default([getitem_1443, getitem_1444, getitem_1445, getitem_1446, getitem_1447, getitem_1448, getitem_1449, getitem_1450]); getitem_1443 = getitem_1444 = getitem_1445 = getitem_1446 = getitem_1447 = getitem_1448 = getitem_1449 = getitem_1450 = None + reduce_scatter_tensor_82 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_136, 'sum', 8, '1'); cat_136 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_82); reduce_scatter_tensor_82 = None + convert_element_type_1146 = torch.ops.prims.convert_element_type.default(wait_tensor_444, torch.float32); wait_tensor_444 = None + convert_element_type_1148 = torch.ops.prims.convert_element_type.default(wait_tensor_399, torch.float32); wait_tensor_399 = None + mul_290 = torch.ops.aten.mul.Tensor(convert_element_type_1146, convert_element_type_1148); convert_element_type_1148 = None + mul_292 = torch.ops.aten.mul.Tensor(mul_244, mul_290) + sum_9 = torch.ops.aten.sum.dim_IntList(mul_292, [2], True); mul_292 = None + div_3 = torch.ops.aten.div.Tensor(mul_244, 4096) + mul_293 = torch.ops.aten.mul.Tensor(div_3, sum_9); div_3 = sum_9 = None + sub_6 = torch.ops.aten.sub.Tensor(mul_290, mul_293); mul_290 = mul_293 = None + mul_294 = torch.ops.aten.mul.Tensor(sub_6, rsqrt_61); sub_6 = rsqrt_61 = None + mul_295 = torch.ops.aten.mul.Tensor(convert_element_type_1146, mul_244); convert_element_type_1146 = mul_244 = None + sum_10 = torch.ops.aten.sum.dim_IntList(mul_295, [0, 1]); mul_295 = None + convert_element_type_1149 = torch.ops.prims.convert_element_type.default(mul_294, torch.bfloat16); mul_294 = None + convert_element_type_1150 = torch.ops.prims.convert_element_type.default(sum_10, torch.bfloat16); sum_10 = None + all_reduce_3 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1150, 'sum', '1'); convert_element_type_1150 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_3); all_reduce_3 = None + convert_element_type_1151 = torch.ops.prims.convert_element_type.default(wait_tensor_445, torch.float32); wait_tensor_445 = None + reduce_scatter_tensor_83 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1151, 'avg', 32, '0'); convert_element_type_1151 = None + wait_tensor_446 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_83); reduce_scatter_tensor_83 = None + add_139 = torch.ops.aten.add.Tensor(add_135, convert_element_type_1149); add_135 = convert_element_type_1149 = None + all_gather_into_tensor_359 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_139, 8, '1') + wait_tensor_447 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_359); all_gather_into_tensor_359 = None + split_145 = torch.ops.aten.split.Tensor(wait_tensor_447, 2); wait_tensor_447 = None + getitem_1451 = split_145[0] + getitem_1452 = split_145[1] + getitem_1453 = split_145[2] + getitem_1454 = split_145[3] + getitem_1455 = split_145[4] + getitem_1456 = split_145[5] + getitem_1457 = split_145[6] + getitem_1458 = split_145[7]; split_145 = None + cat_137 = torch.ops.aten.cat.default([getitem_1451, getitem_1452, getitem_1453, getitem_1454, getitem_1455, getitem_1456, getitem_1457, getitem_1458], 1); getitem_1451 = getitem_1452 = getitem_1453 = getitem_1454 = getitem_1455 = getitem_1456 = getitem_1457 = getitem_1458 = None + view_2353 = torch.ops.aten.view.default(cat_137, [16384, 4096]); cat_137 = None + permute_401 = torch.ops.aten.permute.default(view_2353, [1, 0]) + permute_336 = torch.ops.aten.permute.default(getitem_1310, [0, 2, 1, 3]) + view_2202 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + view_2208 = torch.ops.aten.view.default(view_2202, [16384, 512]); view_2202 = None + mm_247 = torch.ops.aten.mm.default(permute_401, view_2208); permute_401 = view_2208 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16); primals_278 = None + all_gather_into_tensor_336 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 32, '0'); convert_element_type_1007 = None + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_336); all_gather_into_tensor_336 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_397, [1, 0]); wait_tensor_397 = None + permute_403 = torch.ops.aten.permute.default(permute_337, [1, 0]); permute_337 = None + mm_248 = torch.ops.aten.mm.default(view_2353, permute_403); view_2353 = permute_403 = None + view_2354 = torch.ops.aten.view.default(mm_248, [2, 8192, 512]); mm_248 = None + convert_element_type_1156 = torch.ops.prims.convert_element_type.default(mm_247, torch.float32); mm_247 = None + reduce_scatter_tensor_84 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1156, 'avg', 32, '0'); convert_element_type_1156 = None + wait_tensor_448 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_84); reduce_scatter_tensor_84 = None + view_2355 = torch.ops.aten.view.default(view_2354, [2, 8192, 4, 128]); view_2354 = None + permute_405 = torch.ops.aten.permute.default(view_2355, [0, 2, 1, 3]); view_2355 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16); primals_274 = None + all_gather_into_tensor_331 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 32, '0'); convert_element_type_991 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_331); all_gather_into_tensor_331 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32); add_119 = None + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_392) + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + all_gather_into_tensor_332 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_993, 8, '1'); convert_element_type_993 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_332); all_gather_into_tensor_332 = None + split_129 = torch.ops.aten.split.Tensor(wait_tensor_393, 2); wait_tensor_393 = None + getitem_1302 = split_129[0] + getitem_1303 = split_129[1] + getitem_1304 = split_129[2] + getitem_1305 = split_129[3] + getitem_1306 = split_129[4] + getitem_1307 = split_129[5] + getitem_1308 = split_129[6] + getitem_1309 = split_129[7]; split_129 = None + cat_121 = torch.ops.aten.cat.default([getitem_1302, getitem_1303, getitem_1304, getitem_1305, getitem_1306, getitem_1307, getitem_1308, getitem_1309], 1); getitem_1302 = getitem_1303 = getitem_1304 = getitem_1305 = getitem_1306 = getitem_1307 = getitem_1308 = getitem_1309 = None + view_2175 = torch.ops.aten.view.default(cat_121, [16384, 4096]); cat_121 = None + view_2176 = torch.ops.aten.view.default(mm_210, [2, 8192, 512]); mm_210 = None + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16); primals_276 = None + all_gather_into_tensor_334 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 32, '0'); convert_element_type_997 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_334); all_gather_into_tensor_334 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_395, [1, 0]); wait_tensor_395 = None + mm_211 = torch.ops.aten.mm.default(view_2175, permute_331) + view_2183 = torch.ops.aten.view.default(mm_211, [2, 8192, 128]); mm_211 = None + view_2190 = torch.ops.aten.view.default(mm_212, [2, 8192, 128]); mm_212 = None + view_2192 = torch.ops.aten.view.default(view_2176, [2, 8192, -1, 128]); view_2176 = None + view_2193 = torch.ops.aten.view.default(view_2183, [2, 8192, -1, 128]); view_2183 = None + view_2194 = torch.ops.aten.view.default(view_2190, [2, 8192, -1, 128]); view_2190 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_2192, torch.float32); view_2192 = None + view_2195 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 4, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_2195); view_2195 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_2193, torch.float32); view_2193 = None + view_2196 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 1, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_2196); view_2196 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_37); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_2198 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 4, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_37); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_2199 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 1, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_2198, torch.bfloat16); view_2198 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_2199, torch.bfloat16); view_2199 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 1, 4, 128]); unsqueeze_60 = None + view_2200 = torch.ops.aten.view.default(expand_60, [2, 8192, 4, 128]); expand_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_2194, 3); view_2194 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 1, 4, 128]); unsqueeze_61 = None + view_2201 = torch.ops.aten.view.default(expand_61, [2, 8192, 4, 128]); expand_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_2200, [0, 2, 1, 3]); view_2200 = None + permute_335 = torch.ops.aten.permute.default(view_2201, [0, 2, 1, 3]); view_2201 = None + _scaled_dot_product_cudnn_attention_backward_1 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_405, permute_333, permute_334, permute_335, getitem_1310, getitem_1311, getitem_1316, getitem_1317, None, None, None, 8192, 8192, 0.0, True); permute_405 = permute_333 = permute_334 = permute_335 = getitem_1310 = getitem_1311 = getitem_1316 = getitem_1317 = None + getitem_1459 = _scaled_dot_product_cudnn_attention_backward_1[0] + getitem_1460 = _scaled_dot_product_cudnn_attention_backward_1[1] + getitem_1461 = _scaled_dot_product_cudnn_attention_backward_1[2]; _scaled_dot_product_cudnn_attention_backward_1 = None + permute_406 = torch.ops.aten.permute.default(getitem_1461, [0, 2, 1, 3]); getitem_1461 = None + permute_407 = torch.ops.aten.permute.default(getitem_1460, [0, 2, 1, 3]); getitem_1460 = None + permute_408 = torch.ops.aten.permute.default(getitem_1459, [0, 2, 1, 3]); getitem_1459 = None + view_2356 = torch.ops.aten.view.default(permute_406, [2, 8192, 1, 4, 128]); permute_406 = None + sum_11 = torch.ops.aten.sum.dim_IntList(view_2356, [3], True); view_2356 = None + squeeze_2 = torch.ops.aten.squeeze.dim(sum_11, 3); sum_11 = None + view_2357 = torch.ops.aten.view.default(permute_407, [2, 8192, 1, 4, 128]); permute_407 = None + sum_12 = torch.ops.aten.sum.dim_IntList(view_2357, [3], True); view_2357 = None + squeeze_3 = torch.ops.aten.squeeze.dim(sum_12, 3); sum_12 = None + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(squeeze_3, torch.float32); squeeze_3 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(permute_408, torch.float32); permute_408 = None + view_2358 = torch.ops.aten.view.default(convert_element_type_1157, [2, 8192, 1, 64, 2]); convert_element_type_1157 = None + view_as_complex_66 = torch.ops.aten.view_as_complex.default(view_2358); view_2358 = None + mul_296 = torch.ops.aten.mul.Tensor(view_as_complex_66, _conj); view_as_complex_66 = None + view_2359 = torch.ops.aten.view.default(convert_element_type_1158, [2, 8192, 4, 64, 2]); convert_element_type_1158 = None + view_as_complex_67 = torch.ops.aten.view_as_complex.default(view_2359); view_2359 = None + mul_297 = torch.ops.aten.mul.Tensor(view_as_complex_67, _conj); view_as_complex_67 = None + view_as_real_66 = torch.ops.aten.view_as_real.default(mul_296); mul_296 = None + view_2360 = torch.ops.aten.view.default(view_as_real_66, [2, 8192, 1, 128]); view_as_real_66 = None + convert_element_type_1159 = torch.ops.prims.convert_element_type.default(view_2360, torch.bfloat16); view_2360 = None + view_as_real_67 = torch.ops.aten.view_as_real.default(mul_297); mul_297 = None + view_2361 = torch.ops.aten.view.default(view_as_real_67, [2, 8192, 4, 128]); view_as_real_67 = None + convert_element_type_1160 = torch.ops.prims.convert_element_type.default(view_2361, torch.bfloat16); view_2361 = None + view_2362 = torch.ops.aten.view.default(squeeze_2, [2, 8192, 128]); squeeze_2 = None + view_2363 = torch.ops.aten.view.default(convert_element_type_1159, [2, 8192, 128]); convert_element_type_1159 = None + view_2364 = torch.ops.aten.view.default(convert_element_type_1160, [2, 8192, 512]); convert_element_type_1160 = None + view_2365 = torch.ops.aten.view.default(view_2362, [16384, 128]); view_2362 = None + permute_409 = torch.ops.aten.permute.default(view_2365, [1, 0]) + mm_249 = torch.ops.aten.mm.default(permute_409, view_2175); permute_409 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16); primals_277 = None + all_gather_into_tensor_335 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 32, '0'); convert_element_type_1000 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_335); all_gather_into_tensor_335 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_396, [1, 0]); wait_tensor_396 = None + permute_411 = torch.ops.aten.permute.default(permute_332, [1, 0]); permute_332 = None + mm_250 = torch.ops.aten.mm.default(view_2365, permute_411); view_2365 = permute_411 = None + view_2366 = torch.ops.aten.view.default(mm_250, [2, 8192, 4096]); mm_250 = None + convert_element_type_1165 = torch.ops.prims.convert_element_type.default(mm_249, torch.float32); mm_249 = None + reduce_scatter_tensor_85 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1165, 'avg', 32, '0'); convert_element_type_1165 = None + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_85); reduce_scatter_tensor_85 = None + view_2367 = torch.ops.aten.view.default(view_2363, [16384, 128]); view_2363 = None + permute_413 = torch.ops.aten.permute.default(view_2367, [1, 0]) + mm_251 = torch.ops.aten.mm.default(permute_413, view_2175); permute_413 = None + permute_415 = torch.ops.aten.permute.default(permute_331, [1, 0]); permute_331 = None + mm_252 = torch.ops.aten.mm.default(view_2367, permute_415); view_2367 = permute_415 = None + view_2368 = torch.ops.aten.view.default(mm_252, [2, 8192, 4096]); mm_252 = None + add_140 = torch.ops.aten.add.Tensor(view_2366, view_2368); view_2366 = view_2368 = None + convert_element_type_1170 = torch.ops.prims.convert_element_type.default(mm_251, torch.float32); mm_251 = None + reduce_scatter_tensor_86 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1170, 'avg', 32, '0'); convert_element_type_1170 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_86); reduce_scatter_tensor_86 = None + view_2369 = torch.ops.aten.view.default(view_2364, [16384, 512]); view_2364 = None + permute_417 = torch.ops.aten.permute.default(view_2369, [1, 0]) + mm_253 = torch.ops.aten.mm.default(permute_417, view_2175); permute_417 = view_2175 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16); primals_275 = None + all_gather_into_tensor_333 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 32, '0'); convert_element_type_994 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_333); all_gather_into_tensor_333 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_394, [1, 0]); wait_tensor_394 = None + permute_419 = torch.ops.aten.permute.default(permute_330, [1, 0]); permute_330 = None + mm_254 = torch.ops.aten.mm.default(view_2369, permute_419); view_2369 = permute_419 = None + view_2370 = torch.ops.aten.view.default(mm_254, [2, 8192, 4096]); mm_254 = None + add_141 = torch.ops.aten.add.Tensor(add_140, view_2370); add_140 = view_2370 = None + convert_element_type_1175 = torch.ops.prims.convert_element_type.default(mm_253, torch.float32); mm_253 = None + reduce_scatter_tensor_87 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1175, 'avg', 32, '0'); convert_element_type_1175 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_87); reduce_scatter_tensor_87 = None + split_146 = torch.ops.aten.split.Tensor(add_141, 1024, 1); add_141 = None + getitem_1462 = split_146[0] + getitem_1463 = split_146[1] + getitem_1464 = split_146[2] + getitem_1465 = split_146[3] + getitem_1466 = split_146[4] + getitem_1467 = split_146[5] + getitem_1468 = split_146[6] + getitem_1469 = split_146[7]; split_146 = None + cat_138 = torch.ops.aten.cat.default([getitem_1462, getitem_1463, getitem_1464, getitem_1465, getitem_1466, getitem_1467, getitem_1468, getitem_1469]); getitem_1462 = getitem_1463 = getitem_1464 = getitem_1465 = getitem_1466 = getitem_1467 = getitem_1468 = getitem_1469 = None + reduce_scatter_tensor_88 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_138, 'sum', 8, '1'); cat_138 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_88); reduce_scatter_tensor_88 = None + convert_element_type_1176 = torch.ops.prims.convert_element_type.default(wait_tensor_452, torch.float32); wait_tensor_452 = None + convert_element_type_1178 = torch.ops.prims.convert_element_type.default(wait_tensor_392, torch.float32); wait_tensor_392 = None + mul_298 = torch.ops.aten.mul.Tensor(convert_element_type_1176, convert_element_type_1178); convert_element_type_1178 = None + mul_300 = torch.ops.aten.mul.Tensor(mul_240, mul_298) + sum_13 = torch.ops.aten.sum.dim_IntList(mul_300, [2], True); mul_300 = None + div_4 = torch.ops.aten.div.Tensor(mul_240, 4096) + mul_301 = torch.ops.aten.mul.Tensor(div_4, sum_13); div_4 = sum_13 = None + sub_7 = torch.ops.aten.sub.Tensor(mul_298, mul_301); mul_298 = mul_301 = None + mul_302 = torch.ops.aten.mul.Tensor(sub_7, rsqrt_60); sub_7 = rsqrt_60 = None + mul_303 = torch.ops.aten.mul.Tensor(convert_element_type_1176, mul_240); convert_element_type_1176 = mul_240 = None + sum_14 = torch.ops.aten.sum.dim_IntList(mul_303, [0, 1]); mul_303 = None + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(mul_302, torch.bfloat16); mul_302 = None + convert_element_type_1180 = torch.ops.prims.convert_element_type.default(sum_14, torch.bfloat16); sum_14 = None + all_reduce_4 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1180, 'sum', '1'); convert_element_type_1180 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_4); all_reduce_4 = None + convert_element_type_1181 = torch.ops.prims.convert_element_type.default(wait_tensor_453, torch.float32); wait_tensor_453 = None + reduce_scatter_tensor_89 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1181, 'avg', 32, '0'); convert_element_type_1181 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_89); reduce_scatter_tensor_89 = None + add_142 = torch.ops.aten.add.Tensor(add_139, convert_element_type_1179); add_139 = convert_element_type_1179 = None + all_gather_into_tensor_360 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_142, 8, '1') + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_360); all_gather_into_tensor_360 = None + split_147 = torch.ops.aten.split.Tensor(wait_tensor_455, 2); wait_tensor_455 = None + getitem_1470 = split_147[0] + getitem_1471 = split_147[1] + getitem_1472 = split_147[2] + getitem_1473 = split_147[3] + getitem_1474 = split_147[4] + getitem_1475 = split_147[5] + getitem_1476 = split_147[6] + getitem_1477 = split_147[7]; split_147 = None + cat_139 = torch.ops.aten.cat.default([getitem_1470, getitem_1471, getitem_1472, getitem_1473, getitem_1474, getitem_1475, getitem_1476, getitem_1477], 1); getitem_1470 = getitem_1471 = getitem_1472 = getitem_1473 = getitem_1474 = getitem_1475 = getitem_1476 = getitem_1477 = None + view_2371 = torch.ops.aten.view.default(cat_139, [16384, 4096]); cat_139 = None + permute_421 = torch.ops.aten.permute.default(view_2371, [1, 0]) + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59); reduce_scatter_tensor_59 = None + add_117 = torch.ops.aten.add.Tensor(add_115, wait_tensor_385); wait_tensor_385 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16); primals_270 = None + all_gather_into_tensor_326 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 32, '0'); convert_element_type_977 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_326); all_gather_into_tensor_326 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32); add_117 = None + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_386) + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + all_gather_into_tensor_327 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_979, 8, '1'); convert_element_type_979 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_327); all_gather_into_tensor_327 = None + split_127 = torch.ops.aten.split.Tensor(wait_tensor_387, 2); wait_tensor_387 = None + getitem_1286 = split_127[0] + getitem_1287 = split_127[1] + getitem_1288 = split_127[2] + getitem_1289 = split_127[3] + getitem_1290 = split_127[4] + getitem_1291 = split_127[5] + getitem_1292 = split_127[6] + getitem_1293 = split_127[7]; split_127 = None + cat_119 = torch.ops.aten.cat.default([getitem_1286, getitem_1287, getitem_1288, getitem_1289, getitem_1290, getitem_1291, getitem_1292, getitem_1293], 1); getitem_1286 = getitem_1287 = getitem_1288 = getitem_1289 = getitem_1290 = getitem_1291 = getitem_1292 = getitem_1293 = None + view_2148 = torch.ops.aten.view.default(cat_119, [16384, 4096]); cat_119 = None + view_2149 = torch.ops.aten.view.default(mm_207, [2, 8192, 1792]); mm_207 = None + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_2149, torch.float32); view_2149 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16); primals_272 = None + all_gather_into_tensor_329 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 32, '0'); convert_element_type_985 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_329); all_gather_into_tensor_329 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_389, [1, 0]); wait_tensor_389 = None + mm_208 = torch.ops.aten.mm.default(view_2148, permute_328) + view_2156 = torch.ops.aten.view.default(mm_208, [2, 8192, 1792]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_2156) + view_2163 = torch.ops.aten.view.default(mul_239, [16384, 1792]); mul_239 = None + mm_255 = torch.ops.aten.mm.default(permute_421, view_2163); permute_421 = view_2163 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16); primals_273 = None + all_gather_into_tensor_330 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 32, '0'); convert_element_type_988 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_330); all_gather_into_tensor_330 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_390, [1, 0]); wait_tensor_390 = None + permute_423 = torch.ops.aten.permute.default(permute_329, [1, 0]); permute_329 = None + mm_256 = torch.ops.aten.mm.default(view_2371, permute_423); view_2371 = permute_423 = None + view_2372 = torch.ops.aten.view.default(mm_256, [2, 8192, 1792]); mm_256 = None + convert_element_type_1186 = torch.ops.prims.convert_element_type.default(mm_255, torch.float32); mm_255 = None + reduce_scatter_tensor_90 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1186, 'avg', 32, '0'); convert_element_type_1186 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_90); reduce_scatter_tensor_90 = None + mul_304 = torch.ops.aten.mul.Tensor(view_2372, convert_element_type_984); convert_element_type_984 = None + mul_305 = torch.ops.aten.mul.Tensor(view_2372, view_2156); view_2372 = view_2156 = None + view_2373 = torch.ops.aten.view.default(mul_304, [16384, 1792]); mul_304 = None + permute_425 = torch.ops.aten.permute.default(view_2373, [1, 0]) + mm_257 = torch.ops.aten.mm.default(permute_425, view_2148); permute_425 = None + permute_427 = torch.ops.aten.permute.default(permute_328, [1, 0]); permute_328 = None + mm_258 = torch.ops.aten.mm.default(view_2373, permute_427); view_2373 = permute_427 = None + view_2374 = torch.ops.aten.view.default(mm_258, [2, 8192, 4096]); mm_258 = None + convert_element_type_1191 = torch.ops.prims.convert_element_type.default(mm_257, torch.float32); mm_257 = None + reduce_scatter_tensor_91 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1191, 'avg', 32, '0'); convert_element_type_1191 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_91); reduce_scatter_tensor_91 = None + convert_element_type_1192 = torch.ops.prims.convert_element_type.default(mul_305, torch.float32); mul_305 = None + neg_2 = torch.ops.aten.neg.default(convert_element_type_983) + exp_2 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_143 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + reciprocal_2 = torch.ops.aten.reciprocal.default(add_143); add_143 = None + mul_306 = torch.ops.aten.mul.Tensor(reciprocal_2, 1); reciprocal_2 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_1192, mul_306); convert_element_type_1192 = None + sub_8 = torch.ops.aten.sub.Tensor(1, mul_306); mul_306 = None + mul_308 = torch.ops.aten.mul.Tensor(convert_element_type_983, sub_8); convert_element_type_983 = sub_8 = None + add_144 = torch.ops.aten.add.Tensor(mul_308, 1); mul_308 = None + mul_309 = torch.ops.aten.mul.Tensor(mul_307, add_144); mul_307 = add_144 = None + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(mul_309, torch.bfloat16); mul_309 = None + view_2375 = torch.ops.aten.view.default(convert_element_type_1194, [16384, 1792]); convert_element_type_1194 = None + permute_429 = torch.ops.aten.permute.default(view_2375, [1, 0]) + mm_259 = torch.ops.aten.mm.default(permute_429, view_2148); permute_429 = view_2148 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16); primals_271 = None + all_gather_into_tensor_328 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 32, '0'); convert_element_type_980 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_328); all_gather_into_tensor_328 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_388, [1, 0]); wait_tensor_388 = None + permute_431 = torch.ops.aten.permute.default(permute_327, [1, 0]); permute_327 = None + mm_260 = torch.ops.aten.mm.default(view_2375, permute_431); view_2375 = permute_431 = None + view_2376 = torch.ops.aten.view.default(mm_260, [2, 8192, 4096]); mm_260 = None + add_145 = torch.ops.aten.add.Tensor(view_2374, view_2376); view_2374 = view_2376 = None + convert_element_type_1199 = torch.ops.prims.convert_element_type.default(mm_259, torch.float32); mm_259 = None + reduce_scatter_tensor_92 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1199, 'avg', 32, '0'); convert_element_type_1199 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_92); reduce_scatter_tensor_92 = None + split_148 = torch.ops.aten.split.Tensor(add_145, 1024, 1); add_145 = None + getitem_1478 = split_148[0] + getitem_1479 = split_148[1] + getitem_1480 = split_148[2] + getitem_1481 = split_148[3] + getitem_1482 = split_148[4] + getitem_1483 = split_148[5] + getitem_1484 = split_148[6] + getitem_1485 = split_148[7]; split_148 = None + cat_140 = torch.ops.aten.cat.default([getitem_1478, getitem_1479, getitem_1480, getitem_1481, getitem_1482, getitem_1483, getitem_1484, getitem_1485]); getitem_1478 = getitem_1479 = getitem_1480 = getitem_1481 = getitem_1482 = getitem_1483 = getitem_1484 = getitem_1485 = None + reduce_scatter_tensor_93 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_140, 'sum', 8, '1'); cat_140 = None + wait_tensor_459 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_93); reduce_scatter_tensor_93 = None + convert_element_type_1200 = torch.ops.prims.convert_element_type.default(wait_tensor_459, torch.float32); wait_tensor_459 = None + convert_element_type_1202 = torch.ops.prims.convert_element_type.default(wait_tensor_386, torch.float32); wait_tensor_386 = None + mul_310 = torch.ops.aten.mul.Tensor(convert_element_type_1200, convert_element_type_1202); convert_element_type_1202 = None + mul_312 = torch.ops.aten.mul.Tensor(mul_236, mul_310) + sum_15 = torch.ops.aten.sum.dim_IntList(mul_312, [2], True); mul_312 = None + div_5 = torch.ops.aten.div.Tensor(mul_236, 4096) + mul_313 = torch.ops.aten.mul.Tensor(div_5, sum_15); div_5 = sum_15 = None + sub_9 = torch.ops.aten.sub.Tensor(mul_310, mul_313); mul_310 = mul_313 = None + mul_314 = torch.ops.aten.mul.Tensor(sub_9, rsqrt_59); sub_9 = rsqrt_59 = None + mul_315 = torch.ops.aten.mul.Tensor(convert_element_type_1200, mul_236); convert_element_type_1200 = mul_236 = None + sum_16 = torch.ops.aten.sum.dim_IntList(mul_315, [0, 1]); mul_315 = None + convert_element_type_1203 = torch.ops.prims.convert_element_type.default(mul_314, torch.bfloat16); mul_314 = None + convert_element_type_1204 = torch.ops.prims.convert_element_type.default(sum_16, torch.bfloat16); sum_16 = None + all_reduce_5 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1204, 'sum', '1'); convert_element_type_1204 = None + wait_tensor_460 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_5); all_reduce_5 = None + convert_element_type_1205 = torch.ops.prims.convert_element_type.default(wait_tensor_460, torch.float32); wait_tensor_460 = None + reduce_scatter_tensor_94 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1205, 'avg', 32, '0'); convert_element_type_1205 = None + wait_tensor_461 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_94); reduce_scatter_tensor_94 = None + add_146 = torch.ops.aten.add.Tensor(add_142, convert_element_type_1203); add_142 = convert_element_type_1203 = None + all_gather_into_tensor_361 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_146, 8, '1') + wait_tensor_462 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_361); all_gather_into_tensor_361 = None + split_149 = torch.ops.aten.split.Tensor(wait_tensor_462, 2); wait_tensor_462 = None + getitem_1486 = split_149[0] + getitem_1487 = split_149[1] + getitem_1488 = split_149[2] + getitem_1489 = split_149[3] + getitem_1490 = split_149[4] + getitem_1491 = split_149[5] + getitem_1492 = split_149[6] + getitem_1493 = split_149[7]; split_149 = None + cat_141 = torch.ops.aten.cat.default([getitem_1486, getitem_1487, getitem_1488, getitem_1489, getitem_1490, getitem_1491, getitem_1492, getitem_1493], 1); getitem_1486 = getitem_1487 = getitem_1488 = getitem_1489 = getitem_1490 = getitem_1491 = getitem_1492 = getitem_1493 = None + view_2377 = torch.ops.aten.view.default(cat_141, [16384, 4096]); cat_141 = None + permute_433 = torch.ops.aten.permute.default(view_2377, [1, 0]) + permute_325 = torch.ops.aten.permute.default(getitem_1269, [0, 2, 1, 3]) + view_2130 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + view_2136 = torch.ops.aten.view.default(view_2130, [16384, 512]); view_2130 = None + mm_261 = torch.ops.aten.mm.default(permute_433, view_2136); permute_433 = view_2136 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16); primals_269 = None + all_gather_into_tensor_325 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 32, '0'); convert_element_type_974 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_325); all_gather_into_tensor_325 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_384, [1, 0]); wait_tensor_384 = None + permute_435 = torch.ops.aten.permute.default(permute_326, [1, 0]); permute_326 = None + mm_262 = torch.ops.aten.mm.default(view_2377, permute_435); view_2377 = permute_435 = None + view_2378 = torch.ops.aten.view.default(mm_262, [2, 8192, 512]); mm_262 = None + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(mm_261, torch.float32); mm_261 = None + reduce_scatter_tensor_95 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1210, 'avg', 32, '0'); convert_element_type_1210 = None + wait_tensor_463 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_95); reduce_scatter_tensor_95 = None + view_2379 = torch.ops.aten.view.default(view_2378, [2, 8192, 4, 128]); view_2378 = None + permute_437 = torch.ops.aten.permute.default(view_2379, [0, 2, 1, 3]); view_2379 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16); primals_265 = None + all_gather_into_tensor_320 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 32, '0'); convert_element_type_958 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_320); all_gather_into_tensor_320 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32); add_115 = None + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_379) + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + all_gather_into_tensor_321 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_960, 8, '1'); convert_element_type_960 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_321); all_gather_into_tensor_321 = None + split_125 = torch.ops.aten.split.Tensor(wait_tensor_380, 2); wait_tensor_380 = None + getitem_1261 = split_125[0] + getitem_1262 = split_125[1] + getitem_1263 = split_125[2] + getitem_1264 = split_125[3] + getitem_1265 = split_125[4] + getitem_1266 = split_125[5] + getitem_1267 = split_125[6] + getitem_1268 = split_125[7]; split_125 = None + cat_117 = torch.ops.aten.cat.default([getitem_1261, getitem_1262, getitem_1263, getitem_1264, getitem_1265, getitem_1266, getitem_1267, getitem_1268], 1); getitem_1261 = getitem_1262 = getitem_1263 = getitem_1264 = getitem_1265 = getitem_1266 = getitem_1267 = getitem_1268 = None + view_2103 = torch.ops.aten.view.default(cat_117, [16384, 4096]); cat_117 = None + view_2104 = torch.ops.aten.view.default(mm_203, [2, 8192, 512]); mm_203 = None + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None + all_gather_into_tensor_323 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 32, '0'); convert_element_type_964 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_323); all_gather_into_tensor_323 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_382, [1, 0]); wait_tensor_382 = None + mm_204 = torch.ops.aten.mm.default(view_2103, permute_320) + view_2111 = torch.ops.aten.view.default(mm_204, [2, 8192, 128]); mm_204 = None + view_2118 = torch.ops.aten.view.default(mm_205, [2, 8192, 128]); mm_205 = None + view_2120 = torch.ops.aten.view.default(view_2104, [2, 8192, -1, 128]); view_2104 = None + view_2121 = torch.ops.aten.view.default(view_2111, [2, 8192, -1, 128]); view_2111 = None + view_2122 = torch.ops.aten.view.default(view_2118, [2, 8192, -1, 128]); view_2118 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_2120, torch.float32); view_2120 = None + view_2123 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 4, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_2123); view_2123 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_2121, torch.float32); view_2121 = None + view_2124 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 1, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_2124); view_2124 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_37); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_2126 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 4, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_37); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_2127 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 1, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_2126, torch.bfloat16); view_2126 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_2127, torch.bfloat16); view_2127 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 1, 4, 128]); unsqueeze_58 = None + view_2128 = torch.ops.aten.view.default(expand_58, [2, 8192, 4, 128]); expand_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_2122, 3); view_2122 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 1, 4, 128]); unsqueeze_59 = None + view_2129 = torch.ops.aten.view.default(expand_59, [2, 8192, 4, 128]); expand_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_2128, [0, 2, 1, 3]); view_2128 = None + permute_324 = torch.ops.aten.permute.default(view_2129, [0, 2, 1, 3]); view_2129 = None + _scaled_dot_product_cudnn_attention_backward_2 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_437, permute_322, permute_323, permute_324, getitem_1269, getitem_1270, getitem_1275, getitem_1276, None, None, None, 8192, 8192, 0.0, True); permute_437 = permute_322 = permute_323 = permute_324 = getitem_1269 = getitem_1270 = getitem_1275 = getitem_1276 = None + getitem_1494 = _scaled_dot_product_cudnn_attention_backward_2[0] + getitem_1495 = _scaled_dot_product_cudnn_attention_backward_2[1] + getitem_1496 = _scaled_dot_product_cudnn_attention_backward_2[2]; _scaled_dot_product_cudnn_attention_backward_2 = None + permute_438 = torch.ops.aten.permute.default(getitem_1496, [0, 2, 1, 3]); getitem_1496 = None + permute_439 = torch.ops.aten.permute.default(getitem_1495, [0, 2, 1, 3]); getitem_1495 = None + permute_440 = torch.ops.aten.permute.default(getitem_1494, [0, 2, 1, 3]); getitem_1494 = None + view_2380 = torch.ops.aten.view.default(permute_438, [2, 8192, 1, 4, 128]); permute_438 = None + sum_17 = torch.ops.aten.sum.dim_IntList(view_2380, [3], True); view_2380 = None + squeeze_4 = torch.ops.aten.squeeze.dim(sum_17, 3); sum_17 = None + view_2381 = torch.ops.aten.view.default(permute_439, [2, 8192, 1, 4, 128]); permute_439 = None + sum_18 = torch.ops.aten.sum.dim_IntList(view_2381, [3], True); view_2381 = None + squeeze_5 = torch.ops.aten.squeeze.dim(sum_18, 3); sum_18 = None + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(squeeze_5, torch.float32); squeeze_5 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(permute_440, torch.float32); permute_440 = None + view_2382 = torch.ops.aten.view.default(convert_element_type_1211, [2, 8192, 1, 64, 2]); convert_element_type_1211 = None + view_as_complex_68 = torch.ops.aten.view_as_complex.default(view_2382); view_2382 = None + mul_316 = torch.ops.aten.mul.Tensor(view_as_complex_68, _conj); view_as_complex_68 = None + view_2383 = torch.ops.aten.view.default(convert_element_type_1212, [2, 8192, 4, 64, 2]); convert_element_type_1212 = None + view_as_complex_69 = torch.ops.aten.view_as_complex.default(view_2383); view_2383 = None + mul_317 = torch.ops.aten.mul.Tensor(view_as_complex_69, _conj); view_as_complex_69 = None + view_as_real_68 = torch.ops.aten.view_as_real.default(mul_316); mul_316 = None + view_2384 = torch.ops.aten.view.default(view_as_real_68, [2, 8192, 1, 128]); view_as_real_68 = None + convert_element_type_1213 = torch.ops.prims.convert_element_type.default(view_2384, torch.bfloat16); view_2384 = None + view_as_real_69 = torch.ops.aten.view_as_real.default(mul_317); mul_317 = None + view_2385 = torch.ops.aten.view.default(view_as_real_69, [2, 8192, 4, 128]); view_as_real_69 = None + convert_element_type_1214 = torch.ops.prims.convert_element_type.default(view_2385, torch.bfloat16); view_2385 = None + view_2386 = torch.ops.aten.view.default(squeeze_4, [2, 8192, 128]); squeeze_4 = None + view_2387 = torch.ops.aten.view.default(convert_element_type_1213, [2, 8192, 128]); convert_element_type_1213 = None + view_2388 = torch.ops.aten.view.default(convert_element_type_1214, [2, 8192, 512]); convert_element_type_1214 = None + view_2389 = torch.ops.aten.view.default(view_2386, [16384, 128]); view_2386 = None + permute_441 = torch.ops.aten.permute.default(view_2389, [1, 0]) + mm_263 = torch.ops.aten.mm.default(permute_441, view_2103); permute_441 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None + all_gather_into_tensor_324 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 32, '0'); convert_element_type_967 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_324); all_gather_into_tensor_324 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_383, [1, 0]); wait_tensor_383 = None + permute_443 = torch.ops.aten.permute.default(permute_321, [1, 0]); permute_321 = None + mm_264 = torch.ops.aten.mm.default(view_2389, permute_443); view_2389 = permute_443 = None + view_2390 = torch.ops.aten.view.default(mm_264, [2, 8192, 4096]); mm_264 = None + convert_element_type_1219 = torch.ops.prims.convert_element_type.default(mm_263, torch.float32); mm_263 = None + reduce_scatter_tensor_96 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1219, 'avg', 32, '0'); convert_element_type_1219 = None + wait_tensor_464 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_96); reduce_scatter_tensor_96 = None + view_2391 = torch.ops.aten.view.default(view_2387, [16384, 128]); view_2387 = None + permute_445 = torch.ops.aten.permute.default(view_2391, [1, 0]) + mm_265 = torch.ops.aten.mm.default(permute_445, view_2103); permute_445 = None + permute_447 = torch.ops.aten.permute.default(permute_320, [1, 0]); permute_320 = None + mm_266 = torch.ops.aten.mm.default(view_2391, permute_447); view_2391 = permute_447 = None + view_2392 = torch.ops.aten.view.default(mm_266, [2, 8192, 4096]); mm_266 = None + add_147 = torch.ops.aten.add.Tensor(view_2390, view_2392); view_2390 = view_2392 = None + convert_element_type_1224 = torch.ops.prims.convert_element_type.default(mm_265, torch.float32); mm_265 = None + reduce_scatter_tensor_97 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1224, 'avg', 32, '0'); convert_element_type_1224 = None + wait_tensor_465 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_97); reduce_scatter_tensor_97 = None + view_2393 = torch.ops.aten.view.default(view_2388, [16384, 512]); view_2388 = None + permute_449 = torch.ops.aten.permute.default(view_2393, [1, 0]) + mm_267 = torch.ops.aten.mm.default(permute_449, view_2103); permute_449 = view_2103 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16); primals_266 = None + all_gather_into_tensor_322 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 32, '0'); convert_element_type_961 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_322); all_gather_into_tensor_322 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_381, [1, 0]); wait_tensor_381 = None + permute_451 = torch.ops.aten.permute.default(permute_319, [1, 0]); permute_319 = None + mm_268 = torch.ops.aten.mm.default(view_2393, permute_451); view_2393 = permute_451 = None + view_2394 = torch.ops.aten.view.default(mm_268, [2, 8192, 4096]); mm_268 = None + add_148 = torch.ops.aten.add.Tensor(add_147, view_2394); add_147 = view_2394 = None + convert_element_type_1229 = torch.ops.prims.convert_element_type.default(mm_267, torch.float32); mm_267 = None + reduce_scatter_tensor_98 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1229, 'avg', 32, '0'); convert_element_type_1229 = None + wait_tensor_466 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_98); reduce_scatter_tensor_98 = None + split_150 = torch.ops.aten.split.Tensor(add_148, 1024, 1); add_148 = None + getitem_1497 = split_150[0] + getitem_1498 = split_150[1] + getitem_1499 = split_150[2] + getitem_1500 = split_150[3] + getitem_1501 = split_150[4] + getitem_1502 = split_150[5] + getitem_1503 = split_150[6] + getitem_1504 = split_150[7]; split_150 = None + cat_142 = torch.ops.aten.cat.default([getitem_1497, getitem_1498, getitem_1499, getitem_1500, getitem_1501, getitem_1502, getitem_1503, getitem_1504]); getitem_1497 = getitem_1498 = getitem_1499 = getitem_1500 = getitem_1501 = getitem_1502 = getitem_1503 = getitem_1504 = None + reduce_scatter_tensor_99 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_142, 'sum', 8, '1'); cat_142 = None + wait_tensor_467 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_99); reduce_scatter_tensor_99 = None + convert_element_type_1230 = torch.ops.prims.convert_element_type.default(wait_tensor_467, torch.float32); wait_tensor_467 = None + convert_element_type_1232 = torch.ops.prims.convert_element_type.default(wait_tensor_379, torch.float32); wait_tensor_379 = None + mul_318 = torch.ops.aten.mul.Tensor(convert_element_type_1230, convert_element_type_1232); convert_element_type_1232 = None + mul_320 = torch.ops.aten.mul.Tensor(mul_232, mul_318) + sum_19 = torch.ops.aten.sum.dim_IntList(mul_320, [2], True); mul_320 = None + div_6 = torch.ops.aten.div.Tensor(mul_232, 4096) + mul_321 = torch.ops.aten.mul.Tensor(div_6, sum_19); div_6 = sum_19 = None + sub_10 = torch.ops.aten.sub.Tensor(mul_318, mul_321); mul_318 = mul_321 = None + mul_322 = torch.ops.aten.mul.Tensor(sub_10, rsqrt_58); sub_10 = rsqrt_58 = None + mul_323 = torch.ops.aten.mul.Tensor(convert_element_type_1230, mul_232); convert_element_type_1230 = mul_232 = None + sum_20 = torch.ops.aten.sum.dim_IntList(mul_323, [0, 1]); mul_323 = None + convert_element_type_1233 = torch.ops.prims.convert_element_type.default(mul_322, torch.bfloat16); mul_322 = None + convert_element_type_1234 = torch.ops.prims.convert_element_type.default(sum_20, torch.bfloat16); sum_20 = None + all_reduce_6 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1234, 'sum', '1'); convert_element_type_1234 = None + wait_tensor_468 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_6); all_reduce_6 = None + convert_element_type_1235 = torch.ops.prims.convert_element_type.default(wait_tensor_468, torch.float32); wait_tensor_468 = None + reduce_scatter_tensor_100 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1235, 'avg', 32, '0'); convert_element_type_1235 = None + wait_tensor_469 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_100); reduce_scatter_tensor_100 = None + add_149 = torch.ops.aten.add.Tensor(add_146, convert_element_type_1233); add_146 = convert_element_type_1233 = None + all_gather_into_tensor_362 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_149, 8, '1') + wait_tensor_470 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_362); all_gather_into_tensor_362 = None + split_151 = torch.ops.aten.split.Tensor(wait_tensor_470, 2); wait_tensor_470 = None + getitem_1505 = split_151[0] + getitem_1506 = split_151[1] + getitem_1507 = split_151[2] + getitem_1508 = split_151[3] + getitem_1509 = split_151[4] + getitem_1510 = split_151[5] + getitem_1511 = split_151[6] + getitem_1512 = split_151[7]; split_151 = None + cat_143 = torch.ops.aten.cat.default([getitem_1505, getitem_1506, getitem_1507, getitem_1508, getitem_1509, getitem_1510, getitem_1511, getitem_1512], 1); getitem_1505 = getitem_1506 = getitem_1507 = getitem_1508 = getitem_1509 = getitem_1510 = getitem_1511 = getitem_1512 = None + view_2395 = torch.ops.aten.view.default(cat_143, [16384, 4096]); cat_143 = None + permute_453 = torch.ops.aten.permute.default(view_2395, [1, 0]) + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57); reduce_scatter_tensor_57 = None + add_113 = torch.ops.aten.add.Tensor(add_111, wait_tensor_372); wait_tensor_372 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16); primals_261 = None + all_gather_into_tensor_315 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 32, '0'); convert_element_type_944 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_315); all_gather_into_tensor_315 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32); add_113 = None + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_373) + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + all_gather_into_tensor_316 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_946, 8, '1'); convert_element_type_946 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_316); all_gather_into_tensor_316 = None + split_123 = torch.ops.aten.split.Tensor(wait_tensor_374, 2); wait_tensor_374 = None + getitem_1245 = split_123[0] + getitem_1246 = split_123[1] + getitem_1247 = split_123[2] + getitem_1248 = split_123[3] + getitem_1249 = split_123[4] + getitem_1250 = split_123[5] + getitem_1251 = split_123[6] + getitem_1252 = split_123[7]; split_123 = None + cat_115 = torch.ops.aten.cat.default([getitem_1245, getitem_1246, getitem_1247, getitem_1248, getitem_1249, getitem_1250, getitem_1251, getitem_1252], 1); getitem_1245 = getitem_1246 = getitem_1247 = getitem_1248 = getitem_1249 = getitem_1250 = getitem_1251 = getitem_1252 = None + view_2076 = torch.ops.aten.view.default(cat_115, [16384, 4096]); cat_115 = None + view_2077 = torch.ops.aten.view.default(mm_200, [2, 8192, 1792]); mm_200 = None + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_2077, torch.float32); view_2077 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16); primals_263 = None + all_gather_into_tensor_318 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 32, '0'); convert_element_type_952 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_318); all_gather_into_tensor_318 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_376, [1, 0]); wait_tensor_376 = None + mm_201 = torch.ops.aten.mm.default(view_2076, permute_317) + view_2084 = torch.ops.aten.view.default(mm_201, [2, 8192, 1792]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_2084) + view_2091 = torch.ops.aten.view.default(mul_231, [16384, 1792]); mul_231 = None + mm_269 = torch.ops.aten.mm.default(permute_453, view_2091); permute_453 = view_2091 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16); primals_264 = None + all_gather_into_tensor_319 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 32, '0'); convert_element_type_955 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_319); all_gather_into_tensor_319 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_377, [1, 0]); wait_tensor_377 = None + permute_455 = torch.ops.aten.permute.default(permute_318, [1, 0]); permute_318 = None + mm_270 = torch.ops.aten.mm.default(view_2395, permute_455); view_2395 = permute_455 = None + view_2396 = torch.ops.aten.view.default(mm_270, [2, 8192, 1792]); mm_270 = None + convert_element_type_1240 = torch.ops.prims.convert_element_type.default(mm_269, torch.float32); mm_269 = None + reduce_scatter_tensor_101 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1240, 'avg', 32, '0'); convert_element_type_1240 = None + wait_tensor_471 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_101); reduce_scatter_tensor_101 = None + mul_324 = torch.ops.aten.mul.Tensor(view_2396, convert_element_type_951); convert_element_type_951 = None + mul_325 = torch.ops.aten.mul.Tensor(view_2396, view_2084); view_2396 = view_2084 = None + view_2397 = torch.ops.aten.view.default(mul_324, [16384, 1792]); mul_324 = None + permute_457 = torch.ops.aten.permute.default(view_2397, [1, 0]) + mm_271 = torch.ops.aten.mm.default(permute_457, view_2076); permute_457 = None + permute_459 = torch.ops.aten.permute.default(permute_317, [1, 0]); permute_317 = None + mm_272 = torch.ops.aten.mm.default(view_2397, permute_459); view_2397 = permute_459 = None + view_2398 = torch.ops.aten.view.default(mm_272, [2, 8192, 4096]); mm_272 = None + convert_element_type_1245 = torch.ops.prims.convert_element_type.default(mm_271, torch.float32); mm_271 = None + reduce_scatter_tensor_102 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1245, 'avg', 32, '0'); convert_element_type_1245 = None + wait_tensor_472 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_102); reduce_scatter_tensor_102 = None + convert_element_type_1246 = torch.ops.prims.convert_element_type.default(mul_325, torch.float32); mul_325 = None + neg_3 = torch.ops.aten.neg.default(convert_element_type_950) + exp_3 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_150 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + reciprocal_3 = torch.ops.aten.reciprocal.default(add_150); add_150 = None + mul_326 = torch.ops.aten.mul.Tensor(reciprocal_3, 1); reciprocal_3 = None + mul_327 = torch.ops.aten.mul.Tensor(convert_element_type_1246, mul_326); convert_element_type_1246 = None + sub_11 = torch.ops.aten.sub.Tensor(1, mul_326); mul_326 = None + mul_328 = torch.ops.aten.mul.Tensor(convert_element_type_950, sub_11); convert_element_type_950 = sub_11 = None + add_151 = torch.ops.aten.add.Tensor(mul_328, 1); mul_328 = None + mul_329 = torch.ops.aten.mul.Tensor(mul_327, add_151); mul_327 = add_151 = None + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(mul_329, torch.bfloat16); mul_329 = None + view_2399 = torch.ops.aten.view.default(convert_element_type_1248, [16384, 1792]); convert_element_type_1248 = None + permute_461 = torch.ops.aten.permute.default(view_2399, [1, 0]) + mm_273 = torch.ops.aten.mm.default(permute_461, view_2076); permute_461 = view_2076 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16); primals_262 = None + all_gather_into_tensor_317 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 32, '0'); convert_element_type_947 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_317); all_gather_into_tensor_317 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_375, [1, 0]); wait_tensor_375 = None + permute_463 = torch.ops.aten.permute.default(permute_316, [1, 0]); permute_316 = None + mm_274 = torch.ops.aten.mm.default(view_2399, permute_463); view_2399 = permute_463 = None + view_2400 = torch.ops.aten.view.default(mm_274, [2, 8192, 4096]); mm_274 = None + add_152 = torch.ops.aten.add.Tensor(view_2398, view_2400); view_2398 = view_2400 = None + convert_element_type_1253 = torch.ops.prims.convert_element_type.default(mm_273, torch.float32); mm_273 = None + reduce_scatter_tensor_103 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1253, 'avg', 32, '0'); convert_element_type_1253 = None + wait_tensor_473 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_103); reduce_scatter_tensor_103 = None + split_152 = torch.ops.aten.split.Tensor(add_152, 1024, 1); add_152 = None + getitem_1513 = split_152[0] + getitem_1514 = split_152[1] + getitem_1515 = split_152[2] + getitem_1516 = split_152[3] + getitem_1517 = split_152[4] + getitem_1518 = split_152[5] + getitem_1519 = split_152[6] + getitem_1520 = split_152[7]; split_152 = None + cat_144 = torch.ops.aten.cat.default([getitem_1513, getitem_1514, getitem_1515, getitem_1516, getitem_1517, getitem_1518, getitem_1519, getitem_1520]); getitem_1513 = getitem_1514 = getitem_1515 = getitem_1516 = getitem_1517 = getitem_1518 = getitem_1519 = getitem_1520 = None + reduce_scatter_tensor_104 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_144, 'sum', 8, '1'); cat_144 = None + wait_tensor_474 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_104); reduce_scatter_tensor_104 = None + convert_element_type_1254 = torch.ops.prims.convert_element_type.default(wait_tensor_474, torch.float32); wait_tensor_474 = None + convert_element_type_1256 = torch.ops.prims.convert_element_type.default(wait_tensor_373, torch.float32); wait_tensor_373 = None + mul_330 = torch.ops.aten.mul.Tensor(convert_element_type_1254, convert_element_type_1256); convert_element_type_1256 = None + mul_332 = torch.ops.aten.mul.Tensor(mul_228, mul_330) + sum_21 = torch.ops.aten.sum.dim_IntList(mul_332, [2], True); mul_332 = None + div_7 = torch.ops.aten.div.Tensor(mul_228, 4096) + mul_333 = torch.ops.aten.mul.Tensor(div_7, sum_21); div_7 = sum_21 = None + sub_12 = torch.ops.aten.sub.Tensor(mul_330, mul_333); mul_330 = mul_333 = None + mul_334 = torch.ops.aten.mul.Tensor(sub_12, rsqrt_57); sub_12 = rsqrt_57 = None + mul_335 = torch.ops.aten.mul.Tensor(convert_element_type_1254, mul_228); convert_element_type_1254 = mul_228 = None + sum_22 = torch.ops.aten.sum.dim_IntList(mul_335, [0, 1]); mul_335 = None + convert_element_type_1257 = torch.ops.prims.convert_element_type.default(mul_334, torch.bfloat16); mul_334 = None + convert_element_type_1258 = torch.ops.prims.convert_element_type.default(sum_22, torch.bfloat16); sum_22 = None + all_reduce_7 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1258, 'sum', '1'); convert_element_type_1258 = None + wait_tensor_475 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_7); all_reduce_7 = None + convert_element_type_1259 = torch.ops.prims.convert_element_type.default(wait_tensor_475, torch.float32); wait_tensor_475 = None + reduce_scatter_tensor_105 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1259, 'avg', 32, '0'); convert_element_type_1259 = None + wait_tensor_476 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_105); reduce_scatter_tensor_105 = None + add_153 = torch.ops.aten.add.Tensor(add_149, convert_element_type_1257); add_149 = convert_element_type_1257 = None + all_gather_into_tensor_363 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_153, 8, '1') + wait_tensor_477 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_363); all_gather_into_tensor_363 = None + split_153 = torch.ops.aten.split.Tensor(wait_tensor_477, 2); wait_tensor_477 = None + getitem_1521 = split_153[0] + getitem_1522 = split_153[1] + getitem_1523 = split_153[2] + getitem_1524 = split_153[3] + getitem_1525 = split_153[4] + getitem_1526 = split_153[5] + getitem_1527 = split_153[6] + getitem_1528 = split_153[7]; split_153 = None + cat_145 = torch.ops.aten.cat.default([getitem_1521, getitem_1522, getitem_1523, getitem_1524, getitem_1525, getitem_1526, getitem_1527, getitem_1528], 1); getitem_1521 = getitem_1522 = getitem_1523 = getitem_1524 = getitem_1525 = getitem_1526 = getitem_1527 = getitem_1528 = None + view_2401 = torch.ops.aten.view.default(cat_145, [16384, 4096]); cat_145 = None + permute_465 = torch.ops.aten.permute.default(view_2401, [1, 0]) + permute_314 = torch.ops.aten.permute.default(getitem_1228, [0, 2, 1, 3]) + view_2058 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + view_2064 = torch.ops.aten.view.default(view_2058, [16384, 512]); view_2058 = None + mm_275 = torch.ops.aten.mm.default(permute_465, view_2064); permute_465 = view_2064 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16); primals_260 = None + all_gather_into_tensor_314 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 32, '0'); convert_element_type_941 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_314); all_gather_into_tensor_314 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_371, [1, 0]); wait_tensor_371 = None + permute_467 = torch.ops.aten.permute.default(permute_315, [1, 0]); permute_315 = None + mm_276 = torch.ops.aten.mm.default(view_2401, permute_467); view_2401 = permute_467 = None + view_2402 = torch.ops.aten.view.default(mm_276, [2, 8192, 512]); mm_276 = None + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(mm_275, torch.float32); mm_275 = None + reduce_scatter_tensor_106 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1264, 'avg', 32, '0'); convert_element_type_1264 = None + wait_tensor_478 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_106); reduce_scatter_tensor_106 = None + view_2403 = torch.ops.aten.view.default(view_2402, [2, 8192, 4, 128]); view_2402 = None + permute_469 = torch.ops.aten.permute.default(view_2403, [0, 2, 1, 3]); view_2403 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16); primals_256 = None + all_gather_into_tensor_309 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 32, '0'); convert_element_type_925 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_309); all_gather_into_tensor_309 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32); add_111 = None + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_366) + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + all_gather_into_tensor_310 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_927, 8, '1'); convert_element_type_927 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_310); all_gather_into_tensor_310 = None + split_121 = torch.ops.aten.split.Tensor(wait_tensor_367, 2); wait_tensor_367 = None + getitem_1220 = split_121[0] + getitem_1221 = split_121[1] + getitem_1222 = split_121[2] + getitem_1223 = split_121[3] + getitem_1224 = split_121[4] + getitem_1225 = split_121[5] + getitem_1226 = split_121[6] + getitem_1227 = split_121[7]; split_121 = None + cat_113 = torch.ops.aten.cat.default([getitem_1220, getitem_1221, getitem_1222, getitem_1223, getitem_1224, getitem_1225, getitem_1226, getitem_1227], 1); getitem_1220 = getitem_1221 = getitem_1222 = getitem_1223 = getitem_1224 = getitem_1225 = getitem_1226 = getitem_1227 = None + view_2031 = torch.ops.aten.view.default(cat_113, [16384, 4096]); cat_113 = None + view_2032 = torch.ops.aten.view.default(mm_196, [2, 8192, 512]); mm_196 = None + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16); primals_258 = None + all_gather_into_tensor_312 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 32, '0'); convert_element_type_931 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_312); all_gather_into_tensor_312 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_369, [1, 0]); wait_tensor_369 = None + mm_197 = torch.ops.aten.mm.default(view_2031, permute_309) + view_2039 = torch.ops.aten.view.default(mm_197, [2, 8192, 128]); mm_197 = None + view_2046 = torch.ops.aten.view.default(mm_198, [2, 8192, 128]); mm_198 = None + view_2048 = torch.ops.aten.view.default(view_2032, [2, 8192, -1, 128]); view_2032 = None + view_2049 = torch.ops.aten.view.default(view_2039, [2, 8192, -1, 128]); view_2039 = None + view_2050 = torch.ops.aten.view.default(view_2046, [2, 8192, -1, 128]); view_2046 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_2048, torch.float32); view_2048 = None + view_2051 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 4, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_2051); view_2051 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_2049, torch.float32); view_2049 = None + view_2052 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 1, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_2052); view_2052 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_37); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_2054 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 4, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_37); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_2055 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 1, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_2054, torch.bfloat16); view_2054 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_2055, torch.bfloat16); view_2055 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 1, 4, 128]); unsqueeze_56 = None + view_2056 = torch.ops.aten.view.default(expand_56, [2, 8192, 4, 128]); expand_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_2050, 3); view_2050 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 1, 4, 128]); unsqueeze_57 = None + view_2057 = torch.ops.aten.view.default(expand_57, [2, 8192, 4, 128]); expand_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_2056, [0, 2, 1, 3]); view_2056 = None + permute_313 = torch.ops.aten.permute.default(view_2057, [0, 2, 1, 3]); view_2057 = None + _scaled_dot_product_cudnn_attention_backward_3 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_469, permute_311, permute_312, permute_313, getitem_1228, getitem_1229, getitem_1234, getitem_1235, None, None, None, 8192, 8192, 0.0, True); permute_469 = permute_311 = permute_312 = permute_313 = getitem_1228 = getitem_1229 = getitem_1234 = getitem_1235 = None + getitem_1529 = _scaled_dot_product_cudnn_attention_backward_3[0] + getitem_1530 = _scaled_dot_product_cudnn_attention_backward_3[1] + getitem_1531 = _scaled_dot_product_cudnn_attention_backward_3[2]; _scaled_dot_product_cudnn_attention_backward_3 = None + permute_470 = torch.ops.aten.permute.default(getitem_1531, [0, 2, 1, 3]); getitem_1531 = None + permute_471 = torch.ops.aten.permute.default(getitem_1530, [0, 2, 1, 3]); getitem_1530 = None + permute_472 = torch.ops.aten.permute.default(getitem_1529, [0, 2, 1, 3]); getitem_1529 = None + view_2404 = torch.ops.aten.view.default(permute_470, [2, 8192, 1, 4, 128]); permute_470 = None + sum_23 = torch.ops.aten.sum.dim_IntList(view_2404, [3], True); view_2404 = None + squeeze_6 = torch.ops.aten.squeeze.dim(sum_23, 3); sum_23 = None + view_2405 = torch.ops.aten.view.default(permute_471, [2, 8192, 1, 4, 128]); permute_471 = None + sum_24 = torch.ops.aten.sum.dim_IntList(view_2405, [3], True); view_2405 = None + squeeze_7 = torch.ops.aten.squeeze.dim(sum_24, 3); sum_24 = None + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(squeeze_7, torch.float32); squeeze_7 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(permute_472, torch.float32); permute_472 = None + view_2406 = torch.ops.aten.view.default(convert_element_type_1265, [2, 8192, 1, 64, 2]); convert_element_type_1265 = None + view_as_complex_70 = torch.ops.aten.view_as_complex.default(view_2406); view_2406 = None + mul_336 = torch.ops.aten.mul.Tensor(view_as_complex_70, _conj); view_as_complex_70 = None + view_2407 = torch.ops.aten.view.default(convert_element_type_1266, [2, 8192, 4, 64, 2]); convert_element_type_1266 = None + view_as_complex_71 = torch.ops.aten.view_as_complex.default(view_2407); view_2407 = None + mul_337 = torch.ops.aten.mul.Tensor(view_as_complex_71, _conj); view_as_complex_71 = None + view_as_real_70 = torch.ops.aten.view_as_real.default(mul_336); mul_336 = None + view_2408 = torch.ops.aten.view.default(view_as_real_70, [2, 8192, 1, 128]); view_as_real_70 = None + convert_element_type_1267 = torch.ops.prims.convert_element_type.default(view_2408, torch.bfloat16); view_2408 = None + view_as_real_71 = torch.ops.aten.view_as_real.default(mul_337); mul_337 = None + view_2409 = torch.ops.aten.view.default(view_as_real_71, [2, 8192, 4, 128]); view_as_real_71 = None + convert_element_type_1268 = torch.ops.prims.convert_element_type.default(view_2409, torch.bfloat16); view_2409 = None + view_2410 = torch.ops.aten.view.default(squeeze_6, [2, 8192, 128]); squeeze_6 = None + view_2411 = torch.ops.aten.view.default(convert_element_type_1267, [2, 8192, 128]); convert_element_type_1267 = None + view_2412 = torch.ops.aten.view.default(convert_element_type_1268, [2, 8192, 512]); convert_element_type_1268 = None + view_2413 = torch.ops.aten.view.default(view_2410, [16384, 128]); view_2410 = None + permute_473 = torch.ops.aten.permute.default(view_2413, [1, 0]) + mm_277 = torch.ops.aten.mm.default(permute_473, view_2031); permute_473 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16); primals_259 = None + all_gather_into_tensor_313 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 32, '0'); convert_element_type_934 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_313); all_gather_into_tensor_313 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_370, [1, 0]); wait_tensor_370 = None + permute_475 = torch.ops.aten.permute.default(permute_310, [1, 0]); permute_310 = None + mm_278 = torch.ops.aten.mm.default(view_2413, permute_475); view_2413 = permute_475 = None + view_2414 = torch.ops.aten.view.default(mm_278, [2, 8192, 4096]); mm_278 = None + convert_element_type_1273 = torch.ops.prims.convert_element_type.default(mm_277, torch.float32); mm_277 = None + reduce_scatter_tensor_107 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1273, 'avg', 32, '0'); convert_element_type_1273 = None + wait_tensor_479 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_107); reduce_scatter_tensor_107 = None + view_2415 = torch.ops.aten.view.default(view_2411, [16384, 128]); view_2411 = None + permute_477 = torch.ops.aten.permute.default(view_2415, [1, 0]) + mm_279 = torch.ops.aten.mm.default(permute_477, view_2031); permute_477 = None + permute_479 = torch.ops.aten.permute.default(permute_309, [1, 0]); permute_309 = None + mm_280 = torch.ops.aten.mm.default(view_2415, permute_479); view_2415 = permute_479 = None + view_2416 = torch.ops.aten.view.default(mm_280, [2, 8192, 4096]); mm_280 = None + add_154 = torch.ops.aten.add.Tensor(view_2414, view_2416); view_2414 = view_2416 = None + convert_element_type_1278 = torch.ops.prims.convert_element_type.default(mm_279, torch.float32); mm_279 = None + reduce_scatter_tensor_108 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1278, 'avg', 32, '0'); convert_element_type_1278 = None + wait_tensor_480 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_108); reduce_scatter_tensor_108 = None + view_2417 = torch.ops.aten.view.default(view_2412, [16384, 512]); view_2412 = None + permute_481 = torch.ops.aten.permute.default(view_2417, [1, 0]) + mm_281 = torch.ops.aten.mm.default(permute_481, view_2031); permute_481 = view_2031 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16); primals_257 = None + all_gather_into_tensor_311 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 32, '0'); convert_element_type_928 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_311); all_gather_into_tensor_311 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_368, [1, 0]); wait_tensor_368 = None + permute_483 = torch.ops.aten.permute.default(permute_308, [1, 0]); permute_308 = None + mm_282 = torch.ops.aten.mm.default(view_2417, permute_483); view_2417 = permute_483 = None + view_2418 = torch.ops.aten.view.default(mm_282, [2, 8192, 4096]); mm_282 = None + add_155 = torch.ops.aten.add.Tensor(add_154, view_2418); add_154 = view_2418 = None + convert_element_type_1283 = torch.ops.prims.convert_element_type.default(mm_281, torch.float32); mm_281 = None + reduce_scatter_tensor_109 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1283, 'avg', 32, '0'); convert_element_type_1283 = None + wait_tensor_481 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_109); reduce_scatter_tensor_109 = None + split_154 = torch.ops.aten.split.Tensor(add_155, 1024, 1); add_155 = None + getitem_1532 = split_154[0] + getitem_1533 = split_154[1] + getitem_1534 = split_154[2] + getitem_1535 = split_154[3] + getitem_1536 = split_154[4] + getitem_1537 = split_154[5] + getitem_1538 = split_154[6] + getitem_1539 = split_154[7]; split_154 = None + cat_146 = torch.ops.aten.cat.default([getitem_1532, getitem_1533, getitem_1534, getitem_1535, getitem_1536, getitem_1537, getitem_1538, getitem_1539]); getitem_1532 = getitem_1533 = getitem_1534 = getitem_1535 = getitem_1536 = getitem_1537 = getitem_1538 = getitem_1539 = None + reduce_scatter_tensor_110 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_146, 'sum', 8, '1'); cat_146 = None + wait_tensor_482 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_110); reduce_scatter_tensor_110 = None + convert_element_type_1284 = torch.ops.prims.convert_element_type.default(wait_tensor_482, torch.float32); wait_tensor_482 = None + convert_element_type_1286 = torch.ops.prims.convert_element_type.default(wait_tensor_366, torch.float32); wait_tensor_366 = None + mul_338 = torch.ops.aten.mul.Tensor(convert_element_type_1284, convert_element_type_1286); convert_element_type_1286 = None + mul_340 = torch.ops.aten.mul.Tensor(mul_224, mul_338) + sum_25 = torch.ops.aten.sum.dim_IntList(mul_340, [2], True); mul_340 = None + div_8 = torch.ops.aten.div.Tensor(mul_224, 4096) + mul_341 = torch.ops.aten.mul.Tensor(div_8, sum_25); div_8 = sum_25 = None + sub_13 = torch.ops.aten.sub.Tensor(mul_338, mul_341); mul_338 = mul_341 = None + mul_342 = torch.ops.aten.mul.Tensor(sub_13, rsqrt_56); sub_13 = rsqrt_56 = None + mul_343 = torch.ops.aten.mul.Tensor(convert_element_type_1284, mul_224); convert_element_type_1284 = mul_224 = None + sum_26 = torch.ops.aten.sum.dim_IntList(mul_343, [0, 1]); mul_343 = None + convert_element_type_1287 = torch.ops.prims.convert_element_type.default(mul_342, torch.bfloat16); mul_342 = None + convert_element_type_1288 = torch.ops.prims.convert_element_type.default(sum_26, torch.bfloat16); sum_26 = None + all_reduce_8 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1288, 'sum', '1'); convert_element_type_1288 = None + wait_tensor_483 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_8); all_reduce_8 = None + convert_element_type_1289 = torch.ops.prims.convert_element_type.default(wait_tensor_483, torch.float32); wait_tensor_483 = None + reduce_scatter_tensor_111 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1289, 'avg', 32, '0'); convert_element_type_1289 = None + wait_tensor_484 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_111); reduce_scatter_tensor_111 = None + add_156 = torch.ops.aten.add.Tensor(add_153, convert_element_type_1287); add_153 = convert_element_type_1287 = None + all_gather_into_tensor_364 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_156, 8, '1') + wait_tensor_485 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_364); all_gather_into_tensor_364 = None + split_155 = torch.ops.aten.split.Tensor(wait_tensor_485, 2); wait_tensor_485 = None + getitem_1540 = split_155[0] + getitem_1541 = split_155[1] + getitem_1542 = split_155[2] + getitem_1543 = split_155[3] + getitem_1544 = split_155[4] + getitem_1545 = split_155[5] + getitem_1546 = split_155[6] + getitem_1547 = split_155[7]; split_155 = None + cat_147 = torch.ops.aten.cat.default([getitem_1540, getitem_1541, getitem_1542, getitem_1543, getitem_1544, getitem_1545, getitem_1546, getitem_1547], 1); getitem_1540 = getitem_1541 = getitem_1542 = getitem_1543 = getitem_1544 = getitem_1545 = getitem_1546 = getitem_1547 = None + view_2419 = torch.ops.aten.view.default(cat_147, [16384, 4096]); cat_147 = None + permute_485 = torch.ops.aten.permute.default(view_2419, [1, 0]) + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55); reduce_scatter_tensor_55 = None + add_109 = torch.ops.aten.add.Tensor(add_107, wait_tensor_359); wait_tensor_359 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16); primals_252 = None + all_gather_into_tensor_304 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 32, '0'); convert_element_type_911 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_304); all_gather_into_tensor_304 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32); add_109 = None + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_360) + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + all_gather_into_tensor_305 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_913, 8, '1'); convert_element_type_913 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_305); all_gather_into_tensor_305 = None + split_119 = torch.ops.aten.split.Tensor(wait_tensor_361, 2); wait_tensor_361 = None + getitem_1204 = split_119[0] + getitem_1205 = split_119[1] + getitem_1206 = split_119[2] + getitem_1207 = split_119[3] + getitem_1208 = split_119[4] + getitem_1209 = split_119[5] + getitem_1210 = split_119[6] + getitem_1211 = split_119[7]; split_119 = None + cat_111 = torch.ops.aten.cat.default([getitem_1204, getitem_1205, getitem_1206, getitem_1207, getitem_1208, getitem_1209, getitem_1210, getitem_1211], 1); getitem_1204 = getitem_1205 = getitem_1206 = getitem_1207 = getitem_1208 = getitem_1209 = getitem_1210 = getitem_1211 = None + view_2004 = torch.ops.aten.view.default(cat_111, [16384, 4096]); cat_111 = None + view_2005 = torch.ops.aten.view.default(mm_193, [2, 8192, 1792]); mm_193 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_2005, torch.float32); view_2005 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16); primals_254 = None + all_gather_into_tensor_307 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 32, '0'); convert_element_type_919 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_307); all_gather_into_tensor_307 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_363, [1, 0]); wait_tensor_363 = None + mm_194 = torch.ops.aten.mm.default(view_2004, permute_306) + view_2012 = torch.ops.aten.view.default(mm_194, [2, 8192, 1792]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_2012) + view_2019 = torch.ops.aten.view.default(mul_223, [16384, 1792]); mul_223 = None + mm_283 = torch.ops.aten.mm.default(permute_485, view_2019); permute_485 = view_2019 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16); primals_255 = None + all_gather_into_tensor_308 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 32, '0'); convert_element_type_922 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_308); all_gather_into_tensor_308 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_364, [1, 0]); wait_tensor_364 = None + permute_487 = torch.ops.aten.permute.default(permute_307, [1, 0]); permute_307 = None + mm_284 = torch.ops.aten.mm.default(view_2419, permute_487); view_2419 = permute_487 = None + view_2420 = torch.ops.aten.view.default(mm_284, [2, 8192, 1792]); mm_284 = None + convert_element_type_1294 = torch.ops.prims.convert_element_type.default(mm_283, torch.float32); mm_283 = None + reduce_scatter_tensor_112 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1294, 'avg', 32, '0'); convert_element_type_1294 = None + wait_tensor_486 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_112); reduce_scatter_tensor_112 = None + mul_344 = torch.ops.aten.mul.Tensor(view_2420, convert_element_type_918); convert_element_type_918 = None + mul_345 = torch.ops.aten.mul.Tensor(view_2420, view_2012); view_2420 = view_2012 = None + view_2421 = torch.ops.aten.view.default(mul_344, [16384, 1792]); mul_344 = None + permute_489 = torch.ops.aten.permute.default(view_2421, [1, 0]) + mm_285 = torch.ops.aten.mm.default(permute_489, view_2004); permute_489 = None + permute_491 = torch.ops.aten.permute.default(permute_306, [1, 0]); permute_306 = None + mm_286 = torch.ops.aten.mm.default(view_2421, permute_491); view_2421 = permute_491 = None + view_2422 = torch.ops.aten.view.default(mm_286, [2, 8192, 4096]); mm_286 = None + convert_element_type_1299 = torch.ops.prims.convert_element_type.default(mm_285, torch.float32); mm_285 = None + reduce_scatter_tensor_113 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1299, 'avg', 32, '0'); convert_element_type_1299 = None + wait_tensor_487 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_113); reduce_scatter_tensor_113 = None + convert_element_type_1300 = torch.ops.prims.convert_element_type.default(mul_345, torch.float32); mul_345 = None + neg_4 = torch.ops.aten.neg.default(convert_element_type_917) + exp_4 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_157 = torch.ops.aten.add.Tensor(exp_4, 1); exp_4 = None + reciprocal_4 = torch.ops.aten.reciprocal.default(add_157); add_157 = None + mul_346 = torch.ops.aten.mul.Tensor(reciprocal_4, 1); reciprocal_4 = None + mul_347 = torch.ops.aten.mul.Tensor(convert_element_type_1300, mul_346); convert_element_type_1300 = None + sub_14 = torch.ops.aten.sub.Tensor(1, mul_346); mul_346 = None + mul_348 = torch.ops.aten.mul.Tensor(convert_element_type_917, sub_14); convert_element_type_917 = sub_14 = None + add_158 = torch.ops.aten.add.Tensor(mul_348, 1); mul_348 = None + mul_349 = torch.ops.aten.mul.Tensor(mul_347, add_158); mul_347 = add_158 = None + convert_element_type_1302 = torch.ops.prims.convert_element_type.default(mul_349, torch.bfloat16); mul_349 = None + view_2423 = torch.ops.aten.view.default(convert_element_type_1302, [16384, 1792]); convert_element_type_1302 = None + permute_493 = torch.ops.aten.permute.default(view_2423, [1, 0]) + mm_287 = torch.ops.aten.mm.default(permute_493, view_2004); permute_493 = view_2004 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16); primals_253 = None + all_gather_into_tensor_306 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 32, '0'); convert_element_type_914 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_306); all_gather_into_tensor_306 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_362, [1, 0]); wait_tensor_362 = None + permute_495 = torch.ops.aten.permute.default(permute_305, [1, 0]); permute_305 = None + mm_288 = torch.ops.aten.mm.default(view_2423, permute_495); view_2423 = permute_495 = None + view_2424 = torch.ops.aten.view.default(mm_288, [2, 8192, 4096]); mm_288 = None + add_159 = torch.ops.aten.add.Tensor(view_2422, view_2424); view_2422 = view_2424 = None + convert_element_type_1307 = torch.ops.prims.convert_element_type.default(mm_287, torch.float32); mm_287 = None + reduce_scatter_tensor_114 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1307, 'avg', 32, '0'); convert_element_type_1307 = None + wait_tensor_488 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_114); reduce_scatter_tensor_114 = None + split_156 = torch.ops.aten.split.Tensor(add_159, 1024, 1); add_159 = None + getitem_1548 = split_156[0] + getitem_1549 = split_156[1] + getitem_1550 = split_156[2] + getitem_1551 = split_156[3] + getitem_1552 = split_156[4] + getitem_1553 = split_156[5] + getitem_1554 = split_156[6] + getitem_1555 = split_156[7]; split_156 = None + cat_148 = torch.ops.aten.cat.default([getitem_1548, getitem_1549, getitem_1550, getitem_1551, getitem_1552, getitem_1553, getitem_1554, getitem_1555]); getitem_1548 = getitem_1549 = getitem_1550 = getitem_1551 = getitem_1552 = getitem_1553 = getitem_1554 = getitem_1555 = None + reduce_scatter_tensor_115 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_148, 'sum', 8, '1'); cat_148 = None + wait_tensor_489 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_115); reduce_scatter_tensor_115 = None + convert_element_type_1308 = torch.ops.prims.convert_element_type.default(wait_tensor_489, torch.float32); wait_tensor_489 = None + convert_element_type_1310 = torch.ops.prims.convert_element_type.default(wait_tensor_360, torch.float32); wait_tensor_360 = None + mul_350 = torch.ops.aten.mul.Tensor(convert_element_type_1308, convert_element_type_1310); convert_element_type_1310 = None + mul_352 = torch.ops.aten.mul.Tensor(mul_220, mul_350) + sum_27 = torch.ops.aten.sum.dim_IntList(mul_352, [2], True); mul_352 = None + div_9 = torch.ops.aten.div.Tensor(mul_220, 4096) + mul_353 = torch.ops.aten.mul.Tensor(div_9, sum_27); div_9 = sum_27 = None + sub_15 = torch.ops.aten.sub.Tensor(mul_350, mul_353); mul_350 = mul_353 = None + mul_354 = torch.ops.aten.mul.Tensor(sub_15, rsqrt_55); sub_15 = rsqrt_55 = None + mul_355 = torch.ops.aten.mul.Tensor(convert_element_type_1308, mul_220); convert_element_type_1308 = mul_220 = None + sum_28 = torch.ops.aten.sum.dim_IntList(mul_355, [0, 1]); mul_355 = None + convert_element_type_1311 = torch.ops.prims.convert_element_type.default(mul_354, torch.bfloat16); mul_354 = None + convert_element_type_1312 = torch.ops.prims.convert_element_type.default(sum_28, torch.bfloat16); sum_28 = None + all_reduce_9 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1312, 'sum', '1'); convert_element_type_1312 = None + wait_tensor_490 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_9); all_reduce_9 = None + convert_element_type_1313 = torch.ops.prims.convert_element_type.default(wait_tensor_490, torch.float32); wait_tensor_490 = None + reduce_scatter_tensor_116 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1313, 'avg', 32, '0'); convert_element_type_1313 = None + wait_tensor_491 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_116); reduce_scatter_tensor_116 = None + add_160 = torch.ops.aten.add.Tensor(add_156, convert_element_type_1311); add_156 = convert_element_type_1311 = None + all_gather_into_tensor_365 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_160, 8, '1') + wait_tensor_492 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_365); all_gather_into_tensor_365 = None + split_157 = torch.ops.aten.split.Tensor(wait_tensor_492, 2); wait_tensor_492 = None + getitem_1556 = split_157[0] + getitem_1557 = split_157[1] + getitem_1558 = split_157[2] + getitem_1559 = split_157[3] + getitem_1560 = split_157[4] + getitem_1561 = split_157[5] + getitem_1562 = split_157[6] + getitem_1563 = split_157[7]; split_157 = None + cat_149 = torch.ops.aten.cat.default([getitem_1556, getitem_1557, getitem_1558, getitem_1559, getitem_1560, getitem_1561, getitem_1562, getitem_1563], 1); getitem_1556 = getitem_1557 = getitem_1558 = getitem_1559 = getitem_1560 = getitem_1561 = getitem_1562 = getitem_1563 = None + view_2425 = torch.ops.aten.view.default(cat_149, [16384, 4096]); cat_149 = None + permute_497 = torch.ops.aten.permute.default(view_2425, [1, 0]) + permute_303 = torch.ops.aten.permute.default(getitem_1187, [0, 2, 1, 3]) + view_1986 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + view_1992 = torch.ops.aten.view.default(view_1986, [16384, 512]); view_1986 = None + mm_289 = torch.ops.aten.mm.default(permute_497, view_1992); permute_497 = view_1992 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16); primals_251 = None + all_gather_into_tensor_303 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 32, '0'); convert_element_type_908 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_303); all_gather_into_tensor_303 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_358, [1, 0]); wait_tensor_358 = None + permute_499 = torch.ops.aten.permute.default(permute_304, [1, 0]); permute_304 = None + mm_290 = torch.ops.aten.mm.default(view_2425, permute_499); view_2425 = permute_499 = None + view_2426 = torch.ops.aten.view.default(mm_290, [2, 8192, 512]); mm_290 = None + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(mm_289, torch.float32); mm_289 = None + reduce_scatter_tensor_117 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1318, 'avg', 32, '0'); convert_element_type_1318 = None + wait_tensor_493 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_117); reduce_scatter_tensor_117 = None + view_2427 = torch.ops.aten.view.default(view_2426, [2, 8192, 4, 128]); view_2426 = None + permute_501 = torch.ops.aten.permute.default(view_2427, [0, 2, 1, 3]); view_2427 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16); primals_247 = None + all_gather_into_tensor_298 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 32, '0'); convert_element_type_892 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_298); all_gather_into_tensor_298 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32); add_107 = None + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_353) + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + all_gather_into_tensor_299 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_894, 8, '1'); convert_element_type_894 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_299); all_gather_into_tensor_299 = None + split_117 = torch.ops.aten.split.Tensor(wait_tensor_354, 2); wait_tensor_354 = None + getitem_1179 = split_117[0] + getitem_1180 = split_117[1] + getitem_1181 = split_117[2] + getitem_1182 = split_117[3] + getitem_1183 = split_117[4] + getitem_1184 = split_117[5] + getitem_1185 = split_117[6] + getitem_1186 = split_117[7]; split_117 = None + cat_109 = torch.ops.aten.cat.default([getitem_1179, getitem_1180, getitem_1181, getitem_1182, getitem_1183, getitem_1184, getitem_1185, getitem_1186], 1); getitem_1179 = getitem_1180 = getitem_1181 = getitem_1182 = getitem_1183 = getitem_1184 = getitem_1185 = getitem_1186 = None + view_1959 = torch.ops.aten.view.default(cat_109, [16384, 4096]); cat_109 = None + view_1960 = torch.ops.aten.view.default(mm_189, [2, 8192, 512]); mm_189 = None + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16); primals_249 = None + all_gather_into_tensor_301 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 32, '0'); convert_element_type_898 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_301); all_gather_into_tensor_301 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_356, [1, 0]); wait_tensor_356 = None + mm_190 = torch.ops.aten.mm.default(view_1959, permute_298) + view_1967 = torch.ops.aten.view.default(mm_190, [2, 8192, 128]); mm_190 = None + view_1974 = torch.ops.aten.view.default(mm_191, [2, 8192, 128]); mm_191 = None + view_1976 = torch.ops.aten.view.default(view_1960, [2, 8192, -1, 128]); view_1960 = None + view_1977 = torch.ops.aten.view.default(view_1967, [2, 8192, -1, 128]); view_1967 = None + view_1978 = torch.ops.aten.view.default(view_1974, [2, 8192, -1, 128]); view_1974 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_1976, torch.float32); view_1976 = None + view_1979 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 4, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_1979); view_1979 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_1977, torch.float32); view_1977 = None + view_1980 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 1, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_1980); view_1980 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_37); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_1982 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 4, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_37); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_1983 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 1, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_1982, torch.bfloat16); view_1982 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_1983, torch.bfloat16); view_1983 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 1, 4, 128]); unsqueeze_54 = None + view_1984 = torch.ops.aten.view.default(expand_54, [2, 8192, 4, 128]); expand_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_1978, 3); view_1978 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 1, 4, 128]); unsqueeze_55 = None + view_1985 = torch.ops.aten.view.default(expand_55, [2, 8192, 4, 128]); expand_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_1984, [0, 2, 1, 3]); view_1984 = None + permute_302 = torch.ops.aten.permute.default(view_1985, [0, 2, 1, 3]); view_1985 = None + _scaled_dot_product_cudnn_attention_backward_4 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_501, permute_300, permute_301, permute_302, getitem_1187, getitem_1188, getitem_1193, getitem_1194, None, None, None, 8192, 8192, 0.0, True); permute_501 = permute_300 = permute_301 = permute_302 = getitem_1187 = getitem_1188 = getitem_1193 = getitem_1194 = None + getitem_1564 = _scaled_dot_product_cudnn_attention_backward_4[0] + getitem_1565 = _scaled_dot_product_cudnn_attention_backward_4[1] + getitem_1566 = _scaled_dot_product_cudnn_attention_backward_4[2]; _scaled_dot_product_cudnn_attention_backward_4 = None + permute_502 = torch.ops.aten.permute.default(getitem_1566, [0, 2, 1, 3]); getitem_1566 = None + permute_503 = torch.ops.aten.permute.default(getitem_1565, [0, 2, 1, 3]); getitem_1565 = None + permute_504 = torch.ops.aten.permute.default(getitem_1564, [0, 2, 1, 3]); getitem_1564 = None + view_2428 = torch.ops.aten.view.default(permute_502, [2, 8192, 1, 4, 128]); permute_502 = None + sum_29 = torch.ops.aten.sum.dim_IntList(view_2428, [3], True); view_2428 = None + squeeze_8 = torch.ops.aten.squeeze.dim(sum_29, 3); sum_29 = None + view_2429 = torch.ops.aten.view.default(permute_503, [2, 8192, 1, 4, 128]); permute_503 = None + sum_30 = torch.ops.aten.sum.dim_IntList(view_2429, [3], True); view_2429 = None + squeeze_9 = torch.ops.aten.squeeze.dim(sum_30, 3); sum_30 = None + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(squeeze_9, torch.float32); squeeze_9 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(permute_504, torch.float32); permute_504 = None + view_2430 = torch.ops.aten.view.default(convert_element_type_1319, [2, 8192, 1, 64, 2]); convert_element_type_1319 = None + view_as_complex_72 = torch.ops.aten.view_as_complex.default(view_2430); view_2430 = None + mul_356 = torch.ops.aten.mul.Tensor(view_as_complex_72, _conj); view_as_complex_72 = None + view_2431 = torch.ops.aten.view.default(convert_element_type_1320, [2, 8192, 4, 64, 2]); convert_element_type_1320 = None + view_as_complex_73 = torch.ops.aten.view_as_complex.default(view_2431); view_2431 = None + mul_357 = torch.ops.aten.mul.Tensor(view_as_complex_73, _conj); view_as_complex_73 = None + view_as_real_72 = torch.ops.aten.view_as_real.default(mul_356); mul_356 = None + view_2432 = torch.ops.aten.view.default(view_as_real_72, [2, 8192, 1, 128]); view_as_real_72 = None + convert_element_type_1321 = torch.ops.prims.convert_element_type.default(view_2432, torch.bfloat16); view_2432 = None + view_as_real_73 = torch.ops.aten.view_as_real.default(mul_357); mul_357 = None + view_2433 = torch.ops.aten.view.default(view_as_real_73, [2, 8192, 4, 128]); view_as_real_73 = None + convert_element_type_1322 = torch.ops.prims.convert_element_type.default(view_2433, torch.bfloat16); view_2433 = None + view_2434 = torch.ops.aten.view.default(squeeze_8, [2, 8192, 128]); squeeze_8 = None + view_2435 = torch.ops.aten.view.default(convert_element_type_1321, [2, 8192, 128]); convert_element_type_1321 = None + view_2436 = torch.ops.aten.view.default(convert_element_type_1322, [2, 8192, 512]); convert_element_type_1322 = None + view_2437 = torch.ops.aten.view.default(view_2434, [16384, 128]); view_2434 = None + permute_505 = torch.ops.aten.permute.default(view_2437, [1, 0]) + mm_291 = torch.ops.aten.mm.default(permute_505, view_1959); permute_505 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16); primals_250 = None + all_gather_into_tensor_302 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 32, '0'); convert_element_type_901 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_302); all_gather_into_tensor_302 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_357, [1, 0]); wait_tensor_357 = None + permute_507 = torch.ops.aten.permute.default(permute_299, [1, 0]); permute_299 = None + mm_292 = torch.ops.aten.mm.default(view_2437, permute_507); view_2437 = permute_507 = None + view_2438 = torch.ops.aten.view.default(mm_292, [2, 8192, 4096]); mm_292 = None + convert_element_type_1327 = torch.ops.prims.convert_element_type.default(mm_291, torch.float32); mm_291 = None + reduce_scatter_tensor_118 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1327, 'avg', 32, '0'); convert_element_type_1327 = None + wait_tensor_494 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_118); reduce_scatter_tensor_118 = None + view_2439 = torch.ops.aten.view.default(view_2435, [16384, 128]); view_2435 = None + permute_509 = torch.ops.aten.permute.default(view_2439, [1, 0]) + mm_293 = torch.ops.aten.mm.default(permute_509, view_1959); permute_509 = None + permute_511 = torch.ops.aten.permute.default(permute_298, [1, 0]); permute_298 = None + mm_294 = torch.ops.aten.mm.default(view_2439, permute_511); view_2439 = permute_511 = None + view_2440 = torch.ops.aten.view.default(mm_294, [2, 8192, 4096]); mm_294 = None + add_161 = torch.ops.aten.add.Tensor(view_2438, view_2440); view_2438 = view_2440 = None + convert_element_type_1332 = torch.ops.prims.convert_element_type.default(mm_293, torch.float32); mm_293 = None + reduce_scatter_tensor_119 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1332, 'avg', 32, '0'); convert_element_type_1332 = None + wait_tensor_495 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_119); reduce_scatter_tensor_119 = None + view_2441 = torch.ops.aten.view.default(view_2436, [16384, 512]); view_2436 = None + permute_513 = torch.ops.aten.permute.default(view_2441, [1, 0]) + mm_295 = torch.ops.aten.mm.default(permute_513, view_1959); permute_513 = view_1959 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16); primals_248 = None + all_gather_into_tensor_300 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 32, '0'); convert_element_type_895 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_300); all_gather_into_tensor_300 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_355, [1, 0]); wait_tensor_355 = None + permute_515 = torch.ops.aten.permute.default(permute_297, [1, 0]); permute_297 = None + mm_296 = torch.ops.aten.mm.default(view_2441, permute_515); view_2441 = permute_515 = None + view_2442 = torch.ops.aten.view.default(mm_296, [2, 8192, 4096]); mm_296 = None + add_162 = torch.ops.aten.add.Tensor(add_161, view_2442); add_161 = view_2442 = None + convert_element_type_1337 = torch.ops.prims.convert_element_type.default(mm_295, torch.float32); mm_295 = None + reduce_scatter_tensor_120 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1337, 'avg', 32, '0'); convert_element_type_1337 = None + wait_tensor_496 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_120); reduce_scatter_tensor_120 = None + split_158 = torch.ops.aten.split.Tensor(add_162, 1024, 1); add_162 = None + getitem_1567 = split_158[0] + getitem_1568 = split_158[1] + getitem_1569 = split_158[2] + getitem_1570 = split_158[3] + getitem_1571 = split_158[4] + getitem_1572 = split_158[5] + getitem_1573 = split_158[6] + getitem_1574 = split_158[7]; split_158 = None + cat_150 = torch.ops.aten.cat.default([getitem_1567, getitem_1568, getitem_1569, getitem_1570, getitem_1571, getitem_1572, getitem_1573, getitem_1574]); getitem_1567 = getitem_1568 = getitem_1569 = getitem_1570 = getitem_1571 = getitem_1572 = getitem_1573 = getitem_1574 = None + reduce_scatter_tensor_121 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_150, 'sum', 8, '1'); cat_150 = None + wait_tensor_497 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_121); reduce_scatter_tensor_121 = None + convert_element_type_1338 = torch.ops.prims.convert_element_type.default(wait_tensor_497, torch.float32); wait_tensor_497 = None + convert_element_type_1340 = torch.ops.prims.convert_element_type.default(wait_tensor_353, torch.float32); wait_tensor_353 = None + mul_358 = torch.ops.aten.mul.Tensor(convert_element_type_1338, convert_element_type_1340); convert_element_type_1340 = None + mul_360 = torch.ops.aten.mul.Tensor(mul_216, mul_358) + sum_31 = torch.ops.aten.sum.dim_IntList(mul_360, [2], True); mul_360 = None + div_10 = torch.ops.aten.div.Tensor(mul_216, 4096) + mul_361 = torch.ops.aten.mul.Tensor(div_10, sum_31); div_10 = sum_31 = None + sub_16 = torch.ops.aten.sub.Tensor(mul_358, mul_361); mul_358 = mul_361 = None + mul_362 = torch.ops.aten.mul.Tensor(sub_16, rsqrt_54); sub_16 = rsqrt_54 = None + mul_363 = torch.ops.aten.mul.Tensor(convert_element_type_1338, mul_216); convert_element_type_1338 = mul_216 = None + sum_32 = torch.ops.aten.sum.dim_IntList(mul_363, [0, 1]); mul_363 = None + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(mul_362, torch.bfloat16); mul_362 = None + convert_element_type_1342 = torch.ops.prims.convert_element_type.default(sum_32, torch.bfloat16); sum_32 = None + all_reduce_10 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1342, 'sum', '1'); convert_element_type_1342 = None + wait_tensor_498 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_10); all_reduce_10 = None + convert_element_type_1343 = torch.ops.prims.convert_element_type.default(wait_tensor_498, torch.float32); wait_tensor_498 = None + reduce_scatter_tensor_122 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1343, 'avg', 32, '0'); convert_element_type_1343 = None + wait_tensor_499 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_122); reduce_scatter_tensor_122 = None + add_163 = torch.ops.aten.add.Tensor(add_160, convert_element_type_1341); add_160 = convert_element_type_1341 = None + all_gather_into_tensor_366 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_163, 8, '1') + wait_tensor_500 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_366); all_gather_into_tensor_366 = None + split_159 = torch.ops.aten.split.Tensor(wait_tensor_500, 2); wait_tensor_500 = None + getitem_1575 = split_159[0] + getitem_1576 = split_159[1] + getitem_1577 = split_159[2] + getitem_1578 = split_159[3] + getitem_1579 = split_159[4] + getitem_1580 = split_159[5] + getitem_1581 = split_159[6] + getitem_1582 = split_159[7]; split_159 = None + cat_151 = torch.ops.aten.cat.default([getitem_1575, getitem_1576, getitem_1577, getitem_1578, getitem_1579, getitem_1580, getitem_1581, getitem_1582], 1); getitem_1575 = getitem_1576 = getitem_1577 = getitem_1578 = getitem_1579 = getitem_1580 = getitem_1581 = getitem_1582 = None + view_2443 = torch.ops.aten.view.default(cat_151, [16384, 4096]); cat_151 = None + permute_517 = torch.ops.aten.permute.default(view_2443, [1, 0]) + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53); reduce_scatter_tensor_53 = None + add_105 = torch.ops.aten.add.Tensor(add_103, wait_tensor_346); wait_tensor_346 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16); primals_243 = None + all_gather_into_tensor_293 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 32, '0'); convert_element_type_878 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_293); all_gather_into_tensor_293 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32); add_105 = None + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_347) + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + all_gather_into_tensor_294 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_880, 8, '1'); convert_element_type_880 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_294); all_gather_into_tensor_294 = None + split_115 = torch.ops.aten.split.Tensor(wait_tensor_348, 2); wait_tensor_348 = None + getitem_1163 = split_115[0] + getitem_1164 = split_115[1] + getitem_1165 = split_115[2] + getitem_1166 = split_115[3] + getitem_1167 = split_115[4] + getitem_1168 = split_115[5] + getitem_1169 = split_115[6] + getitem_1170 = split_115[7]; split_115 = None + cat_107 = torch.ops.aten.cat.default([getitem_1163, getitem_1164, getitem_1165, getitem_1166, getitem_1167, getitem_1168, getitem_1169, getitem_1170], 1); getitem_1163 = getitem_1164 = getitem_1165 = getitem_1166 = getitem_1167 = getitem_1168 = getitem_1169 = getitem_1170 = None + view_1932 = torch.ops.aten.view.default(cat_107, [16384, 4096]); cat_107 = None + view_1933 = torch.ops.aten.view.default(mm_186, [2, 8192, 1792]); mm_186 = None + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_1933, torch.float32); view_1933 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16); primals_245 = None + all_gather_into_tensor_296 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 32, '0'); convert_element_type_886 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_296); all_gather_into_tensor_296 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_350, [1, 0]); wait_tensor_350 = None + mm_187 = torch.ops.aten.mm.default(view_1932, permute_295) + view_1940 = torch.ops.aten.view.default(mm_187, [2, 8192, 1792]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_1940) + view_1947 = torch.ops.aten.view.default(mul_215, [16384, 1792]); mul_215 = None + mm_297 = torch.ops.aten.mm.default(permute_517, view_1947); permute_517 = view_1947 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16); primals_246 = None + all_gather_into_tensor_297 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 32, '0'); convert_element_type_889 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_297); all_gather_into_tensor_297 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_351, [1, 0]); wait_tensor_351 = None + permute_519 = torch.ops.aten.permute.default(permute_296, [1, 0]); permute_296 = None + mm_298 = torch.ops.aten.mm.default(view_2443, permute_519); view_2443 = permute_519 = None + view_2444 = torch.ops.aten.view.default(mm_298, [2, 8192, 1792]); mm_298 = None + convert_element_type_1348 = torch.ops.prims.convert_element_type.default(mm_297, torch.float32); mm_297 = None + reduce_scatter_tensor_123 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1348, 'avg', 32, '0'); convert_element_type_1348 = None + wait_tensor_501 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_123); reduce_scatter_tensor_123 = None + mul_364 = torch.ops.aten.mul.Tensor(view_2444, convert_element_type_885); convert_element_type_885 = None + mul_365 = torch.ops.aten.mul.Tensor(view_2444, view_1940); view_2444 = view_1940 = None + view_2445 = torch.ops.aten.view.default(mul_364, [16384, 1792]); mul_364 = None + permute_521 = torch.ops.aten.permute.default(view_2445, [1, 0]) + mm_299 = torch.ops.aten.mm.default(permute_521, view_1932); permute_521 = None + permute_523 = torch.ops.aten.permute.default(permute_295, [1, 0]); permute_295 = None + mm_300 = torch.ops.aten.mm.default(view_2445, permute_523); view_2445 = permute_523 = None + view_2446 = torch.ops.aten.view.default(mm_300, [2, 8192, 4096]); mm_300 = None + convert_element_type_1353 = torch.ops.prims.convert_element_type.default(mm_299, torch.float32); mm_299 = None + reduce_scatter_tensor_124 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1353, 'avg', 32, '0'); convert_element_type_1353 = None + wait_tensor_502 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_124); reduce_scatter_tensor_124 = None + convert_element_type_1354 = torch.ops.prims.convert_element_type.default(mul_365, torch.float32); mul_365 = None + neg_5 = torch.ops.aten.neg.default(convert_element_type_884) + exp_5 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_164 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + reciprocal_5 = torch.ops.aten.reciprocal.default(add_164); add_164 = None + mul_366 = torch.ops.aten.mul.Tensor(reciprocal_5, 1); reciprocal_5 = None + mul_367 = torch.ops.aten.mul.Tensor(convert_element_type_1354, mul_366); convert_element_type_1354 = None + sub_17 = torch.ops.aten.sub.Tensor(1, mul_366); mul_366 = None + mul_368 = torch.ops.aten.mul.Tensor(convert_element_type_884, sub_17); convert_element_type_884 = sub_17 = None + add_165 = torch.ops.aten.add.Tensor(mul_368, 1); mul_368 = None + mul_369 = torch.ops.aten.mul.Tensor(mul_367, add_165); mul_367 = add_165 = None + convert_element_type_1356 = torch.ops.prims.convert_element_type.default(mul_369, torch.bfloat16); mul_369 = None + view_2447 = torch.ops.aten.view.default(convert_element_type_1356, [16384, 1792]); convert_element_type_1356 = None + permute_525 = torch.ops.aten.permute.default(view_2447, [1, 0]) + mm_301 = torch.ops.aten.mm.default(permute_525, view_1932); permute_525 = view_1932 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None + all_gather_into_tensor_295 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 32, '0'); convert_element_type_881 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_295); all_gather_into_tensor_295 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_349, [1, 0]); wait_tensor_349 = None + permute_527 = torch.ops.aten.permute.default(permute_294, [1, 0]); permute_294 = None + mm_302 = torch.ops.aten.mm.default(view_2447, permute_527); view_2447 = permute_527 = None + view_2448 = torch.ops.aten.view.default(mm_302, [2, 8192, 4096]); mm_302 = None + add_166 = torch.ops.aten.add.Tensor(view_2446, view_2448); view_2446 = view_2448 = None + convert_element_type_1361 = torch.ops.prims.convert_element_type.default(mm_301, torch.float32); mm_301 = None + reduce_scatter_tensor_125 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1361, 'avg', 32, '0'); convert_element_type_1361 = None + wait_tensor_503 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_125); reduce_scatter_tensor_125 = None + split_160 = torch.ops.aten.split.Tensor(add_166, 1024, 1); add_166 = None + getitem_1583 = split_160[0] + getitem_1584 = split_160[1] + getitem_1585 = split_160[2] + getitem_1586 = split_160[3] + getitem_1587 = split_160[4] + getitem_1588 = split_160[5] + getitem_1589 = split_160[6] + getitem_1590 = split_160[7]; split_160 = None + cat_152 = torch.ops.aten.cat.default([getitem_1583, getitem_1584, getitem_1585, getitem_1586, getitem_1587, getitem_1588, getitem_1589, getitem_1590]); getitem_1583 = getitem_1584 = getitem_1585 = getitem_1586 = getitem_1587 = getitem_1588 = getitem_1589 = getitem_1590 = None + reduce_scatter_tensor_126 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_152, 'sum', 8, '1'); cat_152 = None + wait_tensor_504 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_126); reduce_scatter_tensor_126 = None + convert_element_type_1362 = torch.ops.prims.convert_element_type.default(wait_tensor_504, torch.float32); wait_tensor_504 = None + convert_element_type_1364 = torch.ops.prims.convert_element_type.default(wait_tensor_347, torch.float32); wait_tensor_347 = None + mul_370 = torch.ops.aten.mul.Tensor(convert_element_type_1362, convert_element_type_1364); convert_element_type_1364 = None + mul_372 = torch.ops.aten.mul.Tensor(mul_212, mul_370) + sum_33 = torch.ops.aten.sum.dim_IntList(mul_372, [2], True); mul_372 = None + div_11 = torch.ops.aten.div.Tensor(mul_212, 4096) + mul_373 = torch.ops.aten.mul.Tensor(div_11, sum_33); div_11 = sum_33 = None + sub_18 = torch.ops.aten.sub.Tensor(mul_370, mul_373); mul_370 = mul_373 = None + mul_374 = torch.ops.aten.mul.Tensor(sub_18, rsqrt_53); sub_18 = rsqrt_53 = None + mul_375 = torch.ops.aten.mul.Tensor(convert_element_type_1362, mul_212); convert_element_type_1362 = mul_212 = None + sum_34 = torch.ops.aten.sum.dim_IntList(mul_375, [0, 1]); mul_375 = None + convert_element_type_1365 = torch.ops.prims.convert_element_type.default(mul_374, torch.bfloat16); mul_374 = None + convert_element_type_1366 = torch.ops.prims.convert_element_type.default(sum_34, torch.bfloat16); sum_34 = None + all_reduce_11 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1366, 'sum', '1'); convert_element_type_1366 = None + wait_tensor_505 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_11); all_reduce_11 = None + convert_element_type_1367 = torch.ops.prims.convert_element_type.default(wait_tensor_505, torch.float32); wait_tensor_505 = None + reduce_scatter_tensor_127 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1367, 'avg', 32, '0'); convert_element_type_1367 = None + wait_tensor_506 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_127); reduce_scatter_tensor_127 = None + add_167 = torch.ops.aten.add.Tensor(add_163, convert_element_type_1365); add_163 = convert_element_type_1365 = None + all_gather_into_tensor_367 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_167, 8, '1') + wait_tensor_507 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_367); all_gather_into_tensor_367 = None + split_161 = torch.ops.aten.split.Tensor(wait_tensor_507, 2); wait_tensor_507 = None + getitem_1591 = split_161[0] + getitem_1592 = split_161[1] + getitem_1593 = split_161[2] + getitem_1594 = split_161[3] + getitem_1595 = split_161[4] + getitem_1596 = split_161[5] + getitem_1597 = split_161[6] + getitem_1598 = split_161[7]; split_161 = None + cat_153 = torch.ops.aten.cat.default([getitem_1591, getitem_1592, getitem_1593, getitem_1594, getitem_1595, getitem_1596, getitem_1597, getitem_1598], 1); getitem_1591 = getitem_1592 = getitem_1593 = getitem_1594 = getitem_1595 = getitem_1596 = getitem_1597 = getitem_1598 = None + view_2449 = torch.ops.aten.view.default(cat_153, [16384, 4096]); cat_153 = None + permute_529 = torch.ops.aten.permute.default(view_2449, [1, 0]) + permute_292 = torch.ops.aten.permute.default(getitem_1146, [0, 2, 1, 3]) + view_1914 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + view_1920 = torch.ops.aten.view.default(view_1914, [16384, 512]); view_1914 = None + mm_303 = torch.ops.aten.mm.default(permute_529, view_1920); permute_529 = view_1920 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16); primals_242 = None + all_gather_into_tensor_292 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 32, '0'); convert_element_type_875 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_292); all_gather_into_tensor_292 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_345, [1, 0]); wait_tensor_345 = None + permute_531 = torch.ops.aten.permute.default(permute_293, [1, 0]); permute_293 = None + mm_304 = torch.ops.aten.mm.default(view_2449, permute_531); view_2449 = permute_531 = None + view_2450 = torch.ops.aten.view.default(mm_304, [2, 8192, 512]); mm_304 = None + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(mm_303, torch.float32); mm_303 = None + reduce_scatter_tensor_128 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1372, 'avg', 32, '0'); convert_element_type_1372 = None + wait_tensor_508 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_128); reduce_scatter_tensor_128 = None + view_2451 = torch.ops.aten.view.default(view_2450, [2, 8192, 4, 128]); view_2450 = None + permute_533 = torch.ops.aten.permute.default(view_2451, [0, 2, 1, 3]); view_2451 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16); primals_238 = None + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 32, '0'); convert_element_type_859 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32); add_103 = None + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_340) + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_861, 8, '1'); convert_element_type_861 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + split_113 = torch.ops.aten.split.Tensor(wait_tensor_341, 2); wait_tensor_341 = None + getitem_1138 = split_113[0] + getitem_1139 = split_113[1] + getitem_1140 = split_113[2] + getitem_1141 = split_113[3] + getitem_1142 = split_113[4] + getitem_1143 = split_113[5] + getitem_1144 = split_113[6] + getitem_1145 = split_113[7]; split_113 = None + cat_105 = torch.ops.aten.cat.default([getitem_1138, getitem_1139, getitem_1140, getitem_1141, getitem_1142, getitem_1143, getitem_1144, getitem_1145], 1); getitem_1138 = getitem_1139 = getitem_1140 = getitem_1141 = getitem_1142 = getitem_1143 = getitem_1144 = getitem_1145 = None + view_1887 = torch.ops.aten.view.default(cat_105, [16384, 4096]); cat_105 = None + view_1888 = torch.ops.aten.view.default(mm_182, [2, 8192, 512]); mm_182 = None + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16); primals_240 = None + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 32, '0'); convert_element_type_865 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_343, [1, 0]); wait_tensor_343 = None + mm_183 = torch.ops.aten.mm.default(view_1887, permute_287) + view_1895 = torch.ops.aten.view.default(mm_183, [2, 8192, 128]); mm_183 = None + view_1902 = torch.ops.aten.view.default(mm_184, [2, 8192, 128]); mm_184 = None + view_1904 = torch.ops.aten.view.default(view_1888, [2, 8192, -1, 128]); view_1888 = None + view_1905 = torch.ops.aten.view.default(view_1895, [2, 8192, -1, 128]); view_1895 = None + view_1906 = torch.ops.aten.view.default(view_1902, [2, 8192, -1, 128]); view_1902 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_1904, torch.float32); view_1904 = None + view_1907 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 4, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_1907); view_1907 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_1905, torch.float32); view_1905 = None + view_1908 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 1, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_1908); view_1908 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_37); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_1910 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 4, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_37); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_1911 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 1, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_1910, torch.bfloat16); view_1910 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_1911, torch.bfloat16); view_1911 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 1, 4, 128]); unsqueeze_52 = None + view_1912 = torch.ops.aten.view.default(expand_52, [2, 8192, 4, 128]); expand_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_1906, 3); view_1906 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 1, 4, 128]); unsqueeze_53 = None + view_1913 = torch.ops.aten.view.default(expand_53, [2, 8192, 4, 128]); expand_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_1912, [0, 2, 1, 3]); view_1912 = None + permute_291 = torch.ops.aten.permute.default(view_1913, [0, 2, 1, 3]); view_1913 = None + _scaled_dot_product_cudnn_attention_backward_5 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_533, permute_289, permute_290, permute_291, getitem_1146, getitem_1147, getitem_1152, getitem_1153, None, None, None, 8192, 8192, 0.0, True); permute_533 = permute_289 = permute_290 = permute_291 = getitem_1146 = getitem_1147 = getitem_1152 = getitem_1153 = None + getitem_1599 = _scaled_dot_product_cudnn_attention_backward_5[0] + getitem_1600 = _scaled_dot_product_cudnn_attention_backward_5[1] + getitem_1601 = _scaled_dot_product_cudnn_attention_backward_5[2]; _scaled_dot_product_cudnn_attention_backward_5 = None + permute_534 = torch.ops.aten.permute.default(getitem_1601, [0, 2, 1, 3]); getitem_1601 = None + permute_535 = torch.ops.aten.permute.default(getitem_1600, [0, 2, 1, 3]); getitem_1600 = None + permute_536 = torch.ops.aten.permute.default(getitem_1599, [0, 2, 1, 3]); getitem_1599 = None + view_2452 = torch.ops.aten.view.default(permute_534, [2, 8192, 1, 4, 128]); permute_534 = None + sum_35 = torch.ops.aten.sum.dim_IntList(view_2452, [3], True); view_2452 = None + squeeze_10 = torch.ops.aten.squeeze.dim(sum_35, 3); sum_35 = None + view_2453 = torch.ops.aten.view.default(permute_535, [2, 8192, 1, 4, 128]); permute_535 = None + sum_36 = torch.ops.aten.sum.dim_IntList(view_2453, [3], True); view_2453 = None + squeeze_11 = torch.ops.aten.squeeze.dim(sum_36, 3); sum_36 = None + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(squeeze_11, torch.float32); squeeze_11 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(permute_536, torch.float32); permute_536 = None + view_2454 = torch.ops.aten.view.default(convert_element_type_1373, [2, 8192, 1, 64, 2]); convert_element_type_1373 = None + view_as_complex_74 = torch.ops.aten.view_as_complex.default(view_2454); view_2454 = None + mul_376 = torch.ops.aten.mul.Tensor(view_as_complex_74, _conj); view_as_complex_74 = None + view_2455 = torch.ops.aten.view.default(convert_element_type_1374, [2, 8192, 4, 64, 2]); convert_element_type_1374 = None + view_as_complex_75 = torch.ops.aten.view_as_complex.default(view_2455); view_2455 = None + mul_377 = torch.ops.aten.mul.Tensor(view_as_complex_75, _conj); view_as_complex_75 = None + view_as_real_74 = torch.ops.aten.view_as_real.default(mul_376); mul_376 = None + view_2456 = torch.ops.aten.view.default(view_as_real_74, [2, 8192, 1, 128]); view_as_real_74 = None + convert_element_type_1375 = torch.ops.prims.convert_element_type.default(view_2456, torch.bfloat16); view_2456 = None + view_as_real_75 = torch.ops.aten.view_as_real.default(mul_377); mul_377 = None + view_2457 = torch.ops.aten.view.default(view_as_real_75, [2, 8192, 4, 128]); view_as_real_75 = None + convert_element_type_1376 = torch.ops.prims.convert_element_type.default(view_2457, torch.bfloat16); view_2457 = None + view_2458 = torch.ops.aten.view.default(squeeze_10, [2, 8192, 128]); squeeze_10 = None + view_2459 = torch.ops.aten.view.default(convert_element_type_1375, [2, 8192, 128]); convert_element_type_1375 = None + view_2460 = torch.ops.aten.view.default(convert_element_type_1376, [2, 8192, 512]); convert_element_type_1376 = None + view_2461 = torch.ops.aten.view.default(view_2458, [16384, 128]); view_2458 = None + permute_537 = torch.ops.aten.permute.default(view_2461, [1, 0]) + mm_305 = torch.ops.aten.mm.default(permute_537, view_1887); permute_537 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16); primals_241 = None + all_gather_into_tensor_291 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 32, '0'); convert_element_type_868 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_291); all_gather_into_tensor_291 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_344, [1, 0]); wait_tensor_344 = None + permute_539 = torch.ops.aten.permute.default(permute_288, [1, 0]); permute_288 = None + mm_306 = torch.ops.aten.mm.default(view_2461, permute_539); view_2461 = permute_539 = None + view_2462 = torch.ops.aten.view.default(mm_306, [2, 8192, 4096]); mm_306 = None + convert_element_type_1381 = torch.ops.prims.convert_element_type.default(mm_305, torch.float32); mm_305 = None + reduce_scatter_tensor_129 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1381, 'avg', 32, '0'); convert_element_type_1381 = None + wait_tensor_509 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_129); reduce_scatter_tensor_129 = None + view_2463 = torch.ops.aten.view.default(view_2459, [16384, 128]); view_2459 = None + permute_541 = torch.ops.aten.permute.default(view_2463, [1, 0]) + mm_307 = torch.ops.aten.mm.default(permute_541, view_1887); permute_541 = None + permute_543 = torch.ops.aten.permute.default(permute_287, [1, 0]); permute_287 = None + mm_308 = torch.ops.aten.mm.default(view_2463, permute_543); view_2463 = permute_543 = None + view_2464 = torch.ops.aten.view.default(mm_308, [2, 8192, 4096]); mm_308 = None + add_168 = torch.ops.aten.add.Tensor(view_2462, view_2464); view_2462 = view_2464 = None + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(mm_307, torch.float32); mm_307 = None + reduce_scatter_tensor_130 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1386, 'avg', 32, '0'); convert_element_type_1386 = None + wait_tensor_510 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_130); reduce_scatter_tensor_130 = None + view_2465 = torch.ops.aten.view.default(view_2460, [16384, 512]); view_2460 = None + permute_545 = torch.ops.aten.permute.default(view_2465, [1, 0]) + mm_309 = torch.ops.aten.mm.default(permute_545, view_1887); permute_545 = view_1887 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16); primals_239 = None + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 32, '0'); convert_element_type_862 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_342, [1, 0]); wait_tensor_342 = None + permute_547 = torch.ops.aten.permute.default(permute_286, [1, 0]); permute_286 = None + mm_310 = torch.ops.aten.mm.default(view_2465, permute_547); view_2465 = permute_547 = None + view_2466 = torch.ops.aten.view.default(mm_310, [2, 8192, 4096]); mm_310 = None + add_169 = torch.ops.aten.add.Tensor(add_168, view_2466); add_168 = view_2466 = None + convert_element_type_1391 = torch.ops.prims.convert_element_type.default(mm_309, torch.float32); mm_309 = None + reduce_scatter_tensor_131 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1391, 'avg', 32, '0'); convert_element_type_1391 = None + wait_tensor_511 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_131); reduce_scatter_tensor_131 = None + split_162 = torch.ops.aten.split.Tensor(add_169, 1024, 1); add_169 = None + getitem_1602 = split_162[0] + getitem_1603 = split_162[1] + getitem_1604 = split_162[2] + getitem_1605 = split_162[3] + getitem_1606 = split_162[4] + getitem_1607 = split_162[5] + getitem_1608 = split_162[6] + getitem_1609 = split_162[7]; split_162 = None + cat_154 = torch.ops.aten.cat.default([getitem_1602, getitem_1603, getitem_1604, getitem_1605, getitem_1606, getitem_1607, getitem_1608, getitem_1609]); getitem_1602 = getitem_1603 = getitem_1604 = getitem_1605 = getitem_1606 = getitem_1607 = getitem_1608 = getitem_1609 = None + reduce_scatter_tensor_132 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_154, 'sum', 8, '1'); cat_154 = None + wait_tensor_512 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_132); reduce_scatter_tensor_132 = None + convert_element_type_1392 = torch.ops.prims.convert_element_type.default(wait_tensor_512, torch.float32); wait_tensor_512 = None + convert_element_type_1394 = torch.ops.prims.convert_element_type.default(wait_tensor_340, torch.float32); wait_tensor_340 = None + mul_378 = torch.ops.aten.mul.Tensor(convert_element_type_1392, convert_element_type_1394); convert_element_type_1394 = None + mul_380 = torch.ops.aten.mul.Tensor(mul_208, mul_378) + sum_37 = torch.ops.aten.sum.dim_IntList(mul_380, [2], True); mul_380 = None + div_12 = torch.ops.aten.div.Tensor(mul_208, 4096) + mul_381 = torch.ops.aten.mul.Tensor(div_12, sum_37); div_12 = sum_37 = None + sub_19 = torch.ops.aten.sub.Tensor(mul_378, mul_381); mul_378 = mul_381 = None + mul_382 = torch.ops.aten.mul.Tensor(sub_19, rsqrt_52); sub_19 = rsqrt_52 = None + mul_383 = torch.ops.aten.mul.Tensor(convert_element_type_1392, mul_208); convert_element_type_1392 = mul_208 = None + sum_38 = torch.ops.aten.sum.dim_IntList(mul_383, [0, 1]); mul_383 = None + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(mul_382, torch.bfloat16); mul_382 = None + convert_element_type_1396 = torch.ops.prims.convert_element_type.default(sum_38, torch.bfloat16); sum_38 = None + all_reduce_12 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1396, 'sum', '1'); convert_element_type_1396 = None + wait_tensor_513 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_12); all_reduce_12 = None + convert_element_type_1397 = torch.ops.prims.convert_element_type.default(wait_tensor_513, torch.float32); wait_tensor_513 = None + reduce_scatter_tensor_133 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1397, 'avg', 32, '0'); convert_element_type_1397 = None + wait_tensor_514 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_133); reduce_scatter_tensor_133 = None + add_170 = torch.ops.aten.add.Tensor(add_167, convert_element_type_1395); add_167 = convert_element_type_1395 = None + all_gather_into_tensor_368 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_170, 8, '1') + wait_tensor_515 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_368); all_gather_into_tensor_368 = None + split_163 = torch.ops.aten.split.Tensor(wait_tensor_515, 2); wait_tensor_515 = None + getitem_1610 = split_163[0] + getitem_1611 = split_163[1] + getitem_1612 = split_163[2] + getitem_1613 = split_163[3] + getitem_1614 = split_163[4] + getitem_1615 = split_163[5] + getitem_1616 = split_163[6] + getitem_1617 = split_163[7]; split_163 = None + cat_155 = torch.ops.aten.cat.default([getitem_1610, getitem_1611, getitem_1612, getitem_1613, getitem_1614, getitem_1615, getitem_1616, getitem_1617], 1); getitem_1610 = getitem_1611 = getitem_1612 = getitem_1613 = getitem_1614 = getitem_1615 = getitem_1616 = getitem_1617 = None + view_2467 = torch.ops.aten.view.default(cat_155, [16384, 4096]); cat_155 = None + permute_549 = torch.ops.aten.permute.default(view_2467, [1, 0]) + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51); reduce_scatter_tensor_51 = None + add_101 = torch.ops.aten.add.Tensor(add_99, wait_tensor_333); wait_tensor_333 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16); primals_234 = None + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 32, '0'); convert_element_type_845 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32); add_101 = None + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_334) + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_847, 8, '1'); convert_element_type_847 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + split_111 = torch.ops.aten.split.Tensor(wait_tensor_335, 2); wait_tensor_335 = None + getitem_1122 = split_111[0] + getitem_1123 = split_111[1] + getitem_1124 = split_111[2] + getitem_1125 = split_111[3] + getitem_1126 = split_111[4] + getitem_1127 = split_111[5] + getitem_1128 = split_111[6] + getitem_1129 = split_111[7]; split_111 = None + cat_103 = torch.ops.aten.cat.default([getitem_1122, getitem_1123, getitem_1124, getitem_1125, getitem_1126, getitem_1127, getitem_1128, getitem_1129], 1); getitem_1122 = getitem_1123 = getitem_1124 = getitem_1125 = getitem_1126 = getitem_1127 = getitem_1128 = getitem_1129 = None + view_1860 = torch.ops.aten.view.default(cat_103, [16384, 4096]); cat_103 = None + view_1861 = torch.ops.aten.view.default(mm_179, [2, 8192, 1792]); mm_179 = None + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_1861, torch.float32); view_1861 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16); primals_236 = None + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 32, '0'); convert_element_type_853 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_337, [1, 0]); wait_tensor_337 = None + mm_180 = torch.ops.aten.mm.default(view_1860, permute_284) + view_1868 = torch.ops.aten.view.default(mm_180, [2, 8192, 1792]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_1868) + view_1875 = torch.ops.aten.view.default(mul_207, [16384, 1792]); mul_207 = None + mm_311 = torch.ops.aten.mm.default(permute_549, view_1875); permute_549 = view_1875 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16); primals_237 = None + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 32, '0'); convert_element_type_856 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_338, [1, 0]); wait_tensor_338 = None + permute_551 = torch.ops.aten.permute.default(permute_285, [1, 0]); permute_285 = None + mm_312 = torch.ops.aten.mm.default(view_2467, permute_551); view_2467 = permute_551 = None + view_2468 = torch.ops.aten.view.default(mm_312, [2, 8192, 1792]); mm_312 = None + convert_element_type_1402 = torch.ops.prims.convert_element_type.default(mm_311, torch.float32); mm_311 = None + reduce_scatter_tensor_134 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1402, 'avg', 32, '0'); convert_element_type_1402 = None + wait_tensor_516 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_134); reduce_scatter_tensor_134 = None + mul_384 = torch.ops.aten.mul.Tensor(view_2468, convert_element_type_852); convert_element_type_852 = None + mul_385 = torch.ops.aten.mul.Tensor(view_2468, view_1868); view_2468 = view_1868 = None + view_2469 = torch.ops.aten.view.default(mul_384, [16384, 1792]); mul_384 = None + permute_553 = torch.ops.aten.permute.default(view_2469, [1, 0]) + mm_313 = torch.ops.aten.mm.default(permute_553, view_1860); permute_553 = None + permute_555 = torch.ops.aten.permute.default(permute_284, [1, 0]); permute_284 = None + mm_314 = torch.ops.aten.mm.default(view_2469, permute_555); view_2469 = permute_555 = None + view_2470 = torch.ops.aten.view.default(mm_314, [2, 8192, 4096]); mm_314 = None + convert_element_type_1407 = torch.ops.prims.convert_element_type.default(mm_313, torch.float32); mm_313 = None + reduce_scatter_tensor_135 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1407, 'avg', 32, '0'); convert_element_type_1407 = None + wait_tensor_517 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_135); reduce_scatter_tensor_135 = None + convert_element_type_1408 = torch.ops.prims.convert_element_type.default(mul_385, torch.float32); mul_385 = None + neg_6 = torch.ops.aten.neg.default(convert_element_type_851) + exp_6 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_171 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + reciprocal_6 = torch.ops.aten.reciprocal.default(add_171); add_171 = None + mul_386 = torch.ops.aten.mul.Tensor(reciprocal_6, 1); reciprocal_6 = None + mul_387 = torch.ops.aten.mul.Tensor(convert_element_type_1408, mul_386); convert_element_type_1408 = None + sub_20 = torch.ops.aten.sub.Tensor(1, mul_386); mul_386 = None + mul_388 = torch.ops.aten.mul.Tensor(convert_element_type_851, sub_20); convert_element_type_851 = sub_20 = None + add_172 = torch.ops.aten.add.Tensor(mul_388, 1); mul_388 = None + mul_389 = torch.ops.aten.mul.Tensor(mul_387, add_172); mul_387 = add_172 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(mul_389, torch.bfloat16); mul_389 = None + view_2471 = torch.ops.aten.view.default(convert_element_type_1410, [16384, 1792]); convert_element_type_1410 = None + permute_557 = torch.ops.aten.permute.default(view_2471, [1, 0]) + mm_315 = torch.ops.aten.mm.default(permute_557, view_1860); permute_557 = view_1860 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16); primals_235 = None + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 32, '0'); convert_element_type_848 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_336, [1, 0]); wait_tensor_336 = None + permute_559 = torch.ops.aten.permute.default(permute_283, [1, 0]); permute_283 = None + mm_316 = torch.ops.aten.mm.default(view_2471, permute_559); view_2471 = permute_559 = None + view_2472 = torch.ops.aten.view.default(mm_316, [2, 8192, 4096]); mm_316 = None + add_173 = torch.ops.aten.add.Tensor(view_2470, view_2472); view_2470 = view_2472 = None + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(mm_315, torch.float32); mm_315 = None + reduce_scatter_tensor_136 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1415, 'avg', 32, '0'); convert_element_type_1415 = None + wait_tensor_518 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_136); reduce_scatter_tensor_136 = None + split_164 = torch.ops.aten.split.Tensor(add_173, 1024, 1); add_173 = None + getitem_1618 = split_164[0] + getitem_1619 = split_164[1] + getitem_1620 = split_164[2] + getitem_1621 = split_164[3] + getitem_1622 = split_164[4] + getitem_1623 = split_164[5] + getitem_1624 = split_164[6] + getitem_1625 = split_164[7]; split_164 = None + cat_156 = torch.ops.aten.cat.default([getitem_1618, getitem_1619, getitem_1620, getitem_1621, getitem_1622, getitem_1623, getitem_1624, getitem_1625]); getitem_1618 = getitem_1619 = getitem_1620 = getitem_1621 = getitem_1622 = getitem_1623 = getitem_1624 = getitem_1625 = None + reduce_scatter_tensor_137 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_156, 'sum', 8, '1'); cat_156 = None + wait_tensor_519 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_137); reduce_scatter_tensor_137 = None + convert_element_type_1416 = torch.ops.prims.convert_element_type.default(wait_tensor_519, torch.float32); wait_tensor_519 = None + convert_element_type_1418 = torch.ops.prims.convert_element_type.default(wait_tensor_334, torch.float32); wait_tensor_334 = None + mul_390 = torch.ops.aten.mul.Tensor(convert_element_type_1416, convert_element_type_1418); convert_element_type_1418 = None + mul_392 = torch.ops.aten.mul.Tensor(mul_204, mul_390) + sum_39 = torch.ops.aten.sum.dim_IntList(mul_392, [2], True); mul_392 = None + div_13 = torch.ops.aten.div.Tensor(mul_204, 4096) + mul_393 = torch.ops.aten.mul.Tensor(div_13, sum_39); div_13 = sum_39 = None + sub_21 = torch.ops.aten.sub.Tensor(mul_390, mul_393); mul_390 = mul_393 = None + mul_394 = torch.ops.aten.mul.Tensor(sub_21, rsqrt_51); sub_21 = rsqrt_51 = None + mul_395 = torch.ops.aten.mul.Tensor(convert_element_type_1416, mul_204); convert_element_type_1416 = mul_204 = None + sum_40 = torch.ops.aten.sum.dim_IntList(mul_395, [0, 1]); mul_395 = None + convert_element_type_1419 = torch.ops.prims.convert_element_type.default(mul_394, torch.bfloat16); mul_394 = None + convert_element_type_1420 = torch.ops.prims.convert_element_type.default(sum_40, torch.bfloat16); sum_40 = None + all_reduce_13 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1420, 'sum', '1'); convert_element_type_1420 = None + wait_tensor_520 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_13); all_reduce_13 = None + convert_element_type_1421 = torch.ops.prims.convert_element_type.default(wait_tensor_520, torch.float32); wait_tensor_520 = None + reduce_scatter_tensor_138 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1421, 'avg', 32, '0'); convert_element_type_1421 = None + wait_tensor_521 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_138); reduce_scatter_tensor_138 = None + add_174 = torch.ops.aten.add.Tensor(add_170, convert_element_type_1419); add_170 = convert_element_type_1419 = None + all_gather_into_tensor_369 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_174, 8, '1') + wait_tensor_522 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_369); all_gather_into_tensor_369 = None + split_165 = torch.ops.aten.split.Tensor(wait_tensor_522, 2); wait_tensor_522 = None + getitem_1626 = split_165[0] + getitem_1627 = split_165[1] + getitem_1628 = split_165[2] + getitem_1629 = split_165[3] + getitem_1630 = split_165[4] + getitem_1631 = split_165[5] + getitem_1632 = split_165[6] + getitem_1633 = split_165[7]; split_165 = None + cat_157 = torch.ops.aten.cat.default([getitem_1626, getitem_1627, getitem_1628, getitem_1629, getitem_1630, getitem_1631, getitem_1632, getitem_1633], 1); getitem_1626 = getitem_1627 = getitem_1628 = getitem_1629 = getitem_1630 = getitem_1631 = getitem_1632 = getitem_1633 = None + view_2473 = torch.ops.aten.view.default(cat_157, [16384, 4096]); cat_157 = None + permute_561 = torch.ops.aten.permute.default(view_2473, [1, 0]) + permute_281 = torch.ops.aten.permute.default(getitem_1105, [0, 2, 1, 3]) + view_1842 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + view_1848 = torch.ops.aten.view.default(view_1842, [16384, 512]); view_1842 = None + mm_317 = torch.ops.aten.mm.default(permute_561, view_1848); permute_561 = view_1848 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16); primals_233 = None + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 32, '0'); convert_element_type_842 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_332, [1, 0]); wait_tensor_332 = None + permute_563 = torch.ops.aten.permute.default(permute_282, [1, 0]); permute_282 = None + mm_318 = torch.ops.aten.mm.default(view_2473, permute_563); view_2473 = permute_563 = None + view_2474 = torch.ops.aten.view.default(mm_318, [2, 8192, 512]); mm_318 = None + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(mm_317, torch.float32); mm_317 = None + reduce_scatter_tensor_139 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1426, 'avg', 32, '0'); convert_element_type_1426 = None + wait_tensor_523 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_139); reduce_scatter_tensor_139 = None + view_2475 = torch.ops.aten.view.default(view_2474, [2, 8192, 4, 128]); view_2474 = None + permute_565 = torch.ops.aten.permute.default(view_2475, [0, 2, 1, 3]); view_2475 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16); primals_229 = None + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 32, '0'); convert_element_type_826 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32); add_99 = None + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_327) + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_828, 8, '1'); convert_element_type_828 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + split_109 = torch.ops.aten.split.Tensor(wait_tensor_328, 2); wait_tensor_328 = None + getitem_1097 = split_109[0] + getitem_1098 = split_109[1] + getitem_1099 = split_109[2] + getitem_1100 = split_109[3] + getitem_1101 = split_109[4] + getitem_1102 = split_109[5] + getitem_1103 = split_109[6] + getitem_1104 = split_109[7]; split_109 = None + cat_101 = torch.ops.aten.cat.default([getitem_1097, getitem_1098, getitem_1099, getitem_1100, getitem_1101, getitem_1102, getitem_1103, getitem_1104], 1); getitem_1097 = getitem_1098 = getitem_1099 = getitem_1100 = getitem_1101 = getitem_1102 = getitem_1103 = getitem_1104 = None + view_1815 = torch.ops.aten.view.default(cat_101, [16384, 4096]); cat_101 = None + view_1816 = torch.ops.aten.view.default(mm_175, [2, 8192, 512]); mm_175 = None + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16); primals_231 = None + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 32, '0'); convert_element_type_832 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_330, [1, 0]); wait_tensor_330 = None + mm_176 = torch.ops.aten.mm.default(view_1815, permute_276) + view_1823 = torch.ops.aten.view.default(mm_176, [2, 8192, 128]); mm_176 = None + view_1830 = torch.ops.aten.view.default(mm_177, [2, 8192, 128]); mm_177 = None + view_1832 = torch.ops.aten.view.default(view_1816, [2, 8192, -1, 128]); view_1816 = None + view_1833 = torch.ops.aten.view.default(view_1823, [2, 8192, -1, 128]); view_1823 = None + view_1834 = torch.ops.aten.view.default(view_1830, [2, 8192, -1, 128]); view_1830 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_1832, torch.float32); view_1832 = None + view_1835 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 4, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_1835); view_1835 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_1833, torch.float32); view_1833 = None + view_1836 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 1, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_1836); view_1836 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_37); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_1838 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 4, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_37); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_1839 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 1, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_1838, torch.bfloat16); view_1838 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_1839, torch.bfloat16); view_1839 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 1, 4, 128]); unsqueeze_50 = None + view_1840 = torch.ops.aten.view.default(expand_50, [2, 8192, 4, 128]); expand_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_1834, 3); view_1834 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 1, 4, 128]); unsqueeze_51 = None + view_1841 = torch.ops.aten.view.default(expand_51, [2, 8192, 4, 128]); expand_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_1840, [0, 2, 1, 3]); view_1840 = None + permute_280 = torch.ops.aten.permute.default(view_1841, [0, 2, 1, 3]); view_1841 = None + _scaled_dot_product_cudnn_attention_backward_6 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_565, permute_278, permute_279, permute_280, getitem_1105, getitem_1106, getitem_1111, getitem_1112, None, None, None, 8192, 8192, 0.0, True); permute_565 = permute_278 = permute_279 = permute_280 = getitem_1105 = getitem_1106 = getitem_1111 = getitem_1112 = None + getitem_1634 = _scaled_dot_product_cudnn_attention_backward_6[0] + getitem_1635 = _scaled_dot_product_cudnn_attention_backward_6[1] + getitem_1636 = _scaled_dot_product_cudnn_attention_backward_6[2]; _scaled_dot_product_cudnn_attention_backward_6 = None + permute_566 = torch.ops.aten.permute.default(getitem_1636, [0, 2, 1, 3]); getitem_1636 = None + permute_567 = torch.ops.aten.permute.default(getitem_1635, [0, 2, 1, 3]); getitem_1635 = None + permute_568 = torch.ops.aten.permute.default(getitem_1634, [0, 2, 1, 3]); getitem_1634 = None + view_2476 = torch.ops.aten.view.default(permute_566, [2, 8192, 1, 4, 128]); permute_566 = None + sum_41 = torch.ops.aten.sum.dim_IntList(view_2476, [3], True); view_2476 = None + squeeze_12 = torch.ops.aten.squeeze.dim(sum_41, 3); sum_41 = None + view_2477 = torch.ops.aten.view.default(permute_567, [2, 8192, 1, 4, 128]); permute_567 = None + sum_42 = torch.ops.aten.sum.dim_IntList(view_2477, [3], True); view_2477 = None + squeeze_13 = torch.ops.aten.squeeze.dim(sum_42, 3); sum_42 = None + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(squeeze_13, torch.float32); squeeze_13 = None + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(permute_568, torch.float32); permute_568 = None + view_2478 = torch.ops.aten.view.default(convert_element_type_1427, [2, 8192, 1, 64, 2]); convert_element_type_1427 = None + view_as_complex_76 = torch.ops.aten.view_as_complex.default(view_2478); view_2478 = None + mul_396 = torch.ops.aten.mul.Tensor(view_as_complex_76, _conj); view_as_complex_76 = None + view_2479 = torch.ops.aten.view.default(convert_element_type_1428, [2, 8192, 4, 64, 2]); convert_element_type_1428 = None + view_as_complex_77 = torch.ops.aten.view_as_complex.default(view_2479); view_2479 = None + mul_397 = torch.ops.aten.mul.Tensor(view_as_complex_77, _conj); view_as_complex_77 = None + view_as_real_76 = torch.ops.aten.view_as_real.default(mul_396); mul_396 = None + view_2480 = torch.ops.aten.view.default(view_as_real_76, [2, 8192, 1, 128]); view_as_real_76 = None + convert_element_type_1429 = torch.ops.prims.convert_element_type.default(view_2480, torch.bfloat16); view_2480 = None + view_as_real_77 = torch.ops.aten.view_as_real.default(mul_397); mul_397 = None + view_2481 = torch.ops.aten.view.default(view_as_real_77, [2, 8192, 4, 128]); view_as_real_77 = None + convert_element_type_1430 = torch.ops.prims.convert_element_type.default(view_2481, torch.bfloat16); view_2481 = None + view_2482 = torch.ops.aten.view.default(squeeze_12, [2, 8192, 128]); squeeze_12 = None + view_2483 = torch.ops.aten.view.default(convert_element_type_1429, [2, 8192, 128]); convert_element_type_1429 = None + view_2484 = torch.ops.aten.view.default(convert_element_type_1430, [2, 8192, 512]); convert_element_type_1430 = None + view_2485 = torch.ops.aten.view.default(view_2482, [16384, 128]); view_2482 = None + permute_569 = torch.ops.aten.permute.default(view_2485, [1, 0]) + mm_319 = torch.ops.aten.mm.default(permute_569, view_1815); permute_569 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16); primals_232 = None + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 32, '0'); convert_element_type_835 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_331, [1, 0]); wait_tensor_331 = None + permute_571 = torch.ops.aten.permute.default(permute_277, [1, 0]); permute_277 = None + mm_320 = torch.ops.aten.mm.default(view_2485, permute_571); view_2485 = permute_571 = None + view_2486 = torch.ops.aten.view.default(mm_320, [2, 8192, 4096]); mm_320 = None + convert_element_type_1435 = torch.ops.prims.convert_element_type.default(mm_319, torch.float32); mm_319 = None + reduce_scatter_tensor_140 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1435, 'avg', 32, '0'); convert_element_type_1435 = None + wait_tensor_524 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_140); reduce_scatter_tensor_140 = None + view_2487 = torch.ops.aten.view.default(view_2483, [16384, 128]); view_2483 = None + permute_573 = torch.ops.aten.permute.default(view_2487, [1, 0]) + mm_321 = torch.ops.aten.mm.default(permute_573, view_1815); permute_573 = None + permute_575 = torch.ops.aten.permute.default(permute_276, [1, 0]); permute_276 = None + mm_322 = torch.ops.aten.mm.default(view_2487, permute_575); view_2487 = permute_575 = None + view_2488 = torch.ops.aten.view.default(mm_322, [2, 8192, 4096]); mm_322 = None + add_175 = torch.ops.aten.add.Tensor(view_2486, view_2488); view_2486 = view_2488 = None + convert_element_type_1440 = torch.ops.prims.convert_element_type.default(mm_321, torch.float32); mm_321 = None + reduce_scatter_tensor_141 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1440, 'avg', 32, '0'); convert_element_type_1440 = None + wait_tensor_525 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_141); reduce_scatter_tensor_141 = None + view_2489 = torch.ops.aten.view.default(view_2484, [16384, 512]); view_2484 = None + permute_577 = torch.ops.aten.permute.default(view_2489, [1, 0]) + mm_323 = torch.ops.aten.mm.default(permute_577, view_1815); permute_577 = view_1815 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16); primals_230 = None + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 32, '0'); convert_element_type_829 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_329, [1, 0]); wait_tensor_329 = None + permute_579 = torch.ops.aten.permute.default(permute_275, [1, 0]); permute_275 = None + mm_324 = torch.ops.aten.mm.default(view_2489, permute_579); view_2489 = permute_579 = None + view_2490 = torch.ops.aten.view.default(mm_324, [2, 8192, 4096]); mm_324 = None + add_176 = torch.ops.aten.add.Tensor(add_175, view_2490); add_175 = view_2490 = None + convert_element_type_1445 = torch.ops.prims.convert_element_type.default(mm_323, torch.float32); mm_323 = None + reduce_scatter_tensor_142 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1445, 'avg', 32, '0'); convert_element_type_1445 = None + wait_tensor_526 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_142); reduce_scatter_tensor_142 = None + split_166 = torch.ops.aten.split.Tensor(add_176, 1024, 1); add_176 = None + getitem_1637 = split_166[0] + getitem_1638 = split_166[1] + getitem_1639 = split_166[2] + getitem_1640 = split_166[3] + getitem_1641 = split_166[4] + getitem_1642 = split_166[5] + getitem_1643 = split_166[6] + getitem_1644 = split_166[7]; split_166 = None + cat_158 = torch.ops.aten.cat.default([getitem_1637, getitem_1638, getitem_1639, getitem_1640, getitem_1641, getitem_1642, getitem_1643, getitem_1644]); getitem_1637 = getitem_1638 = getitem_1639 = getitem_1640 = getitem_1641 = getitem_1642 = getitem_1643 = getitem_1644 = None + reduce_scatter_tensor_143 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_158, 'sum', 8, '1'); cat_158 = None + wait_tensor_527 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_143); reduce_scatter_tensor_143 = None + convert_element_type_1446 = torch.ops.prims.convert_element_type.default(wait_tensor_527, torch.float32); wait_tensor_527 = None + convert_element_type_1448 = torch.ops.prims.convert_element_type.default(wait_tensor_327, torch.float32); wait_tensor_327 = None + mul_398 = torch.ops.aten.mul.Tensor(convert_element_type_1446, convert_element_type_1448); convert_element_type_1448 = None + mul_400 = torch.ops.aten.mul.Tensor(mul_200, mul_398) + sum_43 = torch.ops.aten.sum.dim_IntList(mul_400, [2], True); mul_400 = None + div_14 = torch.ops.aten.div.Tensor(mul_200, 4096) + mul_401 = torch.ops.aten.mul.Tensor(div_14, sum_43); div_14 = sum_43 = None + sub_22 = torch.ops.aten.sub.Tensor(mul_398, mul_401); mul_398 = mul_401 = None + mul_402 = torch.ops.aten.mul.Tensor(sub_22, rsqrt_50); sub_22 = rsqrt_50 = None + mul_403 = torch.ops.aten.mul.Tensor(convert_element_type_1446, mul_200); convert_element_type_1446 = mul_200 = None + sum_44 = torch.ops.aten.sum.dim_IntList(mul_403, [0, 1]); mul_403 = None + convert_element_type_1449 = torch.ops.prims.convert_element_type.default(mul_402, torch.bfloat16); mul_402 = None + convert_element_type_1450 = torch.ops.prims.convert_element_type.default(sum_44, torch.bfloat16); sum_44 = None + all_reduce_14 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1450, 'sum', '1'); convert_element_type_1450 = None + wait_tensor_528 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_14); all_reduce_14 = None + convert_element_type_1451 = torch.ops.prims.convert_element_type.default(wait_tensor_528, torch.float32); wait_tensor_528 = None + reduce_scatter_tensor_144 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1451, 'avg', 32, '0'); convert_element_type_1451 = None + wait_tensor_529 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_144); reduce_scatter_tensor_144 = None + add_177 = torch.ops.aten.add.Tensor(add_174, convert_element_type_1449); add_174 = convert_element_type_1449 = None + all_gather_into_tensor_370 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_177, 8, '1') + wait_tensor_530 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_370); all_gather_into_tensor_370 = None + split_167 = torch.ops.aten.split.Tensor(wait_tensor_530, 2); wait_tensor_530 = None + getitem_1645 = split_167[0] + getitem_1646 = split_167[1] + getitem_1647 = split_167[2] + getitem_1648 = split_167[3] + getitem_1649 = split_167[4] + getitem_1650 = split_167[5] + getitem_1651 = split_167[6] + getitem_1652 = split_167[7]; split_167 = None + cat_159 = torch.ops.aten.cat.default([getitem_1645, getitem_1646, getitem_1647, getitem_1648, getitem_1649, getitem_1650, getitem_1651, getitem_1652], 1); getitem_1645 = getitem_1646 = getitem_1647 = getitem_1648 = getitem_1649 = getitem_1650 = getitem_1651 = getitem_1652 = None + view_2491 = torch.ops.aten.view.default(cat_159, [16384, 4096]); cat_159 = None + permute_581 = torch.ops.aten.permute.default(view_2491, [1, 0]) + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49); reduce_scatter_tensor_49 = None + add_97 = torch.ops.aten.add.Tensor(add_95, wait_tensor_320); wait_tensor_320 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16); primals_225 = None + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 32, '0'); convert_element_type_812 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32); add_97 = None + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_321) + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_814, 8, '1'); convert_element_type_814 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + split_107 = torch.ops.aten.split.Tensor(wait_tensor_322, 2); wait_tensor_322 = None + getitem_1081 = split_107[0] + getitem_1082 = split_107[1] + getitem_1083 = split_107[2] + getitem_1084 = split_107[3] + getitem_1085 = split_107[4] + getitem_1086 = split_107[5] + getitem_1087 = split_107[6] + getitem_1088 = split_107[7]; split_107 = None + cat_99 = torch.ops.aten.cat.default([getitem_1081, getitem_1082, getitem_1083, getitem_1084, getitem_1085, getitem_1086, getitem_1087, getitem_1088], 1); getitem_1081 = getitem_1082 = getitem_1083 = getitem_1084 = getitem_1085 = getitem_1086 = getitem_1087 = getitem_1088 = None + view_1788 = torch.ops.aten.view.default(cat_99, [16384, 4096]); cat_99 = None + view_1789 = torch.ops.aten.view.default(mm_172, [2, 8192, 1792]); mm_172 = None + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_1789, torch.float32); view_1789 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16); primals_227 = None + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 32, '0'); convert_element_type_820 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_324, [1, 0]); wait_tensor_324 = None + mm_173 = torch.ops.aten.mm.default(view_1788, permute_273) + view_1796 = torch.ops.aten.view.default(mm_173, [2, 8192, 1792]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_1796) + view_1803 = torch.ops.aten.view.default(mul_199, [16384, 1792]); mul_199 = None + mm_325 = torch.ops.aten.mm.default(permute_581, view_1803); permute_581 = view_1803 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16); primals_228 = None + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 32, '0'); convert_element_type_823 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_325, [1, 0]); wait_tensor_325 = None + permute_583 = torch.ops.aten.permute.default(permute_274, [1, 0]); permute_274 = None + mm_326 = torch.ops.aten.mm.default(view_2491, permute_583); view_2491 = permute_583 = None + view_2492 = torch.ops.aten.view.default(mm_326, [2, 8192, 1792]); mm_326 = None + convert_element_type_1456 = torch.ops.prims.convert_element_type.default(mm_325, torch.float32); mm_325 = None + reduce_scatter_tensor_145 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1456, 'avg', 32, '0'); convert_element_type_1456 = None + wait_tensor_531 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_145); reduce_scatter_tensor_145 = None + mul_404 = torch.ops.aten.mul.Tensor(view_2492, convert_element_type_819); convert_element_type_819 = None + mul_405 = torch.ops.aten.mul.Tensor(view_2492, view_1796); view_2492 = view_1796 = None + view_2493 = torch.ops.aten.view.default(mul_404, [16384, 1792]); mul_404 = None + permute_585 = torch.ops.aten.permute.default(view_2493, [1, 0]) + mm_327 = torch.ops.aten.mm.default(permute_585, view_1788); permute_585 = None + permute_587 = torch.ops.aten.permute.default(permute_273, [1, 0]); permute_273 = None + mm_328 = torch.ops.aten.mm.default(view_2493, permute_587); view_2493 = permute_587 = None + view_2494 = torch.ops.aten.view.default(mm_328, [2, 8192, 4096]); mm_328 = None + convert_element_type_1461 = torch.ops.prims.convert_element_type.default(mm_327, torch.float32); mm_327 = None + reduce_scatter_tensor_146 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1461, 'avg', 32, '0'); convert_element_type_1461 = None + wait_tensor_532 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_146); reduce_scatter_tensor_146 = None + convert_element_type_1462 = torch.ops.prims.convert_element_type.default(mul_405, torch.float32); mul_405 = None + neg_7 = torch.ops.aten.neg.default(convert_element_type_818) + exp_7 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_178 = torch.ops.aten.add.Tensor(exp_7, 1); exp_7 = None + reciprocal_7 = torch.ops.aten.reciprocal.default(add_178); add_178 = None + mul_406 = torch.ops.aten.mul.Tensor(reciprocal_7, 1); reciprocal_7 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_1462, mul_406); convert_element_type_1462 = None + sub_23 = torch.ops.aten.sub.Tensor(1, mul_406); mul_406 = None + mul_408 = torch.ops.aten.mul.Tensor(convert_element_type_818, sub_23); convert_element_type_818 = sub_23 = None + add_179 = torch.ops.aten.add.Tensor(mul_408, 1); mul_408 = None + mul_409 = torch.ops.aten.mul.Tensor(mul_407, add_179); mul_407 = add_179 = None + convert_element_type_1464 = torch.ops.prims.convert_element_type.default(mul_409, torch.bfloat16); mul_409 = None + view_2495 = torch.ops.aten.view.default(convert_element_type_1464, [16384, 1792]); convert_element_type_1464 = None + permute_589 = torch.ops.aten.permute.default(view_2495, [1, 0]) + mm_329 = torch.ops.aten.mm.default(permute_589, view_1788); permute_589 = view_1788 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16); primals_226 = None + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 32, '0'); convert_element_type_815 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_323, [1, 0]); wait_tensor_323 = None + permute_591 = torch.ops.aten.permute.default(permute_272, [1, 0]); permute_272 = None + mm_330 = torch.ops.aten.mm.default(view_2495, permute_591); view_2495 = permute_591 = None + view_2496 = torch.ops.aten.view.default(mm_330, [2, 8192, 4096]); mm_330 = None + add_180 = torch.ops.aten.add.Tensor(view_2494, view_2496); view_2494 = view_2496 = None + convert_element_type_1469 = torch.ops.prims.convert_element_type.default(mm_329, torch.float32); mm_329 = None + reduce_scatter_tensor_147 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1469, 'avg', 32, '0'); convert_element_type_1469 = None + wait_tensor_533 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_147); reduce_scatter_tensor_147 = None + split_168 = torch.ops.aten.split.Tensor(add_180, 1024, 1); add_180 = None + getitem_1653 = split_168[0] + getitem_1654 = split_168[1] + getitem_1655 = split_168[2] + getitem_1656 = split_168[3] + getitem_1657 = split_168[4] + getitem_1658 = split_168[5] + getitem_1659 = split_168[6] + getitem_1660 = split_168[7]; split_168 = None + cat_160 = torch.ops.aten.cat.default([getitem_1653, getitem_1654, getitem_1655, getitem_1656, getitem_1657, getitem_1658, getitem_1659, getitem_1660]); getitem_1653 = getitem_1654 = getitem_1655 = getitem_1656 = getitem_1657 = getitem_1658 = getitem_1659 = getitem_1660 = None + reduce_scatter_tensor_148 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_160, 'sum', 8, '1'); cat_160 = None + wait_tensor_534 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_148); reduce_scatter_tensor_148 = None + convert_element_type_1470 = torch.ops.prims.convert_element_type.default(wait_tensor_534, torch.float32); wait_tensor_534 = None + convert_element_type_1472 = torch.ops.prims.convert_element_type.default(wait_tensor_321, torch.float32); wait_tensor_321 = None + mul_410 = torch.ops.aten.mul.Tensor(convert_element_type_1470, convert_element_type_1472); convert_element_type_1472 = None + mul_412 = torch.ops.aten.mul.Tensor(mul_196, mul_410) + sum_45 = torch.ops.aten.sum.dim_IntList(mul_412, [2], True); mul_412 = None + div_15 = torch.ops.aten.div.Tensor(mul_196, 4096) + mul_413 = torch.ops.aten.mul.Tensor(div_15, sum_45); div_15 = sum_45 = None + sub_24 = torch.ops.aten.sub.Tensor(mul_410, mul_413); mul_410 = mul_413 = None + mul_414 = torch.ops.aten.mul.Tensor(sub_24, rsqrt_49); sub_24 = rsqrt_49 = None + mul_415 = torch.ops.aten.mul.Tensor(convert_element_type_1470, mul_196); convert_element_type_1470 = mul_196 = None + sum_46 = torch.ops.aten.sum.dim_IntList(mul_415, [0, 1]); mul_415 = None + convert_element_type_1473 = torch.ops.prims.convert_element_type.default(mul_414, torch.bfloat16); mul_414 = None + convert_element_type_1474 = torch.ops.prims.convert_element_type.default(sum_46, torch.bfloat16); sum_46 = None + all_reduce_15 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1474, 'sum', '1'); convert_element_type_1474 = None + wait_tensor_535 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_15); all_reduce_15 = None + convert_element_type_1475 = torch.ops.prims.convert_element_type.default(wait_tensor_535, torch.float32); wait_tensor_535 = None + reduce_scatter_tensor_149 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1475, 'avg', 32, '0'); convert_element_type_1475 = None + wait_tensor_536 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_149); reduce_scatter_tensor_149 = None + add_181 = torch.ops.aten.add.Tensor(add_177, convert_element_type_1473); add_177 = convert_element_type_1473 = None + all_gather_into_tensor_371 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_181, 8, '1') + wait_tensor_537 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_371); all_gather_into_tensor_371 = None + split_169 = torch.ops.aten.split.Tensor(wait_tensor_537, 2); wait_tensor_537 = None + getitem_1661 = split_169[0] + getitem_1662 = split_169[1] + getitem_1663 = split_169[2] + getitem_1664 = split_169[3] + getitem_1665 = split_169[4] + getitem_1666 = split_169[5] + getitem_1667 = split_169[6] + getitem_1668 = split_169[7]; split_169 = None + cat_161 = torch.ops.aten.cat.default([getitem_1661, getitem_1662, getitem_1663, getitem_1664, getitem_1665, getitem_1666, getitem_1667, getitem_1668], 1); getitem_1661 = getitem_1662 = getitem_1663 = getitem_1664 = getitem_1665 = getitem_1666 = getitem_1667 = getitem_1668 = None + view_2497 = torch.ops.aten.view.default(cat_161, [16384, 4096]); cat_161 = None + permute_593 = torch.ops.aten.permute.default(view_2497, [1, 0]) + permute_270 = torch.ops.aten.permute.default(getitem_1064, [0, 2, 1, 3]) + view_1770 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + view_1776 = torch.ops.aten.view.default(view_1770, [16384, 512]); view_1770 = None + mm_331 = torch.ops.aten.mm.default(permute_593, view_1776); permute_593 = view_1776 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16); primals_224 = None + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 32, '0'); convert_element_type_809 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_319, [1, 0]); wait_tensor_319 = None + permute_595 = torch.ops.aten.permute.default(permute_271, [1, 0]); permute_271 = None + mm_332 = torch.ops.aten.mm.default(view_2497, permute_595); view_2497 = permute_595 = None + view_2498 = torch.ops.aten.view.default(mm_332, [2, 8192, 512]); mm_332 = None + convert_element_type_1480 = torch.ops.prims.convert_element_type.default(mm_331, torch.float32); mm_331 = None + reduce_scatter_tensor_150 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1480, 'avg', 32, '0'); convert_element_type_1480 = None + wait_tensor_538 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_150); reduce_scatter_tensor_150 = None + view_2499 = torch.ops.aten.view.default(view_2498, [2, 8192, 4, 128]); view_2498 = None + permute_597 = torch.ops.aten.permute.default(view_2499, [0, 2, 1, 3]); view_2499 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16); primals_220 = None + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 32, '0'); convert_element_type_793 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32); add_95 = None + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_314) + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_795, 8, '1'); convert_element_type_795 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + split_105 = torch.ops.aten.split.Tensor(wait_tensor_315, 2); wait_tensor_315 = None + getitem_1056 = split_105[0] + getitem_1057 = split_105[1] + getitem_1058 = split_105[2] + getitem_1059 = split_105[3] + getitem_1060 = split_105[4] + getitem_1061 = split_105[5] + getitem_1062 = split_105[6] + getitem_1063 = split_105[7]; split_105 = None + cat_97 = torch.ops.aten.cat.default([getitem_1056, getitem_1057, getitem_1058, getitem_1059, getitem_1060, getitem_1061, getitem_1062, getitem_1063], 1); getitem_1056 = getitem_1057 = getitem_1058 = getitem_1059 = getitem_1060 = getitem_1061 = getitem_1062 = getitem_1063 = None + view_1743 = torch.ops.aten.view.default(cat_97, [16384, 4096]); cat_97 = None + view_1744 = torch.ops.aten.view.default(mm_168, [2, 8192, 512]); mm_168 = None + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16); primals_222 = None + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 32, '0'); convert_element_type_799 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_317, [1, 0]); wait_tensor_317 = None + mm_169 = torch.ops.aten.mm.default(view_1743, permute_265) + view_1751 = torch.ops.aten.view.default(mm_169, [2, 8192, 128]); mm_169 = None + view_1758 = torch.ops.aten.view.default(mm_170, [2, 8192, 128]); mm_170 = None + view_1760 = torch.ops.aten.view.default(view_1744, [2, 8192, -1, 128]); view_1744 = None + view_1761 = torch.ops.aten.view.default(view_1751, [2, 8192, -1, 128]); view_1751 = None + view_1762 = torch.ops.aten.view.default(view_1758, [2, 8192, -1, 128]); view_1758 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_1760, torch.float32); view_1760 = None + view_1763 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 4, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_1763); view_1763 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_1761, torch.float32); view_1761 = None + view_1764 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 1, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_1764); view_1764 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_37); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_1766 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 4, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_37); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_1767 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 1, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_1766, torch.bfloat16); view_1766 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_1767, torch.bfloat16); view_1767 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 1, 4, 128]); unsqueeze_48 = None + view_1768 = torch.ops.aten.view.default(expand_48, [2, 8192, 4, 128]); expand_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_1762, 3); view_1762 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 1, 4, 128]); unsqueeze_49 = None + view_1769 = torch.ops.aten.view.default(expand_49, [2, 8192, 4, 128]); expand_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_1768, [0, 2, 1, 3]); view_1768 = None + permute_269 = torch.ops.aten.permute.default(view_1769, [0, 2, 1, 3]); view_1769 = None + _scaled_dot_product_cudnn_attention_backward_7 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_597, permute_267, permute_268, permute_269, getitem_1064, getitem_1065, getitem_1070, getitem_1071, None, None, None, 8192, 8192, 0.0, True); permute_597 = permute_267 = permute_268 = permute_269 = getitem_1064 = getitem_1065 = getitem_1070 = getitem_1071 = None + getitem_1669 = _scaled_dot_product_cudnn_attention_backward_7[0] + getitem_1670 = _scaled_dot_product_cudnn_attention_backward_7[1] + getitem_1671 = _scaled_dot_product_cudnn_attention_backward_7[2]; _scaled_dot_product_cudnn_attention_backward_7 = None + permute_598 = torch.ops.aten.permute.default(getitem_1671, [0, 2, 1, 3]); getitem_1671 = None + permute_599 = torch.ops.aten.permute.default(getitem_1670, [0, 2, 1, 3]); getitem_1670 = None + permute_600 = torch.ops.aten.permute.default(getitem_1669, [0, 2, 1, 3]); getitem_1669 = None + view_2500 = torch.ops.aten.view.default(permute_598, [2, 8192, 1, 4, 128]); permute_598 = None + sum_47 = torch.ops.aten.sum.dim_IntList(view_2500, [3], True); view_2500 = None + squeeze_14 = torch.ops.aten.squeeze.dim(sum_47, 3); sum_47 = None + view_2501 = torch.ops.aten.view.default(permute_599, [2, 8192, 1, 4, 128]); permute_599 = None + sum_48 = torch.ops.aten.sum.dim_IntList(view_2501, [3], True); view_2501 = None + squeeze_15 = torch.ops.aten.squeeze.dim(sum_48, 3); sum_48 = None + convert_element_type_1481 = torch.ops.prims.convert_element_type.default(squeeze_15, torch.float32); squeeze_15 = None + convert_element_type_1482 = torch.ops.prims.convert_element_type.default(permute_600, torch.float32); permute_600 = None + view_2502 = torch.ops.aten.view.default(convert_element_type_1481, [2, 8192, 1, 64, 2]); convert_element_type_1481 = None + view_as_complex_78 = torch.ops.aten.view_as_complex.default(view_2502); view_2502 = None + mul_416 = torch.ops.aten.mul.Tensor(view_as_complex_78, _conj); view_as_complex_78 = None + view_2503 = torch.ops.aten.view.default(convert_element_type_1482, [2, 8192, 4, 64, 2]); convert_element_type_1482 = None + view_as_complex_79 = torch.ops.aten.view_as_complex.default(view_2503); view_2503 = None + mul_417 = torch.ops.aten.mul.Tensor(view_as_complex_79, _conj); view_as_complex_79 = None + view_as_real_78 = torch.ops.aten.view_as_real.default(mul_416); mul_416 = None + view_2504 = torch.ops.aten.view.default(view_as_real_78, [2, 8192, 1, 128]); view_as_real_78 = None + convert_element_type_1483 = torch.ops.prims.convert_element_type.default(view_2504, torch.bfloat16); view_2504 = None + view_as_real_79 = torch.ops.aten.view_as_real.default(mul_417); mul_417 = None + view_2505 = torch.ops.aten.view.default(view_as_real_79, [2, 8192, 4, 128]); view_as_real_79 = None + convert_element_type_1484 = torch.ops.prims.convert_element_type.default(view_2505, torch.bfloat16); view_2505 = None + view_2506 = torch.ops.aten.view.default(squeeze_14, [2, 8192, 128]); squeeze_14 = None + view_2507 = torch.ops.aten.view.default(convert_element_type_1483, [2, 8192, 128]); convert_element_type_1483 = None + view_2508 = torch.ops.aten.view.default(convert_element_type_1484, [2, 8192, 512]); convert_element_type_1484 = None + view_2509 = torch.ops.aten.view.default(view_2506, [16384, 128]); view_2506 = None + permute_601 = torch.ops.aten.permute.default(view_2509, [1, 0]) + mm_333 = torch.ops.aten.mm.default(permute_601, view_1743); permute_601 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16); primals_223 = None + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 32, '0'); convert_element_type_802 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_318, [1, 0]); wait_tensor_318 = None + permute_603 = torch.ops.aten.permute.default(permute_266, [1, 0]); permute_266 = None + mm_334 = torch.ops.aten.mm.default(view_2509, permute_603); view_2509 = permute_603 = None + view_2510 = torch.ops.aten.view.default(mm_334, [2, 8192, 4096]); mm_334 = None + convert_element_type_1489 = torch.ops.prims.convert_element_type.default(mm_333, torch.float32); mm_333 = None + reduce_scatter_tensor_151 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1489, 'avg', 32, '0'); convert_element_type_1489 = None + wait_tensor_539 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_151); reduce_scatter_tensor_151 = None + view_2511 = torch.ops.aten.view.default(view_2507, [16384, 128]); view_2507 = None + permute_605 = torch.ops.aten.permute.default(view_2511, [1, 0]) + mm_335 = torch.ops.aten.mm.default(permute_605, view_1743); permute_605 = None + permute_607 = torch.ops.aten.permute.default(permute_265, [1, 0]); permute_265 = None + mm_336 = torch.ops.aten.mm.default(view_2511, permute_607); view_2511 = permute_607 = None + view_2512 = torch.ops.aten.view.default(mm_336, [2, 8192, 4096]); mm_336 = None + add_182 = torch.ops.aten.add.Tensor(view_2510, view_2512); view_2510 = view_2512 = None + convert_element_type_1494 = torch.ops.prims.convert_element_type.default(mm_335, torch.float32); mm_335 = None + reduce_scatter_tensor_152 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1494, 'avg', 32, '0'); convert_element_type_1494 = None + wait_tensor_540 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_152); reduce_scatter_tensor_152 = None + view_2513 = torch.ops.aten.view.default(view_2508, [16384, 512]); view_2508 = None + permute_609 = torch.ops.aten.permute.default(view_2513, [1, 0]) + mm_337 = torch.ops.aten.mm.default(permute_609, view_1743); permute_609 = view_1743 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16); primals_221 = None + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 32, '0'); convert_element_type_796 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_316, [1, 0]); wait_tensor_316 = None + permute_611 = torch.ops.aten.permute.default(permute_264, [1, 0]); permute_264 = None + mm_338 = torch.ops.aten.mm.default(view_2513, permute_611); view_2513 = permute_611 = None + view_2514 = torch.ops.aten.view.default(mm_338, [2, 8192, 4096]); mm_338 = None + add_183 = torch.ops.aten.add.Tensor(add_182, view_2514); add_182 = view_2514 = None + convert_element_type_1499 = torch.ops.prims.convert_element_type.default(mm_337, torch.float32); mm_337 = None + reduce_scatter_tensor_153 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1499, 'avg', 32, '0'); convert_element_type_1499 = None + wait_tensor_541 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_153); reduce_scatter_tensor_153 = None + split_170 = torch.ops.aten.split.Tensor(add_183, 1024, 1); add_183 = None + getitem_1672 = split_170[0] + getitem_1673 = split_170[1] + getitem_1674 = split_170[2] + getitem_1675 = split_170[3] + getitem_1676 = split_170[4] + getitem_1677 = split_170[5] + getitem_1678 = split_170[6] + getitem_1679 = split_170[7]; split_170 = None + cat_162 = torch.ops.aten.cat.default([getitem_1672, getitem_1673, getitem_1674, getitem_1675, getitem_1676, getitem_1677, getitem_1678, getitem_1679]); getitem_1672 = getitem_1673 = getitem_1674 = getitem_1675 = getitem_1676 = getitem_1677 = getitem_1678 = getitem_1679 = None + reduce_scatter_tensor_154 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_162, 'sum', 8, '1'); cat_162 = None + wait_tensor_542 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_154); reduce_scatter_tensor_154 = None + convert_element_type_1500 = torch.ops.prims.convert_element_type.default(wait_tensor_542, torch.float32); wait_tensor_542 = None + convert_element_type_1502 = torch.ops.prims.convert_element_type.default(wait_tensor_314, torch.float32); wait_tensor_314 = None + mul_418 = torch.ops.aten.mul.Tensor(convert_element_type_1500, convert_element_type_1502); convert_element_type_1502 = None + mul_420 = torch.ops.aten.mul.Tensor(mul_192, mul_418) + sum_49 = torch.ops.aten.sum.dim_IntList(mul_420, [2], True); mul_420 = None + div_16 = torch.ops.aten.div.Tensor(mul_192, 4096) + mul_421 = torch.ops.aten.mul.Tensor(div_16, sum_49); div_16 = sum_49 = None + sub_25 = torch.ops.aten.sub.Tensor(mul_418, mul_421); mul_418 = mul_421 = None + mul_422 = torch.ops.aten.mul.Tensor(sub_25, rsqrt_48); sub_25 = rsqrt_48 = None + mul_423 = torch.ops.aten.mul.Tensor(convert_element_type_1500, mul_192); convert_element_type_1500 = mul_192 = None + sum_50 = torch.ops.aten.sum.dim_IntList(mul_423, [0, 1]); mul_423 = None + convert_element_type_1503 = torch.ops.prims.convert_element_type.default(mul_422, torch.bfloat16); mul_422 = None + convert_element_type_1504 = torch.ops.prims.convert_element_type.default(sum_50, torch.bfloat16); sum_50 = None + all_reduce_16 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1504, 'sum', '1'); convert_element_type_1504 = None + wait_tensor_543 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_16); all_reduce_16 = None + convert_element_type_1505 = torch.ops.prims.convert_element_type.default(wait_tensor_543, torch.float32); wait_tensor_543 = None + reduce_scatter_tensor_155 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1505, 'avg', 32, '0'); convert_element_type_1505 = None + wait_tensor_544 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_155); reduce_scatter_tensor_155 = None + add_184 = torch.ops.aten.add.Tensor(add_181, convert_element_type_1503); add_181 = convert_element_type_1503 = None + all_gather_into_tensor_372 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_184, 8, '1') + wait_tensor_545 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_372); all_gather_into_tensor_372 = None + split_171 = torch.ops.aten.split.Tensor(wait_tensor_545, 2); wait_tensor_545 = None + getitem_1680 = split_171[0] + getitem_1681 = split_171[1] + getitem_1682 = split_171[2] + getitem_1683 = split_171[3] + getitem_1684 = split_171[4] + getitem_1685 = split_171[5] + getitem_1686 = split_171[6] + getitem_1687 = split_171[7]; split_171 = None + cat_163 = torch.ops.aten.cat.default([getitem_1680, getitem_1681, getitem_1682, getitem_1683, getitem_1684, getitem_1685, getitem_1686, getitem_1687], 1); getitem_1680 = getitem_1681 = getitem_1682 = getitem_1683 = getitem_1684 = getitem_1685 = getitem_1686 = getitem_1687 = None + view_2515 = torch.ops.aten.view.default(cat_163, [16384, 4096]); cat_163 = None + permute_613 = torch.ops.aten.permute.default(view_2515, [1, 0]) + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47); reduce_scatter_tensor_47 = None + add_93 = torch.ops.aten.add.Tensor(add_91, wait_tensor_307); wait_tensor_307 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16); primals_216 = None + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 32, '0'); convert_element_type_779 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32); add_93 = None + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_308) + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_781, 8, '1'); convert_element_type_781 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + split_103 = torch.ops.aten.split.Tensor(wait_tensor_309, 2); wait_tensor_309 = None + getitem_1040 = split_103[0] + getitem_1041 = split_103[1] + getitem_1042 = split_103[2] + getitem_1043 = split_103[3] + getitem_1044 = split_103[4] + getitem_1045 = split_103[5] + getitem_1046 = split_103[6] + getitem_1047 = split_103[7]; split_103 = None + cat_95 = torch.ops.aten.cat.default([getitem_1040, getitem_1041, getitem_1042, getitem_1043, getitem_1044, getitem_1045, getitem_1046, getitem_1047], 1); getitem_1040 = getitem_1041 = getitem_1042 = getitem_1043 = getitem_1044 = getitem_1045 = getitem_1046 = getitem_1047 = None + view_1716 = torch.ops.aten.view.default(cat_95, [16384, 4096]); cat_95 = None + view_1717 = torch.ops.aten.view.default(mm_165, [2, 8192, 1792]); mm_165 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_1717, torch.float32); view_1717 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16); primals_218 = None + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 32, '0'); convert_element_type_787 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_311, [1, 0]); wait_tensor_311 = None + mm_166 = torch.ops.aten.mm.default(view_1716, permute_262) + view_1724 = torch.ops.aten.view.default(mm_166, [2, 8192, 1792]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_1724) + view_1731 = torch.ops.aten.view.default(mul_191, [16384, 1792]); mul_191 = None + mm_339 = torch.ops.aten.mm.default(permute_613, view_1731); permute_613 = view_1731 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16); primals_219 = None + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 32, '0'); convert_element_type_790 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_312, [1, 0]); wait_tensor_312 = None + permute_615 = torch.ops.aten.permute.default(permute_263, [1, 0]); permute_263 = None + mm_340 = torch.ops.aten.mm.default(view_2515, permute_615); view_2515 = permute_615 = None + view_2516 = torch.ops.aten.view.default(mm_340, [2, 8192, 1792]); mm_340 = None + convert_element_type_1510 = torch.ops.prims.convert_element_type.default(mm_339, torch.float32); mm_339 = None + reduce_scatter_tensor_156 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1510, 'avg', 32, '0'); convert_element_type_1510 = None + wait_tensor_546 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_156); reduce_scatter_tensor_156 = None + mul_424 = torch.ops.aten.mul.Tensor(view_2516, convert_element_type_786); convert_element_type_786 = None + mul_425 = torch.ops.aten.mul.Tensor(view_2516, view_1724); view_2516 = view_1724 = None + view_2517 = torch.ops.aten.view.default(mul_424, [16384, 1792]); mul_424 = None + permute_617 = torch.ops.aten.permute.default(view_2517, [1, 0]) + mm_341 = torch.ops.aten.mm.default(permute_617, view_1716); permute_617 = None + permute_619 = torch.ops.aten.permute.default(permute_262, [1, 0]); permute_262 = None + mm_342 = torch.ops.aten.mm.default(view_2517, permute_619); view_2517 = permute_619 = None + view_2518 = torch.ops.aten.view.default(mm_342, [2, 8192, 4096]); mm_342 = None + convert_element_type_1515 = torch.ops.prims.convert_element_type.default(mm_341, torch.float32); mm_341 = None + reduce_scatter_tensor_157 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1515, 'avg', 32, '0'); convert_element_type_1515 = None + wait_tensor_547 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_157); reduce_scatter_tensor_157 = None + convert_element_type_1516 = torch.ops.prims.convert_element_type.default(mul_425, torch.float32); mul_425 = None + neg_8 = torch.ops.aten.neg.default(convert_element_type_785) + exp_8 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_185 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + reciprocal_8 = torch.ops.aten.reciprocal.default(add_185); add_185 = None + mul_426 = torch.ops.aten.mul.Tensor(reciprocal_8, 1); reciprocal_8 = None + mul_427 = torch.ops.aten.mul.Tensor(convert_element_type_1516, mul_426); convert_element_type_1516 = None + sub_26 = torch.ops.aten.sub.Tensor(1, mul_426); mul_426 = None + mul_428 = torch.ops.aten.mul.Tensor(convert_element_type_785, sub_26); convert_element_type_785 = sub_26 = None + add_186 = torch.ops.aten.add.Tensor(mul_428, 1); mul_428 = None + mul_429 = torch.ops.aten.mul.Tensor(mul_427, add_186); mul_427 = add_186 = None + convert_element_type_1518 = torch.ops.prims.convert_element_type.default(mul_429, torch.bfloat16); mul_429 = None + view_2519 = torch.ops.aten.view.default(convert_element_type_1518, [16384, 1792]); convert_element_type_1518 = None + permute_621 = torch.ops.aten.permute.default(view_2519, [1, 0]) + mm_343 = torch.ops.aten.mm.default(permute_621, view_1716); permute_621 = view_1716 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16); primals_217 = None + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 32, '0'); convert_element_type_782 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_310, [1, 0]); wait_tensor_310 = None + permute_623 = torch.ops.aten.permute.default(permute_261, [1, 0]); permute_261 = None + mm_344 = torch.ops.aten.mm.default(view_2519, permute_623); view_2519 = permute_623 = None + view_2520 = torch.ops.aten.view.default(mm_344, [2, 8192, 4096]); mm_344 = None + add_187 = torch.ops.aten.add.Tensor(view_2518, view_2520); view_2518 = view_2520 = None + convert_element_type_1523 = torch.ops.prims.convert_element_type.default(mm_343, torch.float32); mm_343 = None + reduce_scatter_tensor_158 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1523, 'avg', 32, '0'); convert_element_type_1523 = None + wait_tensor_548 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_158); reduce_scatter_tensor_158 = None + split_172 = torch.ops.aten.split.Tensor(add_187, 1024, 1); add_187 = None + getitem_1688 = split_172[0] + getitem_1689 = split_172[1] + getitem_1690 = split_172[2] + getitem_1691 = split_172[3] + getitem_1692 = split_172[4] + getitem_1693 = split_172[5] + getitem_1694 = split_172[6] + getitem_1695 = split_172[7]; split_172 = None + cat_164 = torch.ops.aten.cat.default([getitem_1688, getitem_1689, getitem_1690, getitem_1691, getitem_1692, getitem_1693, getitem_1694, getitem_1695]); getitem_1688 = getitem_1689 = getitem_1690 = getitem_1691 = getitem_1692 = getitem_1693 = getitem_1694 = getitem_1695 = None + reduce_scatter_tensor_159 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_164, 'sum', 8, '1'); cat_164 = None + wait_tensor_549 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_159); reduce_scatter_tensor_159 = None + convert_element_type_1524 = torch.ops.prims.convert_element_type.default(wait_tensor_549, torch.float32); wait_tensor_549 = None + convert_element_type_1526 = torch.ops.prims.convert_element_type.default(wait_tensor_308, torch.float32); wait_tensor_308 = None + mul_430 = torch.ops.aten.mul.Tensor(convert_element_type_1524, convert_element_type_1526); convert_element_type_1526 = None + mul_432 = torch.ops.aten.mul.Tensor(mul_188, mul_430) + sum_51 = torch.ops.aten.sum.dim_IntList(mul_432, [2], True); mul_432 = None + div_17 = torch.ops.aten.div.Tensor(mul_188, 4096) + mul_433 = torch.ops.aten.mul.Tensor(div_17, sum_51); div_17 = sum_51 = None + sub_27 = torch.ops.aten.sub.Tensor(mul_430, mul_433); mul_430 = mul_433 = None + mul_434 = torch.ops.aten.mul.Tensor(sub_27, rsqrt_47); sub_27 = rsqrt_47 = None + mul_435 = torch.ops.aten.mul.Tensor(convert_element_type_1524, mul_188); convert_element_type_1524 = mul_188 = None + sum_52 = torch.ops.aten.sum.dim_IntList(mul_435, [0, 1]); mul_435 = None + convert_element_type_1527 = torch.ops.prims.convert_element_type.default(mul_434, torch.bfloat16); mul_434 = None + convert_element_type_1528 = torch.ops.prims.convert_element_type.default(sum_52, torch.bfloat16); sum_52 = None + all_reduce_17 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1528, 'sum', '1'); convert_element_type_1528 = None + wait_tensor_550 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_17); all_reduce_17 = None + convert_element_type_1529 = torch.ops.prims.convert_element_type.default(wait_tensor_550, torch.float32); wait_tensor_550 = None + reduce_scatter_tensor_160 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1529, 'avg', 32, '0'); convert_element_type_1529 = None + wait_tensor_551 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_160); reduce_scatter_tensor_160 = None + add_188 = torch.ops.aten.add.Tensor(add_184, convert_element_type_1527); add_184 = convert_element_type_1527 = None + all_gather_into_tensor_373 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_188, 8, '1') + wait_tensor_552 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_373); all_gather_into_tensor_373 = None + split_173 = torch.ops.aten.split.Tensor(wait_tensor_552, 2); wait_tensor_552 = None + getitem_1696 = split_173[0] + getitem_1697 = split_173[1] + getitem_1698 = split_173[2] + getitem_1699 = split_173[3] + getitem_1700 = split_173[4] + getitem_1701 = split_173[5] + getitem_1702 = split_173[6] + getitem_1703 = split_173[7]; split_173 = None + cat_165 = torch.ops.aten.cat.default([getitem_1696, getitem_1697, getitem_1698, getitem_1699, getitem_1700, getitem_1701, getitem_1702, getitem_1703], 1); getitem_1696 = getitem_1697 = getitem_1698 = getitem_1699 = getitem_1700 = getitem_1701 = getitem_1702 = getitem_1703 = None + view_2521 = torch.ops.aten.view.default(cat_165, [16384, 4096]); cat_165 = None + permute_625 = torch.ops.aten.permute.default(view_2521, [1, 0]) + permute_259 = torch.ops.aten.permute.default(getitem_1023, [0, 2, 1, 3]) + view_1698 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + view_1704 = torch.ops.aten.view.default(view_1698, [16384, 512]); view_1698 = None + mm_345 = torch.ops.aten.mm.default(permute_625, view_1704); permute_625 = view_1704 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16); primals_215 = None + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 32, '0'); convert_element_type_776 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_306, [1, 0]); wait_tensor_306 = None + permute_627 = torch.ops.aten.permute.default(permute_260, [1, 0]); permute_260 = None + mm_346 = torch.ops.aten.mm.default(view_2521, permute_627); view_2521 = permute_627 = None + view_2522 = torch.ops.aten.view.default(mm_346, [2, 8192, 512]); mm_346 = None + convert_element_type_1534 = torch.ops.prims.convert_element_type.default(mm_345, torch.float32); mm_345 = None + reduce_scatter_tensor_161 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1534, 'avg', 32, '0'); convert_element_type_1534 = None + wait_tensor_553 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_161); reduce_scatter_tensor_161 = None + view_2523 = torch.ops.aten.view.default(view_2522, [2, 8192, 4, 128]); view_2522 = None + permute_629 = torch.ops.aten.permute.default(view_2523, [0, 2, 1, 3]); view_2523 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16); primals_211 = None + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 32, '0'); convert_element_type_760 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32); add_91 = None + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_301) + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_762, 8, '1'); convert_element_type_762 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + split_101 = torch.ops.aten.split.Tensor(wait_tensor_302, 2); wait_tensor_302 = None + getitem_1015 = split_101[0] + getitem_1016 = split_101[1] + getitem_1017 = split_101[2] + getitem_1018 = split_101[3] + getitem_1019 = split_101[4] + getitem_1020 = split_101[5] + getitem_1021 = split_101[6] + getitem_1022 = split_101[7]; split_101 = None + cat_93 = torch.ops.aten.cat.default([getitem_1015, getitem_1016, getitem_1017, getitem_1018, getitem_1019, getitem_1020, getitem_1021, getitem_1022], 1); getitem_1015 = getitem_1016 = getitem_1017 = getitem_1018 = getitem_1019 = getitem_1020 = getitem_1021 = getitem_1022 = None + view_1671 = torch.ops.aten.view.default(cat_93, [16384, 4096]); cat_93 = None + view_1672 = torch.ops.aten.view.default(mm_161, [2, 8192, 512]); mm_161 = None + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16); primals_213 = None + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 32, '0'); convert_element_type_766 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_304, [1, 0]); wait_tensor_304 = None + mm_162 = torch.ops.aten.mm.default(view_1671, permute_254) + view_1679 = torch.ops.aten.view.default(mm_162, [2, 8192, 128]); mm_162 = None + view_1686 = torch.ops.aten.view.default(mm_163, [2, 8192, 128]); mm_163 = None + view_1688 = torch.ops.aten.view.default(view_1672, [2, 8192, -1, 128]); view_1672 = None + view_1689 = torch.ops.aten.view.default(view_1679, [2, 8192, -1, 128]); view_1679 = None + view_1690 = torch.ops.aten.view.default(view_1686, [2, 8192, -1, 128]); view_1686 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_1688, torch.float32); view_1688 = None + view_1691 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 4, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_1691); view_1691 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_1689, torch.float32); view_1689 = None + view_1692 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 1, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_1692); view_1692 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_37); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_1694 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 4, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_37); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_1695 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 1, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_1694, torch.bfloat16); view_1694 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_1695, torch.bfloat16); view_1695 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 1, 4, 128]); unsqueeze_46 = None + view_1696 = torch.ops.aten.view.default(expand_46, [2, 8192, 4, 128]); expand_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_1690, 3); view_1690 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 1, 4, 128]); unsqueeze_47 = None + view_1697 = torch.ops.aten.view.default(expand_47, [2, 8192, 4, 128]); expand_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_1696, [0, 2, 1, 3]); view_1696 = None + permute_258 = torch.ops.aten.permute.default(view_1697, [0, 2, 1, 3]); view_1697 = None + _scaled_dot_product_cudnn_attention_backward_8 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_629, permute_256, permute_257, permute_258, getitem_1023, getitem_1024, getitem_1029, getitem_1030, None, None, None, 8192, 8192, 0.0, True); permute_629 = permute_256 = permute_257 = permute_258 = getitem_1023 = getitem_1024 = getitem_1029 = getitem_1030 = None + getitem_1704 = _scaled_dot_product_cudnn_attention_backward_8[0] + getitem_1705 = _scaled_dot_product_cudnn_attention_backward_8[1] + getitem_1706 = _scaled_dot_product_cudnn_attention_backward_8[2]; _scaled_dot_product_cudnn_attention_backward_8 = None + permute_630 = torch.ops.aten.permute.default(getitem_1706, [0, 2, 1, 3]); getitem_1706 = None + permute_631 = torch.ops.aten.permute.default(getitem_1705, [0, 2, 1, 3]); getitem_1705 = None + permute_632 = torch.ops.aten.permute.default(getitem_1704, [0, 2, 1, 3]); getitem_1704 = None + view_2524 = torch.ops.aten.view.default(permute_630, [2, 8192, 1, 4, 128]); permute_630 = None + sum_53 = torch.ops.aten.sum.dim_IntList(view_2524, [3], True); view_2524 = None + squeeze_16 = torch.ops.aten.squeeze.dim(sum_53, 3); sum_53 = None + view_2525 = torch.ops.aten.view.default(permute_631, [2, 8192, 1, 4, 128]); permute_631 = None + sum_54 = torch.ops.aten.sum.dim_IntList(view_2525, [3], True); view_2525 = None + squeeze_17 = torch.ops.aten.squeeze.dim(sum_54, 3); sum_54 = None + convert_element_type_1535 = torch.ops.prims.convert_element_type.default(squeeze_17, torch.float32); squeeze_17 = None + convert_element_type_1536 = torch.ops.prims.convert_element_type.default(permute_632, torch.float32); permute_632 = None + view_2526 = torch.ops.aten.view.default(convert_element_type_1535, [2, 8192, 1, 64, 2]); convert_element_type_1535 = None + view_as_complex_80 = torch.ops.aten.view_as_complex.default(view_2526); view_2526 = None + mul_436 = torch.ops.aten.mul.Tensor(view_as_complex_80, _conj); view_as_complex_80 = None + view_2527 = torch.ops.aten.view.default(convert_element_type_1536, [2, 8192, 4, 64, 2]); convert_element_type_1536 = None + view_as_complex_81 = torch.ops.aten.view_as_complex.default(view_2527); view_2527 = None + mul_437 = torch.ops.aten.mul.Tensor(view_as_complex_81, _conj); view_as_complex_81 = None + view_as_real_80 = torch.ops.aten.view_as_real.default(mul_436); mul_436 = None + view_2528 = torch.ops.aten.view.default(view_as_real_80, [2, 8192, 1, 128]); view_as_real_80 = None + convert_element_type_1537 = torch.ops.prims.convert_element_type.default(view_2528, torch.bfloat16); view_2528 = None + view_as_real_81 = torch.ops.aten.view_as_real.default(mul_437); mul_437 = None + view_2529 = torch.ops.aten.view.default(view_as_real_81, [2, 8192, 4, 128]); view_as_real_81 = None + convert_element_type_1538 = torch.ops.prims.convert_element_type.default(view_2529, torch.bfloat16); view_2529 = None + view_2530 = torch.ops.aten.view.default(squeeze_16, [2, 8192, 128]); squeeze_16 = None + view_2531 = torch.ops.aten.view.default(convert_element_type_1537, [2, 8192, 128]); convert_element_type_1537 = None + view_2532 = torch.ops.aten.view.default(convert_element_type_1538, [2, 8192, 512]); convert_element_type_1538 = None + view_2533 = torch.ops.aten.view.default(view_2530, [16384, 128]); view_2530 = None + permute_633 = torch.ops.aten.permute.default(view_2533, [1, 0]) + mm_347 = torch.ops.aten.mm.default(permute_633, view_1671); permute_633 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16); primals_214 = None + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 32, '0'); convert_element_type_769 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_305, [1, 0]); wait_tensor_305 = None + permute_635 = torch.ops.aten.permute.default(permute_255, [1, 0]); permute_255 = None + mm_348 = torch.ops.aten.mm.default(view_2533, permute_635); view_2533 = permute_635 = None + view_2534 = torch.ops.aten.view.default(mm_348, [2, 8192, 4096]); mm_348 = None + convert_element_type_1543 = torch.ops.prims.convert_element_type.default(mm_347, torch.float32); mm_347 = None + reduce_scatter_tensor_162 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1543, 'avg', 32, '0'); convert_element_type_1543 = None + wait_tensor_554 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_162); reduce_scatter_tensor_162 = None + view_2535 = torch.ops.aten.view.default(view_2531, [16384, 128]); view_2531 = None + permute_637 = torch.ops.aten.permute.default(view_2535, [1, 0]) + mm_349 = torch.ops.aten.mm.default(permute_637, view_1671); permute_637 = None + permute_639 = torch.ops.aten.permute.default(permute_254, [1, 0]); permute_254 = None + mm_350 = torch.ops.aten.mm.default(view_2535, permute_639); view_2535 = permute_639 = None + view_2536 = torch.ops.aten.view.default(mm_350, [2, 8192, 4096]); mm_350 = None + add_189 = torch.ops.aten.add.Tensor(view_2534, view_2536); view_2534 = view_2536 = None + convert_element_type_1548 = torch.ops.prims.convert_element_type.default(mm_349, torch.float32); mm_349 = None + reduce_scatter_tensor_163 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1548, 'avg', 32, '0'); convert_element_type_1548 = None + wait_tensor_555 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_163); reduce_scatter_tensor_163 = None + view_2537 = torch.ops.aten.view.default(view_2532, [16384, 512]); view_2532 = None + permute_641 = torch.ops.aten.permute.default(view_2537, [1, 0]) + mm_351 = torch.ops.aten.mm.default(permute_641, view_1671); permute_641 = view_1671 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16); primals_212 = None + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 32, '0'); convert_element_type_763 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_303, [1, 0]); wait_tensor_303 = None + permute_643 = torch.ops.aten.permute.default(permute_253, [1, 0]); permute_253 = None + mm_352 = torch.ops.aten.mm.default(view_2537, permute_643); view_2537 = permute_643 = None + view_2538 = torch.ops.aten.view.default(mm_352, [2, 8192, 4096]); mm_352 = None + add_190 = torch.ops.aten.add.Tensor(add_189, view_2538); add_189 = view_2538 = None + convert_element_type_1553 = torch.ops.prims.convert_element_type.default(mm_351, torch.float32); mm_351 = None + reduce_scatter_tensor_164 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1553, 'avg', 32, '0'); convert_element_type_1553 = None + wait_tensor_556 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_164); reduce_scatter_tensor_164 = None + split_174 = torch.ops.aten.split.Tensor(add_190, 1024, 1); add_190 = None + getitem_1707 = split_174[0] + getitem_1708 = split_174[1] + getitem_1709 = split_174[2] + getitem_1710 = split_174[3] + getitem_1711 = split_174[4] + getitem_1712 = split_174[5] + getitem_1713 = split_174[6] + getitem_1714 = split_174[7]; split_174 = None + cat_166 = torch.ops.aten.cat.default([getitem_1707, getitem_1708, getitem_1709, getitem_1710, getitem_1711, getitem_1712, getitem_1713, getitem_1714]); getitem_1707 = getitem_1708 = getitem_1709 = getitem_1710 = getitem_1711 = getitem_1712 = getitem_1713 = getitem_1714 = None + reduce_scatter_tensor_165 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_166, 'sum', 8, '1'); cat_166 = None + wait_tensor_557 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_165); reduce_scatter_tensor_165 = None + convert_element_type_1554 = torch.ops.prims.convert_element_type.default(wait_tensor_557, torch.float32); wait_tensor_557 = None + convert_element_type_1556 = torch.ops.prims.convert_element_type.default(wait_tensor_301, torch.float32); wait_tensor_301 = None + mul_438 = torch.ops.aten.mul.Tensor(convert_element_type_1554, convert_element_type_1556); convert_element_type_1556 = None + mul_440 = torch.ops.aten.mul.Tensor(mul_184, mul_438) + sum_55 = torch.ops.aten.sum.dim_IntList(mul_440, [2], True); mul_440 = None + div_18 = torch.ops.aten.div.Tensor(mul_184, 4096) + mul_441 = torch.ops.aten.mul.Tensor(div_18, sum_55); div_18 = sum_55 = None + sub_28 = torch.ops.aten.sub.Tensor(mul_438, mul_441); mul_438 = mul_441 = None + mul_442 = torch.ops.aten.mul.Tensor(sub_28, rsqrt_46); sub_28 = rsqrt_46 = None + mul_443 = torch.ops.aten.mul.Tensor(convert_element_type_1554, mul_184); convert_element_type_1554 = mul_184 = None + sum_56 = torch.ops.aten.sum.dim_IntList(mul_443, [0, 1]); mul_443 = None + convert_element_type_1557 = torch.ops.prims.convert_element_type.default(mul_442, torch.bfloat16); mul_442 = None + convert_element_type_1558 = torch.ops.prims.convert_element_type.default(sum_56, torch.bfloat16); sum_56 = None + all_reduce_18 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1558, 'sum', '1'); convert_element_type_1558 = None + wait_tensor_558 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_18); all_reduce_18 = None + convert_element_type_1559 = torch.ops.prims.convert_element_type.default(wait_tensor_558, torch.float32); wait_tensor_558 = None + reduce_scatter_tensor_166 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1559, 'avg', 32, '0'); convert_element_type_1559 = None + wait_tensor_559 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_166); reduce_scatter_tensor_166 = None + add_191 = torch.ops.aten.add.Tensor(add_188, convert_element_type_1557); add_188 = convert_element_type_1557 = None + all_gather_into_tensor_374 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_191, 8, '1') + wait_tensor_560 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_374); all_gather_into_tensor_374 = None + split_175 = torch.ops.aten.split.Tensor(wait_tensor_560, 2); wait_tensor_560 = None + getitem_1715 = split_175[0] + getitem_1716 = split_175[1] + getitem_1717 = split_175[2] + getitem_1718 = split_175[3] + getitem_1719 = split_175[4] + getitem_1720 = split_175[5] + getitem_1721 = split_175[6] + getitem_1722 = split_175[7]; split_175 = None + cat_167 = torch.ops.aten.cat.default([getitem_1715, getitem_1716, getitem_1717, getitem_1718, getitem_1719, getitem_1720, getitem_1721, getitem_1722], 1); getitem_1715 = getitem_1716 = getitem_1717 = getitem_1718 = getitem_1719 = getitem_1720 = getitem_1721 = getitem_1722 = None + view_2539 = torch.ops.aten.view.default(cat_167, [16384, 4096]); cat_167 = None + permute_645 = torch.ops.aten.permute.default(view_2539, [1, 0]) + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45); reduce_scatter_tensor_45 = None + add_89 = torch.ops.aten.add.Tensor(add_87, wait_tensor_294); wait_tensor_294 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16); primals_207 = None + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 32, '0'); convert_element_type_746 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32); add_89 = None + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_295) + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_748, 8, '1'); convert_element_type_748 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + split_99 = torch.ops.aten.split.Tensor(wait_tensor_296, 2); wait_tensor_296 = None + getitem_999 = split_99[0] + getitem_1000 = split_99[1] + getitem_1001 = split_99[2] + getitem_1002 = split_99[3] + getitem_1003 = split_99[4] + getitem_1004 = split_99[5] + getitem_1005 = split_99[6] + getitem_1006 = split_99[7]; split_99 = None + cat_91 = torch.ops.aten.cat.default([getitem_999, getitem_1000, getitem_1001, getitem_1002, getitem_1003, getitem_1004, getitem_1005, getitem_1006], 1); getitem_999 = getitem_1000 = getitem_1001 = getitem_1002 = getitem_1003 = getitem_1004 = getitem_1005 = getitem_1006 = None + view_1644 = torch.ops.aten.view.default(cat_91, [16384, 4096]); cat_91 = None + view_1645 = torch.ops.aten.view.default(mm_158, [2, 8192, 1792]); mm_158 = None + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_1645, torch.float32); view_1645 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16); primals_209 = None + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 32, '0'); convert_element_type_754 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_298, [1, 0]); wait_tensor_298 = None + mm_159 = torch.ops.aten.mm.default(view_1644, permute_251) + view_1652 = torch.ops.aten.view.default(mm_159, [2, 8192, 1792]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_1652) + view_1659 = torch.ops.aten.view.default(mul_183, [16384, 1792]); mul_183 = None + mm_353 = torch.ops.aten.mm.default(permute_645, view_1659); permute_645 = view_1659 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16); primals_210 = None + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 32, '0'); convert_element_type_757 = None + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_299, [1, 0]); wait_tensor_299 = None + permute_647 = torch.ops.aten.permute.default(permute_252, [1, 0]); permute_252 = None + mm_354 = torch.ops.aten.mm.default(view_2539, permute_647); view_2539 = permute_647 = None + view_2540 = torch.ops.aten.view.default(mm_354, [2, 8192, 1792]); mm_354 = None + convert_element_type_1564 = torch.ops.prims.convert_element_type.default(mm_353, torch.float32); mm_353 = None + reduce_scatter_tensor_167 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1564, 'avg', 32, '0'); convert_element_type_1564 = None + wait_tensor_561 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_167); reduce_scatter_tensor_167 = None + mul_444 = torch.ops.aten.mul.Tensor(view_2540, convert_element_type_753); convert_element_type_753 = None + mul_445 = torch.ops.aten.mul.Tensor(view_2540, view_1652); view_2540 = view_1652 = None + view_2541 = torch.ops.aten.view.default(mul_444, [16384, 1792]); mul_444 = None + permute_649 = torch.ops.aten.permute.default(view_2541, [1, 0]) + mm_355 = torch.ops.aten.mm.default(permute_649, view_1644); permute_649 = None + permute_651 = torch.ops.aten.permute.default(permute_251, [1, 0]); permute_251 = None + mm_356 = torch.ops.aten.mm.default(view_2541, permute_651); view_2541 = permute_651 = None + view_2542 = torch.ops.aten.view.default(mm_356, [2, 8192, 4096]); mm_356 = None + convert_element_type_1569 = torch.ops.prims.convert_element_type.default(mm_355, torch.float32); mm_355 = None + reduce_scatter_tensor_168 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1569, 'avg', 32, '0'); convert_element_type_1569 = None + wait_tensor_562 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_168); reduce_scatter_tensor_168 = None + convert_element_type_1570 = torch.ops.prims.convert_element_type.default(mul_445, torch.float32); mul_445 = None + neg_9 = torch.ops.aten.neg.default(convert_element_type_752) + exp_9 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_192 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + reciprocal_9 = torch.ops.aten.reciprocal.default(add_192); add_192 = None + mul_446 = torch.ops.aten.mul.Tensor(reciprocal_9, 1); reciprocal_9 = None + mul_447 = torch.ops.aten.mul.Tensor(convert_element_type_1570, mul_446); convert_element_type_1570 = None + sub_29 = torch.ops.aten.sub.Tensor(1, mul_446); mul_446 = None + mul_448 = torch.ops.aten.mul.Tensor(convert_element_type_752, sub_29); convert_element_type_752 = sub_29 = None + add_193 = torch.ops.aten.add.Tensor(mul_448, 1); mul_448 = None + mul_449 = torch.ops.aten.mul.Tensor(mul_447, add_193); mul_447 = add_193 = None + convert_element_type_1572 = torch.ops.prims.convert_element_type.default(mul_449, torch.bfloat16); mul_449 = None + view_2543 = torch.ops.aten.view.default(convert_element_type_1572, [16384, 1792]); convert_element_type_1572 = None + permute_653 = torch.ops.aten.permute.default(view_2543, [1, 0]) + mm_357 = torch.ops.aten.mm.default(permute_653, view_1644); permute_653 = view_1644 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16); primals_208 = None + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 32, '0'); convert_element_type_749 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_297, [1, 0]); wait_tensor_297 = None + permute_655 = torch.ops.aten.permute.default(permute_250, [1, 0]); permute_250 = None + mm_358 = torch.ops.aten.mm.default(view_2543, permute_655); view_2543 = permute_655 = None + view_2544 = torch.ops.aten.view.default(mm_358, [2, 8192, 4096]); mm_358 = None + add_194 = torch.ops.aten.add.Tensor(view_2542, view_2544); view_2542 = view_2544 = None + convert_element_type_1577 = torch.ops.prims.convert_element_type.default(mm_357, torch.float32); mm_357 = None + reduce_scatter_tensor_169 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1577, 'avg', 32, '0'); convert_element_type_1577 = None + wait_tensor_563 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_169); reduce_scatter_tensor_169 = None + split_176 = torch.ops.aten.split.Tensor(add_194, 1024, 1); add_194 = None + getitem_1723 = split_176[0] + getitem_1724 = split_176[1] + getitem_1725 = split_176[2] + getitem_1726 = split_176[3] + getitem_1727 = split_176[4] + getitem_1728 = split_176[5] + getitem_1729 = split_176[6] + getitem_1730 = split_176[7]; split_176 = None + cat_168 = torch.ops.aten.cat.default([getitem_1723, getitem_1724, getitem_1725, getitem_1726, getitem_1727, getitem_1728, getitem_1729, getitem_1730]); getitem_1723 = getitem_1724 = getitem_1725 = getitem_1726 = getitem_1727 = getitem_1728 = getitem_1729 = getitem_1730 = None + reduce_scatter_tensor_170 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_168, 'sum', 8, '1'); cat_168 = None + wait_tensor_564 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_170); reduce_scatter_tensor_170 = None + convert_element_type_1578 = torch.ops.prims.convert_element_type.default(wait_tensor_564, torch.float32); wait_tensor_564 = None + convert_element_type_1580 = torch.ops.prims.convert_element_type.default(wait_tensor_295, torch.float32); wait_tensor_295 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_1578, convert_element_type_1580); convert_element_type_1580 = None + mul_452 = torch.ops.aten.mul.Tensor(mul_180, mul_450) + sum_57 = torch.ops.aten.sum.dim_IntList(mul_452, [2], True); mul_452 = None + div_19 = torch.ops.aten.div.Tensor(mul_180, 4096) + mul_453 = torch.ops.aten.mul.Tensor(div_19, sum_57); div_19 = sum_57 = None + sub_30 = torch.ops.aten.sub.Tensor(mul_450, mul_453); mul_450 = mul_453 = None + mul_454 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_45); sub_30 = rsqrt_45 = None + mul_455 = torch.ops.aten.mul.Tensor(convert_element_type_1578, mul_180); convert_element_type_1578 = mul_180 = None + sum_58 = torch.ops.aten.sum.dim_IntList(mul_455, [0, 1]); mul_455 = None + convert_element_type_1581 = torch.ops.prims.convert_element_type.default(mul_454, torch.bfloat16); mul_454 = None + convert_element_type_1582 = torch.ops.prims.convert_element_type.default(sum_58, torch.bfloat16); sum_58 = None + all_reduce_19 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1582, 'sum', '1'); convert_element_type_1582 = None + wait_tensor_565 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_19); all_reduce_19 = None + convert_element_type_1583 = torch.ops.prims.convert_element_type.default(wait_tensor_565, torch.float32); wait_tensor_565 = None + reduce_scatter_tensor_171 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1583, 'avg', 32, '0'); convert_element_type_1583 = None + wait_tensor_566 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_171); reduce_scatter_tensor_171 = None + add_195 = torch.ops.aten.add.Tensor(add_191, convert_element_type_1581); add_191 = convert_element_type_1581 = None + all_gather_into_tensor_375 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_195, 8, '1') + wait_tensor_567 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_375); all_gather_into_tensor_375 = None + split_177 = torch.ops.aten.split.Tensor(wait_tensor_567, 2); wait_tensor_567 = None + getitem_1731 = split_177[0] + getitem_1732 = split_177[1] + getitem_1733 = split_177[2] + getitem_1734 = split_177[3] + getitem_1735 = split_177[4] + getitem_1736 = split_177[5] + getitem_1737 = split_177[6] + getitem_1738 = split_177[7]; split_177 = None + cat_169 = torch.ops.aten.cat.default([getitem_1731, getitem_1732, getitem_1733, getitem_1734, getitem_1735, getitem_1736, getitem_1737, getitem_1738], 1); getitem_1731 = getitem_1732 = getitem_1733 = getitem_1734 = getitem_1735 = getitem_1736 = getitem_1737 = getitem_1738 = None + view_2545 = torch.ops.aten.view.default(cat_169, [16384, 4096]); cat_169 = None + permute_657 = torch.ops.aten.permute.default(view_2545, [1, 0]) + permute_248 = torch.ops.aten.permute.default(getitem_982, [0, 2, 1, 3]) + view_1626 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + view_1632 = torch.ops.aten.view.default(view_1626, [16384, 512]); view_1626 = None + mm_359 = torch.ops.aten.mm.default(permute_657, view_1632); permute_657 = view_1632 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16); primals_206 = None + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 32, '0'); convert_element_type_743 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_293, [1, 0]); wait_tensor_293 = None + permute_659 = torch.ops.aten.permute.default(permute_249, [1, 0]); permute_249 = None + mm_360 = torch.ops.aten.mm.default(view_2545, permute_659); view_2545 = permute_659 = None + view_2546 = torch.ops.aten.view.default(mm_360, [2, 8192, 512]); mm_360 = None + convert_element_type_1588 = torch.ops.prims.convert_element_type.default(mm_359, torch.float32); mm_359 = None + reduce_scatter_tensor_172 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1588, 'avg', 32, '0'); convert_element_type_1588 = None + wait_tensor_568 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_172); reduce_scatter_tensor_172 = None + view_2547 = torch.ops.aten.view.default(view_2546, [2, 8192, 4, 128]); view_2546 = None + permute_661 = torch.ops.aten.permute.default(view_2547, [0, 2, 1, 3]); view_2547 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16); primals_202 = None + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 32, '0'); convert_element_type_727 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32); add_87 = None + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_288) + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_729, 8, '1'); convert_element_type_729 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + split_97 = torch.ops.aten.split.Tensor(wait_tensor_289, 2); wait_tensor_289 = None + getitem_974 = split_97[0] + getitem_975 = split_97[1] + getitem_976 = split_97[2] + getitem_977 = split_97[3] + getitem_978 = split_97[4] + getitem_979 = split_97[5] + getitem_980 = split_97[6] + getitem_981 = split_97[7]; split_97 = None + cat_89 = torch.ops.aten.cat.default([getitem_974, getitem_975, getitem_976, getitem_977, getitem_978, getitem_979, getitem_980, getitem_981], 1); getitem_974 = getitem_975 = getitem_976 = getitem_977 = getitem_978 = getitem_979 = getitem_980 = getitem_981 = None + view_1599 = torch.ops.aten.view.default(cat_89, [16384, 4096]); cat_89 = None + view_1600 = torch.ops.aten.view.default(mm_154, [2, 8192, 512]); mm_154 = None + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16); primals_204 = None + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 32, '0'); convert_element_type_733 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_291, [1, 0]); wait_tensor_291 = None + mm_155 = torch.ops.aten.mm.default(view_1599, permute_243) + view_1607 = torch.ops.aten.view.default(mm_155, [2, 8192, 128]); mm_155 = None + view_1614 = torch.ops.aten.view.default(mm_156, [2, 8192, 128]); mm_156 = None + view_1616 = torch.ops.aten.view.default(view_1600, [2, 8192, -1, 128]); view_1600 = None + view_1617 = torch.ops.aten.view.default(view_1607, [2, 8192, -1, 128]); view_1607 = None + view_1618 = torch.ops.aten.view.default(view_1614, [2, 8192, -1, 128]); view_1614 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_1616, torch.float32); view_1616 = None + view_1619 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 4, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_1619); view_1619 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_1617, torch.float32); view_1617 = None + view_1620 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 1, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_1620); view_1620 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_37); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_1622 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 4, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_37); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_1623 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 1, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_1622, torch.bfloat16); view_1622 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_1623, torch.bfloat16); view_1623 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 1, 4, 128]); unsqueeze_44 = None + view_1624 = torch.ops.aten.view.default(expand_44, [2, 8192, 4, 128]); expand_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_1618, 3); view_1618 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 1, 4, 128]); unsqueeze_45 = None + view_1625 = torch.ops.aten.view.default(expand_45, [2, 8192, 4, 128]); expand_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_1624, [0, 2, 1, 3]); view_1624 = None + permute_247 = torch.ops.aten.permute.default(view_1625, [0, 2, 1, 3]); view_1625 = None + _scaled_dot_product_cudnn_attention_backward_9 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_661, permute_245, permute_246, permute_247, getitem_982, getitem_983, getitem_988, getitem_989, None, None, None, 8192, 8192, 0.0, True); permute_661 = permute_245 = permute_246 = permute_247 = getitem_982 = getitem_983 = getitem_988 = getitem_989 = None + getitem_1739 = _scaled_dot_product_cudnn_attention_backward_9[0] + getitem_1740 = _scaled_dot_product_cudnn_attention_backward_9[1] + getitem_1741 = _scaled_dot_product_cudnn_attention_backward_9[2]; _scaled_dot_product_cudnn_attention_backward_9 = None + permute_662 = torch.ops.aten.permute.default(getitem_1741, [0, 2, 1, 3]); getitem_1741 = None + permute_663 = torch.ops.aten.permute.default(getitem_1740, [0, 2, 1, 3]); getitem_1740 = None + permute_664 = torch.ops.aten.permute.default(getitem_1739, [0, 2, 1, 3]); getitem_1739 = None + view_2548 = torch.ops.aten.view.default(permute_662, [2, 8192, 1, 4, 128]); permute_662 = None + sum_59 = torch.ops.aten.sum.dim_IntList(view_2548, [3], True); view_2548 = None + squeeze_18 = torch.ops.aten.squeeze.dim(sum_59, 3); sum_59 = None + view_2549 = torch.ops.aten.view.default(permute_663, [2, 8192, 1, 4, 128]); permute_663 = None + sum_60 = torch.ops.aten.sum.dim_IntList(view_2549, [3], True); view_2549 = None + squeeze_19 = torch.ops.aten.squeeze.dim(sum_60, 3); sum_60 = None + convert_element_type_1589 = torch.ops.prims.convert_element_type.default(squeeze_19, torch.float32); squeeze_19 = None + convert_element_type_1590 = torch.ops.prims.convert_element_type.default(permute_664, torch.float32); permute_664 = None + view_2550 = torch.ops.aten.view.default(convert_element_type_1589, [2, 8192, 1, 64, 2]); convert_element_type_1589 = None + view_as_complex_82 = torch.ops.aten.view_as_complex.default(view_2550); view_2550 = None + mul_456 = torch.ops.aten.mul.Tensor(view_as_complex_82, _conj); view_as_complex_82 = None + view_2551 = torch.ops.aten.view.default(convert_element_type_1590, [2, 8192, 4, 64, 2]); convert_element_type_1590 = None + view_as_complex_83 = torch.ops.aten.view_as_complex.default(view_2551); view_2551 = None + mul_457 = torch.ops.aten.mul.Tensor(view_as_complex_83, _conj); view_as_complex_83 = None + view_as_real_82 = torch.ops.aten.view_as_real.default(mul_456); mul_456 = None + view_2552 = torch.ops.aten.view.default(view_as_real_82, [2, 8192, 1, 128]); view_as_real_82 = None + convert_element_type_1591 = torch.ops.prims.convert_element_type.default(view_2552, torch.bfloat16); view_2552 = None + view_as_real_83 = torch.ops.aten.view_as_real.default(mul_457); mul_457 = None + view_2553 = torch.ops.aten.view.default(view_as_real_83, [2, 8192, 4, 128]); view_as_real_83 = None + convert_element_type_1592 = torch.ops.prims.convert_element_type.default(view_2553, torch.bfloat16); view_2553 = None + view_2554 = torch.ops.aten.view.default(squeeze_18, [2, 8192, 128]); squeeze_18 = None + view_2555 = torch.ops.aten.view.default(convert_element_type_1591, [2, 8192, 128]); convert_element_type_1591 = None + view_2556 = torch.ops.aten.view.default(convert_element_type_1592, [2, 8192, 512]); convert_element_type_1592 = None + view_2557 = torch.ops.aten.view.default(view_2554, [16384, 128]); view_2554 = None + permute_665 = torch.ops.aten.permute.default(view_2557, [1, 0]) + mm_361 = torch.ops.aten.mm.default(permute_665, view_1599); permute_665 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16); primals_205 = None + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 32, '0'); convert_element_type_736 = None + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_292, [1, 0]); wait_tensor_292 = None + permute_667 = torch.ops.aten.permute.default(permute_244, [1, 0]); permute_244 = None + mm_362 = torch.ops.aten.mm.default(view_2557, permute_667); view_2557 = permute_667 = None + view_2558 = torch.ops.aten.view.default(mm_362, [2, 8192, 4096]); mm_362 = None + convert_element_type_1597 = torch.ops.prims.convert_element_type.default(mm_361, torch.float32); mm_361 = None + reduce_scatter_tensor_173 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1597, 'avg', 32, '0'); convert_element_type_1597 = None + wait_tensor_569 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_173); reduce_scatter_tensor_173 = None + view_2559 = torch.ops.aten.view.default(view_2555, [16384, 128]); view_2555 = None + permute_669 = torch.ops.aten.permute.default(view_2559, [1, 0]) + mm_363 = torch.ops.aten.mm.default(permute_669, view_1599); permute_669 = None + permute_671 = torch.ops.aten.permute.default(permute_243, [1, 0]); permute_243 = None + mm_364 = torch.ops.aten.mm.default(view_2559, permute_671); view_2559 = permute_671 = None + view_2560 = torch.ops.aten.view.default(mm_364, [2, 8192, 4096]); mm_364 = None + add_196 = torch.ops.aten.add.Tensor(view_2558, view_2560); view_2558 = view_2560 = None + convert_element_type_1602 = torch.ops.prims.convert_element_type.default(mm_363, torch.float32); mm_363 = None + reduce_scatter_tensor_174 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1602, 'avg', 32, '0'); convert_element_type_1602 = None + wait_tensor_570 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_174); reduce_scatter_tensor_174 = None + view_2561 = torch.ops.aten.view.default(view_2556, [16384, 512]); view_2556 = None + permute_673 = torch.ops.aten.permute.default(view_2561, [1, 0]) + mm_365 = torch.ops.aten.mm.default(permute_673, view_1599); permute_673 = view_1599 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16); primals_203 = None + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 32, '0'); convert_element_type_730 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + permute_675 = torch.ops.aten.permute.default(permute_242, [1, 0]); permute_242 = None + mm_366 = torch.ops.aten.mm.default(view_2561, permute_675); view_2561 = permute_675 = None + view_2562 = torch.ops.aten.view.default(mm_366, [2, 8192, 4096]); mm_366 = None + add_197 = torch.ops.aten.add.Tensor(add_196, view_2562); add_196 = view_2562 = None + convert_element_type_1607 = torch.ops.prims.convert_element_type.default(mm_365, torch.float32); mm_365 = None + reduce_scatter_tensor_175 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1607, 'avg', 32, '0'); convert_element_type_1607 = None + wait_tensor_571 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_175); reduce_scatter_tensor_175 = None + split_178 = torch.ops.aten.split.Tensor(add_197, 1024, 1); add_197 = None + getitem_1742 = split_178[0] + getitem_1743 = split_178[1] + getitem_1744 = split_178[2] + getitem_1745 = split_178[3] + getitem_1746 = split_178[4] + getitem_1747 = split_178[5] + getitem_1748 = split_178[6] + getitem_1749 = split_178[7]; split_178 = None + cat_170 = torch.ops.aten.cat.default([getitem_1742, getitem_1743, getitem_1744, getitem_1745, getitem_1746, getitem_1747, getitem_1748, getitem_1749]); getitem_1742 = getitem_1743 = getitem_1744 = getitem_1745 = getitem_1746 = getitem_1747 = getitem_1748 = getitem_1749 = None + reduce_scatter_tensor_176 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_170, 'sum', 8, '1'); cat_170 = None + wait_tensor_572 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_176); reduce_scatter_tensor_176 = None + convert_element_type_1608 = torch.ops.prims.convert_element_type.default(wait_tensor_572, torch.float32); wait_tensor_572 = None + convert_element_type_1610 = torch.ops.prims.convert_element_type.default(wait_tensor_288, torch.float32); wait_tensor_288 = None + mul_458 = torch.ops.aten.mul.Tensor(convert_element_type_1608, convert_element_type_1610); convert_element_type_1610 = None + mul_460 = torch.ops.aten.mul.Tensor(mul_176, mul_458) + sum_61 = torch.ops.aten.sum.dim_IntList(mul_460, [2], True); mul_460 = None + div_20 = torch.ops.aten.div.Tensor(mul_176, 4096) + mul_461 = torch.ops.aten.mul.Tensor(div_20, sum_61); div_20 = sum_61 = None + sub_31 = torch.ops.aten.sub.Tensor(mul_458, mul_461); mul_458 = mul_461 = None + mul_462 = torch.ops.aten.mul.Tensor(sub_31, rsqrt_44); sub_31 = rsqrt_44 = None + mul_463 = torch.ops.aten.mul.Tensor(convert_element_type_1608, mul_176); convert_element_type_1608 = mul_176 = None + sum_62 = torch.ops.aten.sum.dim_IntList(mul_463, [0, 1]); mul_463 = None + convert_element_type_1611 = torch.ops.prims.convert_element_type.default(mul_462, torch.bfloat16); mul_462 = None + convert_element_type_1612 = torch.ops.prims.convert_element_type.default(sum_62, torch.bfloat16); sum_62 = None + all_reduce_20 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1612, 'sum', '1'); convert_element_type_1612 = None + wait_tensor_573 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_20); all_reduce_20 = None + convert_element_type_1613 = torch.ops.prims.convert_element_type.default(wait_tensor_573, torch.float32); wait_tensor_573 = None + reduce_scatter_tensor_177 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1613, 'avg', 32, '0'); convert_element_type_1613 = None + wait_tensor_574 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_177); reduce_scatter_tensor_177 = None + add_198 = torch.ops.aten.add.Tensor(add_195, convert_element_type_1611); add_195 = convert_element_type_1611 = None + all_gather_into_tensor_376 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_198, 8, '1') + wait_tensor_575 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_376); all_gather_into_tensor_376 = None + split_179 = torch.ops.aten.split.Tensor(wait_tensor_575, 2); wait_tensor_575 = None + getitem_1750 = split_179[0] + getitem_1751 = split_179[1] + getitem_1752 = split_179[2] + getitem_1753 = split_179[3] + getitem_1754 = split_179[4] + getitem_1755 = split_179[5] + getitem_1756 = split_179[6] + getitem_1757 = split_179[7]; split_179 = None + cat_171 = torch.ops.aten.cat.default([getitem_1750, getitem_1751, getitem_1752, getitem_1753, getitem_1754, getitem_1755, getitem_1756, getitem_1757], 1); getitem_1750 = getitem_1751 = getitem_1752 = getitem_1753 = getitem_1754 = getitem_1755 = getitem_1756 = getitem_1757 = None + view_2563 = torch.ops.aten.view.default(cat_171, [16384, 4096]); cat_171 = None + permute_677 = torch.ops.aten.permute.default(view_2563, [1, 0]) + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43); reduce_scatter_tensor_43 = None + add_85 = torch.ops.aten.add.Tensor(add_83, wait_tensor_281); wait_tensor_281 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16); primals_198 = None + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 32, '0'); convert_element_type_713 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32); add_85 = None + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_282) + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_715, 8, '1'); convert_element_type_715 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + split_95 = torch.ops.aten.split.Tensor(wait_tensor_283, 2); wait_tensor_283 = None + getitem_958 = split_95[0] + getitem_959 = split_95[1] + getitem_960 = split_95[2] + getitem_961 = split_95[3] + getitem_962 = split_95[4] + getitem_963 = split_95[5] + getitem_964 = split_95[6] + getitem_965 = split_95[7]; split_95 = None + cat_87 = torch.ops.aten.cat.default([getitem_958, getitem_959, getitem_960, getitem_961, getitem_962, getitem_963, getitem_964, getitem_965], 1); getitem_958 = getitem_959 = getitem_960 = getitem_961 = getitem_962 = getitem_963 = getitem_964 = getitem_965 = None + view_1572 = torch.ops.aten.view.default(cat_87, [16384, 4096]); cat_87 = None + view_1573 = torch.ops.aten.view.default(mm_151, [2, 8192, 1792]); mm_151 = None + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_1573, torch.float32); view_1573 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16); primals_200 = None + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 32, '0'); convert_element_type_721 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_285, [1, 0]); wait_tensor_285 = None + mm_152 = torch.ops.aten.mm.default(view_1572, permute_240) + view_1580 = torch.ops.aten.view.default(mm_152, [2, 8192, 1792]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_1580) + view_1587 = torch.ops.aten.view.default(mul_175, [16384, 1792]); mul_175 = None + mm_367 = torch.ops.aten.mm.default(permute_677, view_1587); permute_677 = view_1587 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16); primals_201 = None + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 32, '0'); convert_element_type_724 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + permute_679 = torch.ops.aten.permute.default(permute_241, [1, 0]); permute_241 = None + mm_368 = torch.ops.aten.mm.default(view_2563, permute_679); view_2563 = permute_679 = None + view_2564 = torch.ops.aten.view.default(mm_368, [2, 8192, 1792]); mm_368 = None + convert_element_type_1618 = torch.ops.prims.convert_element_type.default(mm_367, torch.float32); mm_367 = None + reduce_scatter_tensor_178 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1618, 'avg', 32, '0'); convert_element_type_1618 = None + wait_tensor_576 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_178); reduce_scatter_tensor_178 = None + mul_464 = torch.ops.aten.mul.Tensor(view_2564, convert_element_type_720); convert_element_type_720 = None + mul_465 = torch.ops.aten.mul.Tensor(view_2564, view_1580); view_2564 = view_1580 = None + view_2565 = torch.ops.aten.view.default(mul_464, [16384, 1792]); mul_464 = None + permute_681 = torch.ops.aten.permute.default(view_2565, [1, 0]) + mm_369 = torch.ops.aten.mm.default(permute_681, view_1572); permute_681 = None + permute_683 = torch.ops.aten.permute.default(permute_240, [1, 0]); permute_240 = None + mm_370 = torch.ops.aten.mm.default(view_2565, permute_683); view_2565 = permute_683 = None + view_2566 = torch.ops.aten.view.default(mm_370, [2, 8192, 4096]); mm_370 = None + convert_element_type_1623 = torch.ops.prims.convert_element_type.default(mm_369, torch.float32); mm_369 = None + reduce_scatter_tensor_179 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1623, 'avg', 32, '0'); convert_element_type_1623 = None + wait_tensor_577 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_179); reduce_scatter_tensor_179 = None + convert_element_type_1624 = torch.ops.prims.convert_element_type.default(mul_465, torch.float32); mul_465 = None + neg_10 = torch.ops.aten.neg.default(convert_element_type_719) + exp_10 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_199 = torch.ops.aten.add.Tensor(exp_10, 1); exp_10 = None + reciprocal_10 = torch.ops.aten.reciprocal.default(add_199); add_199 = None + mul_466 = torch.ops.aten.mul.Tensor(reciprocal_10, 1); reciprocal_10 = None + mul_467 = torch.ops.aten.mul.Tensor(convert_element_type_1624, mul_466); convert_element_type_1624 = None + sub_32 = torch.ops.aten.sub.Tensor(1, mul_466); mul_466 = None + mul_468 = torch.ops.aten.mul.Tensor(convert_element_type_719, sub_32); convert_element_type_719 = sub_32 = None + add_200 = torch.ops.aten.add.Tensor(mul_468, 1); mul_468 = None + mul_469 = torch.ops.aten.mul.Tensor(mul_467, add_200); mul_467 = add_200 = None + convert_element_type_1626 = torch.ops.prims.convert_element_type.default(mul_469, torch.bfloat16); mul_469 = None + view_2567 = torch.ops.aten.view.default(convert_element_type_1626, [16384, 1792]); convert_element_type_1626 = None + permute_685 = torch.ops.aten.permute.default(view_2567, [1, 0]) + mm_371 = torch.ops.aten.mm.default(permute_685, view_1572); permute_685 = view_1572 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16); primals_199 = None + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 32, '0'); convert_element_type_716 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + permute_687 = torch.ops.aten.permute.default(permute_239, [1, 0]); permute_239 = None + mm_372 = torch.ops.aten.mm.default(view_2567, permute_687); view_2567 = permute_687 = None + view_2568 = torch.ops.aten.view.default(mm_372, [2, 8192, 4096]); mm_372 = None + add_201 = torch.ops.aten.add.Tensor(view_2566, view_2568); view_2566 = view_2568 = None + convert_element_type_1631 = torch.ops.prims.convert_element_type.default(mm_371, torch.float32); mm_371 = None + reduce_scatter_tensor_180 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1631, 'avg', 32, '0'); convert_element_type_1631 = None + wait_tensor_578 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_180); reduce_scatter_tensor_180 = None + split_180 = torch.ops.aten.split.Tensor(add_201, 1024, 1); add_201 = None + getitem_1758 = split_180[0] + getitem_1759 = split_180[1] + getitem_1760 = split_180[2] + getitem_1761 = split_180[3] + getitem_1762 = split_180[4] + getitem_1763 = split_180[5] + getitem_1764 = split_180[6] + getitem_1765 = split_180[7]; split_180 = None + cat_172 = torch.ops.aten.cat.default([getitem_1758, getitem_1759, getitem_1760, getitem_1761, getitem_1762, getitem_1763, getitem_1764, getitem_1765]); getitem_1758 = getitem_1759 = getitem_1760 = getitem_1761 = getitem_1762 = getitem_1763 = getitem_1764 = getitem_1765 = None + reduce_scatter_tensor_181 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_172, 'sum', 8, '1'); cat_172 = None + wait_tensor_579 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_181); reduce_scatter_tensor_181 = None + convert_element_type_1632 = torch.ops.prims.convert_element_type.default(wait_tensor_579, torch.float32); wait_tensor_579 = None + convert_element_type_1634 = torch.ops.prims.convert_element_type.default(wait_tensor_282, torch.float32); wait_tensor_282 = None + mul_470 = torch.ops.aten.mul.Tensor(convert_element_type_1632, convert_element_type_1634); convert_element_type_1634 = None + mul_472 = torch.ops.aten.mul.Tensor(mul_172, mul_470) + sum_63 = torch.ops.aten.sum.dim_IntList(mul_472, [2], True); mul_472 = None + div_21 = torch.ops.aten.div.Tensor(mul_172, 4096) + mul_473 = torch.ops.aten.mul.Tensor(div_21, sum_63); div_21 = sum_63 = None + sub_33 = torch.ops.aten.sub.Tensor(mul_470, mul_473); mul_470 = mul_473 = None + mul_474 = torch.ops.aten.mul.Tensor(sub_33, rsqrt_43); sub_33 = rsqrt_43 = None + mul_475 = torch.ops.aten.mul.Tensor(convert_element_type_1632, mul_172); convert_element_type_1632 = mul_172 = None + sum_64 = torch.ops.aten.sum.dim_IntList(mul_475, [0, 1]); mul_475 = None + convert_element_type_1635 = torch.ops.prims.convert_element_type.default(mul_474, torch.bfloat16); mul_474 = None + convert_element_type_1636 = torch.ops.prims.convert_element_type.default(sum_64, torch.bfloat16); sum_64 = None + all_reduce_21 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1636, 'sum', '1'); convert_element_type_1636 = None + wait_tensor_580 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_21); all_reduce_21 = None + convert_element_type_1637 = torch.ops.prims.convert_element_type.default(wait_tensor_580, torch.float32); wait_tensor_580 = None + reduce_scatter_tensor_182 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1637, 'avg', 32, '0'); convert_element_type_1637 = None + wait_tensor_581 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_182); reduce_scatter_tensor_182 = None + add_202 = torch.ops.aten.add.Tensor(add_198, convert_element_type_1635); add_198 = convert_element_type_1635 = None + all_gather_into_tensor_377 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_202, 8, '1') + wait_tensor_582 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_377); all_gather_into_tensor_377 = None + split_181 = torch.ops.aten.split.Tensor(wait_tensor_582, 2); wait_tensor_582 = None + getitem_1766 = split_181[0] + getitem_1767 = split_181[1] + getitem_1768 = split_181[2] + getitem_1769 = split_181[3] + getitem_1770 = split_181[4] + getitem_1771 = split_181[5] + getitem_1772 = split_181[6] + getitem_1773 = split_181[7]; split_181 = None + cat_173 = torch.ops.aten.cat.default([getitem_1766, getitem_1767, getitem_1768, getitem_1769, getitem_1770, getitem_1771, getitem_1772, getitem_1773], 1); getitem_1766 = getitem_1767 = getitem_1768 = getitem_1769 = getitem_1770 = getitem_1771 = getitem_1772 = getitem_1773 = None + view_2569 = torch.ops.aten.view.default(cat_173, [16384, 4096]); cat_173 = None + permute_689 = torch.ops.aten.permute.default(view_2569, [1, 0]) + permute_237 = torch.ops.aten.permute.default(getitem_941, [0, 2, 1, 3]) + view_1554 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + view_1560 = torch.ops.aten.view.default(view_1554, [16384, 512]); view_1554 = None + mm_373 = torch.ops.aten.mm.default(permute_689, view_1560); permute_689 = view_1560 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16); primals_197 = None + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 32, '0'); convert_element_type_710 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_280, [1, 0]); wait_tensor_280 = None + permute_691 = torch.ops.aten.permute.default(permute_238, [1, 0]); permute_238 = None + mm_374 = torch.ops.aten.mm.default(view_2569, permute_691); view_2569 = permute_691 = None + view_2570 = torch.ops.aten.view.default(mm_374, [2, 8192, 512]); mm_374 = None + convert_element_type_1642 = torch.ops.prims.convert_element_type.default(mm_373, torch.float32); mm_373 = None + reduce_scatter_tensor_183 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1642, 'avg', 32, '0'); convert_element_type_1642 = None + wait_tensor_583 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_183); reduce_scatter_tensor_183 = None + view_2571 = torch.ops.aten.view.default(view_2570, [2, 8192, 4, 128]); view_2570 = None + permute_693 = torch.ops.aten.permute.default(view_2571, [0, 2, 1, 3]); view_2571 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16); primals_193 = None + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 32, '0'); convert_element_type_694 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32); add_83 = None + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_275) + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_696, 8, '1'); convert_element_type_696 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + split_93 = torch.ops.aten.split.Tensor(wait_tensor_276, 2); wait_tensor_276 = None + getitem_933 = split_93[0] + getitem_934 = split_93[1] + getitem_935 = split_93[2] + getitem_936 = split_93[3] + getitem_937 = split_93[4] + getitem_938 = split_93[5] + getitem_939 = split_93[6] + getitem_940 = split_93[7]; split_93 = None + cat_85 = torch.ops.aten.cat.default([getitem_933, getitem_934, getitem_935, getitem_936, getitem_937, getitem_938, getitem_939, getitem_940], 1); getitem_933 = getitem_934 = getitem_935 = getitem_936 = getitem_937 = getitem_938 = getitem_939 = getitem_940 = None + view_1527 = torch.ops.aten.view.default(cat_85, [16384, 4096]); cat_85 = None + view_1528 = torch.ops.aten.view.default(mm_147, [2, 8192, 512]); mm_147 = None + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 32, '0'); convert_element_type_700 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_148 = torch.ops.aten.mm.default(view_1527, permute_232) + view_1535 = torch.ops.aten.view.default(mm_148, [2, 8192, 128]); mm_148 = None + view_1542 = torch.ops.aten.view.default(mm_149, [2, 8192, 128]); mm_149 = None + view_1544 = torch.ops.aten.view.default(view_1528, [2, 8192, -1, 128]); view_1528 = None + view_1545 = torch.ops.aten.view.default(view_1535, [2, 8192, -1, 128]); view_1535 = None + view_1546 = torch.ops.aten.view.default(view_1542, [2, 8192, -1, 128]); view_1542 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_1544, torch.float32); view_1544 = None + view_1547 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 4, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_1547); view_1547 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_1545, torch.float32); view_1545 = None + view_1548 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 1, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_1548); view_1548 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_37); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_1550 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 4, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_37); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_1551 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 1, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_1550, torch.bfloat16); view_1550 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_1551, torch.bfloat16); view_1551 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 1, 4, 128]); unsqueeze_42 = None + view_1552 = torch.ops.aten.view.default(expand_42, [2, 8192, 4, 128]); expand_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_1546, 3); view_1546 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 1, 4, 128]); unsqueeze_43 = None + view_1553 = torch.ops.aten.view.default(expand_43, [2, 8192, 4, 128]); expand_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_1552, [0, 2, 1, 3]); view_1552 = None + permute_236 = torch.ops.aten.permute.default(view_1553, [0, 2, 1, 3]); view_1553 = None + _scaled_dot_product_cudnn_attention_backward_10 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_693, permute_234, permute_235, permute_236, getitem_941, getitem_942, getitem_947, getitem_948, None, None, None, 8192, 8192, 0.0, True); permute_693 = permute_234 = permute_235 = permute_236 = getitem_941 = getitem_942 = getitem_947 = getitem_948 = None + getitem_1774 = _scaled_dot_product_cudnn_attention_backward_10[0] + getitem_1775 = _scaled_dot_product_cudnn_attention_backward_10[1] + getitem_1776 = _scaled_dot_product_cudnn_attention_backward_10[2]; _scaled_dot_product_cudnn_attention_backward_10 = None + permute_694 = torch.ops.aten.permute.default(getitem_1776, [0, 2, 1, 3]); getitem_1776 = None + permute_695 = torch.ops.aten.permute.default(getitem_1775, [0, 2, 1, 3]); getitem_1775 = None + permute_696 = torch.ops.aten.permute.default(getitem_1774, [0, 2, 1, 3]); getitem_1774 = None + view_2572 = torch.ops.aten.view.default(permute_694, [2, 8192, 1, 4, 128]); permute_694 = None + sum_65 = torch.ops.aten.sum.dim_IntList(view_2572, [3], True); view_2572 = None + squeeze_20 = torch.ops.aten.squeeze.dim(sum_65, 3); sum_65 = None + view_2573 = torch.ops.aten.view.default(permute_695, [2, 8192, 1, 4, 128]); permute_695 = None + sum_66 = torch.ops.aten.sum.dim_IntList(view_2573, [3], True); view_2573 = None + squeeze_21 = torch.ops.aten.squeeze.dim(sum_66, 3); sum_66 = None + convert_element_type_1643 = torch.ops.prims.convert_element_type.default(squeeze_21, torch.float32); squeeze_21 = None + convert_element_type_1644 = torch.ops.prims.convert_element_type.default(permute_696, torch.float32); permute_696 = None + view_2574 = torch.ops.aten.view.default(convert_element_type_1643, [2, 8192, 1, 64, 2]); convert_element_type_1643 = None + view_as_complex_84 = torch.ops.aten.view_as_complex.default(view_2574); view_2574 = None + mul_476 = torch.ops.aten.mul.Tensor(view_as_complex_84, _conj); view_as_complex_84 = None + view_2575 = torch.ops.aten.view.default(convert_element_type_1644, [2, 8192, 4, 64, 2]); convert_element_type_1644 = None + view_as_complex_85 = torch.ops.aten.view_as_complex.default(view_2575); view_2575 = None + mul_477 = torch.ops.aten.mul.Tensor(view_as_complex_85, _conj); view_as_complex_85 = None + view_as_real_84 = torch.ops.aten.view_as_real.default(mul_476); mul_476 = None + view_2576 = torch.ops.aten.view.default(view_as_real_84, [2, 8192, 1, 128]); view_as_real_84 = None + convert_element_type_1645 = torch.ops.prims.convert_element_type.default(view_2576, torch.bfloat16); view_2576 = None + view_as_real_85 = torch.ops.aten.view_as_real.default(mul_477); mul_477 = None + view_2577 = torch.ops.aten.view.default(view_as_real_85, [2, 8192, 4, 128]); view_as_real_85 = None + convert_element_type_1646 = torch.ops.prims.convert_element_type.default(view_2577, torch.bfloat16); view_2577 = None + view_2578 = torch.ops.aten.view.default(squeeze_20, [2, 8192, 128]); squeeze_20 = None + view_2579 = torch.ops.aten.view.default(convert_element_type_1645, [2, 8192, 128]); convert_element_type_1645 = None + view_2580 = torch.ops.aten.view.default(convert_element_type_1646, [2, 8192, 512]); convert_element_type_1646 = None + view_2581 = torch.ops.aten.view.default(view_2578, [16384, 128]); view_2578 = None + permute_697 = torch.ops.aten.permute.default(view_2581, [1, 0]) + mm_375 = torch.ops.aten.mm.default(permute_697, view_1527); permute_697 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 32, '0'); convert_element_type_703 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + permute_699 = torch.ops.aten.permute.default(permute_233, [1, 0]); permute_233 = None + mm_376 = torch.ops.aten.mm.default(view_2581, permute_699); view_2581 = permute_699 = None + view_2582 = torch.ops.aten.view.default(mm_376, [2, 8192, 4096]); mm_376 = None + convert_element_type_1651 = torch.ops.prims.convert_element_type.default(mm_375, torch.float32); mm_375 = None + reduce_scatter_tensor_184 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1651, 'avg', 32, '0'); convert_element_type_1651 = None + wait_tensor_584 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_184); reduce_scatter_tensor_184 = None + view_2583 = torch.ops.aten.view.default(view_2579, [16384, 128]); view_2579 = None + permute_701 = torch.ops.aten.permute.default(view_2583, [1, 0]) + mm_377 = torch.ops.aten.mm.default(permute_701, view_1527); permute_701 = None + permute_703 = torch.ops.aten.permute.default(permute_232, [1, 0]); permute_232 = None + mm_378 = torch.ops.aten.mm.default(view_2583, permute_703); view_2583 = permute_703 = None + view_2584 = torch.ops.aten.view.default(mm_378, [2, 8192, 4096]); mm_378 = None + add_203 = torch.ops.aten.add.Tensor(view_2582, view_2584); view_2582 = view_2584 = None + convert_element_type_1656 = torch.ops.prims.convert_element_type.default(mm_377, torch.float32); mm_377 = None + reduce_scatter_tensor_185 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1656, 'avg', 32, '0'); convert_element_type_1656 = None + wait_tensor_585 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_185); reduce_scatter_tensor_185 = None + view_2585 = torch.ops.aten.view.default(view_2580, [16384, 512]); view_2580 = None + permute_705 = torch.ops.aten.permute.default(view_2585, [1, 0]) + mm_379 = torch.ops.aten.mm.default(permute_705, view_1527); permute_705 = view_1527 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16); primals_194 = None + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 32, '0'); convert_element_type_697 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + permute_707 = torch.ops.aten.permute.default(permute_231, [1, 0]); permute_231 = None + mm_380 = torch.ops.aten.mm.default(view_2585, permute_707); view_2585 = permute_707 = None + view_2586 = torch.ops.aten.view.default(mm_380, [2, 8192, 4096]); mm_380 = None + add_204 = torch.ops.aten.add.Tensor(add_203, view_2586); add_203 = view_2586 = None + convert_element_type_1661 = torch.ops.prims.convert_element_type.default(mm_379, torch.float32); mm_379 = None + reduce_scatter_tensor_186 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1661, 'avg', 32, '0'); convert_element_type_1661 = None + wait_tensor_586 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_186); reduce_scatter_tensor_186 = None + split_182 = torch.ops.aten.split.Tensor(add_204, 1024, 1); add_204 = None + getitem_1777 = split_182[0] + getitem_1778 = split_182[1] + getitem_1779 = split_182[2] + getitem_1780 = split_182[3] + getitem_1781 = split_182[4] + getitem_1782 = split_182[5] + getitem_1783 = split_182[6] + getitem_1784 = split_182[7]; split_182 = None + cat_174 = torch.ops.aten.cat.default([getitem_1777, getitem_1778, getitem_1779, getitem_1780, getitem_1781, getitem_1782, getitem_1783, getitem_1784]); getitem_1777 = getitem_1778 = getitem_1779 = getitem_1780 = getitem_1781 = getitem_1782 = getitem_1783 = getitem_1784 = None + reduce_scatter_tensor_187 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_174, 'sum', 8, '1'); cat_174 = None + wait_tensor_587 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_187); reduce_scatter_tensor_187 = None + convert_element_type_1662 = torch.ops.prims.convert_element_type.default(wait_tensor_587, torch.float32); wait_tensor_587 = None + convert_element_type_1664 = torch.ops.prims.convert_element_type.default(wait_tensor_275, torch.float32); wait_tensor_275 = None + mul_478 = torch.ops.aten.mul.Tensor(convert_element_type_1662, convert_element_type_1664); convert_element_type_1664 = None + mul_480 = torch.ops.aten.mul.Tensor(mul_168, mul_478) + sum_67 = torch.ops.aten.sum.dim_IntList(mul_480, [2], True); mul_480 = None + div_22 = torch.ops.aten.div.Tensor(mul_168, 4096) + mul_481 = torch.ops.aten.mul.Tensor(div_22, sum_67); div_22 = sum_67 = None + sub_34 = torch.ops.aten.sub.Tensor(mul_478, mul_481); mul_478 = mul_481 = None + mul_482 = torch.ops.aten.mul.Tensor(sub_34, rsqrt_42); sub_34 = rsqrt_42 = None + mul_483 = torch.ops.aten.mul.Tensor(convert_element_type_1662, mul_168); convert_element_type_1662 = mul_168 = None + sum_68 = torch.ops.aten.sum.dim_IntList(mul_483, [0, 1]); mul_483 = None + convert_element_type_1665 = torch.ops.prims.convert_element_type.default(mul_482, torch.bfloat16); mul_482 = None + convert_element_type_1666 = torch.ops.prims.convert_element_type.default(sum_68, torch.bfloat16); sum_68 = None + all_reduce_22 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1666, 'sum', '1'); convert_element_type_1666 = None + wait_tensor_588 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_22); all_reduce_22 = None + convert_element_type_1667 = torch.ops.prims.convert_element_type.default(wait_tensor_588, torch.float32); wait_tensor_588 = None + reduce_scatter_tensor_188 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1667, 'avg', 32, '0'); convert_element_type_1667 = None + wait_tensor_589 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_188); reduce_scatter_tensor_188 = None + add_205 = torch.ops.aten.add.Tensor(add_202, convert_element_type_1665); add_202 = convert_element_type_1665 = None + all_gather_into_tensor_378 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_205, 8, '1') + wait_tensor_590 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_378); all_gather_into_tensor_378 = None + split_183 = torch.ops.aten.split.Tensor(wait_tensor_590, 2); wait_tensor_590 = None + getitem_1785 = split_183[0] + getitem_1786 = split_183[1] + getitem_1787 = split_183[2] + getitem_1788 = split_183[3] + getitem_1789 = split_183[4] + getitem_1790 = split_183[5] + getitem_1791 = split_183[6] + getitem_1792 = split_183[7]; split_183 = None + cat_175 = torch.ops.aten.cat.default([getitem_1785, getitem_1786, getitem_1787, getitem_1788, getitem_1789, getitem_1790, getitem_1791, getitem_1792], 1); getitem_1785 = getitem_1786 = getitem_1787 = getitem_1788 = getitem_1789 = getitem_1790 = getitem_1791 = getitem_1792 = None + view_2587 = torch.ops.aten.view.default(cat_175, [16384, 4096]); cat_175 = None + permute_709 = torch.ops.aten.permute.default(view_2587, [1, 0]) + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41); reduce_scatter_tensor_41 = None + add_81 = torch.ops.aten.add.Tensor(add_79, wait_tensor_268); wait_tensor_268 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16); primals_189 = None + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 32, '0'); convert_element_type_680 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32); add_81 = None + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_269) + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_682, 8, '1'); convert_element_type_682 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + split_91 = torch.ops.aten.split.Tensor(wait_tensor_270, 2); wait_tensor_270 = None + getitem_917 = split_91[0] + getitem_918 = split_91[1] + getitem_919 = split_91[2] + getitem_920 = split_91[3] + getitem_921 = split_91[4] + getitem_922 = split_91[5] + getitem_923 = split_91[6] + getitem_924 = split_91[7]; split_91 = None + cat_83 = torch.ops.aten.cat.default([getitem_917, getitem_918, getitem_919, getitem_920, getitem_921, getitem_922, getitem_923, getitem_924], 1); getitem_917 = getitem_918 = getitem_919 = getitem_920 = getitem_921 = getitem_922 = getitem_923 = getitem_924 = None + view_1500 = torch.ops.aten.view.default(cat_83, [16384, 4096]); cat_83 = None + view_1501 = torch.ops.aten.view.default(mm_144, [2, 8192, 1792]); mm_144 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_1501, torch.float32); view_1501 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16); primals_191 = None + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 32, '0'); convert_element_type_688 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + mm_145 = torch.ops.aten.mm.default(view_1500, permute_229) + view_1508 = torch.ops.aten.view.default(mm_145, [2, 8192, 1792]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_1508) + view_1515 = torch.ops.aten.view.default(mul_167, [16384, 1792]); mul_167 = None + mm_381 = torch.ops.aten.mm.default(permute_709, view_1515); permute_709 = view_1515 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16); primals_192 = None + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 32, '0'); convert_element_type_691 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + permute_711 = torch.ops.aten.permute.default(permute_230, [1, 0]); permute_230 = None + mm_382 = torch.ops.aten.mm.default(view_2587, permute_711); view_2587 = permute_711 = None + view_2588 = torch.ops.aten.view.default(mm_382, [2, 8192, 1792]); mm_382 = None + convert_element_type_1672 = torch.ops.prims.convert_element_type.default(mm_381, torch.float32); mm_381 = None + reduce_scatter_tensor_189 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1672, 'avg', 32, '0'); convert_element_type_1672 = None + wait_tensor_591 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_189); reduce_scatter_tensor_189 = None + mul_484 = torch.ops.aten.mul.Tensor(view_2588, convert_element_type_687); convert_element_type_687 = None + mul_485 = torch.ops.aten.mul.Tensor(view_2588, view_1508); view_2588 = view_1508 = None + view_2589 = torch.ops.aten.view.default(mul_484, [16384, 1792]); mul_484 = None + permute_713 = torch.ops.aten.permute.default(view_2589, [1, 0]) + mm_383 = torch.ops.aten.mm.default(permute_713, view_1500); permute_713 = None + permute_715 = torch.ops.aten.permute.default(permute_229, [1, 0]); permute_229 = None + mm_384 = torch.ops.aten.mm.default(view_2589, permute_715); view_2589 = permute_715 = None + view_2590 = torch.ops.aten.view.default(mm_384, [2, 8192, 4096]); mm_384 = None + convert_element_type_1677 = torch.ops.prims.convert_element_type.default(mm_383, torch.float32); mm_383 = None + reduce_scatter_tensor_190 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1677, 'avg', 32, '0'); convert_element_type_1677 = None + wait_tensor_592 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_190); reduce_scatter_tensor_190 = None + convert_element_type_1678 = torch.ops.prims.convert_element_type.default(mul_485, torch.float32); mul_485 = None + neg_11 = torch.ops.aten.neg.default(convert_element_type_686) + exp_11 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_206 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + reciprocal_11 = torch.ops.aten.reciprocal.default(add_206); add_206 = None + mul_486 = torch.ops.aten.mul.Tensor(reciprocal_11, 1); reciprocal_11 = None + mul_487 = torch.ops.aten.mul.Tensor(convert_element_type_1678, mul_486); convert_element_type_1678 = None + sub_35 = torch.ops.aten.sub.Tensor(1, mul_486); mul_486 = None + mul_488 = torch.ops.aten.mul.Tensor(convert_element_type_686, sub_35); convert_element_type_686 = sub_35 = None + add_207 = torch.ops.aten.add.Tensor(mul_488, 1); mul_488 = None + mul_489 = torch.ops.aten.mul.Tensor(mul_487, add_207); mul_487 = add_207 = None + convert_element_type_1680 = torch.ops.prims.convert_element_type.default(mul_489, torch.bfloat16); mul_489 = None + view_2591 = torch.ops.aten.view.default(convert_element_type_1680, [16384, 1792]); convert_element_type_1680 = None + permute_717 = torch.ops.aten.permute.default(view_2591, [1, 0]) + mm_385 = torch.ops.aten.mm.default(permute_717, view_1500); permute_717 = view_1500 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16); primals_190 = None + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 32, '0'); convert_element_type_683 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_271, [1, 0]); wait_tensor_271 = None + permute_719 = torch.ops.aten.permute.default(permute_228, [1, 0]); permute_228 = None + mm_386 = torch.ops.aten.mm.default(view_2591, permute_719); view_2591 = permute_719 = None + view_2592 = torch.ops.aten.view.default(mm_386, [2, 8192, 4096]); mm_386 = None + add_208 = torch.ops.aten.add.Tensor(view_2590, view_2592); view_2590 = view_2592 = None + convert_element_type_1685 = torch.ops.prims.convert_element_type.default(mm_385, torch.float32); mm_385 = None + reduce_scatter_tensor_191 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1685, 'avg', 32, '0'); convert_element_type_1685 = None + wait_tensor_593 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_191); reduce_scatter_tensor_191 = None + split_184 = torch.ops.aten.split.Tensor(add_208, 1024, 1); add_208 = None + getitem_1793 = split_184[0] + getitem_1794 = split_184[1] + getitem_1795 = split_184[2] + getitem_1796 = split_184[3] + getitem_1797 = split_184[4] + getitem_1798 = split_184[5] + getitem_1799 = split_184[6] + getitem_1800 = split_184[7]; split_184 = None + cat_176 = torch.ops.aten.cat.default([getitem_1793, getitem_1794, getitem_1795, getitem_1796, getitem_1797, getitem_1798, getitem_1799, getitem_1800]); getitem_1793 = getitem_1794 = getitem_1795 = getitem_1796 = getitem_1797 = getitem_1798 = getitem_1799 = getitem_1800 = None + reduce_scatter_tensor_192 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_176, 'sum', 8, '1'); cat_176 = None + wait_tensor_594 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_192); reduce_scatter_tensor_192 = None + convert_element_type_1686 = torch.ops.prims.convert_element_type.default(wait_tensor_594, torch.float32); wait_tensor_594 = None + convert_element_type_1688 = torch.ops.prims.convert_element_type.default(wait_tensor_269, torch.float32); wait_tensor_269 = None + mul_490 = torch.ops.aten.mul.Tensor(convert_element_type_1686, convert_element_type_1688); convert_element_type_1688 = None + mul_492 = torch.ops.aten.mul.Tensor(mul_164, mul_490) + sum_69 = torch.ops.aten.sum.dim_IntList(mul_492, [2], True); mul_492 = None + div_23 = torch.ops.aten.div.Tensor(mul_164, 4096) + mul_493 = torch.ops.aten.mul.Tensor(div_23, sum_69); div_23 = sum_69 = None + sub_36 = torch.ops.aten.sub.Tensor(mul_490, mul_493); mul_490 = mul_493 = None + mul_494 = torch.ops.aten.mul.Tensor(sub_36, rsqrt_41); sub_36 = rsqrt_41 = None + mul_495 = torch.ops.aten.mul.Tensor(convert_element_type_1686, mul_164); convert_element_type_1686 = mul_164 = None + sum_70 = torch.ops.aten.sum.dim_IntList(mul_495, [0, 1]); mul_495 = None + convert_element_type_1689 = torch.ops.prims.convert_element_type.default(mul_494, torch.bfloat16); mul_494 = None + convert_element_type_1690 = torch.ops.prims.convert_element_type.default(sum_70, torch.bfloat16); sum_70 = None + all_reduce_23 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1690, 'sum', '1'); convert_element_type_1690 = None + wait_tensor_595 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_23); all_reduce_23 = None + convert_element_type_1691 = torch.ops.prims.convert_element_type.default(wait_tensor_595, torch.float32); wait_tensor_595 = None + reduce_scatter_tensor_193 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1691, 'avg', 32, '0'); convert_element_type_1691 = None + wait_tensor_596 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_193); reduce_scatter_tensor_193 = None + add_209 = torch.ops.aten.add.Tensor(add_205, convert_element_type_1689); add_205 = convert_element_type_1689 = None + all_gather_into_tensor_379 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_209, 8, '1') + wait_tensor_597 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_379); all_gather_into_tensor_379 = None + split_185 = torch.ops.aten.split.Tensor(wait_tensor_597, 2); wait_tensor_597 = None + getitem_1801 = split_185[0] + getitem_1802 = split_185[1] + getitem_1803 = split_185[2] + getitem_1804 = split_185[3] + getitem_1805 = split_185[4] + getitem_1806 = split_185[5] + getitem_1807 = split_185[6] + getitem_1808 = split_185[7]; split_185 = None + cat_177 = torch.ops.aten.cat.default([getitem_1801, getitem_1802, getitem_1803, getitem_1804, getitem_1805, getitem_1806, getitem_1807, getitem_1808], 1); getitem_1801 = getitem_1802 = getitem_1803 = getitem_1804 = getitem_1805 = getitem_1806 = getitem_1807 = getitem_1808 = None + view_2593 = torch.ops.aten.view.default(cat_177, [16384, 4096]); cat_177 = None + permute_721 = torch.ops.aten.permute.default(view_2593, [1, 0]) + permute_226 = torch.ops.aten.permute.default(getitem_900, [0, 2, 1, 3]) + view_1482 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + view_1488 = torch.ops.aten.view.default(view_1482, [16384, 512]); view_1482 = None + mm_387 = torch.ops.aten.mm.default(permute_721, view_1488); permute_721 = view_1488 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16); primals_188 = None + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 32, '0'); convert_element_type_677 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_267, [1, 0]); wait_tensor_267 = None + permute_723 = torch.ops.aten.permute.default(permute_227, [1, 0]); permute_227 = None + mm_388 = torch.ops.aten.mm.default(view_2593, permute_723); view_2593 = permute_723 = None + view_2594 = torch.ops.aten.view.default(mm_388, [2, 8192, 512]); mm_388 = None + convert_element_type_1696 = torch.ops.prims.convert_element_type.default(mm_387, torch.float32); mm_387 = None + reduce_scatter_tensor_194 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1696, 'avg', 32, '0'); convert_element_type_1696 = None + wait_tensor_598 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_194); reduce_scatter_tensor_194 = None + view_2595 = torch.ops.aten.view.default(view_2594, [2, 8192, 4, 128]); view_2594 = None + permute_725 = torch.ops.aten.permute.default(view_2595, [0, 2, 1, 3]); view_2595 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16); primals_184 = None + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 32, '0'); convert_element_type_661 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32); add_79 = None + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_262) + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_663, 8, '1'); convert_element_type_663 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + split_89 = torch.ops.aten.split.Tensor(wait_tensor_263, 2); wait_tensor_263 = None + getitem_892 = split_89[0] + getitem_893 = split_89[1] + getitem_894 = split_89[2] + getitem_895 = split_89[3] + getitem_896 = split_89[4] + getitem_897 = split_89[5] + getitem_898 = split_89[6] + getitem_899 = split_89[7]; split_89 = None + cat_81 = torch.ops.aten.cat.default([getitem_892, getitem_893, getitem_894, getitem_895, getitem_896, getitem_897, getitem_898, getitem_899], 1); getitem_892 = getitem_893 = getitem_894 = getitem_895 = getitem_896 = getitem_897 = getitem_898 = getitem_899 = None + view_1455 = torch.ops.aten.view.default(cat_81, [16384, 4096]); cat_81 = None + view_1456 = torch.ops.aten.view.default(mm_140, [2, 8192, 512]); mm_140 = None + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16); primals_186 = None + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 32, '0'); convert_element_type_667 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_141 = torch.ops.aten.mm.default(view_1455, permute_221) + view_1463 = torch.ops.aten.view.default(mm_141, [2, 8192, 128]); mm_141 = None + view_1470 = torch.ops.aten.view.default(mm_142, [2, 8192, 128]); mm_142 = None + view_1472 = torch.ops.aten.view.default(view_1456, [2, 8192, -1, 128]); view_1456 = None + view_1473 = torch.ops.aten.view.default(view_1463, [2, 8192, -1, 128]); view_1463 = None + view_1474 = torch.ops.aten.view.default(view_1470, [2, 8192, -1, 128]); view_1470 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_1472, torch.float32); view_1472 = None + view_1475 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 4, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_1475); view_1475 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_1473, torch.float32); view_1473 = None + view_1476 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 1, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_1476); view_1476 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_37); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_1478 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 4, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_37); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_1479 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 1, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_1478, torch.bfloat16); view_1478 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_1479, torch.bfloat16); view_1479 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 1, 4, 128]); unsqueeze_40 = None + view_1480 = torch.ops.aten.view.default(expand_40, [2, 8192, 4, 128]); expand_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_1474, 3); view_1474 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 1, 4, 128]); unsqueeze_41 = None + view_1481 = torch.ops.aten.view.default(expand_41, [2, 8192, 4, 128]); expand_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_1480, [0, 2, 1, 3]); view_1480 = None + permute_225 = torch.ops.aten.permute.default(view_1481, [0, 2, 1, 3]); view_1481 = None + _scaled_dot_product_cudnn_attention_backward_11 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_725, permute_223, permute_224, permute_225, getitem_900, getitem_901, getitem_906, getitem_907, None, None, None, 8192, 8192, 0.0, True); permute_725 = permute_223 = permute_224 = permute_225 = getitem_900 = getitem_901 = getitem_906 = getitem_907 = None + getitem_1809 = _scaled_dot_product_cudnn_attention_backward_11[0] + getitem_1810 = _scaled_dot_product_cudnn_attention_backward_11[1] + getitem_1811 = _scaled_dot_product_cudnn_attention_backward_11[2]; _scaled_dot_product_cudnn_attention_backward_11 = None + permute_726 = torch.ops.aten.permute.default(getitem_1811, [0, 2, 1, 3]); getitem_1811 = None + permute_727 = torch.ops.aten.permute.default(getitem_1810, [0, 2, 1, 3]); getitem_1810 = None + permute_728 = torch.ops.aten.permute.default(getitem_1809, [0, 2, 1, 3]); getitem_1809 = None + view_2596 = torch.ops.aten.view.default(permute_726, [2, 8192, 1, 4, 128]); permute_726 = None + sum_71 = torch.ops.aten.sum.dim_IntList(view_2596, [3], True); view_2596 = None + squeeze_22 = torch.ops.aten.squeeze.dim(sum_71, 3); sum_71 = None + view_2597 = torch.ops.aten.view.default(permute_727, [2, 8192, 1, 4, 128]); permute_727 = None + sum_72 = torch.ops.aten.sum.dim_IntList(view_2597, [3], True); view_2597 = None + squeeze_23 = torch.ops.aten.squeeze.dim(sum_72, 3); sum_72 = None + convert_element_type_1697 = torch.ops.prims.convert_element_type.default(squeeze_23, torch.float32); squeeze_23 = None + convert_element_type_1698 = torch.ops.prims.convert_element_type.default(permute_728, torch.float32); permute_728 = None + view_2598 = torch.ops.aten.view.default(convert_element_type_1697, [2, 8192, 1, 64, 2]); convert_element_type_1697 = None + view_as_complex_86 = torch.ops.aten.view_as_complex.default(view_2598); view_2598 = None + mul_496 = torch.ops.aten.mul.Tensor(view_as_complex_86, _conj); view_as_complex_86 = None + view_2599 = torch.ops.aten.view.default(convert_element_type_1698, [2, 8192, 4, 64, 2]); convert_element_type_1698 = None + view_as_complex_87 = torch.ops.aten.view_as_complex.default(view_2599); view_2599 = None + mul_497 = torch.ops.aten.mul.Tensor(view_as_complex_87, _conj); view_as_complex_87 = None + view_as_real_86 = torch.ops.aten.view_as_real.default(mul_496); mul_496 = None + view_2600 = torch.ops.aten.view.default(view_as_real_86, [2, 8192, 1, 128]); view_as_real_86 = None + convert_element_type_1699 = torch.ops.prims.convert_element_type.default(view_2600, torch.bfloat16); view_2600 = None + view_as_real_87 = torch.ops.aten.view_as_real.default(mul_497); mul_497 = None + view_2601 = torch.ops.aten.view.default(view_as_real_87, [2, 8192, 4, 128]); view_as_real_87 = None + convert_element_type_1700 = torch.ops.prims.convert_element_type.default(view_2601, torch.bfloat16); view_2601 = None + view_2602 = torch.ops.aten.view.default(squeeze_22, [2, 8192, 128]); squeeze_22 = None + view_2603 = torch.ops.aten.view.default(convert_element_type_1699, [2, 8192, 128]); convert_element_type_1699 = None + view_2604 = torch.ops.aten.view.default(convert_element_type_1700, [2, 8192, 512]); convert_element_type_1700 = None + view_2605 = torch.ops.aten.view.default(view_2602, [16384, 128]); view_2602 = None + permute_729 = torch.ops.aten.permute.default(view_2605, [1, 0]) + mm_389 = torch.ops.aten.mm.default(permute_729, view_1455); permute_729 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16); primals_187 = None + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 32, '0'); convert_element_type_670 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + permute_731 = torch.ops.aten.permute.default(permute_222, [1, 0]); permute_222 = None + mm_390 = torch.ops.aten.mm.default(view_2605, permute_731); view_2605 = permute_731 = None + view_2606 = torch.ops.aten.view.default(mm_390, [2, 8192, 4096]); mm_390 = None + convert_element_type_1705 = torch.ops.prims.convert_element_type.default(mm_389, torch.float32); mm_389 = None + reduce_scatter_tensor_195 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1705, 'avg', 32, '0'); convert_element_type_1705 = None + wait_tensor_599 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_195); reduce_scatter_tensor_195 = None + view_2607 = torch.ops.aten.view.default(view_2603, [16384, 128]); view_2603 = None + permute_733 = torch.ops.aten.permute.default(view_2607, [1, 0]) + mm_391 = torch.ops.aten.mm.default(permute_733, view_1455); permute_733 = None + permute_735 = torch.ops.aten.permute.default(permute_221, [1, 0]); permute_221 = None + mm_392 = torch.ops.aten.mm.default(view_2607, permute_735); view_2607 = permute_735 = None + view_2608 = torch.ops.aten.view.default(mm_392, [2, 8192, 4096]); mm_392 = None + add_210 = torch.ops.aten.add.Tensor(view_2606, view_2608); view_2606 = view_2608 = None + convert_element_type_1710 = torch.ops.prims.convert_element_type.default(mm_391, torch.float32); mm_391 = None + reduce_scatter_tensor_196 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1710, 'avg', 32, '0'); convert_element_type_1710 = None + wait_tensor_600 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_196); reduce_scatter_tensor_196 = None + view_2609 = torch.ops.aten.view.default(view_2604, [16384, 512]); view_2604 = None + permute_737 = torch.ops.aten.permute.default(view_2609, [1, 0]) + mm_393 = torch.ops.aten.mm.default(permute_737, view_1455); permute_737 = view_1455 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16); primals_185 = None + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 32, '0'); convert_element_type_664 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + permute_739 = torch.ops.aten.permute.default(permute_220, [1, 0]); permute_220 = None + mm_394 = torch.ops.aten.mm.default(view_2609, permute_739); view_2609 = permute_739 = None + view_2610 = torch.ops.aten.view.default(mm_394, [2, 8192, 4096]); mm_394 = None + add_211 = torch.ops.aten.add.Tensor(add_210, view_2610); add_210 = view_2610 = None + convert_element_type_1715 = torch.ops.prims.convert_element_type.default(mm_393, torch.float32); mm_393 = None + reduce_scatter_tensor_197 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1715, 'avg', 32, '0'); convert_element_type_1715 = None + wait_tensor_601 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_197); reduce_scatter_tensor_197 = None + split_186 = torch.ops.aten.split.Tensor(add_211, 1024, 1); add_211 = None + getitem_1812 = split_186[0] + getitem_1813 = split_186[1] + getitem_1814 = split_186[2] + getitem_1815 = split_186[3] + getitem_1816 = split_186[4] + getitem_1817 = split_186[5] + getitem_1818 = split_186[6] + getitem_1819 = split_186[7]; split_186 = None + cat_178 = torch.ops.aten.cat.default([getitem_1812, getitem_1813, getitem_1814, getitem_1815, getitem_1816, getitem_1817, getitem_1818, getitem_1819]); getitem_1812 = getitem_1813 = getitem_1814 = getitem_1815 = getitem_1816 = getitem_1817 = getitem_1818 = getitem_1819 = None + reduce_scatter_tensor_198 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_178, 'sum', 8, '1'); cat_178 = None + wait_tensor_602 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_198); reduce_scatter_tensor_198 = None + convert_element_type_1716 = torch.ops.prims.convert_element_type.default(wait_tensor_602, torch.float32); wait_tensor_602 = None + convert_element_type_1718 = torch.ops.prims.convert_element_type.default(wait_tensor_262, torch.float32); wait_tensor_262 = None + mul_498 = torch.ops.aten.mul.Tensor(convert_element_type_1716, convert_element_type_1718); convert_element_type_1718 = None + mul_500 = torch.ops.aten.mul.Tensor(mul_160, mul_498) + sum_73 = torch.ops.aten.sum.dim_IntList(mul_500, [2], True); mul_500 = None + div_24 = torch.ops.aten.div.Tensor(mul_160, 4096) + mul_501 = torch.ops.aten.mul.Tensor(div_24, sum_73); div_24 = sum_73 = None + sub_37 = torch.ops.aten.sub.Tensor(mul_498, mul_501); mul_498 = mul_501 = None + mul_502 = torch.ops.aten.mul.Tensor(sub_37, rsqrt_40); sub_37 = rsqrt_40 = None + mul_503 = torch.ops.aten.mul.Tensor(convert_element_type_1716, mul_160); convert_element_type_1716 = mul_160 = None + sum_74 = torch.ops.aten.sum.dim_IntList(mul_503, [0, 1]); mul_503 = None + convert_element_type_1719 = torch.ops.prims.convert_element_type.default(mul_502, torch.bfloat16); mul_502 = None + convert_element_type_1720 = torch.ops.prims.convert_element_type.default(sum_74, torch.bfloat16); sum_74 = None + all_reduce_24 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1720, 'sum', '1'); convert_element_type_1720 = None + wait_tensor_603 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_24); all_reduce_24 = None + convert_element_type_1721 = torch.ops.prims.convert_element_type.default(wait_tensor_603, torch.float32); wait_tensor_603 = None + reduce_scatter_tensor_199 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1721, 'avg', 32, '0'); convert_element_type_1721 = None + wait_tensor_604 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_199); reduce_scatter_tensor_199 = None + add_212 = torch.ops.aten.add.Tensor(add_209, convert_element_type_1719); add_209 = convert_element_type_1719 = None + all_gather_into_tensor_380 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_212, 8, '1') + wait_tensor_605 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_380); all_gather_into_tensor_380 = None + split_187 = torch.ops.aten.split.Tensor(wait_tensor_605, 2); wait_tensor_605 = None + getitem_1820 = split_187[0] + getitem_1821 = split_187[1] + getitem_1822 = split_187[2] + getitem_1823 = split_187[3] + getitem_1824 = split_187[4] + getitem_1825 = split_187[5] + getitem_1826 = split_187[6] + getitem_1827 = split_187[7]; split_187 = None + cat_179 = torch.ops.aten.cat.default([getitem_1820, getitem_1821, getitem_1822, getitem_1823, getitem_1824, getitem_1825, getitem_1826, getitem_1827], 1); getitem_1820 = getitem_1821 = getitem_1822 = getitem_1823 = getitem_1824 = getitem_1825 = getitem_1826 = getitem_1827 = None + view_2611 = torch.ops.aten.view.default(cat_179, [16384, 4096]); cat_179 = None + permute_741 = torch.ops.aten.permute.default(view_2611, [1, 0]) + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39); reduce_scatter_tensor_39 = None + add_77 = torch.ops.aten.add.Tensor(add_75, wait_tensor_255); wait_tensor_255 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 32, '0'); convert_element_type_647 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32); add_77 = None + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_256) + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_649, 8, '1'); convert_element_type_649 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + split_87 = torch.ops.aten.split.Tensor(wait_tensor_257, 2); wait_tensor_257 = None + getitem_876 = split_87[0] + getitem_877 = split_87[1] + getitem_878 = split_87[2] + getitem_879 = split_87[3] + getitem_880 = split_87[4] + getitem_881 = split_87[5] + getitem_882 = split_87[6] + getitem_883 = split_87[7]; split_87 = None + cat_79 = torch.ops.aten.cat.default([getitem_876, getitem_877, getitem_878, getitem_879, getitem_880, getitem_881, getitem_882, getitem_883], 1); getitem_876 = getitem_877 = getitem_878 = getitem_879 = getitem_880 = getitem_881 = getitem_882 = getitem_883 = None + view_1428 = torch.ops.aten.view.default(cat_79, [16384, 4096]); cat_79 = None + view_1429 = torch.ops.aten.view.default(mm_137, [2, 8192, 1792]); mm_137 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_1429, torch.float32); view_1429 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16); primals_182 = None + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 32, '0'); convert_element_type_655 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + mm_138 = torch.ops.aten.mm.default(view_1428, permute_218) + view_1436 = torch.ops.aten.view.default(mm_138, [2, 8192, 1792]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_1436) + view_1443 = torch.ops.aten.view.default(mul_159, [16384, 1792]); mul_159 = None + mm_395 = torch.ops.aten.mm.default(permute_741, view_1443); permute_741 = view_1443 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16); primals_183 = None + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 32, '0'); convert_element_type_658 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + permute_743 = torch.ops.aten.permute.default(permute_219, [1, 0]); permute_219 = None + mm_396 = torch.ops.aten.mm.default(view_2611, permute_743); view_2611 = permute_743 = None + view_2612 = torch.ops.aten.view.default(mm_396, [2, 8192, 1792]); mm_396 = None + convert_element_type_1726 = torch.ops.prims.convert_element_type.default(mm_395, torch.float32); mm_395 = None + reduce_scatter_tensor_200 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1726, 'avg', 32, '0'); convert_element_type_1726 = None + wait_tensor_606 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_200); reduce_scatter_tensor_200 = None + mul_504 = torch.ops.aten.mul.Tensor(view_2612, convert_element_type_654); convert_element_type_654 = None + mul_505 = torch.ops.aten.mul.Tensor(view_2612, view_1436); view_2612 = view_1436 = None + view_2613 = torch.ops.aten.view.default(mul_504, [16384, 1792]); mul_504 = None + permute_745 = torch.ops.aten.permute.default(view_2613, [1, 0]) + mm_397 = torch.ops.aten.mm.default(permute_745, view_1428); permute_745 = None + permute_747 = torch.ops.aten.permute.default(permute_218, [1, 0]); permute_218 = None + mm_398 = torch.ops.aten.mm.default(view_2613, permute_747); view_2613 = permute_747 = None + view_2614 = torch.ops.aten.view.default(mm_398, [2, 8192, 4096]); mm_398 = None + convert_element_type_1731 = torch.ops.prims.convert_element_type.default(mm_397, torch.float32); mm_397 = None + reduce_scatter_tensor_201 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1731, 'avg', 32, '0'); convert_element_type_1731 = None + wait_tensor_607 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_201); reduce_scatter_tensor_201 = None + convert_element_type_1732 = torch.ops.prims.convert_element_type.default(mul_505, torch.float32); mul_505 = None + neg_12 = torch.ops.aten.neg.default(convert_element_type_653) + exp_12 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_213 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + reciprocal_12 = torch.ops.aten.reciprocal.default(add_213); add_213 = None + mul_506 = torch.ops.aten.mul.Tensor(reciprocal_12, 1); reciprocal_12 = None + mul_507 = torch.ops.aten.mul.Tensor(convert_element_type_1732, mul_506); convert_element_type_1732 = None + sub_38 = torch.ops.aten.sub.Tensor(1, mul_506); mul_506 = None + mul_508 = torch.ops.aten.mul.Tensor(convert_element_type_653, sub_38); convert_element_type_653 = sub_38 = None + add_214 = torch.ops.aten.add.Tensor(mul_508, 1); mul_508 = None + mul_509 = torch.ops.aten.mul.Tensor(mul_507, add_214); mul_507 = add_214 = None + convert_element_type_1734 = torch.ops.prims.convert_element_type.default(mul_509, torch.bfloat16); mul_509 = None + view_2615 = torch.ops.aten.view.default(convert_element_type_1734, [16384, 1792]); convert_element_type_1734 = None + permute_749 = torch.ops.aten.permute.default(view_2615, [1, 0]) + mm_399 = torch.ops.aten.mm.default(permute_749, view_1428); permute_749 = view_1428 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16); primals_181 = None + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 32, '0'); convert_element_type_650 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_258, [1, 0]); wait_tensor_258 = None + permute_751 = torch.ops.aten.permute.default(permute_217, [1, 0]); permute_217 = None + mm_400 = torch.ops.aten.mm.default(view_2615, permute_751); view_2615 = permute_751 = None + view_2616 = torch.ops.aten.view.default(mm_400, [2, 8192, 4096]); mm_400 = None + add_215 = torch.ops.aten.add.Tensor(view_2614, view_2616); view_2614 = view_2616 = None + convert_element_type_1739 = torch.ops.prims.convert_element_type.default(mm_399, torch.float32); mm_399 = None + reduce_scatter_tensor_202 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1739, 'avg', 32, '0'); convert_element_type_1739 = None + wait_tensor_608 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_202); reduce_scatter_tensor_202 = None + split_188 = torch.ops.aten.split.Tensor(add_215, 1024, 1); add_215 = None + getitem_1828 = split_188[0] + getitem_1829 = split_188[1] + getitem_1830 = split_188[2] + getitem_1831 = split_188[3] + getitem_1832 = split_188[4] + getitem_1833 = split_188[5] + getitem_1834 = split_188[6] + getitem_1835 = split_188[7]; split_188 = None + cat_180 = torch.ops.aten.cat.default([getitem_1828, getitem_1829, getitem_1830, getitem_1831, getitem_1832, getitem_1833, getitem_1834, getitem_1835]); getitem_1828 = getitem_1829 = getitem_1830 = getitem_1831 = getitem_1832 = getitem_1833 = getitem_1834 = getitem_1835 = None + reduce_scatter_tensor_203 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_180, 'sum', 8, '1'); cat_180 = None + wait_tensor_609 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_203); reduce_scatter_tensor_203 = None + convert_element_type_1740 = torch.ops.prims.convert_element_type.default(wait_tensor_609, torch.float32); wait_tensor_609 = None + convert_element_type_1742 = torch.ops.prims.convert_element_type.default(wait_tensor_256, torch.float32); wait_tensor_256 = None + mul_510 = torch.ops.aten.mul.Tensor(convert_element_type_1740, convert_element_type_1742); convert_element_type_1742 = None + mul_512 = torch.ops.aten.mul.Tensor(mul_156, mul_510) + sum_75 = torch.ops.aten.sum.dim_IntList(mul_512, [2], True); mul_512 = None + div_25 = torch.ops.aten.div.Tensor(mul_156, 4096) + mul_513 = torch.ops.aten.mul.Tensor(div_25, sum_75); div_25 = sum_75 = None + sub_39 = torch.ops.aten.sub.Tensor(mul_510, mul_513); mul_510 = mul_513 = None + mul_514 = torch.ops.aten.mul.Tensor(sub_39, rsqrt_39); sub_39 = rsqrt_39 = None + mul_515 = torch.ops.aten.mul.Tensor(convert_element_type_1740, mul_156); convert_element_type_1740 = mul_156 = None + sum_76 = torch.ops.aten.sum.dim_IntList(mul_515, [0, 1]); mul_515 = None + convert_element_type_1743 = torch.ops.prims.convert_element_type.default(mul_514, torch.bfloat16); mul_514 = None + convert_element_type_1744 = torch.ops.prims.convert_element_type.default(sum_76, torch.bfloat16); sum_76 = None + all_reduce_25 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1744, 'sum', '1'); convert_element_type_1744 = None + wait_tensor_610 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_25); all_reduce_25 = None + convert_element_type_1745 = torch.ops.prims.convert_element_type.default(wait_tensor_610, torch.float32); wait_tensor_610 = None + reduce_scatter_tensor_204 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1745, 'avg', 32, '0'); convert_element_type_1745 = None + wait_tensor_611 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_204); reduce_scatter_tensor_204 = None + add_216 = torch.ops.aten.add.Tensor(add_212, convert_element_type_1743); add_212 = convert_element_type_1743 = None + all_gather_into_tensor_381 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_216, 8, '1') + wait_tensor_612 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_381); all_gather_into_tensor_381 = None + split_189 = torch.ops.aten.split.Tensor(wait_tensor_612, 2); wait_tensor_612 = None + getitem_1836 = split_189[0] + getitem_1837 = split_189[1] + getitem_1838 = split_189[2] + getitem_1839 = split_189[3] + getitem_1840 = split_189[4] + getitem_1841 = split_189[5] + getitem_1842 = split_189[6] + getitem_1843 = split_189[7]; split_189 = None + cat_181 = torch.ops.aten.cat.default([getitem_1836, getitem_1837, getitem_1838, getitem_1839, getitem_1840, getitem_1841, getitem_1842, getitem_1843], 1); getitem_1836 = getitem_1837 = getitem_1838 = getitem_1839 = getitem_1840 = getitem_1841 = getitem_1842 = getitem_1843 = None + view_2617 = torch.ops.aten.view.default(cat_181, [16384, 4096]); cat_181 = None + permute_753 = torch.ops.aten.permute.default(view_2617, [1, 0]) + permute_215 = torch.ops.aten.permute.default(getitem_859, [0, 2, 1, 3]) + view_1410 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + view_1416 = torch.ops.aten.view.default(view_1410, [16384, 512]); view_1410 = None + mm_401 = torch.ops.aten.mm.default(permute_753, view_1416); permute_753 = view_1416 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 32, '0'); convert_element_type_644 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + permute_755 = torch.ops.aten.permute.default(permute_216, [1, 0]); permute_216 = None + mm_402 = torch.ops.aten.mm.default(view_2617, permute_755); view_2617 = permute_755 = None + view_2618 = torch.ops.aten.view.default(mm_402, [2, 8192, 512]); mm_402 = None + convert_element_type_1750 = torch.ops.prims.convert_element_type.default(mm_401, torch.float32); mm_401 = None + reduce_scatter_tensor_205 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1750, 'avg', 32, '0'); convert_element_type_1750 = None + wait_tensor_613 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_205); reduce_scatter_tensor_205 = None + view_2619 = torch.ops.aten.view.default(view_2618, [2, 8192, 4, 128]); view_2618 = None + permute_757 = torch.ops.aten.permute.default(view_2619, [0, 2, 1, 3]); view_2619 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16); primals_175 = None + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 32, '0'); convert_element_type_628 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32); add_75 = None + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_249) + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_630, 8, '1'); convert_element_type_630 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + split_85 = torch.ops.aten.split.Tensor(wait_tensor_250, 2); wait_tensor_250 = None + getitem_851 = split_85[0] + getitem_852 = split_85[1] + getitem_853 = split_85[2] + getitem_854 = split_85[3] + getitem_855 = split_85[4] + getitem_856 = split_85[5] + getitem_857 = split_85[6] + getitem_858 = split_85[7]; split_85 = None + cat_77 = torch.ops.aten.cat.default([getitem_851, getitem_852, getitem_853, getitem_854, getitem_855, getitem_856, getitem_857, getitem_858], 1); getitem_851 = getitem_852 = getitem_853 = getitem_854 = getitem_855 = getitem_856 = getitem_857 = getitem_858 = None + view_1383 = torch.ops.aten.view.default(cat_77, [16384, 4096]); cat_77 = None + view_1384 = torch.ops.aten.view.default(mm_133, [2, 8192, 512]); mm_133 = None + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16); primals_177 = None + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 32, '0'); convert_element_type_634 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + mm_134 = torch.ops.aten.mm.default(view_1383, permute_210) + view_1391 = torch.ops.aten.view.default(mm_134, [2, 8192, 128]); mm_134 = None + view_1398 = torch.ops.aten.view.default(mm_135, [2, 8192, 128]); mm_135 = None + view_1400 = torch.ops.aten.view.default(view_1384, [2, 8192, -1, 128]); view_1384 = None + view_1401 = torch.ops.aten.view.default(view_1391, [2, 8192, -1, 128]); view_1391 = None + view_1402 = torch.ops.aten.view.default(view_1398, [2, 8192, -1, 128]); view_1398 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_1400, torch.float32); view_1400 = None + view_1403 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 4, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_1403); view_1403 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_1401, torch.float32); view_1401 = None + view_1404 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 1, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_1404); view_1404 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_37); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_1406 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 4, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_37); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_1407 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 1, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_1406, torch.bfloat16); view_1406 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_1407, torch.bfloat16); view_1407 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 1, 4, 128]); unsqueeze_38 = None + view_1408 = torch.ops.aten.view.default(expand_38, [2, 8192, 4, 128]); expand_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_1402, 3); view_1402 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 1, 4, 128]); unsqueeze_39 = None + view_1409 = torch.ops.aten.view.default(expand_39, [2, 8192, 4, 128]); expand_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_1408, [0, 2, 1, 3]); view_1408 = None + permute_214 = torch.ops.aten.permute.default(view_1409, [0, 2, 1, 3]); view_1409 = None + _scaled_dot_product_cudnn_attention_backward_12 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_757, permute_212, permute_213, permute_214, getitem_859, getitem_860, getitem_865, getitem_866, None, None, None, 8192, 8192, 0.0, True); permute_757 = permute_212 = permute_213 = permute_214 = getitem_859 = getitem_860 = getitem_865 = getitem_866 = None + getitem_1844 = _scaled_dot_product_cudnn_attention_backward_12[0] + getitem_1845 = _scaled_dot_product_cudnn_attention_backward_12[1] + getitem_1846 = _scaled_dot_product_cudnn_attention_backward_12[2]; _scaled_dot_product_cudnn_attention_backward_12 = None + permute_758 = torch.ops.aten.permute.default(getitem_1846, [0, 2, 1, 3]); getitem_1846 = None + permute_759 = torch.ops.aten.permute.default(getitem_1845, [0, 2, 1, 3]); getitem_1845 = None + permute_760 = torch.ops.aten.permute.default(getitem_1844, [0, 2, 1, 3]); getitem_1844 = None + view_2620 = torch.ops.aten.view.default(permute_758, [2, 8192, 1, 4, 128]); permute_758 = None + sum_77 = torch.ops.aten.sum.dim_IntList(view_2620, [3], True); view_2620 = None + squeeze_24 = torch.ops.aten.squeeze.dim(sum_77, 3); sum_77 = None + view_2621 = torch.ops.aten.view.default(permute_759, [2, 8192, 1, 4, 128]); permute_759 = None + sum_78 = torch.ops.aten.sum.dim_IntList(view_2621, [3], True); view_2621 = None + squeeze_25 = torch.ops.aten.squeeze.dim(sum_78, 3); sum_78 = None + convert_element_type_1751 = torch.ops.prims.convert_element_type.default(squeeze_25, torch.float32); squeeze_25 = None + convert_element_type_1752 = torch.ops.prims.convert_element_type.default(permute_760, torch.float32); permute_760 = None + view_2622 = torch.ops.aten.view.default(convert_element_type_1751, [2, 8192, 1, 64, 2]); convert_element_type_1751 = None + view_as_complex_88 = torch.ops.aten.view_as_complex.default(view_2622); view_2622 = None + mul_516 = torch.ops.aten.mul.Tensor(view_as_complex_88, _conj); view_as_complex_88 = None + view_2623 = torch.ops.aten.view.default(convert_element_type_1752, [2, 8192, 4, 64, 2]); convert_element_type_1752 = None + view_as_complex_89 = torch.ops.aten.view_as_complex.default(view_2623); view_2623 = None + mul_517 = torch.ops.aten.mul.Tensor(view_as_complex_89, _conj); view_as_complex_89 = None + view_as_real_88 = torch.ops.aten.view_as_real.default(mul_516); mul_516 = None + view_2624 = torch.ops.aten.view.default(view_as_real_88, [2, 8192, 1, 128]); view_as_real_88 = None + convert_element_type_1753 = torch.ops.prims.convert_element_type.default(view_2624, torch.bfloat16); view_2624 = None + view_as_real_89 = torch.ops.aten.view_as_real.default(mul_517); mul_517 = None + view_2625 = torch.ops.aten.view.default(view_as_real_89, [2, 8192, 4, 128]); view_as_real_89 = None + convert_element_type_1754 = torch.ops.prims.convert_element_type.default(view_2625, torch.bfloat16); view_2625 = None + view_2626 = torch.ops.aten.view.default(squeeze_24, [2, 8192, 128]); squeeze_24 = None + view_2627 = torch.ops.aten.view.default(convert_element_type_1753, [2, 8192, 128]); convert_element_type_1753 = None + view_2628 = torch.ops.aten.view.default(convert_element_type_1754, [2, 8192, 512]); convert_element_type_1754 = None + view_2629 = torch.ops.aten.view.default(view_2626, [16384, 128]); view_2626 = None + permute_761 = torch.ops.aten.permute.default(view_2629, [1, 0]) + mm_403 = torch.ops.aten.mm.default(permute_761, view_1383); permute_761 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16); primals_178 = None + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 32, '0'); convert_element_type_637 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_253, [1, 0]); wait_tensor_253 = None + permute_763 = torch.ops.aten.permute.default(permute_211, [1, 0]); permute_211 = None + mm_404 = torch.ops.aten.mm.default(view_2629, permute_763); view_2629 = permute_763 = None + view_2630 = torch.ops.aten.view.default(mm_404, [2, 8192, 4096]); mm_404 = None + convert_element_type_1759 = torch.ops.prims.convert_element_type.default(mm_403, torch.float32); mm_403 = None + reduce_scatter_tensor_206 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1759, 'avg', 32, '0'); convert_element_type_1759 = None + wait_tensor_614 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_206); reduce_scatter_tensor_206 = None + view_2631 = torch.ops.aten.view.default(view_2627, [16384, 128]); view_2627 = None + permute_765 = torch.ops.aten.permute.default(view_2631, [1, 0]) + mm_405 = torch.ops.aten.mm.default(permute_765, view_1383); permute_765 = None + permute_767 = torch.ops.aten.permute.default(permute_210, [1, 0]); permute_210 = None + mm_406 = torch.ops.aten.mm.default(view_2631, permute_767); view_2631 = permute_767 = None + view_2632 = torch.ops.aten.view.default(mm_406, [2, 8192, 4096]); mm_406 = None + add_217 = torch.ops.aten.add.Tensor(view_2630, view_2632); view_2630 = view_2632 = None + convert_element_type_1764 = torch.ops.prims.convert_element_type.default(mm_405, torch.float32); mm_405 = None + reduce_scatter_tensor_207 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1764, 'avg', 32, '0'); convert_element_type_1764 = None + wait_tensor_615 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_207); reduce_scatter_tensor_207 = None + view_2633 = torch.ops.aten.view.default(view_2628, [16384, 512]); view_2628 = None + permute_769 = torch.ops.aten.permute.default(view_2633, [1, 0]) + mm_407 = torch.ops.aten.mm.default(permute_769, view_1383); permute_769 = view_1383 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16); primals_176 = None + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 32, '0'); convert_element_type_631 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + permute_771 = torch.ops.aten.permute.default(permute_209, [1, 0]); permute_209 = None + mm_408 = torch.ops.aten.mm.default(view_2633, permute_771); view_2633 = permute_771 = None + view_2634 = torch.ops.aten.view.default(mm_408, [2, 8192, 4096]); mm_408 = None + add_218 = torch.ops.aten.add.Tensor(add_217, view_2634); add_217 = view_2634 = None + convert_element_type_1769 = torch.ops.prims.convert_element_type.default(mm_407, torch.float32); mm_407 = None + reduce_scatter_tensor_208 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1769, 'avg', 32, '0'); convert_element_type_1769 = None + wait_tensor_616 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_208); reduce_scatter_tensor_208 = None + split_190 = torch.ops.aten.split.Tensor(add_218, 1024, 1); add_218 = None + getitem_1847 = split_190[0] + getitem_1848 = split_190[1] + getitem_1849 = split_190[2] + getitem_1850 = split_190[3] + getitem_1851 = split_190[4] + getitem_1852 = split_190[5] + getitem_1853 = split_190[6] + getitem_1854 = split_190[7]; split_190 = None + cat_182 = torch.ops.aten.cat.default([getitem_1847, getitem_1848, getitem_1849, getitem_1850, getitem_1851, getitem_1852, getitem_1853, getitem_1854]); getitem_1847 = getitem_1848 = getitem_1849 = getitem_1850 = getitem_1851 = getitem_1852 = getitem_1853 = getitem_1854 = None + reduce_scatter_tensor_209 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_182, 'sum', 8, '1'); cat_182 = None + wait_tensor_617 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_209); reduce_scatter_tensor_209 = None + convert_element_type_1770 = torch.ops.prims.convert_element_type.default(wait_tensor_617, torch.float32); wait_tensor_617 = None + convert_element_type_1772 = torch.ops.prims.convert_element_type.default(wait_tensor_249, torch.float32); wait_tensor_249 = None + mul_518 = torch.ops.aten.mul.Tensor(convert_element_type_1770, convert_element_type_1772); convert_element_type_1772 = None + mul_520 = torch.ops.aten.mul.Tensor(mul_152, mul_518) + sum_79 = torch.ops.aten.sum.dim_IntList(mul_520, [2], True); mul_520 = None + div_26 = torch.ops.aten.div.Tensor(mul_152, 4096) + mul_521 = torch.ops.aten.mul.Tensor(div_26, sum_79); div_26 = sum_79 = None + sub_40 = torch.ops.aten.sub.Tensor(mul_518, mul_521); mul_518 = mul_521 = None + mul_522 = torch.ops.aten.mul.Tensor(sub_40, rsqrt_38); sub_40 = rsqrt_38 = None + mul_523 = torch.ops.aten.mul.Tensor(convert_element_type_1770, mul_152); convert_element_type_1770 = mul_152 = None + sum_80 = torch.ops.aten.sum.dim_IntList(mul_523, [0, 1]); mul_523 = None + convert_element_type_1773 = torch.ops.prims.convert_element_type.default(mul_522, torch.bfloat16); mul_522 = None + convert_element_type_1774 = torch.ops.prims.convert_element_type.default(sum_80, torch.bfloat16); sum_80 = None + all_reduce_26 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1774, 'sum', '1'); convert_element_type_1774 = None + wait_tensor_618 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_26); all_reduce_26 = None + convert_element_type_1775 = torch.ops.prims.convert_element_type.default(wait_tensor_618, torch.float32); wait_tensor_618 = None + reduce_scatter_tensor_210 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1775, 'avg', 32, '0'); convert_element_type_1775 = None + wait_tensor_619 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_210); reduce_scatter_tensor_210 = None + add_219 = torch.ops.aten.add.Tensor(add_216, convert_element_type_1773); add_216 = convert_element_type_1773 = None + all_gather_into_tensor_382 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_219, 8, '1') + wait_tensor_620 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_382); all_gather_into_tensor_382 = None + split_191 = torch.ops.aten.split.Tensor(wait_tensor_620, 2); wait_tensor_620 = None + getitem_1855 = split_191[0] + getitem_1856 = split_191[1] + getitem_1857 = split_191[2] + getitem_1858 = split_191[3] + getitem_1859 = split_191[4] + getitem_1860 = split_191[5] + getitem_1861 = split_191[6] + getitem_1862 = split_191[7]; split_191 = None + cat_183 = torch.ops.aten.cat.default([getitem_1855, getitem_1856, getitem_1857, getitem_1858, getitem_1859, getitem_1860, getitem_1861, getitem_1862], 1); getitem_1855 = getitem_1856 = getitem_1857 = getitem_1858 = getitem_1859 = getitem_1860 = getitem_1861 = getitem_1862 = None + view_2635 = torch.ops.aten.view.default(cat_183, [16384, 4096]); cat_183 = None + permute_773 = torch.ops.aten.permute.default(view_2635, [1, 0]) + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37); reduce_scatter_tensor_37 = None + add_73 = torch.ops.aten.add.Tensor(add_71, wait_tensor_242); wait_tensor_242 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16); primals_171 = None + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 32, '0'); convert_element_type_614 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32); add_73 = None + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_243) + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_616, 8, '1'); convert_element_type_616 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + split_83 = torch.ops.aten.split.Tensor(wait_tensor_244, 2); wait_tensor_244 = None + getitem_835 = split_83[0] + getitem_836 = split_83[1] + getitem_837 = split_83[2] + getitem_838 = split_83[3] + getitem_839 = split_83[4] + getitem_840 = split_83[5] + getitem_841 = split_83[6] + getitem_842 = split_83[7]; split_83 = None + cat_75 = torch.ops.aten.cat.default([getitem_835, getitem_836, getitem_837, getitem_838, getitem_839, getitem_840, getitem_841, getitem_842], 1); getitem_835 = getitem_836 = getitem_837 = getitem_838 = getitem_839 = getitem_840 = getitem_841 = getitem_842 = None + view_1356 = torch.ops.aten.view.default(cat_75, [16384, 4096]); cat_75 = None + view_1357 = torch.ops.aten.view.default(mm_130, [2, 8192, 1792]); mm_130 = None + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_1357, torch.float32); view_1357 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16); primals_173 = None + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 32, '0'); convert_element_type_622 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_131 = torch.ops.aten.mm.default(view_1356, permute_207) + view_1364 = torch.ops.aten.view.default(mm_131, [2, 8192, 1792]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_1364) + view_1371 = torch.ops.aten.view.default(mul_151, [16384, 1792]); mul_151 = None + mm_409 = torch.ops.aten.mm.default(permute_773, view_1371); permute_773 = view_1371 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16); primals_174 = None + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 32, '0'); convert_element_type_625 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + permute_775 = torch.ops.aten.permute.default(permute_208, [1, 0]); permute_208 = None + mm_410 = torch.ops.aten.mm.default(view_2635, permute_775); view_2635 = permute_775 = None + view_2636 = torch.ops.aten.view.default(mm_410, [2, 8192, 1792]); mm_410 = None + convert_element_type_1780 = torch.ops.prims.convert_element_type.default(mm_409, torch.float32); mm_409 = None + reduce_scatter_tensor_211 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1780, 'avg', 32, '0'); convert_element_type_1780 = None + wait_tensor_621 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_211); reduce_scatter_tensor_211 = None + mul_524 = torch.ops.aten.mul.Tensor(view_2636, convert_element_type_621); convert_element_type_621 = None + mul_525 = torch.ops.aten.mul.Tensor(view_2636, view_1364); view_2636 = view_1364 = None + view_2637 = torch.ops.aten.view.default(mul_524, [16384, 1792]); mul_524 = None + permute_777 = torch.ops.aten.permute.default(view_2637, [1, 0]) + mm_411 = torch.ops.aten.mm.default(permute_777, view_1356); permute_777 = None + permute_779 = torch.ops.aten.permute.default(permute_207, [1, 0]); permute_207 = None + mm_412 = torch.ops.aten.mm.default(view_2637, permute_779); view_2637 = permute_779 = None + view_2638 = torch.ops.aten.view.default(mm_412, [2, 8192, 4096]); mm_412 = None + convert_element_type_1785 = torch.ops.prims.convert_element_type.default(mm_411, torch.float32); mm_411 = None + reduce_scatter_tensor_212 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1785, 'avg', 32, '0'); convert_element_type_1785 = None + wait_tensor_622 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_212); reduce_scatter_tensor_212 = None + convert_element_type_1786 = torch.ops.prims.convert_element_type.default(mul_525, torch.float32); mul_525 = None + neg_13 = torch.ops.aten.neg.default(convert_element_type_620) + exp_13 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_220 = torch.ops.aten.add.Tensor(exp_13, 1); exp_13 = None + reciprocal_13 = torch.ops.aten.reciprocal.default(add_220); add_220 = None + mul_526 = torch.ops.aten.mul.Tensor(reciprocal_13, 1); reciprocal_13 = None + mul_527 = torch.ops.aten.mul.Tensor(convert_element_type_1786, mul_526); convert_element_type_1786 = None + sub_41 = torch.ops.aten.sub.Tensor(1, mul_526); mul_526 = None + mul_528 = torch.ops.aten.mul.Tensor(convert_element_type_620, sub_41); convert_element_type_620 = sub_41 = None + add_221 = torch.ops.aten.add.Tensor(mul_528, 1); mul_528 = None + mul_529 = torch.ops.aten.mul.Tensor(mul_527, add_221); mul_527 = add_221 = None + convert_element_type_1788 = torch.ops.prims.convert_element_type.default(mul_529, torch.bfloat16); mul_529 = None + view_2639 = torch.ops.aten.view.default(convert_element_type_1788, [16384, 1792]); convert_element_type_1788 = None + permute_781 = torch.ops.aten.permute.default(view_2639, [1, 0]) + mm_413 = torch.ops.aten.mm.default(permute_781, view_1356); permute_781 = view_1356 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16); primals_172 = None + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 32, '0'); convert_element_type_617 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + permute_783 = torch.ops.aten.permute.default(permute_206, [1, 0]); permute_206 = None + mm_414 = torch.ops.aten.mm.default(view_2639, permute_783); view_2639 = permute_783 = None + view_2640 = torch.ops.aten.view.default(mm_414, [2, 8192, 4096]); mm_414 = None + add_222 = torch.ops.aten.add.Tensor(view_2638, view_2640); view_2638 = view_2640 = None + convert_element_type_1793 = torch.ops.prims.convert_element_type.default(mm_413, torch.float32); mm_413 = None + reduce_scatter_tensor_213 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1793, 'avg', 32, '0'); convert_element_type_1793 = None + wait_tensor_623 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_213); reduce_scatter_tensor_213 = None + split_192 = torch.ops.aten.split.Tensor(add_222, 1024, 1); add_222 = None + getitem_1863 = split_192[0] + getitem_1864 = split_192[1] + getitem_1865 = split_192[2] + getitem_1866 = split_192[3] + getitem_1867 = split_192[4] + getitem_1868 = split_192[5] + getitem_1869 = split_192[6] + getitem_1870 = split_192[7]; split_192 = None + cat_184 = torch.ops.aten.cat.default([getitem_1863, getitem_1864, getitem_1865, getitem_1866, getitem_1867, getitem_1868, getitem_1869, getitem_1870]); getitem_1863 = getitem_1864 = getitem_1865 = getitem_1866 = getitem_1867 = getitem_1868 = getitem_1869 = getitem_1870 = None + reduce_scatter_tensor_214 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_184, 'sum', 8, '1'); cat_184 = None + wait_tensor_624 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_214); reduce_scatter_tensor_214 = None + convert_element_type_1794 = torch.ops.prims.convert_element_type.default(wait_tensor_624, torch.float32); wait_tensor_624 = None + convert_element_type_1796 = torch.ops.prims.convert_element_type.default(wait_tensor_243, torch.float32); wait_tensor_243 = None + mul_530 = torch.ops.aten.mul.Tensor(convert_element_type_1794, convert_element_type_1796); convert_element_type_1796 = None + mul_532 = torch.ops.aten.mul.Tensor(mul_148, mul_530) + sum_81 = torch.ops.aten.sum.dim_IntList(mul_532, [2], True); mul_532 = None + div_27 = torch.ops.aten.div.Tensor(mul_148, 4096) + mul_533 = torch.ops.aten.mul.Tensor(div_27, sum_81); div_27 = sum_81 = None + sub_42 = torch.ops.aten.sub.Tensor(mul_530, mul_533); mul_530 = mul_533 = None + mul_534 = torch.ops.aten.mul.Tensor(sub_42, rsqrt_37); sub_42 = rsqrt_37 = None + mul_535 = torch.ops.aten.mul.Tensor(convert_element_type_1794, mul_148); convert_element_type_1794 = mul_148 = None + sum_82 = torch.ops.aten.sum.dim_IntList(mul_535, [0, 1]); mul_535 = None + convert_element_type_1797 = torch.ops.prims.convert_element_type.default(mul_534, torch.bfloat16); mul_534 = None + convert_element_type_1798 = torch.ops.prims.convert_element_type.default(sum_82, torch.bfloat16); sum_82 = None + all_reduce_27 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1798, 'sum', '1'); convert_element_type_1798 = None + wait_tensor_625 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_27); all_reduce_27 = None + convert_element_type_1799 = torch.ops.prims.convert_element_type.default(wait_tensor_625, torch.float32); wait_tensor_625 = None + reduce_scatter_tensor_215 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1799, 'avg', 32, '0'); convert_element_type_1799 = None + wait_tensor_626 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_215); reduce_scatter_tensor_215 = None + add_223 = torch.ops.aten.add.Tensor(add_219, convert_element_type_1797); add_219 = convert_element_type_1797 = None + all_gather_into_tensor_383 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_223, 8, '1') + wait_tensor_627 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_383); all_gather_into_tensor_383 = None + split_193 = torch.ops.aten.split.Tensor(wait_tensor_627, 2); wait_tensor_627 = None + getitem_1871 = split_193[0] + getitem_1872 = split_193[1] + getitem_1873 = split_193[2] + getitem_1874 = split_193[3] + getitem_1875 = split_193[4] + getitem_1876 = split_193[5] + getitem_1877 = split_193[6] + getitem_1878 = split_193[7]; split_193 = None + cat_185 = torch.ops.aten.cat.default([getitem_1871, getitem_1872, getitem_1873, getitem_1874, getitem_1875, getitem_1876, getitem_1877, getitem_1878], 1); getitem_1871 = getitem_1872 = getitem_1873 = getitem_1874 = getitem_1875 = getitem_1876 = getitem_1877 = getitem_1878 = None + view_2641 = torch.ops.aten.view.default(cat_185, [16384, 4096]); cat_185 = None + permute_785 = torch.ops.aten.permute.default(view_2641, [1, 0]) + permute_204 = torch.ops.aten.permute.default(getitem_818, [0, 2, 1, 3]) + view_1338 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + view_1344 = torch.ops.aten.view.default(view_1338, [16384, 512]); view_1338 = None + mm_415 = torch.ops.aten.mm.default(permute_785, view_1344); permute_785 = view_1344 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16); primals_170 = None + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 32, '0'); convert_element_type_611 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + permute_787 = torch.ops.aten.permute.default(permute_205, [1, 0]); permute_205 = None + mm_416 = torch.ops.aten.mm.default(view_2641, permute_787); view_2641 = permute_787 = None + view_2642 = torch.ops.aten.view.default(mm_416, [2, 8192, 512]); mm_416 = None + convert_element_type_1804 = torch.ops.prims.convert_element_type.default(mm_415, torch.float32); mm_415 = None + reduce_scatter_tensor_216 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1804, 'avg', 32, '0'); convert_element_type_1804 = None + wait_tensor_628 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_216); reduce_scatter_tensor_216 = None + view_2643 = torch.ops.aten.view.default(view_2642, [2, 8192, 4, 128]); view_2642 = None + permute_789 = torch.ops.aten.permute.default(view_2643, [0, 2, 1, 3]); view_2643 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16); primals_166 = None + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 32, '0'); convert_element_type_595 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32); add_71 = None + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_236) + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_597, 8, '1'); convert_element_type_597 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + split_81 = torch.ops.aten.split.Tensor(wait_tensor_237, 2); wait_tensor_237 = None + getitem_810 = split_81[0] + getitem_811 = split_81[1] + getitem_812 = split_81[2] + getitem_813 = split_81[3] + getitem_814 = split_81[4] + getitem_815 = split_81[5] + getitem_816 = split_81[6] + getitem_817 = split_81[7]; split_81 = None + cat_73 = torch.ops.aten.cat.default([getitem_810, getitem_811, getitem_812, getitem_813, getitem_814, getitem_815, getitem_816, getitem_817], 1); getitem_810 = getitem_811 = getitem_812 = getitem_813 = getitem_814 = getitem_815 = getitem_816 = getitem_817 = None + view_1311 = torch.ops.aten.view.default(cat_73, [16384, 4096]); cat_73 = None + view_1312 = torch.ops.aten.view.default(mm_126, [2, 8192, 512]); mm_126 = None + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16); primals_168 = None + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 32, '0'); convert_element_type_601 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + mm_127 = torch.ops.aten.mm.default(view_1311, permute_199) + view_1319 = torch.ops.aten.view.default(mm_127, [2, 8192, 128]); mm_127 = None + view_1326 = torch.ops.aten.view.default(mm_128, [2, 8192, 128]); mm_128 = None + view_1328 = torch.ops.aten.view.default(view_1312, [2, 8192, -1, 128]); view_1312 = None + view_1329 = torch.ops.aten.view.default(view_1319, [2, 8192, -1, 128]); view_1319 = None + view_1330 = torch.ops.aten.view.default(view_1326, [2, 8192, -1, 128]); view_1326 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_1328, torch.float32); view_1328 = None + view_1331 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 4, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_1331); view_1331 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_1329, torch.float32); view_1329 = None + view_1332 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 1, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_1332); view_1332 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_37); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_1334 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 4, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_37); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_1335 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 1, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_1334, torch.bfloat16); view_1334 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_1335, torch.bfloat16); view_1335 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 1, 4, 128]); unsqueeze_36 = None + view_1336 = torch.ops.aten.view.default(expand_36, [2, 8192, 4, 128]); expand_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_1330, 3); view_1330 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 1, 4, 128]); unsqueeze_37 = None + view_1337 = torch.ops.aten.view.default(expand_37, [2, 8192, 4, 128]); expand_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_1336, [0, 2, 1, 3]); view_1336 = None + permute_203 = torch.ops.aten.permute.default(view_1337, [0, 2, 1, 3]); view_1337 = None + _scaled_dot_product_cudnn_attention_backward_13 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_789, permute_201, permute_202, permute_203, getitem_818, getitem_819, getitem_824, getitem_825, None, None, None, 8192, 8192, 0.0, True); permute_789 = permute_201 = permute_202 = permute_203 = getitem_818 = getitem_819 = getitem_824 = getitem_825 = None + getitem_1879 = _scaled_dot_product_cudnn_attention_backward_13[0] + getitem_1880 = _scaled_dot_product_cudnn_attention_backward_13[1] + getitem_1881 = _scaled_dot_product_cudnn_attention_backward_13[2]; _scaled_dot_product_cudnn_attention_backward_13 = None + permute_790 = torch.ops.aten.permute.default(getitem_1881, [0, 2, 1, 3]); getitem_1881 = None + permute_791 = torch.ops.aten.permute.default(getitem_1880, [0, 2, 1, 3]); getitem_1880 = None + permute_792 = torch.ops.aten.permute.default(getitem_1879, [0, 2, 1, 3]); getitem_1879 = None + view_2644 = torch.ops.aten.view.default(permute_790, [2, 8192, 1, 4, 128]); permute_790 = None + sum_83 = torch.ops.aten.sum.dim_IntList(view_2644, [3], True); view_2644 = None + squeeze_26 = torch.ops.aten.squeeze.dim(sum_83, 3); sum_83 = None + view_2645 = torch.ops.aten.view.default(permute_791, [2, 8192, 1, 4, 128]); permute_791 = None + sum_84 = torch.ops.aten.sum.dim_IntList(view_2645, [3], True); view_2645 = None + squeeze_27 = torch.ops.aten.squeeze.dim(sum_84, 3); sum_84 = None + convert_element_type_1805 = torch.ops.prims.convert_element_type.default(squeeze_27, torch.float32); squeeze_27 = None + convert_element_type_1806 = torch.ops.prims.convert_element_type.default(permute_792, torch.float32); permute_792 = None + view_2646 = torch.ops.aten.view.default(convert_element_type_1805, [2, 8192, 1, 64, 2]); convert_element_type_1805 = None + view_as_complex_90 = torch.ops.aten.view_as_complex.default(view_2646); view_2646 = None + mul_536 = torch.ops.aten.mul.Tensor(view_as_complex_90, _conj); view_as_complex_90 = None + view_2647 = torch.ops.aten.view.default(convert_element_type_1806, [2, 8192, 4, 64, 2]); convert_element_type_1806 = None + view_as_complex_91 = torch.ops.aten.view_as_complex.default(view_2647); view_2647 = None + mul_537 = torch.ops.aten.mul.Tensor(view_as_complex_91, _conj); view_as_complex_91 = None + view_as_real_90 = torch.ops.aten.view_as_real.default(mul_536); mul_536 = None + view_2648 = torch.ops.aten.view.default(view_as_real_90, [2, 8192, 1, 128]); view_as_real_90 = None + convert_element_type_1807 = torch.ops.prims.convert_element_type.default(view_2648, torch.bfloat16); view_2648 = None + view_as_real_91 = torch.ops.aten.view_as_real.default(mul_537); mul_537 = None + view_2649 = torch.ops.aten.view.default(view_as_real_91, [2, 8192, 4, 128]); view_as_real_91 = None + convert_element_type_1808 = torch.ops.prims.convert_element_type.default(view_2649, torch.bfloat16); view_2649 = None + view_2650 = torch.ops.aten.view.default(squeeze_26, [2, 8192, 128]); squeeze_26 = None + view_2651 = torch.ops.aten.view.default(convert_element_type_1807, [2, 8192, 128]); convert_element_type_1807 = None + view_2652 = torch.ops.aten.view.default(convert_element_type_1808, [2, 8192, 512]); convert_element_type_1808 = None + view_2653 = torch.ops.aten.view.default(view_2650, [16384, 128]); view_2650 = None + permute_793 = torch.ops.aten.permute.default(view_2653, [1, 0]) + mm_417 = torch.ops.aten.mm.default(permute_793, view_1311); permute_793 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16); primals_169 = None + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 32, '0'); convert_element_type_604 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_240, [1, 0]); wait_tensor_240 = None + permute_795 = torch.ops.aten.permute.default(permute_200, [1, 0]); permute_200 = None + mm_418 = torch.ops.aten.mm.default(view_2653, permute_795); view_2653 = permute_795 = None + view_2654 = torch.ops.aten.view.default(mm_418, [2, 8192, 4096]); mm_418 = None + convert_element_type_1813 = torch.ops.prims.convert_element_type.default(mm_417, torch.float32); mm_417 = None + reduce_scatter_tensor_217 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1813, 'avg', 32, '0'); convert_element_type_1813 = None + wait_tensor_629 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_217); reduce_scatter_tensor_217 = None + view_2655 = torch.ops.aten.view.default(view_2651, [16384, 128]); view_2651 = None + permute_797 = torch.ops.aten.permute.default(view_2655, [1, 0]) + mm_419 = torch.ops.aten.mm.default(permute_797, view_1311); permute_797 = None + permute_799 = torch.ops.aten.permute.default(permute_199, [1, 0]); permute_199 = None + mm_420 = torch.ops.aten.mm.default(view_2655, permute_799); view_2655 = permute_799 = None + view_2656 = torch.ops.aten.view.default(mm_420, [2, 8192, 4096]); mm_420 = None + add_224 = torch.ops.aten.add.Tensor(view_2654, view_2656); view_2654 = view_2656 = None + convert_element_type_1818 = torch.ops.prims.convert_element_type.default(mm_419, torch.float32); mm_419 = None + reduce_scatter_tensor_218 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1818, 'avg', 32, '0'); convert_element_type_1818 = None + wait_tensor_630 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_218); reduce_scatter_tensor_218 = None + view_2657 = torch.ops.aten.view.default(view_2652, [16384, 512]); view_2652 = None + permute_801 = torch.ops.aten.permute.default(view_2657, [1, 0]) + mm_421 = torch.ops.aten.mm.default(permute_801, view_1311); permute_801 = view_1311 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16); primals_167 = None + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 32, '0'); convert_element_type_598 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + permute_803 = torch.ops.aten.permute.default(permute_198, [1, 0]); permute_198 = None + mm_422 = torch.ops.aten.mm.default(view_2657, permute_803); view_2657 = permute_803 = None + view_2658 = torch.ops.aten.view.default(mm_422, [2, 8192, 4096]); mm_422 = None + add_225 = torch.ops.aten.add.Tensor(add_224, view_2658); add_224 = view_2658 = None + convert_element_type_1823 = torch.ops.prims.convert_element_type.default(mm_421, torch.float32); mm_421 = None + reduce_scatter_tensor_219 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1823, 'avg', 32, '0'); convert_element_type_1823 = None + wait_tensor_631 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_219); reduce_scatter_tensor_219 = None + split_194 = torch.ops.aten.split.Tensor(add_225, 1024, 1); add_225 = None + getitem_1882 = split_194[0] + getitem_1883 = split_194[1] + getitem_1884 = split_194[2] + getitem_1885 = split_194[3] + getitem_1886 = split_194[4] + getitem_1887 = split_194[5] + getitem_1888 = split_194[6] + getitem_1889 = split_194[7]; split_194 = None + cat_186 = torch.ops.aten.cat.default([getitem_1882, getitem_1883, getitem_1884, getitem_1885, getitem_1886, getitem_1887, getitem_1888, getitem_1889]); getitem_1882 = getitem_1883 = getitem_1884 = getitem_1885 = getitem_1886 = getitem_1887 = getitem_1888 = getitem_1889 = None + reduce_scatter_tensor_220 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_186, 'sum', 8, '1'); cat_186 = None + wait_tensor_632 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_220); reduce_scatter_tensor_220 = None + convert_element_type_1824 = torch.ops.prims.convert_element_type.default(wait_tensor_632, torch.float32); wait_tensor_632 = None + convert_element_type_1826 = torch.ops.prims.convert_element_type.default(wait_tensor_236, torch.float32); wait_tensor_236 = None + mul_538 = torch.ops.aten.mul.Tensor(convert_element_type_1824, convert_element_type_1826); convert_element_type_1826 = None + mul_540 = torch.ops.aten.mul.Tensor(mul_144, mul_538) + sum_85 = torch.ops.aten.sum.dim_IntList(mul_540, [2], True); mul_540 = None + div_28 = torch.ops.aten.div.Tensor(mul_144, 4096) + mul_541 = torch.ops.aten.mul.Tensor(div_28, sum_85); div_28 = sum_85 = None + sub_43 = torch.ops.aten.sub.Tensor(mul_538, mul_541); mul_538 = mul_541 = None + mul_542 = torch.ops.aten.mul.Tensor(sub_43, rsqrt_36); sub_43 = rsqrt_36 = None + mul_543 = torch.ops.aten.mul.Tensor(convert_element_type_1824, mul_144); convert_element_type_1824 = mul_144 = None + sum_86 = torch.ops.aten.sum.dim_IntList(mul_543, [0, 1]); mul_543 = None + convert_element_type_1827 = torch.ops.prims.convert_element_type.default(mul_542, torch.bfloat16); mul_542 = None + convert_element_type_1828 = torch.ops.prims.convert_element_type.default(sum_86, torch.bfloat16); sum_86 = None + all_reduce_28 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1828, 'sum', '1'); convert_element_type_1828 = None + wait_tensor_633 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_28); all_reduce_28 = None + convert_element_type_1829 = torch.ops.prims.convert_element_type.default(wait_tensor_633, torch.float32); wait_tensor_633 = None + reduce_scatter_tensor_221 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1829, 'avg', 32, '0'); convert_element_type_1829 = None + wait_tensor_634 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_221); reduce_scatter_tensor_221 = None + add_226 = torch.ops.aten.add.Tensor(add_223, convert_element_type_1827); add_223 = convert_element_type_1827 = None + all_gather_into_tensor_384 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_226, 8, '1') + wait_tensor_635 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_384); all_gather_into_tensor_384 = None + split_195 = torch.ops.aten.split.Tensor(wait_tensor_635, 2); wait_tensor_635 = None + getitem_1890 = split_195[0] + getitem_1891 = split_195[1] + getitem_1892 = split_195[2] + getitem_1893 = split_195[3] + getitem_1894 = split_195[4] + getitem_1895 = split_195[5] + getitem_1896 = split_195[6] + getitem_1897 = split_195[7]; split_195 = None + cat_187 = torch.ops.aten.cat.default([getitem_1890, getitem_1891, getitem_1892, getitem_1893, getitem_1894, getitem_1895, getitem_1896, getitem_1897], 1); getitem_1890 = getitem_1891 = getitem_1892 = getitem_1893 = getitem_1894 = getitem_1895 = getitem_1896 = getitem_1897 = None + view_2659 = torch.ops.aten.view.default(cat_187, [16384, 4096]); cat_187 = None + permute_805 = torch.ops.aten.permute.default(view_2659, [1, 0]) + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35); reduce_scatter_tensor_35 = None + add_69 = torch.ops.aten.add.Tensor(add_67, wait_tensor_229); wait_tensor_229 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16); primals_162 = None + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 32, '0'); convert_element_type_581 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32); add_69 = None + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_230) + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_583, 8, '1'); convert_element_type_583 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + split_79 = torch.ops.aten.split.Tensor(wait_tensor_231, 2); wait_tensor_231 = None + getitem_794 = split_79[0] + getitem_795 = split_79[1] + getitem_796 = split_79[2] + getitem_797 = split_79[3] + getitem_798 = split_79[4] + getitem_799 = split_79[5] + getitem_800 = split_79[6] + getitem_801 = split_79[7]; split_79 = None + cat_71 = torch.ops.aten.cat.default([getitem_794, getitem_795, getitem_796, getitem_797, getitem_798, getitem_799, getitem_800, getitem_801], 1); getitem_794 = getitem_795 = getitem_796 = getitem_797 = getitem_798 = getitem_799 = getitem_800 = getitem_801 = None + view_1284 = torch.ops.aten.view.default(cat_71, [16384, 4096]); cat_71 = None + view_1285 = torch.ops.aten.view.default(mm_123, [2, 8192, 1792]); mm_123 = None + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_1285, torch.float32); view_1285 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16); primals_164 = None + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 32, '0'); convert_element_type_589 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_124 = torch.ops.aten.mm.default(view_1284, permute_196) + view_1292 = torch.ops.aten.view.default(mm_124, [2, 8192, 1792]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_1292) + view_1299 = torch.ops.aten.view.default(mul_143, [16384, 1792]); mul_143 = None + mm_423 = torch.ops.aten.mm.default(permute_805, view_1299); permute_805 = view_1299 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16); primals_165 = None + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 32, '0'); convert_element_type_592 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + permute_807 = torch.ops.aten.permute.default(permute_197, [1, 0]); permute_197 = None + mm_424 = torch.ops.aten.mm.default(view_2659, permute_807); view_2659 = permute_807 = None + view_2660 = torch.ops.aten.view.default(mm_424, [2, 8192, 1792]); mm_424 = None + convert_element_type_1834 = torch.ops.prims.convert_element_type.default(mm_423, torch.float32); mm_423 = None + reduce_scatter_tensor_222 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1834, 'avg', 32, '0'); convert_element_type_1834 = None + wait_tensor_636 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_222); reduce_scatter_tensor_222 = None + mul_544 = torch.ops.aten.mul.Tensor(view_2660, convert_element_type_588); convert_element_type_588 = None + mul_545 = torch.ops.aten.mul.Tensor(view_2660, view_1292); view_2660 = view_1292 = None + view_2661 = torch.ops.aten.view.default(mul_544, [16384, 1792]); mul_544 = None + permute_809 = torch.ops.aten.permute.default(view_2661, [1, 0]) + mm_425 = torch.ops.aten.mm.default(permute_809, view_1284); permute_809 = None + permute_811 = torch.ops.aten.permute.default(permute_196, [1, 0]); permute_196 = None + mm_426 = torch.ops.aten.mm.default(view_2661, permute_811); view_2661 = permute_811 = None + view_2662 = torch.ops.aten.view.default(mm_426, [2, 8192, 4096]); mm_426 = None + convert_element_type_1839 = torch.ops.prims.convert_element_type.default(mm_425, torch.float32); mm_425 = None + reduce_scatter_tensor_223 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1839, 'avg', 32, '0'); convert_element_type_1839 = None + wait_tensor_637 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_223); reduce_scatter_tensor_223 = None + convert_element_type_1840 = torch.ops.prims.convert_element_type.default(mul_545, torch.float32); mul_545 = None + neg_14 = torch.ops.aten.neg.default(convert_element_type_587) + exp_14 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_227 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + reciprocal_14 = torch.ops.aten.reciprocal.default(add_227); add_227 = None + mul_546 = torch.ops.aten.mul.Tensor(reciprocal_14, 1); reciprocal_14 = None + mul_547 = torch.ops.aten.mul.Tensor(convert_element_type_1840, mul_546); convert_element_type_1840 = None + sub_44 = torch.ops.aten.sub.Tensor(1, mul_546); mul_546 = None + mul_548 = torch.ops.aten.mul.Tensor(convert_element_type_587, sub_44); convert_element_type_587 = sub_44 = None + add_228 = torch.ops.aten.add.Tensor(mul_548, 1); mul_548 = None + mul_549 = torch.ops.aten.mul.Tensor(mul_547, add_228); mul_547 = add_228 = None + convert_element_type_1842 = torch.ops.prims.convert_element_type.default(mul_549, torch.bfloat16); mul_549 = None + view_2663 = torch.ops.aten.view.default(convert_element_type_1842, [16384, 1792]); convert_element_type_1842 = None + permute_813 = torch.ops.aten.permute.default(view_2663, [1, 0]) + mm_427 = torch.ops.aten.mm.default(permute_813, view_1284); permute_813 = view_1284 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16); primals_163 = None + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 32, '0'); convert_element_type_584 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + permute_815 = torch.ops.aten.permute.default(permute_195, [1, 0]); permute_195 = None + mm_428 = torch.ops.aten.mm.default(view_2663, permute_815); view_2663 = permute_815 = None + view_2664 = torch.ops.aten.view.default(mm_428, [2, 8192, 4096]); mm_428 = None + add_229 = torch.ops.aten.add.Tensor(view_2662, view_2664); view_2662 = view_2664 = None + convert_element_type_1847 = torch.ops.prims.convert_element_type.default(mm_427, torch.float32); mm_427 = None + reduce_scatter_tensor_224 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1847, 'avg', 32, '0'); convert_element_type_1847 = None + wait_tensor_638 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_224); reduce_scatter_tensor_224 = None + split_196 = torch.ops.aten.split.Tensor(add_229, 1024, 1); add_229 = None + getitem_1898 = split_196[0] + getitem_1899 = split_196[1] + getitem_1900 = split_196[2] + getitem_1901 = split_196[3] + getitem_1902 = split_196[4] + getitem_1903 = split_196[5] + getitem_1904 = split_196[6] + getitem_1905 = split_196[7]; split_196 = None + cat_188 = torch.ops.aten.cat.default([getitem_1898, getitem_1899, getitem_1900, getitem_1901, getitem_1902, getitem_1903, getitem_1904, getitem_1905]); getitem_1898 = getitem_1899 = getitem_1900 = getitem_1901 = getitem_1902 = getitem_1903 = getitem_1904 = getitem_1905 = None + reduce_scatter_tensor_225 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_188, 'sum', 8, '1'); cat_188 = None + wait_tensor_639 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_225); reduce_scatter_tensor_225 = None + convert_element_type_1848 = torch.ops.prims.convert_element_type.default(wait_tensor_639, torch.float32); wait_tensor_639 = None + convert_element_type_1850 = torch.ops.prims.convert_element_type.default(wait_tensor_230, torch.float32); wait_tensor_230 = None + mul_550 = torch.ops.aten.mul.Tensor(convert_element_type_1848, convert_element_type_1850); convert_element_type_1850 = None + mul_552 = torch.ops.aten.mul.Tensor(mul_140, mul_550) + sum_87 = torch.ops.aten.sum.dim_IntList(mul_552, [2], True); mul_552 = None + div_29 = torch.ops.aten.div.Tensor(mul_140, 4096) + mul_553 = torch.ops.aten.mul.Tensor(div_29, sum_87); div_29 = sum_87 = None + sub_45 = torch.ops.aten.sub.Tensor(mul_550, mul_553); mul_550 = mul_553 = None + mul_554 = torch.ops.aten.mul.Tensor(sub_45, rsqrt_35); sub_45 = rsqrt_35 = None + mul_555 = torch.ops.aten.mul.Tensor(convert_element_type_1848, mul_140); convert_element_type_1848 = mul_140 = None + sum_88 = torch.ops.aten.sum.dim_IntList(mul_555, [0, 1]); mul_555 = None + convert_element_type_1851 = torch.ops.prims.convert_element_type.default(mul_554, torch.bfloat16); mul_554 = None + convert_element_type_1852 = torch.ops.prims.convert_element_type.default(sum_88, torch.bfloat16); sum_88 = None + all_reduce_29 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1852, 'sum', '1'); convert_element_type_1852 = None + wait_tensor_640 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_29); all_reduce_29 = None + convert_element_type_1853 = torch.ops.prims.convert_element_type.default(wait_tensor_640, torch.float32); wait_tensor_640 = None + reduce_scatter_tensor_226 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1853, 'avg', 32, '0'); convert_element_type_1853 = None + wait_tensor_641 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_226); reduce_scatter_tensor_226 = None + add_230 = torch.ops.aten.add.Tensor(add_226, convert_element_type_1851); add_226 = convert_element_type_1851 = None + all_gather_into_tensor_385 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_230, 8, '1') + wait_tensor_642 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_385); all_gather_into_tensor_385 = None + split_197 = torch.ops.aten.split.Tensor(wait_tensor_642, 2); wait_tensor_642 = None + getitem_1906 = split_197[0] + getitem_1907 = split_197[1] + getitem_1908 = split_197[2] + getitem_1909 = split_197[3] + getitem_1910 = split_197[4] + getitem_1911 = split_197[5] + getitem_1912 = split_197[6] + getitem_1913 = split_197[7]; split_197 = None + cat_189 = torch.ops.aten.cat.default([getitem_1906, getitem_1907, getitem_1908, getitem_1909, getitem_1910, getitem_1911, getitem_1912, getitem_1913], 1); getitem_1906 = getitem_1907 = getitem_1908 = getitem_1909 = getitem_1910 = getitem_1911 = getitem_1912 = getitem_1913 = None + view_2665 = torch.ops.aten.view.default(cat_189, [16384, 4096]); cat_189 = None + permute_817 = torch.ops.aten.permute.default(view_2665, [1, 0]) + permute_193 = torch.ops.aten.permute.default(getitem_777, [0, 2, 1, 3]) + view_1266 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + view_1272 = torch.ops.aten.view.default(view_1266, [16384, 512]); view_1266 = None + mm_429 = torch.ops.aten.mm.default(permute_817, view_1272); permute_817 = view_1272 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16); primals_161 = None + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 32, '0'); convert_element_type_578 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + permute_819 = torch.ops.aten.permute.default(permute_194, [1, 0]); permute_194 = None + mm_430 = torch.ops.aten.mm.default(view_2665, permute_819); view_2665 = permute_819 = None + view_2666 = torch.ops.aten.view.default(mm_430, [2, 8192, 512]); mm_430 = None + convert_element_type_1858 = torch.ops.prims.convert_element_type.default(mm_429, torch.float32); mm_429 = None + reduce_scatter_tensor_227 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1858, 'avg', 32, '0'); convert_element_type_1858 = None + wait_tensor_643 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_227); reduce_scatter_tensor_227 = None + view_2667 = torch.ops.aten.view.default(view_2666, [2, 8192, 4, 128]); view_2666 = None + permute_821 = torch.ops.aten.permute.default(view_2667, [0, 2, 1, 3]); view_2667 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16); primals_157 = None + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 32, '0'); convert_element_type_562 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32); add_67 = None + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_223) + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_564, 8, '1'); convert_element_type_564 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + split_77 = torch.ops.aten.split.Tensor(wait_tensor_224, 2); wait_tensor_224 = None + getitem_769 = split_77[0] + getitem_770 = split_77[1] + getitem_771 = split_77[2] + getitem_772 = split_77[3] + getitem_773 = split_77[4] + getitem_774 = split_77[5] + getitem_775 = split_77[6] + getitem_776 = split_77[7]; split_77 = None + cat_69 = torch.ops.aten.cat.default([getitem_769, getitem_770, getitem_771, getitem_772, getitem_773, getitem_774, getitem_775, getitem_776], 1); getitem_769 = getitem_770 = getitem_771 = getitem_772 = getitem_773 = getitem_774 = getitem_775 = getitem_776 = None + view_1239 = torch.ops.aten.view.default(cat_69, [16384, 4096]); cat_69 = None + view_1240 = torch.ops.aten.view.default(mm_119, [2, 8192, 512]); mm_119 = None + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16); primals_159 = None + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 32, '0'); convert_element_type_568 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_226, [1, 0]); wait_tensor_226 = None + mm_120 = torch.ops.aten.mm.default(view_1239, permute_188) + view_1247 = torch.ops.aten.view.default(mm_120, [2, 8192, 128]); mm_120 = None + view_1254 = torch.ops.aten.view.default(mm_121, [2, 8192, 128]); mm_121 = None + view_1256 = torch.ops.aten.view.default(view_1240, [2, 8192, -1, 128]); view_1240 = None + view_1257 = torch.ops.aten.view.default(view_1247, [2, 8192, -1, 128]); view_1247 = None + view_1258 = torch.ops.aten.view.default(view_1254, [2, 8192, -1, 128]); view_1254 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_1256, torch.float32); view_1256 = None + view_1259 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 4, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_1259); view_1259 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_1257, torch.float32); view_1257 = None + view_1260 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 1, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_1260); view_1260 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_37); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_1262 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 4, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_37); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_1263 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 1, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_1262, torch.bfloat16); view_1262 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_1263, torch.bfloat16); view_1263 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 1, 4, 128]); unsqueeze_34 = None + view_1264 = torch.ops.aten.view.default(expand_34, [2, 8192, 4, 128]); expand_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_1258, 3); view_1258 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 1, 4, 128]); unsqueeze_35 = None + view_1265 = torch.ops.aten.view.default(expand_35, [2, 8192, 4, 128]); expand_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_1264, [0, 2, 1, 3]); view_1264 = None + permute_192 = torch.ops.aten.permute.default(view_1265, [0, 2, 1, 3]); view_1265 = None + _scaled_dot_product_cudnn_attention_backward_14 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_821, permute_190, permute_191, permute_192, getitem_777, getitem_778, getitem_783, getitem_784, None, None, None, 8192, 8192, 0.0, True); permute_821 = permute_190 = permute_191 = permute_192 = getitem_777 = getitem_778 = getitem_783 = getitem_784 = None + getitem_1914 = _scaled_dot_product_cudnn_attention_backward_14[0] + getitem_1915 = _scaled_dot_product_cudnn_attention_backward_14[1] + getitem_1916 = _scaled_dot_product_cudnn_attention_backward_14[2]; _scaled_dot_product_cudnn_attention_backward_14 = None + permute_822 = torch.ops.aten.permute.default(getitem_1916, [0, 2, 1, 3]); getitem_1916 = None + permute_823 = torch.ops.aten.permute.default(getitem_1915, [0, 2, 1, 3]); getitem_1915 = None + permute_824 = torch.ops.aten.permute.default(getitem_1914, [0, 2, 1, 3]); getitem_1914 = None + view_2668 = torch.ops.aten.view.default(permute_822, [2, 8192, 1, 4, 128]); permute_822 = None + sum_89 = torch.ops.aten.sum.dim_IntList(view_2668, [3], True); view_2668 = None + squeeze_28 = torch.ops.aten.squeeze.dim(sum_89, 3); sum_89 = None + view_2669 = torch.ops.aten.view.default(permute_823, [2, 8192, 1, 4, 128]); permute_823 = None + sum_90 = torch.ops.aten.sum.dim_IntList(view_2669, [3], True); view_2669 = None + squeeze_29 = torch.ops.aten.squeeze.dim(sum_90, 3); sum_90 = None + convert_element_type_1859 = torch.ops.prims.convert_element_type.default(squeeze_29, torch.float32); squeeze_29 = None + convert_element_type_1860 = torch.ops.prims.convert_element_type.default(permute_824, torch.float32); permute_824 = None + view_2670 = torch.ops.aten.view.default(convert_element_type_1859, [2, 8192, 1, 64, 2]); convert_element_type_1859 = None + view_as_complex_92 = torch.ops.aten.view_as_complex.default(view_2670); view_2670 = None + mul_556 = torch.ops.aten.mul.Tensor(view_as_complex_92, _conj); view_as_complex_92 = None + view_2671 = torch.ops.aten.view.default(convert_element_type_1860, [2, 8192, 4, 64, 2]); convert_element_type_1860 = None + view_as_complex_93 = torch.ops.aten.view_as_complex.default(view_2671); view_2671 = None + mul_557 = torch.ops.aten.mul.Tensor(view_as_complex_93, _conj); view_as_complex_93 = None + view_as_real_92 = torch.ops.aten.view_as_real.default(mul_556); mul_556 = None + view_2672 = torch.ops.aten.view.default(view_as_real_92, [2, 8192, 1, 128]); view_as_real_92 = None + convert_element_type_1861 = torch.ops.prims.convert_element_type.default(view_2672, torch.bfloat16); view_2672 = None + view_as_real_93 = torch.ops.aten.view_as_real.default(mul_557); mul_557 = None + view_2673 = torch.ops.aten.view.default(view_as_real_93, [2, 8192, 4, 128]); view_as_real_93 = None + convert_element_type_1862 = torch.ops.prims.convert_element_type.default(view_2673, torch.bfloat16); view_2673 = None + view_2674 = torch.ops.aten.view.default(squeeze_28, [2, 8192, 128]); squeeze_28 = None + view_2675 = torch.ops.aten.view.default(convert_element_type_1861, [2, 8192, 128]); convert_element_type_1861 = None + view_2676 = torch.ops.aten.view.default(convert_element_type_1862, [2, 8192, 512]); convert_element_type_1862 = None + view_2677 = torch.ops.aten.view.default(view_2674, [16384, 128]); view_2674 = None + permute_825 = torch.ops.aten.permute.default(view_2677, [1, 0]) + mm_431 = torch.ops.aten.mm.default(permute_825, view_1239); permute_825 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16); primals_160 = None + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 32, '0'); convert_element_type_571 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + permute_827 = torch.ops.aten.permute.default(permute_189, [1, 0]); permute_189 = None + mm_432 = torch.ops.aten.mm.default(view_2677, permute_827); view_2677 = permute_827 = None + view_2678 = torch.ops.aten.view.default(mm_432, [2, 8192, 4096]); mm_432 = None + convert_element_type_1867 = torch.ops.prims.convert_element_type.default(mm_431, torch.float32); mm_431 = None + reduce_scatter_tensor_228 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1867, 'avg', 32, '0'); convert_element_type_1867 = None + wait_tensor_644 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_228); reduce_scatter_tensor_228 = None + view_2679 = torch.ops.aten.view.default(view_2675, [16384, 128]); view_2675 = None + permute_829 = torch.ops.aten.permute.default(view_2679, [1, 0]) + mm_433 = torch.ops.aten.mm.default(permute_829, view_1239); permute_829 = None + permute_831 = torch.ops.aten.permute.default(permute_188, [1, 0]); permute_188 = None + mm_434 = torch.ops.aten.mm.default(view_2679, permute_831); view_2679 = permute_831 = None + view_2680 = torch.ops.aten.view.default(mm_434, [2, 8192, 4096]); mm_434 = None + add_231 = torch.ops.aten.add.Tensor(view_2678, view_2680); view_2678 = view_2680 = None + convert_element_type_1872 = torch.ops.prims.convert_element_type.default(mm_433, torch.float32); mm_433 = None + reduce_scatter_tensor_229 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1872, 'avg', 32, '0'); convert_element_type_1872 = None + wait_tensor_645 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_229); reduce_scatter_tensor_229 = None + view_2681 = torch.ops.aten.view.default(view_2676, [16384, 512]); view_2676 = None + permute_833 = torch.ops.aten.permute.default(view_2681, [1, 0]) + mm_435 = torch.ops.aten.mm.default(permute_833, view_1239); permute_833 = view_1239 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16); primals_158 = None + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 32, '0'); convert_element_type_565 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + permute_835 = torch.ops.aten.permute.default(permute_187, [1, 0]); permute_187 = None + mm_436 = torch.ops.aten.mm.default(view_2681, permute_835); view_2681 = permute_835 = None + view_2682 = torch.ops.aten.view.default(mm_436, [2, 8192, 4096]); mm_436 = None + add_232 = torch.ops.aten.add.Tensor(add_231, view_2682); add_231 = view_2682 = None + convert_element_type_1877 = torch.ops.prims.convert_element_type.default(mm_435, torch.float32); mm_435 = None + reduce_scatter_tensor_230 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1877, 'avg', 32, '0'); convert_element_type_1877 = None + wait_tensor_646 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_230); reduce_scatter_tensor_230 = None + split_198 = torch.ops.aten.split.Tensor(add_232, 1024, 1); add_232 = None + getitem_1917 = split_198[0] + getitem_1918 = split_198[1] + getitem_1919 = split_198[2] + getitem_1920 = split_198[3] + getitem_1921 = split_198[4] + getitem_1922 = split_198[5] + getitem_1923 = split_198[6] + getitem_1924 = split_198[7]; split_198 = None + cat_190 = torch.ops.aten.cat.default([getitem_1917, getitem_1918, getitem_1919, getitem_1920, getitem_1921, getitem_1922, getitem_1923, getitem_1924]); getitem_1917 = getitem_1918 = getitem_1919 = getitem_1920 = getitem_1921 = getitem_1922 = getitem_1923 = getitem_1924 = None + reduce_scatter_tensor_231 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_190, 'sum', 8, '1'); cat_190 = None + wait_tensor_647 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_231); reduce_scatter_tensor_231 = None + convert_element_type_1878 = torch.ops.prims.convert_element_type.default(wait_tensor_647, torch.float32); wait_tensor_647 = None + convert_element_type_1880 = torch.ops.prims.convert_element_type.default(wait_tensor_223, torch.float32); wait_tensor_223 = None + mul_558 = torch.ops.aten.mul.Tensor(convert_element_type_1878, convert_element_type_1880); convert_element_type_1880 = None + mul_560 = torch.ops.aten.mul.Tensor(mul_136, mul_558) + sum_91 = torch.ops.aten.sum.dim_IntList(mul_560, [2], True); mul_560 = None + div_30 = torch.ops.aten.div.Tensor(mul_136, 4096) + mul_561 = torch.ops.aten.mul.Tensor(div_30, sum_91); div_30 = sum_91 = None + sub_46 = torch.ops.aten.sub.Tensor(mul_558, mul_561); mul_558 = mul_561 = None + mul_562 = torch.ops.aten.mul.Tensor(sub_46, rsqrt_34); sub_46 = rsqrt_34 = None + mul_563 = torch.ops.aten.mul.Tensor(convert_element_type_1878, mul_136); convert_element_type_1878 = mul_136 = None + sum_92 = torch.ops.aten.sum.dim_IntList(mul_563, [0, 1]); mul_563 = None + convert_element_type_1881 = torch.ops.prims.convert_element_type.default(mul_562, torch.bfloat16); mul_562 = None + convert_element_type_1882 = torch.ops.prims.convert_element_type.default(sum_92, torch.bfloat16); sum_92 = None + all_reduce_30 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1882, 'sum', '1'); convert_element_type_1882 = None + wait_tensor_648 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_30); all_reduce_30 = None + convert_element_type_1883 = torch.ops.prims.convert_element_type.default(wait_tensor_648, torch.float32); wait_tensor_648 = None + reduce_scatter_tensor_232 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1883, 'avg', 32, '0'); convert_element_type_1883 = None + wait_tensor_649 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_232); reduce_scatter_tensor_232 = None + add_233 = torch.ops.aten.add.Tensor(add_230, convert_element_type_1881); add_230 = convert_element_type_1881 = None + all_gather_into_tensor_386 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_233, 8, '1') + wait_tensor_650 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_386); all_gather_into_tensor_386 = None + split_199 = torch.ops.aten.split.Tensor(wait_tensor_650, 2); wait_tensor_650 = None + getitem_1925 = split_199[0] + getitem_1926 = split_199[1] + getitem_1927 = split_199[2] + getitem_1928 = split_199[3] + getitem_1929 = split_199[4] + getitem_1930 = split_199[5] + getitem_1931 = split_199[6] + getitem_1932 = split_199[7]; split_199 = None + cat_191 = torch.ops.aten.cat.default([getitem_1925, getitem_1926, getitem_1927, getitem_1928, getitem_1929, getitem_1930, getitem_1931, getitem_1932], 1); getitem_1925 = getitem_1926 = getitem_1927 = getitem_1928 = getitem_1929 = getitem_1930 = getitem_1931 = getitem_1932 = None + view_2683 = torch.ops.aten.view.default(cat_191, [16384, 4096]); cat_191 = None + permute_837 = torch.ops.aten.permute.default(view_2683, [1, 0]) + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33); reduce_scatter_tensor_33 = None + add_65 = torch.ops.aten.add.Tensor(add_63, wait_tensor_216); wait_tensor_216 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16); primals_153 = None + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 32, '0'); convert_element_type_548 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32); add_65 = None + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_217) + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_550, 8, '1'); convert_element_type_550 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + split_75 = torch.ops.aten.split.Tensor(wait_tensor_218, 2); wait_tensor_218 = None + getitem_753 = split_75[0] + getitem_754 = split_75[1] + getitem_755 = split_75[2] + getitem_756 = split_75[3] + getitem_757 = split_75[4] + getitem_758 = split_75[5] + getitem_759 = split_75[6] + getitem_760 = split_75[7]; split_75 = None + cat_67 = torch.ops.aten.cat.default([getitem_753, getitem_754, getitem_755, getitem_756, getitem_757, getitem_758, getitem_759, getitem_760], 1); getitem_753 = getitem_754 = getitem_755 = getitem_756 = getitem_757 = getitem_758 = getitem_759 = getitem_760 = None + view_1212 = torch.ops.aten.view.default(cat_67, [16384, 4096]); cat_67 = None + view_1213 = torch.ops.aten.view.default(mm_116, [2, 8192, 1792]); mm_116 = None + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_1213, torch.float32); view_1213 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16); primals_155 = None + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 32, '0'); convert_element_type_556 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_117 = torch.ops.aten.mm.default(view_1212, permute_185) + view_1220 = torch.ops.aten.view.default(mm_117, [2, 8192, 1792]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_1220) + view_1227 = torch.ops.aten.view.default(mul_135, [16384, 1792]); mul_135 = None + mm_437 = torch.ops.aten.mm.default(permute_837, view_1227); permute_837 = view_1227 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16); primals_156 = None + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 32, '0'); convert_element_type_559 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + permute_839 = torch.ops.aten.permute.default(permute_186, [1, 0]); permute_186 = None + mm_438 = torch.ops.aten.mm.default(view_2683, permute_839); view_2683 = permute_839 = None + view_2684 = torch.ops.aten.view.default(mm_438, [2, 8192, 1792]); mm_438 = None + convert_element_type_1888 = torch.ops.prims.convert_element_type.default(mm_437, torch.float32); mm_437 = None + reduce_scatter_tensor_233 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1888, 'avg', 32, '0'); convert_element_type_1888 = None + wait_tensor_651 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_233); reduce_scatter_tensor_233 = None + mul_564 = torch.ops.aten.mul.Tensor(view_2684, convert_element_type_555); convert_element_type_555 = None + mul_565 = torch.ops.aten.mul.Tensor(view_2684, view_1220); view_2684 = view_1220 = None + view_2685 = torch.ops.aten.view.default(mul_564, [16384, 1792]); mul_564 = None + permute_841 = torch.ops.aten.permute.default(view_2685, [1, 0]) + mm_439 = torch.ops.aten.mm.default(permute_841, view_1212); permute_841 = None + permute_843 = torch.ops.aten.permute.default(permute_185, [1, 0]); permute_185 = None + mm_440 = torch.ops.aten.mm.default(view_2685, permute_843); view_2685 = permute_843 = None + view_2686 = torch.ops.aten.view.default(mm_440, [2, 8192, 4096]); mm_440 = None + convert_element_type_1893 = torch.ops.prims.convert_element_type.default(mm_439, torch.float32); mm_439 = None + reduce_scatter_tensor_234 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1893, 'avg', 32, '0'); convert_element_type_1893 = None + wait_tensor_652 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_234); reduce_scatter_tensor_234 = None + convert_element_type_1894 = torch.ops.prims.convert_element_type.default(mul_565, torch.float32); mul_565 = None + neg_15 = torch.ops.aten.neg.default(convert_element_type_554) + exp_15 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_234 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + reciprocal_15 = torch.ops.aten.reciprocal.default(add_234); add_234 = None + mul_566 = torch.ops.aten.mul.Tensor(reciprocal_15, 1); reciprocal_15 = None + mul_567 = torch.ops.aten.mul.Tensor(convert_element_type_1894, mul_566); convert_element_type_1894 = None + sub_47 = torch.ops.aten.sub.Tensor(1, mul_566); mul_566 = None + mul_568 = torch.ops.aten.mul.Tensor(convert_element_type_554, sub_47); convert_element_type_554 = sub_47 = None + add_235 = torch.ops.aten.add.Tensor(mul_568, 1); mul_568 = None + mul_569 = torch.ops.aten.mul.Tensor(mul_567, add_235); mul_567 = add_235 = None + convert_element_type_1896 = torch.ops.prims.convert_element_type.default(mul_569, torch.bfloat16); mul_569 = None + view_2687 = torch.ops.aten.view.default(convert_element_type_1896, [16384, 1792]); convert_element_type_1896 = None + permute_845 = torch.ops.aten.permute.default(view_2687, [1, 0]) + mm_441 = torch.ops.aten.mm.default(permute_845, view_1212); permute_845 = view_1212 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16); primals_154 = None + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 32, '0'); convert_element_type_551 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + permute_847 = torch.ops.aten.permute.default(permute_184, [1, 0]); permute_184 = None + mm_442 = torch.ops.aten.mm.default(view_2687, permute_847); view_2687 = permute_847 = None + view_2688 = torch.ops.aten.view.default(mm_442, [2, 8192, 4096]); mm_442 = None + add_236 = torch.ops.aten.add.Tensor(view_2686, view_2688); view_2686 = view_2688 = None + convert_element_type_1901 = torch.ops.prims.convert_element_type.default(mm_441, torch.float32); mm_441 = None + reduce_scatter_tensor_235 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1901, 'avg', 32, '0'); convert_element_type_1901 = None + wait_tensor_653 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_235); reduce_scatter_tensor_235 = None + split_200 = torch.ops.aten.split.Tensor(add_236, 1024, 1); add_236 = None + getitem_1933 = split_200[0] + getitem_1934 = split_200[1] + getitem_1935 = split_200[2] + getitem_1936 = split_200[3] + getitem_1937 = split_200[4] + getitem_1938 = split_200[5] + getitem_1939 = split_200[6] + getitem_1940 = split_200[7]; split_200 = None + cat_192 = torch.ops.aten.cat.default([getitem_1933, getitem_1934, getitem_1935, getitem_1936, getitem_1937, getitem_1938, getitem_1939, getitem_1940]); getitem_1933 = getitem_1934 = getitem_1935 = getitem_1936 = getitem_1937 = getitem_1938 = getitem_1939 = getitem_1940 = None + reduce_scatter_tensor_236 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_192, 'sum', 8, '1'); cat_192 = None + wait_tensor_654 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_236); reduce_scatter_tensor_236 = None + convert_element_type_1902 = torch.ops.prims.convert_element_type.default(wait_tensor_654, torch.float32); wait_tensor_654 = None + convert_element_type_1904 = torch.ops.prims.convert_element_type.default(wait_tensor_217, torch.float32); wait_tensor_217 = None + mul_570 = torch.ops.aten.mul.Tensor(convert_element_type_1902, convert_element_type_1904); convert_element_type_1904 = None + mul_572 = torch.ops.aten.mul.Tensor(mul_132, mul_570) + sum_93 = torch.ops.aten.sum.dim_IntList(mul_572, [2], True); mul_572 = None + div_31 = torch.ops.aten.div.Tensor(mul_132, 4096) + mul_573 = torch.ops.aten.mul.Tensor(div_31, sum_93); div_31 = sum_93 = None + sub_48 = torch.ops.aten.sub.Tensor(mul_570, mul_573); mul_570 = mul_573 = None + mul_574 = torch.ops.aten.mul.Tensor(sub_48, rsqrt_33); sub_48 = rsqrt_33 = None + mul_575 = torch.ops.aten.mul.Tensor(convert_element_type_1902, mul_132); convert_element_type_1902 = mul_132 = None + sum_94 = torch.ops.aten.sum.dim_IntList(mul_575, [0, 1]); mul_575 = None + convert_element_type_1905 = torch.ops.prims.convert_element_type.default(mul_574, torch.bfloat16); mul_574 = None + convert_element_type_1906 = torch.ops.prims.convert_element_type.default(sum_94, torch.bfloat16); sum_94 = None + all_reduce_31 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1906, 'sum', '1'); convert_element_type_1906 = None + wait_tensor_655 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_31); all_reduce_31 = None + convert_element_type_1907 = torch.ops.prims.convert_element_type.default(wait_tensor_655, torch.float32); wait_tensor_655 = None + reduce_scatter_tensor_237 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1907, 'avg', 32, '0'); convert_element_type_1907 = None + wait_tensor_656 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_237); reduce_scatter_tensor_237 = None + add_237 = torch.ops.aten.add.Tensor(add_233, convert_element_type_1905); add_233 = convert_element_type_1905 = None + all_gather_into_tensor_387 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_237, 8, '1') + wait_tensor_657 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_387); all_gather_into_tensor_387 = None + split_201 = torch.ops.aten.split.Tensor(wait_tensor_657, 2); wait_tensor_657 = None + getitem_1941 = split_201[0] + getitem_1942 = split_201[1] + getitem_1943 = split_201[2] + getitem_1944 = split_201[3] + getitem_1945 = split_201[4] + getitem_1946 = split_201[5] + getitem_1947 = split_201[6] + getitem_1948 = split_201[7]; split_201 = None + cat_193 = torch.ops.aten.cat.default([getitem_1941, getitem_1942, getitem_1943, getitem_1944, getitem_1945, getitem_1946, getitem_1947, getitem_1948], 1); getitem_1941 = getitem_1942 = getitem_1943 = getitem_1944 = getitem_1945 = getitem_1946 = getitem_1947 = getitem_1948 = None + view_2689 = torch.ops.aten.view.default(cat_193, [16384, 4096]); cat_193 = None + permute_849 = torch.ops.aten.permute.default(view_2689, [1, 0]) + permute_182 = torch.ops.aten.permute.default(getitem_736, [0, 2, 1, 3]) + view_1194 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + view_1200 = torch.ops.aten.view.default(view_1194, [16384, 512]); view_1194 = None + mm_443 = torch.ops.aten.mm.default(permute_849, view_1200); permute_849 = view_1200 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16); primals_152 = None + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 32, '0'); convert_element_type_545 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + permute_851 = torch.ops.aten.permute.default(permute_183, [1, 0]); permute_183 = None + mm_444 = torch.ops.aten.mm.default(view_2689, permute_851); view_2689 = permute_851 = None + view_2690 = torch.ops.aten.view.default(mm_444, [2, 8192, 512]); mm_444 = None + convert_element_type_1912 = torch.ops.prims.convert_element_type.default(mm_443, torch.float32); mm_443 = None + reduce_scatter_tensor_238 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1912, 'avg', 32, '0'); convert_element_type_1912 = None + wait_tensor_658 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_238); reduce_scatter_tensor_238 = None + view_2691 = torch.ops.aten.view.default(view_2690, [2, 8192, 4, 128]); view_2690 = None + permute_853 = torch.ops.aten.permute.default(view_2691, [0, 2, 1, 3]); view_2691 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16); primals_148 = None + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 32, '0'); convert_element_type_529 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32); add_63 = None + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_210) + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_531, 8, '1'); convert_element_type_531 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + split_73 = torch.ops.aten.split.Tensor(wait_tensor_211, 2); wait_tensor_211 = None + getitem_728 = split_73[0] + getitem_729 = split_73[1] + getitem_730 = split_73[2] + getitem_731 = split_73[3] + getitem_732 = split_73[4] + getitem_733 = split_73[5] + getitem_734 = split_73[6] + getitem_735 = split_73[7]; split_73 = None + cat_65 = torch.ops.aten.cat.default([getitem_728, getitem_729, getitem_730, getitem_731, getitem_732, getitem_733, getitem_734, getitem_735], 1); getitem_728 = getitem_729 = getitem_730 = getitem_731 = getitem_732 = getitem_733 = getitem_734 = getitem_735 = None + view_1167 = torch.ops.aten.view.default(cat_65, [16384, 4096]); cat_65 = None + view_1168 = torch.ops.aten.view.default(mm_112, [2, 8192, 512]); mm_112 = None + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16); primals_150 = None + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 32, '0'); convert_element_type_535 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_213, [1, 0]); wait_tensor_213 = None + mm_113 = torch.ops.aten.mm.default(view_1167, permute_177) + view_1175 = torch.ops.aten.view.default(mm_113, [2, 8192, 128]); mm_113 = None + view_1182 = torch.ops.aten.view.default(mm_114, [2, 8192, 128]); mm_114 = None + view_1184 = torch.ops.aten.view.default(view_1168, [2, 8192, -1, 128]); view_1168 = None + view_1185 = torch.ops.aten.view.default(view_1175, [2, 8192, -1, 128]); view_1175 = None + view_1186 = torch.ops.aten.view.default(view_1182, [2, 8192, -1, 128]); view_1182 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_1184, torch.float32); view_1184 = None + view_1187 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 4, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_1187); view_1187 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_1185, torch.float32); view_1185 = None + view_1188 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 1, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_1188); view_1188 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_37); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_1190 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 4, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_37); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_1191 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 1, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_1190, torch.bfloat16); view_1190 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_1191, torch.bfloat16); view_1191 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 1, 4, 128]); unsqueeze_32 = None + view_1192 = torch.ops.aten.view.default(expand_32, [2, 8192, 4, 128]); expand_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_1186, 3); view_1186 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 1, 4, 128]); unsqueeze_33 = None + view_1193 = torch.ops.aten.view.default(expand_33, [2, 8192, 4, 128]); expand_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_1192, [0, 2, 1, 3]); view_1192 = None + permute_181 = torch.ops.aten.permute.default(view_1193, [0, 2, 1, 3]); view_1193 = None + _scaled_dot_product_cudnn_attention_backward_15 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_853, permute_179, permute_180, permute_181, getitem_736, getitem_737, getitem_742, getitem_743, None, None, None, 8192, 8192, 0.0, True); permute_853 = permute_179 = permute_180 = permute_181 = getitem_736 = getitem_737 = getitem_742 = getitem_743 = None + getitem_1949 = _scaled_dot_product_cudnn_attention_backward_15[0] + getitem_1950 = _scaled_dot_product_cudnn_attention_backward_15[1] + getitem_1951 = _scaled_dot_product_cudnn_attention_backward_15[2]; _scaled_dot_product_cudnn_attention_backward_15 = None + permute_854 = torch.ops.aten.permute.default(getitem_1951, [0, 2, 1, 3]); getitem_1951 = None + permute_855 = torch.ops.aten.permute.default(getitem_1950, [0, 2, 1, 3]); getitem_1950 = None + permute_856 = torch.ops.aten.permute.default(getitem_1949, [0, 2, 1, 3]); getitem_1949 = None + view_2692 = torch.ops.aten.view.default(permute_854, [2, 8192, 1, 4, 128]); permute_854 = None + sum_95 = torch.ops.aten.sum.dim_IntList(view_2692, [3], True); view_2692 = None + squeeze_30 = torch.ops.aten.squeeze.dim(sum_95, 3); sum_95 = None + view_2693 = torch.ops.aten.view.default(permute_855, [2, 8192, 1, 4, 128]); permute_855 = None + sum_96 = torch.ops.aten.sum.dim_IntList(view_2693, [3], True); view_2693 = None + squeeze_31 = torch.ops.aten.squeeze.dim(sum_96, 3); sum_96 = None + convert_element_type_1913 = torch.ops.prims.convert_element_type.default(squeeze_31, torch.float32); squeeze_31 = None + convert_element_type_1914 = torch.ops.prims.convert_element_type.default(permute_856, torch.float32); permute_856 = None + view_2694 = torch.ops.aten.view.default(convert_element_type_1913, [2, 8192, 1, 64, 2]); convert_element_type_1913 = None + view_as_complex_94 = torch.ops.aten.view_as_complex.default(view_2694); view_2694 = None + mul_576 = torch.ops.aten.mul.Tensor(view_as_complex_94, _conj); view_as_complex_94 = None + view_2695 = torch.ops.aten.view.default(convert_element_type_1914, [2, 8192, 4, 64, 2]); convert_element_type_1914 = None + view_as_complex_95 = torch.ops.aten.view_as_complex.default(view_2695); view_2695 = None + mul_577 = torch.ops.aten.mul.Tensor(view_as_complex_95, _conj); view_as_complex_95 = None + view_as_real_94 = torch.ops.aten.view_as_real.default(mul_576); mul_576 = None + view_2696 = torch.ops.aten.view.default(view_as_real_94, [2, 8192, 1, 128]); view_as_real_94 = None + convert_element_type_1915 = torch.ops.prims.convert_element_type.default(view_2696, torch.bfloat16); view_2696 = None + view_as_real_95 = torch.ops.aten.view_as_real.default(mul_577); mul_577 = None + view_2697 = torch.ops.aten.view.default(view_as_real_95, [2, 8192, 4, 128]); view_as_real_95 = None + convert_element_type_1916 = torch.ops.prims.convert_element_type.default(view_2697, torch.bfloat16); view_2697 = None + view_2698 = torch.ops.aten.view.default(squeeze_30, [2, 8192, 128]); squeeze_30 = None + view_2699 = torch.ops.aten.view.default(convert_element_type_1915, [2, 8192, 128]); convert_element_type_1915 = None + view_2700 = torch.ops.aten.view.default(convert_element_type_1916, [2, 8192, 512]); convert_element_type_1916 = None + view_2701 = torch.ops.aten.view.default(view_2698, [16384, 128]); view_2698 = None + permute_857 = torch.ops.aten.permute.default(view_2701, [1, 0]) + mm_445 = torch.ops.aten.mm.default(permute_857, view_1167); permute_857 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16); primals_151 = None + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 32, '0'); convert_element_type_538 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + permute_859 = torch.ops.aten.permute.default(permute_178, [1, 0]); permute_178 = None + mm_446 = torch.ops.aten.mm.default(view_2701, permute_859); view_2701 = permute_859 = None + view_2702 = torch.ops.aten.view.default(mm_446, [2, 8192, 4096]); mm_446 = None + convert_element_type_1921 = torch.ops.prims.convert_element_type.default(mm_445, torch.float32); mm_445 = None + reduce_scatter_tensor_239 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1921, 'avg', 32, '0'); convert_element_type_1921 = None + wait_tensor_659 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_239); reduce_scatter_tensor_239 = None + view_2703 = torch.ops.aten.view.default(view_2699, [16384, 128]); view_2699 = None + permute_861 = torch.ops.aten.permute.default(view_2703, [1, 0]) + mm_447 = torch.ops.aten.mm.default(permute_861, view_1167); permute_861 = None + permute_863 = torch.ops.aten.permute.default(permute_177, [1, 0]); permute_177 = None + mm_448 = torch.ops.aten.mm.default(view_2703, permute_863); view_2703 = permute_863 = None + view_2704 = torch.ops.aten.view.default(mm_448, [2, 8192, 4096]); mm_448 = None + add_238 = torch.ops.aten.add.Tensor(view_2702, view_2704); view_2702 = view_2704 = None + convert_element_type_1926 = torch.ops.prims.convert_element_type.default(mm_447, torch.float32); mm_447 = None + reduce_scatter_tensor_240 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1926, 'avg', 32, '0'); convert_element_type_1926 = None + wait_tensor_660 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_240); reduce_scatter_tensor_240 = None + view_2705 = torch.ops.aten.view.default(view_2700, [16384, 512]); view_2700 = None + permute_865 = torch.ops.aten.permute.default(view_2705, [1, 0]) + mm_449 = torch.ops.aten.mm.default(permute_865, view_1167); permute_865 = view_1167 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16); primals_149 = None + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 32, '0'); convert_element_type_532 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + permute_867 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None + mm_450 = torch.ops.aten.mm.default(view_2705, permute_867); view_2705 = permute_867 = None + view_2706 = torch.ops.aten.view.default(mm_450, [2, 8192, 4096]); mm_450 = None + add_239 = torch.ops.aten.add.Tensor(add_238, view_2706); add_238 = view_2706 = None + convert_element_type_1931 = torch.ops.prims.convert_element_type.default(mm_449, torch.float32); mm_449 = None + reduce_scatter_tensor_241 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1931, 'avg', 32, '0'); convert_element_type_1931 = None + wait_tensor_661 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_241); reduce_scatter_tensor_241 = None + split_202 = torch.ops.aten.split.Tensor(add_239, 1024, 1); add_239 = None + getitem_1952 = split_202[0] + getitem_1953 = split_202[1] + getitem_1954 = split_202[2] + getitem_1955 = split_202[3] + getitem_1956 = split_202[4] + getitem_1957 = split_202[5] + getitem_1958 = split_202[6] + getitem_1959 = split_202[7]; split_202 = None + cat_194 = torch.ops.aten.cat.default([getitem_1952, getitem_1953, getitem_1954, getitem_1955, getitem_1956, getitem_1957, getitem_1958, getitem_1959]); getitem_1952 = getitem_1953 = getitem_1954 = getitem_1955 = getitem_1956 = getitem_1957 = getitem_1958 = getitem_1959 = None + reduce_scatter_tensor_242 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_194, 'sum', 8, '1'); cat_194 = None + wait_tensor_662 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_242); reduce_scatter_tensor_242 = None + convert_element_type_1932 = torch.ops.prims.convert_element_type.default(wait_tensor_662, torch.float32); wait_tensor_662 = None + convert_element_type_1934 = torch.ops.prims.convert_element_type.default(wait_tensor_210, torch.float32); wait_tensor_210 = None + mul_578 = torch.ops.aten.mul.Tensor(convert_element_type_1932, convert_element_type_1934); convert_element_type_1934 = None + mul_580 = torch.ops.aten.mul.Tensor(mul_128, mul_578) + sum_97 = torch.ops.aten.sum.dim_IntList(mul_580, [2], True); mul_580 = None + div_32 = torch.ops.aten.div.Tensor(mul_128, 4096) + mul_581 = torch.ops.aten.mul.Tensor(div_32, sum_97); div_32 = sum_97 = None + sub_49 = torch.ops.aten.sub.Tensor(mul_578, mul_581); mul_578 = mul_581 = None + mul_582 = torch.ops.aten.mul.Tensor(sub_49, rsqrt_32); sub_49 = rsqrt_32 = None + mul_583 = torch.ops.aten.mul.Tensor(convert_element_type_1932, mul_128); convert_element_type_1932 = mul_128 = None + sum_98 = torch.ops.aten.sum.dim_IntList(mul_583, [0, 1]); mul_583 = None + convert_element_type_1935 = torch.ops.prims.convert_element_type.default(mul_582, torch.bfloat16); mul_582 = None + convert_element_type_1936 = torch.ops.prims.convert_element_type.default(sum_98, torch.bfloat16); sum_98 = None + all_reduce_32 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1936, 'sum', '1'); convert_element_type_1936 = None + wait_tensor_663 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_32); all_reduce_32 = None + convert_element_type_1937 = torch.ops.prims.convert_element_type.default(wait_tensor_663, torch.float32); wait_tensor_663 = None + reduce_scatter_tensor_243 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1937, 'avg', 32, '0'); convert_element_type_1937 = None + wait_tensor_664 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_243); reduce_scatter_tensor_243 = None + add_240 = torch.ops.aten.add.Tensor(add_237, convert_element_type_1935); add_237 = convert_element_type_1935 = None + all_gather_into_tensor_388 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_240, 8, '1') + wait_tensor_665 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_388); all_gather_into_tensor_388 = None + split_203 = torch.ops.aten.split.Tensor(wait_tensor_665, 2); wait_tensor_665 = None + getitem_1960 = split_203[0] + getitem_1961 = split_203[1] + getitem_1962 = split_203[2] + getitem_1963 = split_203[3] + getitem_1964 = split_203[4] + getitem_1965 = split_203[5] + getitem_1966 = split_203[6] + getitem_1967 = split_203[7]; split_203 = None + cat_195 = torch.ops.aten.cat.default([getitem_1960, getitem_1961, getitem_1962, getitem_1963, getitem_1964, getitem_1965, getitem_1966, getitem_1967], 1); getitem_1960 = getitem_1961 = getitem_1962 = getitem_1963 = getitem_1964 = getitem_1965 = getitem_1966 = getitem_1967 = None + view_2707 = torch.ops.aten.view.default(cat_195, [16384, 4096]); cat_195 = None + permute_869 = torch.ops.aten.permute.default(view_2707, [1, 0]) + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31); reduce_scatter_tensor_31 = None + add_61 = torch.ops.aten.add.Tensor(add_59, wait_tensor_203); wait_tensor_203 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16); primals_144 = None + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 32, '0'); convert_element_type_515 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32); add_61 = None + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_204) + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_517, 8, '1'); convert_element_type_517 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + split_71 = torch.ops.aten.split.Tensor(wait_tensor_205, 2); wait_tensor_205 = None + getitem_712 = split_71[0] + getitem_713 = split_71[1] + getitem_714 = split_71[2] + getitem_715 = split_71[3] + getitem_716 = split_71[4] + getitem_717 = split_71[5] + getitem_718 = split_71[6] + getitem_719 = split_71[7]; split_71 = None + cat_63 = torch.ops.aten.cat.default([getitem_712, getitem_713, getitem_714, getitem_715, getitem_716, getitem_717, getitem_718, getitem_719], 1); getitem_712 = getitem_713 = getitem_714 = getitem_715 = getitem_716 = getitem_717 = getitem_718 = getitem_719 = None + view_1140 = torch.ops.aten.view.default(cat_63, [16384, 4096]); cat_63 = None + view_1141 = torch.ops.aten.view.default(mm_109, [2, 8192, 1792]); mm_109 = None + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_1141, torch.float32); view_1141 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16); primals_146 = None + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 32, '0'); convert_element_type_523 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + mm_110 = torch.ops.aten.mm.default(view_1140, permute_174) + view_1148 = torch.ops.aten.view.default(mm_110, [2, 8192, 1792]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_1148) + view_1155 = torch.ops.aten.view.default(mul_127, [16384, 1792]); mul_127 = None + mm_451 = torch.ops.aten.mm.default(permute_869, view_1155); permute_869 = view_1155 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16); primals_147 = None + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 32, '0'); convert_element_type_526 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_208, [1, 0]); wait_tensor_208 = None + permute_871 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None + mm_452 = torch.ops.aten.mm.default(view_2707, permute_871); view_2707 = permute_871 = None + view_2708 = torch.ops.aten.view.default(mm_452, [2, 8192, 1792]); mm_452 = None + convert_element_type_1942 = torch.ops.prims.convert_element_type.default(mm_451, torch.float32); mm_451 = None + reduce_scatter_tensor_244 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1942, 'avg', 32, '0'); convert_element_type_1942 = None + wait_tensor_666 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_244); reduce_scatter_tensor_244 = None + mul_584 = torch.ops.aten.mul.Tensor(view_2708, convert_element_type_522); convert_element_type_522 = None + mul_585 = torch.ops.aten.mul.Tensor(view_2708, view_1148); view_2708 = view_1148 = None + view_2709 = torch.ops.aten.view.default(mul_584, [16384, 1792]); mul_584 = None + permute_873 = torch.ops.aten.permute.default(view_2709, [1, 0]) + mm_453 = torch.ops.aten.mm.default(permute_873, view_1140); permute_873 = None + permute_875 = torch.ops.aten.permute.default(permute_174, [1, 0]); permute_174 = None + mm_454 = torch.ops.aten.mm.default(view_2709, permute_875); view_2709 = permute_875 = None + view_2710 = torch.ops.aten.view.default(mm_454, [2, 8192, 4096]); mm_454 = None + convert_element_type_1947 = torch.ops.prims.convert_element_type.default(mm_453, torch.float32); mm_453 = None + reduce_scatter_tensor_245 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1947, 'avg', 32, '0'); convert_element_type_1947 = None + wait_tensor_667 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_245); reduce_scatter_tensor_245 = None + convert_element_type_1948 = torch.ops.prims.convert_element_type.default(mul_585, torch.float32); mul_585 = None + neg_16 = torch.ops.aten.neg.default(convert_element_type_521) + exp_16 = torch.ops.aten.exp.default(neg_16); neg_16 = None + add_241 = torch.ops.aten.add.Tensor(exp_16, 1); exp_16 = None + reciprocal_16 = torch.ops.aten.reciprocal.default(add_241); add_241 = None + mul_586 = torch.ops.aten.mul.Tensor(reciprocal_16, 1); reciprocal_16 = None + mul_587 = torch.ops.aten.mul.Tensor(convert_element_type_1948, mul_586); convert_element_type_1948 = None + sub_50 = torch.ops.aten.sub.Tensor(1, mul_586); mul_586 = None + mul_588 = torch.ops.aten.mul.Tensor(convert_element_type_521, sub_50); convert_element_type_521 = sub_50 = None + add_242 = torch.ops.aten.add.Tensor(mul_588, 1); mul_588 = None + mul_589 = torch.ops.aten.mul.Tensor(mul_587, add_242); mul_587 = add_242 = None + convert_element_type_1950 = torch.ops.prims.convert_element_type.default(mul_589, torch.bfloat16); mul_589 = None + view_2711 = torch.ops.aten.view.default(convert_element_type_1950, [16384, 1792]); convert_element_type_1950 = None + permute_877 = torch.ops.aten.permute.default(view_2711, [1, 0]) + mm_455 = torch.ops.aten.mm.default(permute_877, view_1140); permute_877 = view_1140 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16); primals_145 = None + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 32, '0'); convert_element_type_518 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + permute_879 = torch.ops.aten.permute.default(permute_173, [1, 0]); permute_173 = None + mm_456 = torch.ops.aten.mm.default(view_2711, permute_879); view_2711 = permute_879 = None + view_2712 = torch.ops.aten.view.default(mm_456, [2, 8192, 4096]); mm_456 = None + add_243 = torch.ops.aten.add.Tensor(view_2710, view_2712); view_2710 = view_2712 = None + convert_element_type_1955 = torch.ops.prims.convert_element_type.default(mm_455, torch.float32); mm_455 = None + reduce_scatter_tensor_246 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1955, 'avg', 32, '0'); convert_element_type_1955 = None + wait_tensor_668 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_246); reduce_scatter_tensor_246 = None + split_204 = torch.ops.aten.split.Tensor(add_243, 1024, 1); add_243 = None + getitem_1968 = split_204[0] + getitem_1969 = split_204[1] + getitem_1970 = split_204[2] + getitem_1971 = split_204[3] + getitem_1972 = split_204[4] + getitem_1973 = split_204[5] + getitem_1974 = split_204[6] + getitem_1975 = split_204[7]; split_204 = None + cat_196 = torch.ops.aten.cat.default([getitem_1968, getitem_1969, getitem_1970, getitem_1971, getitem_1972, getitem_1973, getitem_1974, getitem_1975]); getitem_1968 = getitem_1969 = getitem_1970 = getitem_1971 = getitem_1972 = getitem_1973 = getitem_1974 = getitem_1975 = None + reduce_scatter_tensor_247 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_196, 'sum', 8, '1'); cat_196 = None + wait_tensor_669 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_247); reduce_scatter_tensor_247 = None + convert_element_type_1956 = torch.ops.prims.convert_element_type.default(wait_tensor_669, torch.float32); wait_tensor_669 = None + convert_element_type_1958 = torch.ops.prims.convert_element_type.default(wait_tensor_204, torch.float32); wait_tensor_204 = None + mul_590 = torch.ops.aten.mul.Tensor(convert_element_type_1956, convert_element_type_1958); convert_element_type_1958 = None + mul_592 = torch.ops.aten.mul.Tensor(mul_124, mul_590) + sum_99 = torch.ops.aten.sum.dim_IntList(mul_592, [2], True); mul_592 = None + div_33 = torch.ops.aten.div.Tensor(mul_124, 4096) + mul_593 = torch.ops.aten.mul.Tensor(div_33, sum_99); div_33 = sum_99 = None + sub_51 = torch.ops.aten.sub.Tensor(mul_590, mul_593); mul_590 = mul_593 = None + mul_594 = torch.ops.aten.mul.Tensor(sub_51, rsqrt_31); sub_51 = rsqrt_31 = None + mul_595 = torch.ops.aten.mul.Tensor(convert_element_type_1956, mul_124); convert_element_type_1956 = mul_124 = None + sum_100 = torch.ops.aten.sum.dim_IntList(mul_595, [0, 1]); mul_595 = None + convert_element_type_1959 = torch.ops.prims.convert_element_type.default(mul_594, torch.bfloat16); mul_594 = None + convert_element_type_1960 = torch.ops.prims.convert_element_type.default(sum_100, torch.bfloat16); sum_100 = None + all_reduce_33 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1960, 'sum', '1'); convert_element_type_1960 = None + wait_tensor_670 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_33); all_reduce_33 = None + convert_element_type_1961 = torch.ops.prims.convert_element_type.default(wait_tensor_670, torch.float32); wait_tensor_670 = None + reduce_scatter_tensor_248 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1961, 'avg', 32, '0'); convert_element_type_1961 = None + wait_tensor_671 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_248); reduce_scatter_tensor_248 = None + add_244 = torch.ops.aten.add.Tensor(add_240, convert_element_type_1959); add_240 = convert_element_type_1959 = None + all_gather_into_tensor_389 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_244, 8, '1') + wait_tensor_672 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_389); all_gather_into_tensor_389 = None + split_205 = torch.ops.aten.split.Tensor(wait_tensor_672, 2); wait_tensor_672 = None + getitem_1976 = split_205[0] + getitem_1977 = split_205[1] + getitem_1978 = split_205[2] + getitem_1979 = split_205[3] + getitem_1980 = split_205[4] + getitem_1981 = split_205[5] + getitem_1982 = split_205[6] + getitem_1983 = split_205[7]; split_205 = None + cat_197 = torch.ops.aten.cat.default([getitem_1976, getitem_1977, getitem_1978, getitem_1979, getitem_1980, getitem_1981, getitem_1982, getitem_1983], 1); getitem_1976 = getitem_1977 = getitem_1978 = getitem_1979 = getitem_1980 = getitem_1981 = getitem_1982 = getitem_1983 = None + view_2713 = torch.ops.aten.view.default(cat_197, [16384, 4096]); cat_197 = None + permute_881 = torch.ops.aten.permute.default(view_2713, [1, 0]) + permute_171 = torch.ops.aten.permute.default(getitem_695, [0, 2, 1, 3]) + view_1122 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + view_1128 = torch.ops.aten.view.default(view_1122, [16384, 512]); view_1122 = None + mm_457 = torch.ops.aten.mm.default(permute_881, view_1128); permute_881 = view_1128 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 32, '0'); convert_element_type_512 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + permute_883 = torch.ops.aten.permute.default(permute_172, [1, 0]); permute_172 = None + mm_458 = torch.ops.aten.mm.default(view_2713, permute_883); view_2713 = permute_883 = None + view_2714 = torch.ops.aten.view.default(mm_458, [2, 8192, 512]); mm_458 = None + convert_element_type_1966 = torch.ops.prims.convert_element_type.default(mm_457, torch.float32); mm_457 = None + reduce_scatter_tensor_249 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1966, 'avg', 32, '0'); convert_element_type_1966 = None + wait_tensor_673 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_249); reduce_scatter_tensor_249 = None + view_2715 = torch.ops.aten.view.default(view_2714, [2, 8192, 4, 128]); view_2714 = None + permute_885 = torch.ops.aten.permute.default(view_2715, [0, 2, 1, 3]); view_2715 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 32, '0'); convert_element_type_496 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32); add_59 = None + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_197) + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_498, 8, '1'); convert_element_type_498 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + split_69 = torch.ops.aten.split.Tensor(wait_tensor_198, 2); wait_tensor_198 = None + getitem_687 = split_69[0] + getitem_688 = split_69[1] + getitem_689 = split_69[2] + getitem_690 = split_69[3] + getitem_691 = split_69[4] + getitem_692 = split_69[5] + getitem_693 = split_69[6] + getitem_694 = split_69[7]; split_69 = None + cat_61 = torch.ops.aten.cat.default([getitem_687, getitem_688, getitem_689, getitem_690, getitem_691, getitem_692, getitem_693, getitem_694], 1); getitem_687 = getitem_688 = getitem_689 = getitem_690 = getitem_691 = getitem_692 = getitem_693 = getitem_694 = None + view_1095 = torch.ops.aten.view.default(cat_61, [16384, 4096]); cat_61 = None + view_1096 = torch.ops.aten.view.default(mm_105, [2, 8192, 512]); mm_105 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 32, '0'); convert_element_type_502 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + mm_106 = torch.ops.aten.mm.default(view_1095, permute_166) + view_1103 = torch.ops.aten.view.default(mm_106, [2, 8192, 128]); mm_106 = None + view_1110 = torch.ops.aten.view.default(mm_107, [2, 8192, 128]); mm_107 = None + view_1112 = torch.ops.aten.view.default(view_1096, [2, 8192, -1, 128]); view_1096 = None + view_1113 = torch.ops.aten.view.default(view_1103, [2, 8192, -1, 128]); view_1103 = None + view_1114 = torch.ops.aten.view.default(view_1110, [2, 8192, -1, 128]); view_1110 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_1112, torch.float32); view_1112 = None + view_1115 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 4, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_1115); view_1115 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_1113, torch.float32); view_1113 = None + view_1116 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 1, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_1116); view_1116 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_37); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_1118 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 4, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_37); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_1119 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 1, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_1118, torch.bfloat16); view_1118 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_1119, torch.bfloat16); view_1119 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 1, 4, 128]); unsqueeze_30 = None + view_1120 = torch.ops.aten.view.default(expand_30, [2, 8192, 4, 128]); expand_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_1114, 3); view_1114 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 1, 4, 128]); unsqueeze_31 = None + view_1121 = torch.ops.aten.view.default(expand_31, [2, 8192, 4, 128]); expand_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_1120, [0, 2, 1, 3]); view_1120 = None + permute_170 = torch.ops.aten.permute.default(view_1121, [0, 2, 1, 3]); view_1121 = None + _scaled_dot_product_cudnn_attention_backward_16 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_885, permute_168, permute_169, permute_170, getitem_695, getitem_696, getitem_701, getitem_702, None, None, None, 8192, 8192, 0.0, True); permute_885 = permute_168 = permute_169 = permute_170 = getitem_695 = getitem_696 = getitem_701 = getitem_702 = None + getitem_1984 = _scaled_dot_product_cudnn_attention_backward_16[0] + getitem_1985 = _scaled_dot_product_cudnn_attention_backward_16[1] + getitem_1986 = _scaled_dot_product_cudnn_attention_backward_16[2]; _scaled_dot_product_cudnn_attention_backward_16 = None + permute_886 = torch.ops.aten.permute.default(getitem_1986, [0, 2, 1, 3]); getitem_1986 = None + permute_887 = torch.ops.aten.permute.default(getitem_1985, [0, 2, 1, 3]); getitem_1985 = None + permute_888 = torch.ops.aten.permute.default(getitem_1984, [0, 2, 1, 3]); getitem_1984 = None + view_2716 = torch.ops.aten.view.default(permute_886, [2, 8192, 1, 4, 128]); permute_886 = None + sum_101 = torch.ops.aten.sum.dim_IntList(view_2716, [3], True); view_2716 = None + squeeze_32 = torch.ops.aten.squeeze.dim(sum_101, 3); sum_101 = None + view_2717 = torch.ops.aten.view.default(permute_887, [2, 8192, 1, 4, 128]); permute_887 = None + sum_102 = torch.ops.aten.sum.dim_IntList(view_2717, [3], True); view_2717 = None + squeeze_33 = torch.ops.aten.squeeze.dim(sum_102, 3); sum_102 = None + convert_element_type_1967 = torch.ops.prims.convert_element_type.default(squeeze_33, torch.float32); squeeze_33 = None + convert_element_type_1968 = torch.ops.prims.convert_element_type.default(permute_888, torch.float32); permute_888 = None + view_2718 = torch.ops.aten.view.default(convert_element_type_1967, [2, 8192, 1, 64, 2]); convert_element_type_1967 = None + view_as_complex_96 = torch.ops.aten.view_as_complex.default(view_2718); view_2718 = None + mul_596 = torch.ops.aten.mul.Tensor(view_as_complex_96, _conj); view_as_complex_96 = None + view_2719 = torch.ops.aten.view.default(convert_element_type_1968, [2, 8192, 4, 64, 2]); convert_element_type_1968 = None + view_as_complex_97 = torch.ops.aten.view_as_complex.default(view_2719); view_2719 = None + mul_597 = torch.ops.aten.mul.Tensor(view_as_complex_97, _conj); view_as_complex_97 = None + view_as_real_96 = torch.ops.aten.view_as_real.default(mul_596); mul_596 = None + view_2720 = torch.ops.aten.view.default(view_as_real_96, [2, 8192, 1, 128]); view_as_real_96 = None + convert_element_type_1969 = torch.ops.prims.convert_element_type.default(view_2720, torch.bfloat16); view_2720 = None + view_as_real_97 = torch.ops.aten.view_as_real.default(mul_597); mul_597 = None + view_2721 = torch.ops.aten.view.default(view_as_real_97, [2, 8192, 4, 128]); view_as_real_97 = None + convert_element_type_1970 = torch.ops.prims.convert_element_type.default(view_2721, torch.bfloat16); view_2721 = None + view_2722 = torch.ops.aten.view.default(squeeze_32, [2, 8192, 128]); squeeze_32 = None + view_2723 = torch.ops.aten.view.default(convert_element_type_1969, [2, 8192, 128]); convert_element_type_1969 = None + view_2724 = torch.ops.aten.view.default(convert_element_type_1970, [2, 8192, 512]); convert_element_type_1970 = None + view_2725 = torch.ops.aten.view.default(view_2722, [16384, 128]); view_2722 = None + permute_889 = torch.ops.aten.permute.default(view_2725, [1, 0]) + mm_459 = torch.ops.aten.mm.default(permute_889, view_1095); permute_889 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16); primals_142 = None + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 32, '0'); convert_element_type_505 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + permute_891 = torch.ops.aten.permute.default(permute_167, [1, 0]); permute_167 = None + mm_460 = torch.ops.aten.mm.default(view_2725, permute_891); view_2725 = permute_891 = None + view_2726 = torch.ops.aten.view.default(mm_460, [2, 8192, 4096]); mm_460 = None + convert_element_type_1975 = torch.ops.prims.convert_element_type.default(mm_459, torch.float32); mm_459 = None + reduce_scatter_tensor_250 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1975, 'avg', 32, '0'); convert_element_type_1975 = None + wait_tensor_674 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_250); reduce_scatter_tensor_250 = None + view_2727 = torch.ops.aten.view.default(view_2723, [16384, 128]); view_2723 = None + permute_893 = torch.ops.aten.permute.default(view_2727, [1, 0]) + mm_461 = torch.ops.aten.mm.default(permute_893, view_1095); permute_893 = None + permute_895 = torch.ops.aten.permute.default(permute_166, [1, 0]); permute_166 = None + mm_462 = torch.ops.aten.mm.default(view_2727, permute_895); view_2727 = permute_895 = None + view_2728 = torch.ops.aten.view.default(mm_462, [2, 8192, 4096]); mm_462 = None + add_245 = torch.ops.aten.add.Tensor(view_2726, view_2728); view_2726 = view_2728 = None + convert_element_type_1980 = torch.ops.prims.convert_element_type.default(mm_461, torch.float32); mm_461 = None + reduce_scatter_tensor_251 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1980, 'avg', 32, '0'); convert_element_type_1980 = None + wait_tensor_675 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_251); reduce_scatter_tensor_251 = None + view_2729 = torch.ops.aten.view.default(view_2724, [16384, 512]); view_2724 = None + permute_897 = torch.ops.aten.permute.default(view_2729, [1, 0]) + mm_463 = torch.ops.aten.mm.default(permute_897, view_1095); permute_897 = view_1095 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 32, '0'); convert_element_type_499 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + permute_899 = torch.ops.aten.permute.default(permute_165, [1, 0]); permute_165 = None + mm_464 = torch.ops.aten.mm.default(view_2729, permute_899); view_2729 = permute_899 = None + view_2730 = torch.ops.aten.view.default(mm_464, [2, 8192, 4096]); mm_464 = None + add_246 = torch.ops.aten.add.Tensor(add_245, view_2730); add_245 = view_2730 = None + convert_element_type_1985 = torch.ops.prims.convert_element_type.default(mm_463, torch.float32); mm_463 = None + reduce_scatter_tensor_252 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1985, 'avg', 32, '0'); convert_element_type_1985 = None + wait_tensor_676 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_252); reduce_scatter_tensor_252 = None + split_206 = torch.ops.aten.split.Tensor(add_246, 1024, 1); add_246 = None + getitem_1987 = split_206[0] + getitem_1988 = split_206[1] + getitem_1989 = split_206[2] + getitem_1990 = split_206[3] + getitem_1991 = split_206[4] + getitem_1992 = split_206[5] + getitem_1993 = split_206[6] + getitem_1994 = split_206[7]; split_206 = None + cat_198 = torch.ops.aten.cat.default([getitem_1987, getitem_1988, getitem_1989, getitem_1990, getitem_1991, getitem_1992, getitem_1993, getitem_1994]); getitem_1987 = getitem_1988 = getitem_1989 = getitem_1990 = getitem_1991 = getitem_1992 = getitem_1993 = getitem_1994 = None + reduce_scatter_tensor_253 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_198, 'sum', 8, '1'); cat_198 = None + wait_tensor_677 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_253); reduce_scatter_tensor_253 = None + convert_element_type_1986 = torch.ops.prims.convert_element_type.default(wait_tensor_677, torch.float32); wait_tensor_677 = None + convert_element_type_1988 = torch.ops.prims.convert_element_type.default(wait_tensor_197, torch.float32); wait_tensor_197 = None + mul_598 = torch.ops.aten.mul.Tensor(convert_element_type_1986, convert_element_type_1988); convert_element_type_1988 = None + mul_600 = torch.ops.aten.mul.Tensor(mul_120, mul_598) + sum_103 = torch.ops.aten.sum.dim_IntList(mul_600, [2], True); mul_600 = None + div_34 = torch.ops.aten.div.Tensor(mul_120, 4096) + mul_601 = torch.ops.aten.mul.Tensor(div_34, sum_103); div_34 = sum_103 = None + sub_52 = torch.ops.aten.sub.Tensor(mul_598, mul_601); mul_598 = mul_601 = None + mul_602 = torch.ops.aten.mul.Tensor(sub_52, rsqrt_30); sub_52 = rsqrt_30 = None + mul_603 = torch.ops.aten.mul.Tensor(convert_element_type_1986, mul_120); convert_element_type_1986 = mul_120 = None + sum_104 = torch.ops.aten.sum.dim_IntList(mul_603, [0, 1]); mul_603 = None + convert_element_type_1989 = torch.ops.prims.convert_element_type.default(mul_602, torch.bfloat16); mul_602 = None + convert_element_type_1990 = torch.ops.prims.convert_element_type.default(sum_104, torch.bfloat16); sum_104 = None + all_reduce_34 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1990, 'sum', '1'); convert_element_type_1990 = None + wait_tensor_678 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_34); all_reduce_34 = None + convert_element_type_1991 = torch.ops.prims.convert_element_type.default(wait_tensor_678, torch.float32); wait_tensor_678 = None + reduce_scatter_tensor_254 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1991, 'avg', 32, '0'); convert_element_type_1991 = None + wait_tensor_679 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_254); reduce_scatter_tensor_254 = None + add_247 = torch.ops.aten.add.Tensor(add_244, convert_element_type_1989); add_244 = convert_element_type_1989 = None + all_gather_into_tensor_390 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_247, 8, '1') + wait_tensor_680 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_390); all_gather_into_tensor_390 = None + split_207 = torch.ops.aten.split.Tensor(wait_tensor_680, 2); wait_tensor_680 = None + getitem_1995 = split_207[0] + getitem_1996 = split_207[1] + getitem_1997 = split_207[2] + getitem_1998 = split_207[3] + getitem_1999 = split_207[4] + getitem_2000 = split_207[5] + getitem_2001 = split_207[6] + getitem_2002 = split_207[7]; split_207 = None + cat_199 = torch.ops.aten.cat.default([getitem_1995, getitem_1996, getitem_1997, getitem_1998, getitem_1999, getitem_2000, getitem_2001, getitem_2002], 1); getitem_1995 = getitem_1996 = getitem_1997 = getitem_1998 = getitem_1999 = getitem_2000 = getitem_2001 = getitem_2002 = None + view_2731 = torch.ops.aten.view.default(cat_199, [16384, 4096]); cat_199 = None + permute_901 = torch.ops.aten.permute.default(view_2731, [1, 0]) + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29); reduce_scatter_tensor_29 = None + add_57 = torch.ops.aten.add.Tensor(add_55, wait_tensor_190); wait_tensor_190 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16); primals_135 = None + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 32, '0'); convert_element_type_482 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32); add_57 = None + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_191) + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_484, 8, '1'); convert_element_type_484 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + split_67 = torch.ops.aten.split.Tensor(wait_tensor_192, 2); wait_tensor_192 = None + getitem_671 = split_67[0] + getitem_672 = split_67[1] + getitem_673 = split_67[2] + getitem_674 = split_67[3] + getitem_675 = split_67[4] + getitem_676 = split_67[5] + getitem_677 = split_67[6] + getitem_678 = split_67[7]; split_67 = None + cat_59 = torch.ops.aten.cat.default([getitem_671, getitem_672, getitem_673, getitem_674, getitem_675, getitem_676, getitem_677, getitem_678], 1); getitem_671 = getitem_672 = getitem_673 = getitem_674 = getitem_675 = getitem_676 = getitem_677 = getitem_678 = None + view_1068 = torch.ops.aten.view.default(cat_59, [16384, 4096]); cat_59 = None + view_1069 = torch.ops.aten.view.default(mm_102, [2, 8192, 1792]); mm_102 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_1069, torch.float32); view_1069 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 32, '0'); convert_element_type_490 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + mm_103 = torch.ops.aten.mm.default(view_1068, permute_163) + view_1076 = torch.ops.aten.view.default(mm_103, [2, 8192, 1792]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_1076) + view_1083 = torch.ops.aten.view.default(mul_119, [16384, 1792]); mul_119 = None + mm_465 = torch.ops.aten.mm.default(permute_901, view_1083); permute_901 = view_1083 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 32, '0'); convert_element_type_493 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_195, [1, 0]); wait_tensor_195 = None + permute_903 = torch.ops.aten.permute.default(permute_164, [1, 0]); permute_164 = None + mm_466 = torch.ops.aten.mm.default(view_2731, permute_903); view_2731 = permute_903 = None + view_2732 = torch.ops.aten.view.default(mm_466, [2, 8192, 1792]); mm_466 = None + convert_element_type_1996 = torch.ops.prims.convert_element_type.default(mm_465, torch.float32); mm_465 = None + reduce_scatter_tensor_255 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1996, 'avg', 32, '0'); convert_element_type_1996 = None + wait_tensor_681 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_255); reduce_scatter_tensor_255 = None + mul_604 = torch.ops.aten.mul.Tensor(view_2732, convert_element_type_489); convert_element_type_489 = None + mul_605 = torch.ops.aten.mul.Tensor(view_2732, view_1076); view_2732 = view_1076 = None + view_2733 = torch.ops.aten.view.default(mul_604, [16384, 1792]); mul_604 = None + permute_905 = torch.ops.aten.permute.default(view_2733, [1, 0]) + mm_467 = torch.ops.aten.mm.default(permute_905, view_1068); permute_905 = None + permute_907 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None + mm_468 = torch.ops.aten.mm.default(view_2733, permute_907); view_2733 = permute_907 = None + view_2734 = torch.ops.aten.view.default(mm_468, [2, 8192, 4096]); mm_468 = None + convert_element_type_2001 = torch.ops.prims.convert_element_type.default(mm_467, torch.float32); mm_467 = None + reduce_scatter_tensor_256 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2001, 'avg', 32, '0'); convert_element_type_2001 = None + wait_tensor_682 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_256); reduce_scatter_tensor_256 = None + convert_element_type_2002 = torch.ops.prims.convert_element_type.default(mul_605, torch.float32); mul_605 = None + neg_17 = torch.ops.aten.neg.default(convert_element_type_488) + exp_17 = torch.ops.aten.exp.default(neg_17); neg_17 = None + add_248 = torch.ops.aten.add.Tensor(exp_17, 1); exp_17 = None + reciprocal_17 = torch.ops.aten.reciprocal.default(add_248); add_248 = None + mul_606 = torch.ops.aten.mul.Tensor(reciprocal_17, 1); reciprocal_17 = None + mul_607 = torch.ops.aten.mul.Tensor(convert_element_type_2002, mul_606); convert_element_type_2002 = None + sub_53 = torch.ops.aten.sub.Tensor(1, mul_606); mul_606 = None + mul_608 = torch.ops.aten.mul.Tensor(convert_element_type_488, sub_53); convert_element_type_488 = sub_53 = None + add_249 = torch.ops.aten.add.Tensor(mul_608, 1); mul_608 = None + mul_609 = torch.ops.aten.mul.Tensor(mul_607, add_249); mul_607 = add_249 = None + convert_element_type_2004 = torch.ops.prims.convert_element_type.default(mul_609, torch.bfloat16); mul_609 = None + view_2735 = torch.ops.aten.view.default(convert_element_type_2004, [16384, 1792]); convert_element_type_2004 = None + permute_909 = torch.ops.aten.permute.default(view_2735, [1, 0]) + mm_469 = torch.ops.aten.mm.default(permute_909, view_1068); permute_909 = view_1068 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 32, '0'); convert_element_type_485 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + permute_911 = torch.ops.aten.permute.default(permute_162, [1, 0]); permute_162 = None + mm_470 = torch.ops.aten.mm.default(view_2735, permute_911); view_2735 = permute_911 = None + view_2736 = torch.ops.aten.view.default(mm_470, [2, 8192, 4096]); mm_470 = None + add_250 = torch.ops.aten.add.Tensor(view_2734, view_2736); view_2734 = view_2736 = None + convert_element_type_2009 = torch.ops.prims.convert_element_type.default(mm_469, torch.float32); mm_469 = None + reduce_scatter_tensor_257 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2009, 'avg', 32, '0'); convert_element_type_2009 = None + wait_tensor_683 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_257); reduce_scatter_tensor_257 = None + split_208 = torch.ops.aten.split.Tensor(add_250, 1024, 1); add_250 = None + getitem_2003 = split_208[0] + getitem_2004 = split_208[1] + getitem_2005 = split_208[2] + getitem_2006 = split_208[3] + getitem_2007 = split_208[4] + getitem_2008 = split_208[5] + getitem_2009 = split_208[6] + getitem_2010 = split_208[7]; split_208 = None + cat_200 = torch.ops.aten.cat.default([getitem_2003, getitem_2004, getitem_2005, getitem_2006, getitem_2007, getitem_2008, getitem_2009, getitem_2010]); getitem_2003 = getitem_2004 = getitem_2005 = getitem_2006 = getitem_2007 = getitem_2008 = getitem_2009 = getitem_2010 = None + reduce_scatter_tensor_258 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_200, 'sum', 8, '1'); cat_200 = None + wait_tensor_684 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_258); reduce_scatter_tensor_258 = None + convert_element_type_2010 = torch.ops.prims.convert_element_type.default(wait_tensor_684, torch.float32); wait_tensor_684 = None + convert_element_type_2012 = torch.ops.prims.convert_element_type.default(wait_tensor_191, torch.float32); wait_tensor_191 = None + mul_610 = torch.ops.aten.mul.Tensor(convert_element_type_2010, convert_element_type_2012); convert_element_type_2012 = None + mul_612 = torch.ops.aten.mul.Tensor(mul_116, mul_610) + sum_105 = torch.ops.aten.sum.dim_IntList(mul_612, [2], True); mul_612 = None + div_35 = torch.ops.aten.div.Tensor(mul_116, 4096) + mul_613 = torch.ops.aten.mul.Tensor(div_35, sum_105); div_35 = sum_105 = None + sub_54 = torch.ops.aten.sub.Tensor(mul_610, mul_613); mul_610 = mul_613 = None + mul_614 = torch.ops.aten.mul.Tensor(sub_54, rsqrt_29); sub_54 = rsqrt_29 = None + mul_615 = torch.ops.aten.mul.Tensor(convert_element_type_2010, mul_116); convert_element_type_2010 = mul_116 = None + sum_106 = torch.ops.aten.sum.dim_IntList(mul_615, [0, 1]); mul_615 = None + convert_element_type_2013 = torch.ops.prims.convert_element_type.default(mul_614, torch.bfloat16); mul_614 = None + convert_element_type_2014 = torch.ops.prims.convert_element_type.default(sum_106, torch.bfloat16); sum_106 = None + all_reduce_35 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2014, 'sum', '1'); convert_element_type_2014 = None + wait_tensor_685 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_35); all_reduce_35 = None + convert_element_type_2015 = torch.ops.prims.convert_element_type.default(wait_tensor_685, torch.float32); wait_tensor_685 = None + reduce_scatter_tensor_259 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2015, 'avg', 32, '0'); convert_element_type_2015 = None + wait_tensor_686 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_259); reduce_scatter_tensor_259 = None + add_251 = torch.ops.aten.add.Tensor(add_247, convert_element_type_2013); add_247 = convert_element_type_2013 = None + all_gather_into_tensor_391 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_251, 8, '1') + wait_tensor_687 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_391); all_gather_into_tensor_391 = None + split_209 = torch.ops.aten.split.Tensor(wait_tensor_687, 2); wait_tensor_687 = None + getitem_2011 = split_209[0] + getitem_2012 = split_209[1] + getitem_2013 = split_209[2] + getitem_2014 = split_209[3] + getitem_2015 = split_209[4] + getitem_2016 = split_209[5] + getitem_2017 = split_209[6] + getitem_2018 = split_209[7]; split_209 = None + cat_201 = torch.ops.aten.cat.default([getitem_2011, getitem_2012, getitem_2013, getitem_2014, getitem_2015, getitem_2016, getitem_2017, getitem_2018], 1); getitem_2011 = getitem_2012 = getitem_2013 = getitem_2014 = getitem_2015 = getitem_2016 = getitem_2017 = getitem_2018 = None + view_2737 = torch.ops.aten.view.default(cat_201, [16384, 4096]); cat_201 = None + permute_913 = torch.ops.aten.permute.default(view_2737, [1, 0]) + permute_160 = torch.ops.aten.permute.default(getitem_654, [0, 2, 1, 3]) + view_1050 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + view_1056 = torch.ops.aten.view.default(view_1050, [16384, 512]); view_1050 = None + mm_471 = torch.ops.aten.mm.default(permute_913, view_1056); permute_913 = view_1056 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16); primals_134 = None + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 32, '0'); convert_element_type_479 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + permute_915 = torch.ops.aten.permute.default(permute_161, [1, 0]); permute_161 = None + mm_472 = torch.ops.aten.mm.default(view_2737, permute_915); view_2737 = permute_915 = None + view_2738 = torch.ops.aten.view.default(mm_472, [2, 8192, 512]); mm_472 = None + convert_element_type_2020 = torch.ops.prims.convert_element_type.default(mm_471, torch.float32); mm_471 = None + reduce_scatter_tensor_260 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2020, 'avg', 32, '0'); convert_element_type_2020 = None + wait_tensor_688 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_260); reduce_scatter_tensor_260 = None + view_2739 = torch.ops.aten.view.default(view_2738, [2, 8192, 4, 128]); view_2738 = None + permute_917 = torch.ops.aten.permute.default(view_2739, [0, 2, 1, 3]); view_2739 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16); primals_130 = None + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 32, '0'); convert_element_type_463 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32); add_55 = None + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_184) + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_465, 8, '1'); convert_element_type_465 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + split_65 = torch.ops.aten.split.Tensor(wait_tensor_185, 2); wait_tensor_185 = None + getitem_646 = split_65[0] + getitem_647 = split_65[1] + getitem_648 = split_65[2] + getitem_649 = split_65[3] + getitem_650 = split_65[4] + getitem_651 = split_65[5] + getitem_652 = split_65[6] + getitem_653 = split_65[7]; split_65 = None + cat_57 = torch.ops.aten.cat.default([getitem_646, getitem_647, getitem_648, getitem_649, getitem_650, getitem_651, getitem_652, getitem_653], 1); getitem_646 = getitem_647 = getitem_648 = getitem_649 = getitem_650 = getitem_651 = getitem_652 = getitem_653 = None + view_1023 = torch.ops.aten.view.default(cat_57, [16384, 4096]); cat_57 = None + view_1024 = torch.ops.aten.view.default(mm_98, [2, 8192, 512]); mm_98 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16); primals_132 = None + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 32, '0'); convert_element_type_469 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + mm_99 = torch.ops.aten.mm.default(view_1023, permute_155) + view_1031 = torch.ops.aten.view.default(mm_99, [2, 8192, 128]); mm_99 = None + view_1038 = torch.ops.aten.view.default(mm_100, [2, 8192, 128]); mm_100 = None + view_1040 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1041 = torch.ops.aten.view.default(view_1031, [2, 8192, -1, 128]); view_1031 = None + view_1042 = torch.ops.aten.view.default(view_1038, [2, 8192, -1, 128]); view_1038 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_1040, torch.float32); view_1040 = None + view_1043 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 4, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_1043); view_1043 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_1041, torch.float32); view_1041 = None + view_1044 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 1, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_1044); view_1044 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_37); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_1046 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 4, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_37); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_1047 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 1, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_1046, torch.bfloat16); view_1046 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_1047, torch.bfloat16); view_1047 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 1, 4, 128]); unsqueeze_28 = None + view_1048 = torch.ops.aten.view.default(expand_28, [2, 8192, 4, 128]); expand_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_1042, 3); view_1042 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 1, 4, 128]); unsqueeze_29 = None + view_1049 = torch.ops.aten.view.default(expand_29, [2, 8192, 4, 128]); expand_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_1048, [0, 2, 1, 3]); view_1048 = None + permute_159 = torch.ops.aten.permute.default(view_1049, [0, 2, 1, 3]); view_1049 = None + _scaled_dot_product_cudnn_attention_backward_17 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_917, permute_157, permute_158, permute_159, getitem_654, getitem_655, getitem_660, getitem_661, None, None, None, 8192, 8192, 0.0, True); permute_917 = permute_157 = permute_158 = permute_159 = getitem_654 = getitem_655 = getitem_660 = getitem_661 = None + getitem_2019 = _scaled_dot_product_cudnn_attention_backward_17[0] + getitem_2020 = _scaled_dot_product_cudnn_attention_backward_17[1] + getitem_2021 = _scaled_dot_product_cudnn_attention_backward_17[2]; _scaled_dot_product_cudnn_attention_backward_17 = None + permute_918 = torch.ops.aten.permute.default(getitem_2021, [0, 2, 1, 3]); getitem_2021 = None + permute_919 = torch.ops.aten.permute.default(getitem_2020, [0, 2, 1, 3]); getitem_2020 = None + permute_920 = torch.ops.aten.permute.default(getitem_2019, [0, 2, 1, 3]); getitem_2019 = None + view_2740 = torch.ops.aten.view.default(permute_918, [2, 8192, 1, 4, 128]); permute_918 = None + sum_107 = torch.ops.aten.sum.dim_IntList(view_2740, [3], True); view_2740 = None + squeeze_34 = torch.ops.aten.squeeze.dim(sum_107, 3); sum_107 = None + view_2741 = torch.ops.aten.view.default(permute_919, [2, 8192, 1, 4, 128]); permute_919 = None + sum_108 = torch.ops.aten.sum.dim_IntList(view_2741, [3], True); view_2741 = None + squeeze_35 = torch.ops.aten.squeeze.dim(sum_108, 3); sum_108 = None + convert_element_type_2021 = torch.ops.prims.convert_element_type.default(squeeze_35, torch.float32); squeeze_35 = None + convert_element_type_2022 = torch.ops.prims.convert_element_type.default(permute_920, torch.float32); permute_920 = None + view_2742 = torch.ops.aten.view.default(convert_element_type_2021, [2, 8192, 1, 64, 2]); convert_element_type_2021 = None + view_as_complex_98 = torch.ops.aten.view_as_complex.default(view_2742); view_2742 = None + mul_616 = torch.ops.aten.mul.Tensor(view_as_complex_98, _conj); view_as_complex_98 = None + view_2743 = torch.ops.aten.view.default(convert_element_type_2022, [2, 8192, 4, 64, 2]); convert_element_type_2022 = None + view_as_complex_99 = torch.ops.aten.view_as_complex.default(view_2743); view_2743 = None + mul_617 = torch.ops.aten.mul.Tensor(view_as_complex_99, _conj); view_as_complex_99 = None + view_as_real_98 = torch.ops.aten.view_as_real.default(mul_616); mul_616 = None + view_2744 = torch.ops.aten.view.default(view_as_real_98, [2, 8192, 1, 128]); view_as_real_98 = None + convert_element_type_2023 = torch.ops.prims.convert_element_type.default(view_2744, torch.bfloat16); view_2744 = None + view_as_real_99 = torch.ops.aten.view_as_real.default(mul_617); mul_617 = None + view_2745 = torch.ops.aten.view.default(view_as_real_99, [2, 8192, 4, 128]); view_as_real_99 = None + convert_element_type_2024 = torch.ops.prims.convert_element_type.default(view_2745, torch.bfloat16); view_2745 = None + view_2746 = torch.ops.aten.view.default(squeeze_34, [2, 8192, 128]); squeeze_34 = None + view_2747 = torch.ops.aten.view.default(convert_element_type_2023, [2, 8192, 128]); convert_element_type_2023 = None + view_2748 = torch.ops.aten.view.default(convert_element_type_2024, [2, 8192, 512]); convert_element_type_2024 = None + view_2749 = torch.ops.aten.view.default(view_2746, [16384, 128]); view_2746 = None + permute_921 = torch.ops.aten.permute.default(view_2749, [1, 0]) + mm_473 = torch.ops.aten.mm.default(permute_921, view_1023); permute_921 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16); primals_133 = None + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 32, '0'); convert_element_type_472 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + permute_923 = torch.ops.aten.permute.default(permute_156, [1, 0]); permute_156 = None + mm_474 = torch.ops.aten.mm.default(view_2749, permute_923); view_2749 = permute_923 = None + view_2750 = torch.ops.aten.view.default(mm_474, [2, 8192, 4096]); mm_474 = None + convert_element_type_2029 = torch.ops.prims.convert_element_type.default(mm_473, torch.float32); mm_473 = None + reduce_scatter_tensor_261 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2029, 'avg', 32, '0'); convert_element_type_2029 = None + wait_tensor_689 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_261); reduce_scatter_tensor_261 = None + view_2751 = torch.ops.aten.view.default(view_2747, [16384, 128]); view_2747 = None + permute_925 = torch.ops.aten.permute.default(view_2751, [1, 0]) + mm_475 = torch.ops.aten.mm.default(permute_925, view_1023); permute_925 = None + permute_927 = torch.ops.aten.permute.default(permute_155, [1, 0]); permute_155 = None + mm_476 = torch.ops.aten.mm.default(view_2751, permute_927); view_2751 = permute_927 = None + view_2752 = torch.ops.aten.view.default(mm_476, [2, 8192, 4096]); mm_476 = None + add_252 = torch.ops.aten.add.Tensor(view_2750, view_2752); view_2750 = view_2752 = None + convert_element_type_2034 = torch.ops.prims.convert_element_type.default(mm_475, torch.float32); mm_475 = None + reduce_scatter_tensor_262 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2034, 'avg', 32, '0'); convert_element_type_2034 = None + wait_tensor_690 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_262); reduce_scatter_tensor_262 = None + view_2753 = torch.ops.aten.view.default(view_2748, [16384, 512]); view_2748 = None + permute_929 = torch.ops.aten.permute.default(view_2753, [1, 0]) + mm_477 = torch.ops.aten.mm.default(permute_929, view_1023); permute_929 = view_1023 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16); primals_131 = None + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 32, '0'); convert_element_type_466 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_186, [1, 0]); wait_tensor_186 = None + permute_931 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None + mm_478 = torch.ops.aten.mm.default(view_2753, permute_931); view_2753 = permute_931 = None + view_2754 = torch.ops.aten.view.default(mm_478, [2, 8192, 4096]); mm_478 = None + add_253 = torch.ops.aten.add.Tensor(add_252, view_2754); add_252 = view_2754 = None + convert_element_type_2039 = torch.ops.prims.convert_element_type.default(mm_477, torch.float32); mm_477 = None + reduce_scatter_tensor_263 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2039, 'avg', 32, '0'); convert_element_type_2039 = None + wait_tensor_691 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_263); reduce_scatter_tensor_263 = None + split_210 = torch.ops.aten.split.Tensor(add_253, 1024, 1); add_253 = None + getitem_2022 = split_210[0] + getitem_2023 = split_210[1] + getitem_2024 = split_210[2] + getitem_2025 = split_210[3] + getitem_2026 = split_210[4] + getitem_2027 = split_210[5] + getitem_2028 = split_210[6] + getitem_2029 = split_210[7]; split_210 = None + cat_202 = torch.ops.aten.cat.default([getitem_2022, getitem_2023, getitem_2024, getitem_2025, getitem_2026, getitem_2027, getitem_2028, getitem_2029]); getitem_2022 = getitem_2023 = getitem_2024 = getitem_2025 = getitem_2026 = getitem_2027 = getitem_2028 = getitem_2029 = None + reduce_scatter_tensor_264 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_202, 'sum', 8, '1'); cat_202 = None + wait_tensor_692 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_264); reduce_scatter_tensor_264 = None + convert_element_type_2040 = torch.ops.prims.convert_element_type.default(wait_tensor_692, torch.float32); wait_tensor_692 = None + convert_element_type_2042 = torch.ops.prims.convert_element_type.default(wait_tensor_184, torch.float32); wait_tensor_184 = None + mul_618 = torch.ops.aten.mul.Tensor(convert_element_type_2040, convert_element_type_2042); convert_element_type_2042 = None + mul_620 = torch.ops.aten.mul.Tensor(mul_112, mul_618) + sum_109 = torch.ops.aten.sum.dim_IntList(mul_620, [2], True); mul_620 = None + div_36 = torch.ops.aten.div.Tensor(mul_112, 4096) + mul_621 = torch.ops.aten.mul.Tensor(div_36, sum_109); div_36 = sum_109 = None + sub_55 = torch.ops.aten.sub.Tensor(mul_618, mul_621); mul_618 = mul_621 = None + mul_622 = torch.ops.aten.mul.Tensor(sub_55, rsqrt_28); sub_55 = rsqrt_28 = None + mul_623 = torch.ops.aten.mul.Tensor(convert_element_type_2040, mul_112); convert_element_type_2040 = mul_112 = None + sum_110 = torch.ops.aten.sum.dim_IntList(mul_623, [0, 1]); mul_623 = None + convert_element_type_2043 = torch.ops.prims.convert_element_type.default(mul_622, torch.bfloat16); mul_622 = None + convert_element_type_2044 = torch.ops.prims.convert_element_type.default(sum_110, torch.bfloat16); sum_110 = None + all_reduce_36 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2044, 'sum', '1'); convert_element_type_2044 = None + wait_tensor_693 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_36); all_reduce_36 = None + convert_element_type_2045 = torch.ops.prims.convert_element_type.default(wait_tensor_693, torch.float32); wait_tensor_693 = None + reduce_scatter_tensor_265 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2045, 'avg', 32, '0'); convert_element_type_2045 = None + wait_tensor_694 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_265); reduce_scatter_tensor_265 = None + add_254 = torch.ops.aten.add.Tensor(add_251, convert_element_type_2043); add_251 = convert_element_type_2043 = None + all_gather_into_tensor_392 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_254, 8, '1') + wait_tensor_695 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_392); all_gather_into_tensor_392 = None + split_211 = torch.ops.aten.split.Tensor(wait_tensor_695, 2); wait_tensor_695 = None + getitem_2030 = split_211[0] + getitem_2031 = split_211[1] + getitem_2032 = split_211[2] + getitem_2033 = split_211[3] + getitem_2034 = split_211[4] + getitem_2035 = split_211[5] + getitem_2036 = split_211[6] + getitem_2037 = split_211[7]; split_211 = None + cat_203 = torch.ops.aten.cat.default([getitem_2030, getitem_2031, getitem_2032, getitem_2033, getitem_2034, getitem_2035, getitem_2036, getitem_2037], 1); getitem_2030 = getitem_2031 = getitem_2032 = getitem_2033 = getitem_2034 = getitem_2035 = getitem_2036 = getitem_2037 = None + view_2755 = torch.ops.aten.view.default(cat_203, [16384, 4096]); cat_203 = None + permute_933 = torch.ops.aten.permute.default(view_2755, [1, 0]) + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27); reduce_scatter_tensor_27 = None + add_53 = torch.ops.aten.add.Tensor(add_51, wait_tensor_177); wait_tensor_177 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16); primals_126 = None + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 32, '0'); convert_element_type_449 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32); add_53 = None + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_178) + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 8, '1'); convert_element_type_451 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + split_63 = torch.ops.aten.split.Tensor(wait_tensor_179, 2); wait_tensor_179 = None + getitem_630 = split_63[0] + getitem_631 = split_63[1] + getitem_632 = split_63[2] + getitem_633 = split_63[3] + getitem_634 = split_63[4] + getitem_635 = split_63[5] + getitem_636 = split_63[6] + getitem_637 = split_63[7]; split_63 = None + cat_55 = torch.ops.aten.cat.default([getitem_630, getitem_631, getitem_632, getitem_633, getitem_634, getitem_635, getitem_636, getitem_637], 1); getitem_630 = getitem_631 = getitem_632 = getitem_633 = getitem_634 = getitem_635 = getitem_636 = getitem_637 = None + view_996 = torch.ops.aten.view.default(cat_55, [16384, 4096]); cat_55 = None + view_997 = torch.ops.aten.view.default(mm_95, [2, 8192, 1792]); mm_95 = None + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16); primals_128 = None + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 32, '0'); convert_element_type_457 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_181, [1, 0]); wait_tensor_181 = None + mm_96 = torch.ops.aten.mm.default(view_996, permute_152) + view_1004 = torch.ops.aten.view.default(mm_96, [2, 8192, 1792]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_1004) + view_1011 = torch.ops.aten.view.default(mul_111, [16384, 1792]); mul_111 = None + mm_479 = torch.ops.aten.mm.default(permute_933, view_1011); permute_933 = view_1011 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16); primals_129 = None + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 32, '0'); convert_element_type_460 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + permute_935 = torch.ops.aten.permute.default(permute_153, [1, 0]); permute_153 = None + mm_480 = torch.ops.aten.mm.default(view_2755, permute_935); view_2755 = permute_935 = None + view_2756 = torch.ops.aten.view.default(mm_480, [2, 8192, 1792]); mm_480 = None + convert_element_type_2050 = torch.ops.prims.convert_element_type.default(mm_479, torch.float32); mm_479 = None + reduce_scatter_tensor_266 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2050, 'avg', 32, '0'); convert_element_type_2050 = None + wait_tensor_696 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_266); reduce_scatter_tensor_266 = None + mul_624 = torch.ops.aten.mul.Tensor(view_2756, convert_element_type_456); convert_element_type_456 = None + mul_625 = torch.ops.aten.mul.Tensor(view_2756, view_1004); view_2756 = view_1004 = None + view_2757 = torch.ops.aten.view.default(mul_624, [16384, 1792]); mul_624 = None + permute_937 = torch.ops.aten.permute.default(view_2757, [1, 0]) + mm_481 = torch.ops.aten.mm.default(permute_937, view_996); permute_937 = None + permute_939 = torch.ops.aten.permute.default(permute_152, [1, 0]); permute_152 = None + mm_482 = torch.ops.aten.mm.default(view_2757, permute_939); view_2757 = permute_939 = None + view_2758 = torch.ops.aten.view.default(mm_482, [2, 8192, 4096]); mm_482 = None + convert_element_type_2055 = torch.ops.prims.convert_element_type.default(mm_481, torch.float32); mm_481 = None + reduce_scatter_tensor_267 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2055, 'avg', 32, '0'); convert_element_type_2055 = None + wait_tensor_697 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_267); reduce_scatter_tensor_267 = None + convert_element_type_2056 = torch.ops.prims.convert_element_type.default(mul_625, torch.float32); mul_625 = None + neg_18 = torch.ops.aten.neg.default(convert_element_type_455) + exp_18 = torch.ops.aten.exp.default(neg_18); neg_18 = None + add_255 = torch.ops.aten.add.Tensor(exp_18, 1); exp_18 = None + reciprocal_18 = torch.ops.aten.reciprocal.default(add_255); add_255 = None + mul_626 = torch.ops.aten.mul.Tensor(reciprocal_18, 1); reciprocal_18 = None + mul_627 = torch.ops.aten.mul.Tensor(convert_element_type_2056, mul_626); convert_element_type_2056 = None + sub_56 = torch.ops.aten.sub.Tensor(1, mul_626); mul_626 = None + mul_628 = torch.ops.aten.mul.Tensor(convert_element_type_455, sub_56); convert_element_type_455 = sub_56 = None + add_256 = torch.ops.aten.add.Tensor(mul_628, 1); mul_628 = None + mul_629 = torch.ops.aten.mul.Tensor(mul_627, add_256); mul_627 = add_256 = None + convert_element_type_2058 = torch.ops.prims.convert_element_type.default(mul_629, torch.bfloat16); mul_629 = None + view_2759 = torch.ops.aten.view.default(convert_element_type_2058, [16384, 1792]); convert_element_type_2058 = None + permute_941 = torch.ops.aten.permute.default(view_2759, [1, 0]) + mm_483 = torch.ops.aten.mm.default(permute_941, view_996); permute_941 = view_996 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16); primals_127 = None + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 32, '0'); convert_element_type_452 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + permute_943 = torch.ops.aten.permute.default(permute_151, [1, 0]); permute_151 = None + mm_484 = torch.ops.aten.mm.default(view_2759, permute_943); view_2759 = permute_943 = None + view_2760 = torch.ops.aten.view.default(mm_484, [2, 8192, 4096]); mm_484 = None + add_257 = torch.ops.aten.add.Tensor(view_2758, view_2760); view_2758 = view_2760 = None + convert_element_type_2063 = torch.ops.prims.convert_element_type.default(mm_483, torch.float32); mm_483 = None + reduce_scatter_tensor_268 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2063, 'avg', 32, '0'); convert_element_type_2063 = None + wait_tensor_698 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_268); reduce_scatter_tensor_268 = None + split_212 = torch.ops.aten.split.Tensor(add_257, 1024, 1); add_257 = None + getitem_2038 = split_212[0] + getitem_2039 = split_212[1] + getitem_2040 = split_212[2] + getitem_2041 = split_212[3] + getitem_2042 = split_212[4] + getitem_2043 = split_212[5] + getitem_2044 = split_212[6] + getitem_2045 = split_212[7]; split_212 = None + cat_204 = torch.ops.aten.cat.default([getitem_2038, getitem_2039, getitem_2040, getitem_2041, getitem_2042, getitem_2043, getitem_2044, getitem_2045]); getitem_2038 = getitem_2039 = getitem_2040 = getitem_2041 = getitem_2042 = getitem_2043 = getitem_2044 = getitem_2045 = None + reduce_scatter_tensor_269 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_204, 'sum', 8, '1'); cat_204 = None + wait_tensor_699 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_269); reduce_scatter_tensor_269 = None + convert_element_type_2064 = torch.ops.prims.convert_element_type.default(wait_tensor_699, torch.float32); wait_tensor_699 = None + convert_element_type_2066 = torch.ops.prims.convert_element_type.default(wait_tensor_178, torch.float32); wait_tensor_178 = None + mul_630 = torch.ops.aten.mul.Tensor(convert_element_type_2064, convert_element_type_2066); convert_element_type_2066 = None + mul_632 = torch.ops.aten.mul.Tensor(mul_108, mul_630) + sum_111 = torch.ops.aten.sum.dim_IntList(mul_632, [2], True); mul_632 = None + div_37 = torch.ops.aten.div.Tensor(mul_108, 4096) + mul_633 = torch.ops.aten.mul.Tensor(div_37, sum_111); div_37 = sum_111 = None + sub_57 = torch.ops.aten.sub.Tensor(mul_630, mul_633); mul_630 = mul_633 = None + mul_634 = torch.ops.aten.mul.Tensor(sub_57, rsqrt_27); sub_57 = rsqrt_27 = None + mul_635 = torch.ops.aten.mul.Tensor(convert_element_type_2064, mul_108); convert_element_type_2064 = mul_108 = None + sum_112 = torch.ops.aten.sum.dim_IntList(mul_635, [0, 1]); mul_635 = None + convert_element_type_2067 = torch.ops.prims.convert_element_type.default(mul_634, torch.bfloat16); mul_634 = None + convert_element_type_2068 = torch.ops.prims.convert_element_type.default(sum_112, torch.bfloat16); sum_112 = None + all_reduce_37 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2068, 'sum', '1'); convert_element_type_2068 = None + wait_tensor_700 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_37); all_reduce_37 = None + convert_element_type_2069 = torch.ops.prims.convert_element_type.default(wait_tensor_700, torch.float32); wait_tensor_700 = None + reduce_scatter_tensor_270 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2069, 'avg', 32, '0'); convert_element_type_2069 = None + wait_tensor_701 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_270); reduce_scatter_tensor_270 = None + add_258 = torch.ops.aten.add.Tensor(add_254, convert_element_type_2067); add_254 = convert_element_type_2067 = None + all_gather_into_tensor_393 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_258, 8, '1') + wait_tensor_702 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_393); all_gather_into_tensor_393 = None + split_213 = torch.ops.aten.split.Tensor(wait_tensor_702, 2); wait_tensor_702 = None + getitem_2046 = split_213[0] + getitem_2047 = split_213[1] + getitem_2048 = split_213[2] + getitem_2049 = split_213[3] + getitem_2050 = split_213[4] + getitem_2051 = split_213[5] + getitem_2052 = split_213[6] + getitem_2053 = split_213[7]; split_213 = None + cat_205 = torch.ops.aten.cat.default([getitem_2046, getitem_2047, getitem_2048, getitem_2049, getitem_2050, getitem_2051, getitem_2052, getitem_2053], 1); getitem_2046 = getitem_2047 = getitem_2048 = getitem_2049 = getitem_2050 = getitem_2051 = getitem_2052 = getitem_2053 = None + view_2761 = torch.ops.aten.view.default(cat_205, [16384, 4096]); cat_205 = None + permute_945 = torch.ops.aten.permute.default(view_2761, [1, 0]) + permute_149 = torch.ops.aten.permute.default(getitem_613, [0, 2, 1, 3]) + view_978 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + view_984 = torch.ops.aten.view.default(view_978, [16384, 512]); view_978 = None + mm_485 = torch.ops.aten.mm.default(permute_945, view_984); permute_945 = view_984 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 32, '0'); convert_element_type_446 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + permute_947 = torch.ops.aten.permute.default(permute_150, [1, 0]); permute_150 = None + mm_486 = torch.ops.aten.mm.default(view_2761, permute_947); view_2761 = permute_947 = None + view_2762 = torch.ops.aten.view.default(mm_486, [2, 8192, 512]); mm_486 = None + convert_element_type_2074 = torch.ops.prims.convert_element_type.default(mm_485, torch.float32); mm_485 = None + reduce_scatter_tensor_271 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2074, 'avg', 32, '0'); convert_element_type_2074 = None + wait_tensor_703 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_271); reduce_scatter_tensor_271 = None + view_2763 = torch.ops.aten.view.default(view_2762, [2, 8192, 4, 128]); view_2762 = None + permute_949 = torch.ops.aten.permute.default(view_2763, [0, 2, 1, 3]); view_2763 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 32, '0'); convert_element_type_430 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32); add_51 = None + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_171) + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_432, 8, '1'); convert_element_type_432 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + split_61 = torch.ops.aten.split.Tensor(wait_tensor_172, 2); wait_tensor_172 = None + getitem_605 = split_61[0] + getitem_606 = split_61[1] + getitem_607 = split_61[2] + getitem_608 = split_61[3] + getitem_609 = split_61[4] + getitem_610 = split_61[5] + getitem_611 = split_61[6] + getitem_612 = split_61[7]; split_61 = None + cat_53 = torch.ops.aten.cat.default([getitem_605, getitem_606, getitem_607, getitem_608, getitem_609, getitem_610, getitem_611, getitem_612], 1); getitem_605 = getitem_606 = getitem_607 = getitem_608 = getitem_609 = getitem_610 = getitem_611 = getitem_612 = None + view_951 = torch.ops.aten.view.default(cat_53, [16384, 4096]); cat_53 = None + view_952 = torch.ops.aten.view.default(mm_91, [2, 8192, 512]); mm_91 = None + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 32, '0'); convert_element_type_436 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_92 = torch.ops.aten.mm.default(view_951, permute_144) + view_959 = torch.ops.aten.view.default(mm_92, [2, 8192, 128]); mm_92 = None + view_966 = torch.ops.aten.view.default(mm_93, [2, 8192, 128]); mm_93 = None + view_968 = torch.ops.aten.view.default(view_952, [2, 8192, -1, 128]); view_952 = None + view_969 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_970 = torch.ops.aten.view.default(view_966, [2, 8192, -1, 128]); view_966 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_968, torch.float32); view_968 = None + view_971 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 4, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_971); view_971 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_969, torch.float32); view_969 = None + view_972 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 1, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_972); view_972 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_37); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_974 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 4, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_37); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_975 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 1, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_974, torch.bfloat16); view_974 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_975, torch.bfloat16); view_975 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 1, 4, 128]); unsqueeze_26 = None + view_976 = torch.ops.aten.view.default(expand_26, [2, 8192, 4, 128]); expand_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_970, 3); view_970 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 1, 4, 128]); unsqueeze_27 = None + view_977 = torch.ops.aten.view.default(expand_27, [2, 8192, 4, 128]); expand_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_976, [0, 2, 1, 3]); view_976 = None + permute_148 = torch.ops.aten.permute.default(view_977, [0, 2, 1, 3]); view_977 = None + _scaled_dot_product_cudnn_attention_backward_18 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_949, permute_146, permute_147, permute_148, getitem_613, getitem_614, getitem_619, getitem_620, None, None, None, 8192, 8192, 0.0, True); permute_949 = permute_146 = permute_147 = permute_148 = getitem_613 = getitem_614 = getitem_619 = getitem_620 = None + getitem_2054 = _scaled_dot_product_cudnn_attention_backward_18[0] + getitem_2055 = _scaled_dot_product_cudnn_attention_backward_18[1] + getitem_2056 = _scaled_dot_product_cudnn_attention_backward_18[2]; _scaled_dot_product_cudnn_attention_backward_18 = None + permute_950 = torch.ops.aten.permute.default(getitem_2056, [0, 2, 1, 3]); getitem_2056 = None + permute_951 = torch.ops.aten.permute.default(getitem_2055, [0, 2, 1, 3]); getitem_2055 = None + permute_952 = torch.ops.aten.permute.default(getitem_2054, [0, 2, 1, 3]); getitem_2054 = None + view_2764 = torch.ops.aten.view.default(permute_950, [2, 8192, 1, 4, 128]); permute_950 = None + sum_113 = torch.ops.aten.sum.dim_IntList(view_2764, [3], True); view_2764 = None + squeeze_36 = torch.ops.aten.squeeze.dim(sum_113, 3); sum_113 = None + view_2765 = torch.ops.aten.view.default(permute_951, [2, 8192, 1, 4, 128]); permute_951 = None + sum_114 = torch.ops.aten.sum.dim_IntList(view_2765, [3], True); view_2765 = None + squeeze_37 = torch.ops.aten.squeeze.dim(sum_114, 3); sum_114 = None + convert_element_type_2075 = torch.ops.prims.convert_element_type.default(squeeze_37, torch.float32); squeeze_37 = None + convert_element_type_2076 = torch.ops.prims.convert_element_type.default(permute_952, torch.float32); permute_952 = None + view_2766 = torch.ops.aten.view.default(convert_element_type_2075, [2, 8192, 1, 64, 2]); convert_element_type_2075 = None + view_as_complex_100 = torch.ops.aten.view_as_complex.default(view_2766); view_2766 = None + mul_636 = torch.ops.aten.mul.Tensor(view_as_complex_100, _conj); view_as_complex_100 = None + view_2767 = torch.ops.aten.view.default(convert_element_type_2076, [2, 8192, 4, 64, 2]); convert_element_type_2076 = None + view_as_complex_101 = torch.ops.aten.view_as_complex.default(view_2767); view_2767 = None + mul_637 = torch.ops.aten.mul.Tensor(view_as_complex_101, _conj); view_as_complex_101 = None + view_as_real_100 = torch.ops.aten.view_as_real.default(mul_636); mul_636 = None + view_2768 = torch.ops.aten.view.default(view_as_real_100, [2, 8192, 1, 128]); view_as_real_100 = None + convert_element_type_2077 = torch.ops.prims.convert_element_type.default(view_2768, torch.bfloat16); view_2768 = None + view_as_real_101 = torch.ops.aten.view_as_real.default(mul_637); mul_637 = None + view_2769 = torch.ops.aten.view.default(view_as_real_101, [2, 8192, 4, 128]); view_as_real_101 = None + convert_element_type_2078 = torch.ops.prims.convert_element_type.default(view_2769, torch.bfloat16); view_2769 = None + view_2770 = torch.ops.aten.view.default(squeeze_36, [2, 8192, 128]); squeeze_36 = None + view_2771 = torch.ops.aten.view.default(convert_element_type_2077, [2, 8192, 128]); convert_element_type_2077 = None + view_2772 = torch.ops.aten.view.default(convert_element_type_2078, [2, 8192, 512]); convert_element_type_2078 = None + view_2773 = torch.ops.aten.view.default(view_2770, [16384, 128]); view_2770 = None + permute_953 = torch.ops.aten.permute.default(view_2773, [1, 0]) + mm_487 = torch.ops.aten.mm.default(permute_953, view_951); permute_953 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 32, '0'); convert_element_type_439 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + permute_955 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None + mm_488 = torch.ops.aten.mm.default(view_2773, permute_955); view_2773 = permute_955 = None + view_2774 = torch.ops.aten.view.default(mm_488, [2, 8192, 4096]); mm_488 = None + convert_element_type_2083 = torch.ops.prims.convert_element_type.default(mm_487, torch.float32); mm_487 = None + reduce_scatter_tensor_272 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2083, 'avg', 32, '0'); convert_element_type_2083 = None + wait_tensor_704 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_272); reduce_scatter_tensor_272 = None + view_2775 = torch.ops.aten.view.default(view_2771, [16384, 128]); view_2771 = None + permute_957 = torch.ops.aten.permute.default(view_2775, [1, 0]) + mm_489 = torch.ops.aten.mm.default(permute_957, view_951); permute_957 = None + permute_959 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None + mm_490 = torch.ops.aten.mm.default(view_2775, permute_959); view_2775 = permute_959 = None + view_2776 = torch.ops.aten.view.default(mm_490, [2, 8192, 4096]); mm_490 = None + add_259 = torch.ops.aten.add.Tensor(view_2774, view_2776); view_2774 = view_2776 = None + convert_element_type_2088 = torch.ops.prims.convert_element_type.default(mm_489, torch.float32); mm_489 = None + reduce_scatter_tensor_273 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2088, 'avg', 32, '0'); convert_element_type_2088 = None + wait_tensor_705 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_273); reduce_scatter_tensor_273 = None + view_2777 = torch.ops.aten.view.default(view_2772, [16384, 512]); view_2772 = None + permute_961 = torch.ops.aten.permute.default(view_2777, [1, 0]) + mm_491 = torch.ops.aten.mm.default(permute_961, view_951); permute_961 = view_951 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 32, '0'); convert_element_type_433 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + permute_963 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None + mm_492 = torch.ops.aten.mm.default(view_2777, permute_963); view_2777 = permute_963 = None + view_2778 = torch.ops.aten.view.default(mm_492, [2, 8192, 4096]); mm_492 = None + add_260 = torch.ops.aten.add.Tensor(add_259, view_2778); add_259 = view_2778 = None + convert_element_type_2093 = torch.ops.prims.convert_element_type.default(mm_491, torch.float32); mm_491 = None + reduce_scatter_tensor_274 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2093, 'avg', 32, '0'); convert_element_type_2093 = None + wait_tensor_706 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_274); reduce_scatter_tensor_274 = None + split_214 = torch.ops.aten.split.Tensor(add_260, 1024, 1); add_260 = None + getitem_2057 = split_214[0] + getitem_2058 = split_214[1] + getitem_2059 = split_214[2] + getitem_2060 = split_214[3] + getitem_2061 = split_214[4] + getitem_2062 = split_214[5] + getitem_2063 = split_214[6] + getitem_2064 = split_214[7]; split_214 = None + cat_206 = torch.ops.aten.cat.default([getitem_2057, getitem_2058, getitem_2059, getitem_2060, getitem_2061, getitem_2062, getitem_2063, getitem_2064]); getitem_2057 = getitem_2058 = getitem_2059 = getitem_2060 = getitem_2061 = getitem_2062 = getitem_2063 = getitem_2064 = None + reduce_scatter_tensor_275 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_206, 'sum', 8, '1'); cat_206 = None + wait_tensor_707 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_275); reduce_scatter_tensor_275 = None + convert_element_type_2094 = torch.ops.prims.convert_element_type.default(wait_tensor_707, torch.float32); wait_tensor_707 = None + convert_element_type_2096 = torch.ops.prims.convert_element_type.default(wait_tensor_171, torch.float32); wait_tensor_171 = None + mul_638 = torch.ops.aten.mul.Tensor(convert_element_type_2094, convert_element_type_2096); convert_element_type_2096 = None + mul_640 = torch.ops.aten.mul.Tensor(mul_104, mul_638) + sum_115 = torch.ops.aten.sum.dim_IntList(mul_640, [2], True); mul_640 = None + div_38 = torch.ops.aten.div.Tensor(mul_104, 4096) + mul_641 = torch.ops.aten.mul.Tensor(div_38, sum_115); div_38 = sum_115 = None + sub_58 = torch.ops.aten.sub.Tensor(mul_638, mul_641); mul_638 = mul_641 = None + mul_642 = torch.ops.aten.mul.Tensor(sub_58, rsqrt_26); sub_58 = rsqrt_26 = None + mul_643 = torch.ops.aten.mul.Tensor(convert_element_type_2094, mul_104); convert_element_type_2094 = mul_104 = None + sum_116 = torch.ops.aten.sum.dim_IntList(mul_643, [0, 1]); mul_643 = None + convert_element_type_2097 = torch.ops.prims.convert_element_type.default(mul_642, torch.bfloat16); mul_642 = None + convert_element_type_2098 = torch.ops.prims.convert_element_type.default(sum_116, torch.bfloat16); sum_116 = None + all_reduce_38 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2098, 'sum', '1'); convert_element_type_2098 = None + wait_tensor_708 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_38); all_reduce_38 = None + convert_element_type_2099 = torch.ops.prims.convert_element_type.default(wait_tensor_708, torch.float32); wait_tensor_708 = None + reduce_scatter_tensor_276 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2099, 'avg', 32, '0'); convert_element_type_2099 = None + wait_tensor_709 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_276); reduce_scatter_tensor_276 = None + add_261 = torch.ops.aten.add.Tensor(add_258, convert_element_type_2097); add_258 = convert_element_type_2097 = None + all_gather_into_tensor_394 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_261, 8, '1') + wait_tensor_710 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_394); all_gather_into_tensor_394 = None + split_215 = torch.ops.aten.split.Tensor(wait_tensor_710, 2); wait_tensor_710 = None + getitem_2065 = split_215[0] + getitem_2066 = split_215[1] + getitem_2067 = split_215[2] + getitem_2068 = split_215[3] + getitem_2069 = split_215[4] + getitem_2070 = split_215[5] + getitem_2071 = split_215[6] + getitem_2072 = split_215[7]; split_215 = None + cat_207 = torch.ops.aten.cat.default([getitem_2065, getitem_2066, getitem_2067, getitem_2068, getitem_2069, getitem_2070, getitem_2071, getitem_2072], 1); getitem_2065 = getitem_2066 = getitem_2067 = getitem_2068 = getitem_2069 = getitem_2070 = getitem_2071 = getitem_2072 = None + view_2779 = torch.ops.aten.view.default(cat_207, [16384, 4096]); cat_207 = None + permute_965 = torch.ops.aten.permute.default(view_2779, [1, 0]) + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25); reduce_scatter_tensor_25 = None + add_49 = torch.ops.aten.add.Tensor(add_47, wait_tensor_164); wait_tensor_164 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16); primals_117 = None + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 32, '0'); convert_element_type_416 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32); add_49 = None + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_165) + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 8, '1'); convert_element_type_418 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + split_59 = torch.ops.aten.split.Tensor(wait_tensor_166, 2); wait_tensor_166 = None + getitem_589 = split_59[0] + getitem_590 = split_59[1] + getitem_591 = split_59[2] + getitem_592 = split_59[3] + getitem_593 = split_59[4] + getitem_594 = split_59[5] + getitem_595 = split_59[6] + getitem_596 = split_59[7]; split_59 = None + cat_51 = torch.ops.aten.cat.default([getitem_589, getitem_590, getitem_591, getitem_592, getitem_593, getitem_594, getitem_595, getitem_596], 1); getitem_589 = getitem_590 = getitem_591 = getitem_592 = getitem_593 = getitem_594 = getitem_595 = getitem_596 = None + view_924 = torch.ops.aten.view.default(cat_51, [16384, 4096]); cat_51 = None + view_925 = torch.ops.aten.view.default(mm_88, [2, 8192, 1792]); mm_88 = None + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_925, torch.float32); view_925 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 32, '0'); convert_element_type_424 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_168, [1, 0]); wait_tensor_168 = None + mm_89 = torch.ops.aten.mm.default(view_924, permute_141) + view_932 = torch.ops.aten.view.default(mm_89, [2, 8192, 1792]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_932) + view_939 = torch.ops.aten.view.default(mul_103, [16384, 1792]); mul_103 = None + mm_493 = torch.ops.aten.mm.default(permute_965, view_939); permute_965 = view_939 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 32, '0'); convert_element_type_427 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + permute_967 = torch.ops.aten.permute.default(permute_142, [1, 0]); permute_142 = None + mm_494 = torch.ops.aten.mm.default(view_2779, permute_967); view_2779 = permute_967 = None + view_2780 = torch.ops.aten.view.default(mm_494, [2, 8192, 1792]); mm_494 = None + convert_element_type_2104 = torch.ops.prims.convert_element_type.default(mm_493, torch.float32); mm_493 = None + reduce_scatter_tensor_277 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2104, 'avg', 32, '0'); convert_element_type_2104 = None + wait_tensor_711 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_277); reduce_scatter_tensor_277 = None + mul_644 = torch.ops.aten.mul.Tensor(view_2780, convert_element_type_423); convert_element_type_423 = None + mul_645 = torch.ops.aten.mul.Tensor(view_2780, view_932); view_2780 = view_932 = None + view_2781 = torch.ops.aten.view.default(mul_644, [16384, 1792]); mul_644 = None + permute_969 = torch.ops.aten.permute.default(view_2781, [1, 0]) + mm_495 = torch.ops.aten.mm.default(permute_969, view_924); permute_969 = None + permute_971 = torch.ops.aten.permute.default(permute_141, [1, 0]); permute_141 = None + mm_496 = torch.ops.aten.mm.default(view_2781, permute_971); view_2781 = permute_971 = None + view_2782 = torch.ops.aten.view.default(mm_496, [2, 8192, 4096]); mm_496 = None + convert_element_type_2109 = torch.ops.prims.convert_element_type.default(mm_495, torch.float32); mm_495 = None + reduce_scatter_tensor_278 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2109, 'avg', 32, '0'); convert_element_type_2109 = None + wait_tensor_712 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_278); reduce_scatter_tensor_278 = None + convert_element_type_2110 = torch.ops.prims.convert_element_type.default(mul_645, torch.float32); mul_645 = None + neg_19 = torch.ops.aten.neg.default(convert_element_type_422) + exp_19 = torch.ops.aten.exp.default(neg_19); neg_19 = None + add_262 = torch.ops.aten.add.Tensor(exp_19, 1); exp_19 = None + reciprocal_19 = torch.ops.aten.reciprocal.default(add_262); add_262 = None + mul_646 = torch.ops.aten.mul.Tensor(reciprocal_19, 1); reciprocal_19 = None + mul_647 = torch.ops.aten.mul.Tensor(convert_element_type_2110, mul_646); convert_element_type_2110 = None + sub_59 = torch.ops.aten.sub.Tensor(1, mul_646); mul_646 = None + mul_648 = torch.ops.aten.mul.Tensor(convert_element_type_422, sub_59); convert_element_type_422 = sub_59 = None + add_263 = torch.ops.aten.add.Tensor(mul_648, 1); mul_648 = None + mul_649 = torch.ops.aten.mul.Tensor(mul_647, add_263); mul_647 = add_263 = None + convert_element_type_2112 = torch.ops.prims.convert_element_type.default(mul_649, torch.bfloat16); mul_649 = None + view_2783 = torch.ops.aten.view.default(convert_element_type_2112, [16384, 1792]); convert_element_type_2112 = None + permute_973 = torch.ops.aten.permute.default(view_2783, [1, 0]) + mm_497 = torch.ops.aten.mm.default(permute_973, view_924); permute_973 = view_924 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 32, '0'); convert_element_type_419 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + permute_975 = torch.ops.aten.permute.default(permute_140, [1, 0]); permute_140 = None + mm_498 = torch.ops.aten.mm.default(view_2783, permute_975); view_2783 = permute_975 = None + view_2784 = torch.ops.aten.view.default(mm_498, [2, 8192, 4096]); mm_498 = None + add_264 = torch.ops.aten.add.Tensor(view_2782, view_2784); view_2782 = view_2784 = None + convert_element_type_2117 = torch.ops.prims.convert_element_type.default(mm_497, torch.float32); mm_497 = None + reduce_scatter_tensor_279 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2117, 'avg', 32, '0'); convert_element_type_2117 = None + wait_tensor_713 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_279); reduce_scatter_tensor_279 = None + split_216 = torch.ops.aten.split.Tensor(add_264, 1024, 1); add_264 = None + getitem_2073 = split_216[0] + getitem_2074 = split_216[1] + getitem_2075 = split_216[2] + getitem_2076 = split_216[3] + getitem_2077 = split_216[4] + getitem_2078 = split_216[5] + getitem_2079 = split_216[6] + getitem_2080 = split_216[7]; split_216 = None + cat_208 = torch.ops.aten.cat.default([getitem_2073, getitem_2074, getitem_2075, getitem_2076, getitem_2077, getitem_2078, getitem_2079, getitem_2080]); getitem_2073 = getitem_2074 = getitem_2075 = getitem_2076 = getitem_2077 = getitem_2078 = getitem_2079 = getitem_2080 = None + reduce_scatter_tensor_280 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_208, 'sum', 8, '1'); cat_208 = None + wait_tensor_714 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_280); reduce_scatter_tensor_280 = None + convert_element_type_2118 = torch.ops.prims.convert_element_type.default(wait_tensor_714, torch.float32); wait_tensor_714 = None + convert_element_type_2120 = torch.ops.prims.convert_element_type.default(wait_tensor_165, torch.float32); wait_tensor_165 = None + mul_650 = torch.ops.aten.mul.Tensor(convert_element_type_2118, convert_element_type_2120); convert_element_type_2120 = None + mul_652 = torch.ops.aten.mul.Tensor(mul_100, mul_650) + sum_117 = torch.ops.aten.sum.dim_IntList(mul_652, [2], True); mul_652 = None + div_39 = torch.ops.aten.div.Tensor(mul_100, 4096) + mul_653 = torch.ops.aten.mul.Tensor(div_39, sum_117); div_39 = sum_117 = None + sub_60 = torch.ops.aten.sub.Tensor(mul_650, mul_653); mul_650 = mul_653 = None + mul_654 = torch.ops.aten.mul.Tensor(sub_60, rsqrt_25); sub_60 = rsqrt_25 = None + mul_655 = torch.ops.aten.mul.Tensor(convert_element_type_2118, mul_100); convert_element_type_2118 = mul_100 = None + sum_118 = torch.ops.aten.sum.dim_IntList(mul_655, [0, 1]); mul_655 = None + convert_element_type_2121 = torch.ops.prims.convert_element_type.default(mul_654, torch.bfloat16); mul_654 = None + convert_element_type_2122 = torch.ops.prims.convert_element_type.default(sum_118, torch.bfloat16); sum_118 = None + all_reduce_39 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2122, 'sum', '1'); convert_element_type_2122 = None + wait_tensor_715 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_39); all_reduce_39 = None + convert_element_type_2123 = torch.ops.prims.convert_element_type.default(wait_tensor_715, torch.float32); wait_tensor_715 = None + reduce_scatter_tensor_281 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2123, 'avg', 32, '0'); convert_element_type_2123 = None + wait_tensor_716 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_281); reduce_scatter_tensor_281 = None + add_265 = torch.ops.aten.add.Tensor(add_261, convert_element_type_2121); add_261 = convert_element_type_2121 = None + all_gather_into_tensor_395 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_265, 8, '1') + wait_tensor_717 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_395); all_gather_into_tensor_395 = None + split_217 = torch.ops.aten.split.Tensor(wait_tensor_717, 2); wait_tensor_717 = None + getitem_2081 = split_217[0] + getitem_2082 = split_217[1] + getitem_2083 = split_217[2] + getitem_2084 = split_217[3] + getitem_2085 = split_217[4] + getitem_2086 = split_217[5] + getitem_2087 = split_217[6] + getitem_2088 = split_217[7]; split_217 = None + cat_209 = torch.ops.aten.cat.default([getitem_2081, getitem_2082, getitem_2083, getitem_2084, getitem_2085, getitem_2086, getitem_2087, getitem_2088], 1); getitem_2081 = getitem_2082 = getitem_2083 = getitem_2084 = getitem_2085 = getitem_2086 = getitem_2087 = getitem_2088 = None + view_2785 = torch.ops.aten.view.default(cat_209, [16384, 4096]); cat_209 = None + permute_977 = torch.ops.aten.permute.default(view_2785, [1, 0]) + permute_138 = torch.ops.aten.permute.default(getitem_572, [0, 2, 1, 3]) + view_906 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + view_912 = torch.ops.aten.view.default(view_906, [16384, 512]); view_906 = None + mm_499 = torch.ops.aten.mm.default(permute_977, view_912); permute_977 = view_912 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16); primals_116 = None + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 32, '0'); convert_element_type_413 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + permute_979 = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None + mm_500 = torch.ops.aten.mm.default(view_2785, permute_979); view_2785 = permute_979 = None + view_2786 = torch.ops.aten.view.default(mm_500, [2, 8192, 512]); mm_500 = None + convert_element_type_2128 = torch.ops.prims.convert_element_type.default(mm_499, torch.float32); mm_499 = None + reduce_scatter_tensor_282 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2128, 'avg', 32, '0'); convert_element_type_2128 = None + wait_tensor_718 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_282); reduce_scatter_tensor_282 = None + view_2787 = torch.ops.aten.view.default(view_2786, [2, 8192, 4, 128]); view_2786 = None + permute_981 = torch.ops.aten.permute.default(view_2787, [0, 2, 1, 3]); view_2787 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16); primals_112 = None + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 32, '0'); convert_element_type_397 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32); add_47 = None + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_158) + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_399, 8, '1'); convert_element_type_399 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + split_57 = torch.ops.aten.split.Tensor(wait_tensor_159, 2); wait_tensor_159 = None + getitem_564 = split_57[0] + getitem_565 = split_57[1] + getitem_566 = split_57[2] + getitem_567 = split_57[3] + getitem_568 = split_57[4] + getitem_569 = split_57[5] + getitem_570 = split_57[6] + getitem_571 = split_57[7]; split_57 = None + cat_49 = torch.ops.aten.cat.default([getitem_564, getitem_565, getitem_566, getitem_567, getitem_568, getitem_569, getitem_570, getitem_571], 1); getitem_564 = getitem_565 = getitem_566 = getitem_567 = getitem_568 = getitem_569 = getitem_570 = getitem_571 = None + view_879 = torch.ops.aten.view.default(cat_49, [16384, 4096]); cat_49 = None + view_880 = torch.ops.aten.view.default(mm_84, [2, 8192, 512]); mm_84 = None + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16); primals_114 = None + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 32, '0'); convert_element_type_403 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_85 = torch.ops.aten.mm.default(view_879, permute_133) + view_887 = torch.ops.aten.view.default(mm_85, [2, 8192, 128]); mm_85 = None + view_894 = torch.ops.aten.view.default(mm_86, [2, 8192, 128]); mm_86 = None + view_896 = torch.ops.aten.view.default(view_880, [2, 8192, -1, 128]); view_880 = None + view_897 = torch.ops.aten.view.default(view_887, [2, 8192, -1, 128]); view_887 = None + view_898 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 4, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_897, torch.float32); view_897 = None + view_900 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 1, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_900); view_900 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_37); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_902 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 4, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_37); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_903 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 1, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_903, torch.bfloat16); view_903 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 1, 4, 128]); unsqueeze_24 = None + view_904 = torch.ops.aten.view.default(expand_24, [2, 8192, 4, 128]); expand_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_898, 3); view_898 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 1, 4, 128]); unsqueeze_25 = None + view_905 = torch.ops.aten.view.default(expand_25, [2, 8192, 4, 128]); expand_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + permute_137 = torch.ops.aten.permute.default(view_905, [0, 2, 1, 3]); view_905 = None + _scaled_dot_product_cudnn_attention_backward_19 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_981, permute_135, permute_136, permute_137, getitem_572, getitem_573, getitem_578, getitem_579, None, None, None, 8192, 8192, 0.0, True); permute_981 = permute_135 = permute_136 = permute_137 = getitem_572 = getitem_573 = getitem_578 = getitem_579 = None + getitem_2089 = _scaled_dot_product_cudnn_attention_backward_19[0] + getitem_2090 = _scaled_dot_product_cudnn_attention_backward_19[1] + getitem_2091 = _scaled_dot_product_cudnn_attention_backward_19[2]; _scaled_dot_product_cudnn_attention_backward_19 = None + permute_982 = torch.ops.aten.permute.default(getitem_2091, [0, 2, 1, 3]); getitem_2091 = None + permute_983 = torch.ops.aten.permute.default(getitem_2090, [0, 2, 1, 3]); getitem_2090 = None + permute_984 = torch.ops.aten.permute.default(getitem_2089, [0, 2, 1, 3]); getitem_2089 = None + view_2788 = torch.ops.aten.view.default(permute_982, [2, 8192, 1, 4, 128]); permute_982 = None + sum_119 = torch.ops.aten.sum.dim_IntList(view_2788, [3], True); view_2788 = None + squeeze_38 = torch.ops.aten.squeeze.dim(sum_119, 3); sum_119 = None + view_2789 = torch.ops.aten.view.default(permute_983, [2, 8192, 1, 4, 128]); permute_983 = None + sum_120 = torch.ops.aten.sum.dim_IntList(view_2789, [3], True); view_2789 = None + squeeze_39 = torch.ops.aten.squeeze.dim(sum_120, 3); sum_120 = None + convert_element_type_2129 = torch.ops.prims.convert_element_type.default(squeeze_39, torch.float32); squeeze_39 = None + convert_element_type_2130 = torch.ops.prims.convert_element_type.default(permute_984, torch.float32); permute_984 = None + view_2790 = torch.ops.aten.view.default(convert_element_type_2129, [2, 8192, 1, 64, 2]); convert_element_type_2129 = None + view_as_complex_102 = torch.ops.aten.view_as_complex.default(view_2790); view_2790 = None + mul_656 = torch.ops.aten.mul.Tensor(view_as_complex_102, _conj); view_as_complex_102 = None + view_2791 = torch.ops.aten.view.default(convert_element_type_2130, [2, 8192, 4, 64, 2]); convert_element_type_2130 = None + view_as_complex_103 = torch.ops.aten.view_as_complex.default(view_2791); view_2791 = None + mul_657 = torch.ops.aten.mul.Tensor(view_as_complex_103, _conj); view_as_complex_103 = None + view_as_real_102 = torch.ops.aten.view_as_real.default(mul_656); mul_656 = None + view_2792 = torch.ops.aten.view.default(view_as_real_102, [2, 8192, 1, 128]); view_as_real_102 = None + convert_element_type_2131 = torch.ops.prims.convert_element_type.default(view_2792, torch.bfloat16); view_2792 = None + view_as_real_103 = torch.ops.aten.view_as_real.default(mul_657); mul_657 = None + view_2793 = torch.ops.aten.view.default(view_as_real_103, [2, 8192, 4, 128]); view_as_real_103 = None + convert_element_type_2132 = torch.ops.prims.convert_element_type.default(view_2793, torch.bfloat16); view_2793 = None + view_2794 = torch.ops.aten.view.default(squeeze_38, [2, 8192, 128]); squeeze_38 = None + view_2795 = torch.ops.aten.view.default(convert_element_type_2131, [2, 8192, 128]); convert_element_type_2131 = None + view_2796 = torch.ops.aten.view.default(convert_element_type_2132, [2, 8192, 512]); convert_element_type_2132 = None + view_2797 = torch.ops.aten.view.default(view_2794, [16384, 128]); view_2794 = None + permute_985 = torch.ops.aten.permute.default(view_2797, [1, 0]) + mm_501 = torch.ops.aten.mm.default(permute_985, view_879); permute_985 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16); primals_115 = None + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 32, '0'); convert_element_type_406 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + permute_987 = torch.ops.aten.permute.default(permute_134, [1, 0]); permute_134 = None + mm_502 = torch.ops.aten.mm.default(view_2797, permute_987); view_2797 = permute_987 = None + view_2798 = torch.ops.aten.view.default(mm_502, [2, 8192, 4096]); mm_502 = None + convert_element_type_2137 = torch.ops.prims.convert_element_type.default(mm_501, torch.float32); mm_501 = None + reduce_scatter_tensor_283 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2137, 'avg', 32, '0'); convert_element_type_2137 = None + wait_tensor_719 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_283); reduce_scatter_tensor_283 = None + view_2799 = torch.ops.aten.view.default(view_2795, [16384, 128]); view_2795 = None + permute_989 = torch.ops.aten.permute.default(view_2799, [1, 0]) + mm_503 = torch.ops.aten.mm.default(permute_989, view_879); permute_989 = None + permute_991 = torch.ops.aten.permute.default(permute_133, [1, 0]); permute_133 = None + mm_504 = torch.ops.aten.mm.default(view_2799, permute_991); view_2799 = permute_991 = None + view_2800 = torch.ops.aten.view.default(mm_504, [2, 8192, 4096]); mm_504 = None + add_266 = torch.ops.aten.add.Tensor(view_2798, view_2800); view_2798 = view_2800 = None + convert_element_type_2142 = torch.ops.prims.convert_element_type.default(mm_503, torch.float32); mm_503 = None + reduce_scatter_tensor_284 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2142, 'avg', 32, '0'); convert_element_type_2142 = None + wait_tensor_720 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_284); reduce_scatter_tensor_284 = None + view_2801 = torch.ops.aten.view.default(view_2796, [16384, 512]); view_2796 = None + permute_993 = torch.ops.aten.permute.default(view_2801, [1, 0]) + mm_505 = torch.ops.aten.mm.default(permute_993, view_879); permute_993 = view_879 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16); primals_113 = None + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 32, '0'); convert_element_type_400 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + permute_995 = torch.ops.aten.permute.default(permute_132, [1, 0]); permute_132 = None + mm_506 = torch.ops.aten.mm.default(view_2801, permute_995); view_2801 = permute_995 = None + view_2802 = torch.ops.aten.view.default(mm_506, [2, 8192, 4096]); mm_506 = None + add_267 = torch.ops.aten.add.Tensor(add_266, view_2802); add_266 = view_2802 = None + convert_element_type_2147 = torch.ops.prims.convert_element_type.default(mm_505, torch.float32); mm_505 = None + reduce_scatter_tensor_285 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2147, 'avg', 32, '0'); convert_element_type_2147 = None + wait_tensor_721 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_285); reduce_scatter_tensor_285 = None + split_218 = torch.ops.aten.split.Tensor(add_267, 1024, 1); add_267 = None + getitem_2092 = split_218[0] + getitem_2093 = split_218[1] + getitem_2094 = split_218[2] + getitem_2095 = split_218[3] + getitem_2096 = split_218[4] + getitem_2097 = split_218[5] + getitem_2098 = split_218[6] + getitem_2099 = split_218[7]; split_218 = None + cat_210 = torch.ops.aten.cat.default([getitem_2092, getitem_2093, getitem_2094, getitem_2095, getitem_2096, getitem_2097, getitem_2098, getitem_2099]); getitem_2092 = getitem_2093 = getitem_2094 = getitem_2095 = getitem_2096 = getitem_2097 = getitem_2098 = getitem_2099 = None + reduce_scatter_tensor_286 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_210, 'sum', 8, '1'); cat_210 = None + wait_tensor_722 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_286); reduce_scatter_tensor_286 = None + convert_element_type_2148 = torch.ops.prims.convert_element_type.default(wait_tensor_722, torch.float32); wait_tensor_722 = None + convert_element_type_2150 = torch.ops.prims.convert_element_type.default(wait_tensor_158, torch.float32); wait_tensor_158 = None + mul_658 = torch.ops.aten.mul.Tensor(convert_element_type_2148, convert_element_type_2150); convert_element_type_2150 = None + mul_660 = torch.ops.aten.mul.Tensor(mul_96, mul_658) + sum_121 = torch.ops.aten.sum.dim_IntList(mul_660, [2], True); mul_660 = None + div_40 = torch.ops.aten.div.Tensor(mul_96, 4096) + mul_661 = torch.ops.aten.mul.Tensor(div_40, sum_121); div_40 = sum_121 = None + sub_61 = torch.ops.aten.sub.Tensor(mul_658, mul_661); mul_658 = mul_661 = None + mul_662 = torch.ops.aten.mul.Tensor(sub_61, rsqrt_24); sub_61 = rsqrt_24 = None + mul_663 = torch.ops.aten.mul.Tensor(convert_element_type_2148, mul_96); convert_element_type_2148 = mul_96 = None + sum_122 = torch.ops.aten.sum.dim_IntList(mul_663, [0, 1]); mul_663 = None + convert_element_type_2151 = torch.ops.prims.convert_element_type.default(mul_662, torch.bfloat16); mul_662 = None + convert_element_type_2152 = torch.ops.prims.convert_element_type.default(sum_122, torch.bfloat16); sum_122 = None + all_reduce_40 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2152, 'sum', '1'); convert_element_type_2152 = None + wait_tensor_723 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_40); all_reduce_40 = None + convert_element_type_2153 = torch.ops.prims.convert_element_type.default(wait_tensor_723, torch.float32); wait_tensor_723 = None + reduce_scatter_tensor_287 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2153, 'avg', 32, '0'); convert_element_type_2153 = None + wait_tensor_724 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_287); reduce_scatter_tensor_287 = None + add_268 = torch.ops.aten.add.Tensor(add_265, convert_element_type_2151); add_265 = convert_element_type_2151 = None + all_gather_into_tensor_396 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_268, 8, '1') + wait_tensor_725 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_396); all_gather_into_tensor_396 = None + split_219 = torch.ops.aten.split.Tensor(wait_tensor_725, 2); wait_tensor_725 = None + getitem_2100 = split_219[0] + getitem_2101 = split_219[1] + getitem_2102 = split_219[2] + getitem_2103 = split_219[3] + getitem_2104 = split_219[4] + getitem_2105 = split_219[5] + getitem_2106 = split_219[6] + getitem_2107 = split_219[7]; split_219 = None + cat_211 = torch.ops.aten.cat.default([getitem_2100, getitem_2101, getitem_2102, getitem_2103, getitem_2104, getitem_2105, getitem_2106, getitem_2107], 1); getitem_2100 = getitem_2101 = getitem_2102 = getitem_2103 = getitem_2104 = getitem_2105 = getitem_2106 = getitem_2107 = None + view_2803 = torch.ops.aten.view.default(cat_211, [16384, 4096]); cat_211 = None + permute_997 = torch.ops.aten.permute.default(view_2803, [1, 0]) + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23); reduce_scatter_tensor_23 = None + add_45 = torch.ops.aten.add.Tensor(add_43, wait_tensor_151); wait_tensor_151 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 32, '0'); convert_element_type_383 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32); add_45 = None + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_152) + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_385, 8, '1'); convert_element_type_385 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + split_55 = torch.ops.aten.split.Tensor(wait_tensor_153, 2); wait_tensor_153 = None + getitem_548 = split_55[0] + getitem_549 = split_55[1] + getitem_550 = split_55[2] + getitem_551 = split_55[3] + getitem_552 = split_55[4] + getitem_553 = split_55[5] + getitem_554 = split_55[6] + getitem_555 = split_55[7]; split_55 = None + cat_47 = torch.ops.aten.cat.default([getitem_548, getitem_549, getitem_550, getitem_551, getitem_552, getitem_553, getitem_554, getitem_555], 1); getitem_548 = getitem_549 = getitem_550 = getitem_551 = getitem_552 = getitem_553 = getitem_554 = getitem_555 = None + view_852 = torch.ops.aten.view.default(cat_47, [16384, 4096]); cat_47 = None + view_853 = torch.ops.aten.view.default(mm_81, [2, 8192, 1792]); mm_81 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_853, torch.float32); view_853 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16); primals_110 = None + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 32, '0'); convert_element_type_391 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + mm_82 = torch.ops.aten.mm.default(view_852, permute_130) + view_860 = torch.ops.aten.view.default(mm_82, [2, 8192, 1792]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_860) + view_867 = torch.ops.aten.view.default(mul_95, [16384, 1792]); mul_95 = None + mm_507 = torch.ops.aten.mm.default(permute_997, view_867); permute_997 = view_867 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16); primals_111 = None + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 32, '0'); convert_element_type_394 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + permute_999 = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None + mm_508 = torch.ops.aten.mm.default(view_2803, permute_999); view_2803 = permute_999 = None + view_2804 = torch.ops.aten.view.default(mm_508, [2, 8192, 1792]); mm_508 = None + convert_element_type_2158 = torch.ops.prims.convert_element_type.default(mm_507, torch.float32); mm_507 = None + reduce_scatter_tensor_288 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2158, 'avg', 32, '0'); convert_element_type_2158 = None + wait_tensor_726 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_288); reduce_scatter_tensor_288 = None + mul_664 = torch.ops.aten.mul.Tensor(view_2804, convert_element_type_390); convert_element_type_390 = None + mul_665 = torch.ops.aten.mul.Tensor(view_2804, view_860); view_2804 = view_860 = None + view_2805 = torch.ops.aten.view.default(mul_664, [16384, 1792]); mul_664 = None + permute_1001 = torch.ops.aten.permute.default(view_2805, [1, 0]) + mm_509 = torch.ops.aten.mm.default(permute_1001, view_852); permute_1001 = None + permute_1003 = torch.ops.aten.permute.default(permute_130, [1, 0]); permute_130 = None + mm_510 = torch.ops.aten.mm.default(view_2805, permute_1003); view_2805 = permute_1003 = None + view_2806 = torch.ops.aten.view.default(mm_510, [2, 8192, 4096]); mm_510 = None + convert_element_type_2163 = torch.ops.prims.convert_element_type.default(mm_509, torch.float32); mm_509 = None + reduce_scatter_tensor_289 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2163, 'avg', 32, '0'); convert_element_type_2163 = None + wait_tensor_727 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_289); reduce_scatter_tensor_289 = None + convert_element_type_2164 = torch.ops.prims.convert_element_type.default(mul_665, torch.float32); mul_665 = None + neg_20 = torch.ops.aten.neg.default(convert_element_type_389) + exp_20 = torch.ops.aten.exp.default(neg_20); neg_20 = None + add_269 = torch.ops.aten.add.Tensor(exp_20, 1); exp_20 = None + reciprocal_20 = torch.ops.aten.reciprocal.default(add_269); add_269 = None + mul_666 = torch.ops.aten.mul.Tensor(reciprocal_20, 1); reciprocal_20 = None + mul_667 = torch.ops.aten.mul.Tensor(convert_element_type_2164, mul_666); convert_element_type_2164 = None + sub_62 = torch.ops.aten.sub.Tensor(1, mul_666); mul_666 = None + mul_668 = torch.ops.aten.mul.Tensor(convert_element_type_389, sub_62); convert_element_type_389 = sub_62 = None + add_270 = torch.ops.aten.add.Tensor(mul_668, 1); mul_668 = None + mul_669 = torch.ops.aten.mul.Tensor(mul_667, add_270); mul_667 = add_270 = None + convert_element_type_2166 = torch.ops.prims.convert_element_type.default(mul_669, torch.bfloat16); mul_669 = None + view_2807 = torch.ops.aten.view.default(convert_element_type_2166, [16384, 1792]); convert_element_type_2166 = None + permute_1005 = torch.ops.aten.permute.default(view_2807, [1, 0]) + mm_511 = torch.ops.aten.mm.default(permute_1005, view_852); permute_1005 = view_852 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16); primals_109 = None + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 32, '0'); convert_element_type_386 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_154, [1, 0]); wait_tensor_154 = None + permute_1007 = torch.ops.aten.permute.default(permute_129, [1, 0]); permute_129 = None + mm_512 = torch.ops.aten.mm.default(view_2807, permute_1007); view_2807 = permute_1007 = None + view_2808 = torch.ops.aten.view.default(mm_512, [2, 8192, 4096]); mm_512 = None + add_271 = torch.ops.aten.add.Tensor(view_2806, view_2808); view_2806 = view_2808 = None + convert_element_type_2171 = torch.ops.prims.convert_element_type.default(mm_511, torch.float32); mm_511 = None + reduce_scatter_tensor_290 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2171, 'avg', 32, '0'); convert_element_type_2171 = None + wait_tensor_728 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_290); reduce_scatter_tensor_290 = None + split_220 = torch.ops.aten.split.Tensor(add_271, 1024, 1); add_271 = None + getitem_2108 = split_220[0] + getitem_2109 = split_220[1] + getitem_2110 = split_220[2] + getitem_2111 = split_220[3] + getitem_2112 = split_220[4] + getitem_2113 = split_220[5] + getitem_2114 = split_220[6] + getitem_2115 = split_220[7]; split_220 = None + cat_212 = torch.ops.aten.cat.default([getitem_2108, getitem_2109, getitem_2110, getitem_2111, getitem_2112, getitem_2113, getitem_2114, getitem_2115]); getitem_2108 = getitem_2109 = getitem_2110 = getitem_2111 = getitem_2112 = getitem_2113 = getitem_2114 = getitem_2115 = None + reduce_scatter_tensor_291 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_212, 'sum', 8, '1'); cat_212 = None + wait_tensor_729 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_291); reduce_scatter_tensor_291 = None + convert_element_type_2172 = torch.ops.prims.convert_element_type.default(wait_tensor_729, torch.float32); wait_tensor_729 = None + convert_element_type_2174 = torch.ops.prims.convert_element_type.default(wait_tensor_152, torch.float32); wait_tensor_152 = None + mul_670 = torch.ops.aten.mul.Tensor(convert_element_type_2172, convert_element_type_2174); convert_element_type_2174 = None + mul_672 = torch.ops.aten.mul.Tensor(mul_92, mul_670) + sum_123 = torch.ops.aten.sum.dim_IntList(mul_672, [2], True); mul_672 = None + div_41 = torch.ops.aten.div.Tensor(mul_92, 4096) + mul_673 = torch.ops.aten.mul.Tensor(div_41, sum_123); div_41 = sum_123 = None + sub_63 = torch.ops.aten.sub.Tensor(mul_670, mul_673); mul_670 = mul_673 = None + mul_674 = torch.ops.aten.mul.Tensor(sub_63, rsqrt_23); sub_63 = rsqrt_23 = None + mul_675 = torch.ops.aten.mul.Tensor(convert_element_type_2172, mul_92); convert_element_type_2172 = mul_92 = None + sum_124 = torch.ops.aten.sum.dim_IntList(mul_675, [0, 1]); mul_675 = None + convert_element_type_2175 = torch.ops.prims.convert_element_type.default(mul_674, torch.bfloat16); mul_674 = None + convert_element_type_2176 = torch.ops.prims.convert_element_type.default(sum_124, torch.bfloat16); sum_124 = None + all_reduce_41 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2176, 'sum', '1'); convert_element_type_2176 = None + wait_tensor_730 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_41); all_reduce_41 = None + convert_element_type_2177 = torch.ops.prims.convert_element_type.default(wait_tensor_730, torch.float32); wait_tensor_730 = None + reduce_scatter_tensor_292 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2177, 'avg', 32, '0'); convert_element_type_2177 = None + wait_tensor_731 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_292); reduce_scatter_tensor_292 = None + add_272 = torch.ops.aten.add.Tensor(add_268, convert_element_type_2175); add_268 = convert_element_type_2175 = None + all_gather_into_tensor_397 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_272, 8, '1') + wait_tensor_732 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_397); all_gather_into_tensor_397 = None + split_221 = torch.ops.aten.split.Tensor(wait_tensor_732, 2); wait_tensor_732 = None + getitem_2116 = split_221[0] + getitem_2117 = split_221[1] + getitem_2118 = split_221[2] + getitem_2119 = split_221[3] + getitem_2120 = split_221[4] + getitem_2121 = split_221[5] + getitem_2122 = split_221[6] + getitem_2123 = split_221[7]; split_221 = None + cat_213 = torch.ops.aten.cat.default([getitem_2116, getitem_2117, getitem_2118, getitem_2119, getitem_2120, getitem_2121, getitem_2122, getitem_2123], 1); getitem_2116 = getitem_2117 = getitem_2118 = getitem_2119 = getitem_2120 = getitem_2121 = getitem_2122 = getitem_2123 = None + view_2809 = torch.ops.aten.view.default(cat_213, [16384, 4096]); cat_213 = None + permute_1009 = torch.ops.aten.permute.default(view_2809, [1, 0]) + permute_127 = torch.ops.aten.permute.default(getitem_531, [0, 2, 1, 3]) + view_834 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + view_840 = torch.ops.aten.view.default(view_834, [16384, 512]); view_834 = None + mm_513 = torch.ops.aten.mm.default(permute_1009, view_840); permute_1009 = view_840 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 32, '0'); convert_element_type_380 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_150, [1, 0]); wait_tensor_150 = None + permute_1011 = torch.ops.aten.permute.default(permute_128, [1, 0]); permute_128 = None + mm_514 = torch.ops.aten.mm.default(view_2809, permute_1011); view_2809 = permute_1011 = None + view_2810 = torch.ops.aten.view.default(mm_514, [2, 8192, 512]); mm_514 = None + convert_element_type_2182 = torch.ops.prims.convert_element_type.default(mm_513, torch.float32); mm_513 = None + reduce_scatter_tensor_293 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2182, 'avg', 32, '0'); convert_element_type_2182 = None + wait_tensor_733 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_293); reduce_scatter_tensor_293 = None + view_2811 = torch.ops.aten.view.default(view_2810, [2, 8192, 4, 128]); view_2810 = None + permute_1013 = torch.ops.aten.permute.default(view_2811, [0, 2, 1, 3]); view_2811 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 32, '0'); convert_element_type_364 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32); add_43 = None + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_145) + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_366, 8, '1'); convert_element_type_366 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + split_53 = torch.ops.aten.split.Tensor(wait_tensor_146, 2); wait_tensor_146 = None + getitem_523 = split_53[0] + getitem_524 = split_53[1] + getitem_525 = split_53[2] + getitem_526 = split_53[3] + getitem_527 = split_53[4] + getitem_528 = split_53[5] + getitem_529 = split_53[6] + getitem_530 = split_53[7]; split_53 = None + cat_45 = torch.ops.aten.cat.default([getitem_523, getitem_524, getitem_525, getitem_526, getitem_527, getitem_528, getitem_529, getitem_530], 1); getitem_523 = getitem_524 = getitem_525 = getitem_526 = getitem_527 = getitem_528 = getitem_529 = getitem_530 = None + view_807 = torch.ops.aten.view.default(cat_45, [16384, 4096]); cat_45 = None + view_808 = torch.ops.aten.view.default(mm_77, [2, 8192, 512]); mm_77 = None + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 32, '0'); convert_element_type_370 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_78 = torch.ops.aten.mm.default(view_807, permute_122) + view_815 = torch.ops.aten.view.default(mm_78, [2, 8192, 128]); mm_78 = None + view_822 = torch.ops.aten.view.default(mm_79, [2, 8192, 128]); mm_79 = None + view_824 = torch.ops.aten.view.default(view_808, [2, 8192, -1, 128]); view_808 = None + view_825 = torch.ops.aten.view.default(view_815, [2, 8192, -1, 128]); view_815 = None + view_826 = torch.ops.aten.view.default(view_822, [2, 8192, -1, 128]); view_822 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_824, torch.float32); view_824 = None + view_827 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 4, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_827); view_827 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_825, torch.float32); view_825 = None + view_828 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 1, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_828); view_828 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_37); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_830 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 4, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_37); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_831 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 1, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_830, torch.bfloat16); view_830 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_831, torch.bfloat16); view_831 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 1, 4, 128]); unsqueeze_22 = None + view_832 = torch.ops.aten.view.default(expand_22, [2, 8192, 4, 128]); expand_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_826, 3); view_826 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 1, 4, 128]); unsqueeze_23 = None + view_833 = torch.ops.aten.view.default(expand_23, [2, 8192, 4, 128]); expand_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_832, [0, 2, 1, 3]); view_832 = None + permute_126 = torch.ops.aten.permute.default(view_833, [0, 2, 1, 3]); view_833 = None + _scaled_dot_product_cudnn_attention_backward_20 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1013, permute_124, permute_125, permute_126, getitem_531, getitem_532, getitem_537, getitem_538, None, None, None, 8192, 8192, 0.0, True); permute_1013 = permute_124 = permute_125 = permute_126 = getitem_531 = getitem_532 = getitem_537 = getitem_538 = None + getitem_2124 = _scaled_dot_product_cudnn_attention_backward_20[0] + getitem_2125 = _scaled_dot_product_cudnn_attention_backward_20[1] + getitem_2126 = _scaled_dot_product_cudnn_attention_backward_20[2]; _scaled_dot_product_cudnn_attention_backward_20 = None + permute_1014 = torch.ops.aten.permute.default(getitem_2126, [0, 2, 1, 3]); getitem_2126 = None + permute_1015 = torch.ops.aten.permute.default(getitem_2125, [0, 2, 1, 3]); getitem_2125 = None + permute_1016 = torch.ops.aten.permute.default(getitem_2124, [0, 2, 1, 3]); getitem_2124 = None + view_2812 = torch.ops.aten.view.default(permute_1014, [2, 8192, 1, 4, 128]); permute_1014 = None + sum_125 = torch.ops.aten.sum.dim_IntList(view_2812, [3], True); view_2812 = None + squeeze_40 = torch.ops.aten.squeeze.dim(sum_125, 3); sum_125 = None + view_2813 = torch.ops.aten.view.default(permute_1015, [2, 8192, 1, 4, 128]); permute_1015 = None + sum_126 = torch.ops.aten.sum.dim_IntList(view_2813, [3], True); view_2813 = None + squeeze_41 = torch.ops.aten.squeeze.dim(sum_126, 3); sum_126 = None + convert_element_type_2183 = torch.ops.prims.convert_element_type.default(squeeze_41, torch.float32); squeeze_41 = None + convert_element_type_2184 = torch.ops.prims.convert_element_type.default(permute_1016, torch.float32); permute_1016 = None + view_2814 = torch.ops.aten.view.default(convert_element_type_2183, [2, 8192, 1, 64, 2]); convert_element_type_2183 = None + view_as_complex_104 = torch.ops.aten.view_as_complex.default(view_2814); view_2814 = None + mul_676 = torch.ops.aten.mul.Tensor(view_as_complex_104, _conj); view_as_complex_104 = None + view_2815 = torch.ops.aten.view.default(convert_element_type_2184, [2, 8192, 4, 64, 2]); convert_element_type_2184 = None + view_as_complex_105 = torch.ops.aten.view_as_complex.default(view_2815); view_2815 = None + mul_677 = torch.ops.aten.mul.Tensor(view_as_complex_105, _conj); view_as_complex_105 = None + view_as_real_104 = torch.ops.aten.view_as_real.default(mul_676); mul_676 = None + view_2816 = torch.ops.aten.view.default(view_as_real_104, [2, 8192, 1, 128]); view_as_real_104 = None + convert_element_type_2185 = torch.ops.prims.convert_element_type.default(view_2816, torch.bfloat16); view_2816 = None + view_as_real_105 = torch.ops.aten.view_as_real.default(mul_677); mul_677 = None + view_2817 = torch.ops.aten.view.default(view_as_real_105, [2, 8192, 4, 128]); view_as_real_105 = None + convert_element_type_2186 = torch.ops.prims.convert_element_type.default(view_2817, torch.bfloat16); view_2817 = None + view_2818 = torch.ops.aten.view.default(squeeze_40, [2, 8192, 128]); squeeze_40 = None + view_2819 = torch.ops.aten.view.default(convert_element_type_2185, [2, 8192, 128]); convert_element_type_2185 = None + view_2820 = torch.ops.aten.view.default(convert_element_type_2186, [2, 8192, 512]); convert_element_type_2186 = None + view_2821 = torch.ops.aten.view.default(view_2818, [16384, 128]); view_2818 = None + permute_1017 = torch.ops.aten.permute.default(view_2821, [1, 0]) + mm_515 = torch.ops.aten.mm.default(permute_1017, view_807); permute_1017 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 32, '0'); convert_element_type_373 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + permute_1019 = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None + mm_516 = torch.ops.aten.mm.default(view_2821, permute_1019); view_2821 = permute_1019 = None + view_2822 = torch.ops.aten.view.default(mm_516, [2, 8192, 4096]); mm_516 = None + convert_element_type_2191 = torch.ops.prims.convert_element_type.default(mm_515, torch.float32); mm_515 = None + reduce_scatter_tensor_294 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2191, 'avg', 32, '0'); convert_element_type_2191 = None + wait_tensor_734 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_294); reduce_scatter_tensor_294 = None + view_2823 = torch.ops.aten.view.default(view_2819, [16384, 128]); view_2819 = None + permute_1021 = torch.ops.aten.permute.default(view_2823, [1, 0]) + mm_517 = torch.ops.aten.mm.default(permute_1021, view_807); permute_1021 = None + permute_1023 = torch.ops.aten.permute.default(permute_122, [1, 0]); permute_122 = None + mm_518 = torch.ops.aten.mm.default(view_2823, permute_1023); view_2823 = permute_1023 = None + view_2824 = torch.ops.aten.view.default(mm_518, [2, 8192, 4096]); mm_518 = None + add_273 = torch.ops.aten.add.Tensor(view_2822, view_2824); view_2822 = view_2824 = None + convert_element_type_2196 = torch.ops.prims.convert_element_type.default(mm_517, torch.float32); mm_517 = None + reduce_scatter_tensor_295 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2196, 'avg', 32, '0'); convert_element_type_2196 = None + wait_tensor_735 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_295); reduce_scatter_tensor_295 = None + view_2825 = torch.ops.aten.view.default(view_2820, [16384, 512]); view_2820 = None + permute_1025 = torch.ops.aten.permute.default(view_2825, [1, 0]) + mm_519 = torch.ops.aten.mm.default(permute_1025, view_807); permute_1025 = view_807 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 32, '0'); convert_element_type_367 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + permute_1027 = torch.ops.aten.permute.default(permute_121, [1, 0]); permute_121 = None + mm_520 = torch.ops.aten.mm.default(view_2825, permute_1027); view_2825 = permute_1027 = None + view_2826 = torch.ops.aten.view.default(mm_520, [2, 8192, 4096]); mm_520 = None + add_274 = torch.ops.aten.add.Tensor(add_273, view_2826); add_273 = view_2826 = None + convert_element_type_2201 = torch.ops.prims.convert_element_type.default(mm_519, torch.float32); mm_519 = None + reduce_scatter_tensor_296 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2201, 'avg', 32, '0'); convert_element_type_2201 = None + wait_tensor_736 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_296); reduce_scatter_tensor_296 = None + split_222 = torch.ops.aten.split.Tensor(add_274, 1024, 1); add_274 = None + getitem_2127 = split_222[0] + getitem_2128 = split_222[1] + getitem_2129 = split_222[2] + getitem_2130 = split_222[3] + getitem_2131 = split_222[4] + getitem_2132 = split_222[5] + getitem_2133 = split_222[6] + getitem_2134 = split_222[7]; split_222 = None + cat_214 = torch.ops.aten.cat.default([getitem_2127, getitem_2128, getitem_2129, getitem_2130, getitem_2131, getitem_2132, getitem_2133, getitem_2134]); getitem_2127 = getitem_2128 = getitem_2129 = getitem_2130 = getitem_2131 = getitem_2132 = getitem_2133 = getitem_2134 = None + reduce_scatter_tensor_297 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_214, 'sum', 8, '1'); cat_214 = None + wait_tensor_737 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_297); reduce_scatter_tensor_297 = None + convert_element_type_2202 = torch.ops.prims.convert_element_type.default(wait_tensor_737, torch.float32); wait_tensor_737 = None + convert_element_type_2204 = torch.ops.prims.convert_element_type.default(wait_tensor_145, torch.float32); wait_tensor_145 = None + mul_678 = torch.ops.aten.mul.Tensor(convert_element_type_2202, convert_element_type_2204); convert_element_type_2204 = None + mul_680 = torch.ops.aten.mul.Tensor(mul_88, mul_678) + sum_127 = torch.ops.aten.sum.dim_IntList(mul_680, [2], True); mul_680 = None + div_42 = torch.ops.aten.div.Tensor(mul_88, 4096) + mul_681 = torch.ops.aten.mul.Tensor(div_42, sum_127); div_42 = sum_127 = None + sub_64 = torch.ops.aten.sub.Tensor(mul_678, mul_681); mul_678 = mul_681 = None + mul_682 = torch.ops.aten.mul.Tensor(sub_64, rsqrt_22); sub_64 = rsqrt_22 = None + mul_683 = torch.ops.aten.mul.Tensor(convert_element_type_2202, mul_88); convert_element_type_2202 = mul_88 = None + sum_128 = torch.ops.aten.sum.dim_IntList(mul_683, [0, 1]); mul_683 = None + convert_element_type_2205 = torch.ops.prims.convert_element_type.default(mul_682, torch.bfloat16); mul_682 = None + convert_element_type_2206 = torch.ops.prims.convert_element_type.default(sum_128, torch.bfloat16); sum_128 = None + all_reduce_42 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2206, 'sum', '1'); convert_element_type_2206 = None + wait_tensor_738 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_42); all_reduce_42 = None + convert_element_type_2207 = torch.ops.prims.convert_element_type.default(wait_tensor_738, torch.float32); wait_tensor_738 = None + reduce_scatter_tensor_298 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2207, 'avg', 32, '0'); convert_element_type_2207 = None + wait_tensor_739 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_298); reduce_scatter_tensor_298 = None + add_275 = torch.ops.aten.add.Tensor(add_272, convert_element_type_2205); add_272 = convert_element_type_2205 = None + all_gather_into_tensor_398 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_275, 8, '1') + wait_tensor_740 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_398); all_gather_into_tensor_398 = None + split_223 = torch.ops.aten.split.Tensor(wait_tensor_740, 2); wait_tensor_740 = None + getitem_2135 = split_223[0] + getitem_2136 = split_223[1] + getitem_2137 = split_223[2] + getitem_2138 = split_223[3] + getitem_2139 = split_223[4] + getitem_2140 = split_223[5] + getitem_2141 = split_223[6] + getitem_2142 = split_223[7]; split_223 = None + cat_215 = torch.ops.aten.cat.default([getitem_2135, getitem_2136, getitem_2137, getitem_2138, getitem_2139, getitem_2140, getitem_2141, getitem_2142], 1); getitem_2135 = getitem_2136 = getitem_2137 = getitem_2138 = getitem_2139 = getitem_2140 = getitem_2141 = getitem_2142 = None + view_2827 = torch.ops.aten.view.default(cat_215, [16384, 4096]); cat_215 = None + permute_1029 = torch.ops.aten.permute.default(view_2827, [1, 0]) + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21); reduce_scatter_tensor_21 = None + add_41 = torch.ops.aten.add.Tensor(add_39, wait_tensor_138); wait_tensor_138 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16); primals_99 = None + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 32, '0'); convert_element_type_350 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32); add_41 = None + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_139) + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_352, 8, '1'); convert_element_type_352 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + split_51 = torch.ops.aten.split.Tensor(wait_tensor_140, 2); wait_tensor_140 = None + getitem_507 = split_51[0] + getitem_508 = split_51[1] + getitem_509 = split_51[2] + getitem_510 = split_51[3] + getitem_511 = split_51[4] + getitem_512 = split_51[5] + getitem_513 = split_51[6] + getitem_514 = split_51[7]; split_51 = None + cat_43 = torch.ops.aten.cat.default([getitem_507, getitem_508, getitem_509, getitem_510, getitem_511, getitem_512, getitem_513, getitem_514], 1); getitem_507 = getitem_508 = getitem_509 = getitem_510 = getitem_511 = getitem_512 = getitem_513 = getitem_514 = None + view_780 = torch.ops.aten.view.default(cat_43, [16384, 4096]); cat_43 = None + view_781 = torch.ops.aten.view.default(mm_74, [2, 8192, 1792]); mm_74 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_781, torch.float32); view_781 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 32, '0'); convert_element_type_358 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + mm_75 = torch.ops.aten.mm.default(view_780, permute_119) + view_788 = torch.ops.aten.view.default(mm_75, [2, 8192, 1792]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_788) + view_795 = torch.ops.aten.view.default(mul_87, [16384, 1792]); mul_87 = None + mm_521 = torch.ops.aten.mm.default(permute_1029, view_795); permute_1029 = view_795 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 32, '0'); convert_element_type_361 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + permute_1031 = torch.ops.aten.permute.default(permute_120, [1, 0]); permute_120 = None + mm_522 = torch.ops.aten.mm.default(view_2827, permute_1031); view_2827 = permute_1031 = None + view_2828 = torch.ops.aten.view.default(mm_522, [2, 8192, 1792]); mm_522 = None + convert_element_type_2212 = torch.ops.prims.convert_element_type.default(mm_521, torch.float32); mm_521 = None + reduce_scatter_tensor_299 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2212, 'avg', 32, '0'); convert_element_type_2212 = None + wait_tensor_741 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_299); reduce_scatter_tensor_299 = None + mul_684 = torch.ops.aten.mul.Tensor(view_2828, convert_element_type_357); convert_element_type_357 = None + mul_685 = torch.ops.aten.mul.Tensor(view_2828, view_788); view_2828 = view_788 = None + view_2829 = torch.ops.aten.view.default(mul_684, [16384, 1792]); mul_684 = None + permute_1033 = torch.ops.aten.permute.default(view_2829, [1, 0]) + mm_523 = torch.ops.aten.mm.default(permute_1033, view_780); permute_1033 = None + permute_1035 = torch.ops.aten.permute.default(permute_119, [1, 0]); permute_119 = None + mm_524 = torch.ops.aten.mm.default(view_2829, permute_1035); view_2829 = permute_1035 = None + view_2830 = torch.ops.aten.view.default(mm_524, [2, 8192, 4096]); mm_524 = None + convert_element_type_2217 = torch.ops.prims.convert_element_type.default(mm_523, torch.float32); mm_523 = None + reduce_scatter_tensor_300 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2217, 'avg', 32, '0'); convert_element_type_2217 = None + wait_tensor_742 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_300); reduce_scatter_tensor_300 = None + convert_element_type_2218 = torch.ops.prims.convert_element_type.default(mul_685, torch.float32); mul_685 = None + neg_21 = torch.ops.aten.neg.default(convert_element_type_356) + exp_21 = torch.ops.aten.exp.default(neg_21); neg_21 = None + add_276 = torch.ops.aten.add.Tensor(exp_21, 1); exp_21 = None + reciprocal_21 = torch.ops.aten.reciprocal.default(add_276); add_276 = None + mul_686 = torch.ops.aten.mul.Tensor(reciprocal_21, 1); reciprocal_21 = None + mul_687 = torch.ops.aten.mul.Tensor(convert_element_type_2218, mul_686); convert_element_type_2218 = None + sub_65 = torch.ops.aten.sub.Tensor(1, mul_686); mul_686 = None + mul_688 = torch.ops.aten.mul.Tensor(convert_element_type_356, sub_65); convert_element_type_356 = sub_65 = None + add_277 = torch.ops.aten.add.Tensor(mul_688, 1); mul_688 = None + mul_689 = torch.ops.aten.mul.Tensor(mul_687, add_277); mul_687 = add_277 = None + convert_element_type_2220 = torch.ops.prims.convert_element_type.default(mul_689, torch.bfloat16); mul_689 = None + view_2831 = torch.ops.aten.view.default(convert_element_type_2220, [16384, 1792]); convert_element_type_2220 = None + permute_1037 = torch.ops.aten.permute.default(view_2831, [1, 0]) + mm_525 = torch.ops.aten.mm.default(permute_1037, view_780); permute_1037 = view_780 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 32, '0'); convert_element_type_353 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + permute_1039 = torch.ops.aten.permute.default(permute_118, [1, 0]); permute_118 = None + mm_526 = torch.ops.aten.mm.default(view_2831, permute_1039); view_2831 = permute_1039 = None + view_2832 = torch.ops.aten.view.default(mm_526, [2, 8192, 4096]); mm_526 = None + add_278 = torch.ops.aten.add.Tensor(view_2830, view_2832); view_2830 = view_2832 = None + convert_element_type_2225 = torch.ops.prims.convert_element_type.default(mm_525, torch.float32); mm_525 = None + reduce_scatter_tensor_301 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2225, 'avg', 32, '0'); convert_element_type_2225 = None + wait_tensor_743 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_301); reduce_scatter_tensor_301 = None + split_224 = torch.ops.aten.split.Tensor(add_278, 1024, 1); add_278 = None + getitem_2143 = split_224[0] + getitem_2144 = split_224[1] + getitem_2145 = split_224[2] + getitem_2146 = split_224[3] + getitem_2147 = split_224[4] + getitem_2148 = split_224[5] + getitem_2149 = split_224[6] + getitem_2150 = split_224[7]; split_224 = None + cat_216 = torch.ops.aten.cat.default([getitem_2143, getitem_2144, getitem_2145, getitem_2146, getitem_2147, getitem_2148, getitem_2149, getitem_2150]); getitem_2143 = getitem_2144 = getitem_2145 = getitem_2146 = getitem_2147 = getitem_2148 = getitem_2149 = getitem_2150 = None + reduce_scatter_tensor_302 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_216, 'sum', 8, '1'); cat_216 = None + wait_tensor_744 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_302); reduce_scatter_tensor_302 = None + convert_element_type_2226 = torch.ops.prims.convert_element_type.default(wait_tensor_744, torch.float32); wait_tensor_744 = None + convert_element_type_2228 = torch.ops.prims.convert_element_type.default(wait_tensor_139, torch.float32); wait_tensor_139 = None + mul_690 = torch.ops.aten.mul.Tensor(convert_element_type_2226, convert_element_type_2228); convert_element_type_2228 = None + mul_692 = torch.ops.aten.mul.Tensor(mul_84, mul_690) + sum_129 = torch.ops.aten.sum.dim_IntList(mul_692, [2], True); mul_692 = None + div_43 = torch.ops.aten.div.Tensor(mul_84, 4096) + mul_693 = torch.ops.aten.mul.Tensor(div_43, sum_129); div_43 = sum_129 = None + sub_66 = torch.ops.aten.sub.Tensor(mul_690, mul_693); mul_690 = mul_693 = None + mul_694 = torch.ops.aten.mul.Tensor(sub_66, rsqrt_21); sub_66 = rsqrt_21 = None + mul_695 = torch.ops.aten.mul.Tensor(convert_element_type_2226, mul_84); convert_element_type_2226 = mul_84 = None + sum_130 = torch.ops.aten.sum.dim_IntList(mul_695, [0, 1]); mul_695 = None + convert_element_type_2229 = torch.ops.prims.convert_element_type.default(mul_694, torch.bfloat16); mul_694 = None + convert_element_type_2230 = torch.ops.prims.convert_element_type.default(sum_130, torch.bfloat16); sum_130 = None + all_reduce_43 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2230, 'sum', '1'); convert_element_type_2230 = None + wait_tensor_745 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_43); all_reduce_43 = None + convert_element_type_2231 = torch.ops.prims.convert_element_type.default(wait_tensor_745, torch.float32); wait_tensor_745 = None + reduce_scatter_tensor_303 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2231, 'avg', 32, '0'); convert_element_type_2231 = None + wait_tensor_746 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_303); reduce_scatter_tensor_303 = None + add_279 = torch.ops.aten.add.Tensor(add_275, convert_element_type_2229); add_275 = convert_element_type_2229 = None + all_gather_into_tensor_399 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_279, 8, '1') + wait_tensor_747 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_399); all_gather_into_tensor_399 = None + split_225 = torch.ops.aten.split.Tensor(wait_tensor_747, 2); wait_tensor_747 = None + getitem_2151 = split_225[0] + getitem_2152 = split_225[1] + getitem_2153 = split_225[2] + getitem_2154 = split_225[3] + getitem_2155 = split_225[4] + getitem_2156 = split_225[5] + getitem_2157 = split_225[6] + getitem_2158 = split_225[7]; split_225 = None + cat_217 = torch.ops.aten.cat.default([getitem_2151, getitem_2152, getitem_2153, getitem_2154, getitem_2155, getitem_2156, getitem_2157, getitem_2158], 1); getitem_2151 = getitem_2152 = getitem_2153 = getitem_2154 = getitem_2155 = getitem_2156 = getitem_2157 = getitem_2158 = None + view_2833 = torch.ops.aten.view.default(cat_217, [16384, 4096]); cat_217 = None + permute_1041 = torch.ops.aten.permute.default(view_2833, [1, 0]) + permute_116 = torch.ops.aten.permute.default(getitem_490, [0, 2, 1, 3]) + view_762 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + view_768 = torch.ops.aten.view.default(view_762, [16384, 512]); view_762 = None + mm_527 = torch.ops.aten.mm.default(permute_1041, view_768); permute_1041 = view_768 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16); primals_98 = None + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 32, '0'); convert_element_type_347 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + permute_1043 = torch.ops.aten.permute.default(permute_117, [1, 0]); permute_117 = None + mm_528 = torch.ops.aten.mm.default(view_2833, permute_1043); view_2833 = permute_1043 = None + view_2834 = torch.ops.aten.view.default(mm_528, [2, 8192, 512]); mm_528 = None + convert_element_type_2236 = torch.ops.prims.convert_element_type.default(mm_527, torch.float32); mm_527 = None + reduce_scatter_tensor_304 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2236, 'avg', 32, '0'); convert_element_type_2236 = None + wait_tensor_748 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_304); reduce_scatter_tensor_304 = None + view_2835 = torch.ops.aten.view.default(view_2834, [2, 8192, 4, 128]); view_2834 = None + permute_1045 = torch.ops.aten.permute.default(view_2835, [0, 2, 1, 3]); view_2835 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16); primals_94 = None + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 32, '0'); convert_element_type_331 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32); add_39 = None + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_132) + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_333, 8, '1'); convert_element_type_333 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + split_49 = torch.ops.aten.split.Tensor(wait_tensor_133, 2); wait_tensor_133 = None + getitem_482 = split_49[0] + getitem_483 = split_49[1] + getitem_484 = split_49[2] + getitem_485 = split_49[3] + getitem_486 = split_49[4] + getitem_487 = split_49[5] + getitem_488 = split_49[6] + getitem_489 = split_49[7]; split_49 = None + cat_41 = torch.ops.aten.cat.default([getitem_482, getitem_483, getitem_484, getitem_485, getitem_486, getitem_487, getitem_488, getitem_489], 1); getitem_482 = getitem_483 = getitem_484 = getitem_485 = getitem_486 = getitem_487 = getitem_488 = getitem_489 = None + view_735 = torch.ops.aten.view.default(cat_41, [16384, 4096]); cat_41 = None + view_736 = torch.ops.aten.view.default(mm_70, [2, 8192, 512]); mm_70 = None + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16); primals_96 = None + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 32, '0'); convert_element_type_337 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + mm_71 = torch.ops.aten.mm.default(view_735, permute_111) + view_743 = torch.ops.aten.view.default(mm_71, [2, 8192, 128]); mm_71 = None + view_750 = torch.ops.aten.view.default(mm_72, [2, 8192, 128]); mm_72 = None + view_752 = torch.ops.aten.view.default(view_736, [2, 8192, -1, 128]); view_736 = None + view_753 = torch.ops.aten.view.default(view_743, [2, 8192, -1, 128]); view_743 = None + view_754 = torch.ops.aten.view.default(view_750, [2, 8192, -1, 128]); view_750 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_752, torch.float32); view_752 = None + view_755 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 4, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_755); view_755 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_753, torch.float32); view_753 = None + view_756 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 1, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_756); view_756 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_37); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_758 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 4, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_37); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_759 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 1, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_758, torch.bfloat16); view_758 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_759, torch.bfloat16); view_759 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 1, 4, 128]); unsqueeze_20 = None + view_760 = torch.ops.aten.view.default(expand_20, [2, 8192, 4, 128]); expand_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_754, 3); view_754 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 1, 4, 128]); unsqueeze_21 = None + view_761 = torch.ops.aten.view.default(expand_21, [2, 8192, 4, 128]); expand_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_760, [0, 2, 1, 3]); view_760 = None + permute_115 = torch.ops.aten.permute.default(view_761, [0, 2, 1, 3]); view_761 = None + _scaled_dot_product_cudnn_attention_backward_21 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1045, permute_113, permute_114, permute_115, getitem_490, getitem_491, getitem_496, getitem_497, None, None, None, 8192, 8192, 0.0, True); permute_1045 = permute_113 = permute_114 = permute_115 = getitem_490 = getitem_491 = getitem_496 = getitem_497 = None + getitem_2159 = _scaled_dot_product_cudnn_attention_backward_21[0] + getitem_2160 = _scaled_dot_product_cudnn_attention_backward_21[1] + getitem_2161 = _scaled_dot_product_cudnn_attention_backward_21[2]; _scaled_dot_product_cudnn_attention_backward_21 = None + permute_1046 = torch.ops.aten.permute.default(getitem_2161, [0, 2, 1, 3]); getitem_2161 = None + permute_1047 = torch.ops.aten.permute.default(getitem_2160, [0, 2, 1, 3]); getitem_2160 = None + permute_1048 = torch.ops.aten.permute.default(getitem_2159, [0, 2, 1, 3]); getitem_2159 = None + view_2836 = torch.ops.aten.view.default(permute_1046, [2, 8192, 1, 4, 128]); permute_1046 = None + sum_131 = torch.ops.aten.sum.dim_IntList(view_2836, [3], True); view_2836 = None + squeeze_42 = torch.ops.aten.squeeze.dim(sum_131, 3); sum_131 = None + view_2837 = torch.ops.aten.view.default(permute_1047, [2, 8192, 1, 4, 128]); permute_1047 = None + sum_132 = torch.ops.aten.sum.dim_IntList(view_2837, [3], True); view_2837 = None + squeeze_43 = torch.ops.aten.squeeze.dim(sum_132, 3); sum_132 = None + convert_element_type_2237 = torch.ops.prims.convert_element_type.default(squeeze_43, torch.float32); squeeze_43 = None + convert_element_type_2238 = torch.ops.prims.convert_element_type.default(permute_1048, torch.float32); permute_1048 = None + view_2838 = torch.ops.aten.view.default(convert_element_type_2237, [2, 8192, 1, 64, 2]); convert_element_type_2237 = None + view_as_complex_106 = torch.ops.aten.view_as_complex.default(view_2838); view_2838 = None + mul_696 = torch.ops.aten.mul.Tensor(view_as_complex_106, _conj); view_as_complex_106 = None + view_2839 = torch.ops.aten.view.default(convert_element_type_2238, [2, 8192, 4, 64, 2]); convert_element_type_2238 = None + view_as_complex_107 = torch.ops.aten.view_as_complex.default(view_2839); view_2839 = None + mul_697 = torch.ops.aten.mul.Tensor(view_as_complex_107, _conj); view_as_complex_107 = None + view_as_real_106 = torch.ops.aten.view_as_real.default(mul_696); mul_696 = None + view_2840 = torch.ops.aten.view.default(view_as_real_106, [2, 8192, 1, 128]); view_as_real_106 = None + convert_element_type_2239 = torch.ops.prims.convert_element_type.default(view_2840, torch.bfloat16); view_2840 = None + view_as_real_107 = torch.ops.aten.view_as_real.default(mul_697); mul_697 = None + view_2841 = torch.ops.aten.view.default(view_as_real_107, [2, 8192, 4, 128]); view_as_real_107 = None + convert_element_type_2240 = torch.ops.prims.convert_element_type.default(view_2841, torch.bfloat16); view_2841 = None + view_2842 = torch.ops.aten.view.default(squeeze_42, [2, 8192, 128]); squeeze_42 = None + view_2843 = torch.ops.aten.view.default(convert_element_type_2239, [2, 8192, 128]); convert_element_type_2239 = None + view_2844 = torch.ops.aten.view.default(convert_element_type_2240, [2, 8192, 512]); convert_element_type_2240 = None + view_2845 = torch.ops.aten.view.default(view_2842, [16384, 128]); view_2842 = None + permute_1049 = torch.ops.aten.permute.default(view_2845, [1, 0]) + mm_529 = torch.ops.aten.mm.default(permute_1049, view_735); permute_1049 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16); primals_97 = None + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 32, '0'); convert_element_type_340 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + permute_1051 = torch.ops.aten.permute.default(permute_112, [1, 0]); permute_112 = None + mm_530 = torch.ops.aten.mm.default(view_2845, permute_1051); view_2845 = permute_1051 = None + view_2846 = torch.ops.aten.view.default(mm_530, [2, 8192, 4096]); mm_530 = None + convert_element_type_2245 = torch.ops.prims.convert_element_type.default(mm_529, torch.float32); mm_529 = None + reduce_scatter_tensor_305 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2245, 'avg', 32, '0'); convert_element_type_2245 = None + wait_tensor_749 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_305); reduce_scatter_tensor_305 = None + view_2847 = torch.ops.aten.view.default(view_2843, [16384, 128]); view_2843 = None + permute_1053 = torch.ops.aten.permute.default(view_2847, [1, 0]) + mm_531 = torch.ops.aten.mm.default(permute_1053, view_735); permute_1053 = None + permute_1055 = torch.ops.aten.permute.default(permute_111, [1, 0]); permute_111 = None + mm_532 = torch.ops.aten.mm.default(view_2847, permute_1055); view_2847 = permute_1055 = None + view_2848 = torch.ops.aten.view.default(mm_532, [2, 8192, 4096]); mm_532 = None + add_280 = torch.ops.aten.add.Tensor(view_2846, view_2848); view_2846 = view_2848 = None + convert_element_type_2250 = torch.ops.prims.convert_element_type.default(mm_531, torch.float32); mm_531 = None + reduce_scatter_tensor_306 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2250, 'avg', 32, '0'); convert_element_type_2250 = None + wait_tensor_750 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_306); reduce_scatter_tensor_306 = None + view_2849 = torch.ops.aten.view.default(view_2844, [16384, 512]); view_2844 = None + permute_1057 = torch.ops.aten.permute.default(view_2849, [1, 0]) + mm_533 = torch.ops.aten.mm.default(permute_1057, view_735); permute_1057 = view_735 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16); primals_95 = None + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 32, '0'); convert_element_type_334 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + permute_1059 = torch.ops.aten.permute.default(permute_110, [1, 0]); permute_110 = None + mm_534 = torch.ops.aten.mm.default(view_2849, permute_1059); view_2849 = permute_1059 = None + view_2850 = torch.ops.aten.view.default(mm_534, [2, 8192, 4096]); mm_534 = None + add_281 = torch.ops.aten.add.Tensor(add_280, view_2850); add_280 = view_2850 = None + convert_element_type_2255 = torch.ops.prims.convert_element_type.default(mm_533, torch.float32); mm_533 = None + reduce_scatter_tensor_307 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2255, 'avg', 32, '0'); convert_element_type_2255 = None + wait_tensor_751 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_307); reduce_scatter_tensor_307 = None + split_226 = torch.ops.aten.split.Tensor(add_281, 1024, 1); add_281 = None + getitem_2162 = split_226[0] + getitem_2163 = split_226[1] + getitem_2164 = split_226[2] + getitem_2165 = split_226[3] + getitem_2166 = split_226[4] + getitem_2167 = split_226[5] + getitem_2168 = split_226[6] + getitem_2169 = split_226[7]; split_226 = None + cat_218 = torch.ops.aten.cat.default([getitem_2162, getitem_2163, getitem_2164, getitem_2165, getitem_2166, getitem_2167, getitem_2168, getitem_2169]); getitem_2162 = getitem_2163 = getitem_2164 = getitem_2165 = getitem_2166 = getitem_2167 = getitem_2168 = getitem_2169 = None + reduce_scatter_tensor_308 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_218, 'sum', 8, '1'); cat_218 = None + wait_tensor_752 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_308); reduce_scatter_tensor_308 = None + convert_element_type_2256 = torch.ops.prims.convert_element_type.default(wait_tensor_752, torch.float32); wait_tensor_752 = None + convert_element_type_2258 = torch.ops.prims.convert_element_type.default(wait_tensor_132, torch.float32); wait_tensor_132 = None + mul_698 = torch.ops.aten.mul.Tensor(convert_element_type_2256, convert_element_type_2258); convert_element_type_2258 = None + mul_700 = torch.ops.aten.mul.Tensor(mul_80, mul_698) + sum_133 = torch.ops.aten.sum.dim_IntList(mul_700, [2], True); mul_700 = None + div_44 = torch.ops.aten.div.Tensor(mul_80, 4096) + mul_701 = torch.ops.aten.mul.Tensor(div_44, sum_133); div_44 = sum_133 = None + sub_67 = torch.ops.aten.sub.Tensor(mul_698, mul_701); mul_698 = mul_701 = None + mul_702 = torch.ops.aten.mul.Tensor(sub_67, rsqrt_20); sub_67 = rsqrt_20 = None + mul_703 = torch.ops.aten.mul.Tensor(convert_element_type_2256, mul_80); convert_element_type_2256 = mul_80 = None + sum_134 = torch.ops.aten.sum.dim_IntList(mul_703, [0, 1]); mul_703 = None + convert_element_type_2259 = torch.ops.prims.convert_element_type.default(mul_702, torch.bfloat16); mul_702 = None + convert_element_type_2260 = torch.ops.prims.convert_element_type.default(sum_134, torch.bfloat16); sum_134 = None + all_reduce_44 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2260, 'sum', '1'); convert_element_type_2260 = None + wait_tensor_753 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_44); all_reduce_44 = None + convert_element_type_2261 = torch.ops.prims.convert_element_type.default(wait_tensor_753, torch.float32); wait_tensor_753 = None + reduce_scatter_tensor_309 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2261, 'avg', 32, '0'); convert_element_type_2261 = None + wait_tensor_754 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_309); reduce_scatter_tensor_309 = None + add_282 = torch.ops.aten.add.Tensor(add_279, convert_element_type_2259); add_279 = convert_element_type_2259 = None + all_gather_into_tensor_400 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_282, 8, '1') + wait_tensor_755 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_400); all_gather_into_tensor_400 = None + split_227 = torch.ops.aten.split.Tensor(wait_tensor_755, 2); wait_tensor_755 = None + getitem_2170 = split_227[0] + getitem_2171 = split_227[1] + getitem_2172 = split_227[2] + getitem_2173 = split_227[3] + getitem_2174 = split_227[4] + getitem_2175 = split_227[5] + getitem_2176 = split_227[6] + getitem_2177 = split_227[7]; split_227 = None + cat_219 = torch.ops.aten.cat.default([getitem_2170, getitem_2171, getitem_2172, getitem_2173, getitem_2174, getitem_2175, getitem_2176, getitem_2177], 1); getitem_2170 = getitem_2171 = getitem_2172 = getitem_2173 = getitem_2174 = getitem_2175 = getitem_2176 = getitem_2177 = None + view_2851 = torch.ops.aten.view.default(cat_219, [16384, 4096]); cat_219 = None + permute_1061 = torch.ops.aten.permute.default(view_2851, [1, 0]) + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19); reduce_scatter_tensor_19 = None + add_37 = torch.ops.aten.add.Tensor(add_35, wait_tensor_125); wait_tensor_125 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 32, '0'); convert_element_type_317 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32); add_37 = None + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_126) + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_319, 8, '1'); convert_element_type_319 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + split_47 = torch.ops.aten.split.Tensor(wait_tensor_127, 2); wait_tensor_127 = None + getitem_466 = split_47[0] + getitem_467 = split_47[1] + getitem_468 = split_47[2] + getitem_469 = split_47[3] + getitem_470 = split_47[4] + getitem_471 = split_47[5] + getitem_472 = split_47[6] + getitem_473 = split_47[7]; split_47 = None + cat_39 = torch.ops.aten.cat.default([getitem_466, getitem_467, getitem_468, getitem_469, getitem_470, getitem_471, getitem_472, getitem_473], 1); getitem_466 = getitem_467 = getitem_468 = getitem_469 = getitem_470 = getitem_471 = getitem_472 = getitem_473 = None + view_708 = torch.ops.aten.view.default(cat_39, [16384, 4096]); cat_39 = None + view_709 = torch.ops.aten.view.default(mm_67, [2, 8192, 1792]); mm_67 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_709, torch.float32); view_709 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16); primals_92 = None + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 32, '0'); convert_element_type_325 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_68 = torch.ops.aten.mm.default(view_708, permute_108) + view_716 = torch.ops.aten.view.default(mm_68, [2, 8192, 1792]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_716) + view_723 = torch.ops.aten.view.default(mul_79, [16384, 1792]); mul_79 = None + mm_535 = torch.ops.aten.mm.default(permute_1061, view_723); permute_1061 = view_723 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16); primals_93 = None + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 32, '0'); convert_element_type_328 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + permute_1063 = torch.ops.aten.permute.default(permute_109, [1, 0]); permute_109 = None + mm_536 = torch.ops.aten.mm.default(view_2851, permute_1063); view_2851 = permute_1063 = None + view_2852 = torch.ops.aten.view.default(mm_536, [2, 8192, 1792]); mm_536 = None + convert_element_type_2266 = torch.ops.prims.convert_element_type.default(mm_535, torch.float32); mm_535 = None + reduce_scatter_tensor_310 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2266, 'avg', 32, '0'); convert_element_type_2266 = None + wait_tensor_756 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_310); reduce_scatter_tensor_310 = None + mul_704 = torch.ops.aten.mul.Tensor(view_2852, convert_element_type_324); convert_element_type_324 = None + mul_705 = torch.ops.aten.mul.Tensor(view_2852, view_716); view_2852 = view_716 = None + view_2853 = torch.ops.aten.view.default(mul_704, [16384, 1792]); mul_704 = None + permute_1065 = torch.ops.aten.permute.default(view_2853, [1, 0]) + mm_537 = torch.ops.aten.mm.default(permute_1065, view_708); permute_1065 = None + permute_1067 = torch.ops.aten.permute.default(permute_108, [1, 0]); permute_108 = None + mm_538 = torch.ops.aten.mm.default(view_2853, permute_1067); view_2853 = permute_1067 = None + view_2854 = torch.ops.aten.view.default(mm_538, [2, 8192, 4096]); mm_538 = None + convert_element_type_2271 = torch.ops.prims.convert_element_type.default(mm_537, torch.float32); mm_537 = None + reduce_scatter_tensor_311 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2271, 'avg', 32, '0'); convert_element_type_2271 = None + wait_tensor_757 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_311); reduce_scatter_tensor_311 = None + convert_element_type_2272 = torch.ops.prims.convert_element_type.default(mul_705, torch.float32); mul_705 = None + neg_22 = torch.ops.aten.neg.default(convert_element_type_323) + exp_22 = torch.ops.aten.exp.default(neg_22); neg_22 = None + add_283 = torch.ops.aten.add.Tensor(exp_22, 1); exp_22 = None + reciprocal_22 = torch.ops.aten.reciprocal.default(add_283); add_283 = None + mul_706 = torch.ops.aten.mul.Tensor(reciprocal_22, 1); reciprocal_22 = None + mul_707 = torch.ops.aten.mul.Tensor(convert_element_type_2272, mul_706); convert_element_type_2272 = None + sub_68 = torch.ops.aten.sub.Tensor(1, mul_706); mul_706 = None + mul_708 = torch.ops.aten.mul.Tensor(convert_element_type_323, sub_68); convert_element_type_323 = sub_68 = None + add_284 = torch.ops.aten.add.Tensor(mul_708, 1); mul_708 = None + mul_709 = torch.ops.aten.mul.Tensor(mul_707, add_284); mul_707 = add_284 = None + convert_element_type_2274 = torch.ops.prims.convert_element_type.default(mul_709, torch.bfloat16); mul_709 = None + view_2855 = torch.ops.aten.view.default(convert_element_type_2274, [16384, 1792]); convert_element_type_2274 = None + permute_1069 = torch.ops.aten.permute.default(view_2855, [1, 0]) + mm_539 = torch.ops.aten.mm.default(permute_1069, view_708); permute_1069 = view_708 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16); primals_91 = None + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 32, '0'); convert_element_type_320 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + permute_1071 = torch.ops.aten.permute.default(permute_107, [1, 0]); permute_107 = None + mm_540 = torch.ops.aten.mm.default(view_2855, permute_1071); view_2855 = permute_1071 = None + view_2856 = torch.ops.aten.view.default(mm_540, [2, 8192, 4096]); mm_540 = None + add_285 = torch.ops.aten.add.Tensor(view_2854, view_2856); view_2854 = view_2856 = None + convert_element_type_2279 = torch.ops.prims.convert_element_type.default(mm_539, torch.float32); mm_539 = None + reduce_scatter_tensor_312 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2279, 'avg', 32, '0'); convert_element_type_2279 = None + wait_tensor_758 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_312); reduce_scatter_tensor_312 = None + split_228 = torch.ops.aten.split.Tensor(add_285, 1024, 1); add_285 = None + getitem_2178 = split_228[0] + getitem_2179 = split_228[1] + getitem_2180 = split_228[2] + getitem_2181 = split_228[3] + getitem_2182 = split_228[4] + getitem_2183 = split_228[5] + getitem_2184 = split_228[6] + getitem_2185 = split_228[7]; split_228 = None + cat_220 = torch.ops.aten.cat.default([getitem_2178, getitem_2179, getitem_2180, getitem_2181, getitem_2182, getitem_2183, getitem_2184, getitem_2185]); getitem_2178 = getitem_2179 = getitem_2180 = getitem_2181 = getitem_2182 = getitem_2183 = getitem_2184 = getitem_2185 = None + reduce_scatter_tensor_313 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_220, 'sum', 8, '1'); cat_220 = None + wait_tensor_759 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_313); reduce_scatter_tensor_313 = None + convert_element_type_2280 = torch.ops.prims.convert_element_type.default(wait_tensor_759, torch.float32); wait_tensor_759 = None + convert_element_type_2282 = torch.ops.prims.convert_element_type.default(wait_tensor_126, torch.float32); wait_tensor_126 = None + mul_710 = torch.ops.aten.mul.Tensor(convert_element_type_2280, convert_element_type_2282); convert_element_type_2282 = None + mul_712 = torch.ops.aten.mul.Tensor(mul_76, mul_710) + sum_135 = torch.ops.aten.sum.dim_IntList(mul_712, [2], True); mul_712 = None + div_45 = torch.ops.aten.div.Tensor(mul_76, 4096) + mul_713 = torch.ops.aten.mul.Tensor(div_45, sum_135); div_45 = sum_135 = None + sub_69 = torch.ops.aten.sub.Tensor(mul_710, mul_713); mul_710 = mul_713 = None + mul_714 = torch.ops.aten.mul.Tensor(sub_69, rsqrt_19); sub_69 = rsqrt_19 = None + mul_715 = torch.ops.aten.mul.Tensor(convert_element_type_2280, mul_76); convert_element_type_2280 = mul_76 = None + sum_136 = torch.ops.aten.sum.dim_IntList(mul_715, [0, 1]); mul_715 = None + convert_element_type_2283 = torch.ops.prims.convert_element_type.default(mul_714, torch.bfloat16); mul_714 = None + convert_element_type_2284 = torch.ops.prims.convert_element_type.default(sum_136, torch.bfloat16); sum_136 = None + all_reduce_45 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2284, 'sum', '1'); convert_element_type_2284 = None + wait_tensor_760 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_45); all_reduce_45 = None + convert_element_type_2285 = torch.ops.prims.convert_element_type.default(wait_tensor_760, torch.float32); wait_tensor_760 = None + reduce_scatter_tensor_314 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2285, 'avg', 32, '0'); convert_element_type_2285 = None + wait_tensor_761 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_314); reduce_scatter_tensor_314 = None + add_286 = torch.ops.aten.add.Tensor(add_282, convert_element_type_2283); add_282 = convert_element_type_2283 = None + all_gather_into_tensor_401 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_286, 8, '1') + wait_tensor_762 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_401); all_gather_into_tensor_401 = None + split_229 = torch.ops.aten.split.Tensor(wait_tensor_762, 2); wait_tensor_762 = None + getitem_2186 = split_229[0] + getitem_2187 = split_229[1] + getitem_2188 = split_229[2] + getitem_2189 = split_229[3] + getitem_2190 = split_229[4] + getitem_2191 = split_229[5] + getitem_2192 = split_229[6] + getitem_2193 = split_229[7]; split_229 = None + cat_221 = torch.ops.aten.cat.default([getitem_2186, getitem_2187, getitem_2188, getitem_2189, getitem_2190, getitem_2191, getitem_2192, getitem_2193], 1); getitem_2186 = getitem_2187 = getitem_2188 = getitem_2189 = getitem_2190 = getitem_2191 = getitem_2192 = getitem_2193 = None + view_2857 = torch.ops.aten.view.default(cat_221, [16384, 4096]); cat_221 = None + permute_1073 = torch.ops.aten.permute.default(view_2857, [1, 0]) + permute_105 = torch.ops.aten.permute.default(getitem_449, [0, 2, 1, 3]) + view_690 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + view_696 = torch.ops.aten.view.default(view_690, [16384, 512]); view_690 = None + mm_541 = torch.ops.aten.mm.default(permute_1073, view_696); permute_1073 = view_696 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 32, '0'); convert_element_type_314 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + permute_1075 = torch.ops.aten.permute.default(permute_106, [1, 0]); permute_106 = None + mm_542 = torch.ops.aten.mm.default(view_2857, permute_1075); view_2857 = permute_1075 = None + view_2858 = torch.ops.aten.view.default(mm_542, [2, 8192, 512]); mm_542 = None + convert_element_type_2290 = torch.ops.prims.convert_element_type.default(mm_541, torch.float32); mm_541 = None + reduce_scatter_tensor_315 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2290, 'avg', 32, '0'); convert_element_type_2290 = None + wait_tensor_763 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_315); reduce_scatter_tensor_315 = None + view_2859 = torch.ops.aten.view.default(view_2858, [2, 8192, 4, 128]); view_2858 = None + permute_1077 = torch.ops.aten.permute.default(view_2859, [0, 2, 1, 3]); view_2859 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 32, '0'); convert_element_type_298 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_119) + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_300, 8, '1'); convert_element_type_300 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + split_45 = torch.ops.aten.split.Tensor(wait_tensor_120, 2); wait_tensor_120 = None + getitem_441 = split_45[0] + getitem_442 = split_45[1] + getitem_443 = split_45[2] + getitem_444 = split_45[3] + getitem_445 = split_45[4] + getitem_446 = split_45[5] + getitem_447 = split_45[6] + getitem_448 = split_45[7]; split_45 = None + cat_37 = torch.ops.aten.cat.default([getitem_441, getitem_442, getitem_443, getitem_444, getitem_445, getitem_446, getitem_447, getitem_448], 1); getitem_441 = getitem_442 = getitem_443 = getitem_444 = getitem_445 = getitem_446 = getitem_447 = getitem_448 = None + view_663 = torch.ops.aten.view.default(cat_37, [16384, 4096]); cat_37 = None + view_664 = torch.ops.aten.view.default(mm_63, [2, 8192, 512]); mm_63 = None + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 32, '0'); convert_element_type_304 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + mm_64 = torch.ops.aten.mm.default(view_663, permute_100) + view_671 = torch.ops.aten.view.default(mm_64, [2, 8192, 128]); mm_64 = None + view_678 = torch.ops.aten.view.default(mm_65, [2, 8192, 128]); mm_65 = None + view_680 = torch.ops.aten.view.default(view_664, [2, 8192, -1, 128]); view_664 = None + view_681 = torch.ops.aten.view.default(view_671, [2, 8192, -1, 128]); view_671 = None + view_682 = torch.ops.aten.view.default(view_678, [2, 8192, -1, 128]); view_678 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_680, torch.float32); view_680 = None + view_683 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 4, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_683); view_683 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_681, torch.float32); view_681 = None + view_684 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 1, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_684); view_684 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_37); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_686 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 4, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_37); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_687 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 1, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_686, torch.bfloat16); view_686 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_687, torch.bfloat16); view_687 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 1, 4, 128]); unsqueeze_18 = None + view_688 = torch.ops.aten.view.default(expand_18, [2, 8192, 4, 128]); expand_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_682, 3); view_682 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 1, 4, 128]); unsqueeze_19 = None + view_689 = torch.ops.aten.view.default(expand_19, [2, 8192, 4, 128]); expand_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_688, [0, 2, 1, 3]); view_688 = None + permute_104 = torch.ops.aten.permute.default(view_689, [0, 2, 1, 3]); view_689 = None + _scaled_dot_product_cudnn_attention_backward_22 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1077, permute_102, permute_103, permute_104, getitem_449, getitem_450, getitem_455, getitem_456, None, None, None, 8192, 8192, 0.0, True); permute_1077 = permute_102 = permute_103 = permute_104 = getitem_449 = getitem_450 = getitem_455 = getitem_456 = None + getitem_2194 = _scaled_dot_product_cudnn_attention_backward_22[0] + getitem_2195 = _scaled_dot_product_cudnn_attention_backward_22[1] + getitem_2196 = _scaled_dot_product_cudnn_attention_backward_22[2]; _scaled_dot_product_cudnn_attention_backward_22 = None + permute_1078 = torch.ops.aten.permute.default(getitem_2196, [0, 2, 1, 3]); getitem_2196 = None + permute_1079 = torch.ops.aten.permute.default(getitem_2195, [0, 2, 1, 3]); getitem_2195 = None + permute_1080 = torch.ops.aten.permute.default(getitem_2194, [0, 2, 1, 3]); getitem_2194 = None + view_2860 = torch.ops.aten.view.default(permute_1078, [2, 8192, 1, 4, 128]); permute_1078 = None + sum_137 = torch.ops.aten.sum.dim_IntList(view_2860, [3], True); view_2860 = None + squeeze_44 = torch.ops.aten.squeeze.dim(sum_137, 3); sum_137 = None + view_2861 = torch.ops.aten.view.default(permute_1079, [2, 8192, 1, 4, 128]); permute_1079 = None + sum_138 = torch.ops.aten.sum.dim_IntList(view_2861, [3], True); view_2861 = None + squeeze_45 = torch.ops.aten.squeeze.dim(sum_138, 3); sum_138 = None + convert_element_type_2291 = torch.ops.prims.convert_element_type.default(squeeze_45, torch.float32); squeeze_45 = None + convert_element_type_2292 = torch.ops.prims.convert_element_type.default(permute_1080, torch.float32); permute_1080 = None + view_2862 = torch.ops.aten.view.default(convert_element_type_2291, [2, 8192, 1, 64, 2]); convert_element_type_2291 = None + view_as_complex_108 = torch.ops.aten.view_as_complex.default(view_2862); view_2862 = None + mul_716 = torch.ops.aten.mul.Tensor(view_as_complex_108, _conj); view_as_complex_108 = None + view_2863 = torch.ops.aten.view.default(convert_element_type_2292, [2, 8192, 4, 64, 2]); convert_element_type_2292 = None + view_as_complex_109 = torch.ops.aten.view_as_complex.default(view_2863); view_2863 = None + mul_717 = torch.ops.aten.mul.Tensor(view_as_complex_109, _conj); view_as_complex_109 = None + view_as_real_108 = torch.ops.aten.view_as_real.default(mul_716); mul_716 = None + view_2864 = torch.ops.aten.view.default(view_as_real_108, [2, 8192, 1, 128]); view_as_real_108 = None + convert_element_type_2293 = torch.ops.prims.convert_element_type.default(view_2864, torch.bfloat16); view_2864 = None + view_as_real_109 = torch.ops.aten.view_as_real.default(mul_717); mul_717 = None + view_2865 = torch.ops.aten.view.default(view_as_real_109, [2, 8192, 4, 128]); view_as_real_109 = None + convert_element_type_2294 = torch.ops.prims.convert_element_type.default(view_2865, torch.bfloat16); view_2865 = None + view_2866 = torch.ops.aten.view.default(squeeze_44, [2, 8192, 128]); squeeze_44 = None + view_2867 = torch.ops.aten.view.default(convert_element_type_2293, [2, 8192, 128]); convert_element_type_2293 = None + view_2868 = torch.ops.aten.view.default(convert_element_type_2294, [2, 8192, 512]); convert_element_type_2294 = None + view_2869 = torch.ops.aten.view.default(view_2866, [16384, 128]); view_2866 = None + permute_1081 = torch.ops.aten.permute.default(view_2869, [1, 0]) + mm_543 = torch.ops.aten.mm.default(permute_1081, view_663); permute_1081 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 32, '0'); convert_element_type_307 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_123, [1, 0]); wait_tensor_123 = None + permute_1083 = torch.ops.aten.permute.default(permute_101, [1, 0]); permute_101 = None + mm_544 = torch.ops.aten.mm.default(view_2869, permute_1083); view_2869 = permute_1083 = None + view_2870 = torch.ops.aten.view.default(mm_544, [2, 8192, 4096]); mm_544 = None + convert_element_type_2299 = torch.ops.prims.convert_element_type.default(mm_543, torch.float32); mm_543 = None + reduce_scatter_tensor_316 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2299, 'avg', 32, '0'); convert_element_type_2299 = None + wait_tensor_764 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_316); reduce_scatter_tensor_316 = None + view_2871 = torch.ops.aten.view.default(view_2867, [16384, 128]); view_2867 = None + permute_1085 = torch.ops.aten.permute.default(view_2871, [1, 0]) + mm_545 = torch.ops.aten.mm.default(permute_1085, view_663); permute_1085 = None + permute_1087 = torch.ops.aten.permute.default(permute_100, [1, 0]); permute_100 = None + mm_546 = torch.ops.aten.mm.default(view_2871, permute_1087); view_2871 = permute_1087 = None + view_2872 = torch.ops.aten.view.default(mm_546, [2, 8192, 4096]); mm_546 = None + add_287 = torch.ops.aten.add.Tensor(view_2870, view_2872); view_2870 = view_2872 = None + convert_element_type_2304 = torch.ops.prims.convert_element_type.default(mm_545, torch.float32); mm_545 = None + reduce_scatter_tensor_317 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2304, 'avg', 32, '0'); convert_element_type_2304 = None + wait_tensor_765 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_317); reduce_scatter_tensor_317 = None + view_2873 = torch.ops.aten.view.default(view_2868, [16384, 512]); view_2868 = None + permute_1089 = torch.ops.aten.permute.default(view_2873, [1, 0]) + mm_547 = torch.ops.aten.mm.default(permute_1089, view_663); permute_1089 = view_663 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 32, '0'); convert_element_type_301 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + permute_1091 = torch.ops.aten.permute.default(permute_99, [1, 0]); permute_99 = None + mm_548 = torch.ops.aten.mm.default(view_2873, permute_1091); view_2873 = permute_1091 = None + view_2874 = torch.ops.aten.view.default(mm_548, [2, 8192, 4096]); mm_548 = None + add_288 = torch.ops.aten.add.Tensor(add_287, view_2874); add_287 = view_2874 = None + convert_element_type_2309 = torch.ops.prims.convert_element_type.default(mm_547, torch.float32); mm_547 = None + reduce_scatter_tensor_318 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2309, 'avg', 32, '0'); convert_element_type_2309 = None + wait_tensor_766 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_318); reduce_scatter_tensor_318 = None + split_230 = torch.ops.aten.split.Tensor(add_288, 1024, 1); add_288 = None + getitem_2197 = split_230[0] + getitem_2198 = split_230[1] + getitem_2199 = split_230[2] + getitem_2200 = split_230[3] + getitem_2201 = split_230[4] + getitem_2202 = split_230[5] + getitem_2203 = split_230[6] + getitem_2204 = split_230[7]; split_230 = None + cat_222 = torch.ops.aten.cat.default([getitem_2197, getitem_2198, getitem_2199, getitem_2200, getitem_2201, getitem_2202, getitem_2203, getitem_2204]); getitem_2197 = getitem_2198 = getitem_2199 = getitem_2200 = getitem_2201 = getitem_2202 = getitem_2203 = getitem_2204 = None + reduce_scatter_tensor_319 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_222, 'sum', 8, '1'); cat_222 = None + wait_tensor_767 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_319); reduce_scatter_tensor_319 = None + convert_element_type_2310 = torch.ops.prims.convert_element_type.default(wait_tensor_767, torch.float32); wait_tensor_767 = None + convert_element_type_2312 = torch.ops.prims.convert_element_type.default(wait_tensor_119, torch.float32); wait_tensor_119 = None + mul_718 = torch.ops.aten.mul.Tensor(convert_element_type_2310, convert_element_type_2312); convert_element_type_2312 = None + mul_720 = torch.ops.aten.mul.Tensor(mul_72, mul_718) + sum_139 = torch.ops.aten.sum.dim_IntList(mul_720, [2], True); mul_720 = None + div_46 = torch.ops.aten.div.Tensor(mul_72, 4096) + mul_721 = torch.ops.aten.mul.Tensor(div_46, sum_139); div_46 = sum_139 = None + sub_70 = torch.ops.aten.sub.Tensor(mul_718, mul_721); mul_718 = mul_721 = None + mul_722 = torch.ops.aten.mul.Tensor(sub_70, rsqrt_18); sub_70 = rsqrt_18 = None + mul_723 = torch.ops.aten.mul.Tensor(convert_element_type_2310, mul_72); convert_element_type_2310 = mul_72 = None + sum_140 = torch.ops.aten.sum.dim_IntList(mul_723, [0, 1]); mul_723 = None + convert_element_type_2313 = torch.ops.prims.convert_element_type.default(mul_722, torch.bfloat16); mul_722 = None + convert_element_type_2314 = torch.ops.prims.convert_element_type.default(sum_140, torch.bfloat16); sum_140 = None + all_reduce_46 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2314, 'sum', '1'); convert_element_type_2314 = None + wait_tensor_768 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_46); all_reduce_46 = None + convert_element_type_2315 = torch.ops.prims.convert_element_type.default(wait_tensor_768, torch.float32); wait_tensor_768 = None + reduce_scatter_tensor_320 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2315, 'avg', 32, '0'); convert_element_type_2315 = None + wait_tensor_769 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_320); reduce_scatter_tensor_320 = None + add_289 = torch.ops.aten.add.Tensor(add_286, convert_element_type_2313); add_286 = convert_element_type_2313 = None + all_gather_into_tensor_402 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_289, 8, '1') + wait_tensor_770 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_402); all_gather_into_tensor_402 = None + split_231 = torch.ops.aten.split.Tensor(wait_tensor_770, 2); wait_tensor_770 = None + getitem_2205 = split_231[0] + getitem_2206 = split_231[1] + getitem_2207 = split_231[2] + getitem_2208 = split_231[3] + getitem_2209 = split_231[4] + getitem_2210 = split_231[5] + getitem_2211 = split_231[6] + getitem_2212 = split_231[7]; split_231 = None + cat_223 = torch.ops.aten.cat.default([getitem_2205, getitem_2206, getitem_2207, getitem_2208, getitem_2209, getitem_2210, getitem_2211, getitem_2212], 1); getitem_2205 = getitem_2206 = getitem_2207 = getitem_2208 = getitem_2209 = getitem_2210 = getitem_2211 = getitem_2212 = None + view_2875 = torch.ops.aten.view.default(cat_223, [16384, 4096]); cat_223 = None + permute_1093 = torch.ops.aten.permute.default(view_2875, [1, 0]) + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17); reduce_scatter_tensor_17 = None + add_33 = torch.ops.aten.add.Tensor(add_31, wait_tensor_112); wait_tensor_112 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16); primals_81 = None + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 32, '0'); convert_element_type_284 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32); add_33 = None + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_113) + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 8, '1'); convert_element_type_286 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + split_43 = torch.ops.aten.split.Tensor(wait_tensor_114, 2); wait_tensor_114 = None + getitem_425 = split_43[0] + getitem_426 = split_43[1] + getitem_427 = split_43[2] + getitem_428 = split_43[3] + getitem_429 = split_43[4] + getitem_430 = split_43[5] + getitem_431 = split_43[6] + getitem_432 = split_43[7]; split_43 = None + cat_35 = torch.ops.aten.cat.default([getitem_425, getitem_426, getitem_427, getitem_428, getitem_429, getitem_430, getitem_431, getitem_432], 1); getitem_425 = getitem_426 = getitem_427 = getitem_428 = getitem_429 = getitem_430 = getitem_431 = getitem_432 = None + view_636 = torch.ops.aten.view.default(cat_35, [16384, 4096]); cat_35 = None + view_637 = torch.ops.aten.view.default(mm_60, [2, 8192, 1792]); mm_60 = None + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_637, torch.float32); view_637 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 32, '0'); convert_element_type_292 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_61 = torch.ops.aten.mm.default(view_636, permute_97) + view_644 = torch.ops.aten.view.default(mm_61, [2, 8192, 1792]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_644) + view_651 = torch.ops.aten.view.default(mul_71, [16384, 1792]); mul_71 = None + mm_549 = torch.ops.aten.mm.default(permute_1093, view_651); permute_1093 = view_651 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 32, '0'); convert_element_type_295 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + permute_1095 = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None + mm_550 = torch.ops.aten.mm.default(view_2875, permute_1095); view_2875 = permute_1095 = None + view_2876 = torch.ops.aten.view.default(mm_550, [2, 8192, 1792]); mm_550 = None + convert_element_type_2320 = torch.ops.prims.convert_element_type.default(mm_549, torch.float32); mm_549 = None + reduce_scatter_tensor_321 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2320, 'avg', 32, '0'); convert_element_type_2320 = None + wait_tensor_771 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_321); reduce_scatter_tensor_321 = None + mul_724 = torch.ops.aten.mul.Tensor(view_2876, convert_element_type_291); convert_element_type_291 = None + mul_725 = torch.ops.aten.mul.Tensor(view_2876, view_644); view_2876 = view_644 = None + view_2877 = torch.ops.aten.view.default(mul_724, [16384, 1792]); mul_724 = None + permute_1097 = torch.ops.aten.permute.default(view_2877, [1, 0]) + mm_551 = torch.ops.aten.mm.default(permute_1097, view_636); permute_1097 = None + permute_1099 = torch.ops.aten.permute.default(permute_97, [1, 0]); permute_97 = None + mm_552 = torch.ops.aten.mm.default(view_2877, permute_1099); view_2877 = permute_1099 = None + view_2878 = torch.ops.aten.view.default(mm_552, [2, 8192, 4096]); mm_552 = None + convert_element_type_2325 = torch.ops.prims.convert_element_type.default(mm_551, torch.float32); mm_551 = None + reduce_scatter_tensor_322 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2325, 'avg', 32, '0'); convert_element_type_2325 = None + wait_tensor_772 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_322); reduce_scatter_tensor_322 = None + convert_element_type_2326 = torch.ops.prims.convert_element_type.default(mul_725, torch.float32); mul_725 = None + neg_23 = torch.ops.aten.neg.default(convert_element_type_290) + exp_23 = torch.ops.aten.exp.default(neg_23); neg_23 = None + add_290 = torch.ops.aten.add.Tensor(exp_23, 1); exp_23 = None + reciprocal_23 = torch.ops.aten.reciprocal.default(add_290); add_290 = None + mul_726 = torch.ops.aten.mul.Tensor(reciprocal_23, 1); reciprocal_23 = None + mul_727 = torch.ops.aten.mul.Tensor(convert_element_type_2326, mul_726); convert_element_type_2326 = None + sub_71 = torch.ops.aten.sub.Tensor(1, mul_726); mul_726 = None + mul_728 = torch.ops.aten.mul.Tensor(convert_element_type_290, sub_71); convert_element_type_290 = sub_71 = None + add_291 = torch.ops.aten.add.Tensor(mul_728, 1); mul_728 = None + mul_729 = torch.ops.aten.mul.Tensor(mul_727, add_291); mul_727 = add_291 = None + convert_element_type_2328 = torch.ops.prims.convert_element_type.default(mul_729, torch.bfloat16); mul_729 = None + view_2879 = torch.ops.aten.view.default(convert_element_type_2328, [16384, 1792]); convert_element_type_2328 = None + permute_1101 = torch.ops.aten.permute.default(view_2879, [1, 0]) + mm_553 = torch.ops.aten.mm.default(permute_1101, view_636); permute_1101 = view_636 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 32, '0'); convert_element_type_287 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + permute_1103 = torch.ops.aten.permute.default(permute_96, [1, 0]); permute_96 = None + mm_554 = torch.ops.aten.mm.default(view_2879, permute_1103); view_2879 = permute_1103 = None + view_2880 = torch.ops.aten.view.default(mm_554, [2, 8192, 4096]); mm_554 = None + add_292 = torch.ops.aten.add.Tensor(view_2878, view_2880); view_2878 = view_2880 = None + convert_element_type_2333 = torch.ops.prims.convert_element_type.default(mm_553, torch.float32); mm_553 = None + reduce_scatter_tensor_323 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2333, 'avg', 32, '0'); convert_element_type_2333 = None + wait_tensor_773 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_323); reduce_scatter_tensor_323 = None + split_232 = torch.ops.aten.split.Tensor(add_292, 1024, 1); add_292 = None + getitem_2213 = split_232[0] + getitem_2214 = split_232[1] + getitem_2215 = split_232[2] + getitem_2216 = split_232[3] + getitem_2217 = split_232[4] + getitem_2218 = split_232[5] + getitem_2219 = split_232[6] + getitem_2220 = split_232[7]; split_232 = None + cat_224 = torch.ops.aten.cat.default([getitem_2213, getitem_2214, getitem_2215, getitem_2216, getitem_2217, getitem_2218, getitem_2219, getitem_2220]); getitem_2213 = getitem_2214 = getitem_2215 = getitem_2216 = getitem_2217 = getitem_2218 = getitem_2219 = getitem_2220 = None + reduce_scatter_tensor_324 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_224, 'sum', 8, '1'); cat_224 = None + wait_tensor_774 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_324); reduce_scatter_tensor_324 = None + convert_element_type_2334 = torch.ops.prims.convert_element_type.default(wait_tensor_774, torch.float32); wait_tensor_774 = None + convert_element_type_2336 = torch.ops.prims.convert_element_type.default(wait_tensor_113, torch.float32); wait_tensor_113 = None + mul_730 = torch.ops.aten.mul.Tensor(convert_element_type_2334, convert_element_type_2336); convert_element_type_2336 = None + mul_732 = torch.ops.aten.mul.Tensor(mul_68, mul_730) + sum_141 = torch.ops.aten.sum.dim_IntList(mul_732, [2], True); mul_732 = None + div_47 = torch.ops.aten.div.Tensor(mul_68, 4096) + mul_733 = torch.ops.aten.mul.Tensor(div_47, sum_141); div_47 = sum_141 = None + sub_72 = torch.ops.aten.sub.Tensor(mul_730, mul_733); mul_730 = mul_733 = None + mul_734 = torch.ops.aten.mul.Tensor(sub_72, rsqrt_17); sub_72 = rsqrt_17 = None + mul_735 = torch.ops.aten.mul.Tensor(convert_element_type_2334, mul_68); convert_element_type_2334 = mul_68 = None + sum_142 = torch.ops.aten.sum.dim_IntList(mul_735, [0, 1]); mul_735 = None + convert_element_type_2337 = torch.ops.prims.convert_element_type.default(mul_734, torch.bfloat16); mul_734 = None + convert_element_type_2338 = torch.ops.prims.convert_element_type.default(sum_142, torch.bfloat16); sum_142 = None + all_reduce_47 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2338, 'sum', '1'); convert_element_type_2338 = None + wait_tensor_775 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_47); all_reduce_47 = None + convert_element_type_2339 = torch.ops.prims.convert_element_type.default(wait_tensor_775, torch.float32); wait_tensor_775 = None + reduce_scatter_tensor_325 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2339, 'avg', 32, '0'); convert_element_type_2339 = None + wait_tensor_776 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_325); reduce_scatter_tensor_325 = None + add_293 = torch.ops.aten.add.Tensor(add_289, convert_element_type_2337); add_289 = convert_element_type_2337 = None + all_gather_into_tensor_403 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_293, 8, '1') + wait_tensor_777 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_403); all_gather_into_tensor_403 = None + split_233 = torch.ops.aten.split.Tensor(wait_tensor_777, 2); wait_tensor_777 = None + getitem_2221 = split_233[0] + getitem_2222 = split_233[1] + getitem_2223 = split_233[2] + getitem_2224 = split_233[3] + getitem_2225 = split_233[4] + getitem_2226 = split_233[5] + getitem_2227 = split_233[6] + getitem_2228 = split_233[7]; split_233 = None + cat_225 = torch.ops.aten.cat.default([getitem_2221, getitem_2222, getitem_2223, getitem_2224, getitem_2225, getitem_2226, getitem_2227, getitem_2228], 1); getitem_2221 = getitem_2222 = getitem_2223 = getitem_2224 = getitem_2225 = getitem_2226 = getitem_2227 = getitem_2228 = None + view_2881 = torch.ops.aten.view.default(cat_225, [16384, 4096]); cat_225 = None + permute_1105 = torch.ops.aten.permute.default(view_2881, [1, 0]) + permute_94 = torch.ops.aten.permute.default(getitem_408, [0, 2, 1, 3]) + view_618 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + view_624 = torch.ops.aten.view.default(view_618, [16384, 512]); view_618 = None + mm_555 = torch.ops.aten.mm.default(permute_1105, view_624); permute_1105 = view_624 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16); primals_80 = None + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 32, '0'); convert_element_type_281 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + permute_1107 = torch.ops.aten.permute.default(permute_95, [1, 0]); permute_95 = None + mm_556 = torch.ops.aten.mm.default(view_2881, permute_1107); view_2881 = permute_1107 = None + view_2882 = torch.ops.aten.view.default(mm_556, [2, 8192, 512]); mm_556 = None + convert_element_type_2344 = torch.ops.prims.convert_element_type.default(mm_555, torch.float32); mm_555 = None + reduce_scatter_tensor_326 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2344, 'avg', 32, '0'); convert_element_type_2344 = None + wait_tensor_778 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_326); reduce_scatter_tensor_326 = None + view_2883 = torch.ops.aten.view.default(view_2882, [2, 8192, 4, 128]); view_2882 = None + permute_1109 = torch.ops.aten.permute.default(view_2883, [0, 2, 1, 3]); view_2883 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16); primals_76 = None + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 32, '0'); convert_element_type_265 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32); add_31 = None + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_106) + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_267, 8, '1'); convert_element_type_267 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + split_41 = torch.ops.aten.split.Tensor(wait_tensor_107, 2); wait_tensor_107 = None + getitem_400 = split_41[0] + getitem_401 = split_41[1] + getitem_402 = split_41[2] + getitem_403 = split_41[3] + getitem_404 = split_41[4] + getitem_405 = split_41[5] + getitem_406 = split_41[6] + getitem_407 = split_41[7]; split_41 = None + cat_33 = torch.ops.aten.cat.default([getitem_400, getitem_401, getitem_402, getitem_403, getitem_404, getitem_405, getitem_406, getitem_407], 1); getitem_400 = getitem_401 = getitem_402 = getitem_403 = getitem_404 = getitem_405 = getitem_406 = getitem_407 = None + view_591 = torch.ops.aten.view.default(cat_33, [16384, 4096]); cat_33 = None + view_592 = torch.ops.aten.view.default(mm_56, [2, 8192, 512]); mm_56 = None + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16); primals_78 = None + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 32, '0'); convert_element_type_271 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_109, [1, 0]); wait_tensor_109 = None + mm_57 = torch.ops.aten.mm.default(view_591, permute_89) + view_599 = torch.ops.aten.view.default(mm_57, [2, 8192, 128]); mm_57 = None + view_606 = torch.ops.aten.view.default(mm_58, [2, 8192, 128]); mm_58 = None + view_608 = torch.ops.aten.view.default(view_592, [2, 8192, -1, 128]); view_592 = None + view_609 = torch.ops.aten.view.default(view_599, [2, 8192, -1, 128]); view_599 = None + view_610 = torch.ops.aten.view.default(view_606, [2, 8192, -1, 128]); view_606 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_608, torch.float32); view_608 = None + view_611 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 4, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_611); view_611 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_609, torch.float32); view_609 = None + view_612 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 1, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_612); view_612 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_37); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_614 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 4, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_37); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_615 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 1, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_614, torch.bfloat16); view_614 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_615, torch.bfloat16); view_615 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 1, 4, 128]); unsqueeze_16 = None + view_616 = torch.ops.aten.view.default(expand_16, [2, 8192, 4, 128]); expand_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_610, 3); view_610 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 1, 4, 128]); unsqueeze_17 = None + view_617 = torch.ops.aten.view.default(expand_17, [2, 8192, 4, 128]); expand_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_616, [0, 2, 1, 3]); view_616 = None + permute_93 = torch.ops.aten.permute.default(view_617, [0, 2, 1, 3]); view_617 = None + _scaled_dot_product_cudnn_attention_backward_23 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1109, permute_91, permute_92, permute_93, getitem_408, getitem_409, getitem_414, getitem_415, None, None, None, 8192, 8192, 0.0, True); permute_1109 = permute_91 = permute_92 = permute_93 = getitem_408 = getitem_409 = getitem_414 = getitem_415 = None + getitem_2229 = _scaled_dot_product_cudnn_attention_backward_23[0] + getitem_2230 = _scaled_dot_product_cudnn_attention_backward_23[1] + getitem_2231 = _scaled_dot_product_cudnn_attention_backward_23[2]; _scaled_dot_product_cudnn_attention_backward_23 = None + permute_1110 = torch.ops.aten.permute.default(getitem_2231, [0, 2, 1, 3]); getitem_2231 = None + permute_1111 = torch.ops.aten.permute.default(getitem_2230, [0, 2, 1, 3]); getitem_2230 = None + permute_1112 = torch.ops.aten.permute.default(getitem_2229, [0, 2, 1, 3]); getitem_2229 = None + view_2884 = torch.ops.aten.view.default(permute_1110, [2, 8192, 1, 4, 128]); permute_1110 = None + sum_143 = torch.ops.aten.sum.dim_IntList(view_2884, [3], True); view_2884 = None + squeeze_46 = torch.ops.aten.squeeze.dim(sum_143, 3); sum_143 = None + view_2885 = torch.ops.aten.view.default(permute_1111, [2, 8192, 1, 4, 128]); permute_1111 = None + sum_144 = torch.ops.aten.sum.dim_IntList(view_2885, [3], True); view_2885 = None + squeeze_47 = torch.ops.aten.squeeze.dim(sum_144, 3); sum_144 = None + convert_element_type_2345 = torch.ops.prims.convert_element_type.default(squeeze_47, torch.float32); squeeze_47 = None + convert_element_type_2346 = torch.ops.prims.convert_element_type.default(permute_1112, torch.float32); permute_1112 = None + view_2886 = torch.ops.aten.view.default(convert_element_type_2345, [2, 8192, 1, 64, 2]); convert_element_type_2345 = None + view_as_complex_110 = torch.ops.aten.view_as_complex.default(view_2886); view_2886 = None + mul_736 = torch.ops.aten.mul.Tensor(view_as_complex_110, _conj); view_as_complex_110 = None + view_2887 = torch.ops.aten.view.default(convert_element_type_2346, [2, 8192, 4, 64, 2]); convert_element_type_2346 = None + view_as_complex_111 = torch.ops.aten.view_as_complex.default(view_2887); view_2887 = None + mul_737 = torch.ops.aten.mul.Tensor(view_as_complex_111, _conj); view_as_complex_111 = None + view_as_real_110 = torch.ops.aten.view_as_real.default(mul_736); mul_736 = None + view_2888 = torch.ops.aten.view.default(view_as_real_110, [2, 8192, 1, 128]); view_as_real_110 = None + convert_element_type_2347 = torch.ops.prims.convert_element_type.default(view_2888, torch.bfloat16); view_2888 = None + view_as_real_111 = torch.ops.aten.view_as_real.default(mul_737); mul_737 = None + view_2889 = torch.ops.aten.view.default(view_as_real_111, [2, 8192, 4, 128]); view_as_real_111 = None + convert_element_type_2348 = torch.ops.prims.convert_element_type.default(view_2889, torch.bfloat16); view_2889 = None + view_2890 = torch.ops.aten.view.default(squeeze_46, [2, 8192, 128]); squeeze_46 = None + view_2891 = torch.ops.aten.view.default(convert_element_type_2347, [2, 8192, 128]); convert_element_type_2347 = None + view_2892 = torch.ops.aten.view.default(convert_element_type_2348, [2, 8192, 512]); convert_element_type_2348 = None + view_2893 = torch.ops.aten.view.default(view_2890, [16384, 128]); view_2890 = None + permute_1113 = torch.ops.aten.permute.default(view_2893, [1, 0]) + mm_557 = torch.ops.aten.mm.default(permute_1113, view_591); permute_1113 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16); primals_79 = None + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 32, '0'); convert_element_type_274 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + permute_1115 = torch.ops.aten.permute.default(permute_90, [1, 0]); permute_90 = None + mm_558 = torch.ops.aten.mm.default(view_2893, permute_1115); view_2893 = permute_1115 = None + view_2894 = torch.ops.aten.view.default(mm_558, [2, 8192, 4096]); mm_558 = None + convert_element_type_2353 = torch.ops.prims.convert_element_type.default(mm_557, torch.float32); mm_557 = None + reduce_scatter_tensor_327 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2353, 'avg', 32, '0'); convert_element_type_2353 = None + wait_tensor_779 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_327); reduce_scatter_tensor_327 = None + view_2895 = torch.ops.aten.view.default(view_2891, [16384, 128]); view_2891 = None + permute_1117 = torch.ops.aten.permute.default(view_2895, [1, 0]) + mm_559 = torch.ops.aten.mm.default(permute_1117, view_591); permute_1117 = None + permute_1119 = torch.ops.aten.permute.default(permute_89, [1, 0]); permute_89 = None + mm_560 = torch.ops.aten.mm.default(view_2895, permute_1119); view_2895 = permute_1119 = None + view_2896 = torch.ops.aten.view.default(mm_560, [2, 8192, 4096]); mm_560 = None + add_294 = torch.ops.aten.add.Tensor(view_2894, view_2896); view_2894 = view_2896 = None + convert_element_type_2358 = torch.ops.prims.convert_element_type.default(mm_559, torch.float32); mm_559 = None + reduce_scatter_tensor_328 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2358, 'avg', 32, '0'); convert_element_type_2358 = None + wait_tensor_780 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_328); reduce_scatter_tensor_328 = None + view_2897 = torch.ops.aten.view.default(view_2892, [16384, 512]); view_2892 = None + permute_1121 = torch.ops.aten.permute.default(view_2897, [1, 0]) + mm_561 = torch.ops.aten.mm.default(permute_1121, view_591); permute_1121 = view_591 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16); primals_77 = None + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 32, '0'); convert_element_type_268 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + permute_1123 = torch.ops.aten.permute.default(permute_88, [1, 0]); permute_88 = None + mm_562 = torch.ops.aten.mm.default(view_2897, permute_1123); view_2897 = permute_1123 = None + view_2898 = torch.ops.aten.view.default(mm_562, [2, 8192, 4096]); mm_562 = None + add_295 = torch.ops.aten.add.Tensor(add_294, view_2898); add_294 = view_2898 = None + convert_element_type_2363 = torch.ops.prims.convert_element_type.default(mm_561, torch.float32); mm_561 = None + reduce_scatter_tensor_329 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2363, 'avg', 32, '0'); convert_element_type_2363 = None + wait_tensor_781 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_329); reduce_scatter_tensor_329 = None + split_234 = torch.ops.aten.split.Tensor(add_295, 1024, 1); add_295 = None + getitem_2232 = split_234[0] + getitem_2233 = split_234[1] + getitem_2234 = split_234[2] + getitem_2235 = split_234[3] + getitem_2236 = split_234[4] + getitem_2237 = split_234[5] + getitem_2238 = split_234[6] + getitem_2239 = split_234[7]; split_234 = None + cat_226 = torch.ops.aten.cat.default([getitem_2232, getitem_2233, getitem_2234, getitem_2235, getitem_2236, getitem_2237, getitem_2238, getitem_2239]); getitem_2232 = getitem_2233 = getitem_2234 = getitem_2235 = getitem_2236 = getitem_2237 = getitem_2238 = getitem_2239 = None + reduce_scatter_tensor_330 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_226, 'sum', 8, '1'); cat_226 = None + wait_tensor_782 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_330); reduce_scatter_tensor_330 = None + convert_element_type_2364 = torch.ops.prims.convert_element_type.default(wait_tensor_782, torch.float32); wait_tensor_782 = None + convert_element_type_2366 = torch.ops.prims.convert_element_type.default(wait_tensor_106, torch.float32); wait_tensor_106 = None + mul_738 = torch.ops.aten.mul.Tensor(convert_element_type_2364, convert_element_type_2366); convert_element_type_2366 = None + mul_740 = torch.ops.aten.mul.Tensor(mul_64, mul_738) + sum_145 = torch.ops.aten.sum.dim_IntList(mul_740, [2], True); mul_740 = None + div_48 = torch.ops.aten.div.Tensor(mul_64, 4096) + mul_741 = torch.ops.aten.mul.Tensor(div_48, sum_145); div_48 = sum_145 = None + sub_73 = torch.ops.aten.sub.Tensor(mul_738, mul_741); mul_738 = mul_741 = None + mul_742 = torch.ops.aten.mul.Tensor(sub_73, rsqrt_16); sub_73 = rsqrt_16 = None + mul_743 = torch.ops.aten.mul.Tensor(convert_element_type_2364, mul_64); convert_element_type_2364 = mul_64 = None + sum_146 = torch.ops.aten.sum.dim_IntList(mul_743, [0, 1]); mul_743 = None + convert_element_type_2367 = torch.ops.prims.convert_element_type.default(mul_742, torch.bfloat16); mul_742 = None + convert_element_type_2368 = torch.ops.prims.convert_element_type.default(sum_146, torch.bfloat16); sum_146 = None + all_reduce_48 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2368, 'sum', '1'); convert_element_type_2368 = None + wait_tensor_783 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_48); all_reduce_48 = None + convert_element_type_2369 = torch.ops.prims.convert_element_type.default(wait_tensor_783, torch.float32); wait_tensor_783 = None + reduce_scatter_tensor_331 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2369, 'avg', 32, '0'); convert_element_type_2369 = None + wait_tensor_784 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_331); reduce_scatter_tensor_331 = None + add_296 = torch.ops.aten.add.Tensor(add_293, convert_element_type_2367); add_293 = convert_element_type_2367 = None + all_gather_into_tensor_404 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_296, 8, '1') + wait_tensor_785 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_404); all_gather_into_tensor_404 = None + split_235 = torch.ops.aten.split.Tensor(wait_tensor_785, 2); wait_tensor_785 = None + getitem_2240 = split_235[0] + getitem_2241 = split_235[1] + getitem_2242 = split_235[2] + getitem_2243 = split_235[3] + getitem_2244 = split_235[4] + getitem_2245 = split_235[5] + getitem_2246 = split_235[6] + getitem_2247 = split_235[7]; split_235 = None + cat_227 = torch.ops.aten.cat.default([getitem_2240, getitem_2241, getitem_2242, getitem_2243, getitem_2244, getitem_2245, getitem_2246, getitem_2247], 1); getitem_2240 = getitem_2241 = getitem_2242 = getitem_2243 = getitem_2244 = getitem_2245 = getitem_2246 = getitem_2247 = None + view_2899 = torch.ops.aten.view.default(cat_227, [16384, 4096]); cat_227 = None + permute_1125 = torch.ops.aten.permute.default(view_2899, [1, 0]) + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15); reduce_scatter_tensor_15 = None + add_29 = torch.ops.aten.add.Tensor(add_27, wait_tensor_99); wait_tensor_99 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 32, '0'); convert_element_type_251 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32); add_29 = None + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_100) + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 8, '1'); convert_element_type_253 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + split_39 = torch.ops.aten.split.Tensor(wait_tensor_101, 2); wait_tensor_101 = None + getitem_384 = split_39[0] + getitem_385 = split_39[1] + getitem_386 = split_39[2] + getitem_387 = split_39[3] + getitem_388 = split_39[4] + getitem_389 = split_39[5] + getitem_390 = split_39[6] + getitem_391 = split_39[7]; split_39 = None + cat_31 = torch.ops.aten.cat.default([getitem_384, getitem_385, getitem_386, getitem_387, getitem_388, getitem_389, getitem_390, getitem_391], 1); getitem_384 = getitem_385 = getitem_386 = getitem_387 = getitem_388 = getitem_389 = getitem_390 = getitem_391 = None + view_564 = torch.ops.aten.view.default(cat_31, [16384, 4096]); cat_31 = None + view_565 = torch.ops.aten.view.default(mm_53, [2, 8192, 1792]); mm_53 = None + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_565, torch.float32); view_565 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16); primals_74 = None + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 32, '0'); convert_element_type_259 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_54 = torch.ops.aten.mm.default(view_564, permute_86) + view_572 = torch.ops.aten.view.default(mm_54, [2, 8192, 1792]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_572) + view_579 = torch.ops.aten.view.default(mul_63, [16384, 1792]); mul_63 = None + mm_563 = torch.ops.aten.mm.default(permute_1125, view_579); permute_1125 = view_579 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16); primals_75 = None + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 32, '0'); convert_element_type_262 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + permute_1127 = torch.ops.aten.permute.default(permute_87, [1, 0]); permute_87 = None + mm_564 = torch.ops.aten.mm.default(view_2899, permute_1127); view_2899 = permute_1127 = None + view_2900 = torch.ops.aten.view.default(mm_564, [2, 8192, 1792]); mm_564 = None + convert_element_type_2374 = torch.ops.prims.convert_element_type.default(mm_563, torch.float32); mm_563 = None + reduce_scatter_tensor_332 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2374, 'avg', 32, '0'); convert_element_type_2374 = None + wait_tensor_786 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_332); reduce_scatter_tensor_332 = None + mul_744 = torch.ops.aten.mul.Tensor(view_2900, convert_element_type_258); convert_element_type_258 = None + mul_745 = torch.ops.aten.mul.Tensor(view_2900, view_572); view_2900 = view_572 = None + view_2901 = torch.ops.aten.view.default(mul_744, [16384, 1792]); mul_744 = None + permute_1129 = torch.ops.aten.permute.default(view_2901, [1, 0]) + mm_565 = torch.ops.aten.mm.default(permute_1129, view_564); permute_1129 = None + permute_1131 = torch.ops.aten.permute.default(permute_86, [1, 0]); permute_86 = None + mm_566 = torch.ops.aten.mm.default(view_2901, permute_1131); view_2901 = permute_1131 = None + view_2902 = torch.ops.aten.view.default(mm_566, [2, 8192, 4096]); mm_566 = None + convert_element_type_2379 = torch.ops.prims.convert_element_type.default(mm_565, torch.float32); mm_565 = None + reduce_scatter_tensor_333 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2379, 'avg', 32, '0'); convert_element_type_2379 = None + wait_tensor_787 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_333); reduce_scatter_tensor_333 = None + convert_element_type_2380 = torch.ops.prims.convert_element_type.default(mul_745, torch.float32); mul_745 = None + neg_24 = torch.ops.aten.neg.default(convert_element_type_257) + exp_24 = torch.ops.aten.exp.default(neg_24); neg_24 = None + add_297 = torch.ops.aten.add.Tensor(exp_24, 1); exp_24 = None + reciprocal_24 = torch.ops.aten.reciprocal.default(add_297); add_297 = None + mul_746 = torch.ops.aten.mul.Tensor(reciprocal_24, 1); reciprocal_24 = None + mul_747 = torch.ops.aten.mul.Tensor(convert_element_type_2380, mul_746); convert_element_type_2380 = None + sub_74 = torch.ops.aten.sub.Tensor(1, mul_746); mul_746 = None + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_257, sub_74); convert_element_type_257 = sub_74 = None + add_298 = torch.ops.aten.add.Tensor(mul_748, 1); mul_748 = None + mul_749 = torch.ops.aten.mul.Tensor(mul_747, add_298); mul_747 = add_298 = None + convert_element_type_2382 = torch.ops.prims.convert_element_type.default(mul_749, torch.bfloat16); mul_749 = None + view_2903 = torch.ops.aten.view.default(convert_element_type_2382, [16384, 1792]); convert_element_type_2382 = None + permute_1133 = torch.ops.aten.permute.default(view_2903, [1, 0]) + mm_567 = torch.ops.aten.mm.default(permute_1133, view_564); permute_1133 = view_564 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16); primals_73 = None + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 32, '0'); convert_element_type_254 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + permute_1135 = torch.ops.aten.permute.default(permute_85, [1, 0]); permute_85 = None + mm_568 = torch.ops.aten.mm.default(view_2903, permute_1135); view_2903 = permute_1135 = None + view_2904 = torch.ops.aten.view.default(mm_568, [2, 8192, 4096]); mm_568 = None + add_299 = torch.ops.aten.add.Tensor(view_2902, view_2904); view_2902 = view_2904 = None + convert_element_type_2387 = torch.ops.prims.convert_element_type.default(mm_567, torch.float32); mm_567 = None + reduce_scatter_tensor_334 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2387, 'avg', 32, '0'); convert_element_type_2387 = None + wait_tensor_788 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_334); reduce_scatter_tensor_334 = None + split_236 = torch.ops.aten.split.Tensor(add_299, 1024, 1); add_299 = None + getitem_2248 = split_236[0] + getitem_2249 = split_236[1] + getitem_2250 = split_236[2] + getitem_2251 = split_236[3] + getitem_2252 = split_236[4] + getitem_2253 = split_236[5] + getitem_2254 = split_236[6] + getitem_2255 = split_236[7]; split_236 = None + cat_228 = torch.ops.aten.cat.default([getitem_2248, getitem_2249, getitem_2250, getitem_2251, getitem_2252, getitem_2253, getitem_2254, getitem_2255]); getitem_2248 = getitem_2249 = getitem_2250 = getitem_2251 = getitem_2252 = getitem_2253 = getitem_2254 = getitem_2255 = None + reduce_scatter_tensor_335 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_228, 'sum', 8, '1'); cat_228 = None + wait_tensor_789 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_335); reduce_scatter_tensor_335 = None + convert_element_type_2388 = torch.ops.prims.convert_element_type.default(wait_tensor_789, torch.float32); wait_tensor_789 = None + convert_element_type_2390 = torch.ops.prims.convert_element_type.default(wait_tensor_100, torch.float32); wait_tensor_100 = None + mul_750 = torch.ops.aten.mul.Tensor(convert_element_type_2388, convert_element_type_2390); convert_element_type_2390 = None + mul_752 = torch.ops.aten.mul.Tensor(mul_60, mul_750) + sum_147 = torch.ops.aten.sum.dim_IntList(mul_752, [2], True); mul_752 = None + div_49 = torch.ops.aten.div.Tensor(mul_60, 4096) + mul_753 = torch.ops.aten.mul.Tensor(div_49, sum_147); div_49 = sum_147 = None + sub_75 = torch.ops.aten.sub.Tensor(mul_750, mul_753); mul_750 = mul_753 = None + mul_754 = torch.ops.aten.mul.Tensor(sub_75, rsqrt_15); sub_75 = rsqrt_15 = None + mul_755 = torch.ops.aten.mul.Tensor(convert_element_type_2388, mul_60); convert_element_type_2388 = mul_60 = None + sum_148 = torch.ops.aten.sum.dim_IntList(mul_755, [0, 1]); mul_755 = None + convert_element_type_2391 = torch.ops.prims.convert_element_type.default(mul_754, torch.bfloat16); mul_754 = None + convert_element_type_2392 = torch.ops.prims.convert_element_type.default(sum_148, torch.bfloat16); sum_148 = None + all_reduce_49 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2392, 'sum', '1'); convert_element_type_2392 = None + wait_tensor_790 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_49); all_reduce_49 = None + convert_element_type_2393 = torch.ops.prims.convert_element_type.default(wait_tensor_790, torch.float32); wait_tensor_790 = None + reduce_scatter_tensor_336 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2393, 'avg', 32, '0'); convert_element_type_2393 = None + wait_tensor_791 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_336); reduce_scatter_tensor_336 = None + add_300 = torch.ops.aten.add.Tensor(add_296, convert_element_type_2391); add_296 = convert_element_type_2391 = None + all_gather_into_tensor_405 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_300, 8, '1') + wait_tensor_792 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_405); all_gather_into_tensor_405 = None + split_237 = torch.ops.aten.split.Tensor(wait_tensor_792, 2); wait_tensor_792 = None + getitem_2256 = split_237[0] + getitem_2257 = split_237[1] + getitem_2258 = split_237[2] + getitem_2259 = split_237[3] + getitem_2260 = split_237[4] + getitem_2261 = split_237[5] + getitem_2262 = split_237[6] + getitem_2263 = split_237[7]; split_237 = None + cat_229 = torch.ops.aten.cat.default([getitem_2256, getitem_2257, getitem_2258, getitem_2259, getitem_2260, getitem_2261, getitem_2262, getitem_2263], 1); getitem_2256 = getitem_2257 = getitem_2258 = getitem_2259 = getitem_2260 = getitem_2261 = getitem_2262 = getitem_2263 = None + view_2905 = torch.ops.aten.view.default(cat_229, [16384, 4096]); cat_229 = None + permute_1137 = torch.ops.aten.permute.default(view_2905, [1, 0]) + permute_83 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]) + view_546 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + view_552 = torch.ops.aten.view.default(view_546, [16384, 512]); view_546 = None + mm_569 = torch.ops.aten.mm.default(permute_1137, view_552); permute_1137 = view_552 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 32, '0'); convert_element_type_248 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + permute_1139 = torch.ops.aten.permute.default(permute_84, [1, 0]); permute_84 = None + mm_570 = torch.ops.aten.mm.default(view_2905, permute_1139); view_2905 = permute_1139 = None + view_2906 = torch.ops.aten.view.default(mm_570, [2, 8192, 512]); mm_570 = None + convert_element_type_2398 = torch.ops.prims.convert_element_type.default(mm_569, torch.float32); mm_569 = None + reduce_scatter_tensor_337 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2398, 'avg', 32, '0'); convert_element_type_2398 = None + wait_tensor_793 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_337); reduce_scatter_tensor_337 = None + view_2907 = torch.ops.aten.view.default(view_2906, [2, 8192, 4, 128]); view_2906 = None + permute_1141 = torch.ops.aten.permute.default(view_2907, [0, 2, 1, 3]); view_2907 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 32, '0'); convert_element_type_232 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32); add_27 = None + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_93) + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 8, '1'); convert_element_type_234 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + split_37 = torch.ops.aten.split.Tensor(wait_tensor_94, 2); wait_tensor_94 = None + getitem_359 = split_37[0] + getitem_360 = split_37[1] + getitem_361 = split_37[2] + getitem_362 = split_37[3] + getitem_363 = split_37[4] + getitem_364 = split_37[5] + getitem_365 = split_37[6] + getitem_366 = split_37[7]; split_37 = None + cat_29 = torch.ops.aten.cat.default([getitem_359, getitem_360, getitem_361, getitem_362, getitem_363, getitem_364, getitem_365, getitem_366], 1); getitem_359 = getitem_360 = getitem_361 = getitem_362 = getitem_363 = getitem_364 = getitem_365 = getitem_366 = None + view_519 = torch.ops.aten.view.default(cat_29, [16384, 4096]); cat_29 = None + view_520 = torch.ops.aten.view.default(mm_49, [2, 8192, 512]); mm_49 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 32, '0'); convert_element_type_238 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + mm_50 = torch.ops.aten.mm.default(view_519, permute_78) + view_527 = torch.ops.aten.view.default(mm_50, [2, 8192, 128]); mm_50 = None + view_534 = torch.ops.aten.view.default(mm_51, [2, 8192, 128]); mm_51 = None + view_536 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + view_537 = torch.ops.aten.view.default(view_527, [2, 8192, -1, 128]); view_527 = None + view_538 = torch.ops.aten.view.default(view_534, [2, 8192, -1, 128]); view_534 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_536, torch.float32); view_536 = None + view_539 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 4, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_539); view_539 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_537, torch.float32); view_537 = None + view_540 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 1, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_540); view_540 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_37); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_542 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 4, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_37); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_543 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 1, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_542, torch.bfloat16); view_542 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_543, torch.bfloat16); view_543 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 1, 4, 128]); unsqueeze_14 = None + view_544 = torch.ops.aten.view.default(expand_14, [2, 8192, 4, 128]); expand_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_538, 3); view_538 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 1, 4, 128]); unsqueeze_15 = None + view_545 = torch.ops.aten.view.default(expand_15, [2, 8192, 4, 128]); expand_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_544, [0, 2, 1, 3]); view_544 = None + permute_82 = torch.ops.aten.permute.default(view_545, [0, 2, 1, 3]); view_545 = None + _scaled_dot_product_cudnn_attention_backward_24 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1141, permute_80, permute_81, permute_82, getitem_367, getitem_368, getitem_373, getitem_374, None, None, None, 8192, 8192, 0.0, True); permute_1141 = permute_80 = permute_81 = permute_82 = getitem_367 = getitem_368 = getitem_373 = getitem_374 = None + getitem_2264 = _scaled_dot_product_cudnn_attention_backward_24[0] + getitem_2265 = _scaled_dot_product_cudnn_attention_backward_24[1] + getitem_2266 = _scaled_dot_product_cudnn_attention_backward_24[2]; _scaled_dot_product_cudnn_attention_backward_24 = None + permute_1142 = torch.ops.aten.permute.default(getitem_2266, [0, 2, 1, 3]); getitem_2266 = None + permute_1143 = torch.ops.aten.permute.default(getitem_2265, [0, 2, 1, 3]); getitem_2265 = None + permute_1144 = torch.ops.aten.permute.default(getitem_2264, [0, 2, 1, 3]); getitem_2264 = None + view_2908 = torch.ops.aten.view.default(permute_1142, [2, 8192, 1, 4, 128]); permute_1142 = None + sum_149 = torch.ops.aten.sum.dim_IntList(view_2908, [3], True); view_2908 = None + squeeze_48 = torch.ops.aten.squeeze.dim(sum_149, 3); sum_149 = None + view_2909 = torch.ops.aten.view.default(permute_1143, [2, 8192, 1, 4, 128]); permute_1143 = None + sum_150 = torch.ops.aten.sum.dim_IntList(view_2909, [3], True); view_2909 = None + squeeze_49 = torch.ops.aten.squeeze.dim(sum_150, 3); sum_150 = None + convert_element_type_2399 = torch.ops.prims.convert_element_type.default(squeeze_49, torch.float32); squeeze_49 = None + convert_element_type_2400 = torch.ops.prims.convert_element_type.default(permute_1144, torch.float32); permute_1144 = None + view_2910 = torch.ops.aten.view.default(convert_element_type_2399, [2, 8192, 1, 64, 2]); convert_element_type_2399 = None + view_as_complex_112 = torch.ops.aten.view_as_complex.default(view_2910); view_2910 = None + mul_756 = torch.ops.aten.mul.Tensor(view_as_complex_112, _conj); view_as_complex_112 = None + view_2911 = torch.ops.aten.view.default(convert_element_type_2400, [2, 8192, 4, 64, 2]); convert_element_type_2400 = None + view_as_complex_113 = torch.ops.aten.view_as_complex.default(view_2911); view_2911 = None + mul_757 = torch.ops.aten.mul.Tensor(view_as_complex_113, _conj); view_as_complex_113 = None + view_as_real_112 = torch.ops.aten.view_as_real.default(mul_756); mul_756 = None + view_2912 = torch.ops.aten.view.default(view_as_real_112, [2, 8192, 1, 128]); view_as_real_112 = None + convert_element_type_2401 = torch.ops.prims.convert_element_type.default(view_2912, torch.bfloat16); view_2912 = None + view_as_real_113 = torch.ops.aten.view_as_real.default(mul_757); mul_757 = None + view_2913 = torch.ops.aten.view.default(view_as_real_113, [2, 8192, 4, 128]); view_as_real_113 = None + convert_element_type_2402 = torch.ops.prims.convert_element_type.default(view_2913, torch.bfloat16); view_2913 = None + view_2914 = torch.ops.aten.view.default(squeeze_48, [2, 8192, 128]); squeeze_48 = None + view_2915 = torch.ops.aten.view.default(convert_element_type_2401, [2, 8192, 128]); convert_element_type_2401 = None + view_2916 = torch.ops.aten.view.default(convert_element_type_2402, [2, 8192, 512]); convert_element_type_2402 = None + view_2917 = torch.ops.aten.view.default(view_2914, [16384, 128]); view_2914 = None + permute_1145 = torch.ops.aten.permute.default(view_2917, [1, 0]) + mm_571 = torch.ops.aten.mm.default(permute_1145, view_519); permute_1145 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 32, '0'); convert_element_type_241 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + permute_1147 = torch.ops.aten.permute.default(permute_79, [1, 0]); permute_79 = None + mm_572 = torch.ops.aten.mm.default(view_2917, permute_1147); view_2917 = permute_1147 = None + view_2918 = torch.ops.aten.view.default(mm_572, [2, 8192, 4096]); mm_572 = None + convert_element_type_2407 = torch.ops.prims.convert_element_type.default(mm_571, torch.float32); mm_571 = None + reduce_scatter_tensor_338 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2407, 'avg', 32, '0'); convert_element_type_2407 = None + wait_tensor_794 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_338); reduce_scatter_tensor_338 = None + view_2919 = torch.ops.aten.view.default(view_2915, [16384, 128]); view_2915 = None + permute_1149 = torch.ops.aten.permute.default(view_2919, [1, 0]) + mm_573 = torch.ops.aten.mm.default(permute_1149, view_519); permute_1149 = None + permute_1151 = torch.ops.aten.permute.default(permute_78, [1, 0]); permute_78 = None + mm_574 = torch.ops.aten.mm.default(view_2919, permute_1151); view_2919 = permute_1151 = None + view_2920 = torch.ops.aten.view.default(mm_574, [2, 8192, 4096]); mm_574 = None + add_301 = torch.ops.aten.add.Tensor(view_2918, view_2920); view_2918 = view_2920 = None + convert_element_type_2412 = torch.ops.prims.convert_element_type.default(mm_573, torch.float32); mm_573 = None + reduce_scatter_tensor_339 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2412, 'avg', 32, '0'); convert_element_type_2412 = None + wait_tensor_795 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_339); reduce_scatter_tensor_339 = None + view_2921 = torch.ops.aten.view.default(view_2916, [16384, 512]); view_2916 = None + permute_1153 = torch.ops.aten.permute.default(view_2921, [1, 0]) + mm_575 = torch.ops.aten.mm.default(permute_1153, view_519); permute_1153 = view_519 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 32, '0'); convert_element_type_235 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + permute_1155 = torch.ops.aten.permute.default(permute_77, [1, 0]); permute_77 = None + mm_576 = torch.ops.aten.mm.default(view_2921, permute_1155); view_2921 = permute_1155 = None + view_2922 = torch.ops.aten.view.default(mm_576, [2, 8192, 4096]); mm_576 = None + add_302 = torch.ops.aten.add.Tensor(add_301, view_2922); add_301 = view_2922 = None + convert_element_type_2417 = torch.ops.prims.convert_element_type.default(mm_575, torch.float32); mm_575 = None + reduce_scatter_tensor_340 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2417, 'avg', 32, '0'); convert_element_type_2417 = None + wait_tensor_796 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_340); reduce_scatter_tensor_340 = None + split_238 = torch.ops.aten.split.Tensor(add_302, 1024, 1); add_302 = None + getitem_2267 = split_238[0] + getitem_2268 = split_238[1] + getitem_2269 = split_238[2] + getitem_2270 = split_238[3] + getitem_2271 = split_238[4] + getitem_2272 = split_238[5] + getitem_2273 = split_238[6] + getitem_2274 = split_238[7]; split_238 = None + cat_230 = torch.ops.aten.cat.default([getitem_2267, getitem_2268, getitem_2269, getitem_2270, getitem_2271, getitem_2272, getitem_2273, getitem_2274]); getitem_2267 = getitem_2268 = getitem_2269 = getitem_2270 = getitem_2271 = getitem_2272 = getitem_2273 = getitem_2274 = None + reduce_scatter_tensor_341 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_230, 'sum', 8, '1'); cat_230 = None + wait_tensor_797 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_341); reduce_scatter_tensor_341 = None + convert_element_type_2418 = torch.ops.prims.convert_element_type.default(wait_tensor_797, torch.float32); wait_tensor_797 = None + convert_element_type_2420 = torch.ops.prims.convert_element_type.default(wait_tensor_93, torch.float32); wait_tensor_93 = None + mul_758 = torch.ops.aten.mul.Tensor(convert_element_type_2418, convert_element_type_2420); convert_element_type_2420 = None + mul_760 = torch.ops.aten.mul.Tensor(mul_56, mul_758) + sum_151 = torch.ops.aten.sum.dim_IntList(mul_760, [2], True); mul_760 = None + div_50 = torch.ops.aten.div.Tensor(mul_56, 4096) + mul_761 = torch.ops.aten.mul.Tensor(div_50, sum_151); div_50 = sum_151 = None + sub_76 = torch.ops.aten.sub.Tensor(mul_758, mul_761); mul_758 = mul_761 = None + mul_762 = torch.ops.aten.mul.Tensor(sub_76, rsqrt_14); sub_76 = rsqrt_14 = None + mul_763 = torch.ops.aten.mul.Tensor(convert_element_type_2418, mul_56); convert_element_type_2418 = mul_56 = None + sum_152 = torch.ops.aten.sum.dim_IntList(mul_763, [0, 1]); mul_763 = None + convert_element_type_2421 = torch.ops.prims.convert_element_type.default(mul_762, torch.bfloat16); mul_762 = None + convert_element_type_2422 = torch.ops.prims.convert_element_type.default(sum_152, torch.bfloat16); sum_152 = None + all_reduce_50 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2422, 'sum', '1'); convert_element_type_2422 = None + wait_tensor_798 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_50); all_reduce_50 = None + convert_element_type_2423 = torch.ops.prims.convert_element_type.default(wait_tensor_798, torch.float32); wait_tensor_798 = None + reduce_scatter_tensor_342 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2423, 'avg', 32, '0'); convert_element_type_2423 = None + wait_tensor_799 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_342); reduce_scatter_tensor_342 = None + add_303 = torch.ops.aten.add.Tensor(add_300, convert_element_type_2421); add_300 = convert_element_type_2421 = None + all_gather_into_tensor_406 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_303, 8, '1') + wait_tensor_800 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_406); all_gather_into_tensor_406 = None + split_239 = torch.ops.aten.split.Tensor(wait_tensor_800, 2); wait_tensor_800 = None + getitem_2275 = split_239[0] + getitem_2276 = split_239[1] + getitem_2277 = split_239[2] + getitem_2278 = split_239[3] + getitem_2279 = split_239[4] + getitem_2280 = split_239[5] + getitem_2281 = split_239[6] + getitem_2282 = split_239[7]; split_239 = None + cat_231 = torch.ops.aten.cat.default([getitem_2275, getitem_2276, getitem_2277, getitem_2278, getitem_2279, getitem_2280, getitem_2281, getitem_2282], 1); getitem_2275 = getitem_2276 = getitem_2277 = getitem_2278 = getitem_2279 = getitem_2280 = getitem_2281 = getitem_2282 = None + view_2923 = torch.ops.aten.view.default(cat_231, [16384, 4096]); cat_231 = None + permute_1157 = torch.ops.aten.permute.default(view_2923, [1, 0]) + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13); reduce_scatter_tensor_13 = None + add_25 = torch.ops.aten.add.Tensor(add_23, wait_tensor_86); wait_tensor_86 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16); primals_63 = None + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 32, '0'); convert_element_type_218 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32); add_25 = None + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_87) + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_220, 8, '1'); convert_element_type_220 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + split_35 = torch.ops.aten.split.Tensor(wait_tensor_88, 2); wait_tensor_88 = None + getitem_343 = split_35[0] + getitem_344 = split_35[1] + getitem_345 = split_35[2] + getitem_346 = split_35[3] + getitem_347 = split_35[4] + getitem_348 = split_35[5] + getitem_349 = split_35[6] + getitem_350 = split_35[7]; split_35 = None + cat_27 = torch.ops.aten.cat.default([getitem_343, getitem_344, getitem_345, getitem_346, getitem_347, getitem_348, getitem_349, getitem_350], 1); getitem_343 = getitem_344 = getitem_345 = getitem_346 = getitem_347 = getitem_348 = getitem_349 = getitem_350 = None + view_492 = torch.ops.aten.view.default(cat_27, [16384, 4096]); cat_27 = None + view_493 = torch.ops.aten.view.default(mm_46, [2, 8192, 1792]); mm_46 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_493, torch.float32); view_493 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 32, '0'); convert_element_type_226 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + mm_47 = torch.ops.aten.mm.default(view_492, permute_75) + view_500 = torch.ops.aten.view.default(mm_47, [2, 8192, 1792]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_500) + view_507 = torch.ops.aten.view.default(mul_55, [16384, 1792]); mul_55 = None + mm_577 = torch.ops.aten.mm.default(permute_1157, view_507); permute_1157 = view_507 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 32, '0'); convert_element_type_229 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_91, [1, 0]); wait_tensor_91 = None + permute_1159 = torch.ops.aten.permute.default(permute_76, [1, 0]); permute_76 = None + mm_578 = torch.ops.aten.mm.default(view_2923, permute_1159); view_2923 = permute_1159 = None + view_2924 = torch.ops.aten.view.default(mm_578, [2, 8192, 1792]); mm_578 = None + convert_element_type_2428 = torch.ops.prims.convert_element_type.default(mm_577, torch.float32); mm_577 = None + reduce_scatter_tensor_343 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2428, 'avg', 32, '0'); convert_element_type_2428 = None + wait_tensor_801 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_343); reduce_scatter_tensor_343 = None + mul_764 = torch.ops.aten.mul.Tensor(view_2924, convert_element_type_225); convert_element_type_225 = None + mul_765 = torch.ops.aten.mul.Tensor(view_2924, view_500); view_2924 = view_500 = None + view_2925 = torch.ops.aten.view.default(mul_764, [16384, 1792]); mul_764 = None + permute_1161 = torch.ops.aten.permute.default(view_2925, [1, 0]) + mm_579 = torch.ops.aten.mm.default(permute_1161, view_492); permute_1161 = None + permute_1163 = torch.ops.aten.permute.default(permute_75, [1, 0]); permute_75 = None + mm_580 = torch.ops.aten.mm.default(view_2925, permute_1163); view_2925 = permute_1163 = None + view_2926 = torch.ops.aten.view.default(mm_580, [2, 8192, 4096]); mm_580 = None + convert_element_type_2433 = torch.ops.prims.convert_element_type.default(mm_579, torch.float32); mm_579 = None + reduce_scatter_tensor_344 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2433, 'avg', 32, '0'); convert_element_type_2433 = None + wait_tensor_802 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_344); reduce_scatter_tensor_344 = None + convert_element_type_2434 = torch.ops.prims.convert_element_type.default(mul_765, torch.float32); mul_765 = None + neg_25 = torch.ops.aten.neg.default(convert_element_type_224) + exp_25 = torch.ops.aten.exp.default(neg_25); neg_25 = None + add_304 = torch.ops.aten.add.Tensor(exp_25, 1); exp_25 = None + reciprocal_25 = torch.ops.aten.reciprocal.default(add_304); add_304 = None + mul_766 = torch.ops.aten.mul.Tensor(reciprocal_25, 1); reciprocal_25 = None + mul_767 = torch.ops.aten.mul.Tensor(convert_element_type_2434, mul_766); convert_element_type_2434 = None + sub_77 = torch.ops.aten.sub.Tensor(1, mul_766); mul_766 = None + mul_768 = torch.ops.aten.mul.Tensor(convert_element_type_224, sub_77); convert_element_type_224 = sub_77 = None + add_305 = torch.ops.aten.add.Tensor(mul_768, 1); mul_768 = None + mul_769 = torch.ops.aten.mul.Tensor(mul_767, add_305); mul_767 = add_305 = None + convert_element_type_2436 = torch.ops.prims.convert_element_type.default(mul_769, torch.bfloat16); mul_769 = None + view_2927 = torch.ops.aten.view.default(convert_element_type_2436, [16384, 1792]); convert_element_type_2436 = None + permute_1165 = torch.ops.aten.permute.default(view_2927, [1, 0]) + mm_581 = torch.ops.aten.mm.default(permute_1165, view_492); permute_1165 = view_492 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16); primals_64 = None + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 32, '0'); convert_element_type_221 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + permute_1167 = torch.ops.aten.permute.default(permute_74, [1, 0]); permute_74 = None + mm_582 = torch.ops.aten.mm.default(view_2927, permute_1167); view_2927 = permute_1167 = None + view_2928 = torch.ops.aten.view.default(mm_582, [2, 8192, 4096]); mm_582 = None + add_306 = torch.ops.aten.add.Tensor(view_2926, view_2928); view_2926 = view_2928 = None + convert_element_type_2441 = torch.ops.prims.convert_element_type.default(mm_581, torch.float32); mm_581 = None + reduce_scatter_tensor_345 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2441, 'avg', 32, '0'); convert_element_type_2441 = None + wait_tensor_803 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_345); reduce_scatter_tensor_345 = None + split_240 = torch.ops.aten.split.Tensor(add_306, 1024, 1); add_306 = None + getitem_2283 = split_240[0] + getitem_2284 = split_240[1] + getitem_2285 = split_240[2] + getitem_2286 = split_240[3] + getitem_2287 = split_240[4] + getitem_2288 = split_240[5] + getitem_2289 = split_240[6] + getitem_2290 = split_240[7]; split_240 = None + cat_232 = torch.ops.aten.cat.default([getitem_2283, getitem_2284, getitem_2285, getitem_2286, getitem_2287, getitem_2288, getitem_2289, getitem_2290]); getitem_2283 = getitem_2284 = getitem_2285 = getitem_2286 = getitem_2287 = getitem_2288 = getitem_2289 = getitem_2290 = None + reduce_scatter_tensor_346 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_232, 'sum', 8, '1'); cat_232 = None + wait_tensor_804 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_346); reduce_scatter_tensor_346 = None + convert_element_type_2442 = torch.ops.prims.convert_element_type.default(wait_tensor_804, torch.float32); wait_tensor_804 = None + convert_element_type_2444 = torch.ops.prims.convert_element_type.default(wait_tensor_87, torch.float32); wait_tensor_87 = None + mul_770 = torch.ops.aten.mul.Tensor(convert_element_type_2442, convert_element_type_2444); convert_element_type_2444 = None + mul_772 = torch.ops.aten.mul.Tensor(mul_52, mul_770) + sum_153 = torch.ops.aten.sum.dim_IntList(mul_772, [2], True); mul_772 = None + div_51 = torch.ops.aten.div.Tensor(mul_52, 4096) + mul_773 = torch.ops.aten.mul.Tensor(div_51, sum_153); div_51 = sum_153 = None + sub_78 = torch.ops.aten.sub.Tensor(mul_770, mul_773); mul_770 = mul_773 = None + mul_774 = torch.ops.aten.mul.Tensor(sub_78, rsqrt_13); sub_78 = rsqrt_13 = None + mul_775 = torch.ops.aten.mul.Tensor(convert_element_type_2442, mul_52); convert_element_type_2442 = mul_52 = None + sum_154 = torch.ops.aten.sum.dim_IntList(mul_775, [0, 1]); mul_775 = None + convert_element_type_2445 = torch.ops.prims.convert_element_type.default(mul_774, torch.bfloat16); mul_774 = None + convert_element_type_2446 = torch.ops.prims.convert_element_type.default(sum_154, torch.bfloat16); sum_154 = None + all_reduce_51 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2446, 'sum', '1'); convert_element_type_2446 = None + wait_tensor_805 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_51); all_reduce_51 = None + convert_element_type_2447 = torch.ops.prims.convert_element_type.default(wait_tensor_805, torch.float32); wait_tensor_805 = None + reduce_scatter_tensor_347 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2447, 'avg', 32, '0'); convert_element_type_2447 = None + wait_tensor_806 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_347); reduce_scatter_tensor_347 = None + add_307 = torch.ops.aten.add.Tensor(add_303, convert_element_type_2445); add_303 = convert_element_type_2445 = None + all_gather_into_tensor_407 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_307, 8, '1') + wait_tensor_807 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_407); all_gather_into_tensor_407 = None + split_241 = torch.ops.aten.split.Tensor(wait_tensor_807, 2); wait_tensor_807 = None + getitem_2291 = split_241[0] + getitem_2292 = split_241[1] + getitem_2293 = split_241[2] + getitem_2294 = split_241[3] + getitem_2295 = split_241[4] + getitem_2296 = split_241[5] + getitem_2297 = split_241[6] + getitem_2298 = split_241[7]; split_241 = None + cat_233 = torch.ops.aten.cat.default([getitem_2291, getitem_2292, getitem_2293, getitem_2294, getitem_2295, getitem_2296, getitem_2297, getitem_2298], 1); getitem_2291 = getitem_2292 = getitem_2293 = getitem_2294 = getitem_2295 = getitem_2296 = getitem_2297 = getitem_2298 = None + view_2929 = torch.ops.aten.view.default(cat_233, [16384, 4096]); cat_233 = None + permute_1169 = torch.ops.aten.permute.default(view_2929, [1, 0]) + permute_72 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]) + view_474 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + view_480 = torch.ops.aten.view.default(view_474, [16384, 512]); view_474 = None + mm_583 = torch.ops.aten.mm.default(permute_1169, view_480); permute_1169 = view_480 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16); primals_62 = None + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 32, '0'); convert_element_type_215 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + permute_1171 = torch.ops.aten.permute.default(permute_73, [1, 0]); permute_73 = None + mm_584 = torch.ops.aten.mm.default(view_2929, permute_1171); view_2929 = permute_1171 = None + view_2930 = torch.ops.aten.view.default(mm_584, [2, 8192, 512]); mm_584 = None + convert_element_type_2452 = torch.ops.prims.convert_element_type.default(mm_583, torch.float32); mm_583 = None + reduce_scatter_tensor_348 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2452, 'avg', 32, '0'); convert_element_type_2452 = None + wait_tensor_808 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_348); reduce_scatter_tensor_348 = None + view_2931 = torch.ops.aten.view.default(view_2930, [2, 8192, 4, 128]); view_2930 = None + permute_1173 = torch.ops.aten.permute.default(view_2931, [0, 2, 1, 3]); view_2931 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16); primals_58 = None + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 32, '0'); convert_element_type_199 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32); add_23 = None + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_80) + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_201, 8, '1'); convert_element_type_201 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + split_33 = torch.ops.aten.split.Tensor(wait_tensor_81, 2); wait_tensor_81 = None + getitem_318 = split_33[0] + getitem_319 = split_33[1] + getitem_320 = split_33[2] + getitem_321 = split_33[3] + getitem_322 = split_33[4] + getitem_323 = split_33[5] + getitem_324 = split_33[6] + getitem_325 = split_33[7]; split_33 = None + cat_25 = torch.ops.aten.cat.default([getitem_318, getitem_319, getitem_320, getitem_321, getitem_322, getitem_323, getitem_324, getitem_325], 1); getitem_318 = getitem_319 = getitem_320 = getitem_321 = getitem_322 = getitem_323 = getitem_324 = getitem_325 = None + view_447 = torch.ops.aten.view.default(cat_25, [16384, 4096]); cat_25 = None + view_448 = torch.ops.aten.view.default(mm_42, [2, 8192, 512]); mm_42 = None + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16); primals_60 = None + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 32, '0'); convert_element_type_205 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + mm_43 = torch.ops.aten.mm.default(view_447, permute_67) + view_455 = torch.ops.aten.view.default(mm_43, [2, 8192, 128]); mm_43 = None + view_462 = torch.ops.aten.view.default(mm_44, [2, 8192, 128]); mm_44 = None + view_464 = torch.ops.aten.view.default(view_448, [2, 8192, -1, 128]); view_448 = None + view_465 = torch.ops.aten.view.default(view_455, [2, 8192, -1, 128]); view_455 = None + view_466 = torch.ops.aten.view.default(view_462, [2, 8192, -1, 128]); view_462 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_464, torch.float32); view_464 = None + view_467 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 4, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_467); view_467 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_465, torch.float32); view_465 = None + view_468 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 1, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_468); view_468 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_37); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_470 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 4, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_37); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_471 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 1, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_470, torch.bfloat16); view_470 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_471, torch.bfloat16); view_471 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 1, 4, 128]); unsqueeze_12 = None + view_472 = torch.ops.aten.view.default(expand_12, [2, 8192, 4, 128]); expand_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_466, 3); view_466 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 1, 4, 128]); unsqueeze_13 = None + view_473 = torch.ops.aten.view.default(expand_13, [2, 8192, 4, 128]); expand_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_472, [0, 2, 1, 3]); view_472 = None + permute_71 = torch.ops.aten.permute.default(view_473, [0, 2, 1, 3]); view_473 = None + _scaled_dot_product_cudnn_attention_backward_25 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1173, permute_69, permute_70, permute_71, getitem_326, getitem_327, getitem_332, getitem_333, None, None, None, 8192, 8192, 0.0, True); permute_1173 = permute_69 = permute_70 = permute_71 = getitem_326 = getitem_327 = getitem_332 = getitem_333 = None + getitem_2299 = _scaled_dot_product_cudnn_attention_backward_25[0] + getitem_2300 = _scaled_dot_product_cudnn_attention_backward_25[1] + getitem_2301 = _scaled_dot_product_cudnn_attention_backward_25[2]; _scaled_dot_product_cudnn_attention_backward_25 = None + permute_1174 = torch.ops.aten.permute.default(getitem_2301, [0, 2, 1, 3]); getitem_2301 = None + permute_1175 = torch.ops.aten.permute.default(getitem_2300, [0, 2, 1, 3]); getitem_2300 = None + permute_1176 = torch.ops.aten.permute.default(getitem_2299, [0, 2, 1, 3]); getitem_2299 = None + view_2932 = torch.ops.aten.view.default(permute_1174, [2, 8192, 1, 4, 128]); permute_1174 = None + sum_155 = torch.ops.aten.sum.dim_IntList(view_2932, [3], True); view_2932 = None + squeeze_50 = torch.ops.aten.squeeze.dim(sum_155, 3); sum_155 = None + view_2933 = torch.ops.aten.view.default(permute_1175, [2, 8192, 1, 4, 128]); permute_1175 = None + sum_156 = torch.ops.aten.sum.dim_IntList(view_2933, [3], True); view_2933 = None + squeeze_51 = torch.ops.aten.squeeze.dim(sum_156, 3); sum_156 = None + convert_element_type_2453 = torch.ops.prims.convert_element_type.default(squeeze_51, torch.float32); squeeze_51 = None + convert_element_type_2454 = torch.ops.prims.convert_element_type.default(permute_1176, torch.float32); permute_1176 = None + view_2934 = torch.ops.aten.view.default(convert_element_type_2453, [2, 8192, 1, 64, 2]); convert_element_type_2453 = None + view_as_complex_114 = torch.ops.aten.view_as_complex.default(view_2934); view_2934 = None + mul_776 = torch.ops.aten.mul.Tensor(view_as_complex_114, _conj); view_as_complex_114 = None + view_2935 = torch.ops.aten.view.default(convert_element_type_2454, [2, 8192, 4, 64, 2]); convert_element_type_2454 = None + view_as_complex_115 = torch.ops.aten.view_as_complex.default(view_2935); view_2935 = None + mul_777 = torch.ops.aten.mul.Tensor(view_as_complex_115, _conj); view_as_complex_115 = None + view_as_real_114 = torch.ops.aten.view_as_real.default(mul_776); mul_776 = None + view_2936 = torch.ops.aten.view.default(view_as_real_114, [2, 8192, 1, 128]); view_as_real_114 = None + convert_element_type_2455 = torch.ops.prims.convert_element_type.default(view_2936, torch.bfloat16); view_2936 = None + view_as_real_115 = torch.ops.aten.view_as_real.default(mul_777); mul_777 = None + view_2937 = torch.ops.aten.view.default(view_as_real_115, [2, 8192, 4, 128]); view_as_real_115 = None + convert_element_type_2456 = torch.ops.prims.convert_element_type.default(view_2937, torch.bfloat16); view_2937 = None + view_2938 = torch.ops.aten.view.default(squeeze_50, [2, 8192, 128]); squeeze_50 = None + view_2939 = torch.ops.aten.view.default(convert_element_type_2455, [2, 8192, 128]); convert_element_type_2455 = None + view_2940 = torch.ops.aten.view.default(convert_element_type_2456, [2, 8192, 512]); convert_element_type_2456 = None + view_2941 = torch.ops.aten.view.default(view_2938, [16384, 128]); view_2938 = None + permute_1177 = torch.ops.aten.permute.default(view_2941, [1, 0]) + mm_585 = torch.ops.aten.mm.default(permute_1177, view_447); permute_1177 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16); primals_61 = None + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 32, '0'); convert_element_type_208 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + permute_1179 = torch.ops.aten.permute.default(permute_68, [1, 0]); permute_68 = None + mm_586 = torch.ops.aten.mm.default(view_2941, permute_1179); view_2941 = permute_1179 = None + view_2942 = torch.ops.aten.view.default(mm_586, [2, 8192, 4096]); mm_586 = None + convert_element_type_2461 = torch.ops.prims.convert_element_type.default(mm_585, torch.float32); mm_585 = None + reduce_scatter_tensor_349 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2461, 'avg', 32, '0'); convert_element_type_2461 = None + wait_tensor_809 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_349); reduce_scatter_tensor_349 = None + view_2943 = torch.ops.aten.view.default(view_2939, [16384, 128]); view_2939 = None + permute_1181 = torch.ops.aten.permute.default(view_2943, [1, 0]) + mm_587 = torch.ops.aten.mm.default(permute_1181, view_447); permute_1181 = None + permute_1183 = torch.ops.aten.permute.default(permute_67, [1, 0]); permute_67 = None + mm_588 = torch.ops.aten.mm.default(view_2943, permute_1183); view_2943 = permute_1183 = None + view_2944 = torch.ops.aten.view.default(mm_588, [2, 8192, 4096]); mm_588 = None + add_308 = torch.ops.aten.add.Tensor(view_2942, view_2944); view_2942 = view_2944 = None + convert_element_type_2466 = torch.ops.prims.convert_element_type.default(mm_587, torch.float32); mm_587 = None + reduce_scatter_tensor_350 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2466, 'avg', 32, '0'); convert_element_type_2466 = None + wait_tensor_810 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_350); reduce_scatter_tensor_350 = None + view_2945 = torch.ops.aten.view.default(view_2940, [16384, 512]); view_2940 = None + permute_1185 = torch.ops.aten.permute.default(view_2945, [1, 0]) + mm_589 = torch.ops.aten.mm.default(permute_1185, view_447); permute_1185 = view_447 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16); primals_59 = None + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 32, '0'); convert_element_type_202 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_82, [1, 0]); wait_tensor_82 = None + permute_1187 = torch.ops.aten.permute.default(permute_66, [1, 0]); permute_66 = None + mm_590 = torch.ops.aten.mm.default(view_2945, permute_1187); view_2945 = permute_1187 = None + view_2946 = torch.ops.aten.view.default(mm_590, [2, 8192, 4096]); mm_590 = None + add_309 = torch.ops.aten.add.Tensor(add_308, view_2946); add_308 = view_2946 = None + convert_element_type_2471 = torch.ops.prims.convert_element_type.default(mm_589, torch.float32); mm_589 = None + reduce_scatter_tensor_351 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2471, 'avg', 32, '0'); convert_element_type_2471 = None + wait_tensor_811 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_351); reduce_scatter_tensor_351 = None + split_242 = torch.ops.aten.split.Tensor(add_309, 1024, 1); add_309 = None + getitem_2302 = split_242[0] + getitem_2303 = split_242[1] + getitem_2304 = split_242[2] + getitem_2305 = split_242[3] + getitem_2306 = split_242[4] + getitem_2307 = split_242[5] + getitem_2308 = split_242[6] + getitem_2309 = split_242[7]; split_242 = None + cat_234 = torch.ops.aten.cat.default([getitem_2302, getitem_2303, getitem_2304, getitem_2305, getitem_2306, getitem_2307, getitem_2308, getitem_2309]); getitem_2302 = getitem_2303 = getitem_2304 = getitem_2305 = getitem_2306 = getitem_2307 = getitem_2308 = getitem_2309 = None + reduce_scatter_tensor_352 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_234, 'sum', 8, '1'); cat_234 = None + wait_tensor_812 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_352); reduce_scatter_tensor_352 = None + convert_element_type_2472 = torch.ops.prims.convert_element_type.default(wait_tensor_812, torch.float32); wait_tensor_812 = None + convert_element_type_2474 = torch.ops.prims.convert_element_type.default(wait_tensor_80, torch.float32); wait_tensor_80 = None + mul_778 = torch.ops.aten.mul.Tensor(convert_element_type_2472, convert_element_type_2474); convert_element_type_2474 = None + mul_780 = torch.ops.aten.mul.Tensor(mul_48, mul_778) + sum_157 = torch.ops.aten.sum.dim_IntList(mul_780, [2], True); mul_780 = None + div_52 = torch.ops.aten.div.Tensor(mul_48, 4096) + mul_781 = torch.ops.aten.mul.Tensor(div_52, sum_157); div_52 = sum_157 = None + sub_79 = torch.ops.aten.sub.Tensor(mul_778, mul_781); mul_778 = mul_781 = None + mul_782 = torch.ops.aten.mul.Tensor(sub_79, rsqrt_12); sub_79 = rsqrt_12 = None + mul_783 = torch.ops.aten.mul.Tensor(convert_element_type_2472, mul_48); convert_element_type_2472 = mul_48 = None + sum_158 = torch.ops.aten.sum.dim_IntList(mul_783, [0, 1]); mul_783 = None + convert_element_type_2475 = torch.ops.prims.convert_element_type.default(mul_782, torch.bfloat16); mul_782 = None + convert_element_type_2476 = torch.ops.prims.convert_element_type.default(sum_158, torch.bfloat16); sum_158 = None + all_reduce_52 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2476, 'sum', '1'); convert_element_type_2476 = None + wait_tensor_813 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_52); all_reduce_52 = None + convert_element_type_2477 = torch.ops.prims.convert_element_type.default(wait_tensor_813, torch.float32); wait_tensor_813 = None + reduce_scatter_tensor_353 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2477, 'avg', 32, '0'); convert_element_type_2477 = None + wait_tensor_814 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_353); reduce_scatter_tensor_353 = None + add_310 = torch.ops.aten.add.Tensor(add_307, convert_element_type_2475); add_307 = convert_element_type_2475 = None + all_gather_into_tensor_408 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_310, 8, '1') + wait_tensor_815 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_408); all_gather_into_tensor_408 = None + split_243 = torch.ops.aten.split.Tensor(wait_tensor_815, 2); wait_tensor_815 = None + getitem_2310 = split_243[0] + getitem_2311 = split_243[1] + getitem_2312 = split_243[2] + getitem_2313 = split_243[3] + getitem_2314 = split_243[4] + getitem_2315 = split_243[5] + getitem_2316 = split_243[6] + getitem_2317 = split_243[7]; split_243 = None + cat_235 = torch.ops.aten.cat.default([getitem_2310, getitem_2311, getitem_2312, getitem_2313, getitem_2314, getitem_2315, getitem_2316, getitem_2317], 1); getitem_2310 = getitem_2311 = getitem_2312 = getitem_2313 = getitem_2314 = getitem_2315 = getitem_2316 = getitem_2317 = None + view_2947 = torch.ops.aten.view.default(cat_235, [16384, 4096]); cat_235 = None + permute_1189 = torch.ops.aten.permute.default(view_2947, [1, 0]) + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11); reduce_scatter_tensor_11 = None + add_21 = torch.ops.aten.add.Tensor(add_19, wait_tensor_73); wait_tensor_73 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 32, '0'); convert_element_type_185 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32); add_21 = None + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_74) + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_187, 8, '1'); convert_element_type_187 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + split_31 = torch.ops.aten.split.Tensor(wait_tensor_75, 2); wait_tensor_75 = None + getitem_302 = split_31[0] + getitem_303 = split_31[1] + getitem_304 = split_31[2] + getitem_305 = split_31[3] + getitem_306 = split_31[4] + getitem_307 = split_31[5] + getitem_308 = split_31[6] + getitem_309 = split_31[7]; split_31 = None + cat_23 = torch.ops.aten.cat.default([getitem_302, getitem_303, getitem_304, getitem_305, getitem_306, getitem_307, getitem_308, getitem_309], 1); getitem_302 = getitem_303 = getitem_304 = getitem_305 = getitem_306 = getitem_307 = getitem_308 = getitem_309 = None + view_420 = torch.ops.aten.view.default(cat_23, [16384, 4096]); cat_23 = None + view_421 = torch.ops.aten.view.default(mm_39, [2, 8192, 1792]); mm_39 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_421, torch.float32); view_421 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16); primals_56 = None + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 32, '0'); convert_element_type_193 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + mm_40 = torch.ops.aten.mm.default(view_420, permute_64) + view_428 = torch.ops.aten.view.default(mm_40, [2, 8192, 1792]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_428) + view_435 = torch.ops.aten.view.default(mul_47, [16384, 1792]); mul_47 = None + mm_591 = torch.ops.aten.mm.default(permute_1189, view_435); permute_1189 = view_435 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16); primals_57 = None + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 32, '0'); convert_element_type_196 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + permute_1191 = torch.ops.aten.permute.default(permute_65, [1, 0]); permute_65 = None + mm_592 = torch.ops.aten.mm.default(view_2947, permute_1191); view_2947 = permute_1191 = None + view_2948 = torch.ops.aten.view.default(mm_592, [2, 8192, 1792]); mm_592 = None + convert_element_type_2482 = torch.ops.prims.convert_element_type.default(mm_591, torch.float32); mm_591 = None + reduce_scatter_tensor_354 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2482, 'avg', 32, '0'); convert_element_type_2482 = None + wait_tensor_816 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_354); reduce_scatter_tensor_354 = None + mul_784 = torch.ops.aten.mul.Tensor(view_2948, convert_element_type_192); convert_element_type_192 = None + mul_785 = torch.ops.aten.mul.Tensor(view_2948, view_428); view_2948 = view_428 = None + view_2949 = torch.ops.aten.view.default(mul_784, [16384, 1792]); mul_784 = None + permute_1193 = torch.ops.aten.permute.default(view_2949, [1, 0]) + mm_593 = torch.ops.aten.mm.default(permute_1193, view_420); permute_1193 = None + permute_1195 = torch.ops.aten.permute.default(permute_64, [1, 0]); permute_64 = None + mm_594 = torch.ops.aten.mm.default(view_2949, permute_1195); view_2949 = permute_1195 = None + view_2950 = torch.ops.aten.view.default(mm_594, [2, 8192, 4096]); mm_594 = None + convert_element_type_2487 = torch.ops.prims.convert_element_type.default(mm_593, torch.float32); mm_593 = None + reduce_scatter_tensor_355 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2487, 'avg', 32, '0'); convert_element_type_2487 = None + wait_tensor_817 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_355); reduce_scatter_tensor_355 = None + convert_element_type_2488 = torch.ops.prims.convert_element_type.default(mul_785, torch.float32); mul_785 = None + neg_26 = torch.ops.aten.neg.default(convert_element_type_191) + exp_26 = torch.ops.aten.exp.default(neg_26); neg_26 = None + add_311 = torch.ops.aten.add.Tensor(exp_26, 1); exp_26 = None + reciprocal_26 = torch.ops.aten.reciprocal.default(add_311); add_311 = None + mul_786 = torch.ops.aten.mul.Tensor(reciprocal_26, 1); reciprocal_26 = None + mul_787 = torch.ops.aten.mul.Tensor(convert_element_type_2488, mul_786); convert_element_type_2488 = None + sub_80 = torch.ops.aten.sub.Tensor(1, mul_786); mul_786 = None + mul_788 = torch.ops.aten.mul.Tensor(convert_element_type_191, sub_80); convert_element_type_191 = sub_80 = None + add_312 = torch.ops.aten.add.Tensor(mul_788, 1); mul_788 = None + mul_789 = torch.ops.aten.mul.Tensor(mul_787, add_312); mul_787 = add_312 = None + convert_element_type_2490 = torch.ops.prims.convert_element_type.default(mul_789, torch.bfloat16); mul_789 = None + view_2951 = torch.ops.aten.view.default(convert_element_type_2490, [16384, 1792]); convert_element_type_2490 = None + permute_1197 = torch.ops.aten.permute.default(view_2951, [1, 0]) + mm_595 = torch.ops.aten.mm.default(permute_1197, view_420); permute_1197 = view_420 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16); primals_55 = None + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 32, '0'); convert_element_type_188 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + permute_1199 = torch.ops.aten.permute.default(permute_63, [1, 0]); permute_63 = None + mm_596 = torch.ops.aten.mm.default(view_2951, permute_1199); view_2951 = permute_1199 = None + view_2952 = torch.ops.aten.view.default(mm_596, [2, 8192, 4096]); mm_596 = None + add_313 = torch.ops.aten.add.Tensor(view_2950, view_2952); view_2950 = view_2952 = None + convert_element_type_2495 = torch.ops.prims.convert_element_type.default(mm_595, torch.float32); mm_595 = None + reduce_scatter_tensor_356 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2495, 'avg', 32, '0'); convert_element_type_2495 = None + wait_tensor_818 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_356); reduce_scatter_tensor_356 = None + split_244 = torch.ops.aten.split.Tensor(add_313, 1024, 1); add_313 = None + getitem_2318 = split_244[0] + getitem_2319 = split_244[1] + getitem_2320 = split_244[2] + getitem_2321 = split_244[3] + getitem_2322 = split_244[4] + getitem_2323 = split_244[5] + getitem_2324 = split_244[6] + getitem_2325 = split_244[7]; split_244 = None + cat_236 = torch.ops.aten.cat.default([getitem_2318, getitem_2319, getitem_2320, getitem_2321, getitem_2322, getitem_2323, getitem_2324, getitem_2325]); getitem_2318 = getitem_2319 = getitem_2320 = getitem_2321 = getitem_2322 = getitem_2323 = getitem_2324 = getitem_2325 = None + reduce_scatter_tensor_357 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_236, 'sum', 8, '1'); cat_236 = None + wait_tensor_819 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_357); reduce_scatter_tensor_357 = None + convert_element_type_2496 = torch.ops.prims.convert_element_type.default(wait_tensor_819, torch.float32); wait_tensor_819 = None + convert_element_type_2498 = torch.ops.prims.convert_element_type.default(wait_tensor_74, torch.float32); wait_tensor_74 = None + mul_790 = torch.ops.aten.mul.Tensor(convert_element_type_2496, convert_element_type_2498); convert_element_type_2498 = None + mul_792 = torch.ops.aten.mul.Tensor(mul_44, mul_790) + sum_159 = torch.ops.aten.sum.dim_IntList(mul_792, [2], True); mul_792 = None + div_53 = torch.ops.aten.div.Tensor(mul_44, 4096) + mul_793 = torch.ops.aten.mul.Tensor(div_53, sum_159); div_53 = sum_159 = None + sub_81 = torch.ops.aten.sub.Tensor(mul_790, mul_793); mul_790 = mul_793 = None + mul_794 = torch.ops.aten.mul.Tensor(sub_81, rsqrt_11); sub_81 = rsqrt_11 = None + mul_795 = torch.ops.aten.mul.Tensor(convert_element_type_2496, mul_44); convert_element_type_2496 = mul_44 = None + sum_160 = torch.ops.aten.sum.dim_IntList(mul_795, [0, 1]); mul_795 = None + convert_element_type_2499 = torch.ops.prims.convert_element_type.default(mul_794, torch.bfloat16); mul_794 = None + convert_element_type_2500 = torch.ops.prims.convert_element_type.default(sum_160, torch.bfloat16); sum_160 = None + all_reduce_53 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2500, 'sum', '1'); convert_element_type_2500 = None + wait_tensor_820 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_53); all_reduce_53 = None + convert_element_type_2501 = torch.ops.prims.convert_element_type.default(wait_tensor_820, torch.float32); wait_tensor_820 = None + reduce_scatter_tensor_358 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2501, 'avg', 32, '0'); convert_element_type_2501 = None + wait_tensor_821 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_358); reduce_scatter_tensor_358 = None + add_314 = torch.ops.aten.add.Tensor(add_310, convert_element_type_2499); add_310 = convert_element_type_2499 = None + all_gather_into_tensor_409 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_314, 8, '1') + wait_tensor_822 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_409); all_gather_into_tensor_409 = None + split_245 = torch.ops.aten.split.Tensor(wait_tensor_822, 2); wait_tensor_822 = None + getitem_2326 = split_245[0] + getitem_2327 = split_245[1] + getitem_2328 = split_245[2] + getitem_2329 = split_245[3] + getitem_2330 = split_245[4] + getitem_2331 = split_245[5] + getitem_2332 = split_245[6] + getitem_2333 = split_245[7]; split_245 = None + cat_237 = torch.ops.aten.cat.default([getitem_2326, getitem_2327, getitem_2328, getitem_2329, getitem_2330, getitem_2331, getitem_2332, getitem_2333], 1); getitem_2326 = getitem_2327 = getitem_2328 = getitem_2329 = getitem_2330 = getitem_2331 = getitem_2332 = getitem_2333 = None + view_2953 = torch.ops.aten.view.default(cat_237, [16384, 4096]); cat_237 = None + permute_1201 = torch.ops.aten.permute.default(view_2953, [1, 0]) + permute_61 = torch.ops.aten.permute.default(getitem_285, [0, 2, 1, 3]) + view_402 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + view_408 = torch.ops.aten.view.default(view_402, [16384, 512]); view_402 = None + mm_597 = torch.ops.aten.mm.default(permute_1201, view_408); permute_1201 = view_408 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 32, '0'); convert_element_type_182 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + permute_1203 = torch.ops.aten.permute.default(permute_62, [1, 0]); permute_62 = None + mm_598 = torch.ops.aten.mm.default(view_2953, permute_1203); view_2953 = permute_1203 = None + view_2954 = torch.ops.aten.view.default(mm_598, [2, 8192, 512]); mm_598 = None + convert_element_type_2506 = torch.ops.prims.convert_element_type.default(mm_597, torch.float32); mm_597 = None + reduce_scatter_tensor_359 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2506, 'avg', 32, '0'); convert_element_type_2506 = None + wait_tensor_823 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_359); reduce_scatter_tensor_359 = None + view_2955 = torch.ops.aten.view.default(view_2954, [2, 8192, 4, 128]); view_2954 = None + permute_1205 = torch.ops.aten.permute.default(view_2955, [0, 2, 1, 3]); view_2955 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 32, '0'); convert_element_type_166 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32); add_19 = None + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_67) + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_168, 8, '1'); convert_element_type_168 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + split_29 = torch.ops.aten.split.Tensor(wait_tensor_68, 2); wait_tensor_68 = None + getitem_277 = split_29[0] + getitem_278 = split_29[1] + getitem_279 = split_29[2] + getitem_280 = split_29[3] + getitem_281 = split_29[4] + getitem_282 = split_29[5] + getitem_283 = split_29[6] + getitem_284 = split_29[7]; split_29 = None + cat_21 = torch.ops.aten.cat.default([getitem_277, getitem_278, getitem_279, getitem_280, getitem_281, getitem_282, getitem_283, getitem_284], 1); getitem_277 = getitem_278 = getitem_279 = getitem_280 = getitem_281 = getitem_282 = getitem_283 = getitem_284 = None + view_375 = torch.ops.aten.view.default(cat_21, [16384, 4096]); cat_21 = None + view_376 = torch.ops.aten.view.default(mm_35, [2, 8192, 512]); mm_35 = None + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 32, '0'); convert_element_type_172 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + mm_36 = torch.ops.aten.mm.default(view_375, permute_56) + view_383 = torch.ops.aten.view.default(mm_36, [2, 8192, 128]); mm_36 = None + view_390 = torch.ops.aten.view.default(mm_37, [2, 8192, 128]); mm_37 = None + view_392 = torch.ops.aten.view.default(view_376, [2, 8192, -1, 128]); view_376 = None + view_393 = torch.ops.aten.view.default(view_383, [2, 8192, -1, 128]); view_383 = None + view_394 = torch.ops.aten.view.default(view_390, [2, 8192, -1, 128]); view_390 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_392, torch.float32); view_392 = None + view_395 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 4, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_395); view_395 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_393, torch.float32); view_393 = None + view_396 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 1, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_396); view_396 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_37); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_398 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 4, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_37); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_399 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 1, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_398, torch.bfloat16); view_398 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_399, torch.bfloat16); view_399 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 1, 4, 128]); unsqueeze_10 = None + view_400 = torch.ops.aten.view.default(expand_10, [2, 8192, 4, 128]); expand_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_394, 3); view_394 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 1, 4, 128]); unsqueeze_11 = None + view_401 = torch.ops.aten.view.default(expand_11, [2, 8192, 4, 128]); expand_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_400, [0, 2, 1, 3]); view_400 = None + permute_60 = torch.ops.aten.permute.default(view_401, [0, 2, 1, 3]); view_401 = None + _scaled_dot_product_cudnn_attention_backward_26 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1205, permute_58, permute_59, permute_60, getitem_285, getitem_286, getitem_291, getitem_292, None, None, None, 8192, 8192, 0.0, True); permute_1205 = permute_58 = permute_59 = permute_60 = getitem_285 = getitem_286 = getitem_291 = getitem_292 = None + getitem_2334 = _scaled_dot_product_cudnn_attention_backward_26[0] + getitem_2335 = _scaled_dot_product_cudnn_attention_backward_26[1] + getitem_2336 = _scaled_dot_product_cudnn_attention_backward_26[2]; _scaled_dot_product_cudnn_attention_backward_26 = None + permute_1206 = torch.ops.aten.permute.default(getitem_2336, [0, 2, 1, 3]); getitem_2336 = None + permute_1207 = torch.ops.aten.permute.default(getitem_2335, [0, 2, 1, 3]); getitem_2335 = None + permute_1208 = torch.ops.aten.permute.default(getitem_2334, [0, 2, 1, 3]); getitem_2334 = None + view_2956 = torch.ops.aten.view.default(permute_1206, [2, 8192, 1, 4, 128]); permute_1206 = None + sum_161 = torch.ops.aten.sum.dim_IntList(view_2956, [3], True); view_2956 = None + squeeze_52 = torch.ops.aten.squeeze.dim(sum_161, 3); sum_161 = None + view_2957 = torch.ops.aten.view.default(permute_1207, [2, 8192, 1, 4, 128]); permute_1207 = None + sum_162 = torch.ops.aten.sum.dim_IntList(view_2957, [3], True); view_2957 = None + squeeze_53 = torch.ops.aten.squeeze.dim(sum_162, 3); sum_162 = None + convert_element_type_2507 = torch.ops.prims.convert_element_type.default(squeeze_53, torch.float32); squeeze_53 = None + convert_element_type_2508 = torch.ops.prims.convert_element_type.default(permute_1208, torch.float32); permute_1208 = None + view_2958 = torch.ops.aten.view.default(convert_element_type_2507, [2, 8192, 1, 64, 2]); convert_element_type_2507 = None + view_as_complex_116 = torch.ops.aten.view_as_complex.default(view_2958); view_2958 = None + mul_796 = torch.ops.aten.mul.Tensor(view_as_complex_116, _conj); view_as_complex_116 = None + view_2959 = torch.ops.aten.view.default(convert_element_type_2508, [2, 8192, 4, 64, 2]); convert_element_type_2508 = None + view_as_complex_117 = torch.ops.aten.view_as_complex.default(view_2959); view_2959 = None + mul_797 = torch.ops.aten.mul.Tensor(view_as_complex_117, _conj); view_as_complex_117 = None + view_as_real_116 = torch.ops.aten.view_as_real.default(mul_796); mul_796 = None + view_2960 = torch.ops.aten.view.default(view_as_real_116, [2, 8192, 1, 128]); view_as_real_116 = None + convert_element_type_2509 = torch.ops.prims.convert_element_type.default(view_2960, torch.bfloat16); view_2960 = None + view_as_real_117 = torch.ops.aten.view_as_real.default(mul_797); mul_797 = None + view_2961 = torch.ops.aten.view.default(view_as_real_117, [2, 8192, 4, 128]); view_as_real_117 = None + convert_element_type_2510 = torch.ops.prims.convert_element_type.default(view_2961, torch.bfloat16); view_2961 = None + view_2962 = torch.ops.aten.view.default(squeeze_52, [2, 8192, 128]); squeeze_52 = None + view_2963 = torch.ops.aten.view.default(convert_element_type_2509, [2, 8192, 128]); convert_element_type_2509 = None + view_2964 = torch.ops.aten.view.default(convert_element_type_2510, [2, 8192, 512]); convert_element_type_2510 = None + view_2965 = torch.ops.aten.view.default(view_2962, [16384, 128]); view_2962 = None + permute_1209 = torch.ops.aten.permute.default(view_2965, [1, 0]) + mm_599 = torch.ops.aten.mm.default(permute_1209, view_375); permute_1209 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 32, '0'); convert_element_type_175 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + permute_1211 = torch.ops.aten.permute.default(permute_57, [1, 0]); permute_57 = None + mm_600 = torch.ops.aten.mm.default(view_2965, permute_1211); view_2965 = permute_1211 = None + view_2966 = torch.ops.aten.view.default(mm_600, [2, 8192, 4096]); mm_600 = None + convert_element_type_2515 = torch.ops.prims.convert_element_type.default(mm_599, torch.float32); mm_599 = None + reduce_scatter_tensor_360 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2515, 'avg', 32, '0'); convert_element_type_2515 = None + wait_tensor_824 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_360); reduce_scatter_tensor_360 = None + view_2967 = torch.ops.aten.view.default(view_2963, [16384, 128]); view_2963 = None + permute_1213 = torch.ops.aten.permute.default(view_2967, [1, 0]) + mm_601 = torch.ops.aten.mm.default(permute_1213, view_375); permute_1213 = None + permute_1215 = torch.ops.aten.permute.default(permute_56, [1, 0]); permute_56 = None + mm_602 = torch.ops.aten.mm.default(view_2967, permute_1215); view_2967 = permute_1215 = None + view_2968 = torch.ops.aten.view.default(mm_602, [2, 8192, 4096]); mm_602 = None + add_315 = torch.ops.aten.add.Tensor(view_2966, view_2968); view_2966 = view_2968 = None + convert_element_type_2520 = torch.ops.prims.convert_element_type.default(mm_601, torch.float32); mm_601 = None + reduce_scatter_tensor_361 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2520, 'avg', 32, '0'); convert_element_type_2520 = None + wait_tensor_825 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_361); reduce_scatter_tensor_361 = None + view_2969 = torch.ops.aten.view.default(view_2964, [16384, 512]); view_2964 = None + permute_1217 = torch.ops.aten.permute.default(view_2969, [1, 0]) + mm_603 = torch.ops.aten.mm.default(permute_1217, view_375); permute_1217 = view_375 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 32, '0'); convert_element_type_169 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_69, [1, 0]); wait_tensor_69 = None + permute_1219 = torch.ops.aten.permute.default(permute_55, [1, 0]); permute_55 = None + mm_604 = torch.ops.aten.mm.default(view_2969, permute_1219); view_2969 = permute_1219 = None + view_2970 = torch.ops.aten.view.default(mm_604, [2, 8192, 4096]); mm_604 = None + add_316 = torch.ops.aten.add.Tensor(add_315, view_2970); add_315 = view_2970 = None + convert_element_type_2525 = torch.ops.prims.convert_element_type.default(mm_603, torch.float32); mm_603 = None + reduce_scatter_tensor_362 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2525, 'avg', 32, '0'); convert_element_type_2525 = None + wait_tensor_826 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_362); reduce_scatter_tensor_362 = None + split_246 = torch.ops.aten.split.Tensor(add_316, 1024, 1); add_316 = None + getitem_2337 = split_246[0] + getitem_2338 = split_246[1] + getitem_2339 = split_246[2] + getitem_2340 = split_246[3] + getitem_2341 = split_246[4] + getitem_2342 = split_246[5] + getitem_2343 = split_246[6] + getitem_2344 = split_246[7]; split_246 = None + cat_238 = torch.ops.aten.cat.default([getitem_2337, getitem_2338, getitem_2339, getitem_2340, getitem_2341, getitem_2342, getitem_2343, getitem_2344]); getitem_2337 = getitem_2338 = getitem_2339 = getitem_2340 = getitem_2341 = getitem_2342 = getitem_2343 = getitem_2344 = None + reduce_scatter_tensor_363 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_238, 'sum', 8, '1'); cat_238 = None + wait_tensor_827 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_363); reduce_scatter_tensor_363 = None + convert_element_type_2526 = torch.ops.prims.convert_element_type.default(wait_tensor_827, torch.float32); wait_tensor_827 = None + convert_element_type_2528 = torch.ops.prims.convert_element_type.default(wait_tensor_67, torch.float32); wait_tensor_67 = None + mul_798 = torch.ops.aten.mul.Tensor(convert_element_type_2526, convert_element_type_2528); convert_element_type_2528 = None + mul_800 = torch.ops.aten.mul.Tensor(mul_40, mul_798) + sum_163 = torch.ops.aten.sum.dim_IntList(mul_800, [2], True); mul_800 = None + div_54 = torch.ops.aten.div.Tensor(mul_40, 4096) + mul_801 = torch.ops.aten.mul.Tensor(div_54, sum_163); div_54 = sum_163 = None + sub_82 = torch.ops.aten.sub.Tensor(mul_798, mul_801); mul_798 = mul_801 = None + mul_802 = torch.ops.aten.mul.Tensor(sub_82, rsqrt_10); sub_82 = rsqrt_10 = None + mul_803 = torch.ops.aten.mul.Tensor(convert_element_type_2526, mul_40); convert_element_type_2526 = mul_40 = None + sum_164 = torch.ops.aten.sum.dim_IntList(mul_803, [0, 1]); mul_803 = None + convert_element_type_2529 = torch.ops.prims.convert_element_type.default(mul_802, torch.bfloat16); mul_802 = None + convert_element_type_2530 = torch.ops.prims.convert_element_type.default(sum_164, torch.bfloat16); sum_164 = None + all_reduce_54 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2530, 'sum', '1'); convert_element_type_2530 = None + wait_tensor_828 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_54); all_reduce_54 = None + convert_element_type_2531 = torch.ops.prims.convert_element_type.default(wait_tensor_828, torch.float32); wait_tensor_828 = None + reduce_scatter_tensor_364 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2531, 'avg', 32, '0'); convert_element_type_2531 = None + wait_tensor_829 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_364); reduce_scatter_tensor_364 = None + add_317 = torch.ops.aten.add.Tensor(add_314, convert_element_type_2529); add_314 = convert_element_type_2529 = None + all_gather_into_tensor_410 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_317, 8, '1') + wait_tensor_830 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_410); all_gather_into_tensor_410 = None + split_247 = torch.ops.aten.split.Tensor(wait_tensor_830, 2); wait_tensor_830 = None + getitem_2345 = split_247[0] + getitem_2346 = split_247[1] + getitem_2347 = split_247[2] + getitem_2348 = split_247[3] + getitem_2349 = split_247[4] + getitem_2350 = split_247[5] + getitem_2351 = split_247[6] + getitem_2352 = split_247[7]; split_247 = None + cat_239 = torch.ops.aten.cat.default([getitem_2345, getitem_2346, getitem_2347, getitem_2348, getitem_2349, getitem_2350, getitem_2351, getitem_2352], 1); getitem_2345 = getitem_2346 = getitem_2347 = getitem_2348 = getitem_2349 = getitem_2350 = getitem_2351 = getitem_2352 = None + view_2971 = torch.ops.aten.view.default(cat_239, [16384, 4096]); cat_239 = None + permute_1221 = torch.ops.aten.permute.default(view_2971, [1, 0]) + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9); reduce_scatter_tensor_9 = None + add_17 = torch.ops.aten.add.Tensor(add_15, wait_tensor_60); wait_tensor_60 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16); primals_45 = None + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 32, '0'); convert_element_type_152 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32); add_17 = None + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_61) + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_154, 8, '1'); convert_element_type_154 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + split_27 = torch.ops.aten.split.Tensor(wait_tensor_62, 2); wait_tensor_62 = None + getitem_261 = split_27[0] + getitem_262 = split_27[1] + getitem_263 = split_27[2] + getitem_264 = split_27[3] + getitem_265 = split_27[4] + getitem_266 = split_27[5] + getitem_267 = split_27[6] + getitem_268 = split_27[7]; split_27 = None + cat_19 = torch.ops.aten.cat.default([getitem_261, getitem_262, getitem_263, getitem_264, getitem_265, getitem_266, getitem_267, getitem_268], 1); getitem_261 = getitem_262 = getitem_263 = getitem_264 = getitem_265 = getitem_266 = getitem_267 = getitem_268 = None + view_348 = torch.ops.aten.view.default(cat_19, [16384, 4096]); cat_19 = None + view_349 = torch.ops.aten.view.default(mm_32, [2, 8192, 1792]); mm_32 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_349, torch.float32); view_349 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 32, '0'); convert_element_type_160 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_64, [1, 0]); wait_tensor_64 = None + mm_33 = torch.ops.aten.mm.default(view_348, permute_53) + view_356 = torch.ops.aten.view.default(mm_33, [2, 8192, 1792]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_356) + view_363 = torch.ops.aten.view.default(mul_39, [16384, 1792]); mul_39 = None + mm_605 = torch.ops.aten.mm.default(permute_1221, view_363); permute_1221 = view_363 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16); primals_48 = None + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 32, '0'); convert_element_type_163 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + permute_1223 = torch.ops.aten.permute.default(permute_54, [1, 0]); permute_54 = None + mm_606 = torch.ops.aten.mm.default(view_2971, permute_1223); view_2971 = permute_1223 = None + view_2972 = torch.ops.aten.view.default(mm_606, [2, 8192, 1792]); mm_606 = None + convert_element_type_2536 = torch.ops.prims.convert_element_type.default(mm_605, torch.float32); mm_605 = None + reduce_scatter_tensor_365 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2536, 'avg', 32, '0'); convert_element_type_2536 = None + wait_tensor_831 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_365); reduce_scatter_tensor_365 = None + mul_804 = torch.ops.aten.mul.Tensor(view_2972, convert_element_type_159); convert_element_type_159 = None + mul_805 = torch.ops.aten.mul.Tensor(view_2972, view_356); view_2972 = view_356 = None + view_2973 = torch.ops.aten.view.default(mul_804, [16384, 1792]); mul_804 = None + permute_1225 = torch.ops.aten.permute.default(view_2973, [1, 0]) + mm_607 = torch.ops.aten.mm.default(permute_1225, view_348); permute_1225 = None + permute_1227 = torch.ops.aten.permute.default(permute_53, [1, 0]); permute_53 = None + mm_608 = torch.ops.aten.mm.default(view_2973, permute_1227); view_2973 = permute_1227 = None + view_2974 = torch.ops.aten.view.default(mm_608, [2, 8192, 4096]); mm_608 = None + convert_element_type_2541 = torch.ops.prims.convert_element_type.default(mm_607, torch.float32); mm_607 = None + reduce_scatter_tensor_366 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2541, 'avg', 32, '0'); convert_element_type_2541 = None + wait_tensor_832 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_366); reduce_scatter_tensor_366 = None + convert_element_type_2542 = torch.ops.prims.convert_element_type.default(mul_805, torch.float32); mul_805 = None + neg_27 = torch.ops.aten.neg.default(convert_element_type_158) + exp_27 = torch.ops.aten.exp.default(neg_27); neg_27 = None + add_318 = torch.ops.aten.add.Tensor(exp_27, 1); exp_27 = None + reciprocal_27 = torch.ops.aten.reciprocal.default(add_318); add_318 = None + mul_806 = torch.ops.aten.mul.Tensor(reciprocal_27, 1); reciprocal_27 = None + mul_807 = torch.ops.aten.mul.Tensor(convert_element_type_2542, mul_806); convert_element_type_2542 = None + sub_83 = torch.ops.aten.sub.Tensor(1, mul_806); mul_806 = None + mul_808 = torch.ops.aten.mul.Tensor(convert_element_type_158, sub_83); convert_element_type_158 = sub_83 = None + add_319 = torch.ops.aten.add.Tensor(mul_808, 1); mul_808 = None + mul_809 = torch.ops.aten.mul.Tensor(mul_807, add_319); mul_807 = add_319 = None + convert_element_type_2544 = torch.ops.prims.convert_element_type.default(mul_809, torch.bfloat16); mul_809 = None + view_2975 = torch.ops.aten.view.default(convert_element_type_2544, [16384, 1792]); convert_element_type_2544 = None + permute_1229 = torch.ops.aten.permute.default(view_2975, [1, 0]) + mm_609 = torch.ops.aten.mm.default(permute_1229, view_348); permute_1229 = view_348 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16); primals_46 = None + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 32, '0'); convert_element_type_155 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + permute_1231 = torch.ops.aten.permute.default(permute_52, [1, 0]); permute_52 = None + mm_610 = torch.ops.aten.mm.default(view_2975, permute_1231); view_2975 = permute_1231 = None + view_2976 = torch.ops.aten.view.default(mm_610, [2, 8192, 4096]); mm_610 = None + add_320 = torch.ops.aten.add.Tensor(view_2974, view_2976); view_2974 = view_2976 = None + convert_element_type_2549 = torch.ops.prims.convert_element_type.default(mm_609, torch.float32); mm_609 = None + reduce_scatter_tensor_367 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2549, 'avg', 32, '0'); convert_element_type_2549 = None + wait_tensor_833 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_367); reduce_scatter_tensor_367 = None + split_248 = torch.ops.aten.split.Tensor(add_320, 1024, 1); add_320 = None + getitem_2353 = split_248[0] + getitem_2354 = split_248[1] + getitem_2355 = split_248[2] + getitem_2356 = split_248[3] + getitem_2357 = split_248[4] + getitem_2358 = split_248[5] + getitem_2359 = split_248[6] + getitem_2360 = split_248[7]; split_248 = None + cat_240 = torch.ops.aten.cat.default([getitem_2353, getitem_2354, getitem_2355, getitem_2356, getitem_2357, getitem_2358, getitem_2359, getitem_2360]); getitem_2353 = getitem_2354 = getitem_2355 = getitem_2356 = getitem_2357 = getitem_2358 = getitem_2359 = getitem_2360 = None + reduce_scatter_tensor_368 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_240, 'sum', 8, '1'); cat_240 = None + wait_tensor_834 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_368); reduce_scatter_tensor_368 = None + convert_element_type_2550 = torch.ops.prims.convert_element_type.default(wait_tensor_834, torch.float32); wait_tensor_834 = None + convert_element_type_2552 = torch.ops.prims.convert_element_type.default(wait_tensor_61, torch.float32); wait_tensor_61 = None + mul_810 = torch.ops.aten.mul.Tensor(convert_element_type_2550, convert_element_type_2552); convert_element_type_2552 = None + mul_812 = torch.ops.aten.mul.Tensor(mul_36, mul_810) + sum_165 = torch.ops.aten.sum.dim_IntList(mul_812, [2], True); mul_812 = None + div_55 = torch.ops.aten.div.Tensor(mul_36, 4096) + mul_813 = torch.ops.aten.mul.Tensor(div_55, sum_165); div_55 = sum_165 = None + sub_84 = torch.ops.aten.sub.Tensor(mul_810, mul_813); mul_810 = mul_813 = None + mul_814 = torch.ops.aten.mul.Tensor(sub_84, rsqrt_9); sub_84 = rsqrt_9 = None + mul_815 = torch.ops.aten.mul.Tensor(convert_element_type_2550, mul_36); convert_element_type_2550 = mul_36 = None + sum_166 = torch.ops.aten.sum.dim_IntList(mul_815, [0, 1]); mul_815 = None + convert_element_type_2553 = torch.ops.prims.convert_element_type.default(mul_814, torch.bfloat16); mul_814 = None + convert_element_type_2554 = torch.ops.prims.convert_element_type.default(sum_166, torch.bfloat16); sum_166 = None + all_reduce_55 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2554, 'sum', '1'); convert_element_type_2554 = None + wait_tensor_835 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_55); all_reduce_55 = None + convert_element_type_2555 = torch.ops.prims.convert_element_type.default(wait_tensor_835, torch.float32); wait_tensor_835 = None + reduce_scatter_tensor_369 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2555, 'avg', 32, '0'); convert_element_type_2555 = None + wait_tensor_836 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_369); reduce_scatter_tensor_369 = None + add_321 = torch.ops.aten.add.Tensor(add_317, convert_element_type_2553); add_317 = convert_element_type_2553 = None + all_gather_into_tensor_411 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_321, 8, '1') + wait_tensor_837 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_411); all_gather_into_tensor_411 = None + split_249 = torch.ops.aten.split.Tensor(wait_tensor_837, 2); wait_tensor_837 = None + getitem_2361 = split_249[0] + getitem_2362 = split_249[1] + getitem_2363 = split_249[2] + getitem_2364 = split_249[3] + getitem_2365 = split_249[4] + getitem_2366 = split_249[5] + getitem_2367 = split_249[6] + getitem_2368 = split_249[7]; split_249 = None + cat_241 = torch.ops.aten.cat.default([getitem_2361, getitem_2362, getitem_2363, getitem_2364, getitem_2365, getitem_2366, getitem_2367, getitem_2368], 1); getitem_2361 = getitem_2362 = getitem_2363 = getitem_2364 = getitem_2365 = getitem_2366 = getitem_2367 = getitem_2368 = None + view_2977 = torch.ops.aten.view.default(cat_241, [16384, 4096]); cat_241 = None + permute_1233 = torch.ops.aten.permute.default(view_2977, [1, 0]) + permute_50 = torch.ops.aten.permute.default(getitem_244, [0, 2, 1, 3]) + view_330 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + view_336 = torch.ops.aten.view.default(view_330, [16384, 512]); view_330 = None + mm_611 = torch.ops.aten.mm.default(permute_1233, view_336); permute_1233 = view_336 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16); primals_44 = None + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 32, '0'); convert_element_type_149 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + permute_1235 = torch.ops.aten.permute.default(permute_51, [1, 0]); permute_51 = None + mm_612 = torch.ops.aten.mm.default(view_2977, permute_1235); view_2977 = permute_1235 = None + view_2978 = torch.ops.aten.view.default(mm_612, [2, 8192, 512]); mm_612 = None + convert_element_type_2560 = torch.ops.prims.convert_element_type.default(mm_611, torch.float32); mm_611 = None + reduce_scatter_tensor_370 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2560, 'avg', 32, '0'); convert_element_type_2560 = None + wait_tensor_838 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_370); reduce_scatter_tensor_370 = None + view_2979 = torch.ops.aten.view.default(view_2978, [2, 8192, 4, 128]); view_2978 = None + permute_1237 = torch.ops.aten.permute.default(view_2979, [0, 2, 1, 3]); view_2979 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16); primals_40 = None + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 32, '0'); convert_element_type_133 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32); add_15 = None + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_54) + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_135, 8, '1'); convert_element_type_135 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + split_25 = torch.ops.aten.split.Tensor(wait_tensor_55, 2); wait_tensor_55 = None + getitem_236 = split_25[0] + getitem_237 = split_25[1] + getitem_238 = split_25[2] + getitem_239 = split_25[3] + getitem_240 = split_25[4] + getitem_241 = split_25[5] + getitem_242 = split_25[6] + getitem_243 = split_25[7]; split_25 = None + cat_17 = torch.ops.aten.cat.default([getitem_236, getitem_237, getitem_238, getitem_239, getitem_240, getitem_241, getitem_242, getitem_243], 1); getitem_236 = getitem_237 = getitem_238 = getitem_239 = getitem_240 = getitem_241 = getitem_242 = getitem_243 = None + view_303 = torch.ops.aten.view.default(cat_17, [16384, 4096]); cat_17 = None + view_304 = torch.ops.aten.view.default(mm_28, [2, 8192, 512]); mm_28 = None + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16); primals_42 = None + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 32, '0'); convert_element_type_139 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_29 = torch.ops.aten.mm.default(view_303, permute_45) + view_311 = torch.ops.aten.view.default(mm_29, [2, 8192, 128]); mm_29 = None + view_318 = torch.ops.aten.view.default(mm_30, [2, 8192, 128]); mm_30 = None + view_320 = torch.ops.aten.view.default(view_304, [2, 8192, -1, 128]); view_304 = None + view_321 = torch.ops.aten.view.default(view_311, [2, 8192, -1, 128]); view_311 = None + view_322 = torch.ops.aten.view.default(view_318, [2, 8192, -1, 128]); view_318 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_320, torch.float32); view_320 = None + view_323 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 4, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_323); view_323 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_321, torch.float32); view_321 = None + view_324 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 1, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_324); view_324 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_37); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_326 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 4, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_37); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_327 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 1, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_326, torch.bfloat16); view_326 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_327, torch.bfloat16); view_327 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 1, 4, 128]); unsqueeze_8 = None + view_328 = torch.ops.aten.view.default(expand_8, [2, 8192, 4, 128]); expand_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_322, 3); view_322 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 1, 4, 128]); unsqueeze_9 = None + view_329 = torch.ops.aten.view.default(expand_9, [2, 8192, 4, 128]); expand_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_328, [0, 2, 1, 3]); view_328 = None + permute_49 = torch.ops.aten.permute.default(view_329, [0, 2, 1, 3]); view_329 = None + _scaled_dot_product_cudnn_attention_backward_27 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1237, permute_47, permute_48, permute_49, getitem_244, getitem_245, getitem_250, getitem_251, None, None, None, 8192, 8192, 0.0, True); permute_1237 = permute_47 = permute_48 = permute_49 = getitem_244 = getitem_245 = getitem_250 = getitem_251 = None + getitem_2369 = _scaled_dot_product_cudnn_attention_backward_27[0] + getitem_2370 = _scaled_dot_product_cudnn_attention_backward_27[1] + getitem_2371 = _scaled_dot_product_cudnn_attention_backward_27[2]; _scaled_dot_product_cudnn_attention_backward_27 = None + permute_1238 = torch.ops.aten.permute.default(getitem_2371, [0, 2, 1, 3]); getitem_2371 = None + permute_1239 = torch.ops.aten.permute.default(getitem_2370, [0, 2, 1, 3]); getitem_2370 = None + permute_1240 = torch.ops.aten.permute.default(getitem_2369, [0, 2, 1, 3]); getitem_2369 = None + view_2980 = torch.ops.aten.view.default(permute_1238, [2, 8192, 1, 4, 128]); permute_1238 = None + sum_167 = torch.ops.aten.sum.dim_IntList(view_2980, [3], True); view_2980 = None + squeeze_54 = torch.ops.aten.squeeze.dim(sum_167, 3); sum_167 = None + view_2981 = torch.ops.aten.view.default(permute_1239, [2, 8192, 1, 4, 128]); permute_1239 = None + sum_168 = torch.ops.aten.sum.dim_IntList(view_2981, [3], True); view_2981 = None + squeeze_55 = torch.ops.aten.squeeze.dim(sum_168, 3); sum_168 = None + convert_element_type_2561 = torch.ops.prims.convert_element_type.default(squeeze_55, torch.float32); squeeze_55 = None + convert_element_type_2562 = torch.ops.prims.convert_element_type.default(permute_1240, torch.float32); permute_1240 = None + view_2982 = torch.ops.aten.view.default(convert_element_type_2561, [2, 8192, 1, 64, 2]); convert_element_type_2561 = None + view_as_complex_118 = torch.ops.aten.view_as_complex.default(view_2982); view_2982 = None + mul_816 = torch.ops.aten.mul.Tensor(view_as_complex_118, _conj); view_as_complex_118 = None + view_2983 = torch.ops.aten.view.default(convert_element_type_2562, [2, 8192, 4, 64, 2]); convert_element_type_2562 = None + view_as_complex_119 = torch.ops.aten.view_as_complex.default(view_2983); view_2983 = None + mul_817 = torch.ops.aten.mul.Tensor(view_as_complex_119, _conj); view_as_complex_119 = None + view_as_real_118 = torch.ops.aten.view_as_real.default(mul_816); mul_816 = None + view_2984 = torch.ops.aten.view.default(view_as_real_118, [2, 8192, 1, 128]); view_as_real_118 = None + convert_element_type_2563 = torch.ops.prims.convert_element_type.default(view_2984, torch.bfloat16); view_2984 = None + view_as_real_119 = torch.ops.aten.view_as_real.default(mul_817); mul_817 = None + view_2985 = torch.ops.aten.view.default(view_as_real_119, [2, 8192, 4, 128]); view_as_real_119 = None + convert_element_type_2564 = torch.ops.prims.convert_element_type.default(view_2985, torch.bfloat16); view_2985 = None + view_2986 = torch.ops.aten.view.default(squeeze_54, [2, 8192, 128]); squeeze_54 = None + view_2987 = torch.ops.aten.view.default(convert_element_type_2563, [2, 8192, 128]); convert_element_type_2563 = None + view_2988 = torch.ops.aten.view.default(convert_element_type_2564, [2, 8192, 512]); convert_element_type_2564 = None + view_2989 = torch.ops.aten.view.default(view_2986, [16384, 128]); view_2986 = None + permute_1241 = torch.ops.aten.permute.default(view_2989, [1, 0]) + mm_613 = torch.ops.aten.mm.default(permute_1241, view_303); permute_1241 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16); primals_43 = None + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 32, '0'); convert_element_type_142 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + permute_1243 = torch.ops.aten.permute.default(permute_46, [1, 0]); permute_46 = None + mm_614 = torch.ops.aten.mm.default(view_2989, permute_1243); view_2989 = permute_1243 = None + view_2990 = torch.ops.aten.view.default(mm_614, [2, 8192, 4096]); mm_614 = None + convert_element_type_2569 = torch.ops.prims.convert_element_type.default(mm_613, torch.float32); mm_613 = None + reduce_scatter_tensor_371 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2569, 'avg', 32, '0'); convert_element_type_2569 = None + wait_tensor_839 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_371); reduce_scatter_tensor_371 = None + view_2991 = torch.ops.aten.view.default(view_2987, [16384, 128]); view_2987 = None + permute_1245 = torch.ops.aten.permute.default(view_2991, [1, 0]) + mm_615 = torch.ops.aten.mm.default(permute_1245, view_303); permute_1245 = None + permute_1247 = torch.ops.aten.permute.default(permute_45, [1, 0]); permute_45 = None + mm_616 = torch.ops.aten.mm.default(view_2991, permute_1247); view_2991 = permute_1247 = None + view_2992 = torch.ops.aten.view.default(mm_616, [2, 8192, 4096]); mm_616 = None + add_322 = torch.ops.aten.add.Tensor(view_2990, view_2992); view_2990 = view_2992 = None + convert_element_type_2574 = torch.ops.prims.convert_element_type.default(mm_615, torch.float32); mm_615 = None + reduce_scatter_tensor_372 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2574, 'avg', 32, '0'); convert_element_type_2574 = None + wait_tensor_840 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_372); reduce_scatter_tensor_372 = None + view_2993 = torch.ops.aten.view.default(view_2988, [16384, 512]); view_2988 = None + permute_1249 = torch.ops.aten.permute.default(view_2993, [1, 0]) + mm_617 = torch.ops.aten.mm.default(permute_1249, view_303); permute_1249 = view_303 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16); primals_41 = None + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 32, '0'); convert_element_type_136 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + permute_1251 = torch.ops.aten.permute.default(permute_44, [1, 0]); permute_44 = None + mm_618 = torch.ops.aten.mm.default(view_2993, permute_1251); view_2993 = permute_1251 = None + view_2994 = torch.ops.aten.view.default(mm_618, [2, 8192, 4096]); mm_618 = None + add_323 = torch.ops.aten.add.Tensor(add_322, view_2994); add_322 = view_2994 = None + convert_element_type_2579 = torch.ops.prims.convert_element_type.default(mm_617, torch.float32); mm_617 = None + reduce_scatter_tensor_373 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2579, 'avg', 32, '0'); convert_element_type_2579 = None + wait_tensor_841 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_373); reduce_scatter_tensor_373 = None + split_250 = torch.ops.aten.split.Tensor(add_323, 1024, 1); add_323 = None + getitem_2372 = split_250[0] + getitem_2373 = split_250[1] + getitem_2374 = split_250[2] + getitem_2375 = split_250[3] + getitem_2376 = split_250[4] + getitem_2377 = split_250[5] + getitem_2378 = split_250[6] + getitem_2379 = split_250[7]; split_250 = None + cat_242 = torch.ops.aten.cat.default([getitem_2372, getitem_2373, getitem_2374, getitem_2375, getitem_2376, getitem_2377, getitem_2378, getitem_2379]); getitem_2372 = getitem_2373 = getitem_2374 = getitem_2375 = getitem_2376 = getitem_2377 = getitem_2378 = getitem_2379 = None + reduce_scatter_tensor_374 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_242, 'sum', 8, '1'); cat_242 = None + wait_tensor_842 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_374); reduce_scatter_tensor_374 = None + convert_element_type_2580 = torch.ops.prims.convert_element_type.default(wait_tensor_842, torch.float32); wait_tensor_842 = None + convert_element_type_2582 = torch.ops.prims.convert_element_type.default(wait_tensor_54, torch.float32); wait_tensor_54 = None + mul_818 = torch.ops.aten.mul.Tensor(convert_element_type_2580, convert_element_type_2582); convert_element_type_2582 = None + mul_820 = torch.ops.aten.mul.Tensor(mul_32, mul_818) + sum_169 = torch.ops.aten.sum.dim_IntList(mul_820, [2], True); mul_820 = None + div_56 = torch.ops.aten.div.Tensor(mul_32, 4096) + mul_821 = torch.ops.aten.mul.Tensor(div_56, sum_169); div_56 = sum_169 = None + sub_85 = torch.ops.aten.sub.Tensor(mul_818, mul_821); mul_818 = mul_821 = None + mul_822 = torch.ops.aten.mul.Tensor(sub_85, rsqrt_8); sub_85 = rsqrt_8 = None + mul_823 = torch.ops.aten.mul.Tensor(convert_element_type_2580, mul_32); convert_element_type_2580 = mul_32 = None + sum_170 = torch.ops.aten.sum.dim_IntList(mul_823, [0, 1]); mul_823 = None + convert_element_type_2583 = torch.ops.prims.convert_element_type.default(mul_822, torch.bfloat16); mul_822 = None + convert_element_type_2584 = torch.ops.prims.convert_element_type.default(sum_170, torch.bfloat16); sum_170 = None + all_reduce_56 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2584, 'sum', '1'); convert_element_type_2584 = None + wait_tensor_843 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_56); all_reduce_56 = None + convert_element_type_2585 = torch.ops.prims.convert_element_type.default(wait_tensor_843, torch.float32); wait_tensor_843 = None + reduce_scatter_tensor_375 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2585, 'avg', 32, '0'); convert_element_type_2585 = None + wait_tensor_844 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_375); reduce_scatter_tensor_375 = None + add_324 = torch.ops.aten.add.Tensor(add_321, convert_element_type_2583); add_321 = convert_element_type_2583 = None + all_gather_into_tensor_412 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_324, 8, '1') + wait_tensor_845 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_412); all_gather_into_tensor_412 = None + split_251 = torch.ops.aten.split.Tensor(wait_tensor_845, 2); wait_tensor_845 = None + getitem_2380 = split_251[0] + getitem_2381 = split_251[1] + getitem_2382 = split_251[2] + getitem_2383 = split_251[3] + getitem_2384 = split_251[4] + getitem_2385 = split_251[5] + getitem_2386 = split_251[6] + getitem_2387 = split_251[7]; split_251 = None + cat_243 = torch.ops.aten.cat.default([getitem_2380, getitem_2381, getitem_2382, getitem_2383, getitem_2384, getitem_2385, getitem_2386, getitem_2387], 1); getitem_2380 = getitem_2381 = getitem_2382 = getitem_2383 = getitem_2384 = getitem_2385 = getitem_2386 = getitem_2387 = None + view_2995 = torch.ops.aten.view.default(cat_243, [16384, 4096]); cat_243 = None + permute_1253 = torch.ops.aten.permute.default(view_2995, [1, 0]) + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7); reduce_scatter_tensor_7 = None + add_13 = torch.ops.aten.add.Tensor(add_11, wait_tensor_47); wait_tensor_47 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 32, '0'); convert_element_type_119 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32); add_13 = None + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_48) + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_121, 8, '1'); convert_element_type_121 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + split_23 = torch.ops.aten.split.Tensor(wait_tensor_49, 2); wait_tensor_49 = None + getitem_220 = split_23[0] + getitem_221 = split_23[1] + getitem_222 = split_23[2] + getitem_223 = split_23[3] + getitem_224 = split_23[4] + getitem_225 = split_23[5] + getitem_226 = split_23[6] + getitem_227 = split_23[7]; split_23 = None + cat_15 = torch.ops.aten.cat.default([getitem_220, getitem_221, getitem_222, getitem_223, getitem_224, getitem_225, getitem_226, getitem_227], 1); getitem_220 = getitem_221 = getitem_222 = getitem_223 = getitem_224 = getitem_225 = getitem_226 = getitem_227 = None + view_276 = torch.ops.aten.view.default(cat_15, [16384, 4096]); cat_15 = None + view_277 = torch.ops.aten.view.default(mm_25, [2, 8192, 1792]); mm_25 = None + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_277, torch.float32); view_277 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16); primals_38 = None + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 32, '0'); convert_element_type_127 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + mm_26 = torch.ops.aten.mm.default(view_276, permute_42) + view_284 = torch.ops.aten.view.default(mm_26, [2, 8192, 1792]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_284) + view_291 = torch.ops.aten.view.default(mul_31, [16384, 1792]); mul_31 = None + mm_619 = torch.ops.aten.mm.default(permute_1253, view_291); permute_1253 = view_291 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16); primals_39 = None + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 32, '0'); convert_element_type_130 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + permute_1255 = torch.ops.aten.permute.default(permute_43, [1, 0]); permute_43 = None + mm_620 = torch.ops.aten.mm.default(view_2995, permute_1255); view_2995 = permute_1255 = None + view_2996 = torch.ops.aten.view.default(mm_620, [2, 8192, 1792]); mm_620 = None + convert_element_type_2590 = torch.ops.prims.convert_element_type.default(mm_619, torch.float32); mm_619 = None + reduce_scatter_tensor_376 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2590, 'avg', 32, '0'); convert_element_type_2590 = None + wait_tensor_846 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_376); reduce_scatter_tensor_376 = None + mul_824 = torch.ops.aten.mul.Tensor(view_2996, convert_element_type_126); convert_element_type_126 = None + mul_825 = torch.ops.aten.mul.Tensor(view_2996, view_284); view_2996 = view_284 = None + view_2997 = torch.ops.aten.view.default(mul_824, [16384, 1792]); mul_824 = None + permute_1257 = torch.ops.aten.permute.default(view_2997, [1, 0]) + mm_621 = torch.ops.aten.mm.default(permute_1257, view_276); permute_1257 = None + permute_1259 = torch.ops.aten.permute.default(permute_42, [1, 0]); permute_42 = None + mm_622 = torch.ops.aten.mm.default(view_2997, permute_1259); view_2997 = permute_1259 = None + view_2998 = torch.ops.aten.view.default(mm_622, [2, 8192, 4096]); mm_622 = None + convert_element_type_2595 = torch.ops.prims.convert_element_type.default(mm_621, torch.float32); mm_621 = None + reduce_scatter_tensor_377 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2595, 'avg', 32, '0'); convert_element_type_2595 = None + wait_tensor_847 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_377); reduce_scatter_tensor_377 = None + convert_element_type_2596 = torch.ops.prims.convert_element_type.default(mul_825, torch.float32); mul_825 = None + neg_28 = torch.ops.aten.neg.default(convert_element_type_125) + exp_28 = torch.ops.aten.exp.default(neg_28); neg_28 = None + add_325 = torch.ops.aten.add.Tensor(exp_28, 1); exp_28 = None + reciprocal_28 = torch.ops.aten.reciprocal.default(add_325); add_325 = None + mul_826 = torch.ops.aten.mul.Tensor(reciprocal_28, 1); reciprocal_28 = None + mul_827 = torch.ops.aten.mul.Tensor(convert_element_type_2596, mul_826); convert_element_type_2596 = None + sub_86 = torch.ops.aten.sub.Tensor(1, mul_826); mul_826 = None + mul_828 = torch.ops.aten.mul.Tensor(convert_element_type_125, sub_86); convert_element_type_125 = sub_86 = None + add_326 = torch.ops.aten.add.Tensor(mul_828, 1); mul_828 = None + mul_829 = torch.ops.aten.mul.Tensor(mul_827, add_326); mul_827 = add_326 = None + convert_element_type_2598 = torch.ops.prims.convert_element_type.default(mul_829, torch.bfloat16); mul_829 = None + view_2999 = torch.ops.aten.view.default(convert_element_type_2598, [16384, 1792]); convert_element_type_2598 = None + permute_1261 = torch.ops.aten.permute.default(view_2999, [1, 0]) + mm_623 = torch.ops.aten.mm.default(permute_1261, view_276); permute_1261 = view_276 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16); primals_37 = None + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 32, '0'); convert_element_type_122 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + permute_1263 = torch.ops.aten.permute.default(permute_41, [1, 0]); permute_41 = None + mm_624 = torch.ops.aten.mm.default(view_2999, permute_1263); view_2999 = permute_1263 = None + view_3000 = torch.ops.aten.view.default(mm_624, [2, 8192, 4096]); mm_624 = None + add_327 = torch.ops.aten.add.Tensor(view_2998, view_3000); view_2998 = view_3000 = None + convert_element_type_2603 = torch.ops.prims.convert_element_type.default(mm_623, torch.float32); mm_623 = None + reduce_scatter_tensor_378 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2603, 'avg', 32, '0'); convert_element_type_2603 = None + wait_tensor_848 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_378); reduce_scatter_tensor_378 = None + split_252 = torch.ops.aten.split.Tensor(add_327, 1024, 1); add_327 = None + getitem_2388 = split_252[0] + getitem_2389 = split_252[1] + getitem_2390 = split_252[2] + getitem_2391 = split_252[3] + getitem_2392 = split_252[4] + getitem_2393 = split_252[5] + getitem_2394 = split_252[6] + getitem_2395 = split_252[7]; split_252 = None + cat_244 = torch.ops.aten.cat.default([getitem_2388, getitem_2389, getitem_2390, getitem_2391, getitem_2392, getitem_2393, getitem_2394, getitem_2395]); getitem_2388 = getitem_2389 = getitem_2390 = getitem_2391 = getitem_2392 = getitem_2393 = getitem_2394 = getitem_2395 = None + reduce_scatter_tensor_379 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_244, 'sum', 8, '1'); cat_244 = None + wait_tensor_849 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_379); reduce_scatter_tensor_379 = None + convert_element_type_2604 = torch.ops.prims.convert_element_type.default(wait_tensor_849, torch.float32); wait_tensor_849 = None + convert_element_type_2606 = torch.ops.prims.convert_element_type.default(wait_tensor_48, torch.float32); wait_tensor_48 = None + mul_830 = torch.ops.aten.mul.Tensor(convert_element_type_2604, convert_element_type_2606); convert_element_type_2606 = None + mul_832 = torch.ops.aten.mul.Tensor(mul_28, mul_830) + sum_171 = torch.ops.aten.sum.dim_IntList(mul_832, [2], True); mul_832 = None + div_57 = torch.ops.aten.div.Tensor(mul_28, 4096) + mul_833 = torch.ops.aten.mul.Tensor(div_57, sum_171); div_57 = sum_171 = None + sub_87 = torch.ops.aten.sub.Tensor(mul_830, mul_833); mul_830 = mul_833 = None + mul_834 = torch.ops.aten.mul.Tensor(sub_87, rsqrt_7); sub_87 = rsqrt_7 = None + mul_835 = torch.ops.aten.mul.Tensor(convert_element_type_2604, mul_28); convert_element_type_2604 = mul_28 = None + sum_172 = torch.ops.aten.sum.dim_IntList(mul_835, [0, 1]); mul_835 = None + convert_element_type_2607 = torch.ops.prims.convert_element_type.default(mul_834, torch.bfloat16); mul_834 = None + convert_element_type_2608 = torch.ops.prims.convert_element_type.default(sum_172, torch.bfloat16); sum_172 = None + all_reduce_57 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2608, 'sum', '1'); convert_element_type_2608 = None + wait_tensor_850 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_57); all_reduce_57 = None + convert_element_type_2609 = torch.ops.prims.convert_element_type.default(wait_tensor_850, torch.float32); wait_tensor_850 = None + reduce_scatter_tensor_380 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2609, 'avg', 32, '0'); convert_element_type_2609 = None + wait_tensor_851 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_380); reduce_scatter_tensor_380 = None + add_328 = torch.ops.aten.add.Tensor(add_324, convert_element_type_2607); add_324 = convert_element_type_2607 = None + all_gather_into_tensor_413 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_328, 8, '1') + wait_tensor_852 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_413); all_gather_into_tensor_413 = None + split_253 = torch.ops.aten.split.Tensor(wait_tensor_852, 2); wait_tensor_852 = None + getitem_2396 = split_253[0] + getitem_2397 = split_253[1] + getitem_2398 = split_253[2] + getitem_2399 = split_253[3] + getitem_2400 = split_253[4] + getitem_2401 = split_253[5] + getitem_2402 = split_253[6] + getitem_2403 = split_253[7]; split_253 = None + cat_245 = torch.ops.aten.cat.default([getitem_2396, getitem_2397, getitem_2398, getitem_2399, getitem_2400, getitem_2401, getitem_2402, getitem_2403], 1); getitem_2396 = getitem_2397 = getitem_2398 = getitem_2399 = getitem_2400 = getitem_2401 = getitem_2402 = getitem_2403 = None + view_3001 = torch.ops.aten.view.default(cat_245, [16384, 4096]); cat_245 = None + permute_1265 = torch.ops.aten.permute.default(view_3001, [1, 0]) + permute_39 = torch.ops.aten.permute.default(getitem_203, [0, 2, 1, 3]) + view_258 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + view_264 = torch.ops.aten.view.default(view_258, [16384, 512]); view_258 = None + mm_625 = torch.ops.aten.mm.default(permute_1265, view_264); permute_1265 = view_264 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 32, '0'); convert_element_type_116 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_46, [1, 0]); wait_tensor_46 = None + permute_1267 = torch.ops.aten.permute.default(permute_40, [1, 0]); permute_40 = None + mm_626 = torch.ops.aten.mm.default(view_3001, permute_1267); view_3001 = permute_1267 = None + view_3002 = torch.ops.aten.view.default(mm_626, [2, 8192, 512]); mm_626 = None + convert_element_type_2614 = torch.ops.prims.convert_element_type.default(mm_625, torch.float32); mm_625 = None + reduce_scatter_tensor_381 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2614, 'avg', 32, '0'); convert_element_type_2614 = None + wait_tensor_853 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_381); reduce_scatter_tensor_381 = None + view_3003 = torch.ops.aten.view.default(view_3002, [2, 8192, 4, 128]); view_3002 = None + permute_1269 = torch.ops.aten.permute.default(view_3003, [0, 2, 1, 3]); view_3003 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 32, '0'); convert_element_type_100 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32); add_11 = None + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_41) + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_102, 8, '1'); convert_element_type_102 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + split_21 = torch.ops.aten.split.Tensor(wait_tensor_42, 2); wait_tensor_42 = None + getitem_195 = split_21[0] + getitem_196 = split_21[1] + getitem_197 = split_21[2] + getitem_198 = split_21[3] + getitem_199 = split_21[4] + getitem_200 = split_21[5] + getitem_201 = split_21[6] + getitem_202 = split_21[7]; split_21 = None + cat_13 = torch.ops.aten.cat.default([getitem_195, getitem_196, getitem_197, getitem_198, getitem_199, getitem_200, getitem_201, getitem_202], 1); getitem_195 = getitem_196 = getitem_197 = getitem_198 = getitem_199 = getitem_200 = getitem_201 = getitem_202 = None + view_231 = torch.ops.aten.view.default(cat_13, [16384, 4096]); cat_13 = None + view_232 = torch.ops.aten.view.default(mm_21, [2, 8192, 512]); mm_21 = None + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 32, '0'); convert_element_type_106 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_22 = torch.ops.aten.mm.default(view_231, permute_34) + view_239 = torch.ops.aten.view.default(mm_22, [2, 8192, 128]); mm_22 = None + view_246 = torch.ops.aten.view.default(mm_23, [2, 8192, 128]); mm_23 = None + view_248 = torch.ops.aten.view.default(view_232, [2, 8192, -1, 128]); view_232 = None + view_249 = torch.ops.aten.view.default(view_239, [2, 8192, -1, 128]); view_239 = None + view_250 = torch.ops.aten.view.default(view_246, [2, 8192, -1, 128]); view_246 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_248, torch.float32); view_248 = None + view_251 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 4, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_251); view_251 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 1, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_37); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_254 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 4, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_37); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_255 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 1, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_254, torch.bfloat16); view_254 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 1, 4, 128]); unsqueeze_6 = None + view_256 = torch.ops.aten.view.default(expand_6, [2, 8192, 4, 128]); expand_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_250, 3); view_250 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 1, 4, 128]); unsqueeze_7 = None + view_257 = torch.ops.aten.view.default(expand_7, [2, 8192, 4, 128]); expand_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_256, [0, 2, 1, 3]); view_256 = None + permute_38 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + _scaled_dot_product_cudnn_attention_backward_28 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1269, permute_36, permute_37, permute_38, getitem_203, getitem_204, getitem_209, getitem_210, None, None, None, 8192, 8192, 0.0, True); permute_1269 = permute_36 = permute_37 = permute_38 = getitem_203 = getitem_204 = getitem_209 = getitem_210 = None + getitem_2404 = _scaled_dot_product_cudnn_attention_backward_28[0] + getitem_2405 = _scaled_dot_product_cudnn_attention_backward_28[1] + getitem_2406 = _scaled_dot_product_cudnn_attention_backward_28[2]; _scaled_dot_product_cudnn_attention_backward_28 = None + permute_1270 = torch.ops.aten.permute.default(getitem_2406, [0, 2, 1, 3]); getitem_2406 = None + permute_1271 = torch.ops.aten.permute.default(getitem_2405, [0, 2, 1, 3]); getitem_2405 = None + permute_1272 = torch.ops.aten.permute.default(getitem_2404, [0, 2, 1, 3]); getitem_2404 = None + view_3004 = torch.ops.aten.view.default(permute_1270, [2, 8192, 1, 4, 128]); permute_1270 = None + sum_173 = torch.ops.aten.sum.dim_IntList(view_3004, [3], True); view_3004 = None + squeeze_56 = torch.ops.aten.squeeze.dim(sum_173, 3); sum_173 = None + view_3005 = torch.ops.aten.view.default(permute_1271, [2, 8192, 1, 4, 128]); permute_1271 = None + sum_174 = torch.ops.aten.sum.dim_IntList(view_3005, [3], True); view_3005 = None + squeeze_57 = torch.ops.aten.squeeze.dim(sum_174, 3); sum_174 = None + convert_element_type_2615 = torch.ops.prims.convert_element_type.default(squeeze_57, torch.float32); squeeze_57 = None + convert_element_type_2616 = torch.ops.prims.convert_element_type.default(permute_1272, torch.float32); permute_1272 = None + view_3006 = torch.ops.aten.view.default(convert_element_type_2615, [2, 8192, 1, 64, 2]); convert_element_type_2615 = None + view_as_complex_120 = torch.ops.aten.view_as_complex.default(view_3006); view_3006 = None + mul_836 = torch.ops.aten.mul.Tensor(view_as_complex_120, _conj); view_as_complex_120 = None + view_3007 = torch.ops.aten.view.default(convert_element_type_2616, [2, 8192, 4, 64, 2]); convert_element_type_2616 = None + view_as_complex_121 = torch.ops.aten.view_as_complex.default(view_3007); view_3007 = None + mul_837 = torch.ops.aten.mul.Tensor(view_as_complex_121, _conj); view_as_complex_121 = None + view_as_real_120 = torch.ops.aten.view_as_real.default(mul_836); mul_836 = None + view_3008 = torch.ops.aten.view.default(view_as_real_120, [2, 8192, 1, 128]); view_as_real_120 = None + convert_element_type_2617 = torch.ops.prims.convert_element_type.default(view_3008, torch.bfloat16); view_3008 = None + view_as_real_121 = torch.ops.aten.view_as_real.default(mul_837); mul_837 = None + view_3009 = torch.ops.aten.view.default(view_as_real_121, [2, 8192, 4, 128]); view_as_real_121 = None + convert_element_type_2618 = torch.ops.prims.convert_element_type.default(view_3009, torch.bfloat16); view_3009 = None + view_3010 = torch.ops.aten.view.default(squeeze_56, [2, 8192, 128]); squeeze_56 = None + view_3011 = torch.ops.aten.view.default(convert_element_type_2617, [2, 8192, 128]); convert_element_type_2617 = None + view_3012 = torch.ops.aten.view.default(convert_element_type_2618, [2, 8192, 512]); convert_element_type_2618 = None + view_3013 = torch.ops.aten.view.default(view_3010, [16384, 128]); view_3010 = None + permute_1273 = torch.ops.aten.permute.default(view_3013, [1, 0]) + mm_627 = torch.ops.aten.mm.default(permute_1273, view_231); permute_1273 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 32, '0'); convert_element_type_109 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + permute_1275 = torch.ops.aten.permute.default(permute_35, [1, 0]); permute_35 = None + mm_628 = torch.ops.aten.mm.default(view_3013, permute_1275); view_3013 = permute_1275 = None + view_3014 = torch.ops.aten.view.default(mm_628, [2, 8192, 4096]); mm_628 = None + convert_element_type_2623 = torch.ops.prims.convert_element_type.default(mm_627, torch.float32); mm_627 = None + reduce_scatter_tensor_382 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2623, 'avg', 32, '0'); convert_element_type_2623 = None + wait_tensor_854 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_382); reduce_scatter_tensor_382 = None + view_3015 = torch.ops.aten.view.default(view_3011, [16384, 128]); view_3011 = None + permute_1277 = torch.ops.aten.permute.default(view_3015, [1, 0]) + mm_629 = torch.ops.aten.mm.default(permute_1277, view_231); permute_1277 = None + permute_1279 = torch.ops.aten.permute.default(permute_34, [1, 0]); permute_34 = None + mm_630 = torch.ops.aten.mm.default(view_3015, permute_1279); view_3015 = permute_1279 = None + view_3016 = torch.ops.aten.view.default(mm_630, [2, 8192, 4096]); mm_630 = None + add_329 = torch.ops.aten.add.Tensor(view_3014, view_3016); view_3014 = view_3016 = None + convert_element_type_2628 = torch.ops.prims.convert_element_type.default(mm_629, torch.float32); mm_629 = None + reduce_scatter_tensor_383 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2628, 'avg', 32, '0'); convert_element_type_2628 = None + wait_tensor_855 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_383); reduce_scatter_tensor_383 = None + view_3017 = torch.ops.aten.view.default(view_3012, [16384, 512]); view_3012 = None + permute_1281 = torch.ops.aten.permute.default(view_3017, [1, 0]) + mm_631 = torch.ops.aten.mm.default(permute_1281, view_231); permute_1281 = view_231 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16); primals_32 = None + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 32, '0'); convert_element_type_103 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + permute_1283 = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None + mm_632 = torch.ops.aten.mm.default(view_3017, permute_1283); view_3017 = permute_1283 = None + view_3018 = torch.ops.aten.view.default(mm_632, [2, 8192, 4096]); mm_632 = None + add_330 = torch.ops.aten.add.Tensor(add_329, view_3018); add_329 = view_3018 = None + convert_element_type_2633 = torch.ops.prims.convert_element_type.default(mm_631, torch.float32); mm_631 = None + reduce_scatter_tensor_384 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2633, 'avg', 32, '0'); convert_element_type_2633 = None + wait_tensor_856 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_384); reduce_scatter_tensor_384 = None + split_254 = torch.ops.aten.split.Tensor(add_330, 1024, 1); add_330 = None + getitem_2407 = split_254[0] + getitem_2408 = split_254[1] + getitem_2409 = split_254[2] + getitem_2410 = split_254[3] + getitem_2411 = split_254[4] + getitem_2412 = split_254[5] + getitem_2413 = split_254[6] + getitem_2414 = split_254[7]; split_254 = None + cat_246 = torch.ops.aten.cat.default([getitem_2407, getitem_2408, getitem_2409, getitem_2410, getitem_2411, getitem_2412, getitem_2413, getitem_2414]); getitem_2407 = getitem_2408 = getitem_2409 = getitem_2410 = getitem_2411 = getitem_2412 = getitem_2413 = getitem_2414 = None + reduce_scatter_tensor_385 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_246, 'sum', 8, '1'); cat_246 = None + wait_tensor_857 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_385); reduce_scatter_tensor_385 = None + convert_element_type_2634 = torch.ops.prims.convert_element_type.default(wait_tensor_857, torch.float32); wait_tensor_857 = None + convert_element_type_2636 = torch.ops.prims.convert_element_type.default(wait_tensor_41, torch.float32); wait_tensor_41 = None + mul_838 = torch.ops.aten.mul.Tensor(convert_element_type_2634, convert_element_type_2636); convert_element_type_2636 = None + mul_840 = torch.ops.aten.mul.Tensor(mul_24, mul_838) + sum_175 = torch.ops.aten.sum.dim_IntList(mul_840, [2], True); mul_840 = None + div_58 = torch.ops.aten.div.Tensor(mul_24, 4096) + mul_841 = torch.ops.aten.mul.Tensor(div_58, sum_175); div_58 = sum_175 = None + sub_88 = torch.ops.aten.sub.Tensor(mul_838, mul_841); mul_838 = mul_841 = None + mul_842 = torch.ops.aten.mul.Tensor(sub_88, rsqrt_6); sub_88 = rsqrt_6 = None + mul_843 = torch.ops.aten.mul.Tensor(convert_element_type_2634, mul_24); convert_element_type_2634 = mul_24 = None + sum_176 = torch.ops.aten.sum.dim_IntList(mul_843, [0, 1]); mul_843 = None + convert_element_type_2637 = torch.ops.prims.convert_element_type.default(mul_842, torch.bfloat16); mul_842 = None + convert_element_type_2638 = torch.ops.prims.convert_element_type.default(sum_176, torch.bfloat16); sum_176 = None + all_reduce_58 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2638, 'sum', '1'); convert_element_type_2638 = None + wait_tensor_858 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_58); all_reduce_58 = None + convert_element_type_2639 = torch.ops.prims.convert_element_type.default(wait_tensor_858, torch.float32); wait_tensor_858 = None + reduce_scatter_tensor_386 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2639, 'avg', 32, '0'); convert_element_type_2639 = None + wait_tensor_859 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_386); reduce_scatter_tensor_386 = None + add_331 = torch.ops.aten.add.Tensor(add_328, convert_element_type_2637); add_328 = convert_element_type_2637 = None + all_gather_into_tensor_414 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_331, 8, '1') + wait_tensor_860 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_414); all_gather_into_tensor_414 = None + split_255 = torch.ops.aten.split.Tensor(wait_tensor_860, 2); wait_tensor_860 = None + getitem_2415 = split_255[0] + getitem_2416 = split_255[1] + getitem_2417 = split_255[2] + getitem_2418 = split_255[3] + getitem_2419 = split_255[4] + getitem_2420 = split_255[5] + getitem_2421 = split_255[6] + getitem_2422 = split_255[7]; split_255 = None + cat_247 = torch.ops.aten.cat.default([getitem_2415, getitem_2416, getitem_2417, getitem_2418, getitem_2419, getitem_2420, getitem_2421, getitem_2422], 1); getitem_2415 = getitem_2416 = getitem_2417 = getitem_2418 = getitem_2419 = getitem_2420 = getitem_2421 = getitem_2422 = None + view_3019 = torch.ops.aten.view.default(cat_247, [16384, 4096]); cat_247 = None + permute_1285 = torch.ops.aten.permute.default(view_3019, [1, 0]) + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5); reduce_scatter_tensor_5 = None + add_9 = torch.ops.aten.add.Tensor(add_7, wait_tensor_34); wait_tensor_34 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16); primals_27 = None + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 32, '0'); convert_element_type_86 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32); add_9 = None + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_35) + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_88, 8, '1'); convert_element_type_88 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + split_19 = torch.ops.aten.split.Tensor(wait_tensor_36, 2); wait_tensor_36 = None + getitem_179 = split_19[0] + getitem_180 = split_19[1] + getitem_181 = split_19[2] + getitem_182 = split_19[3] + getitem_183 = split_19[4] + getitem_184 = split_19[5] + getitem_185 = split_19[6] + getitem_186 = split_19[7]; split_19 = None + cat_11 = torch.ops.aten.cat.default([getitem_179, getitem_180, getitem_181, getitem_182, getitem_183, getitem_184, getitem_185, getitem_186], 1); getitem_179 = getitem_180 = getitem_181 = getitem_182 = getitem_183 = getitem_184 = getitem_185 = getitem_186 = None + view_204 = torch.ops.aten.view.default(cat_11, [16384, 4096]); cat_11 = None + view_205 = torch.ops.aten.view.default(mm_18, [2, 8192, 1792]); mm_18 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_205, torch.float32); view_205 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 32, '0'); convert_element_type_94 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + mm_19 = torch.ops.aten.mm.default(view_204, permute_31) + view_212 = torch.ops.aten.view.default(mm_19, [2, 8192, 1792]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_212) + view_219 = torch.ops.aten.view.default(mul_23, [16384, 1792]); mul_23 = None + mm_633 = torch.ops.aten.mm.default(permute_1285, view_219); permute_1285 = view_219 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16); primals_30 = None + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 32, '0'); convert_element_type_97 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + permute_1287 = torch.ops.aten.permute.default(permute_32, [1, 0]); permute_32 = None + mm_634 = torch.ops.aten.mm.default(view_3019, permute_1287); view_3019 = permute_1287 = None + view_3020 = torch.ops.aten.view.default(mm_634, [2, 8192, 1792]); mm_634 = None + convert_element_type_2644 = torch.ops.prims.convert_element_type.default(mm_633, torch.float32); mm_633 = None + reduce_scatter_tensor_387 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2644, 'avg', 32, '0'); convert_element_type_2644 = None + wait_tensor_861 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_387); reduce_scatter_tensor_387 = None + mul_844 = torch.ops.aten.mul.Tensor(view_3020, convert_element_type_93); convert_element_type_93 = None + mul_845 = torch.ops.aten.mul.Tensor(view_3020, view_212); view_3020 = view_212 = None + view_3021 = torch.ops.aten.view.default(mul_844, [16384, 1792]); mul_844 = None + permute_1289 = torch.ops.aten.permute.default(view_3021, [1, 0]) + mm_635 = torch.ops.aten.mm.default(permute_1289, view_204); permute_1289 = None + permute_1291 = torch.ops.aten.permute.default(permute_31, [1, 0]); permute_31 = None + mm_636 = torch.ops.aten.mm.default(view_3021, permute_1291); view_3021 = permute_1291 = None + view_3022 = torch.ops.aten.view.default(mm_636, [2, 8192, 4096]); mm_636 = None + convert_element_type_2649 = torch.ops.prims.convert_element_type.default(mm_635, torch.float32); mm_635 = None + reduce_scatter_tensor_388 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2649, 'avg', 32, '0'); convert_element_type_2649 = None + wait_tensor_862 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_388); reduce_scatter_tensor_388 = None + convert_element_type_2650 = torch.ops.prims.convert_element_type.default(mul_845, torch.float32); mul_845 = None + neg_29 = torch.ops.aten.neg.default(convert_element_type_92) + exp_29 = torch.ops.aten.exp.default(neg_29); neg_29 = None + add_332 = torch.ops.aten.add.Tensor(exp_29, 1); exp_29 = None + reciprocal_29 = torch.ops.aten.reciprocal.default(add_332); add_332 = None + mul_846 = torch.ops.aten.mul.Tensor(reciprocal_29, 1); reciprocal_29 = None + mul_847 = torch.ops.aten.mul.Tensor(convert_element_type_2650, mul_846); convert_element_type_2650 = None + sub_89 = torch.ops.aten.sub.Tensor(1, mul_846); mul_846 = None + mul_848 = torch.ops.aten.mul.Tensor(convert_element_type_92, sub_89); convert_element_type_92 = sub_89 = None + add_333 = torch.ops.aten.add.Tensor(mul_848, 1); mul_848 = None + mul_849 = torch.ops.aten.mul.Tensor(mul_847, add_333); mul_847 = add_333 = None + convert_element_type_2652 = torch.ops.prims.convert_element_type.default(mul_849, torch.bfloat16); mul_849 = None + view_3023 = torch.ops.aten.view.default(convert_element_type_2652, [16384, 1792]); convert_element_type_2652 = None + permute_1293 = torch.ops.aten.permute.default(view_3023, [1, 0]) + mm_637 = torch.ops.aten.mm.default(permute_1293, view_204); permute_1293 = view_204 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 32, '0'); convert_element_type_89 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + permute_1295 = torch.ops.aten.permute.default(permute_30, [1, 0]); permute_30 = None + mm_638 = torch.ops.aten.mm.default(view_3023, permute_1295); view_3023 = permute_1295 = None + view_3024 = torch.ops.aten.view.default(mm_638, [2, 8192, 4096]); mm_638 = None + add_334 = torch.ops.aten.add.Tensor(view_3022, view_3024); view_3022 = view_3024 = None + convert_element_type_2657 = torch.ops.prims.convert_element_type.default(mm_637, torch.float32); mm_637 = None + reduce_scatter_tensor_389 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2657, 'avg', 32, '0'); convert_element_type_2657 = None + wait_tensor_863 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_389); reduce_scatter_tensor_389 = None + split_256 = torch.ops.aten.split.Tensor(add_334, 1024, 1); add_334 = None + getitem_2423 = split_256[0] + getitem_2424 = split_256[1] + getitem_2425 = split_256[2] + getitem_2426 = split_256[3] + getitem_2427 = split_256[4] + getitem_2428 = split_256[5] + getitem_2429 = split_256[6] + getitem_2430 = split_256[7]; split_256 = None + cat_248 = torch.ops.aten.cat.default([getitem_2423, getitem_2424, getitem_2425, getitem_2426, getitem_2427, getitem_2428, getitem_2429, getitem_2430]); getitem_2423 = getitem_2424 = getitem_2425 = getitem_2426 = getitem_2427 = getitem_2428 = getitem_2429 = getitem_2430 = None + reduce_scatter_tensor_390 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_248, 'sum', 8, '1'); cat_248 = None + wait_tensor_864 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_390); reduce_scatter_tensor_390 = None + convert_element_type_2658 = torch.ops.prims.convert_element_type.default(wait_tensor_864, torch.float32); wait_tensor_864 = None + convert_element_type_2660 = torch.ops.prims.convert_element_type.default(wait_tensor_35, torch.float32); wait_tensor_35 = None + mul_850 = torch.ops.aten.mul.Tensor(convert_element_type_2658, convert_element_type_2660); convert_element_type_2660 = None + mul_852 = torch.ops.aten.mul.Tensor(mul_20, mul_850) + sum_177 = torch.ops.aten.sum.dim_IntList(mul_852, [2], True); mul_852 = None + div_59 = torch.ops.aten.div.Tensor(mul_20, 4096) + mul_853 = torch.ops.aten.mul.Tensor(div_59, sum_177); div_59 = sum_177 = None + sub_90 = torch.ops.aten.sub.Tensor(mul_850, mul_853); mul_850 = mul_853 = None + mul_854 = torch.ops.aten.mul.Tensor(sub_90, rsqrt_5); sub_90 = rsqrt_5 = None + mul_855 = torch.ops.aten.mul.Tensor(convert_element_type_2658, mul_20); convert_element_type_2658 = mul_20 = None + sum_178 = torch.ops.aten.sum.dim_IntList(mul_855, [0, 1]); mul_855 = None + convert_element_type_2661 = torch.ops.prims.convert_element_type.default(mul_854, torch.bfloat16); mul_854 = None + convert_element_type_2662 = torch.ops.prims.convert_element_type.default(sum_178, torch.bfloat16); sum_178 = None + all_reduce_59 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2662, 'sum', '1'); convert_element_type_2662 = None + wait_tensor_865 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_59); all_reduce_59 = None + convert_element_type_2663 = torch.ops.prims.convert_element_type.default(wait_tensor_865, torch.float32); wait_tensor_865 = None + reduce_scatter_tensor_391 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2663, 'avg', 32, '0'); convert_element_type_2663 = None + wait_tensor_866 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_391); reduce_scatter_tensor_391 = None + add_335 = torch.ops.aten.add.Tensor(add_331, convert_element_type_2661); add_331 = convert_element_type_2661 = None + all_gather_into_tensor_415 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_335, 8, '1') + wait_tensor_867 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_415); all_gather_into_tensor_415 = None + split_257 = torch.ops.aten.split.Tensor(wait_tensor_867, 2); wait_tensor_867 = None + getitem_2431 = split_257[0] + getitem_2432 = split_257[1] + getitem_2433 = split_257[2] + getitem_2434 = split_257[3] + getitem_2435 = split_257[4] + getitem_2436 = split_257[5] + getitem_2437 = split_257[6] + getitem_2438 = split_257[7]; split_257 = None + cat_249 = torch.ops.aten.cat.default([getitem_2431, getitem_2432, getitem_2433, getitem_2434, getitem_2435, getitem_2436, getitem_2437, getitem_2438], 1); getitem_2431 = getitem_2432 = getitem_2433 = getitem_2434 = getitem_2435 = getitem_2436 = getitem_2437 = getitem_2438 = None + view_3025 = torch.ops.aten.view.default(cat_249, [16384, 4096]); cat_249 = None + permute_1297 = torch.ops.aten.permute.default(view_3025, [1, 0]) + permute_28 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_186 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + view_192 = torch.ops.aten.view.default(view_186, [16384, 512]); view_186 = None + mm_639 = torch.ops.aten.mm.default(permute_1297, view_192); permute_1297 = view_192 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16); primals_26 = None + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 32, '0'); convert_element_type_83 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + permute_1299 = torch.ops.aten.permute.default(permute_29, [1, 0]); permute_29 = None + mm_640 = torch.ops.aten.mm.default(view_3025, permute_1299); view_3025 = permute_1299 = None + view_3026 = torch.ops.aten.view.default(mm_640, [2, 8192, 512]); mm_640 = None + convert_element_type_2668 = torch.ops.prims.convert_element_type.default(mm_639, torch.float32); mm_639 = None + reduce_scatter_tensor_392 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2668, 'avg', 32, '0'); convert_element_type_2668 = None + wait_tensor_868 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_392); reduce_scatter_tensor_392 = None + view_3027 = torch.ops.aten.view.default(view_3026, [2, 8192, 4, 128]); view_3026 = None + permute_1301 = torch.ops.aten.permute.default(view_3027, [0, 2, 1, 3]); view_3027 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16); primals_22 = None + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 32, '0'); convert_element_type_67 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32); add_7 = None + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_28) + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_69, 8, '1'); convert_element_type_69 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + split_17 = torch.ops.aten.split.Tensor(wait_tensor_29, 2); wait_tensor_29 = None + getitem_154 = split_17[0] + getitem_155 = split_17[1] + getitem_156 = split_17[2] + getitem_157 = split_17[3] + getitem_158 = split_17[4] + getitem_159 = split_17[5] + getitem_160 = split_17[6] + getitem_161 = split_17[7]; split_17 = None + cat_9 = torch.ops.aten.cat.default([getitem_154, getitem_155, getitem_156, getitem_157, getitem_158, getitem_159, getitem_160, getitem_161], 1); getitem_154 = getitem_155 = getitem_156 = getitem_157 = getitem_158 = getitem_159 = getitem_160 = getitem_161 = None + view_159 = torch.ops.aten.view.default(cat_9, [16384, 4096]); cat_9 = None + view_160 = torch.ops.aten.view.default(mm_14, [2, 8192, 512]); mm_14 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16); primals_24 = None + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 32, '0'); convert_element_type_73 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_15 = torch.ops.aten.mm.default(view_159, permute_23) + view_167 = torch.ops.aten.view.default(mm_15, [2, 8192, 128]); mm_15 = None + view_174 = torch.ops.aten.view.default(mm_16, [2, 8192, 128]); mm_16 = None + view_176 = torch.ops.aten.view.default(view_160, [2, 8192, -1, 128]); view_160 = None + view_177 = torch.ops.aten.view.default(view_167, [2, 8192, -1, 128]); view_167 = None + view_178 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_176, torch.float32); view_176 = None + view_179 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 4, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_179); view_179 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_177, torch.float32); view_177 = None + view_180 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 1, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_180); view_180 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_37); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_182 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 4, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_37); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_183 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 1, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_182, torch.bfloat16); view_182 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_183, torch.bfloat16); view_183 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 1, 4, 128]); unsqueeze_4 = None + view_184 = torch.ops.aten.view.default(expand_4, [2, 8192, 4, 128]); expand_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_178, 3); view_178 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 1, 4, 128]); unsqueeze_5 = None + view_185 = torch.ops.aten.view.default(expand_5, [2, 8192, 4, 128]); expand_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_184, [0, 2, 1, 3]); view_184 = None + permute_27 = torch.ops.aten.permute.default(view_185, [0, 2, 1, 3]); view_185 = None + _scaled_dot_product_cudnn_attention_backward_29 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1301, permute_25, permute_26, permute_27, getitem_162, getitem_163, getitem_168, getitem_169, None, None, None, 8192, 8192, 0.0, True); permute_1301 = permute_25 = permute_26 = permute_27 = getitem_162 = getitem_163 = getitem_168 = getitem_169 = None + getitem_2439 = _scaled_dot_product_cudnn_attention_backward_29[0] + getitem_2440 = _scaled_dot_product_cudnn_attention_backward_29[1] + getitem_2441 = _scaled_dot_product_cudnn_attention_backward_29[2]; _scaled_dot_product_cudnn_attention_backward_29 = None + permute_1302 = torch.ops.aten.permute.default(getitem_2441, [0, 2, 1, 3]); getitem_2441 = None + permute_1303 = torch.ops.aten.permute.default(getitem_2440, [0, 2, 1, 3]); getitem_2440 = None + permute_1304 = torch.ops.aten.permute.default(getitem_2439, [0, 2, 1, 3]); getitem_2439 = None + view_3028 = torch.ops.aten.view.default(permute_1302, [2, 8192, 1, 4, 128]); permute_1302 = None + sum_179 = torch.ops.aten.sum.dim_IntList(view_3028, [3], True); view_3028 = None + squeeze_58 = torch.ops.aten.squeeze.dim(sum_179, 3); sum_179 = None + view_3029 = torch.ops.aten.view.default(permute_1303, [2, 8192, 1, 4, 128]); permute_1303 = None + sum_180 = torch.ops.aten.sum.dim_IntList(view_3029, [3], True); view_3029 = None + squeeze_59 = torch.ops.aten.squeeze.dim(sum_180, 3); sum_180 = None + convert_element_type_2669 = torch.ops.prims.convert_element_type.default(squeeze_59, torch.float32); squeeze_59 = None + convert_element_type_2670 = torch.ops.prims.convert_element_type.default(permute_1304, torch.float32); permute_1304 = None + view_3030 = torch.ops.aten.view.default(convert_element_type_2669, [2, 8192, 1, 64, 2]); convert_element_type_2669 = None + view_as_complex_122 = torch.ops.aten.view_as_complex.default(view_3030); view_3030 = None + mul_856 = torch.ops.aten.mul.Tensor(view_as_complex_122, _conj); view_as_complex_122 = None + view_3031 = torch.ops.aten.view.default(convert_element_type_2670, [2, 8192, 4, 64, 2]); convert_element_type_2670 = None + view_as_complex_123 = torch.ops.aten.view_as_complex.default(view_3031); view_3031 = None + mul_857 = torch.ops.aten.mul.Tensor(view_as_complex_123, _conj); view_as_complex_123 = None + view_as_real_122 = torch.ops.aten.view_as_real.default(mul_856); mul_856 = None + view_3032 = torch.ops.aten.view.default(view_as_real_122, [2, 8192, 1, 128]); view_as_real_122 = None + convert_element_type_2671 = torch.ops.prims.convert_element_type.default(view_3032, torch.bfloat16); view_3032 = None + view_as_real_123 = torch.ops.aten.view_as_real.default(mul_857); mul_857 = None + view_3033 = torch.ops.aten.view.default(view_as_real_123, [2, 8192, 4, 128]); view_as_real_123 = None + convert_element_type_2672 = torch.ops.prims.convert_element_type.default(view_3033, torch.bfloat16); view_3033 = None + view_3034 = torch.ops.aten.view.default(squeeze_58, [2, 8192, 128]); squeeze_58 = None + view_3035 = torch.ops.aten.view.default(convert_element_type_2671, [2, 8192, 128]); convert_element_type_2671 = None + view_3036 = torch.ops.aten.view.default(convert_element_type_2672, [2, 8192, 512]); convert_element_type_2672 = None + view_3037 = torch.ops.aten.view.default(view_3034, [16384, 128]); view_3034 = None + permute_1305 = torch.ops.aten.permute.default(view_3037, [1, 0]) + mm_641 = torch.ops.aten.mm.default(permute_1305, view_159); permute_1305 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16); primals_25 = None + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 32, '0'); convert_element_type_76 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + permute_1307 = torch.ops.aten.permute.default(permute_24, [1, 0]); permute_24 = None + mm_642 = torch.ops.aten.mm.default(view_3037, permute_1307); view_3037 = permute_1307 = None + view_3038 = torch.ops.aten.view.default(mm_642, [2, 8192, 4096]); mm_642 = None + convert_element_type_2677 = torch.ops.prims.convert_element_type.default(mm_641, torch.float32); mm_641 = None + reduce_scatter_tensor_393 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2677, 'avg', 32, '0'); convert_element_type_2677 = None + wait_tensor_869 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_393); reduce_scatter_tensor_393 = None + view_3039 = torch.ops.aten.view.default(view_3035, [16384, 128]); view_3035 = None + permute_1309 = torch.ops.aten.permute.default(view_3039, [1, 0]) + mm_643 = torch.ops.aten.mm.default(permute_1309, view_159); permute_1309 = None + permute_1311 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None + mm_644 = torch.ops.aten.mm.default(view_3039, permute_1311); view_3039 = permute_1311 = None + view_3040 = torch.ops.aten.view.default(mm_644, [2, 8192, 4096]); mm_644 = None + add_336 = torch.ops.aten.add.Tensor(view_3038, view_3040); view_3038 = view_3040 = None + convert_element_type_2682 = torch.ops.prims.convert_element_type.default(mm_643, torch.float32); mm_643 = None + reduce_scatter_tensor_394 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2682, 'avg', 32, '0'); convert_element_type_2682 = None + wait_tensor_870 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_394); reduce_scatter_tensor_394 = None + view_3041 = torch.ops.aten.view.default(view_3036, [16384, 512]); view_3036 = None + permute_1313 = torch.ops.aten.permute.default(view_3041, [1, 0]) + mm_645 = torch.ops.aten.mm.default(permute_1313, view_159); permute_1313 = view_159 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16); primals_23 = None + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 32, '0'); convert_element_type_70 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + permute_1315 = torch.ops.aten.permute.default(permute_22, [1, 0]); permute_22 = None + mm_646 = torch.ops.aten.mm.default(view_3041, permute_1315); view_3041 = permute_1315 = None + view_3042 = torch.ops.aten.view.default(mm_646, [2, 8192, 4096]); mm_646 = None + add_337 = torch.ops.aten.add.Tensor(add_336, view_3042); add_336 = view_3042 = None + convert_element_type_2687 = torch.ops.prims.convert_element_type.default(mm_645, torch.float32); mm_645 = None + reduce_scatter_tensor_395 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2687, 'avg', 32, '0'); convert_element_type_2687 = None + wait_tensor_871 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_395); reduce_scatter_tensor_395 = None + split_258 = torch.ops.aten.split.Tensor(add_337, 1024, 1); add_337 = None + getitem_2442 = split_258[0] + getitem_2443 = split_258[1] + getitem_2444 = split_258[2] + getitem_2445 = split_258[3] + getitem_2446 = split_258[4] + getitem_2447 = split_258[5] + getitem_2448 = split_258[6] + getitem_2449 = split_258[7]; split_258 = None + cat_250 = torch.ops.aten.cat.default([getitem_2442, getitem_2443, getitem_2444, getitem_2445, getitem_2446, getitem_2447, getitem_2448, getitem_2449]); getitem_2442 = getitem_2443 = getitem_2444 = getitem_2445 = getitem_2446 = getitem_2447 = getitem_2448 = getitem_2449 = None + reduce_scatter_tensor_396 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_250, 'sum', 8, '1'); cat_250 = None + wait_tensor_872 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_396); reduce_scatter_tensor_396 = None + convert_element_type_2688 = torch.ops.prims.convert_element_type.default(wait_tensor_872, torch.float32); wait_tensor_872 = None + convert_element_type_2690 = torch.ops.prims.convert_element_type.default(wait_tensor_28, torch.float32); wait_tensor_28 = None + mul_858 = torch.ops.aten.mul.Tensor(convert_element_type_2688, convert_element_type_2690); convert_element_type_2690 = None + mul_860 = torch.ops.aten.mul.Tensor(mul_16, mul_858) + sum_181 = torch.ops.aten.sum.dim_IntList(mul_860, [2], True); mul_860 = None + div_60 = torch.ops.aten.div.Tensor(mul_16, 4096) + mul_861 = torch.ops.aten.mul.Tensor(div_60, sum_181); div_60 = sum_181 = None + sub_91 = torch.ops.aten.sub.Tensor(mul_858, mul_861); mul_858 = mul_861 = None + mul_862 = torch.ops.aten.mul.Tensor(sub_91, rsqrt_4); sub_91 = rsqrt_4 = None + mul_863 = torch.ops.aten.mul.Tensor(convert_element_type_2688, mul_16); convert_element_type_2688 = mul_16 = None + sum_182 = torch.ops.aten.sum.dim_IntList(mul_863, [0, 1]); mul_863 = None + convert_element_type_2691 = torch.ops.prims.convert_element_type.default(mul_862, torch.bfloat16); mul_862 = None + convert_element_type_2692 = torch.ops.prims.convert_element_type.default(sum_182, torch.bfloat16); sum_182 = None + all_reduce_60 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2692, 'sum', '1'); convert_element_type_2692 = None + wait_tensor_873 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_60); all_reduce_60 = None + convert_element_type_2693 = torch.ops.prims.convert_element_type.default(wait_tensor_873, torch.float32); wait_tensor_873 = None + reduce_scatter_tensor_397 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2693, 'avg', 32, '0'); convert_element_type_2693 = None + wait_tensor_874 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_397); reduce_scatter_tensor_397 = None + add_338 = torch.ops.aten.add.Tensor(add_335, convert_element_type_2691); add_335 = convert_element_type_2691 = None + all_gather_into_tensor_416 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_338, 8, '1') + wait_tensor_875 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_416); all_gather_into_tensor_416 = None + split_259 = torch.ops.aten.split.Tensor(wait_tensor_875, 2); wait_tensor_875 = None + getitem_2450 = split_259[0] + getitem_2451 = split_259[1] + getitem_2452 = split_259[2] + getitem_2453 = split_259[3] + getitem_2454 = split_259[4] + getitem_2455 = split_259[5] + getitem_2456 = split_259[6] + getitem_2457 = split_259[7]; split_259 = None + cat_251 = torch.ops.aten.cat.default([getitem_2450, getitem_2451, getitem_2452, getitem_2453, getitem_2454, getitem_2455, getitem_2456, getitem_2457], 1); getitem_2450 = getitem_2451 = getitem_2452 = getitem_2453 = getitem_2454 = getitem_2455 = getitem_2456 = getitem_2457 = None + view_3043 = torch.ops.aten.view.default(cat_251, [16384, 4096]); cat_251 = None + permute_1317 = torch.ops.aten.permute.default(view_3043, [1, 0]) + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3); reduce_scatter_tensor_3 = None + add_5 = torch.ops.aten.add.Tensor(add_3, wait_tensor_21); wait_tensor_21 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16); primals_18 = None + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 32, '0'); convert_element_type_53 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32); add_5 = None + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_22) + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_55, 8, '1'); convert_element_type_55 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + split_15 = torch.ops.aten.split.Tensor(wait_tensor_23, 2); wait_tensor_23 = None + getitem_138 = split_15[0] + getitem_139 = split_15[1] + getitem_140 = split_15[2] + getitem_141 = split_15[3] + getitem_142 = split_15[4] + getitem_143 = split_15[5] + getitem_144 = split_15[6] + getitem_145 = split_15[7]; split_15 = None + cat_7 = torch.ops.aten.cat.default([getitem_138, getitem_139, getitem_140, getitem_141, getitem_142, getitem_143, getitem_144, getitem_145], 1); getitem_138 = getitem_139 = getitem_140 = getitem_141 = getitem_142 = getitem_143 = getitem_144 = getitem_145 = None + view_132 = torch.ops.aten.view.default(cat_7, [16384, 4096]); cat_7 = None + view_133 = torch.ops.aten.view.default(mm_11, [2, 8192, 1792]); mm_11 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_133, torch.float32); view_133 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16); primals_20 = None + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 32, '0'); convert_element_type_61 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + mm_12 = torch.ops.aten.mm.default(view_132, permute_20) + view_140 = torch.ops.aten.view.default(mm_12, [2, 8192, 1792]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_140) + view_147 = torch.ops.aten.view.default(mul_15, [16384, 1792]); mul_15 = None + mm_647 = torch.ops.aten.mm.default(permute_1317, view_147); permute_1317 = view_147 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16); primals_21 = None + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 32, '0'); convert_element_type_64 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + permute_1319 = torch.ops.aten.permute.default(permute_21, [1, 0]); permute_21 = None + mm_648 = torch.ops.aten.mm.default(view_3043, permute_1319); view_3043 = permute_1319 = None + view_3044 = torch.ops.aten.view.default(mm_648, [2, 8192, 1792]); mm_648 = None + convert_element_type_2698 = torch.ops.prims.convert_element_type.default(mm_647, torch.float32); mm_647 = None + reduce_scatter_tensor_398 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2698, 'avg', 32, '0'); convert_element_type_2698 = None + wait_tensor_876 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_398); reduce_scatter_tensor_398 = None + mul_864 = torch.ops.aten.mul.Tensor(view_3044, convert_element_type_60); convert_element_type_60 = None + mul_865 = torch.ops.aten.mul.Tensor(view_3044, view_140); view_3044 = view_140 = None + view_3045 = torch.ops.aten.view.default(mul_864, [16384, 1792]); mul_864 = None + permute_1321 = torch.ops.aten.permute.default(view_3045, [1, 0]) + mm_649 = torch.ops.aten.mm.default(permute_1321, view_132); permute_1321 = None + permute_1323 = torch.ops.aten.permute.default(permute_20, [1, 0]); permute_20 = None + mm_650 = torch.ops.aten.mm.default(view_3045, permute_1323); view_3045 = permute_1323 = None + view_3046 = torch.ops.aten.view.default(mm_650, [2, 8192, 4096]); mm_650 = None + convert_element_type_2703 = torch.ops.prims.convert_element_type.default(mm_649, torch.float32); mm_649 = None + reduce_scatter_tensor_399 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2703, 'avg', 32, '0'); convert_element_type_2703 = None + wait_tensor_877 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_399); reduce_scatter_tensor_399 = None + convert_element_type_2704 = torch.ops.prims.convert_element_type.default(mul_865, torch.float32); mul_865 = None + neg_30 = torch.ops.aten.neg.default(convert_element_type_59) + exp_30 = torch.ops.aten.exp.default(neg_30); neg_30 = None + add_339 = torch.ops.aten.add.Tensor(exp_30, 1); exp_30 = None + reciprocal_30 = torch.ops.aten.reciprocal.default(add_339); add_339 = None + mul_866 = torch.ops.aten.mul.Tensor(reciprocal_30, 1); reciprocal_30 = None + mul_867 = torch.ops.aten.mul.Tensor(convert_element_type_2704, mul_866); convert_element_type_2704 = None + sub_92 = torch.ops.aten.sub.Tensor(1, mul_866); mul_866 = None + mul_868 = torch.ops.aten.mul.Tensor(convert_element_type_59, sub_92); convert_element_type_59 = sub_92 = None + add_340 = torch.ops.aten.add.Tensor(mul_868, 1); mul_868 = None + mul_869 = torch.ops.aten.mul.Tensor(mul_867, add_340); mul_867 = add_340 = None + convert_element_type_2706 = torch.ops.prims.convert_element_type.default(mul_869, torch.bfloat16); mul_869 = None + view_3047 = torch.ops.aten.view.default(convert_element_type_2706, [16384, 1792]); convert_element_type_2706 = None + permute_1325 = torch.ops.aten.permute.default(view_3047, [1, 0]) + mm_651 = torch.ops.aten.mm.default(permute_1325, view_132); permute_1325 = view_132 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16); primals_19 = None + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 32, '0'); convert_element_type_56 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_24, [1, 0]); wait_tensor_24 = None + permute_1327 = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None + mm_652 = torch.ops.aten.mm.default(view_3047, permute_1327); view_3047 = permute_1327 = None + view_3048 = torch.ops.aten.view.default(mm_652, [2, 8192, 4096]); mm_652 = None + add_341 = torch.ops.aten.add.Tensor(view_3046, view_3048); view_3046 = view_3048 = None + convert_element_type_2711 = torch.ops.prims.convert_element_type.default(mm_651, torch.float32); mm_651 = None + reduce_scatter_tensor_400 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2711, 'avg', 32, '0'); convert_element_type_2711 = None + wait_tensor_878 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_400); reduce_scatter_tensor_400 = None + split_260 = torch.ops.aten.split.Tensor(add_341, 1024, 1); add_341 = None + getitem_2458 = split_260[0] + getitem_2459 = split_260[1] + getitem_2460 = split_260[2] + getitem_2461 = split_260[3] + getitem_2462 = split_260[4] + getitem_2463 = split_260[5] + getitem_2464 = split_260[6] + getitem_2465 = split_260[7]; split_260 = None + cat_252 = torch.ops.aten.cat.default([getitem_2458, getitem_2459, getitem_2460, getitem_2461, getitem_2462, getitem_2463, getitem_2464, getitem_2465]); getitem_2458 = getitem_2459 = getitem_2460 = getitem_2461 = getitem_2462 = getitem_2463 = getitem_2464 = getitem_2465 = None + reduce_scatter_tensor_401 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_252, 'sum', 8, '1'); cat_252 = None + wait_tensor_879 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_401); reduce_scatter_tensor_401 = None + convert_element_type_2712 = torch.ops.prims.convert_element_type.default(wait_tensor_879, torch.float32); wait_tensor_879 = None + convert_element_type_2714 = torch.ops.prims.convert_element_type.default(wait_tensor_22, torch.float32); wait_tensor_22 = None + mul_870 = torch.ops.aten.mul.Tensor(convert_element_type_2712, convert_element_type_2714); convert_element_type_2714 = None + mul_872 = torch.ops.aten.mul.Tensor(mul_12, mul_870) + sum_183 = torch.ops.aten.sum.dim_IntList(mul_872, [2], True); mul_872 = None + div_61 = torch.ops.aten.div.Tensor(mul_12, 4096) + mul_873 = torch.ops.aten.mul.Tensor(div_61, sum_183); div_61 = sum_183 = None + sub_93 = torch.ops.aten.sub.Tensor(mul_870, mul_873); mul_870 = mul_873 = None + mul_874 = torch.ops.aten.mul.Tensor(sub_93, rsqrt_3); sub_93 = rsqrt_3 = None + mul_875 = torch.ops.aten.mul.Tensor(convert_element_type_2712, mul_12); convert_element_type_2712 = mul_12 = None + sum_184 = torch.ops.aten.sum.dim_IntList(mul_875, [0, 1]); mul_875 = None + convert_element_type_2715 = torch.ops.prims.convert_element_type.default(mul_874, torch.bfloat16); mul_874 = None + convert_element_type_2716 = torch.ops.prims.convert_element_type.default(sum_184, torch.bfloat16); sum_184 = None + all_reduce_61 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2716, 'sum', '1'); convert_element_type_2716 = None + wait_tensor_880 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_61); all_reduce_61 = None + convert_element_type_2717 = torch.ops.prims.convert_element_type.default(wait_tensor_880, torch.float32); wait_tensor_880 = None + reduce_scatter_tensor_402 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2717, 'avg', 32, '0'); convert_element_type_2717 = None + wait_tensor_881 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_402); reduce_scatter_tensor_402 = None + add_342 = torch.ops.aten.add.Tensor(add_338, convert_element_type_2715); add_338 = convert_element_type_2715 = None + all_gather_into_tensor_417 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_342, 8, '1') + wait_tensor_882 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_417); all_gather_into_tensor_417 = None + split_261 = torch.ops.aten.split.Tensor(wait_tensor_882, 2); wait_tensor_882 = None + getitem_2466 = split_261[0] + getitem_2467 = split_261[1] + getitem_2468 = split_261[2] + getitem_2469 = split_261[3] + getitem_2470 = split_261[4] + getitem_2471 = split_261[5] + getitem_2472 = split_261[6] + getitem_2473 = split_261[7]; split_261 = None + cat_253 = torch.ops.aten.cat.default([getitem_2466, getitem_2467, getitem_2468, getitem_2469, getitem_2470, getitem_2471, getitem_2472, getitem_2473], 1); getitem_2466 = getitem_2467 = getitem_2468 = getitem_2469 = getitem_2470 = getitem_2471 = getitem_2472 = getitem_2473 = None + view_3049 = torch.ops.aten.view.default(cat_253, [16384, 4096]); cat_253 = None + permute_1329 = torch.ops.aten.permute.default(view_3049, [1, 0]) + permute_17 = torch.ops.aten.permute.default(getitem_121, [0, 2, 1, 3]) + view_114 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + view_120 = torch.ops.aten.view.default(view_114, [16384, 512]); view_114 = None + mm_653 = torch.ops.aten.mm.default(permute_1329, view_120); permute_1329 = view_120 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16); primals_17 = None + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 32, '0'); convert_element_type_50 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + permute_1331 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None + mm_654 = torch.ops.aten.mm.default(view_3049, permute_1331); view_3049 = permute_1331 = None + view_3050 = torch.ops.aten.view.default(mm_654, [2, 8192, 512]); mm_654 = None + convert_element_type_2722 = torch.ops.prims.convert_element_type.default(mm_653, torch.float32); mm_653 = None + reduce_scatter_tensor_403 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2722, 'avg', 32, '0'); convert_element_type_2722 = None + wait_tensor_883 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_403); reduce_scatter_tensor_403 = None + view_3051 = torch.ops.aten.view.default(view_3050, [2, 8192, 4, 128]); view_3050 = None + permute_1333 = torch.ops.aten.permute.default(view_3051, [0, 2, 1, 3]); view_3051 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16); primals_13 = None + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 32, '0'); convert_element_type_34 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32); add_3 = None + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_15) + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_36, 8, '1'); convert_element_type_36 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + split_13 = torch.ops.aten.split.Tensor(wait_tensor_16, 2); wait_tensor_16 = None + getitem_113 = split_13[0] + getitem_114 = split_13[1] + getitem_115 = split_13[2] + getitem_116 = split_13[3] + getitem_117 = split_13[4] + getitem_118 = split_13[5] + getitem_119 = split_13[6] + getitem_120 = split_13[7]; split_13 = None + cat_5 = torch.ops.aten.cat.default([getitem_113, getitem_114, getitem_115, getitem_116, getitem_117, getitem_118, getitem_119, getitem_120], 1); getitem_113 = getitem_114 = getitem_115 = getitem_116 = getitem_117 = getitem_118 = getitem_119 = getitem_120 = None + view_87 = torch.ops.aten.view.default(cat_5, [16384, 4096]); cat_5 = None + view_88 = torch.ops.aten.view.default(mm_7, [2, 8192, 512]); mm_7 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16); primals_15 = None + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 32, '0'); convert_element_type_40 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + mm_8 = torch.ops.aten.mm.default(view_87, permute_12) + view_95 = torch.ops.aten.view.default(mm_8, [2, 8192, 128]); mm_8 = None + view_102 = torch.ops.aten.view.default(mm_9, [2, 8192, 128]); mm_9 = None + view_104 = torch.ops.aten.view.default(view_88, [2, 8192, -1, 128]); view_88 = None + view_105 = torch.ops.aten.view.default(view_95, [2, 8192, -1, 128]); view_95 = None + view_106 = torch.ops.aten.view.default(view_102, [2, 8192, -1, 128]); view_102 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_104, torch.float32); view_104 = None + view_107 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 4, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_107); view_107 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_105, torch.float32); view_105 = None + view_108 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 1, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_108); view_108 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_37); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_110 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 4, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_37); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_111 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 1, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_110, torch.bfloat16); view_110 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_111, torch.bfloat16); view_111 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 1, 4, 128]); unsqueeze_2 = None + view_112 = torch.ops.aten.view.default(expand_2, [2, 8192, 4, 128]); expand_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_106, 3); view_106 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 1, 4, 128]); unsqueeze_3 = None + view_113 = torch.ops.aten.view.default(expand_3, [2, 8192, 4, 128]); expand_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_112, [0, 2, 1, 3]); view_112 = None + permute_16 = torch.ops.aten.permute.default(view_113, [0, 2, 1, 3]); view_113 = None + _scaled_dot_product_cudnn_attention_backward_30 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1333, permute_14, permute_15, permute_16, getitem_121, getitem_122, getitem_127, getitem_128, None, None, None, 8192, 8192, 0.0, True); permute_1333 = permute_14 = permute_15 = permute_16 = getitem_121 = getitem_122 = getitem_127 = getitem_128 = None + getitem_2474 = _scaled_dot_product_cudnn_attention_backward_30[0] + getitem_2475 = _scaled_dot_product_cudnn_attention_backward_30[1] + getitem_2476 = _scaled_dot_product_cudnn_attention_backward_30[2]; _scaled_dot_product_cudnn_attention_backward_30 = None + permute_1334 = torch.ops.aten.permute.default(getitem_2476, [0, 2, 1, 3]); getitem_2476 = None + permute_1335 = torch.ops.aten.permute.default(getitem_2475, [0, 2, 1, 3]); getitem_2475 = None + permute_1336 = torch.ops.aten.permute.default(getitem_2474, [0, 2, 1, 3]); getitem_2474 = None + view_3052 = torch.ops.aten.view.default(permute_1334, [2, 8192, 1, 4, 128]); permute_1334 = None + sum_185 = torch.ops.aten.sum.dim_IntList(view_3052, [3], True); view_3052 = None + squeeze_60 = torch.ops.aten.squeeze.dim(sum_185, 3); sum_185 = None + view_3053 = torch.ops.aten.view.default(permute_1335, [2, 8192, 1, 4, 128]); permute_1335 = None + sum_186 = torch.ops.aten.sum.dim_IntList(view_3053, [3], True); view_3053 = None + squeeze_61 = torch.ops.aten.squeeze.dim(sum_186, 3); sum_186 = None + convert_element_type_2723 = torch.ops.prims.convert_element_type.default(squeeze_61, torch.float32); squeeze_61 = None + convert_element_type_2724 = torch.ops.prims.convert_element_type.default(permute_1336, torch.float32); permute_1336 = None + view_3054 = torch.ops.aten.view.default(convert_element_type_2723, [2, 8192, 1, 64, 2]); convert_element_type_2723 = None + view_as_complex_124 = torch.ops.aten.view_as_complex.default(view_3054); view_3054 = None + mul_876 = torch.ops.aten.mul.Tensor(view_as_complex_124, _conj); view_as_complex_124 = None + view_3055 = torch.ops.aten.view.default(convert_element_type_2724, [2, 8192, 4, 64, 2]); convert_element_type_2724 = None + view_as_complex_125 = torch.ops.aten.view_as_complex.default(view_3055); view_3055 = None + mul_877 = torch.ops.aten.mul.Tensor(view_as_complex_125, _conj); view_as_complex_125 = None + view_as_real_124 = torch.ops.aten.view_as_real.default(mul_876); mul_876 = None + view_3056 = torch.ops.aten.view.default(view_as_real_124, [2, 8192, 1, 128]); view_as_real_124 = None + convert_element_type_2725 = torch.ops.prims.convert_element_type.default(view_3056, torch.bfloat16); view_3056 = None + view_as_real_125 = torch.ops.aten.view_as_real.default(mul_877); mul_877 = None + view_3057 = torch.ops.aten.view.default(view_as_real_125, [2, 8192, 4, 128]); view_as_real_125 = None + convert_element_type_2726 = torch.ops.prims.convert_element_type.default(view_3057, torch.bfloat16); view_3057 = None + view_3058 = torch.ops.aten.view.default(squeeze_60, [2, 8192, 128]); squeeze_60 = None + view_3059 = torch.ops.aten.view.default(convert_element_type_2725, [2, 8192, 128]); convert_element_type_2725 = None + view_3060 = torch.ops.aten.view.default(convert_element_type_2726, [2, 8192, 512]); convert_element_type_2726 = None + view_3061 = torch.ops.aten.view.default(view_3058, [16384, 128]); view_3058 = None + permute_1337 = torch.ops.aten.permute.default(view_3061, [1, 0]) + mm_655 = torch.ops.aten.mm.default(permute_1337, view_87); permute_1337 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16); primals_16 = None + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 32, '0'); convert_element_type_43 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_19, [1, 0]); wait_tensor_19 = None + permute_1339 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None + mm_656 = torch.ops.aten.mm.default(view_3061, permute_1339); view_3061 = permute_1339 = None + view_3062 = torch.ops.aten.view.default(mm_656, [2, 8192, 4096]); mm_656 = None + convert_element_type_2731 = torch.ops.prims.convert_element_type.default(mm_655, torch.float32); mm_655 = None + reduce_scatter_tensor_404 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2731, 'avg', 32, '0'); convert_element_type_2731 = None + wait_tensor_884 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_404); reduce_scatter_tensor_404 = None + view_3063 = torch.ops.aten.view.default(view_3059, [16384, 128]); view_3059 = None + permute_1341 = torch.ops.aten.permute.default(view_3063, [1, 0]) + mm_657 = torch.ops.aten.mm.default(permute_1341, view_87); permute_1341 = None + permute_1343 = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None + mm_658 = torch.ops.aten.mm.default(view_3063, permute_1343); view_3063 = permute_1343 = None + view_3064 = torch.ops.aten.view.default(mm_658, [2, 8192, 4096]); mm_658 = None + add_343 = torch.ops.aten.add.Tensor(view_3062, view_3064); view_3062 = view_3064 = None + convert_element_type_2736 = torch.ops.prims.convert_element_type.default(mm_657, torch.float32); mm_657 = None + reduce_scatter_tensor_405 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2736, 'avg', 32, '0'); convert_element_type_2736 = None + wait_tensor_885 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_405); reduce_scatter_tensor_405 = None + view_3065 = torch.ops.aten.view.default(view_3060, [16384, 512]); view_3060 = None + permute_1345 = torch.ops.aten.permute.default(view_3065, [1, 0]) + mm_659 = torch.ops.aten.mm.default(permute_1345, view_87); permute_1345 = view_87 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16); primals_14 = None + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 32, '0'); convert_element_type_37 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + permute_1347 = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None + mm_660 = torch.ops.aten.mm.default(view_3065, permute_1347); view_3065 = permute_1347 = None + view_3066 = torch.ops.aten.view.default(mm_660, [2, 8192, 4096]); mm_660 = None + add_344 = torch.ops.aten.add.Tensor(add_343, view_3066); add_343 = view_3066 = None + convert_element_type_2741 = torch.ops.prims.convert_element_type.default(mm_659, torch.float32); mm_659 = None + reduce_scatter_tensor_406 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2741, 'avg', 32, '0'); convert_element_type_2741 = None + wait_tensor_886 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_406); reduce_scatter_tensor_406 = None + split_262 = torch.ops.aten.split.Tensor(add_344, 1024, 1); add_344 = None + getitem_2477 = split_262[0] + getitem_2478 = split_262[1] + getitem_2479 = split_262[2] + getitem_2480 = split_262[3] + getitem_2481 = split_262[4] + getitem_2482 = split_262[5] + getitem_2483 = split_262[6] + getitem_2484 = split_262[7]; split_262 = None + cat_254 = torch.ops.aten.cat.default([getitem_2477, getitem_2478, getitem_2479, getitem_2480, getitem_2481, getitem_2482, getitem_2483, getitem_2484]); getitem_2477 = getitem_2478 = getitem_2479 = getitem_2480 = getitem_2481 = getitem_2482 = getitem_2483 = getitem_2484 = None + reduce_scatter_tensor_407 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_254, 'sum', 8, '1'); cat_254 = None + wait_tensor_887 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_407); reduce_scatter_tensor_407 = None + convert_element_type_2742 = torch.ops.prims.convert_element_type.default(wait_tensor_887, torch.float32); wait_tensor_887 = None + convert_element_type_2744 = torch.ops.prims.convert_element_type.default(wait_tensor_15, torch.float32); wait_tensor_15 = None + mul_878 = torch.ops.aten.mul.Tensor(convert_element_type_2742, convert_element_type_2744); convert_element_type_2744 = None + mul_880 = torch.ops.aten.mul.Tensor(mul_8, mul_878) + sum_187 = torch.ops.aten.sum.dim_IntList(mul_880, [2], True); mul_880 = None + div_62 = torch.ops.aten.div.Tensor(mul_8, 4096) + mul_881 = torch.ops.aten.mul.Tensor(div_62, sum_187); div_62 = sum_187 = None + sub_94 = torch.ops.aten.sub.Tensor(mul_878, mul_881); mul_878 = mul_881 = None + mul_882 = torch.ops.aten.mul.Tensor(sub_94, rsqrt_2); sub_94 = rsqrt_2 = None + mul_883 = torch.ops.aten.mul.Tensor(convert_element_type_2742, mul_8); convert_element_type_2742 = mul_8 = None + sum_188 = torch.ops.aten.sum.dim_IntList(mul_883, [0, 1]); mul_883 = None + convert_element_type_2745 = torch.ops.prims.convert_element_type.default(mul_882, torch.bfloat16); mul_882 = None + convert_element_type_2746 = torch.ops.prims.convert_element_type.default(sum_188, torch.bfloat16); sum_188 = None + all_reduce_62 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2746, 'sum', '1'); convert_element_type_2746 = None + wait_tensor_888 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_62); all_reduce_62 = None + convert_element_type_2747 = torch.ops.prims.convert_element_type.default(wait_tensor_888, torch.float32); wait_tensor_888 = None + reduce_scatter_tensor_408 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2747, 'avg', 32, '0'); convert_element_type_2747 = None + wait_tensor_889 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_408); reduce_scatter_tensor_408 = None + add_345 = torch.ops.aten.add.Tensor(add_342, convert_element_type_2745); add_342 = convert_element_type_2745 = None + all_gather_into_tensor_418 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_345, 8, '1') + wait_tensor_890 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_418); all_gather_into_tensor_418 = None + split_263 = torch.ops.aten.split.Tensor(wait_tensor_890, 2); wait_tensor_890 = None + getitem_2485 = split_263[0] + getitem_2486 = split_263[1] + getitem_2487 = split_263[2] + getitem_2488 = split_263[3] + getitem_2489 = split_263[4] + getitem_2490 = split_263[5] + getitem_2491 = split_263[6] + getitem_2492 = split_263[7]; split_263 = None + cat_255 = torch.ops.aten.cat.default([getitem_2485, getitem_2486, getitem_2487, getitem_2488, getitem_2489, getitem_2490, getitem_2491, getitem_2492], 1); getitem_2485 = getitem_2486 = getitem_2487 = getitem_2488 = getitem_2489 = getitem_2490 = getitem_2491 = getitem_2492 = None + view_3067 = torch.ops.aten.view.default(cat_255, [16384, 4096]); cat_255 = None + permute_1349 = torch.ops.aten.permute.default(view_3067, [1, 0]) + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None + add_1 = torch.ops.aten.add.Tensor(wait_tensor_1, wait_tensor_8); wait_tensor_8 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16); primals_9 = None + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 32, '0'); convert_element_type_20 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32); add_1 = None + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_9) + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_22, 8, '1'); convert_element_type_22 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + split_11 = torch.ops.aten.split.Tensor(wait_tensor_10, 2); wait_tensor_10 = None + getitem_97 = split_11[0] + getitem_98 = split_11[1] + getitem_99 = split_11[2] + getitem_100 = split_11[3] + getitem_101 = split_11[4] + getitem_102 = split_11[5] + getitem_103 = split_11[6] + getitem_104 = split_11[7]; split_11 = None + cat_3 = torch.ops.aten.cat.default([getitem_97, getitem_98, getitem_99, getitem_100, getitem_101, getitem_102, getitem_103, getitem_104], 1); getitem_97 = getitem_98 = getitem_99 = getitem_100 = getitem_101 = getitem_102 = getitem_103 = getitem_104 = None + view_60 = torch.ops.aten.view.default(cat_3, [16384, 4096]); cat_3 = None + view_61 = torch.ops.aten.view.default(mm_4, [2, 8192, 1792]); mm_4 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_61, torch.float32); view_61 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16); primals_11 = None + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 32, '0'); convert_element_type_28 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_5 = torch.ops.aten.mm.default(view_60, permute_9) + view_68 = torch.ops.aten.view.default(mm_5, [2, 8192, 1792]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_68) + view_75 = torch.ops.aten.view.default(mul_7, [16384, 1792]); mul_7 = None + mm_661 = torch.ops.aten.mm.default(permute_1349, view_75); permute_1349 = view_75 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16); primals_12 = None + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 32, '0'); convert_element_type_31 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + permute_1351 = torch.ops.aten.permute.default(permute_10, [1, 0]); permute_10 = None + mm_662 = torch.ops.aten.mm.default(view_3067, permute_1351); view_3067 = permute_1351 = None + view_3068 = torch.ops.aten.view.default(mm_662, [2, 8192, 1792]); mm_662 = None + convert_element_type_2752 = torch.ops.prims.convert_element_type.default(mm_661, torch.float32); mm_661 = None + reduce_scatter_tensor_409 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2752, 'avg', 32, '0'); convert_element_type_2752 = None + wait_tensor_891 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_409); reduce_scatter_tensor_409 = None + mul_884 = torch.ops.aten.mul.Tensor(view_3068, convert_element_type_27); convert_element_type_27 = None + mul_885 = torch.ops.aten.mul.Tensor(view_3068, view_68); view_3068 = view_68 = None + view_3069 = torch.ops.aten.view.default(mul_884, [16384, 1792]); mul_884 = None + permute_1353 = torch.ops.aten.permute.default(view_3069, [1, 0]) + mm_663 = torch.ops.aten.mm.default(permute_1353, view_60); permute_1353 = None + permute_1355 = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None + mm_664 = torch.ops.aten.mm.default(view_3069, permute_1355); view_3069 = permute_1355 = None + view_3070 = torch.ops.aten.view.default(mm_664, [2, 8192, 4096]); mm_664 = None + convert_element_type_2757 = torch.ops.prims.convert_element_type.default(mm_663, torch.float32); mm_663 = None + reduce_scatter_tensor_410 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2757, 'avg', 32, '0'); convert_element_type_2757 = None + wait_tensor_892 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_410); reduce_scatter_tensor_410 = None + convert_element_type_2758 = torch.ops.prims.convert_element_type.default(mul_885, torch.float32); mul_885 = None + neg_31 = torch.ops.aten.neg.default(convert_element_type_26) + exp_31 = torch.ops.aten.exp.default(neg_31); neg_31 = None + add_346 = torch.ops.aten.add.Tensor(exp_31, 1); exp_31 = None + reciprocal_31 = torch.ops.aten.reciprocal.default(add_346); add_346 = None + mul_886 = torch.ops.aten.mul.Tensor(reciprocal_31, 1); reciprocal_31 = None + mul_887 = torch.ops.aten.mul.Tensor(convert_element_type_2758, mul_886); convert_element_type_2758 = None + sub_95 = torch.ops.aten.sub.Tensor(1, mul_886); mul_886 = None + mul_888 = torch.ops.aten.mul.Tensor(convert_element_type_26, sub_95); convert_element_type_26 = sub_95 = None + add_347 = torch.ops.aten.add.Tensor(mul_888, 1); mul_888 = None + mul_889 = torch.ops.aten.mul.Tensor(mul_887, add_347); mul_887 = add_347 = None + convert_element_type_2760 = torch.ops.prims.convert_element_type.default(mul_889, torch.bfloat16); mul_889 = None + view_3071 = torch.ops.aten.view.default(convert_element_type_2760, [16384, 1792]); convert_element_type_2760 = None + permute_1357 = torch.ops.aten.permute.default(view_3071, [1, 0]) + mm_665 = torch.ops.aten.mm.default(permute_1357, view_60); permute_1357 = view_60 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16); primals_10 = None + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 32, '0'); convert_element_type_23 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + permute_1359 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None + mm_666 = torch.ops.aten.mm.default(view_3071, permute_1359); view_3071 = permute_1359 = None + view_3072 = torch.ops.aten.view.default(mm_666, [2, 8192, 4096]); mm_666 = None + add_348 = torch.ops.aten.add.Tensor(view_3070, view_3072); view_3070 = view_3072 = None + convert_element_type_2765 = torch.ops.prims.convert_element_type.default(mm_665, torch.float32); mm_665 = None + reduce_scatter_tensor_411 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2765, 'avg', 32, '0'); convert_element_type_2765 = None + wait_tensor_893 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_411); reduce_scatter_tensor_411 = None + split_264 = torch.ops.aten.split.Tensor(add_348, 1024, 1); add_348 = None + getitem_2493 = split_264[0] + getitem_2494 = split_264[1] + getitem_2495 = split_264[2] + getitem_2496 = split_264[3] + getitem_2497 = split_264[4] + getitem_2498 = split_264[5] + getitem_2499 = split_264[6] + getitem_2500 = split_264[7]; split_264 = None + cat_256 = torch.ops.aten.cat.default([getitem_2493, getitem_2494, getitem_2495, getitem_2496, getitem_2497, getitem_2498, getitem_2499, getitem_2500]); getitem_2493 = getitem_2494 = getitem_2495 = getitem_2496 = getitem_2497 = getitem_2498 = getitem_2499 = getitem_2500 = None + reduce_scatter_tensor_412 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_256, 'sum', 8, '1'); cat_256 = None + wait_tensor_894 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_412); reduce_scatter_tensor_412 = None + convert_element_type_2766 = torch.ops.prims.convert_element_type.default(wait_tensor_894, torch.float32); wait_tensor_894 = None + convert_element_type_2768 = torch.ops.prims.convert_element_type.default(wait_tensor_9, torch.float32); wait_tensor_9 = None + mul_890 = torch.ops.aten.mul.Tensor(convert_element_type_2766, convert_element_type_2768); convert_element_type_2768 = None + mul_892 = torch.ops.aten.mul.Tensor(mul_4, mul_890) + sum_189 = torch.ops.aten.sum.dim_IntList(mul_892, [2], True); mul_892 = None + div_63 = torch.ops.aten.div.Tensor(mul_4, 4096) + mul_893 = torch.ops.aten.mul.Tensor(div_63, sum_189); div_63 = sum_189 = None + sub_96 = torch.ops.aten.sub.Tensor(mul_890, mul_893); mul_890 = mul_893 = None + mul_894 = torch.ops.aten.mul.Tensor(sub_96, rsqrt_1); sub_96 = rsqrt_1 = None + mul_895 = torch.ops.aten.mul.Tensor(convert_element_type_2766, mul_4); convert_element_type_2766 = mul_4 = None + sum_190 = torch.ops.aten.sum.dim_IntList(mul_895, [0, 1]); mul_895 = None + convert_element_type_2769 = torch.ops.prims.convert_element_type.default(mul_894, torch.bfloat16); mul_894 = None + convert_element_type_2770 = torch.ops.prims.convert_element_type.default(sum_190, torch.bfloat16); sum_190 = None + all_reduce_63 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2770, 'sum', '1'); convert_element_type_2770 = None + wait_tensor_895 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_63); all_reduce_63 = None + convert_element_type_2771 = torch.ops.prims.convert_element_type.default(wait_tensor_895, torch.float32); wait_tensor_895 = None + reduce_scatter_tensor_413 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2771, 'avg', 32, '0'); convert_element_type_2771 = None + wait_tensor_896 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_413); reduce_scatter_tensor_413 = None + add_349 = torch.ops.aten.add.Tensor(add_345, convert_element_type_2769); add_345 = convert_element_type_2769 = None + all_gather_into_tensor_419 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_349, 8, '1') + wait_tensor_897 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_419); all_gather_into_tensor_419 = None + split_265 = torch.ops.aten.split.Tensor(wait_tensor_897, 2); wait_tensor_897 = None + getitem_2501 = split_265[0] + getitem_2502 = split_265[1] + getitem_2503 = split_265[2] + getitem_2504 = split_265[3] + getitem_2505 = split_265[4] + getitem_2506 = split_265[5] + getitem_2507 = split_265[6] + getitem_2508 = split_265[7]; split_265 = None + cat_257 = torch.ops.aten.cat.default([getitem_2501, getitem_2502, getitem_2503, getitem_2504, getitem_2505, getitem_2506, getitem_2507, getitem_2508], 1); getitem_2501 = getitem_2502 = getitem_2503 = getitem_2504 = getitem_2505 = getitem_2506 = getitem_2507 = getitem_2508 = None + view_3073 = torch.ops.aten.view.default(cat_257, [16384, 4096]); cat_257 = None + permute_1361 = torch.ops.aten.permute.default(view_3073, [1, 0]) + permute_6 = torch.ops.aten.permute.default(getitem_80, [0, 2, 1, 3]) + view_42 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + view_48 = torch.ops.aten.view.default(view_42, [16384, 512]); view_42 = None + mm_667 = torch.ops.aten.mm.default(permute_1361, view_48); permute_1361 = view_48 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16); primals_8 = None + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 32, '0'); convert_element_type_17 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + permute_1363 = torch.ops.aten.permute.default(permute_7, [1, 0]); permute_7 = None + mm_668 = torch.ops.aten.mm.default(view_3073, permute_1363); view_3073 = permute_1363 = None + view_3074 = torch.ops.aten.view.default(mm_668, [2, 8192, 512]); mm_668 = None + convert_element_type_2776 = torch.ops.prims.convert_element_type.default(mm_667, torch.float32); mm_667 = None + reduce_scatter_tensor_414 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2776, 'avg', 32, '0'); convert_element_type_2776 = None + wait_tensor_898 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_414); reduce_scatter_tensor_414 = None + view_3075 = torch.ops.aten.view.default(view_3074, [2, 8192, 4, 128]); view_3074 = None + permute_1365 = torch.ops.aten.permute.default(view_3075, [0, 2, 1, 3]); view_3075 = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 32, '0'); convert_element_type_1 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32); wait_tensor_1 = None + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_2) + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_3, 8, '1'); convert_element_type_3 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + split_9 = torch.ops.aten.split.Tensor(wait_tensor_3, 2); wait_tensor_3 = None + getitem_72 = split_9[0] + getitem_73 = split_9[1] + getitem_74 = split_9[2] + getitem_75 = split_9[3] + getitem_76 = split_9[4] + getitem_77 = split_9[5] + getitem_78 = split_9[6] + getitem_79 = split_9[7]; split_9 = None + cat_1 = torch.ops.aten.cat.default([getitem_72, getitem_73, getitem_74, getitem_75, getitem_76, getitem_77, getitem_78, getitem_79], 1); getitem_72 = getitem_73 = getitem_74 = getitem_75 = getitem_76 = getitem_77 = getitem_78 = getitem_79 = None + view_15 = torch.ops.aten.view.default(cat_1, [16384, 4096]); cat_1 = None + view_16 = torch.ops.aten.view.default(mm, [2, 8192, 512]); mm = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16); primals_6 = None + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 32, '0'); convert_element_type_7 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + mm_1 = torch.ops.aten.mm.default(view_15, permute_1) + view_23 = torch.ops.aten.view.default(mm_1, [2, 8192, 128]); mm_1 = None + view_30 = torch.ops.aten.view.default(mm_2, [2, 8192, 128]); mm_2 = None + view_32 = torch.ops.aten.view.default(view_16, [2, 8192, -1, 128]); view_16 = None + view_33 = torch.ops.aten.view.default(view_23, [2, 8192, -1, 128]); view_23 = None + view_34 = torch.ops.aten.view.default(view_30, [2, 8192, -1, 128]); view_30 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_32, torch.float32); view_32 = None + view_35 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 4, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_35); view_35 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_33, torch.float32); view_33 = None + view_36 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 1, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_36); view_36 = None + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_37); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_38 = torch.ops.aten.view.default(view_as_real, [2, 8192, 4, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_37); view_as_complex_1 = view_37 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_39 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 1, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_38, torch.bfloat16); view_38 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_39, torch.bfloat16); view_39 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 1, 4, 128]); unsqueeze = None + view_40 = torch.ops.aten.view.default(expand, [2, 8192, 4, 128]); expand = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_34, 3); view_34 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 1, 4, 128]); unsqueeze_1 = None + view_41 = torch.ops.aten.view.default(expand_1, [2, 8192, 4, 128]); expand_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_40, [0, 2, 1, 3]); view_40 = None + permute_5 = torch.ops.aten.permute.default(view_41, [0, 2, 1, 3]); view_41 = None + _scaled_dot_product_cudnn_attention_backward_31 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1365, permute_3, permute_4, permute_5, getitem_80, getitem_81, getitem_86, getitem_87, None, None, None, 8192, 8192, 0.0, True); permute_1365 = permute_3 = permute_4 = permute_5 = getitem_80 = getitem_81 = getitem_86 = getitem_87 = None + getitem_2509 = _scaled_dot_product_cudnn_attention_backward_31[0] + getitem_2510 = _scaled_dot_product_cudnn_attention_backward_31[1] + getitem_2511 = _scaled_dot_product_cudnn_attention_backward_31[2]; _scaled_dot_product_cudnn_attention_backward_31 = None + permute_1366 = torch.ops.aten.permute.default(getitem_2511, [0, 2, 1, 3]); getitem_2511 = None + permute_1367 = torch.ops.aten.permute.default(getitem_2510, [0, 2, 1, 3]); getitem_2510 = None + permute_1368 = torch.ops.aten.permute.default(getitem_2509, [0, 2, 1, 3]); getitem_2509 = None + view_3076 = torch.ops.aten.view.default(permute_1366, [2, 8192, 1, 4, 128]); permute_1366 = None + sum_191 = torch.ops.aten.sum.dim_IntList(view_3076, [3], True); view_3076 = None + squeeze_62 = torch.ops.aten.squeeze.dim(sum_191, 3); sum_191 = None + view_3077 = torch.ops.aten.view.default(permute_1367, [2, 8192, 1, 4, 128]); permute_1367 = None + sum_192 = torch.ops.aten.sum.dim_IntList(view_3077, [3], True); view_3077 = None + squeeze_63 = torch.ops.aten.squeeze.dim(sum_192, 3); sum_192 = None + convert_element_type_2777 = torch.ops.prims.convert_element_type.default(squeeze_63, torch.float32); squeeze_63 = None + convert_element_type_2778 = torch.ops.prims.convert_element_type.default(permute_1368, torch.float32); permute_1368 = None + view_3078 = torch.ops.aten.view.default(convert_element_type_2777, [2, 8192, 1, 64, 2]); convert_element_type_2777 = None + view_as_complex_126 = torch.ops.aten.view_as_complex.default(view_3078); view_3078 = None + mul_896 = torch.ops.aten.mul.Tensor(view_as_complex_126, _conj); view_as_complex_126 = None + view_3079 = torch.ops.aten.view.default(convert_element_type_2778, [2, 8192, 4, 64, 2]); convert_element_type_2778 = None + view_as_complex_127 = torch.ops.aten.view_as_complex.default(view_3079); view_3079 = None + mul_897 = torch.ops.aten.mul.Tensor(view_as_complex_127, _conj); view_as_complex_127 = _conj = None + view_as_real_126 = torch.ops.aten.view_as_real.default(mul_896); mul_896 = None + view_3080 = torch.ops.aten.view.default(view_as_real_126, [2, 8192, 1, 128]); view_as_real_126 = None + convert_element_type_2779 = torch.ops.prims.convert_element_type.default(view_3080, torch.bfloat16); view_3080 = None + view_as_real_127 = torch.ops.aten.view_as_real.default(mul_897); mul_897 = None + view_3081 = torch.ops.aten.view.default(view_as_real_127, [2, 8192, 4, 128]); view_as_real_127 = None + convert_element_type_2780 = torch.ops.prims.convert_element_type.default(view_3081, torch.bfloat16); view_3081 = None + view_3082 = torch.ops.aten.view.default(squeeze_62, [2, 8192, 128]); squeeze_62 = None + view_3083 = torch.ops.aten.view.default(convert_element_type_2779, [2, 8192, 128]); convert_element_type_2779 = None + view_3084 = torch.ops.aten.view.default(convert_element_type_2780, [2, 8192, 512]); convert_element_type_2780 = None + view_3085 = torch.ops.aten.view.default(view_3082, [16384, 128]); view_3082 = None + permute_1369 = torch.ops.aten.permute.default(view_3085, [1, 0]) + mm_669 = torch.ops.aten.mm.default(permute_1369, view_15); permute_1369 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16); primals_7 = None + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 32, '0'); convert_element_type_10 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + permute_1371 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + mm_670 = torch.ops.aten.mm.default(view_3085, permute_1371); view_3085 = permute_1371 = None + view_3086 = torch.ops.aten.view.default(mm_670, [2, 8192, 4096]); mm_670 = None + convert_element_type_2785 = torch.ops.prims.convert_element_type.default(mm_669, torch.float32); mm_669 = None + reduce_scatter_tensor_415 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2785, 'avg', 32, '0'); convert_element_type_2785 = None + wait_tensor_899 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_415); reduce_scatter_tensor_415 = None + view_3087 = torch.ops.aten.view.default(view_3083, [16384, 128]); view_3083 = None + permute_1373 = torch.ops.aten.permute.default(view_3087, [1, 0]) + mm_671 = torch.ops.aten.mm.default(permute_1373, view_15); permute_1373 = None + permute_1375 = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None + mm_672 = torch.ops.aten.mm.default(view_3087, permute_1375); view_3087 = permute_1375 = None + view_3088 = torch.ops.aten.view.default(mm_672, [2, 8192, 4096]); mm_672 = None + add_350 = torch.ops.aten.add.Tensor(view_3086, view_3088); view_3086 = view_3088 = None + convert_element_type_2790 = torch.ops.prims.convert_element_type.default(mm_671, torch.float32); mm_671 = None + reduce_scatter_tensor_416 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2790, 'avg', 32, '0'); convert_element_type_2790 = None + wait_tensor_900 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_416); reduce_scatter_tensor_416 = None + view_3089 = torch.ops.aten.view.default(view_3084, [16384, 512]); view_3084 = None + permute_1377 = torch.ops.aten.permute.default(view_3089, [1, 0]) + mm_673 = torch.ops.aten.mm.default(permute_1377, view_15); permute_1377 = view_15 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16); primals_5 = None + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 32, '0'); convert_element_type_4 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + permute_1379 = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + mm_674 = torch.ops.aten.mm.default(view_3089, permute_1379); view_3089 = permute_1379 = None + view_3090 = torch.ops.aten.view.default(mm_674, [2, 8192, 4096]); mm_674 = None + add_351 = torch.ops.aten.add.Tensor(add_350, view_3090); add_350 = view_3090 = None + convert_element_type_2795 = torch.ops.prims.convert_element_type.default(mm_673, torch.float32); mm_673 = None + reduce_scatter_tensor_417 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2795, 'avg', 32, '0'); convert_element_type_2795 = None + wait_tensor_901 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_417); reduce_scatter_tensor_417 = None + split_266 = torch.ops.aten.split.Tensor(add_351, 1024, 1); add_351 = None + getitem_2512 = split_266[0] + getitem_2513 = split_266[1] + getitem_2514 = split_266[2] + getitem_2515 = split_266[3] + getitem_2516 = split_266[4] + getitem_2517 = split_266[5] + getitem_2518 = split_266[6] + getitem_2519 = split_266[7]; split_266 = None + cat_258 = torch.ops.aten.cat.default([getitem_2512, getitem_2513, getitem_2514, getitem_2515, getitem_2516, getitem_2517, getitem_2518, getitem_2519]); getitem_2512 = getitem_2513 = getitem_2514 = getitem_2515 = getitem_2516 = getitem_2517 = getitem_2518 = getitem_2519 = None + reduce_scatter_tensor_418 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_258, 'sum', 8, '1'); cat_258 = None + wait_tensor_902 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_418); reduce_scatter_tensor_418 = None + convert_element_type_2796 = torch.ops.prims.convert_element_type.default(wait_tensor_902, torch.float32); wait_tensor_902 = None + convert_element_type_2798 = torch.ops.prims.convert_element_type.default(wait_tensor_2, torch.float32); wait_tensor_2 = None + mul_898 = torch.ops.aten.mul.Tensor(convert_element_type_2796, convert_element_type_2798); convert_element_type_2798 = None + mul_900 = torch.ops.aten.mul.Tensor(mul, mul_898) + sum_193 = torch.ops.aten.sum.dim_IntList(mul_900, [2], True); mul_900 = None + div_64 = torch.ops.aten.div.Tensor(mul, 4096) + mul_901 = torch.ops.aten.mul.Tensor(div_64, sum_193); div_64 = sum_193 = None + sub_97 = torch.ops.aten.sub.Tensor(mul_898, mul_901); mul_898 = mul_901 = None + mul_902 = torch.ops.aten.mul.Tensor(sub_97, rsqrt); sub_97 = rsqrt = None + mul_903 = torch.ops.aten.mul.Tensor(convert_element_type_2796, mul); convert_element_type_2796 = mul = None + sum_194 = torch.ops.aten.sum.dim_IntList(mul_903, [0, 1]); mul_903 = None + convert_element_type_2799 = torch.ops.prims.convert_element_type.default(mul_902, torch.bfloat16); mul_902 = None + convert_element_type_2800 = torch.ops.prims.convert_element_type.default(sum_194, torch.bfloat16); sum_194 = None + all_reduce_64 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_2800, 'sum', '1'); convert_element_type_2800 = None + wait_tensor_903 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_64); all_reduce_64 = None + convert_element_type_2801 = torch.ops.prims.convert_element_type.default(wait_tensor_903, torch.float32); wait_tensor_903 = None + reduce_scatter_tensor_419 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2801, 'avg', 32, '0'); convert_element_type_2801 = None + wait_tensor_904 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_419); reduce_scatter_tensor_419 = None + add_352 = torch.ops.aten.add.Tensor(add_349, convert_element_type_2799); add_349 = convert_element_type_2799 = None + all_gather_into_tensor_420 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_352, 8, '1'); add_352 = None + wait_tensor_905 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_420); all_gather_into_tensor_420 = None + split_267 = torch.ops.aten.split.Tensor(wait_tensor_905, 2); wait_tensor_905 = None + getitem_2520 = split_267[0] + getitem_2521 = split_267[1] + getitem_2522 = split_267[2] + getitem_2523 = split_267[3] + getitem_2524 = split_267[4] + getitem_2525 = split_267[5] + getitem_2526 = split_267[6] + getitem_2527 = split_267[7]; split_267 = None + cat_259 = torch.ops.aten.cat.default([getitem_2520, getitem_2521, getitem_2522, getitem_2523, getitem_2524, getitem_2525, getitem_2526, getitem_2527], 1); getitem_2520 = getitem_2521 = getitem_2522 = getitem_2523 = getitem_2524 = getitem_2525 = getitem_2526 = getitem_2527 = None + convert_element_type_2802 = torch.ops.prims.convert_element_type.default(cat_259, torch.float32); cat_259 = None + eq = torch.ops.aten.eq.Scalar(primals_1, -1) + unsqueeze_64 = torch.ops.aten.unsqueeze.default(eq, -1); eq = None + full_default_2 = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + where = torch.ops.aten.where.self(unsqueeze_64, full_default_2, convert_element_type_2802); unsqueeze_64 = full_default_2 = convert_element_type_2802 = None + full_default_3 = torch.ops.aten.full.default([128256, 4096], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_2 = torch.ops.aten.index_put.default(full_default_3, [primals_1], where, True); full_default_3 = primals_1 = where = None + convert_element_type_2803 = torch.ops.prims.convert_element_type.default(index_put_2, torch.bfloat16); index_put_2 = None + split_268 = torch.ops.aten.split.Tensor(convert_element_type_2803, 16032); convert_element_type_2803 = None + getitem_2528 = split_268[0]; split_268 = None + convert_element_type_2804 = torch.ops.prims.convert_element_type.default(getitem_2528, torch.float32); getitem_2528 = None + reduce_scatter_tensor_420 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2804, 'avg', 32, '0'); convert_element_type_2804 = None + wait_tensor_906 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_420); reduce_scatter_tensor_420 = None + return (None, wait_tensor_906, None, wait_tensor_904, wait_tensor_901, wait_tensor_900, wait_tensor_899, wait_tensor_898, wait_tensor_896, wait_tensor_893, wait_tensor_892, wait_tensor_891, wait_tensor_889, wait_tensor_886, wait_tensor_885, wait_tensor_884, wait_tensor_883, wait_tensor_881, wait_tensor_878, wait_tensor_877, wait_tensor_876, wait_tensor_874, wait_tensor_871, wait_tensor_870, wait_tensor_869, wait_tensor_868, wait_tensor_866, wait_tensor_863, wait_tensor_862, wait_tensor_861, wait_tensor_859, wait_tensor_856, wait_tensor_855, wait_tensor_854, wait_tensor_853, wait_tensor_851, wait_tensor_848, wait_tensor_847, wait_tensor_846, wait_tensor_844, wait_tensor_841, wait_tensor_840, wait_tensor_839, wait_tensor_838, wait_tensor_836, wait_tensor_833, wait_tensor_832, wait_tensor_831, wait_tensor_829, wait_tensor_826, wait_tensor_825, wait_tensor_824, wait_tensor_823, wait_tensor_821, wait_tensor_818, wait_tensor_817, wait_tensor_816, wait_tensor_814, wait_tensor_811, wait_tensor_810, wait_tensor_809, wait_tensor_808, wait_tensor_806, wait_tensor_803, wait_tensor_802, wait_tensor_801, wait_tensor_799, wait_tensor_796, wait_tensor_795, wait_tensor_794, wait_tensor_793, wait_tensor_791, wait_tensor_788, wait_tensor_787, wait_tensor_786, wait_tensor_784, wait_tensor_781, wait_tensor_780, wait_tensor_779, wait_tensor_778, wait_tensor_776, wait_tensor_773, wait_tensor_772, wait_tensor_771, wait_tensor_769, wait_tensor_766, wait_tensor_765, wait_tensor_764, wait_tensor_763, wait_tensor_761, wait_tensor_758, wait_tensor_757, wait_tensor_756, wait_tensor_754, wait_tensor_751, wait_tensor_750, wait_tensor_749, wait_tensor_748, wait_tensor_746, wait_tensor_743, wait_tensor_742, wait_tensor_741, wait_tensor_739, wait_tensor_736, wait_tensor_735, wait_tensor_734, wait_tensor_733, wait_tensor_731, wait_tensor_728, wait_tensor_727, wait_tensor_726, wait_tensor_724, wait_tensor_721, wait_tensor_720, wait_tensor_719, wait_tensor_718, wait_tensor_716, wait_tensor_713, wait_tensor_712, wait_tensor_711, wait_tensor_709, wait_tensor_706, wait_tensor_705, wait_tensor_704, wait_tensor_703, wait_tensor_701, wait_tensor_698, wait_tensor_697, wait_tensor_696, wait_tensor_694, wait_tensor_691, wait_tensor_690, wait_tensor_689, wait_tensor_688, wait_tensor_686, wait_tensor_683, wait_tensor_682, wait_tensor_681, wait_tensor_679, wait_tensor_676, wait_tensor_675, wait_tensor_674, wait_tensor_673, wait_tensor_671, wait_tensor_668, wait_tensor_667, wait_tensor_666, wait_tensor_664, wait_tensor_661, wait_tensor_660, wait_tensor_659, wait_tensor_658, wait_tensor_656, wait_tensor_653, wait_tensor_652, wait_tensor_651, wait_tensor_649, wait_tensor_646, wait_tensor_645, wait_tensor_644, wait_tensor_643, wait_tensor_641, wait_tensor_638, wait_tensor_637, wait_tensor_636, wait_tensor_634, wait_tensor_631, wait_tensor_630, wait_tensor_629, wait_tensor_628, wait_tensor_626, wait_tensor_623, wait_tensor_622, wait_tensor_621, wait_tensor_619, wait_tensor_616, wait_tensor_615, wait_tensor_614, wait_tensor_613, wait_tensor_611, wait_tensor_608, wait_tensor_607, wait_tensor_606, wait_tensor_604, wait_tensor_601, wait_tensor_600, wait_tensor_599, wait_tensor_598, wait_tensor_596, wait_tensor_593, wait_tensor_592, wait_tensor_591, wait_tensor_589, wait_tensor_586, wait_tensor_585, wait_tensor_584, wait_tensor_583, wait_tensor_581, wait_tensor_578, wait_tensor_577, wait_tensor_576, wait_tensor_574, wait_tensor_571, wait_tensor_570, wait_tensor_569, wait_tensor_568, wait_tensor_566, wait_tensor_563, wait_tensor_562, wait_tensor_561, wait_tensor_559, wait_tensor_556, wait_tensor_555, wait_tensor_554, wait_tensor_553, wait_tensor_551, wait_tensor_548, wait_tensor_547, wait_tensor_546, wait_tensor_544, wait_tensor_541, wait_tensor_540, wait_tensor_539, wait_tensor_538, wait_tensor_536, wait_tensor_533, wait_tensor_532, wait_tensor_531, wait_tensor_529, wait_tensor_526, wait_tensor_525, wait_tensor_524, wait_tensor_523, wait_tensor_521, wait_tensor_518, wait_tensor_517, wait_tensor_516, wait_tensor_514, wait_tensor_511, wait_tensor_510, wait_tensor_509, wait_tensor_508, wait_tensor_506, wait_tensor_503, wait_tensor_502, wait_tensor_501, wait_tensor_499, wait_tensor_496, wait_tensor_495, wait_tensor_494, wait_tensor_493, wait_tensor_491, wait_tensor_488, wait_tensor_487, wait_tensor_486, wait_tensor_484, wait_tensor_481, wait_tensor_480, wait_tensor_479, wait_tensor_478, wait_tensor_476, wait_tensor_473, wait_tensor_472, wait_tensor_471, wait_tensor_469, wait_tensor_466, wait_tensor_465, wait_tensor_464, wait_tensor_463, wait_tensor_461, wait_tensor_458, wait_tensor_457, wait_tensor_456, wait_tensor_454, wait_tensor_451, wait_tensor_450, wait_tensor_449, wait_tensor_448, wait_tensor_446, wait_tensor_443, wait_tensor_442, wait_tensor_441, wait_tensor_439, wait_tensor_436, wait_tensor_435, wait_tensor_434, wait_tensor_433, wait_tensor_431, wait_tensor_428, wait_tensor_427, wait_tensor_426, wait_tensor_424, wait_tensor_421) + +def load_args(reader): + buf0 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf0, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_1 + buf1 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf1, (501, 4096), is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf3, (128,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf4, (16, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf5, (4, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf6, (4, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf7, (128, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf8, (128,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf9, (56, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf10, (56, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf11, (128, 1792), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf12, (128,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf13, (16, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf14, (4, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf15, (4, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf16, (128, 512), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf17, (128,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf18, (56, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf19, (56, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf20, (128, 1792), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf21, (128,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf22, (16, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf23, (4, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf24, (4, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf25, (128, 512), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf26, (128,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf27, (56, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf28, (56, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf29, (128, 1792), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf30, (128,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf31, (16, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf32, (4, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf33, (4, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf34, (128, 512), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf35, (128,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf36, (56, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf37, (56, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf38, (128, 1792), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf39, (128,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf40, (16, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf41, (4, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf42, (4, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf43, (128, 512), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf44, (128,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf45, (56, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf46, (56, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf47, (128, 1792), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf48, (128,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf49, (16, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf50, (4, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf51, (4, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf52, (128, 512), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf53, (128,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf54, (56, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf55, (56, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf56, (128, 1792), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf57, (128,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf58, (16, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf59, (4, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf60, (4, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf61, (128, 512), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf62, (128,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf63, (56, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf64, (56, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf65, (128, 1792), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf66, (128,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf67, (16, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf68, (4, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf69, (4, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf70, (128, 512), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf71, (128,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf72, (56, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf73, (56, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf74, (128, 1792), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf75, (128,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf76, (16, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf77, (4, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf78, (4, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf79, (128, 512), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf80, (128,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf81, (56, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf82, (56, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf83, (128, 1792), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf84, (128,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf85, (16, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf86, (4, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf87, (4, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf88, (128, 512), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf89, (128,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf90, (56, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf91, (56, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf92, (128, 1792), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf93, (128,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf94, (16, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf95, (4, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf96, (4, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf97, (128, 512), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf98, (128,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf99, (56, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf100, (56, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf101, (128, 1792), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf102, (128,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf103, (16, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf104, (4, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf105, (4, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf106, (128, 512), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf107, (128,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf108, (56, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf109, (56, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf110, (128, 1792), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf111, (128,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf112, (16, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf113, (4, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf114, (4, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf115, (128, 512), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf116, (128,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf117, (56, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf118, (56, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf119, (128, 1792), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf120, (128,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf121, (16, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf122, (4, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf123, (4, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf124, (128, 512), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf125, (128,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf126, (56, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf127, (56, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf128, (128, 1792), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf129, (128,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf130, (16, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf131, (4, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf132, (4, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf133, (128, 512), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf134, (128,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf135, (56, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf136, (56, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf137, (128, 1792), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf138, (128,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf139, (16, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf140, (4, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf141, (4, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf142, (128, 512), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf143, (128,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf144, (56, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf145, (56, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf146, (128, 1792), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf147, (128,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf148, (16, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf149, (4, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf150, (4, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf151, (128, 512), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf152, (128,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf153, (56, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf154, (56, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf155, (128, 1792), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf156, (128,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf157, (16, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf158, (4, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf159, (4, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf160, (128, 512), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf161, (128,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf162, (56, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf163, (56, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf164, (128, 1792), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf165, (128,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf166, (16, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf167, (4, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf168, (4, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf169, (128, 512), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf170, (128,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf171, (56, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf172, (56, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf173, (128, 1792), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf174, (128,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf175, (16, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf176, (4, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf177, (4, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf178, (128, 512), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf179, (128,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf180, (56, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf181, (56, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf182, (128, 1792), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf183, (128,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf184, (16, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf185, (4, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf186, (4, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf187, (128, 512), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf188, (128,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf189, (56, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf190, (56, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf191, (128, 1792), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf192, (128,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf193, (16, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf194, (4, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf195, (4, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf196, (128, 512), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf197, (128,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf198, (56, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf199, (56, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf200, (128, 1792), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf201, (128,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf202, (16, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf203, (4, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf204, (4, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf205, (128, 512), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf206, (128,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf207, (56, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf208, (56, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf209, (128, 1792), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf210, (128,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf211, (16, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf212, (4, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf213, (4, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf214, (128, 512), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf215, (128,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf216, (56, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf217, (56, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf218, (128, 1792), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf219, (128,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf220, (16, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf221, (4, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf222, (4, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf223, (128, 512), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf224, (128,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf225, (56, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf226, (56, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf227, (128, 1792), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf228, (128,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf229, (16, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf230, (4, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf231, (4, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf232, (128, 512), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf233, (128,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf234, (56, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf235, (56, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf236, (128, 1792), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf237, (128,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf238, (16, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf239, (4, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf240, (4, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf241, (128, 512), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf242, (128,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf243, (56, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf244, (56, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf245, (128, 1792), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf246, (128,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf247, (16, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf248, (4, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf249, (4, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf250, (128, 512), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf251, (128,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf252, (56, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf253, (56, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf254, (128, 1792), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf255, (128,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf256, (16, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf257, (4, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf258, (4, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf259, (128, 512), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf260, (128,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf261, (56, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf262, (56, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf263, (128, 1792), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf264, (128,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf265, (16, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf266, (4, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf267, (4, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf268, (128, 512), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf269, (128,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf270, (56, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf271, (56, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf272, (128, 1792), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf273, (128,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf274, (16, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf275, (4, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf276, (4, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf277, (128, 512), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf278, (128,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf279, (56, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf280, (56, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf281, (128, 1792), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf282, (128,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf283, (16, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf284, (4, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf285, (4, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf286, (128, 512), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf287, (128,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf288, (56, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf289, (56, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf290, (128, 1792), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf291, (128,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf292, (501, 4096), is_leaf=True) # primals_293 + buf293 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf293, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # wait_tensor_1 + buf294 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf294, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm + buf295 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf295, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_2 + buf296 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf296, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_80 + buf297 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf297, (2, 4, 8192, 1), is_leaf=True) # getitem_81 + buf298 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf298, (), dtype=torch.int64, is_leaf=True) # getitem_86 + buf299 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf299, (), dtype=torch.int64, is_leaf=True) # getitem_87 + buf300 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf300, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_1 + buf301 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf301, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_4 + buf302 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf302, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_3 + buf303 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf303, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_7 + buf304 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf304, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_9 + buf305 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf305, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_121 + buf306 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf306, (2, 4, 8192, 1), is_leaf=True) # getitem_122 + buf307 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf307, (), dtype=torch.int64, is_leaf=True) # getitem_127 + buf308 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf308, (), dtype=torch.int64, is_leaf=True) # getitem_128 + buf309 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf309, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_3 + buf310 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf310, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_11 + buf311 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf311, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_7 + buf312 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf312, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_14 + buf313 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf313, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_16 + buf314 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf314, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_162 + buf315 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf315, (2, 4, 8192, 1), is_leaf=True) # getitem_163 + buf316 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf316, (), dtype=torch.int64, is_leaf=True) # getitem_168 + buf317 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf317, (), dtype=torch.int64, is_leaf=True) # getitem_169 + buf318 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf318, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_5 + buf319 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf319, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_18 + buf320 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf320, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_11 + buf321 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf321, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_21 + buf322 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf322, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_23 + buf323 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf323, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_203 + buf324 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf324, (2, 4, 8192, 1), is_leaf=True) # getitem_204 + buf325 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf325, (), dtype=torch.int64, is_leaf=True) # getitem_209 + buf326 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf326, (), dtype=torch.int64, is_leaf=True) # getitem_210 + buf327 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf327, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_7 + buf328 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf328, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_25 + buf329 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf329, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_15 + buf330 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf330, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_28 + buf331 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf331, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_30 + buf332 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf332, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_244 + buf333 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf333, (2, 4, 8192, 1), is_leaf=True) # getitem_245 + buf334 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf334, (), dtype=torch.int64, is_leaf=True) # getitem_250 + buf335 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf335, (), dtype=torch.int64, is_leaf=True) # getitem_251 + buf336 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf336, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_9 + buf337 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf337, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_32 + buf338 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf338, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_19 + buf339 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf339, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_35 + buf340 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf340, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_37 + buf341 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf341, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_285 + buf342 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf342, (2, 4, 8192, 1), is_leaf=True) # getitem_286 + buf343 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf343, (), dtype=torch.int64, is_leaf=True) # getitem_291 + buf344 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf344, (), dtype=torch.int64, is_leaf=True) # getitem_292 + buf345 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf345, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_11 + buf346 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf346, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_39 + buf347 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf347, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_23 + buf348 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf348, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_42 + buf349 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf349, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_44 + buf350 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf350, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_326 + buf351 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf351, (2, 4, 8192, 1), is_leaf=True) # getitem_327 + buf352 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf352, (), dtype=torch.int64, is_leaf=True) # getitem_332 + buf353 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf353, (), dtype=torch.int64, is_leaf=True) # getitem_333 + buf354 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf354, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_13 + buf355 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf355, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_46 + buf356 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf356, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_27 + buf357 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf357, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_49 + buf358 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf358, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_51 + buf359 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf359, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_367 + buf360 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf360, (2, 4, 8192, 1), is_leaf=True) # getitem_368 + buf361 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf361, (), dtype=torch.int64, is_leaf=True) # getitem_373 + buf362 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf362, (), dtype=torch.int64, is_leaf=True) # getitem_374 + buf363 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf363, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_15 + buf364 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf364, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_53 + buf365 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf365, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_31 + buf366 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf366, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_56 + buf367 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf367, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_58 + buf368 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf368, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_408 + buf369 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf369, (2, 4, 8192, 1), is_leaf=True) # getitem_409 + buf370 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf370, (), dtype=torch.int64, is_leaf=True) # getitem_414 + buf371 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf371, (), dtype=torch.int64, is_leaf=True) # getitem_415 + buf372 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf372, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_17 + buf373 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf373, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_60 + buf374 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf374, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_35 + buf375 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf375, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_63 + buf376 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf376, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_65 + buf377 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf377, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_449 + buf378 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf378, (2, 4, 8192, 1), is_leaf=True) # getitem_450 + buf379 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf379, (), dtype=torch.int64, is_leaf=True) # getitem_455 + buf380 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf380, (), dtype=torch.int64, is_leaf=True) # getitem_456 + buf381 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf381, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_19 + buf382 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf382, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_67 + buf383 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf383, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_39 + buf384 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf384, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_70 + buf385 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf385, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_72 + buf386 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf386, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_490 + buf387 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf387, (2, 4, 8192, 1), is_leaf=True) # getitem_491 + buf388 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf388, (), dtype=torch.int64, is_leaf=True) # getitem_496 + buf389 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf389, (), dtype=torch.int64, is_leaf=True) # getitem_497 + buf390 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf390, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_21 + buf391 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf391, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_74 + buf392 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf392, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_43 + buf393 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf393, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_77 + buf394 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf394, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_79 + buf395 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf395, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_531 + buf396 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf396, (2, 4, 8192, 1), is_leaf=True) # getitem_532 + buf397 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf397, (), dtype=torch.int64, is_leaf=True) # getitem_537 + buf398 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf398, (), dtype=torch.int64, is_leaf=True) # getitem_538 + buf399 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf399, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_23 + buf400 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf400, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_81 + buf401 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf401, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_47 + buf402 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf402, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_84 + buf403 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf403, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_86 + buf404 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf404, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_572 + buf405 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf405, (2, 4, 8192, 1), is_leaf=True) # getitem_573 + buf406 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf406, (), dtype=torch.int64, is_leaf=True) # getitem_578 + buf407 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf407, (), dtype=torch.int64, is_leaf=True) # getitem_579 + buf408 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf408, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_25 + buf409 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf409, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_88 + buf410 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf410, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_51 + buf411 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf411, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_91 + buf412 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf412, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_93 + buf413 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf413, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_613 + buf414 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf414, (2, 4, 8192, 1), is_leaf=True) # getitem_614 + buf415 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf415, (), dtype=torch.int64, is_leaf=True) # getitem_619 + buf416 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf416, (), dtype=torch.int64, is_leaf=True) # getitem_620 + buf417 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf417, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_27 + buf418 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf418, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_95 + buf419 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf419, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_55 + buf420 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf420, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_98 + buf421 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf421, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_100 + buf422 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf422, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_654 + buf423 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf423, (2, 4, 8192, 1), is_leaf=True) # getitem_655 + buf424 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf424, (), dtype=torch.int64, is_leaf=True) # getitem_660 + buf425 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf425, (), dtype=torch.int64, is_leaf=True) # getitem_661 + buf426 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf426, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_29 + buf427 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf427, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_102 + buf428 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf428, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_59 + buf429 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf429, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_105 + buf430 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf430, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_107 + buf431 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf431, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_695 + buf432 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf432, (2, 4, 8192, 1), is_leaf=True) # getitem_696 + buf433 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf433, (), dtype=torch.int64, is_leaf=True) # getitem_701 + buf434 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf434, (), dtype=torch.int64, is_leaf=True) # getitem_702 + buf435 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf435, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_31 + buf436 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf436, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_109 + buf437 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf437, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_63 + buf438 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf438, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_112 + buf439 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf439, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_114 + buf440 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf440, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_736 + buf441 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf441, (2, 4, 8192, 1), is_leaf=True) # getitem_737 + buf442 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf442, (), dtype=torch.int64, is_leaf=True) # getitem_742 + buf443 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf443, (), dtype=torch.int64, is_leaf=True) # getitem_743 + buf444 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf444, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_33 + buf445 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf445, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_116 + buf446 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf446, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_67 + buf447 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf447, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_119 + buf448 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf448, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_121 + buf449 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf449, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_777 + buf450 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf450, (2, 4, 8192, 1), is_leaf=True) # getitem_778 + buf451 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf451, (), dtype=torch.int64, is_leaf=True) # getitem_783 + buf452 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf452, (), dtype=torch.int64, is_leaf=True) # getitem_784 + buf453 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf453, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_35 + buf454 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf454, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_123 + buf455 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf455, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_71 + buf456 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf456, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_126 + buf457 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf457, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_128 + buf458 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf458, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_818 + buf459 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf459, (2, 4, 8192, 1), is_leaf=True) # getitem_819 + buf460 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf460, (), dtype=torch.int64, is_leaf=True) # getitem_824 + buf461 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf461, (), dtype=torch.int64, is_leaf=True) # getitem_825 + buf462 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf462, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_37 + buf463 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf463, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_130 + buf464 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf464, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_75 + buf465 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf465, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_133 + buf466 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf466, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_135 + buf467 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf467, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_859 + buf468 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf468, (2, 4, 8192, 1), is_leaf=True) # getitem_860 + buf469 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf469, (), dtype=torch.int64, is_leaf=True) # getitem_865 + buf470 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf470, (), dtype=torch.int64, is_leaf=True) # getitem_866 + buf471 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf471, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_39 + buf472 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf472, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_137 + buf473 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf473, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_79 + buf474 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf474, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_140 + buf475 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf475, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_142 + buf476 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf476, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_900 + buf477 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf477, (2, 4, 8192, 1), is_leaf=True) # getitem_901 + buf478 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf478, (), dtype=torch.int64, is_leaf=True) # getitem_906 + buf479 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf479, (), dtype=torch.int64, is_leaf=True) # getitem_907 + buf480 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf480, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_41 + buf481 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf481, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_144 + buf482 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf482, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_83 + buf483 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf483, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_147 + buf484 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf484, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_149 + buf485 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf485, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_941 + buf486 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf486, (2, 4, 8192, 1), is_leaf=True) # getitem_942 + buf487 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf487, (), dtype=torch.int64, is_leaf=True) # getitem_947 + buf488 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf488, (), dtype=torch.int64, is_leaf=True) # getitem_948 + buf489 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf489, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_43 + buf490 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf490, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_151 + buf491 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf491, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_87 + buf492 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf492, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_154 + buf493 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf493, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_156 + buf494 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf494, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_982 + buf495 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf495, (2, 4, 8192, 1), is_leaf=True) # getitem_983 + buf496 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf496, (), dtype=torch.int64, is_leaf=True) # getitem_988 + buf497 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf497, (), dtype=torch.int64, is_leaf=True) # getitem_989 + buf498 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf498, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_45 + buf499 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf499, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_158 + buf500 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf500, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_91 + buf501 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf501, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_161 + buf502 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf502, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_163 + buf503 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf503, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1023 + buf504 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf504, (2, 4, 8192, 1), is_leaf=True) # getitem_1024 + buf505 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf505, (), dtype=torch.int64, is_leaf=True) # getitem_1029 + buf506 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf506, (), dtype=torch.int64, is_leaf=True) # getitem_1030 + buf507 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf507, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_47 + buf508 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf508, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_165 + buf509 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf509, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_95 + buf510 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf510, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_168 + buf511 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf511, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_170 + buf512 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf512, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1064 + buf513 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf513, (2, 4, 8192, 1), is_leaf=True) # getitem_1065 + buf514 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf514, (), dtype=torch.int64, is_leaf=True) # getitem_1070 + buf515 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf515, (), dtype=torch.int64, is_leaf=True) # getitem_1071 + buf516 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf516, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_49 + buf517 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf517, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_172 + buf518 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf518, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_99 + buf519 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf519, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_175 + buf520 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf520, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_177 + buf521 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf521, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1105 + buf522 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf522, (2, 4, 8192, 1), is_leaf=True) # getitem_1106 + buf523 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf523, (), dtype=torch.int64, is_leaf=True) # getitem_1111 + buf524 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf524, (), dtype=torch.int64, is_leaf=True) # getitem_1112 + buf525 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf525, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_51 + buf526 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf526, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_179 + buf527 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf527, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_103 + buf528 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf528, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_182 + buf529 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf529, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_184 + buf530 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf530, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1146 + buf531 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf531, (2, 4, 8192, 1), is_leaf=True) # getitem_1147 + buf532 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf532, (), dtype=torch.int64, is_leaf=True) # getitem_1152 + buf533 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf533, (), dtype=torch.int64, is_leaf=True) # getitem_1153 + buf534 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf534, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_53 + buf535 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf535, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_186 + buf536 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf536, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_107 + buf537 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf537, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_189 + buf538 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf538, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_191 + buf539 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf539, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1187 + buf540 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf540, (2, 4, 8192, 1), is_leaf=True) # getitem_1188 + buf541 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf541, (), dtype=torch.int64, is_leaf=True) # getitem_1193 + buf542 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf542, (), dtype=torch.int64, is_leaf=True) # getitem_1194 + buf543 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf543, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_55 + buf544 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf544, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_193 + buf545 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf545, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_111 + buf546 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf546, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_196 + buf547 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf547, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_198 + buf548 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf548, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1228 + buf549 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf549, (2, 4, 8192, 1), is_leaf=True) # getitem_1229 + buf550 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf550, (), dtype=torch.int64, is_leaf=True) # getitem_1234 + buf551 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf551, (), dtype=torch.int64, is_leaf=True) # getitem_1235 + buf552 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf552, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_57 + buf553 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf553, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_200 + buf554 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf554, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_115 + buf555 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf555, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_203 + buf556 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf556, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_205 + buf557 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf557, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1269 + buf558 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf558, (2, 4, 8192, 1), is_leaf=True) # getitem_1270 + buf559 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf559, (), dtype=torch.int64, is_leaf=True) # getitem_1275 + buf560 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf560, (), dtype=torch.int64, is_leaf=True) # getitem_1276 + buf561 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf561, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_59 + buf562 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf562, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_207 + buf563 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf563, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_119 + buf564 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf564, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_210 + buf565 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf565, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_212 + buf566 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf566, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1310 + buf567 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf567, (2, 4, 8192, 1), is_leaf=True) # getitem_1311 + buf568 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf568, (), dtype=torch.int64, is_leaf=True) # getitem_1316 + buf569 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf569, (), dtype=torch.int64, is_leaf=True) # getitem_1317 + buf570 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf570, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_61 + buf571 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf571, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_214 + buf572 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf572, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_123 + buf573 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf573, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_217 + buf574 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf574, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_219 + buf575 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf575, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_1351 + buf576 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf576, (2, 4, 8192, 1), is_leaf=True) # getitem_1352 + buf577 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf577, (), dtype=torch.int64, is_leaf=True) # getitem_1357 + buf578 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf578, (), dtype=torch.int64, is_leaf=True) # getitem_1358 + buf579 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf579, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_63 + buf580 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf580, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_221 + buf581 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf581, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_64 + buf582 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf582, (2, 1024, 1), is_leaf=True) # rsqrt_64 + buf583 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf583, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # view_2319 + buf584 = reader.storage(None, 525336576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf584, (2, 8192, 16032), dtype=torch.bfloat16, is_leaf=True) # tangents_1 +load_args._version = 0 + +def get_pg_config(): + return {'0': {'size': 32, 'rank': 0}, '1': {'size': 8, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls32_8.table" diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_1d.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_1d.py new file mode 100644 index 00000000..828dc735 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_1d.py @@ -0,0 +1,8953 @@ +# fmt: off +# flake8: noqa +# isort: skip_file +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, embedding, mm, mm_2, getitem, getitem_1, getitem_6, getitem_7, mm_4, add_3, mm_7, mm_9, getitem_9, getitem_10, getitem_15, getitem_16, mm_11, add_7, mm_14, mm_16, getitem_18, getitem_19, getitem_24, getitem_25, mm_18, add_11, mm_21, mm_23, getitem_27, getitem_28, getitem_33, getitem_34, mm_25, add_15, mm_28, mm_30, getitem_36, getitem_37, getitem_42, getitem_43, mm_32, add_19, mm_35, mm_37, getitem_45, getitem_46, getitem_51, getitem_52, mm_39, add_23, mm_42, mm_44, getitem_54, getitem_55, getitem_60, getitem_61, mm_46, add_27, mm_49, mm_51, getitem_63, getitem_64, getitem_69, getitem_70, mm_53, add_31, mm_56, mm_58, getitem_72, getitem_73, getitem_78, getitem_79, mm_60, add_35, mm_63, mm_65, getitem_81, getitem_82, getitem_87, getitem_88, mm_67, add_39, mm_70, mm_72, getitem_90, getitem_91, getitem_96, getitem_97, mm_74, add_43, mm_77, mm_79, getitem_99, getitem_100, getitem_105, getitem_106, mm_81, add_47, mm_84, mm_86, getitem_108, getitem_109, getitem_114, getitem_115, mm_88, add_51, mm_91, mm_93, getitem_117, getitem_118, getitem_123, getitem_124, mm_95, add_55, mm_98, mm_100, getitem_126, getitem_127, getitem_132, getitem_133, mm_102, add_59, mm_105, mm_107, getitem_135, getitem_136, getitem_141, getitem_142, mm_109, add_63, mm_112, mm_114, getitem_144, getitem_145, getitem_150, getitem_151, mm_116, add_67, mm_119, mm_121, getitem_153, getitem_154, getitem_159, getitem_160, mm_123, add_71, mm_126, mm_128, getitem_162, getitem_163, getitem_168, getitem_169, mm_130, add_75, mm_133, mm_135, getitem_171, getitem_172, getitem_177, getitem_178, mm_137, add_79, mm_140, mm_142, getitem_180, getitem_181, getitem_186, getitem_187, mm_144, add_83, mm_147, mm_149, getitem_189, getitem_190, getitem_195, getitem_196, mm_151, add_87, mm_154, mm_156, getitem_198, getitem_199, getitem_204, getitem_205, mm_158, add_91, mm_161, mm_163, getitem_207, getitem_208, getitem_213, getitem_214, mm_165, add_95, mm_168, mm_170, getitem_216, getitem_217, getitem_222, getitem_223, mm_172, add_99, mm_175, mm_177, getitem_225, getitem_226, getitem_231, getitem_232, mm_179, add_103, mm_182, mm_184, getitem_234, getitem_235, getitem_240, getitem_241, mm_186, add_107, mm_189, mm_191, getitem_243, getitem_244, getitem_249, getitem_250, mm_193, add_111, mm_196, mm_198, getitem_252, getitem_253, getitem_258, getitem_259, mm_200, add_115, mm_203, mm_205, getitem_261, getitem_262, getitem_267, getitem_268, mm_207, add_119, mm_210, mm_212, getitem_270, getitem_271, getitem_276, getitem_277, mm_214, add_123, mm_217, mm_219, getitem_279, getitem_280, getitem_285, getitem_286, mm_221, mm_223, rsqrt_64, view_1091, tangents_1): + view_1093 = torch.ops.aten.view.default(tangents_1, [16384, 128256]); tangents_1 = None + permute_353 = torch.ops.aten.permute.default(view_1093, [1, 0]) + mm_225 = torch.ops.aten.mm.default(permute_353, view_1091); permute_353 = view_1091 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16); primals_293 = None + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 64, '0'); convert_element_type_1060 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + permute_355 = torch.ops.aten.permute.default(permute_352, [1, 0]); permute_352 = None + mm_226 = torch.ops.aten.mm.default(view_1093, permute_355); view_1093 = permute_355 = None + view_1094 = torch.ops.aten.view.default(mm_226, [2, 8192, 4096]); mm_226 = None + convert_element_type_1067 = torch.ops.prims.convert_element_type.default(mm_225, torch.float32); mm_225 = None + reduce_scatter_tensor = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1067, 'avg', 64, '0'); convert_element_type_1067 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None + convert_element_type_1068 = torch.ops.prims.convert_element_type.default(view_1094, torch.float32); view_1094 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16); primals_292 = None + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 64, '0'); convert_element_type_1057 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_1070 = torch.ops.prims.convert_element_type.default(wait_tensor_289, torch.float32); wait_tensor_289 = None + mul_258 = torch.ops.aten.mul.Tensor(convert_element_type_1068, convert_element_type_1070); convert_element_type_1070 = None + permute_347 = torch.ops.aten.permute.default(getitem_279, [0, 2, 1, 3]) + view_1075 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16); primals_287 = None + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 64, '0'); convert_element_type_1040 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1077 = torch.ops.aten.view.default(view_1075, [16384, 4096]); view_1075 = None + mm_220 = torch.ops.aten.mm.default(view_1077, permute_348) + view_1078 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + add_125 = torch.ops.aten.add.Tensor(add_123, view_1078); view_1078 = None + view_1088 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]); mm_223 = None + add_127 = torch.ops.aten.add.Tensor(add_125, view_1088); view_1088 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_260 = torch.ops.aten.mul.Tensor(mul_256, mul_258) + sum_1 = torch.ops.aten.sum.dim_IntList(mul_260, [2], True); mul_260 = None + div = torch.ops.aten.div.Tensor(mul_256, 4096) + mul_261 = torch.ops.aten.mul.Tensor(div, sum_1); div = sum_1 = None + sub = torch.ops.aten.sub.Tensor(mul_258, mul_261); mul_258 = mul_261 = None + mul_262 = torch.ops.aten.mul.Tensor(sub, rsqrt_64); sub = rsqrt_64 = None + mul_263 = torch.ops.aten.mul.Tensor(convert_element_type_1068, mul_256); convert_element_type_1068 = mul_256 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_263, [0, 1]); mul_263 = None + convert_element_type_1071 = torch.ops.prims.convert_element_type.default(mul_262, torch.bfloat16); mul_262 = None + convert_element_type_default_65 = torch.ops.prims.convert_element_type.default(sum_2, torch.float32); sum_2 = None + reduce_scatter_tensor_1 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_65, 'avg', 64, '0'); convert_element_type_default_65 = None + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None + view_1095 = torch.ops.aten.view.default(convert_element_type_1071, [16384, 4096]) + permute_357 = torch.ops.aten.permute.default(view_1095, [1, 0]) + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16); primals_288 = None + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 64, '0'); convert_element_type_1043 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32); add_125 = None + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_285) + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + view_1081 = torch.ops.aten.view.default(convert_element_type_1045, [16384, 4096]); convert_element_type_1045 = None + view_1082 = torch.ops.aten.view.default(mm_221, [2, 8192, 14336]); mm_221 = None + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_1082, torch.float32); view_1082 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16); primals_290 = None + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 64, '0'); convert_element_type_1051 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_287, [1, 0]); wait_tensor_287 = None + mm_222 = torch.ops.aten.mm.default(view_1081, permute_350) + view_1085 = torch.ops.aten.view.default(mm_222, [2, 8192, 14336]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_1085) + view_1087 = torch.ops.aten.view.default(mul_255, [16384, 14336]); mul_255 = None + mm_227 = torch.ops.aten.mm.default(permute_357, view_1087); permute_357 = view_1087 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16); primals_291 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 64, '0'); convert_element_type_1054 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + permute_359 = torch.ops.aten.permute.default(permute_351, [1, 0]); permute_351 = None + mm_228 = torch.ops.aten.mm.default(view_1095, permute_359); view_1095 = permute_359 = None + view_1096 = torch.ops.aten.view.default(mm_228, [2, 8192, 14336]); mm_228 = None + convert_element_type_1078 = torch.ops.prims.convert_element_type.default(mm_227, torch.float32); mm_227 = None + reduce_scatter_tensor_2 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1078, 'avg', 64, '0'); convert_element_type_1078 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None + mul_264 = torch.ops.aten.mul.Tensor(view_1096, convert_element_type_1050); convert_element_type_1050 = None + mul_265 = torch.ops.aten.mul.Tensor(view_1096, view_1085); view_1096 = view_1085 = None + view_1097 = torch.ops.aten.view.default(mul_264, [16384, 14336]); mul_264 = None + permute_361 = torch.ops.aten.permute.default(view_1097, [1, 0]) + mm_229 = torch.ops.aten.mm.default(permute_361, view_1081); permute_361 = None + permute_363 = torch.ops.aten.permute.default(permute_350, [1, 0]); permute_350 = None + mm_230 = torch.ops.aten.mm.default(view_1097, permute_363); view_1097 = permute_363 = None + view_1098 = torch.ops.aten.view.default(mm_230, [2, 8192, 4096]); mm_230 = None + convert_element_type_1083 = torch.ops.prims.convert_element_type.default(mm_229, torch.float32); mm_229 = None + reduce_scatter_tensor_3 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1083, 'avg', 64, '0'); convert_element_type_1083 = None + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3); reduce_scatter_tensor_3 = None + convert_element_type_1084 = torch.ops.prims.convert_element_type.default(mul_265, torch.float32); mul_265 = None + neg = torch.ops.aten.neg.default(convert_element_type_1049) + exp = torch.ops.aten.exp.default(neg); neg = None + add_129 = torch.ops.aten.add.Tensor(exp, 1); exp = None + reciprocal = torch.ops.aten.reciprocal.default(add_129); add_129 = None + mul_266 = torch.ops.aten.mul.Tensor(reciprocal, 1); reciprocal = None + mul_267 = torch.ops.aten.mul.Tensor(convert_element_type_1084, mul_266); convert_element_type_1084 = None + sub_1 = torch.ops.aten.sub.Tensor(1, mul_266); mul_266 = None + mul_268 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sub_1); convert_element_type_1049 = sub_1 = None + add_130 = torch.ops.aten.add.Tensor(mul_268, 1); mul_268 = None + mul_269 = torch.ops.aten.mul.Tensor(mul_267, add_130); mul_267 = add_130 = None + convert_element_type_1086 = torch.ops.prims.convert_element_type.default(mul_269, torch.bfloat16); mul_269 = None + view_1099 = torch.ops.aten.view.default(convert_element_type_1086, [16384, 14336]); convert_element_type_1086 = None + permute_365 = torch.ops.aten.permute.default(view_1099, [1, 0]) + mm_231 = torch.ops.aten.mm.default(permute_365, view_1081); permute_365 = view_1081 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16); primals_289 = None + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 64, '0'); convert_element_type_1046 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + permute_367 = torch.ops.aten.permute.default(permute_349, [1, 0]); permute_349 = None + mm_232 = torch.ops.aten.mm.default(view_1099, permute_367); view_1099 = permute_367 = None + view_1100 = torch.ops.aten.view.default(mm_232, [2, 8192, 4096]); mm_232 = None + add_131 = torch.ops.aten.add.Tensor(view_1098, view_1100); view_1098 = view_1100 = None + convert_element_type_1091 = torch.ops.prims.convert_element_type.default(mm_231, torch.float32); mm_231 = None + reduce_scatter_tensor_4 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1091, 'avg', 64, '0'); convert_element_type_1091 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_4); reduce_scatter_tensor_4 = None + convert_element_type_1092 = torch.ops.prims.convert_element_type.default(add_131, torch.float32); add_131 = None + convert_element_type_1094 = torch.ops.prims.convert_element_type.default(wait_tensor_285, torch.float32); wait_tensor_285 = None + mul_270 = torch.ops.aten.mul.Tensor(convert_element_type_1092, convert_element_type_1094); convert_element_type_1094 = None + mul_272 = torch.ops.aten.mul.Tensor(mul_252, mul_270) + sum_3 = torch.ops.aten.sum.dim_IntList(mul_272, [2], True); mul_272 = None + div_1 = torch.ops.aten.div.Tensor(mul_252, 4096) + mul_273 = torch.ops.aten.mul.Tensor(div_1, sum_3); div_1 = sum_3 = None + sub_2 = torch.ops.aten.sub.Tensor(mul_270, mul_273); mul_270 = mul_273 = None + mul_274 = torch.ops.aten.mul.Tensor(sub_2, rsqrt_63); sub_2 = rsqrt_63 = None + mul_275 = torch.ops.aten.mul.Tensor(convert_element_type_1092, mul_252); convert_element_type_1092 = mul_252 = None + sum_4 = torch.ops.aten.sum.dim_IntList(mul_275, [0, 1]); mul_275 = None + convert_element_type_1095 = torch.ops.prims.convert_element_type.default(mul_274, torch.bfloat16); mul_274 = None + add_132 = torch.ops.aten.add.Tensor(convert_element_type_1071, convert_element_type_1095); convert_element_type_1071 = convert_element_type_1095 = None + convert_element_type_default_64 = torch.ops.prims.convert_element_type.default(sum_4, torch.float32); sum_4 = None + reduce_scatter_tensor_5 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_64, 'avg', 64, '0'); convert_element_type_default_64 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5); reduce_scatter_tensor_5 = None + view_1101 = torch.ops.aten.view.default(add_132, [16384, 4096]) + permute_369 = torch.ops.aten.permute.default(view_1101, [1, 0]) + mm_233 = torch.ops.aten.mm.default(permute_369, view_1077); permute_369 = view_1077 = None + permute_371 = torch.ops.aten.permute.default(permute_348, [1, 0]); permute_348 = None + mm_234 = torch.ops.aten.mm.default(view_1101, permute_371); view_1101 = permute_371 = None + view_1102 = torch.ops.aten.view.default(mm_234, [2, 8192, 4096]); mm_234 = None + convert_element_type_1102 = torch.ops.prims.convert_element_type.default(mm_233, torch.float32); mm_233 = None + reduce_scatter_tensor_6 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1102, 'avg', 64, '0'); convert_element_type_1102 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_6); reduce_scatter_tensor_6 = None + view_1103 = torch.ops.aten.view.default(view_1102, [2, 8192, 32, 128]); view_1102 = None + permute_373 = torch.ops.aten.permute.default(view_1103, [0, 2, 1, 3]); view_1103 = None + view_16 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]); primals_3 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16); primals_283 = None + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 64, '0'); convert_element_type_1024 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32); add_123 = None + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_280) + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + view_1057 = torch.ops.aten.view.default(convert_element_type_1026, [16384, 4096]); convert_element_type_1026 = None + view_1058 = torch.ops.aten.view.default(mm_217, [2, 8192, 4096]); mm_217 = None + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 64, '0'); convert_element_type_1030 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + mm_218 = torch.ops.aten.mm.default(view_1057, permute_342) + view_1061 = torch.ops.aten.view.default(mm_218, [2, 8192, 1024]); mm_218 = None + view_1064 = torch.ops.aten.view.default(mm_219, [2, 8192, 1024]); mm_219 = None + view_1065 = torch.ops.aten.view.default(view_1058, [2, 8192, -1, 128]); view_1058 = None + view_1066 = torch.ops.aten.view.default(view_1061, [2, 8192, -1, 128]); view_1061 = None + view_1067 = torch.ops.aten.view.default(view_1064, [2, 8192, -1, 128]); view_1064 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_1065, torch.float32); view_1065 = None + view_1068 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 32, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1068); view_1068 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_1066, torch.float32); view_1066 = None + view_1069 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 8, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1069); view_1069 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_16); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_1071 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 32, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_16); view_as_complex_63 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_1072 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 8, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_1071, torch.bfloat16); view_1071 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_1072, torch.bfloat16); view_1072 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 8, 4, 128]); unsqueeze_62 = None + clone_62 = torch.ops.aten.clone.default(expand_62, memory_format = torch.contiguous_format); expand_62 = None + view_1073 = torch.ops.aten.view.default(clone_62, [2, 8192, 32, 128]); clone_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_1067, 3); view_1067 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 8, 4, 128]); unsqueeze_63 = None + clone_63 = torch.ops.aten.clone.default(expand_63, memory_format = torch.contiguous_format); expand_63 = None + view_1074 = torch.ops.aten.view.default(clone_63, [2, 8192, 32, 128]); clone_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_1073, [0, 2, 1, 3]); view_1073 = None + permute_346 = torch.ops.aten.permute.default(view_1074, [0, 2, 1, 3]); view_1074 = None + _scaled_dot_product_cudnn_attention_backward = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_373, permute_344, permute_345, permute_346, getitem_279, getitem_280, getitem_285, getitem_286, None, None, None, 8192, 8192, 0.0, True); permute_373 = permute_344 = permute_345 = permute_346 = getitem_279 = getitem_280 = getitem_285 = getitem_286 = None + getitem_288 = _scaled_dot_product_cudnn_attention_backward[0] + getitem_289 = _scaled_dot_product_cudnn_attention_backward[1] + getitem_290 = _scaled_dot_product_cudnn_attention_backward[2]; _scaled_dot_product_cudnn_attention_backward = None + permute_374 = torch.ops.aten.permute.default(getitem_290, [0, 2, 1, 3]); getitem_290 = None + permute_375 = torch.ops.aten.permute.default(getitem_289, [0, 2, 1, 3]); getitem_289 = None + permute_376 = torch.ops.aten.permute.default(getitem_288, [0, 2, 1, 3]); getitem_288 = None + view_1104 = torch.ops.aten.view.default(permute_374, [2, 8192, 8, 4, 128]); permute_374 = None + sum_5 = torch.ops.aten.sum.dim_IntList(view_1104, [3], True); view_1104 = None + squeeze = torch.ops.aten.squeeze.dim(sum_5, 3); sum_5 = None + view_1105 = torch.ops.aten.view.default(permute_375, [2, 8192, 8, 4, 128]); permute_375 = None + sum_6 = torch.ops.aten.sum.dim_IntList(view_1105, [3], True); view_1105 = None + squeeze_1 = torch.ops.aten.squeeze.dim(sum_6, 3); sum_6 = None + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(squeeze_1, torch.float32); squeeze_1 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(permute_376, torch.float32); permute_376 = None + view_1106 = torch.ops.aten.view.default(convert_element_type_1103, [2, 8192, 8, 64, 2]); convert_element_type_1103 = None + view_as_complex_64 = torch.ops.aten.view_as_complex.default(view_1106); view_1106 = None + _conj = torch.ops.aten._conj.default(view_16) + mul_276 = torch.ops.aten.mul.Tensor(view_as_complex_64, _conj); view_as_complex_64 = None + view_1107 = torch.ops.aten.view.default(convert_element_type_1104, [2, 8192, 32, 64, 2]); convert_element_type_1104 = None + view_as_complex_65 = torch.ops.aten.view_as_complex.default(view_1107); view_1107 = None + mul_277 = torch.ops.aten.mul.Tensor(view_as_complex_65, _conj); view_as_complex_65 = None + view_as_real_64 = torch.ops.aten.view_as_real.default(mul_276); mul_276 = None + view_1108 = torch.ops.aten.view.default(view_as_real_64, [2, 8192, 8, 128]); view_as_real_64 = None + convert_element_type_1105 = torch.ops.prims.convert_element_type.default(view_1108, torch.bfloat16); view_1108 = None + view_as_real_65 = torch.ops.aten.view_as_real.default(mul_277); mul_277 = None + view_1109 = torch.ops.aten.view.default(view_as_real_65, [2, 8192, 32, 128]); view_as_real_65 = None + convert_element_type_1106 = torch.ops.prims.convert_element_type.default(view_1109, torch.bfloat16); view_1109 = None + view_1110 = torch.ops.aten.view.default(squeeze, [2, 8192, 1024]); squeeze = None + view_1111 = torch.ops.aten.view.default(convert_element_type_1105, [2, 8192, 1024]); convert_element_type_1105 = None + view_1112 = torch.ops.aten.view.default(convert_element_type_1106, [2, 8192, 4096]); convert_element_type_1106 = None + view_1113 = torch.ops.aten.view.default(view_1110, [16384, 1024]); view_1110 = None + permute_377 = torch.ops.aten.permute.default(view_1113, [1, 0]) + mm_235 = torch.ops.aten.mm.default(permute_377, view_1057); permute_377 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16); primals_286 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 64, '0'); convert_element_type_1033 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + permute_379 = torch.ops.aten.permute.default(permute_343, [1, 0]); permute_343 = None + mm_236 = torch.ops.aten.mm.default(view_1113, permute_379); view_1113 = permute_379 = None + view_1114 = torch.ops.aten.view.default(mm_236, [2, 8192, 4096]); mm_236 = None + convert_element_type_1111 = torch.ops.prims.convert_element_type.default(mm_235, torch.float32); mm_235 = None + reduce_scatter_tensor_7 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1111, 'avg', 64, '0'); convert_element_type_1111 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7); reduce_scatter_tensor_7 = None + view_1115 = torch.ops.aten.view.default(view_1111, [16384, 1024]); view_1111 = None + permute_381 = torch.ops.aten.permute.default(view_1115, [1, 0]) + mm_237 = torch.ops.aten.mm.default(permute_381, view_1057); permute_381 = None + permute_383 = torch.ops.aten.permute.default(permute_342, [1, 0]); permute_342 = None + mm_238 = torch.ops.aten.mm.default(view_1115, permute_383); view_1115 = permute_383 = None + view_1116 = torch.ops.aten.view.default(mm_238, [2, 8192, 4096]); mm_238 = None + add_133 = torch.ops.aten.add.Tensor(view_1114, view_1116); view_1114 = view_1116 = None + convert_element_type_1116 = torch.ops.prims.convert_element_type.default(mm_237, torch.float32); mm_237 = None + reduce_scatter_tensor_8 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1116, 'avg', 64, '0'); convert_element_type_1116 = None + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_8); reduce_scatter_tensor_8 = None + view_1117 = torch.ops.aten.view.default(view_1112, [16384, 4096]); view_1112 = None + permute_385 = torch.ops.aten.permute.default(view_1117, [1, 0]) + mm_239 = torch.ops.aten.mm.default(permute_385, view_1057); permute_385 = view_1057 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 64, '0'); convert_element_type_1027 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + permute_387 = torch.ops.aten.permute.default(permute_341, [1, 0]); permute_341 = None + mm_240 = torch.ops.aten.mm.default(view_1117, permute_387); view_1117 = permute_387 = None + view_1118 = torch.ops.aten.view.default(mm_240, [2, 8192, 4096]); mm_240 = None + add_134 = torch.ops.aten.add.Tensor(add_133, view_1118); add_133 = view_1118 = None + convert_element_type_1121 = torch.ops.prims.convert_element_type.default(mm_239, torch.float32); mm_239 = None + reduce_scatter_tensor_9 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1121, 'avg', 64, '0'); convert_element_type_1121 = None + wait_tensor_300 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9); reduce_scatter_tensor_9 = None + convert_element_type_1122 = torch.ops.prims.convert_element_type.default(add_134, torch.float32); add_134 = None + convert_element_type_1124 = torch.ops.prims.convert_element_type.default(wait_tensor_280, torch.float32); wait_tensor_280 = None + mul_278 = torch.ops.aten.mul.Tensor(convert_element_type_1122, convert_element_type_1124); convert_element_type_1124 = None + mul_280 = torch.ops.aten.mul.Tensor(mul_248, mul_278) + sum_7 = torch.ops.aten.sum.dim_IntList(mul_280, [2], True); mul_280 = None + div_2 = torch.ops.aten.div.Tensor(mul_248, 4096) + mul_281 = torch.ops.aten.mul.Tensor(div_2, sum_7); div_2 = sum_7 = None + sub_3 = torch.ops.aten.sub.Tensor(mul_278, mul_281); mul_278 = mul_281 = None + mul_282 = torch.ops.aten.mul.Tensor(sub_3, rsqrt_62); sub_3 = rsqrt_62 = None + mul_283 = torch.ops.aten.mul.Tensor(convert_element_type_1122, mul_248); convert_element_type_1122 = mul_248 = None + sum_8 = torch.ops.aten.sum.dim_IntList(mul_283, [0, 1]); mul_283 = None + convert_element_type_1125 = torch.ops.prims.convert_element_type.default(mul_282, torch.bfloat16); mul_282 = None + add_135 = torch.ops.aten.add.Tensor(add_132, convert_element_type_1125); add_132 = convert_element_type_1125 = None + convert_element_type_default_63 = torch.ops.prims.convert_element_type.default(sum_8, torch.float32); sum_8 = None + reduce_scatter_tensor_10 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_63, 'avg', 64, '0'); convert_element_type_default_63 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_10); reduce_scatter_tensor_10 = None + view_1119 = torch.ops.aten.view.default(add_135, [16384, 4096]) + permute_389 = torch.ops.aten.permute.default(view_1119, [1, 0]) + permute_336 = torch.ops.aten.permute.default(getitem_270, [0, 2, 1, 3]) + view_1041 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16); primals_278 = None + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 64, '0'); convert_element_type_1007 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_275, [1, 0]); wait_tensor_275 = None + view_1043 = torch.ops.aten.view.default(view_1041, [16384, 4096]); view_1041 = None + mm_213 = torch.ops.aten.mm.default(view_1043, permute_337) + view_1044 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + add_121 = torch.ops.aten.add.Tensor(add_119, view_1044); view_1044 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16); primals_279 = None + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 64, '0'); convert_element_type_1010 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32); add_121 = None + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_276) + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + view_1047 = torch.ops.aten.view.default(convert_element_type_1012, [16384, 4096]); convert_element_type_1012 = None + view_1048 = torch.ops.aten.view.default(mm_214, [2, 8192, 14336]); mm_214 = None + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_1048, torch.float32); view_1048 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16); primals_281 = None + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 64, '0'); convert_element_type_1018 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_215 = torch.ops.aten.mm.default(view_1047, permute_339) + view_1051 = torch.ops.aten.view.default(mm_215, [2, 8192, 14336]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_1051) + view_1053 = torch.ops.aten.view.default(mul_247, [16384, 14336]); mul_247 = None + mm_241 = torch.ops.aten.mm.default(permute_389, view_1053); permute_389 = view_1053 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16); primals_282 = None + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 64, '0'); convert_element_type_1021 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + permute_391 = torch.ops.aten.permute.default(permute_340, [1, 0]); permute_340 = None + mm_242 = torch.ops.aten.mm.default(view_1119, permute_391); view_1119 = permute_391 = None + view_1120 = torch.ops.aten.view.default(mm_242, [2, 8192, 14336]); mm_242 = None + convert_element_type_1132 = torch.ops.prims.convert_element_type.default(mm_241, torch.float32); mm_241 = None + reduce_scatter_tensor_11 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1132, 'avg', 64, '0'); convert_element_type_1132 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11); reduce_scatter_tensor_11 = None + mul_284 = torch.ops.aten.mul.Tensor(view_1120, convert_element_type_1017); convert_element_type_1017 = None + mul_285 = torch.ops.aten.mul.Tensor(view_1120, view_1051); view_1120 = view_1051 = None + view_1121 = torch.ops.aten.view.default(mul_284, [16384, 14336]); mul_284 = None + permute_393 = torch.ops.aten.permute.default(view_1121, [1, 0]) + mm_243 = torch.ops.aten.mm.default(permute_393, view_1047); permute_393 = None + permute_395 = torch.ops.aten.permute.default(permute_339, [1, 0]); permute_339 = None + mm_244 = torch.ops.aten.mm.default(view_1121, permute_395); view_1121 = permute_395 = None + view_1122 = torch.ops.aten.view.default(mm_244, [2, 8192, 4096]); mm_244 = None + convert_element_type_1137 = torch.ops.prims.convert_element_type.default(mm_243, torch.float32); mm_243 = None + reduce_scatter_tensor_12 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1137, 'avg', 64, '0'); convert_element_type_1137 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_12); reduce_scatter_tensor_12 = None + convert_element_type_1138 = torch.ops.prims.convert_element_type.default(mul_285, torch.float32); mul_285 = None + neg_1 = torch.ops.aten.neg.default(convert_element_type_1016) + exp_1 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_136 = torch.ops.aten.add.Tensor(exp_1, 1); exp_1 = None + reciprocal_1 = torch.ops.aten.reciprocal.default(add_136); add_136 = None + mul_286 = torch.ops.aten.mul.Tensor(reciprocal_1, 1); reciprocal_1 = None + mul_287 = torch.ops.aten.mul.Tensor(convert_element_type_1138, mul_286); convert_element_type_1138 = None + sub_4 = torch.ops.aten.sub.Tensor(1, mul_286); mul_286 = None + mul_288 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sub_4); convert_element_type_1016 = sub_4 = None + add_137 = torch.ops.aten.add.Tensor(mul_288, 1); mul_288 = None + mul_289 = torch.ops.aten.mul.Tensor(mul_287, add_137); mul_287 = add_137 = None + convert_element_type_1140 = torch.ops.prims.convert_element_type.default(mul_289, torch.bfloat16); mul_289 = None + view_1123 = torch.ops.aten.view.default(convert_element_type_1140, [16384, 14336]); convert_element_type_1140 = None + permute_397 = torch.ops.aten.permute.default(view_1123, [1, 0]) + mm_245 = torch.ops.aten.mm.default(permute_397, view_1047); permute_397 = view_1047 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16); primals_280 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 64, '0'); convert_element_type_1013 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + permute_399 = torch.ops.aten.permute.default(permute_338, [1, 0]); permute_338 = None + mm_246 = torch.ops.aten.mm.default(view_1123, permute_399); view_1123 = permute_399 = None + view_1124 = torch.ops.aten.view.default(mm_246, [2, 8192, 4096]); mm_246 = None + add_138 = torch.ops.aten.add.Tensor(view_1122, view_1124); view_1122 = view_1124 = None + convert_element_type_1145 = torch.ops.prims.convert_element_type.default(mm_245, torch.float32); mm_245 = None + reduce_scatter_tensor_13 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1145, 'avg', 64, '0'); convert_element_type_1145 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13); reduce_scatter_tensor_13 = None + convert_element_type_1146 = torch.ops.prims.convert_element_type.default(add_138, torch.float32); add_138 = None + convert_element_type_1148 = torch.ops.prims.convert_element_type.default(wait_tensor_276, torch.float32); wait_tensor_276 = None + mul_290 = torch.ops.aten.mul.Tensor(convert_element_type_1146, convert_element_type_1148); convert_element_type_1148 = None + mul_292 = torch.ops.aten.mul.Tensor(mul_244, mul_290) + sum_9 = torch.ops.aten.sum.dim_IntList(mul_292, [2], True); mul_292 = None + div_3 = torch.ops.aten.div.Tensor(mul_244, 4096) + mul_293 = torch.ops.aten.mul.Tensor(div_3, sum_9); div_3 = sum_9 = None + sub_5 = torch.ops.aten.sub.Tensor(mul_290, mul_293); mul_290 = mul_293 = None + mul_294 = torch.ops.aten.mul.Tensor(sub_5, rsqrt_61); sub_5 = rsqrt_61 = None + mul_295 = torch.ops.aten.mul.Tensor(convert_element_type_1146, mul_244); convert_element_type_1146 = mul_244 = None + sum_10 = torch.ops.aten.sum.dim_IntList(mul_295, [0, 1]); mul_295 = None + convert_element_type_1149 = torch.ops.prims.convert_element_type.default(mul_294, torch.bfloat16); mul_294 = None + add_139 = torch.ops.aten.add.Tensor(add_135, convert_element_type_1149); add_135 = convert_element_type_1149 = None + convert_element_type_default_62 = torch.ops.prims.convert_element_type.default(sum_10, torch.float32); sum_10 = None + reduce_scatter_tensor_14 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_62, 'avg', 64, '0'); convert_element_type_default_62 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_14); reduce_scatter_tensor_14 = None + view_1125 = torch.ops.aten.view.default(add_139, [16384, 4096]) + permute_401 = torch.ops.aten.permute.default(view_1125, [1, 0]) + mm_247 = torch.ops.aten.mm.default(permute_401, view_1043); permute_401 = view_1043 = None + permute_403 = torch.ops.aten.permute.default(permute_337, [1, 0]); permute_337 = None + mm_248 = torch.ops.aten.mm.default(view_1125, permute_403); view_1125 = permute_403 = None + view_1126 = torch.ops.aten.view.default(mm_248, [2, 8192, 4096]); mm_248 = None + convert_element_type_1156 = torch.ops.prims.convert_element_type.default(mm_247, torch.float32); mm_247 = None + reduce_scatter_tensor_15 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1156, 'avg', 64, '0'); convert_element_type_1156 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15); reduce_scatter_tensor_15 = None + view_1127 = torch.ops.aten.view.default(view_1126, [2, 8192, 32, 128]); view_1126 = None + permute_405 = torch.ops.aten.permute.default(view_1127, [0, 2, 1, 3]); view_1127 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16); primals_274 = None + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 64, '0'); convert_element_type_991 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32); add_119 = None + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_271) + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + view_1023 = torch.ops.aten.view.default(convert_element_type_993, [16384, 4096]); convert_element_type_993 = None + view_1024 = torch.ops.aten.view.default(mm_210, [2, 8192, 4096]); mm_210 = None + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16); primals_276 = None + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 64, '0'); convert_element_type_997 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + mm_211 = torch.ops.aten.mm.default(view_1023, permute_331) + view_1027 = torch.ops.aten.view.default(mm_211, [2, 8192, 1024]); mm_211 = None + view_1030 = torch.ops.aten.view.default(mm_212, [2, 8192, 1024]); mm_212 = None + view_1031 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1032 = torch.ops.aten.view.default(view_1027, [2, 8192, -1, 128]); view_1027 = None + view_1033 = torch.ops.aten.view.default(view_1030, [2, 8192, -1, 128]); view_1030 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_1031, torch.float32); view_1031 = None + view_1034 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 32, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1034); view_1034 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_1032, torch.float32); view_1032 = None + view_1035 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 8, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1035); view_1035 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_16); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_1037 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 32, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_16); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_1038 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 8, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_1037, torch.bfloat16); view_1037 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_1038, torch.bfloat16); view_1038 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 8, 4, 128]); unsqueeze_60 = None + clone_60 = torch.ops.aten.clone.default(expand_60, memory_format = torch.contiguous_format); expand_60 = None + view_1039 = torch.ops.aten.view.default(clone_60, [2, 8192, 32, 128]); clone_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_1033, 3); view_1033 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 8, 4, 128]); unsqueeze_61 = None + clone_61 = torch.ops.aten.clone.default(expand_61, memory_format = torch.contiguous_format); expand_61 = None + view_1040 = torch.ops.aten.view.default(clone_61, [2, 8192, 32, 128]); clone_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_1039, [0, 2, 1, 3]); view_1039 = None + permute_335 = torch.ops.aten.permute.default(view_1040, [0, 2, 1, 3]); view_1040 = None + _scaled_dot_product_cudnn_attention_backward_1 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_405, permute_333, permute_334, permute_335, getitem_270, getitem_271, getitem_276, getitem_277, None, None, None, 8192, 8192, 0.0, True); permute_405 = permute_333 = permute_334 = permute_335 = getitem_270 = getitem_271 = getitem_276 = getitem_277 = None + getitem_291 = _scaled_dot_product_cudnn_attention_backward_1[0] + getitem_292 = _scaled_dot_product_cudnn_attention_backward_1[1] + getitem_293 = _scaled_dot_product_cudnn_attention_backward_1[2]; _scaled_dot_product_cudnn_attention_backward_1 = None + permute_406 = torch.ops.aten.permute.default(getitem_293, [0, 2, 1, 3]); getitem_293 = None + permute_407 = torch.ops.aten.permute.default(getitem_292, [0, 2, 1, 3]); getitem_292 = None + permute_408 = torch.ops.aten.permute.default(getitem_291, [0, 2, 1, 3]); getitem_291 = None + view_1128 = torch.ops.aten.view.default(permute_406, [2, 8192, 8, 4, 128]); permute_406 = None + sum_11 = torch.ops.aten.sum.dim_IntList(view_1128, [3], True); view_1128 = None + squeeze_2 = torch.ops.aten.squeeze.dim(sum_11, 3); sum_11 = None + view_1129 = torch.ops.aten.view.default(permute_407, [2, 8192, 8, 4, 128]); permute_407 = None + sum_12 = torch.ops.aten.sum.dim_IntList(view_1129, [3], True); view_1129 = None + squeeze_3 = torch.ops.aten.squeeze.dim(sum_12, 3); sum_12 = None + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(squeeze_3, torch.float32); squeeze_3 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(permute_408, torch.float32); permute_408 = None + view_1130 = torch.ops.aten.view.default(convert_element_type_1157, [2, 8192, 8, 64, 2]); convert_element_type_1157 = None + view_as_complex_66 = torch.ops.aten.view_as_complex.default(view_1130); view_1130 = None + mul_296 = torch.ops.aten.mul.Tensor(view_as_complex_66, _conj); view_as_complex_66 = None + view_1131 = torch.ops.aten.view.default(convert_element_type_1158, [2, 8192, 32, 64, 2]); convert_element_type_1158 = None + view_as_complex_67 = torch.ops.aten.view_as_complex.default(view_1131); view_1131 = None + mul_297 = torch.ops.aten.mul.Tensor(view_as_complex_67, _conj); view_as_complex_67 = None + view_as_real_66 = torch.ops.aten.view_as_real.default(mul_296); mul_296 = None + view_1132 = torch.ops.aten.view.default(view_as_real_66, [2, 8192, 8, 128]); view_as_real_66 = None + convert_element_type_1159 = torch.ops.prims.convert_element_type.default(view_1132, torch.bfloat16); view_1132 = None + view_as_real_67 = torch.ops.aten.view_as_real.default(mul_297); mul_297 = None + view_1133 = torch.ops.aten.view.default(view_as_real_67, [2, 8192, 32, 128]); view_as_real_67 = None + convert_element_type_1160 = torch.ops.prims.convert_element_type.default(view_1133, torch.bfloat16); view_1133 = None + view_1134 = torch.ops.aten.view.default(squeeze_2, [2, 8192, 1024]); squeeze_2 = None + view_1135 = torch.ops.aten.view.default(convert_element_type_1159, [2, 8192, 1024]); convert_element_type_1159 = None + view_1136 = torch.ops.aten.view.default(convert_element_type_1160, [2, 8192, 4096]); convert_element_type_1160 = None + view_1137 = torch.ops.aten.view.default(view_1134, [16384, 1024]); view_1134 = None + permute_409 = torch.ops.aten.permute.default(view_1137, [1, 0]) + mm_249 = torch.ops.aten.mm.default(permute_409, view_1023); permute_409 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16); primals_277 = None + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 64, '0'); convert_element_type_1000 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_274, [1, 0]); wait_tensor_274 = None + permute_411 = torch.ops.aten.permute.default(permute_332, [1, 0]); permute_332 = None + mm_250 = torch.ops.aten.mm.default(view_1137, permute_411); view_1137 = permute_411 = None + view_1138 = torch.ops.aten.view.default(mm_250, [2, 8192, 4096]); mm_250 = None + convert_element_type_1165 = torch.ops.prims.convert_element_type.default(mm_249, torch.float32); mm_249 = None + reduce_scatter_tensor_16 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1165, 'avg', 64, '0'); convert_element_type_1165 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_16); reduce_scatter_tensor_16 = None + view_1139 = torch.ops.aten.view.default(view_1135, [16384, 1024]); view_1135 = None + permute_413 = torch.ops.aten.permute.default(view_1139, [1, 0]) + mm_251 = torch.ops.aten.mm.default(permute_413, view_1023); permute_413 = None + permute_415 = torch.ops.aten.permute.default(permute_331, [1, 0]); permute_331 = None + mm_252 = torch.ops.aten.mm.default(view_1139, permute_415); view_1139 = permute_415 = None + view_1140 = torch.ops.aten.view.default(mm_252, [2, 8192, 4096]); mm_252 = None + add_140 = torch.ops.aten.add.Tensor(view_1138, view_1140); view_1138 = view_1140 = None + convert_element_type_1170 = torch.ops.prims.convert_element_type.default(mm_251, torch.float32); mm_251 = None + reduce_scatter_tensor_17 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1170, 'avg', 64, '0'); convert_element_type_1170 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17); reduce_scatter_tensor_17 = None + view_1141 = torch.ops.aten.view.default(view_1136, [16384, 4096]); view_1136 = None + permute_417 = torch.ops.aten.permute.default(view_1141, [1, 0]) + mm_253 = torch.ops.aten.mm.default(permute_417, view_1023); permute_417 = view_1023 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16); primals_275 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 64, '0'); convert_element_type_994 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + permute_419 = torch.ops.aten.permute.default(permute_330, [1, 0]); permute_330 = None + mm_254 = torch.ops.aten.mm.default(view_1141, permute_419); view_1141 = permute_419 = None + view_1142 = torch.ops.aten.view.default(mm_254, [2, 8192, 4096]); mm_254 = None + add_141 = torch.ops.aten.add.Tensor(add_140, view_1142); add_140 = view_1142 = None + convert_element_type_1175 = torch.ops.prims.convert_element_type.default(mm_253, torch.float32); mm_253 = None + reduce_scatter_tensor_18 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1175, 'avg', 64, '0'); convert_element_type_1175 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_18); reduce_scatter_tensor_18 = None + convert_element_type_1176 = torch.ops.prims.convert_element_type.default(add_141, torch.float32); add_141 = None + convert_element_type_1178 = torch.ops.prims.convert_element_type.default(wait_tensor_271, torch.float32); wait_tensor_271 = None + mul_298 = torch.ops.aten.mul.Tensor(convert_element_type_1176, convert_element_type_1178); convert_element_type_1178 = None + mul_300 = torch.ops.aten.mul.Tensor(mul_240, mul_298) + sum_13 = torch.ops.aten.sum.dim_IntList(mul_300, [2], True); mul_300 = None + div_4 = torch.ops.aten.div.Tensor(mul_240, 4096) + mul_301 = torch.ops.aten.mul.Tensor(div_4, sum_13); div_4 = sum_13 = None + sub_6 = torch.ops.aten.sub.Tensor(mul_298, mul_301); mul_298 = mul_301 = None + mul_302 = torch.ops.aten.mul.Tensor(sub_6, rsqrt_60); sub_6 = rsqrt_60 = None + mul_303 = torch.ops.aten.mul.Tensor(convert_element_type_1176, mul_240); convert_element_type_1176 = mul_240 = None + sum_14 = torch.ops.aten.sum.dim_IntList(mul_303, [0, 1]); mul_303 = None + convert_element_type_1179 = torch.ops.prims.convert_element_type.default(mul_302, torch.bfloat16); mul_302 = None + add_142 = torch.ops.aten.add.Tensor(add_139, convert_element_type_1179); add_139 = convert_element_type_1179 = None + convert_element_type_default_61 = torch.ops.prims.convert_element_type.default(sum_14, torch.float32); sum_14 = None + reduce_scatter_tensor_19 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_61, 'avg', 64, '0'); convert_element_type_default_61 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19); reduce_scatter_tensor_19 = None + view_1143 = torch.ops.aten.view.default(add_142, [16384, 4096]) + permute_421 = torch.ops.aten.permute.default(view_1143, [1, 0]) + permute_325 = torch.ops.aten.permute.default(getitem_261, [0, 2, 1, 3]) + view_1007 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16); primals_269 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 64, '0'); convert_element_type_974 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + view_1009 = torch.ops.aten.view.default(view_1007, [16384, 4096]); view_1007 = None + mm_206 = torch.ops.aten.mm.default(view_1009, permute_326) + view_1010 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + add_117 = torch.ops.aten.add.Tensor(add_115, view_1010); view_1010 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16); primals_270 = None + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 64, '0'); convert_element_type_977 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32); add_117 = None + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_267) + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + view_1013 = torch.ops.aten.view.default(convert_element_type_979, [16384, 4096]); convert_element_type_979 = None + view_1014 = torch.ops.aten.view.default(mm_207, [2, 8192, 14336]); mm_207 = None + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_1014, torch.float32); view_1014 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16); primals_272 = None + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 64, '0'); convert_element_type_985 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_269, [1, 0]); wait_tensor_269 = None + mm_208 = torch.ops.aten.mm.default(view_1013, permute_328) + view_1017 = torch.ops.aten.view.default(mm_208, [2, 8192, 14336]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_1017) + view_1019 = torch.ops.aten.view.default(mul_239, [16384, 14336]); mul_239 = None + mm_255 = torch.ops.aten.mm.default(permute_421, view_1019); permute_421 = view_1019 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16); primals_273 = None + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 64, '0'); convert_element_type_988 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_270, [1, 0]); wait_tensor_270 = None + permute_423 = torch.ops.aten.permute.default(permute_329, [1, 0]); permute_329 = None + mm_256 = torch.ops.aten.mm.default(view_1143, permute_423); view_1143 = permute_423 = None + view_1144 = torch.ops.aten.view.default(mm_256, [2, 8192, 14336]); mm_256 = None + convert_element_type_1186 = torch.ops.prims.convert_element_type.default(mm_255, torch.float32); mm_255 = None + reduce_scatter_tensor_20 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1186, 'avg', 64, '0'); convert_element_type_1186 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_20); reduce_scatter_tensor_20 = None + mul_304 = torch.ops.aten.mul.Tensor(view_1144, convert_element_type_984); convert_element_type_984 = None + mul_305 = torch.ops.aten.mul.Tensor(view_1144, view_1017); view_1144 = view_1017 = None + view_1145 = torch.ops.aten.view.default(mul_304, [16384, 14336]); mul_304 = None + permute_425 = torch.ops.aten.permute.default(view_1145, [1, 0]) + mm_257 = torch.ops.aten.mm.default(permute_425, view_1013); permute_425 = None + permute_427 = torch.ops.aten.permute.default(permute_328, [1, 0]); permute_328 = None + mm_258 = torch.ops.aten.mm.default(view_1145, permute_427); view_1145 = permute_427 = None + view_1146 = torch.ops.aten.view.default(mm_258, [2, 8192, 4096]); mm_258 = None + convert_element_type_1191 = torch.ops.prims.convert_element_type.default(mm_257, torch.float32); mm_257 = None + reduce_scatter_tensor_21 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1191, 'avg', 64, '0'); convert_element_type_1191 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21); reduce_scatter_tensor_21 = None + convert_element_type_1192 = torch.ops.prims.convert_element_type.default(mul_305, torch.float32); mul_305 = None + neg_2 = torch.ops.aten.neg.default(convert_element_type_983) + exp_2 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_143 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + reciprocal_2 = torch.ops.aten.reciprocal.default(add_143); add_143 = None + mul_306 = torch.ops.aten.mul.Tensor(reciprocal_2, 1); reciprocal_2 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_1192, mul_306); convert_element_type_1192 = None + sub_7 = torch.ops.aten.sub.Tensor(1, mul_306); mul_306 = None + mul_308 = torch.ops.aten.mul.Tensor(convert_element_type_983, sub_7); convert_element_type_983 = sub_7 = None + add_144 = torch.ops.aten.add.Tensor(mul_308, 1); mul_308 = None + mul_309 = torch.ops.aten.mul.Tensor(mul_307, add_144); mul_307 = add_144 = None + convert_element_type_1194 = torch.ops.prims.convert_element_type.default(mul_309, torch.bfloat16); mul_309 = None + view_1147 = torch.ops.aten.view.default(convert_element_type_1194, [16384, 14336]); convert_element_type_1194 = None + permute_429 = torch.ops.aten.permute.default(view_1147, [1, 0]) + mm_259 = torch.ops.aten.mm.default(permute_429, view_1013); permute_429 = view_1013 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16); primals_271 = None + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 64, '0'); convert_element_type_980 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + permute_431 = torch.ops.aten.permute.default(permute_327, [1, 0]); permute_327 = None + mm_260 = torch.ops.aten.mm.default(view_1147, permute_431); view_1147 = permute_431 = None + view_1148 = torch.ops.aten.view.default(mm_260, [2, 8192, 4096]); mm_260 = None + add_145 = torch.ops.aten.add.Tensor(view_1146, view_1148); view_1146 = view_1148 = None + convert_element_type_1199 = torch.ops.prims.convert_element_type.default(mm_259, torch.float32); mm_259 = None + reduce_scatter_tensor_22 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1199, 'avg', 64, '0'); convert_element_type_1199 = None + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_22); reduce_scatter_tensor_22 = None + convert_element_type_1200 = torch.ops.prims.convert_element_type.default(add_145, torch.float32); add_145 = None + convert_element_type_1202 = torch.ops.prims.convert_element_type.default(wait_tensor_267, torch.float32); wait_tensor_267 = None + mul_310 = torch.ops.aten.mul.Tensor(convert_element_type_1200, convert_element_type_1202); convert_element_type_1202 = None + mul_312 = torch.ops.aten.mul.Tensor(mul_236, mul_310) + sum_15 = torch.ops.aten.sum.dim_IntList(mul_312, [2], True); mul_312 = None + div_5 = torch.ops.aten.div.Tensor(mul_236, 4096) + mul_313 = torch.ops.aten.mul.Tensor(div_5, sum_15); div_5 = sum_15 = None + sub_8 = torch.ops.aten.sub.Tensor(mul_310, mul_313); mul_310 = mul_313 = None + mul_314 = torch.ops.aten.mul.Tensor(sub_8, rsqrt_59); sub_8 = rsqrt_59 = None + mul_315 = torch.ops.aten.mul.Tensor(convert_element_type_1200, mul_236); convert_element_type_1200 = mul_236 = None + sum_16 = torch.ops.aten.sum.dim_IntList(mul_315, [0, 1]); mul_315 = None + convert_element_type_1203 = torch.ops.prims.convert_element_type.default(mul_314, torch.bfloat16); mul_314 = None + add_146 = torch.ops.aten.add.Tensor(add_142, convert_element_type_1203); add_142 = convert_element_type_1203 = None + convert_element_type_default_60 = torch.ops.prims.convert_element_type.default(sum_16, torch.float32); sum_16 = None + reduce_scatter_tensor_23 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_60, 'avg', 64, '0'); convert_element_type_default_60 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23); reduce_scatter_tensor_23 = None + view_1149 = torch.ops.aten.view.default(add_146, [16384, 4096]) + permute_433 = torch.ops.aten.permute.default(view_1149, [1, 0]) + mm_261 = torch.ops.aten.mm.default(permute_433, view_1009); permute_433 = view_1009 = None + permute_435 = torch.ops.aten.permute.default(permute_326, [1, 0]); permute_326 = None + mm_262 = torch.ops.aten.mm.default(view_1149, permute_435); view_1149 = permute_435 = None + view_1150 = torch.ops.aten.view.default(mm_262, [2, 8192, 4096]); mm_262 = None + convert_element_type_1210 = torch.ops.prims.convert_element_type.default(mm_261, torch.float32); mm_261 = None + reduce_scatter_tensor_24 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1210, 'avg', 64, '0'); convert_element_type_1210 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_24); reduce_scatter_tensor_24 = None + view_1151 = torch.ops.aten.view.default(view_1150, [2, 8192, 32, 128]); view_1150 = None + permute_437 = torch.ops.aten.permute.default(view_1151, [0, 2, 1, 3]); view_1151 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16); primals_265 = None + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 64, '0'); convert_element_type_958 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32); add_115 = None + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_262) + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + view_989 = torch.ops.aten.view.default(convert_element_type_960, [16384, 4096]); convert_element_type_960 = None + view_990 = torch.ops.aten.view.default(mm_203, [2, 8192, 4096]); mm_203 = None + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 64, '0'); convert_element_type_964 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + mm_204 = torch.ops.aten.mm.default(view_989, permute_320) + view_993 = torch.ops.aten.view.default(mm_204, [2, 8192, 1024]); mm_204 = None + view_996 = torch.ops.aten.view.default(mm_205, [2, 8192, 1024]); mm_205 = None + view_997 = torch.ops.aten.view.default(view_990, [2, 8192, -1, 128]); view_990 = None + view_998 = torch.ops.aten.view.default(view_993, [2, 8192, -1, 128]); view_993 = None + view_999 = torch.ops.aten.view.default(view_996, [2, 8192, -1, 128]); view_996 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + view_1000 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 32, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1000); view_1000 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_998, torch.float32); view_998 = None + view_1001 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 8, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1001); view_1001 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_16); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_1003 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 32, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_16); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_1004 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 8, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_1003, torch.bfloat16); view_1003 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_1004, torch.bfloat16); view_1004 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 8, 4, 128]); unsqueeze_58 = None + clone_58 = torch.ops.aten.clone.default(expand_58, memory_format = torch.contiguous_format); expand_58 = None + view_1005 = torch.ops.aten.view.default(clone_58, [2, 8192, 32, 128]); clone_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_999, 3); view_999 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 8, 4, 128]); unsqueeze_59 = None + clone_59 = torch.ops.aten.clone.default(expand_59, memory_format = torch.contiguous_format); expand_59 = None + view_1006 = torch.ops.aten.view.default(clone_59, [2, 8192, 32, 128]); clone_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_1005, [0, 2, 1, 3]); view_1005 = None + permute_324 = torch.ops.aten.permute.default(view_1006, [0, 2, 1, 3]); view_1006 = None + _scaled_dot_product_cudnn_attention_backward_2 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_437, permute_322, permute_323, permute_324, getitem_261, getitem_262, getitem_267, getitem_268, None, None, None, 8192, 8192, 0.0, True); permute_437 = permute_322 = permute_323 = permute_324 = getitem_261 = getitem_262 = getitem_267 = getitem_268 = None + getitem_294 = _scaled_dot_product_cudnn_attention_backward_2[0] + getitem_295 = _scaled_dot_product_cudnn_attention_backward_2[1] + getitem_296 = _scaled_dot_product_cudnn_attention_backward_2[2]; _scaled_dot_product_cudnn_attention_backward_2 = None + permute_438 = torch.ops.aten.permute.default(getitem_296, [0, 2, 1, 3]); getitem_296 = None + permute_439 = torch.ops.aten.permute.default(getitem_295, [0, 2, 1, 3]); getitem_295 = None + permute_440 = torch.ops.aten.permute.default(getitem_294, [0, 2, 1, 3]); getitem_294 = None + view_1152 = torch.ops.aten.view.default(permute_438, [2, 8192, 8, 4, 128]); permute_438 = None + sum_17 = torch.ops.aten.sum.dim_IntList(view_1152, [3], True); view_1152 = None + squeeze_4 = torch.ops.aten.squeeze.dim(sum_17, 3); sum_17 = None + view_1153 = torch.ops.aten.view.default(permute_439, [2, 8192, 8, 4, 128]); permute_439 = None + sum_18 = torch.ops.aten.sum.dim_IntList(view_1153, [3], True); view_1153 = None + squeeze_5 = torch.ops.aten.squeeze.dim(sum_18, 3); sum_18 = None + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(squeeze_5, torch.float32); squeeze_5 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(permute_440, torch.float32); permute_440 = None + view_1154 = torch.ops.aten.view.default(convert_element_type_1211, [2, 8192, 8, 64, 2]); convert_element_type_1211 = None + view_as_complex_68 = torch.ops.aten.view_as_complex.default(view_1154); view_1154 = None + mul_316 = torch.ops.aten.mul.Tensor(view_as_complex_68, _conj); view_as_complex_68 = None + view_1155 = torch.ops.aten.view.default(convert_element_type_1212, [2, 8192, 32, 64, 2]); convert_element_type_1212 = None + view_as_complex_69 = torch.ops.aten.view_as_complex.default(view_1155); view_1155 = None + mul_317 = torch.ops.aten.mul.Tensor(view_as_complex_69, _conj); view_as_complex_69 = None + view_as_real_68 = torch.ops.aten.view_as_real.default(mul_316); mul_316 = None + view_1156 = torch.ops.aten.view.default(view_as_real_68, [2, 8192, 8, 128]); view_as_real_68 = None + convert_element_type_1213 = torch.ops.prims.convert_element_type.default(view_1156, torch.bfloat16); view_1156 = None + view_as_real_69 = torch.ops.aten.view_as_real.default(mul_317); mul_317 = None + view_1157 = torch.ops.aten.view.default(view_as_real_69, [2, 8192, 32, 128]); view_as_real_69 = None + convert_element_type_1214 = torch.ops.prims.convert_element_type.default(view_1157, torch.bfloat16); view_1157 = None + view_1158 = torch.ops.aten.view.default(squeeze_4, [2, 8192, 1024]); squeeze_4 = None + view_1159 = torch.ops.aten.view.default(convert_element_type_1213, [2, 8192, 1024]); convert_element_type_1213 = None + view_1160 = torch.ops.aten.view.default(convert_element_type_1214, [2, 8192, 4096]); convert_element_type_1214 = None + view_1161 = torch.ops.aten.view.default(view_1158, [16384, 1024]); view_1158 = None + permute_441 = torch.ops.aten.permute.default(view_1161, [1, 0]) + mm_263 = torch.ops.aten.mm.default(permute_441, view_989); permute_441 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 64, '0'); convert_element_type_967 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + permute_443 = torch.ops.aten.permute.default(permute_321, [1, 0]); permute_321 = None + mm_264 = torch.ops.aten.mm.default(view_1161, permute_443); view_1161 = permute_443 = None + view_1162 = torch.ops.aten.view.default(mm_264, [2, 8192, 4096]); mm_264 = None + convert_element_type_1219 = torch.ops.prims.convert_element_type.default(mm_263, torch.float32); mm_263 = None + reduce_scatter_tensor_25 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1219, 'avg', 64, '0'); convert_element_type_1219 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25); reduce_scatter_tensor_25 = None + view_1163 = torch.ops.aten.view.default(view_1159, [16384, 1024]); view_1159 = None + permute_445 = torch.ops.aten.permute.default(view_1163, [1, 0]) + mm_265 = torch.ops.aten.mm.default(permute_445, view_989); permute_445 = None + permute_447 = torch.ops.aten.permute.default(permute_320, [1, 0]); permute_320 = None + mm_266 = torch.ops.aten.mm.default(view_1163, permute_447); view_1163 = permute_447 = None + view_1164 = torch.ops.aten.view.default(mm_266, [2, 8192, 4096]); mm_266 = None + add_147 = torch.ops.aten.add.Tensor(view_1162, view_1164); view_1162 = view_1164 = None + convert_element_type_1224 = torch.ops.prims.convert_element_type.default(mm_265, torch.float32); mm_265 = None + reduce_scatter_tensor_26 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1224, 'avg', 64, '0'); convert_element_type_1224 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_26); reduce_scatter_tensor_26 = None + view_1165 = torch.ops.aten.view.default(view_1160, [16384, 4096]); view_1160 = None + permute_449 = torch.ops.aten.permute.default(view_1165, [1, 0]) + mm_267 = torch.ops.aten.mm.default(permute_449, view_989); permute_449 = view_989 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16); primals_266 = None + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 64, '0'); convert_element_type_961 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_263, [1, 0]); wait_tensor_263 = None + permute_451 = torch.ops.aten.permute.default(permute_319, [1, 0]); permute_319 = None + mm_268 = torch.ops.aten.mm.default(view_1165, permute_451); view_1165 = permute_451 = None + view_1166 = torch.ops.aten.view.default(mm_268, [2, 8192, 4096]); mm_268 = None + add_148 = torch.ops.aten.add.Tensor(add_147, view_1166); add_147 = view_1166 = None + convert_element_type_1229 = torch.ops.prims.convert_element_type.default(mm_267, torch.float32); mm_267 = None + reduce_scatter_tensor_27 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1229, 'avg', 64, '0'); convert_element_type_1229 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27); reduce_scatter_tensor_27 = None + convert_element_type_1230 = torch.ops.prims.convert_element_type.default(add_148, torch.float32); add_148 = None + convert_element_type_1232 = torch.ops.prims.convert_element_type.default(wait_tensor_262, torch.float32); wait_tensor_262 = None + mul_318 = torch.ops.aten.mul.Tensor(convert_element_type_1230, convert_element_type_1232); convert_element_type_1232 = None + mul_320 = torch.ops.aten.mul.Tensor(mul_232, mul_318) + sum_19 = torch.ops.aten.sum.dim_IntList(mul_320, [2], True); mul_320 = None + div_6 = torch.ops.aten.div.Tensor(mul_232, 4096) + mul_321 = torch.ops.aten.mul.Tensor(div_6, sum_19); div_6 = sum_19 = None + sub_9 = torch.ops.aten.sub.Tensor(mul_318, mul_321); mul_318 = mul_321 = None + mul_322 = torch.ops.aten.mul.Tensor(sub_9, rsqrt_58); sub_9 = rsqrt_58 = None + mul_323 = torch.ops.aten.mul.Tensor(convert_element_type_1230, mul_232); convert_element_type_1230 = mul_232 = None + sum_20 = torch.ops.aten.sum.dim_IntList(mul_323, [0, 1]); mul_323 = None + convert_element_type_1233 = torch.ops.prims.convert_element_type.default(mul_322, torch.bfloat16); mul_322 = None + add_149 = torch.ops.aten.add.Tensor(add_146, convert_element_type_1233); add_146 = convert_element_type_1233 = None + convert_element_type_default_59 = torch.ops.prims.convert_element_type.default(sum_20, torch.float32); sum_20 = None + reduce_scatter_tensor_28 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_59, 'avg', 64, '0'); convert_element_type_default_59 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_28); reduce_scatter_tensor_28 = None + view_1167 = torch.ops.aten.view.default(add_149, [16384, 4096]) + permute_453 = torch.ops.aten.permute.default(view_1167, [1, 0]) + permute_314 = torch.ops.aten.permute.default(getitem_252, [0, 2, 1, 3]) + view_973 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16); primals_260 = None + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 64, '0'); convert_element_type_941 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_257, [1, 0]); wait_tensor_257 = None + view_975 = torch.ops.aten.view.default(view_973, [16384, 4096]); view_973 = None + mm_199 = torch.ops.aten.mm.default(view_975, permute_315) + view_976 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + add_113 = torch.ops.aten.add.Tensor(add_111, view_976); view_976 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16); primals_261 = None + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 64, '0'); convert_element_type_944 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32); add_113 = None + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_258) + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + view_979 = torch.ops.aten.view.default(convert_element_type_946, [16384, 4096]); convert_element_type_946 = None + view_980 = torch.ops.aten.view.default(mm_200, [2, 8192, 14336]); mm_200 = None + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_980, torch.float32); view_980 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16); primals_263 = None + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 64, '0'); convert_element_type_952 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + mm_201 = torch.ops.aten.mm.default(view_979, permute_317) + view_983 = torch.ops.aten.view.default(mm_201, [2, 8192, 14336]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_983) + view_985 = torch.ops.aten.view.default(mul_231, [16384, 14336]); mul_231 = None + mm_269 = torch.ops.aten.mm.default(permute_453, view_985); permute_453 = view_985 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16); primals_264 = None + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 64, '0'); convert_element_type_955 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + permute_455 = torch.ops.aten.permute.default(permute_318, [1, 0]); permute_318 = None + mm_270 = torch.ops.aten.mm.default(view_1167, permute_455); view_1167 = permute_455 = None + view_1168 = torch.ops.aten.view.default(mm_270, [2, 8192, 14336]); mm_270 = None + convert_element_type_1240 = torch.ops.prims.convert_element_type.default(mm_269, torch.float32); mm_269 = None + reduce_scatter_tensor_29 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1240, 'avg', 64, '0'); convert_element_type_1240 = None + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29); reduce_scatter_tensor_29 = None + mul_324 = torch.ops.aten.mul.Tensor(view_1168, convert_element_type_951); convert_element_type_951 = None + mul_325 = torch.ops.aten.mul.Tensor(view_1168, view_983); view_1168 = view_983 = None + view_1169 = torch.ops.aten.view.default(mul_324, [16384, 14336]); mul_324 = None + permute_457 = torch.ops.aten.permute.default(view_1169, [1, 0]) + mm_271 = torch.ops.aten.mm.default(permute_457, view_979); permute_457 = None + permute_459 = torch.ops.aten.permute.default(permute_317, [1, 0]); permute_317 = None + mm_272 = torch.ops.aten.mm.default(view_1169, permute_459); view_1169 = permute_459 = None + view_1170 = torch.ops.aten.view.default(mm_272, [2, 8192, 4096]); mm_272 = None + convert_element_type_1245 = torch.ops.prims.convert_element_type.default(mm_271, torch.float32); mm_271 = None + reduce_scatter_tensor_30 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1245, 'avg', 64, '0'); convert_element_type_1245 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_30); reduce_scatter_tensor_30 = None + convert_element_type_1246 = torch.ops.prims.convert_element_type.default(mul_325, torch.float32); mul_325 = None + neg_3 = torch.ops.aten.neg.default(convert_element_type_950) + exp_3 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_150 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + reciprocal_3 = torch.ops.aten.reciprocal.default(add_150); add_150 = None + mul_326 = torch.ops.aten.mul.Tensor(reciprocal_3, 1); reciprocal_3 = None + mul_327 = torch.ops.aten.mul.Tensor(convert_element_type_1246, mul_326); convert_element_type_1246 = None + sub_10 = torch.ops.aten.sub.Tensor(1, mul_326); mul_326 = None + mul_328 = torch.ops.aten.mul.Tensor(convert_element_type_950, sub_10); convert_element_type_950 = sub_10 = None + add_151 = torch.ops.aten.add.Tensor(mul_328, 1); mul_328 = None + mul_329 = torch.ops.aten.mul.Tensor(mul_327, add_151); mul_327 = add_151 = None + convert_element_type_1248 = torch.ops.prims.convert_element_type.default(mul_329, torch.bfloat16); mul_329 = None + view_1171 = torch.ops.aten.view.default(convert_element_type_1248, [16384, 14336]); convert_element_type_1248 = None + permute_461 = torch.ops.aten.permute.default(view_1171, [1, 0]) + mm_273 = torch.ops.aten.mm.default(permute_461, view_979); permute_461 = view_979 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16); primals_262 = None + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 64, '0'); convert_element_type_947 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + permute_463 = torch.ops.aten.permute.default(permute_316, [1, 0]); permute_316 = None + mm_274 = torch.ops.aten.mm.default(view_1171, permute_463); view_1171 = permute_463 = None + view_1172 = torch.ops.aten.view.default(mm_274, [2, 8192, 4096]); mm_274 = None + add_152 = torch.ops.aten.add.Tensor(view_1170, view_1172); view_1170 = view_1172 = None + convert_element_type_1253 = torch.ops.prims.convert_element_type.default(mm_273, torch.float32); mm_273 = None + reduce_scatter_tensor_31 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1253, 'avg', 64, '0'); convert_element_type_1253 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31); reduce_scatter_tensor_31 = None + convert_element_type_1254 = torch.ops.prims.convert_element_type.default(add_152, torch.float32); add_152 = None + convert_element_type_1256 = torch.ops.prims.convert_element_type.default(wait_tensor_258, torch.float32); wait_tensor_258 = None + mul_330 = torch.ops.aten.mul.Tensor(convert_element_type_1254, convert_element_type_1256); convert_element_type_1256 = None + mul_332 = torch.ops.aten.mul.Tensor(mul_228, mul_330) + sum_21 = torch.ops.aten.sum.dim_IntList(mul_332, [2], True); mul_332 = None + div_7 = torch.ops.aten.div.Tensor(mul_228, 4096) + mul_333 = torch.ops.aten.mul.Tensor(div_7, sum_21); div_7 = sum_21 = None + sub_11 = torch.ops.aten.sub.Tensor(mul_330, mul_333); mul_330 = mul_333 = None + mul_334 = torch.ops.aten.mul.Tensor(sub_11, rsqrt_57); sub_11 = rsqrt_57 = None + mul_335 = torch.ops.aten.mul.Tensor(convert_element_type_1254, mul_228); convert_element_type_1254 = mul_228 = None + sum_22 = torch.ops.aten.sum.dim_IntList(mul_335, [0, 1]); mul_335 = None + convert_element_type_1257 = torch.ops.prims.convert_element_type.default(mul_334, torch.bfloat16); mul_334 = None + add_153 = torch.ops.aten.add.Tensor(add_149, convert_element_type_1257); add_149 = convert_element_type_1257 = None + convert_element_type_default_58 = torch.ops.prims.convert_element_type.default(sum_22, torch.float32); sum_22 = None + reduce_scatter_tensor_32 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_58, 'avg', 64, '0'); convert_element_type_default_58 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + view_1173 = torch.ops.aten.view.default(add_153, [16384, 4096]) + permute_465 = torch.ops.aten.permute.default(view_1173, [1, 0]) + mm_275 = torch.ops.aten.mm.default(permute_465, view_975); permute_465 = view_975 = None + permute_467 = torch.ops.aten.permute.default(permute_315, [1, 0]); permute_315 = None + mm_276 = torch.ops.aten.mm.default(view_1173, permute_467); view_1173 = permute_467 = None + view_1174 = torch.ops.aten.view.default(mm_276, [2, 8192, 4096]); mm_276 = None + convert_element_type_1264 = torch.ops.prims.convert_element_type.default(mm_275, torch.float32); mm_275 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1264, 'avg', 64, '0'); convert_element_type_1264 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33); reduce_scatter_tensor_33 = None + view_1175 = torch.ops.aten.view.default(view_1174, [2, 8192, 32, 128]); view_1174 = None + permute_469 = torch.ops.aten.permute.default(view_1175, [0, 2, 1, 3]); view_1175 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16); primals_256 = None + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 64, '0'); convert_element_type_925 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32); add_111 = None + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_253) + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + view_955 = torch.ops.aten.view.default(convert_element_type_927, [16384, 4096]); convert_element_type_927 = None + view_956 = torch.ops.aten.view.default(mm_196, [2, 8192, 4096]); mm_196 = None + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16); primals_258 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 64, '0'); convert_element_type_931 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_255, [1, 0]); wait_tensor_255 = None + mm_197 = torch.ops.aten.mm.default(view_955, permute_309) + view_959 = torch.ops.aten.view.default(mm_197, [2, 8192, 1024]); mm_197 = None + view_962 = torch.ops.aten.view.default(mm_198, [2, 8192, 1024]); mm_198 = None + view_963 = torch.ops.aten.view.default(view_956, [2, 8192, -1, 128]); view_956 = None + view_964 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_965 = torch.ops.aten.view.default(view_962, [2, 8192, -1, 128]); view_962 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_963, torch.float32); view_963 = None + view_966 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 32, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_966); view_966 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_964, torch.float32); view_964 = None + view_967 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 8, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_967); view_967 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_16); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_969 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 32, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_16); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_970 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 8, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_969, torch.bfloat16); view_969 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_970, torch.bfloat16); view_970 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 8, 4, 128]); unsqueeze_56 = None + clone_56 = torch.ops.aten.clone.default(expand_56, memory_format = torch.contiguous_format); expand_56 = None + view_971 = torch.ops.aten.view.default(clone_56, [2, 8192, 32, 128]); clone_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_965, 3); view_965 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 8, 4, 128]); unsqueeze_57 = None + clone_57 = torch.ops.aten.clone.default(expand_57, memory_format = torch.contiguous_format); expand_57 = None + view_972 = torch.ops.aten.view.default(clone_57, [2, 8192, 32, 128]); clone_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_971, [0, 2, 1, 3]); view_971 = None + permute_313 = torch.ops.aten.permute.default(view_972, [0, 2, 1, 3]); view_972 = None + _scaled_dot_product_cudnn_attention_backward_3 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_469, permute_311, permute_312, permute_313, getitem_252, getitem_253, getitem_258, getitem_259, None, None, None, 8192, 8192, 0.0, True); permute_469 = permute_311 = permute_312 = permute_313 = getitem_252 = getitem_253 = getitem_258 = getitem_259 = None + getitem_297 = _scaled_dot_product_cudnn_attention_backward_3[0] + getitem_298 = _scaled_dot_product_cudnn_attention_backward_3[1] + getitem_299 = _scaled_dot_product_cudnn_attention_backward_3[2]; _scaled_dot_product_cudnn_attention_backward_3 = None + permute_470 = torch.ops.aten.permute.default(getitem_299, [0, 2, 1, 3]); getitem_299 = None + permute_471 = torch.ops.aten.permute.default(getitem_298, [0, 2, 1, 3]); getitem_298 = None + permute_472 = torch.ops.aten.permute.default(getitem_297, [0, 2, 1, 3]); getitem_297 = None + view_1176 = torch.ops.aten.view.default(permute_470, [2, 8192, 8, 4, 128]); permute_470 = None + sum_23 = torch.ops.aten.sum.dim_IntList(view_1176, [3], True); view_1176 = None + squeeze_6 = torch.ops.aten.squeeze.dim(sum_23, 3); sum_23 = None + view_1177 = torch.ops.aten.view.default(permute_471, [2, 8192, 8, 4, 128]); permute_471 = None + sum_24 = torch.ops.aten.sum.dim_IntList(view_1177, [3], True); view_1177 = None + squeeze_7 = torch.ops.aten.squeeze.dim(sum_24, 3); sum_24 = None + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(squeeze_7, torch.float32); squeeze_7 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(permute_472, torch.float32); permute_472 = None + view_1178 = torch.ops.aten.view.default(convert_element_type_1265, [2, 8192, 8, 64, 2]); convert_element_type_1265 = None + view_as_complex_70 = torch.ops.aten.view_as_complex.default(view_1178); view_1178 = None + mul_336 = torch.ops.aten.mul.Tensor(view_as_complex_70, _conj); view_as_complex_70 = None + view_1179 = torch.ops.aten.view.default(convert_element_type_1266, [2, 8192, 32, 64, 2]); convert_element_type_1266 = None + view_as_complex_71 = torch.ops.aten.view_as_complex.default(view_1179); view_1179 = None + mul_337 = torch.ops.aten.mul.Tensor(view_as_complex_71, _conj); view_as_complex_71 = None + view_as_real_70 = torch.ops.aten.view_as_real.default(mul_336); mul_336 = None + view_1180 = torch.ops.aten.view.default(view_as_real_70, [2, 8192, 8, 128]); view_as_real_70 = None + convert_element_type_1267 = torch.ops.prims.convert_element_type.default(view_1180, torch.bfloat16); view_1180 = None + view_as_real_71 = torch.ops.aten.view_as_real.default(mul_337); mul_337 = None + view_1181 = torch.ops.aten.view.default(view_as_real_71, [2, 8192, 32, 128]); view_as_real_71 = None + convert_element_type_1268 = torch.ops.prims.convert_element_type.default(view_1181, torch.bfloat16); view_1181 = None + view_1182 = torch.ops.aten.view.default(squeeze_6, [2, 8192, 1024]); squeeze_6 = None + view_1183 = torch.ops.aten.view.default(convert_element_type_1267, [2, 8192, 1024]); convert_element_type_1267 = None + view_1184 = torch.ops.aten.view.default(convert_element_type_1268, [2, 8192, 4096]); convert_element_type_1268 = None + view_1185 = torch.ops.aten.view.default(view_1182, [16384, 1024]); view_1182 = None + permute_473 = torch.ops.aten.permute.default(view_1185, [1, 0]) + mm_277 = torch.ops.aten.mm.default(permute_473, view_955); permute_473 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16); primals_259 = None + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 64, '0'); convert_element_type_934 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_256, [1, 0]); wait_tensor_256 = None + permute_475 = torch.ops.aten.permute.default(permute_310, [1, 0]); permute_310 = None + mm_278 = torch.ops.aten.mm.default(view_1185, permute_475); view_1185 = permute_475 = None + view_1186 = torch.ops.aten.view.default(mm_278, [2, 8192, 4096]); mm_278 = None + convert_element_type_1273 = torch.ops.prims.convert_element_type.default(mm_277, torch.float32); mm_277 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1273, 'avg', 64, '0'); convert_element_type_1273 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + view_1187 = torch.ops.aten.view.default(view_1183, [16384, 1024]); view_1183 = None + permute_477 = torch.ops.aten.permute.default(view_1187, [1, 0]) + mm_279 = torch.ops.aten.mm.default(permute_477, view_955); permute_477 = None + permute_479 = torch.ops.aten.permute.default(permute_309, [1, 0]); permute_309 = None + mm_280 = torch.ops.aten.mm.default(view_1187, permute_479); view_1187 = permute_479 = None + view_1188 = torch.ops.aten.view.default(mm_280, [2, 8192, 4096]); mm_280 = None + add_154 = torch.ops.aten.add.Tensor(view_1186, view_1188); view_1186 = view_1188 = None + convert_element_type_1278 = torch.ops.prims.convert_element_type.default(mm_279, torch.float32); mm_279 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1278, 'avg', 64, '0'); convert_element_type_1278 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35); reduce_scatter_tensor_35 = None + view_1189 = torch.ops.aten.view.default(view_1184, [16384, 4096]); view_1184 = None + permute_481 = torch.ops.aten.permute.default(view_1189, [1, 0]) + mm_281 = torch.ops.aten.mm.default(permute_481, view_955); permute_481 = view_955 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16); primals_257 = None + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 64, '0'); convert_element_type_928 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + permute_483 = torch.ops.aten.permute.default(permute_308, [1, 0]); permute_308 = None + mm_282 = torch.ops.aten.mm.default(view_1189, permute_483); view_1189 = permute_483 = None + view_1190 = torch.ops.aten.view.default(mm_282, [2, 8192, 4096]); mm_282 = None + add_155 = torch.ops.aten.add.Tensor(add_154, view_1190); add_154 = view_1190 = None + convert_element_type_1283 = torch.ops.prims.convert_element_type.default(mm_281, torch.float32); mm_281 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1283, 'avg', 64, '0'); convert_element_type_1283 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + convert_element_type_1284 = torch.ops.prims.convert_element_type.default(add_155, torch.float32); add_155 = None + convert_element_type_1286 = torch.ops.prims.convert_element_type.default(wait_tensor_253, torch.float32); wait_tensor_253 = None + mul_338 = torch.ops.aten.mul.Tensor(convert_element_type_1284, convert_element_type_1286); convert_element_type_1286 = None + mul_340 = torch.ops.aten.mul.Tensor(mul_224, mul_338) + sum_25 = torch.ops.aten.sum.dim_IntList(mul_340, [2], True); mul_340 = None + div_8 = torch.ops.aten.div.Tensor(mul_224, 4096) + mul_341 = torch.ops.aten.mul.Tensor(div_8, sum_25); div_8 = sum_25 = None + sub_12 = torch.ops.aten.sub.Tensor(mul_338, mul_341); mul_338 = mul_341 = None + mul_342 = torch.ops.aten.mul.Tensor(sub_12, rsqrt_56); sub_12 = rsqrt_56 = None + mul_343 = torch.ops.aten.mul.Tensor(convert_element_type_1284, mul_224); convert_element_type_1284 = mul_224 = None + sum_26 = torch.ops.aten.sum.dim_IntList(mul_343, [0, 1]); mul_343 = None + convert_element_type_1287 = torch.ops.prims.convert_element_type.default(mul_342, torch.bfloat16); mul_342 = None + add_156 = torch.ops.aten.add.Tensor(add_153, convert_element_type_1287); add_153 = convert_element_type_1287 = None + convert_element_type_default_57 = torch.ops.prims.convert_element_type.default(sum_26, torch.float32); sum_26 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_57, 'avg', 64, '0'); convert_element_type_default_57 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37); reduce_scatter_tensor_37 = None + view_1191 = torch.ops.aten.view.default(add_156, [16384, 4096]) + permute_485 = torch.ops.aten.permute.default(view_1191, [1, 0]) + permute_303 = torch.ops.aten.permute.default(getitem_243, [0, 2, 1, 3]) + view_939 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16); primals_251 = None + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 64, '0'); convert_element_type_908 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_248, [1, 0]); wait_tensor_248 = None + view_941 = torch.ops.aten.view.default(view_939, [16384, 4096]); view_939 = None + mm_192 = torch.ops.aten.mm.default(view_941, permute_304) + view_942 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + add_109 = torch.ops.aten.add.Tensor(add_107, view_942); view_942 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16); primals_252 = None + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 64, '0'); convert_element_type_911 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32); add_109 = None + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_249) + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + view_945 = torch.ops.aten.view.default(convert_element_type_913, [16384, 4096]); convert_element_type_913 = None + view_946 = torch.ops.aten.view.default(mm_193, [2, 8192, 14336]); mm_193 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_946, torch.float32); view_946 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16); primals_254 = None + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 64, '0'); convert_element_type_919 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + mm_194 = torch.ops.aten.mm.default(view_945, permute_306) + view_949 = torch.ops.aten.view.default(mm_194, [2, 8192, 14336]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_949) + view_951 = torch.ops.aten.view.default(mul_223, [16384, 14336]); mul_223 = None + mm_283 = torch.ops.aten.mm.default(permute_485, view_951); permute_485 = view_951 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16); primals_255 = None + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 64, '0'); convert_element_type_922 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + permute_487 = torch.ops.aten.permute.default(permute_307, [1, 0]); permute_307 = None + mm_284 = torch.ops.aten.mm.default(view_1191, permute_487); view_1191 = permute_487 = None + view_1192 = torch.ops.aten.view.default(mm_284, [2, 8192, 14336]); mm_284 = None + convert_element_type_1294 = torch.ops.prims.convert_element_type.default(mm_283, torch.float32); mm_283 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1294, 'avg', 64, '0'); convert_element_type_1294 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + mul_344 = torch.ops.aten.mul.Tensor(view_1192, convert_element_type_918); convert_element_type_918 = None + mul_345 = torch.ops.aten.mul.Tensor(view_1192, view_949); view_1192 = view_949 = None + view_1193 = torch.ops.aten.view.default(mul_344, [16384, 14336]); mul_344 = None + permute_489 = torch.ops.aten.permute.default(view_1193, [1, 0]) + mm_285 = torch.ops.aten.mm.default(permute_489, view_945); permute_489 = None + permute_491 = torch.ops.aten.permute.default(permute_306, [1, 0]); permute_306 = None + mm_286 = torch.ops.aten.mm.default(view_1193, permute_491); view_1193 = permute_491 = None + view_1194 = torch.ops.aten.view.default(mm_286, [2, 8192, 4096]); mm_286 = None + convert_element_type_1299 = torch.ops.prims.convert_element_type.default(mm_285, torch.float32); mm_285 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1299, 'avg', 64, '0'); convert_element_type_1299 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39); reduce_scatter_tensor_39 = None + convert_element_type_1300 = torch.ops.prims.convert_element_type.default(mul_345, torch.float32); mul_345 = None + neg_4 = torch.ops.aten.neg.default(convert_element_type_917) + exp_4 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_157 = torch.ops.aten.add.Tensor(exp_4, 1); exp_4 = None + reciprocal_4 = torch.ops.aten.reciprocal.default(add_157); add_157 = None + mul_346 = torch.ops.aten.mul.Tensor(reciprocal_4, 1); reciprocal_4 = None + mul_347 = torch.ops.aten.mul.Tensor(convert_element_type_1300, mul_346); convert_element_type_1300 = None + sub_13 = torch.ops.aten.sub.Tensor(1, mul_346); mul_346 = None + mul_348 = torch.ops.aten.mul.Tensor(convert_element_type_917, sub_13); convert_element_type_917 = sub_13 = None + add_158 = torch.ops.aten.add.Tensor(mul_348, 1); mul_348 = None + mul_349 = torch.ops.aten.mul.Tensor(mul_347, add_158); mul_347 = add_158 = None + convert_element_type_1302 = torch.ops.prims.convert_element_type.default(mul_349, torch.bfloat16); mul_349 = None + view_1195 = torch.ops.aten.view.default(convert_element_type_1302, [16384, 14336]); convert_element_type_1302 = None + permute_493 = torch.ops.aten.permute.default(view_1195, [1, 0]) + mm_287 = torch.ops.aten.mm.default(permute_493, view_945); permute_493 = view_945 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16); primals_253 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 64, '0'); convert_element_type_914 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_250, [1, 0]); wait_tensor_250 = None + permute_495 = torch.ops.aten.permute.default(permute_305, [1, 0]); permute_305 = None + mm_288 = torch.ops.aten.mm.default(view_1195, permute_495); view_1195 = permute_495 = None + view_1196 = torch.ops.aten.view.default(mm_288, [2, 8192, 4096]); mm_288 = None + add_159 = torch.ops.aten.add.Tensor(view_1194, view_1196); view_1194 = view_1196 = None + convert_element_type_1307 = torch.ops.prims.convert_element_type.default(mm_287, torch.float32); mm_287 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1307, 'avg', 64, '0'); convert_element_type_1307 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + convert_element_type_1308 = torch.ops.prims.convert_element_type.default(add_159, torch.float32); add_159 = None + convert_element_type_1310 = torch.ops.prims.convert_element_type.default(wait_tensor_249, torch.float32); wait_tensor_249 = None + mul_350 = torch.ops.aten.mul.Tensor(convert_element_type_1308, convert_element_type_1310); convert_element_type_1310 = None + mul_352 = torch.ops.aten.mul.Tensor(mul_220, mul_350) + sum_27 = torch.ops.aten.sum.dim_IntList(mul_352, [2], True); mul_352 = None + div_9 = torch.ops.aten.div.Tensor(mul_220, 4096) + mul_353 = torch.ops.aten.mul.Tensor(div_9, sum_27); div_9 = sum_27 = None + sub_14 = torch.ops.aten.sub.Tensor(mul_350, mul_353); mul_350 = mul_353 = None + mul_354 = torch.ops.aten.mul.Tensor(sub_14, rsqrt_55); sub_14 = rsqrt_55 = None + mul_355 = torch.ops.aten.mul.Tensor(convert_element_type_1308, mul_220); convert_element_type_1308 = mul_220 = None + sum_28 = torch.ops.aten.sum.dim_IntList(mul_355, [0, 1]); mul_355 = None + convert_element_type_1311 = torch.ops.prims.convert_element_type.default(mul_354, torch.bfloat16); mul_354 = None + add_160 = torch.ops.aten.add.Tensor(add_156, convert_element_type_1311); add_156 = convert_element_type_1311 = None + convert_element_type_default_56 = torch.ops.prims.convert_element_type.default(sum_28, torch.float32); sum_28 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_56, 'avg', 64, '0'); convert_element_type_default_56 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41); reduce_scatter_tensor_41 = None + view_1197 = torch.ops.aten.view.default(add_160, [16384, 4096]) + permute_497 = torch.ops.aten.permute.default(view_1197, [1, 0]) + mm_289 = torch.ops.aten.mm.default(permute_497, view_941); permute_497 = view_941 = None + permute_499 = torch.ops.aten.permute.default(permute_304, [1, 0]); permute_304 = None + mm_290 = torch.ops.aten.mm.default(view_1197, permute_499); view_1197 = permute_499 = None + view_1198 = torch.ops.aten.view.default(mm_290, [2, 8192, 4096]); mm_290 = None + convert_element_type_1318 = torch.ops.prims.convert_element_type.default(mm_289, torch.float32); mm_289 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1318, 'avg', 64, '0'); convert_element_type_1318 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + view_1199 = torch.ops.aten.view.default(view_1198, [2, 8192, 32, 128]); view_1198 = None + permute_501 = torch.ops.aten.permute.default(view_1199, [0, 2, 1, 3]); view_1199 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16); primals_247 = None + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 64, '0'); convert_element_type_892 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32); add_107 = None + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_244) + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + view_921 = torch.ops.aten.view.default(convert_element_type_894, [16384, 4096]); convert_element_type_894 = None + view_922 = torch.ops.aten.view.default(mm_189, [2, 8192, 4096]); mm_189 = None + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16); primals_249 = None + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 64, '0'); convert_element_type_898 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_190 = torch.ops.aten.mm.default(view_921, permute_298) + view_925 = torch.ops.aten.view.default(mm_190, [2, 8192, 1024]); mm_190 = None + view_928 = torch.ops.aten.view.default(mm_191, [2, 8192, 1024]); mm_191 = None + view_929 = torch.ops.aten.view.default(view_922, [2, 8192, -1, 128]); view_922 = None + view_930 = torch.ops.aten.view.default(view_925, [2, 8192, -1, 128]); view_925 = None + view_931 = torch.ops.aten.view.default(view_928, [2, 8192, -1, 128]); view_928 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_929, torch.float32); view_929 = None + view_932 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 32, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_932); view_932 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_930, torch.float32); view_930 = None + view_933 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 8, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_933); view_933 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_16); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_935 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 32, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_16); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_936 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 8, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_935, torch.bfloat16); view_935 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_936, torch.bfloat16); view_936 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 8, 4, 128]); unsqueeze_54 = None + clone_54 = torch.ops.aten.clone.default(expand_54, memory_format = torch.contiguous_format); expand_54 = None + view_937 = torch.ops.aten.view.default(clone_54, [2, 8192, 32, 128]); clone_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_931, 3); view_931 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 8, 4, 128]); unsqueeze_55 = None + clone_55 = torch.ops.aten.clone.default(expand_55, memory_format = torch.contiguous_format); expand_55 = None + view_938 = torch.ops.aten.view.default(clone_55, [2, 8192, 32, 128]); clone_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_937, [0, 2, 1, 3]); view_937 = None + permute_302 = torch.ops.aten.permute.default(view_938, [0, 2, 1, 3]); view_938 = None + _scaled_dot_product_cudnn_attention_backward_4 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_501, permute_300, permute_301, permute_302, getitem_243, getitem_244, getitem_249, getitem_250, None, None, None, 8192, 8192, 0.0, True); permute_501 = permute_300 = permute_301 = permute_302 = getitem_243 = getitem_244 = getitem_249 = getitem_250 = None + getitem_300 = _scaled_dot_product_cudnn_attention_backward_4[0] + getitem_301 = _scaled_dot_product_cudnn_attention_backward_4[1] + getitem_302 = _scaled_dot_product_cudnn_attention_backward_4[2]; _scaled_dot_product_cudnn_attention_backward_4 = None + permute_502 = torch.ops.aten.permute.default(getitem_302, [0, 2, 1, 3]); getitem_302 = None + permute_503 = torch.ops.aten.permute.default(getitem_301, [0, 2, 1, 3]); getitem_301 = None + permute_504 = torch.ops.aten.permute.default(getitem_300, [0, 2, 1, 3]); getitem_300 = None + view_1200 = torch.ops.aten.view.default(permute_502, [2, 8192, 8, 4, 128]); permute_502 = None + sum_29 = torch.ops.aten.sum.dim_IntList(view_1200, [3], True); view_1200 = None + squeeze_8 = torch.ops.aten.squeeze.dim(sum_29, 3); sum_29 = None + view_1201 = torch.ops.aten.view.default(permute_503, [2, 8192, 8, 4, 128]); permute_503 = None + sum_30 = torch.ops.aten.sum.dim_IntList(view_1201, [3], True); view_1201 = None + squeeze_9 = torch.ops.aten.squeeze.dim(sum_30, 3); sum_30 = None + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(squeeze_9, torch.float32); squeeze_9 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(permute_504, torch.float32); permute_504 = None + view_1202 = torch.ops.aten.view.default(convert_element_type_1319, [2, 8192, 8, 64, 2]); convert_element_type_1319 = None + view_as_complex_72 = torch.ops.aten.view_as_complex.default(view_1202); view_1202 = None + mul_356 = torch.ops.aten.mul.Tensor(view_as_complex_72, _conj); view_as_complex_72 = None + view_1203 = torch.ops.aten.view.default(convert_element_type_1320, [2, 8192, 32, 64, 2]); convert_element_type_1320 = None + view_as_complex_73 = torch.ops.aten.view_as_complex.default(view_1203); view_1203 = None + mul_357 = torch.ops.aten.mul.Tensor(view_as_complex_73, _conj); view_as_complex_73 = None + view_as_real_72 = torch.ops.aten.view_as_real.default(mul_356); mul_356 = None + view_1204 = torch.ops.aten.view.default(view_as_real_72, [2, 8192, 8, 128]); view_as_real_72 = None + convert_element_type_1321 = torch.ops.prims.convert_element_type.default(view_1204, torch.bfloat16); view_1204 = None + view_as_real_73 = torch.ops.aten.view_as_real.default(mul_357); mul_357 = None + view_1205 = torch.ops.aten.view.default(view_as_real_73, [2, 8192, 32, 128]); view_as_real_73 = None + convert_element_type_1322 = torch.ops.prims.convert_element_type.default(view_1205, torch.bfloat16); view_1205 = None + view_1206 = torch.ops.aten.view.default(squeeze_8, [2, 8192, 1024]); squeeze_8 = None + view_1207 = torch.ops.aten.view.default(convert_element_type_1321, [2, 8192, 1024]); convert_element_type_1321 = None + view_1208 = torch.ops.aten.view.default(convert_element_type_1322, [2, 8192, 4096]); convert_element_type_1322 = None + view_1209 = torch.ops.aten.view.default(view_1206, [16384, 1024]); view_1206 = None + permute_505 = torch.ops.aten.permute.default(view_1209, [1, 0]) + mm_291 = torch.ops.aten.mm.default(permute_505, view_921); permute_505 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16); primals_250 = None + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 64, '0'); convert_element_type_901 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + permute_507 = torch.ops.aten.permute.default(permute_299, [1, 0]); permute_299 = None + mm_292 = torch.ops.aten.mm.default(view_1209, permute_507); view_1209 = permute_507 = None + view_1210 = torch.ops.aten.view.default(mm_292, [2, 8192, 4096]); mm_292 = None + convert_element_type_1327 = torch.ops.prims.convert_element_type.default(mm_291, torch.float32); mm_291 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1327, 'avg', 64, '0'); convert_element_type_1327 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43); reduce_scatter_tensor_43 = None + view_1211 = torch.ops.aten.view.default(view_1207, [16384, 1024]); view_1207 = None + permute_509 = torch.ops.aten.permute.default(view_1211, [1, 0]) + mm_293 = torch.ops.aten.mm.default(permute_509, view_921); permute_509 = None + permute_511 = torch.ops.aten.permute.default(permute_298, [1, 0]); permute_298 = None + mm_294 = torch.ops.aten.mm.default(view_1211, permute_511); view_1211 = permute_511 = None + view_1212 = torch.ops.aten.view.default(mm_294, [2, 8192, 4096]); mm_294 = None + add_161 = torch.ops.aten.add.Tensor(view_1210, view_1212); view_1210 = view_1212 = None + convert_element_type_1332 = torch.ops.prims.convert_element_type.default(mm_293, torch.float32); mm_293 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1332, 'avg', 64, '0'); convert_element_type_1332 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + view_1213 = torch.ops.aten.view.default(view_1208, [16384, 4096]); view_1208 = None + permute_513 = torch.ops.aten.permute.default(view_1213, [1, 0]) + mm_295 = torch.ops.aten.mm.default(permute_513, view_921); permute_513 = view_921 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16); primals_248 = None + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 64, '0'); convert_element_type_895 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + permute_515 = torch.ops.aten.permute.default(permute_297, [1, 0]); permute_297 = None + mm_296 = torch.ops.aten.mm.default(view_1213, permute_515); view_1213 = permute_515 = None + view_1214 = torch.ops.aten.view.default(mm_296, [2, 8192, 4096]); mm_296 = None + add_162 = torch.ops.aten.add.Tensor(add_161, view_1214); add_161 = view_1214 = None + convert_element_type_1337 = torch.ops.prims.convert_element_type.default(mm_295, torch.float32); mm_295 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1337, 'avg', 64, '0'); convert_element_type_1337 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45); reduce_scatter_tensor_45 = None + convert_element_type_1338 = torch.ops.prims.convert_element_type.default(add_162, torch.float32); add_162 = None + convert_element_type_1340 = torch.ops.prims.convert_element_type.default(wait_tensor_244, torch.float32); wait_tensor_244 = None + mul_358 = torch.ops.aten.mul.Tensor(convert_element_type_1338, convert_element_type_1340); convert_element_type_1340 = None + mul_360 = torch.ops.aten.mul.Tensor(mul_216, mul_358) + sum_31 = torch.ops.aten.sum.dim_IntList(mul_360, [2], True); mul_360 = None + div_10 = torch.ops.aten.div.Tensor(mul_216, 4096) + mul_361 = torch.ops.aten.mul.Tensor(div_10, sum_31); div_10 = sum_31 = None + sub_15 = torch.ops.aten.sub.Tensor(mul_358, mul_361); mul_358 = mul_361 = None + mul_362 = torch.ops.aten.mul.Tensor(sub_15, rsqrt_54); sub_15 = rsqrt_54 = None + mul_363 = torch.ops.aten.mul.Tensor(convert_element_type_1338, mul_216); convert_element_type_1338 = mul_216 = None + sum_32 = torch.ops.aten.sum.dim_IntList(mul_363, [0, 1]); mul_363 = None + convert_element_type_1341 = torch.ops.prims.convert_element_type.default(mul_362, torch.bfloat16); mul_362 = None + add_163 = torch.ops.aten.add.Tensor(add_160, convert_element_type_1341); add_160 = convert_element_type_1341 = None + convert_element_type_default_55 = torch.ops.prims.convert_element_type.default(sum_32, torch.float32); sum_32 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_55, 'avg', 64, '0'); convert_element_type_default_55 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + view_1215 = torch.ops.aten.view.default(add_163, [16384, 4096]) + permute_517 = torch.ops.aten.permute.default(view_1215, [1, 0]) + permute_292 = torch.ops.aten.permute.default(getitem_234, [0, 2, 1, 3]) + view_905 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16); primals_242 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 64, '0'); convert_element_type_875 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + view_907 = torch.ops.aten.view.default(view_905, [16384, 4096]); view_905 = None + mm_185 = torch.ops.aten.mm.default(view_907, permute_293) + view_908 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + add_105 = torch.ops.aten.add.Tensor(add_103, view_908); view_908 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16); primals_243 = None + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 64, '0'); convert_element_type_878 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32); add_105 = None + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_240) + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + view_911 = torch.ops.aten.view.default(convert_element_type_880, [16384, 4096]); convert_element_type_880 = None + view_912 = torch.ops.aten.view.default(mm_186, [2, 8192, 14336]); mm_186 = None + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_912, torch.float32); view_912 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16); primals_245 = None + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 64, '0'); convert_element_type_886 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_242, [1, 0]); wait_tensor_242 = None + mm_187 = torch.ops.aten.mm.default(view_911, permute_295) + view_915 = torch.ops.aten.view.default(mm_187, [2, 8192, 14336]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_915) + view_917 = torch.ops.aten.view.default(mul_215, [16384, 14336]); mul_215 = None + mm_297 = torch.ops.aten.mm.default(permute_517, view_917); permute_517 = view_917 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16); primals_246 = None + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 64, '0'); convert_element_type_889 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + permute_519 = torch.ops.aten.permute.default(permute_296, [1, 0]); permute_296 = None + mm_298 = torch.ops.aten.mm.default(view_1215, permute_519); view_1215 = permute_519 = None + view_1216 = torch.ops.aten.view.default(mm_298, [2, 8192, 14336]); mm_298 = None + convert_element_type_1348 = torch.ops.prims.convert_element_type.default(mm_297, torch.float32); mm_297 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1348, 'avg', 64, '0'); convert_element_type_1348 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47); reduce_scatter_tensor_47 = None + mul_364 = torch.ops.aten.mul.Tensor(view_1216, convert_element_type_885); convert_element_type_885 = None + mul_365 = torch.ops.aten.mul.Tensor(view_1216, view_915); view_1216 = view_915 = None + view_1217 = torch.ops.aten.view.default(mul_364, [16384, 14336]); mul_364 = None + permute_521 = torch.ops.aten.permute.default(view_1217, [1, 0]) + mm_299 = torch.ops.aten.mm.default(permute_521, view_911); permute_521 = None + permute_523 = torch.ops.aten.permute.default(permute_295, [1, 0]); permute_295 = None + mm_300 = torch.ops.aten.mm.default(view_1217, permute_523); view_1217 = permute_523 = None + view_1218 = torch.ops.aten.view.default(mm_300, [2, 8192, 4096]); mm_300 = None + convert_element_type_1353 = torch.ops.prims.convert_element_type.default(mm_299, torch.float32); mm_299 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1353, 'avg', 64, '0'); convert_element_type_1353 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + convert_element_type_1354 = torch.ops.prims.convert_element_type.default(mul_365, torch.float32); mul_365 = None + neg_5 = torch.ops.aten.neg.default(convert_element_type_884) + exp_5 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_164 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + reciprocal_5 = torch.ops.aten.reciprocal.default(add_164); add_164 = None + mul_366 = torch.ops.aten.mul.Tensor(reciprocal_5, 1); reciprocal_5 = None + mul_367 = torch.ops.aten.mul.Tensor(convert_element_type_1354, mul_366); convert_element_type_1354 = None + sub_16 = torch.ops.aten.sub.Tensor(1, mul_366); mul_366 = None + mul_368 = torch.ops.aten.mul.Tensor(convert_element_type_884, sub_16); convert_element_type_884 = sub_16 = None + add_165 = torch.ops.aten.add.Tensor(mul_368, 1); mul_368 = None + mul_369 = torch.ops.aten.mul.Tensor(mul_367, add_165); mul_367 = add_165 = None + convert_element_type_1356 = torch.ops.prims.convert_element_type.default(mul_369, torch.bfloat16); mul_369 = None + view_1219 = torch.ops.aten.view.default(convert_element_type_1356, [16384, 14336]); convert_element_type_1356 = None + permute_525 = torch.ops.aten.permute.default(view_1219, [1, 0]) + mm_301 = torch.ops.aten.mm.default(permute_525, view_911); permute_525 = view_911 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 64, '0'); convert_element_type_881 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + permute_527 = torch.ops.aten.permute.default(permute_294, [1, 0]); permute_294 = None + mm_302 = torch.ops.aten.mm.default(view_1219, permute_527); view_1219 = permute_527 = None + view_1220 = torch.ops.aten.view.default(mm_302, [2, 8192, 4096]); mm_302 = None + add_166 = torch.ops.aten.add.Tensor(view_1218, view_1220); view_1218 = view_1220 = None + convert_element_type_1361 = torch.ops.prims.convert_element_type.default(mm_301, torch.float32); mm_301 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1361, 'avg', 64, '0'); convert_element_type_1361 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49); reduce_scatter_tensor_49 = None + convert_element_type_1362 = torch.ops.prims.convert_element_type.default(add_166, torch.float32); add_166 = None + convert_element_type_1364 = torch.ops.prims.convert_element_type.default(wait_tensor_240, torch.float32); wait_tensor_240 = None + mul_370 = torch.ops.aten.mul.Tensor(convert_element_type_1362, convert_element_type_1364); convert_element_type_1364 = None + mul_372 = torch.ops.aten.mul.Tensor(mul_212, mul_370) + sum_33 = torch.ops.aten.sum.dim_IntList(mul_372, [2], True); mul_372 = None + div_11 = torch.ops.aten.div.Tensor(mul_212, 4096) + mul_373 = torch.ops.aten.mul.Tensor(div_11, sum_33); div_11 = sum_33 = None + sub_17 = torch.ops.aten.sub.Tensor(mul_370, mul_373); mul_370 = mul_373 = None + mul_374 = torch.ops.aten.mul.Tensor(sub_17, rsqrt_53); sub_17 = rsqrt_53 = None + mul_375 = torch.ops.aten.mul.Tensor(convert_element_type_1362, mul_212); convert_element_type_1362 = mul_212 = None + sum_34 = torch.ops.aten.sum.dim_IntList(mul_375, [0, 1]); mul_375 = None + convert_element_type_1365 = torch.ops.prims.convert_element_type.default(mul_374, torch.bfloat16); mul_374 = None + add_167 = torch.ops.aten.add.Tensor(add_163, convert_element_type_1365); add_163 = convert_element_type_1365 = None + convert_element_type_default_54 = torch.ops.prims.convert_element_type.default(sum_34, torch.float32); sum_34 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_54, 'avg', 64, '0'); convert_element_type_default_54 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + view_1221 = torch.ops.aten.view.default(add_167, [16384, 4096]) + permute_529 = torch.ops.aten.permute.default(view_1221, [1, 0]) + mm_303 = torch.ops.aten.mm.default(permute_529, view_907); permute_529 = view_907 = None + permute_531 = torch.ops.aten.permute.default(permute_293, [1, 0]); permute_293 = None + mm_304 = torch.ops.aten.mm.default(view_1221, permute_531); view_1221 = permute_531 = None + view_1222 = torch.ops.aten.view.default(mm_304, [2, 8192, 4096]); mm_304 = None + convert_element_type_1372 = torch.ops.prims.convert_element_type.default(mm_303, torch.float32); mm_303 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1372, 'avg', 64, '0'); convert_element_type_1372 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51); reduce_scatter_tensor_51 = None + view_1223 = torch.ops.aten.view.default(view_1222, [2, 8192, 32, 128]); view_1222 = None + permute_533 = torch.ops.aten.permute.default(view_1223, [0, 2, 1, 3]); view_1223 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16); primals_238 = None + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 64, '0'); convert_element_type_859 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32); add_103 = None + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_235) + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + view_887 = torch.ops.aten.view.default(convert_element_type_861, [16384, 4096]); convert_element_type_861 = None + view_888 = torch.ops.aten.view.default(mm_182, [2, 8192, 4096]); mm_182 = None + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16); primals_240 = None + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 64, '0'); convert_element_type_865 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_237, [1, 0]); wait_tensor_237 = None + mm_183 = torch.ops.aten.mm.default(view_887, permute_287) + view_891 = torch.ops.aten.view.default(mm_183, [2, 8192, 1024]); mm_183 = None + view_894 = torch.ops.aten.view.default(mm_184, [2, 8192, 1024]); mm_184 = None + view_895 = torch.ops.aten.view.default(view_888, [2, 8192, -1, 128]); view_888 = None + view_896 = torch.ops.aten.view.default(view_891, [2, 8192, -1, 128]); view_891 = None + view_897 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_895, torch.float32); view_895 = None + view_898 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 32, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_898); view_898 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 8, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_16); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_901 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 32, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_16); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_902 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 8, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_901, torch.bfloat16); view_901 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 8, 4, 128]); unsqueeze_52 = None + clone_52 = torch.ops.aten.clone.default(expand_52, memory_format = torch.contiguous_format); expand_52 = None + view_903 = torch.ops.aten.view.default(clone_52, [2, 8192, 32, 128]); clone_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_897, 3); view_897 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 8, 4, 128]); unsqueeze_53 = None + clone_53 = torch.ops.aten.clone.default(expand_53, memory_format = torch.contiguous_format); expand_53 = None + view_904 = torch.ops.aten.view.default(clone_53, [2, 8192, 32, 128]); clone_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_903, [0, 2, 1, 3]); view_903 = None + permute_291 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + _scaled_dot_product_cudnn_attention_backward_5 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_533, permute_289, permute_290, permute_291, getitem_234, getitem_235, getitem_240, getitem_241, None, None, None, 8192, 8192, 0.0, True); permute_533 = permute_289 = permute_290 = permute_291 = getitem_234 = getitem_235 = getitem_240 = getitem_241 = None + getitem_303 = _scaled_dot_product_cudnn_attention_backward_5[0] + getitem_304 = _scaled_dot_product_cudnn_attention_backward_5[1] + getitem_305 = _scaled_dot_product_cudnn_attention_backward_5[2]; _scaled_dot_product_cudnn_attention_backward_5 = None + permute_534 = torch.ops.aten.permute.default(getitem_305, [0, 2, 1, 3]); getitem_305 = None + permute_535 = torch.ops.aten.permute.default(getitem_304, [0, 2, 1, 3]); getitem_304 = None + permute_536 = torch.ops.aten.permute.default(getitem_303, [0, 2, 1, 3]); getitem_303 = None + view_1224 = torch.ops.aten.view.default(permute_534, [2, 8192, 8, 4, 128]); permute_534 = None + sum_35 = torch.ops.aten.sum.dim_IntList(view_1224, [3], True); view_1224 = None + squeeze_10 = torch.ops.aten.squeeze.dim(sum_35, 3); sum_35 = None + view_1225 = torch.ops.aten.view.default(permute_535, [2, 8192, 8, 4, 128]); permute_535 = None + sum_36 = torch.ops.aten.sum.dim_IntList(view_1225, [3], True); view_1225 = None + squeeze_11 = torch.ops.aten.squeeze.dim(sum_36, 3); sum_36 = None + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(squeeze_11, torch.float32); squeeze_11 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(permute_536, torch.float32); permute_536 = None + view_1226 = torch.ops.aten.view.default(convert_element_type_1373, [2, 8192, 8, 64, 2]); convert_element_type_1373 = None + view_as_complex_74 = torch.ops.aten.view_as_complex.default(view_1226); view_1226 = None + mul_376 = torch.ops.aten.mul.Tensor(view_as_complex_74, _conj); view_as_complex_74 = None + view_1227 = torch.ops.aten.view.default(convert_element_type_1374, [2, 8192, 32, 64, 2]); convert_element_type_1374 = None + view_as_complex_75 = torch.ops.aten.view_as_complex.default(view_1227); view_1227 = None + mul_377 = torch.ops.aten.mul.Tensor(view_as_complex_75, _conj); view_as_complex_75 = None + view_as_real_74 = torch.ops.aten.view_as_real.default(mul_376); mul_376 = None + view_1228 = torch.ops.aten.view.default(view_as_real_74, [2, 8192, 8, 128]); view_as_real_74 = None + convert_element_type_1375 = torch.ops.prims.convert_element_type.default(view_1228, torch.bfloat16); view_1228 = None + view_as_real_75 = torch.ops.aten.view_as_real.default(mul_377); mul_377 = None + view_1229 = torch.ops.aten.view.default(view_as_real_75, [2, 8192, 32, 128]); view_as_real_75 = None + convert_element_type_1376 = torch.ops.prims.convert_element_type.default(view_1229, torch.bfloat16); view_1229 = None + view_1230 = torch.ops.aten.view.default(squeeze_10, [2, 8192, 1024]); squeeze_10 = None + view_1231 = torch.ops.aten.view.default(convert_element_type_1375, [2, 8192, 1024]); convert_element_type_1375 = None + view_1232 = torch.ops.aten.view.default(convert_element_type_1376, [2, 8192, 4096]); convert_element_type_1376 = None + view_1233 = torch.ops.aten.view.default(view_1230, [16384, 1024]); view_1230 = None + permute_537 = torch.ops.aten.permute.default(view_1233, [1, 0]) + mm_305 = torch.ops.aten.mm.default(permute_537, view_887); permute_537 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16); primals_241 = None + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 64, '0'); convert_element_type_868 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + permute_539 = torch.ops.aten.permute.default(permute_288, [1, 0]); permute_288 = None + mm_306 = torch.ops.aten.mm.default(view_1233, permute_539); view_1233 = permute_539 = None + view_1234 = torch.ops.aten.view.default(mm_306, [2, 8192, 4096]); mm_306 = None + convert_element_type_1381 = torch.ops.prims.convert_element_type.default(mm_305, torch.float32); mm_305 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1381, 'avg', 64, '0'); convert_element_type_1381 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + view_1235 = torch.ops.aten.view.default(view_1231, [16384, 1024]); view_1231 = None + permute_541 = torch.ops.aten.permute.default(view_1235, [1, 0]) + mm_307 = torch.ops.aten.mm.default(permute_541, view_887); permute_541 = None + permute_543 = torch.ops.aten.permute.default(permute_287, [1, 0]); permute_287 = None + mm_308 = torch.ops.aten.mm.default(view_1235, permute_543); view_1235 = permute_543 = None + view_1236 = torch.ops.aten.view.default(mm_308, [2, 8192, 4096]); mm_308 = None + add_168 = torch.ops.aten.add.Tensor(view_1234, view_1236); view_1234 = view_1236 = None + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(mm_307, torch.float32); mm_307 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1386, 'avg', 64, '0'); convert_element_type_1386 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53); reduce_scatter_tensor_53 = None + view_1237 = torch.ops.aten.view.default(view_1232, [16384, 4096]); view_1232 = None + permute_545 = torch.ops.aten.permute.default(view_1237, [1, 0]) + mm_309 = torch.ops.aten.mm.default(permute_545, view_887); permute_545 = view_887 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16); primals_239 = None + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 64, '0'); convert_element_type_862 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_236, [1, 0]); wait_tensor_236 = None + permute_547 = torch.ops.aten.permute.default(permute_286, [1, 0]); permute_286 = None + mm_310 = torch.ops.aten.mm.default(view_1237, permute_547); view_1237 = permute_547 = None + view_1238 = torch.ops.aten.view.default(mm_310, [2, 8192, 4096]); mm_310 = None + add_169 = torch.ops.aten.add.Tensor(add_168, view_1238); add_168 = view_1238 = None + convert_element_type_1391 = torch.ops.prims.convert_element_type.default(mm_309, torch.float32); mm_309 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1391, 'avg', 64, '0'); convert_element_type_1391 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + convert_element_type_1392 = torch.ops.prims.convert_element_type.default(add_169, torch.float32); add_169 = None + convert_element_type_1394 = torch.ops.prims.convert_element_type.default(wait_tensor_235, torch.float32); wait_tensor_235 = None + mul_378 = torch.ops.aten.mul.Tensor(convert_element_type_1392, convert_element_type_1394); convert_element_type_1394 = None + mul_380 = torch.ops.aten.mul.Tensor(mul_208, mul_378) + sum_37 = torch.ops.aten.sum.dim_IntList(mul_380, [2], True); mul_380 = None + div_12 = torch.ops.aten.div.Tensor(mul_208, 4096) + mul_381 = torch.ops.aten.mul.Tensor(div_12, sum_37); div_12 = sum_37 = None + sub_18 = torch.ops.aten.sub.Tensor(mul_378, mul_381); mul_378 = mul_381 = None + mul_382 = torch.ops.aten.mul.Tensor(sub_18, rsqrt_52); sub_18 = rsqrt_52 = None + mul_383 = torch.ops.aten.mul.Tensor(convert_element_type_1392, mul_208); convert_element_type_1392 = mul_208 = None + sum_38 = torch.ops.aten.sum.dim_IntList(mul_383, [0, 1]); mul_383 = None + convert_element_type_1395 = torch.ops.prims.convert_element_type.default(mul_382, torch.bfloat16); mul_382 = None + add_170 = torch.ops.aten.add.Tensor(add_167, convert_element_type_1395); add_167 = convert_element_type_1395 = None + convert_element_type_default_53 = torch.ops.prims.convert_element_type.default(sum_38, torch.float32); sum_38 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_53, 'avg', 64, '0'); convert_element_type_default_53 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55); reduce_scatter_tensor_55 = None + view_1239 = torch.ops.aten.view.default(add_170, [16384, 4096]) + permute_549 = torch.ops.aten.permute.default(view_1239, [1, 0]) + permute_281 = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]) + view_871 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16); primals_233 = None + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 64, '0'); convert_element_type_842 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_230, [1, 0]); wait_tensor_230 = None + view_873 = torch.ops.aten.view.default(view_871, [16384, 4096]); view_871 = None + mm_178 = torch.ops.aten.mm.default(view_873, permute_282) + view_874 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + add_101 = torch.ops.aten.add.Tensor(add_99, view_874); view_874 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16); primals_234 = None + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 64, '0'); convert_element_type_845 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32); add_101 = None + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_231) + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + view_877 = torch.ops.aten.view.default(convert_element_type_847, [16384, 4096]); convert_element_type_847 = None + view_878 = torch.ops.aten.view.default(mm_179, [2, 8192, 14336]); mm_179 = None + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_878, torch.float32); view_878 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16); primals_236 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 64, '0'); convert_element_type_853 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_180 = torch.ops.aten.mm.default(view_877, permute_284) + view_881 = torch.ops.aten.view.default(mm_180, [2, 8192, 14336]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_881) + view_883 = torch.ops.aten.view.default(mul_207, [16384, 14336]); mul_207 = None + mm_311 = torch.ops.aten.mm.default(permute_549, view_883); permute_549 = view_883 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16); primals_237 = None + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 64, '0'); convert_element_type_856 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + permute_551 = torch.ops.aten.permute.default(permute_285, [1, 0]); permute_285 = None + mm_312 = torch.ops.aten.mm.default(view_1239, permute_551); view_1239 = permute_551 = None + view_1240 = torch.ops.aten.view.default(mm_312, [2, 8192, 14336]); mm_312 = None + convert_element_type_1402 = torch.ops.prims.convert_element_type.default(mm_311, torch.float32); mm_311 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1402, 'avg', 64, '0'); convert_element_type_1402 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + mul_384 = torch.ops.aten.mul.Tensor(view_1240, convert_element_type_852); convert_element_type_852 = None + mul_385 = torch.ops.aten.mul.Tensor(view_1240, view_881); view_1240 = view_881 = None + view_1241 = torch.ops.aten.view.default(mul_384, [16384, 14336]); mul_384 = None + permute_553 = torch.ops.aten.permute.default(view_1241, [1, 0]) + mm_313 = torch.ops.aten.mm.default(permute_553, view_877); permute_553 = None + permute_555 = torch.ops.aten.permute.default(permute_284, [1, 0]); permute_284 = None + mm_314 = torch.ops.aten.mm.default(view_1241, permute_555); view_1241 = permute_555 = None + view_1242 = torch.ops.aten.view.default(mm_314, [2, 8192, 4096]); mm_314 = None + convert_element_type_1407 = torch.ops.prims.convert_element_type.default(mm_313, torch.float32); mm_313 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1407, 'avg', 64, '0'); convert_element_type_1407 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57); reduce_scatter_tensor_57 = None + convert_element_type_1408 = torch.ops.prims.convert_element_type.default(mul_385, torch.float32); mul_385 = None + neg_6 = torch.ops.aten.neg.default(convert_element_type_851) + exp_6 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_171 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + reciprocal_6 = torch.ops.aten.reciprocal.default(add_171); add_171 = None + mul_386 = torch.ops.aten.mul.Tensor(reciprocal_6, 1); reciprocal_6 = None + mul_387 = torch.ops.aten.mul.Tensor(convert_element_type_1408, mul_386); convert_element_type_1408 = None + sub_19 = torch.ops.aten.sub.Tensor(1, mul_386); mul_386 = None + mul_388 = torch.ops.aten.mul.Tensor(convert_element_type_851, sub_19); convert_element_type_851 = sub_19 = None + add_172 = torch.ops.aten.add.Tensor(mul_388, 1); mul_388 = None + mul_389 = torch.ops.aten.mul.Tensor(mul_387, add_172); mul_387 = add_172 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(mul_389, torch.bfloat16); mul_389 = None + view_1243 = torch.ops.aten.view.default(convert_element_type_1410, [16384, 14336]); convert_element_type_1410 = None + permute_557 = torch.ops.aten.permute.default(view_1243, [1, 0]) + mm_315 = torch.ops.aten.mm.default(permute_557, view_877); permute_557 = view_877 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16); primals_235 = None + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 64, '0'); convert_element_type_848 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + permute_559 = torch.ops.aten.permute.default(permute_283, [1, 0]); permute_283 = None + mm_316 = torch.ops.aten.mm.default(view_1243, permute_559); view_1243 = permute_559 = None + view_1244 = torch.ops.aten.view.default(mm_316, [2, 8192, 4096]); mm_316 = None + add_173 = torch.ops.aten.add.Tensor(view_1242, view_1244); view_1242 = view_1244 = None + convert_element_type_1415 = torch.ops.prims.convert_element_type.default(mm_315, torch.float32); mm_315 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1415, 'avg', 64, '0'); convert_element_type_1415 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + convert_element_type_1416 = torch.ops.prims.convert_element_type.default(add_173, torch.float32); add_173 = None + convert_element_type_1418 = torch.ops.prims.convert_element_type.default(wait_tensor_231, torch.float32); wait_tensor_231 = None + mul_390 = torch.ops.aten.mul.Tensor(convert_element_type_1416, convert_element_type_1418); convert_element_type_1418 = None + mul_392 = torch.ops.aten.mul.Tensor(mul_204, mul_390) + sum_39 = torch.ops.aten.sum.dim_IntList(mul_392, [2], True); mul_392 = None + div_13 = torch.ops.aten.div.Tensor(mul_204, 4096) + mul_393 = torch.ops.aten.mul.Tensor(div_13, sum_39); div_13 = sum_39 = None + sub_20 = torch.ops.aten.sub.Tensor(mul_390, mul_393); mul_390 = mul_393 = None + mul_394 = torch.ops.aten.mul.Tensor(sub_20, rsqrt_51); sub_20 = rsqrt_51 = None + mul_395 = torch.ops.aten.mul.Tensor(convert_element_type_1416, mul_204); convert_element_type_1416 = mul_204 = None + sum_40 = torch.ops.aten.sum.dim_IntList(mul_395, [0, 1]); mul_395 = None + convert_element_type_1419 = torch.ops.prims.convert_element_type.default(mul_394, torch.bfloat16); mul_394 = None + add_174 = torch.ops.aten.add.Tensor(add_170, convert_element_type_1419); add_170 = convert_element_type_1419 = None + convert_element_type_default_52 = torch.ops.prims.convert_element_type.default(sum_40, torch.float32); sum_40 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_52, 'avg', 64, '0'); convert_element_type_default_52 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59); reduce_scatter_tensor_59 = None + view_1245 = torch.ops.aten.view.default(add_174, [16384, 4096]) + permute_561 = torch.ops.aten.permute.default(view_1245, [1, 0]) + mm_317 = torch.ops.aten.mm.default(permute_561, view_873); permute_561 = view_873 = None + permute_563 = torch.ops.aten.permute.default(permute_282, [1, 0]); permute_282 = None + mm_318 = torch.ops.aten.mm.default(view_1245, permute_563); view_1245 = permute_563 = None + view_1246 = torch.ops.aten.view.default(mm_318, [2, 8192, 4096]); mm_318 = None + convert_element_type_1426 = torch.ops.prims.convert_element_type.default(mm_317, torch.float32); mm_317 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1426, 'avg', 64, '0'); convert_element_type_1426 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + view_1247 = torch.ops.aten.view.default(view_1246, [2, 8192, 32, 128]); view_1246 = None + permute_565 = torch.ops.aten.permute.default(view_1247, [0, 2, 1, 3]); view_1247 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16); primals_229 = None + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 64, '0'); convert_element_type_826 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32); add_99 = None + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_226) + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + view_853 = torch.ops.aten.view.default(convert_element_type_828, [16384, 4096]); convert_element_type_828 = None + view_854 = torch.ops.aten.view.default(mm_175, [2, 8192, 4096]); mm_175 = None + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16); primals_231 = None + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 64, '0'); convert_element_type_832 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + mm_176 = torch.ops.aten.mm.default(view_853, permute_276) + view_857 = torch.ops.aten.view.default(mm_176, [2, 8192, 1024]); mm_176 = None + view_860 = torch.ops.aten.view.default(mm_177, [2, 8192, 1024]); mm_177 = None + view_861 = torch.ops.aten.view.default(view_854, [2, 8192, -1, 128]); view_854 = None + view_862 = torch.ops.aten.view.default(view_857, [2, 8192, -1, 128]); view_857 = None + view_863 = torch.ops.aten.view.default(view_860, [2, 8192, -1, 128]); view_860 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_861, torch.float32); view_861 = None + view_864 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 32, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_864); view_864 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_862, torch.float32); view_862 = None + view_865 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 8, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_865); view_865 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_16); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_867 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 32, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_16); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_868 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 8, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_867, torch.bfloat16); view_867 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_868, torch.bfloat16); view_868 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 8, 4, 128]); unsqueeze_50 = None + clone_50 = torch.ops.aten.clone.default(expand_50, memory_format = torch.contiguous_format); expand_50 = None + view_869 = torch.ops.aten.view.default(clone_50, [2, 8192, 32, 128]); clone_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_863, 3); view_863 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 8, 4, 128]); unsqueeze_51 = None + clone_51 = torch.ops.aten.clone.default(expand_51, memory_format = torch.contiguous_format); expand_51 = None + view_870 = torch.ops.aten.view.default(clone_51, [2, 8192, 32, 128]); clone_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_869, [0, 2, 1, 3]); view_869 = None + permute_280 = torch.ops.aten.permute.default(view_870, [0, 2, 1, 3]); view_870 = None + _scaled_dot_product_cudnn_attention_backward_6 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_565, permute_278, permute_279, permute_280, getitem_225, getitem_226, getitem_231, getitem_232, None, None, None, 8192, 8192, 0.0, True); permute_565 = permute_278 = permute_279 = permute_280 = getitem_225 = getitem_226 = getitem_231 = getitem_232 = None + getitem_306 = _scaled_dot_product_cudnn_attention_backward_6[0] + getitem_307 = _scaled_dot_product_cudnn_attention_backward_6[1] + getitem_308 = _scaled_dot_product_cudnn_attention_backward_6[2]; _scaled_dot_product_cudnn_attention_backward_6 = None + permute_566 = torch.ops.aten.permute.default(getitem_308, [0, 2, 1, 3]); getitem_308 = None + permute_567 = torch.ops.aten.permute.default(getitem_307, [0, 2, 1, 3]); getitem_307 = None + permute_568 = torch.ops.aten.permute.default(getitem_306, [0, 2, 1, 3]); getitem_306 = None + view_1248 = torch.ops.aten.view.default(permute_566, [2, 8192, 8, 4, 128]); permute_566 = None + sum_41 = torch.ops.aten.sum.dim_IntList(view_1248, [3], True); view_1248 = None + squeeze_12 = torch.ops.aten.squeeze.dim(sum_41, 3); sum_41 = None + view_1249 = torch.ops.aten.view.default(permute_567, [2, 8192, 8, 4, 128]); permute_567 = None + sum_42 = torch.ops.aten.sum.dim_IntList(view_1249, [3], True); view_1249 = None + squeeze_13 = torch.ops.aten.squeeze.dim(sum_42, 3); sum_42 = None + convert_element_type_1427 = torch.ops.prims.convert_element_type.default(squeeze_13, torch.float32); squeeze_13 = None + convert_element_type_1428 = torch.ops.prims.convert_element_type.default(permute_568, torch.float32); permute_568 = None + view_1250 = torch.ops.aten.view.default(convert_element_type_1427, [2, 8192, 8, 64, 2]); convert_element_type_1427 = None + view_as_complex_76 = torch.ops.aten.view_as_complex.default(view_1250); view_1250 = None + mul_396 = torch.ops.aten.mul.Tensor(view_as_complex_76, _conj); view_as_complex_76 = None + view_1251 = torch.ops.aten.view.default(convert_element_type_1428, [2, 8192, 32, 64, 2]); convert_element_type_1428 = None + view_as_complex_77 = torch.ops.aten.view_as_complex.default(view_1251); view_1251 = None + mul_397 = torch.ops.aten.mul.Tensor(view_as_complex_77, _conj); view_as_complex_77 = None + view_as_real_76 = torch.ops.aten.view_as_real.default(mul_396); mul_396 = None + view_1252 = torch.ops.aten.view.default(view_as_real_76, [2, 8192, 8, 128]); view_as_real_76 = None + convert_element_type_1429 = torch.ops.prims.convert_element_type.default(view_1252, torch.bfloat16); view_1252 = None + view_as_real_77 = torch.ops.aten.view_as_real.default(mul_397); mul_397 = None + view_1253 = torch.ops.aten.view.default(view_as_real_77, [2, 8192, 32, 128]); view_as_real_77 = None + convert_element_type_1430 = torch.ops.prims.convert_element_type.default(view_1253, torch.bfloat16); view_1253 = None + view_1254 = torch.ops.aten.view.default(squeeze_12, [2, 8192, 1024]); squeeze_12 = None + view_1255 = torch.ops.aten.view.default(convert_element_type_1429, [2, 8192, 1024]); convert_element_type_1429 = None + view_1256 = torch.ops.aten.view.default(convert_element_type_1430, [2, 8192, 4096]); convert_element_type_1430 = None + view_1257 = torch.ops.aten.view.default(view_1254, [16384, 1024]); view_1254 = None + permute_569 = torch.ops.aten.permute.default(view_1257, [1, 0]) + mm_319 = torch.ops.aten.mm.default(permute_569, view_853); permute_569 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16); primals_232 = None + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 64, '0'); convert_element_type_835 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_229, [1, 0]); wait_tensor_229 = None + permute_571 = torch.ops.aten.permute.default(permute_277, [1, 0]); permute_277 = None + mm_320 = torch.ops.aten.mm.default(view_1257, permute_571); view_1257 = permute_571 = None + view_1258 = torch.ops.aten.view.default(mm_320, [2, 8192, 4096]); mm_320 = None + convert_element_type_1435 = torch.ops.prims.convert_element_type.default(mm_319, torch.float32); mm_319 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1435, 'avg', 64, '0'); convert_element_type_1435 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61); reduce_scatter_tensor_61 = None + view_1259 = torch.ops.aten.view.default(view_1255, [16384, 1024]); view_1255 = None + permute_573 = torch.ops.aten.permute.default(view_1259, [1, 0]) + mm_321 = torch.ops.aten.mm.default(permute_573, view_853); permute_573 = None + permute_575 = torch.ops.aten.permute.default(permute_276, [1, 0]); permute_276 = None + mm_322 = torch.ops.aten.mm.default(view_1259, permute_575); view_1259 = permute_575 = None + view_1260 = torch.ops.aten.view.default(mm_322, [2, 8192, 4096]); mm_322 = None + add_175 = torch.ops.aten.add.Tensor(view_1258, view_1260); view_1258 = view_1260 = None + convert_element_type_1440 = torch.ops.prims.convert_element_type.default(mm_321, torch.float32); mm_321 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1440, 'avg', 64, '0'); convert_element_type_1440 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + view_1261 = torch.ops.aten.view.default(view_1256, [16384, 4096]); view_1256 = None + permute_577 = torch.ops.aten.permute.default(view_1261, [1, 0]) + mm_323 = torch.ops.aten.mm.default(permute_577, view_853); permute_577 = view_853 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16); primals_230 = None + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 64, '0'); convert_element_type_829 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + permute_579 = torch.ops.aten.permute.default(permute_275, [1, 0]); permute_275 = None + mm_324 = torch.ops.aten.mm.default(view_1261, permute_579); view_1261 = permute_579 = None + view_1262 = torch.ops.aten.view.default(mm_324, [2, 8192, 4096]); mm_324 = None + add_176 = torch.ops.aten.add.Tensor(add_175, view_1262); add_175 = view_1262 = None + convert_element_type_1445 = torch.ops.prims.convert_element_type.default(mm_323, torch.float32); mm_323 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1445, 'avg', 64, '0'); convert_element_type_1445 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63); reduce_scatter_tensor_63 = None + convert_element_type_1446 = torch.ops.prims.convert_element_type.default(add_176, torch.float32); add_176 = None + convert_element_type_1448 = torch.ops.prims.convert_element_type.default(wait_tensor_226, torch.float32); wait_tensor_226 = None + mul_398 = torch.ops.aten.mul.Tensor(convert_element_type_1446, convert_element_type_1448); convert_element_type_1448 = None + mul_400 = torch.ops.aten.mul.Tensor(mul_200, mul_398) + sum_43 = torch.ops.aten.sum.dim_IntList(mul_400, [2], True); mul_400 = None + div_14 = torch.ops.aten.div.Tensor(mul_200, 4096) + mul_401 = torch.ops.aten.mul.Tensor(div_14, sum_43); div_14 = sum_43 = None + sub_21 = torch.ops.aten.sub.Tensor(mul_398, mul_401); mul_398 = mul_401 = None + mul_402 = torch.ops.aten.mul.Tensor(sub_21, rsqrt_50); sub_21 = rsqrt_50 = None + mul_403 = torch.ops.aten.mul.Tensor(convert_element_type_1446, mul_200); convert_element_type_1446 = mul_200 = None + sum_44 = torch.ops.aten.sum.dim_IntList(mul_403, [0, 1]); mul_403 = None + convert_element_type_1449 = torch.ops.prims.convert_element_type.default(mul_402, torch.bfloat16); mul_402 = None + add_177 = torch.ops.aten.add.Tensor(add_174, convert_element_type_1449); add_174 = convert_element_type_1449 = None + convert_element_type_default_51 = torch.ops.prims.convert_element_type.default(sum_44, torch.float32); sum_44 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_51, 'avg', 64, '0'); convert_element_type_default_51 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64); reduce_scatter_tensor_64 = None + view_1263 = torch.ops.aten.view.default(add_177, [16384, 4096]) + permute_581 = torch.ops.aten.permute.default(view_1263, [1, 0]) + permute_270 = torch.ops.aten.permute.default(getitem_216, [0, 2, 1, 3]) + view_837 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16); primals_224 = None + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 64, '0'); convert_element_type_809 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_839 = torch.ops.aten.view.default(view_837, [16384, 4096]); view_837 = None + mm_171 = torch.ops.aten.mm.default(view_839, permute_271) + view_840 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + add_97 = torch.ops.aten.add.Tensor(add_95, view_840); view_840 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16); primals_225 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 64, '0'); convert_element_type_812 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32); add_97 = None + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_222) + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + view_843 = torch.ops.aten.view.default(convert_element_type_814, [16384, 4096]); convert_element_type_814 = None + view_844 = torch.ops.aten.view.default(mm_172, [2, 8192, 14336]); mm_172 = None + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_844, torch.float32); view_844 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16); primals_227 = None + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 64, '0'); convert_element_type_820 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_224, [1, 0]); wait_tensor_224 = None + mm_173 = torch.ops.aten.mm.default(view_843, permute_273) + view_847 = torch.ops.aten.view.default(mm_173, [2, 8192, 14336]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_847) + view_849 = torch.ops.aten.view.default(mul_199, [16384, 14336]); mul_199 = None + mm_325 = torch.ops.aten.mm.default(permute_581, view_849); permute_581 = view_849 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16); primals_228 = None + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 64, '0'); convert_element_type_823 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + permute_583 = torch.ops.aten.permute.default(permute_274, [1, 0]); permute_274 = None + mm_326 = torch.ops.aten.mm.default(view_1263, permute_583); view_1263 = permute_583 = None + view_1264 = torch.ops.aten.view.default(mm_326, [2, 8192, 14336]); mm_326 = None + convert_element_type_1456 = torch.ops.prims.convert_element_type.default(mm_325, torch.float32); mm_325 = None + reduce_scatter_tensor_65 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1456, 'avg', 64, '0'); convert_element_type_1456 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_65); reduce_scatter_tensor_65 = None + mul_404 = torch.ops.aten.mul.Tensor(view_1264, convert_element_type_819); convert_element_type_819 = None + mul_405 = torch.ops.aten.mul.Tensor(view_1264, view_847); view_1264 = view_847 = None + view_1265 = torch.ops.aten.view.default(mul_404, [16384, 14336]); mul_404 = None + permute_585 = torch.ops.aten.permute.default(view_1265, [1, 0]) + mm_327 = torch.ops.aten.mm.default(permute_585, view_843); permute_585 = None + permute_587 = torch.ops.aten.permute.default(permute_273, [1, 0]); permute_273 = None + mm_328 = torch.ops.aten.mm.default(view_1265, permute_587); view_1265 = permute_587 = None + view_1266 = torch.ops.aten.view.default(mm_328, [2, 8192, 4096]); mm_328 = None + convert_element_type_1461 = torch.ops.prims.convert_element_type.default(mm_327, torch.float32); mm_327 = None + reduce_scatter_tensor_66 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1461, 'avg', 64, '0'); convert_element_type_1461 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_66); reduce_scatter_tensor_66 = None + convert_element_type_1462 = torch.ops.prims.convert_element_type.default(mul_405, torch.float32); mul_405 = None + neg_7 = torch.ops.aten.neg.default(convert_element_type_818) + exp_7 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_178 = torch.ops.aten.add.Tensor(exp_7, 1); exp_7 = None + reciprocal_7 = torch.ops.aten.reciprocal.default(add_178); add_178 = None + mul_406 = torch.ops.aten.mul.Tensor(reciprocal_7, 1); reciprocal_7 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_1462, mul_406); convert_element_type_1462 = None + sub_22 = torch.ops.aten.sub.Tensor(1, mul_406); mul_406 = None + mul_408 = torch.ops.aten.mul.Tensor(convert_element_type_818, sub_22); convert_element_type_818 = sub_22 = None + add_179 = torch.ops.aten.add.Tensor(mul_408, 1); mul_408 = None + mul_409 = torch.ops.aten.mul.Tensor(mul_407, add_179); mul_407 = add_179 = None + convert_element_type_1464 = torch.ops.prims.convert_element_type.default(mul_409, torch.bfloat16); mul_409 = None + view_1267 = torch.ops.aten.view.default(convert_element_type_1464, [16384, 14336]); convert_element_type_1464 = None + permute_589 = torch.ops.aten.permute.default(view_1267, [1, 0]) + mm_329 = torch.ops.aten.mm.default(permute_589, view_843); permute_589 = view_843 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16); primals_226 = None + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 64, '0'); convert_element_type_815 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_223, [1, 0]); wait_tensor_223 = None + permute_591 = torch.ops.aten.permute.default(permute_272, [1, 0]); permute_272 = None + mm_330 = torch.ops.aten.mm.default(view_1267, permute_591); view_1267 = permute_591 = None + view_1268 = torch.ops.aten.view.default(mm_330, [2, 8192, 4096]); mm_330 = None + add_180 = torch.ops.aten.add.Tensor(view_1266, view_1268); view_1266 = view_1268 = None + convert_element_type_1469 = torch.ops.prims.convert_element_type.default(mm_329, torch.float32); mm_329 = None + reduce_scatter_tensor_67 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1469, 'avg', 64, '0'); convert_element_type_1469 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_67); reduce_scatter_tensor_67 = None + convert_element_type_1470 = torch.ops.prims.convert_element_type.default(add_180, torch.float32); add_180 = None + convert_element_type_1472 = torch.ops.prims.convert_element_type.default(wait_tensor_222, torch.float32); wait_tensor_222 = None + mul_410 = torch.ops.aten.mul.Tensor(convert_element_type_1470, convert_element_type_1472); convert_element_type_1472 = None + mul_412 = torch.ops.aten.mul.Tensor(mul_196, mul_410) + sum_45 = torch.ops.aten.sum.dim_IntList(mul_412, [2], True); mul_412 = None + div_15 = torch.ops.aten.div.Tensor(mul_196, 4096) + mul_413 = torch.ops.aten.mul.Tensor(div_15, sum_45); div_15 = sum_45 = None + sub_23 = torch.ops.aten.sub.Tensor(mul_410, mul_413); mul_410 = mul_413 = None + mul_414 = torch.ops.aten.mul.Tensor(sub_23, rsqrt_49); sub_23 = rsqrt_49 = None + mul_415 = torch.ops.aten.mul.Tensor(convert_element_type_1470, mul_196); convert_element_type_1470 = mul_196 = None + sum_46 = torch.ops.aten.sum.dim_IntList(mul_415, [0, 1]); mul_415 = None + convert_element_type_1473 = torch.ops.prims.convert_element_type.default(mul_414, torch.bfloat16); mul_414 = None + add_181 = torch.ops.aten.add.Tensor(add_177, convert_element_type_1473); add_177 = convert_element_type_1473 = None + convert_element_type_default_50 = torch.ops.prims.convert_element_type.default(sum_46, torch.float32); sum_46 = None + reduce_scatter_tensor_68 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_50, 'avg', 64, '0'); convert_element_type_default_50 = None + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_68); reduce_scatter_tensor_68 = None + view_1269 = torch.ops.aten.view.default(add_181, [16384, 4096]) + permute_593 = torch.ops.aten.permute.default(view_1269, [1, 0]) + mm_331 = torch.ops.aten.mm.default(permute_593, view_839); permute_593 = view_839 = None + permute_595 = torch.ops.aten.permute.default(permute_271, [1, 0]); permute_271 = None + mm_332 = torch.ops.aten.mm.default(view_1269, permute_595); view_1269 = permute_595 = None + view_1270 = torch.ops.aten.view.default(mm_332, [2, 8192, 4096]); mm_332 = None + convert_element_type_1480 = torch.ops.prims.convert_element_type.default(mm_331, torch.float32); mm_331 = None + reduce_scatter_tensor_69 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1480, 'avg', 64, '0'); convert_element_type_1480 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_69); reduce_scatter_tensor_69 = None + view_1271 = torch.ops.aten.view.default(view_1270, [2, 8192, 32, 128]); view_1270 = None + permute_597 = torch.ops.aten.permute.default(view_1271, [0, 2, 1, 3]); view_1271 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16); primals_220 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 64, '0'); convert_element_type_793 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32); add_95 = None + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_217) + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + view_819 = torch.ops.aten.view.default(convert_element_type_795, [16384, 4096]); convert_element_type_795 = None + view_820 = torch.ops.aten.view.default(mm_168, [2, 8192, 4096]); mm_168 = None + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16); primals_222 = None + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 64, '0'); convert_element_type_799 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + mm_169 = torch.ops.aten.mm.default(view_819, permute_265) + view_823 = torch.ops.aten.view.default(mm_169, [2, 8192, 1024]); mm_169 = None + view_826 = torch.ops.aten.view.default(mm_170, [2, 8192, 1024]); mm_170 = None + view_827 = torch.ops.aten.view.default(view_820, [2, 8192, -1, 128]); view_820 = None + view_828 = torch.ops.aten.view.default(view_823, [2, 8192, -1, 128]); view_823 = None + view_829 = torch.ops.aten.view.default(view_826, [2, 8192, -1, 128]); view_826 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_827, torch.float32); view_827 = None + view_830 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 32, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_830); view_830 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_828, torch.float32); view_828 = None + view_831 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 8, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_831); view_831 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_16); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_833 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 32, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_16); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_834 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 8, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_833, torch.bfloat16); view_833 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_834, torch.bfloat16); view_834 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 8, 4, 128]); unsqueeze_48 = None + clone_48 = torch.ops.aten.clone.default(expand_48, memory_format = torch.contiguous_format); expand_48 = None + view_835 = torch.ops.aten.view.default(clone_48, [2, 8192, 32, 128]); clone_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_829, 3); view_829 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 8, 4, 128]); unsqueeze_49 = None + clone_49 = torch.ops.aten.clone.default(expand_49, memory_format = torch.contiguous_format); expand_49 = None + view_836 = torch.ops.aten.view.default(clone_49, [2, 8192, 32, 128]); clone_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_835, [0, 2, 1, 3]); view_835 = None + permute_269 = torch.ops.aten.permute.default(view_836, [0, 2, 1, 3]); view_836 = None + _scaled_dot_product_cudnn_attention_backward_7 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_597, permute_267, permute_268, permute_269, getitem_216, getitem_217, getitem_222, getitem_223, None, None, None, 8192, 8192, 0.0, True); permute_597 = permute_267 = permute_268 = permute_269 = getitem_216 = getitem_217 = getitem_222 = getitem_223 = None + getitem_309 = _scaled_dot_product_cudnn_attention_backward_7[0] + getitem_310 = _scaled_dot_product_cudnn_attention_backward_7[1] + getitem_311 = _scaled_dot_product_cudnn_attention_backward_7[2]; _scaled_dot_product_cudnn_attention_backward_7 = None + permute_598 = torch.ops.aten.permute.default(getitem_311, [0, 2, 1, 3]); getitem_311 = None + permute_599 = torch.ops.aten.permute.default(getitem_310, [0, 2, 1, 3]); getitem_310 = None + permute_600 = torch.ops.aten.permute.default(getitem_309, [0, 2, 1, 3]); getitem_309 = None + view_1272 = torch.ops.aten.view.default(permute_598, [2, 8192, 8, 4, 128]); permute_598 = None + sum_47 = torch.ops.aten.sum.dim_IntList(view_1272, [3], True); view_1272 = None + squeeze_14 = torch.ops.aten.squeeze.dim(sum_47, 3); sum_47 = None + view_1273 = torch.ops.aten.view.default(permute_599, [2, 8192, 8, 4, 128]); permute_599 = None + sum_48 = torch.ops.aten.sum.dim_IntList(view_1273, [3], True); view_1273 = None + squeeze_15 = torch.ops.aten.squeeze.dim(sum_48, 3); sum_48 = None + convert_element_type_1481 = torch.ops.prims.convert_element_type.default(squeeze_15, torch.float32); squeeze_15 = None + convert_element_type_1482 = torch.ops.prims.convert_element_type.default(permute_600, torch.float32); permute_600 = None + view_1274 = torch.ops.aten.view.default(convert_element_type_1481, [2, 8192, 8, 64, 2]); convert_element_type_1481 = None + view_as_complex_78 = torch.ops.aten.view_as_complex.default(view_1274); view_1274 = None + mul_416 = torch.ops.aten.mul.Tensor(view_as_complex_78, _conj); view_as_complex_78 = None + view_1275 = torch.ops.aten.view.default(convert_element_type_1482, [2, 8192, 32, 64, 2]); convert_element_type_1482 = None + view_as_complex_79 = torch.ops.aten.view_as_complex.default(view_1275); view_1275 = None + mul_417 = torch.ops.aten.mul.Tensor(view_as_complex_79, _conj); view_as_complex_79 = None + view_as_real_78 = torch.ops.aten.view_as_real.default(mul_416); mul_416 = None + view_1276 = torch.ops.aten.view.default(view_as_real_78, [2, 8192, 8, 128]); view_as_real_78 = None + convert_element_type_1483 = torch.ops.prims.convert_element_type.default(view_1276, torch.bfloat16); view_1276 = None + view_as_real_79 = torch.ops.aten.view_as_real.default(mul_417); mul_417 = None + view_1277 = torch.ops.aten.view.default(view_as_real_79, [2, 8192, 32, 128]); view_as_real_79 = None + convert_element_type_1484 = torch.ops.prims.convert_element_type.default(view_1277, torch.bfloat16); view_1277 = None + view_1278 = torch.ops.aten.view.default(squeeze_14, [2, 8192, 1024]); squeeze_14 = None + view_1279 = torch.ops.aten.view.default(convert_element_type_1483, [2, 8192, 1024]); convert_element_type_1483 = None + view_1280 = torch.ops.aten.view.default(convert_element_type_1484, [2, 8192, 4096]); convert_element_type_1484 = None + view_1281 = torch.ops.aten.view.default(view_1278, [16384, 1024]); view_1278 = None + permute_601 = torch.ops.aten.permute.default(view_1281, [1, 0]) + mm_333 = torch.ops.aten.mm.default(permute_601, view_819); permute_601 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16); primals_223 = None + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 64, '0'); convert_element_type_802 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + permute_603 = torch.ops.aten.permute.default(permute_266, [1, 0]); permute_266 = None + mm_334 = torch.ops.aten.mm.default(view_1281, permute_603); view_1281 = permute_603 = None + view_1282 = torch.ops.aten.view.default(mm_334, [2, 8192, 4096]); mm_334 = None + convert_element_type_1489 = torch.ops.prims.convert_element_type.default(mm_333, torch.float32); mm_333 = None + reduce_scatter_tensor_70 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1489, 'avg', 64, '0'); convert_element_type_1489 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_70); reduce_scatter_tensor_70 = None + view_1283 = torch.ops.aten.view.default(view_1279, [16384, 1024]); view_1279 = None + permute_605 = torch.ops.aten.permute.default(view_1283, [1, 0]) + mm_335 = torch.ops.aten.mm.default(permute_605, view_819); permute_605 = None + permute_607 = torch.ops.aten.permute.default(permute_265, [1, 0]); permute_265 = None + mm_336 = torch.ops.aten.mm.default(view_1283, permute_607); view_1283 = permute_607 = None + view_1284 = torch.ops.aten.view.default(mm_336, [2, 8192, 4096]); mm_336 = None + add_182 = torch.ops.aten.add.Tensor(view_1282, view_1284); view_1282 = view_1284 = None + convert_element_type_1494 = torch.ops.prims.convert_element_type.default(mm_335, torch.float32); mm_335 = None + reduce_scatter_tensor_71 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1494, 'avg', 64, '0'); convert_element_type_1494 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_71); reduce_scatter_tensor_71 = None + view_1285 = torch.ops.aten.view.default(view_1280, [16384, 4096]); view_1280 = None + permute_609 = torch.ops.aten.permute.default(view_1285, [1, 0]) + mm_337 = torch.ops.aten.mm.default(permute_609, view_819); permute_609 = view_819 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16); primals_221 = None + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 64, '0'); convert_element_type_796 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + permute_611 = torch.ops.aten.permute.default(permute_264, [1, 0]); permute_264 = None + mm_338 = torch.ops.aten.mm.default(view_1285, permute_611); view_1285 = permute_611 = None + view_1286 = torch.ops.aten.view.default(mm_338, [2, 8192, 4096]); mm_338 = None + add_183 = torch.ops.aten.add.Tensor(add_182, view_1286); add_182 = view_1286 = None + convert_element_type_1499 = torch.ops.prims.convert_element_type.default(mm_337, torch.float32); mm_337 = None + reduce_scatter_tensor_72 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1499, 'avg', 64, '0'); convert_element_type_1499 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_72); reduce_scatter_tensor_72 = None + convert_element_type_1500 = torch.ops.prims.convert_element_type.default(add_183, torch.float32); add_183 = None + convert_element_type_1502 = torch.ops.prims.convert_element_type.default(wait_tensor_217, torch.float32); wait_tensor_217 = None + mul_418 = torch.ops.aten.mul.Tensor(convert_element_type_1500, convert_element_type_1502); convert_element_type_1502 = None + mul_420 = torch.ops.aten.mul.Tensor(mul_192, mul_418) + sum_49 = torch.ops.aten.sum.dim_IntList(mul_420, [2], True); mul_420 = None + div_16 = torch.ops.aten.div.Tensor(mul_192, 4096) + mul_421 = torch.ops.aten.mul.Tensor(div_16, sum_49); div_16 = sum_49 = None + sub_24 = torch.ops.aten.sub.Tensor(mul_418, mul_421); mul_418 = mul_421 = None + mul_422 = torch.ops.aten.mul.Tensor(sub_24, rsqrt_48); sub_24 = rsqrt_48 = None + mul_423 = torch.ops.aten.mul.Tensor(convert_element_type_1500, mul_192); convert_element_type_1500 = mul_192 = None + sum_50 = torch.ops.aten.sum.dim_IntList(mul_423, [0, 1]); mul_423 = None + convert_element_type_1503 = torch.ops.prims.convert_element_type.default(mul_422, torch.bfloat16); mul_422 = None + add_184 = torch.ops.aten.add.Tensor(add_181, convert_element_type_1503); add_181 = convert_element_type_1503 = None + convert_element_type_default_49 = torch.ops.prims.convert_element_type.default(sum_50, torch.float32); sum_50 = None + reduce_scatter_tensor_73 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_49, 'avg', 64, '0'); convert_element_type_default_49 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_73); reduce_scatter_tensor_73 = None + view_1287 = torch.ops.aten.view.default(add_184, [16384, 4096]) + permute_613 = torch.ops.aten.permute.default(view_1287, [1, 0]) + permute_259 = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3]) + view_803 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16); primals_215 = None + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 64, '0'); convert_element_type_776 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_805 = torch.ops.aten.view.default(view_803, [16384, 4096]); view_803 = None + mm_164 = torch.ops.aten.mm.default(view_805, permute_260) + view_806 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + add_93 = torch.ops.aten.add.Tensor(add_91, view_806); view_806 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16); primals_216 = None + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 64, '0'); convert_element_type_779 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32); add_93 = None + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_213) + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + view_809 = torch.ops.aten.view.default(convert_element_type_781, [16384, 4096]); convert_element_type_781 = None + view_810 = torch.ops.aten.view.default(mm_165, [2, 8192, 14336]); mm_165 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_810, torch.float32); view_810 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16); primals_218 = None + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 64, '0'); convert_element_type_787 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + mm_166 = torch.ops.aten.mm.default(view_809, permute_262) + view_813 = torch.ops.aten.view.default(mm_166, [2, 8192, 14336]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_813) + view_815 = torch.ops.aten.view.default(mul_191, [16384, 14336]); mul_191 = None + mm_339 = torch.ops.aten.mm.default(permute_613, view_815); permute_613 = view_815 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16); primals_219 = None + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 64, '0'); convert_element_type_790 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_216, [1, 0]); wait_tensor_216 = None + permute_615 = torch.ops.aten.permute.default(permute_263, [1, 0]); permute_263 = None + mm_340 = torch.ops.aten.mm.default(view_1287, permute_615); view_1287 = permute_615 = None + view_1288 = torch.ops.aten.view.default(mm_340, [2, 8192, 14336]); mm_340 = None + convert_element_type_1510 = torch.ops.prims.convert_element_type.default(mm_339, torch.float32); mm_339 = None + reduce_scatter_tensor_74 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1510, 'avg', 64, '0'); convert_element_type_1510 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_74); reduce_scatter_tensor_74 = None + mul_424 = torch.ops.aten.mul.Tensor(view_1288, convert_element_type_786); convert_element_type_786 = None + mul_425 = torch.ops.aten.mul.Tensor(view_1288, view_813); view_1288 = view_813 = None + view_1289 = torch.ops.aten.view.default(mul_424, [16384, 14336]); mul_424 = None + permute_617 = torch.ops.aten.permute.default(view_1289, [1, 0]) + mm_341 = torch.ops.aten.mm.default(permute_617, view_809); permute_617 = None + permute_619 = torch.ops.aten.permute.default(permute_262, [1, 0]); permute_262 = None + mm_342 = torch.ops.aten.mm.default(view_1289, permute_619); view_1289 = permute_619 = None + view_1290 = torch.ops.aten.view.default(mm_342, [2, 8192, 4096]); mm_342 = None + convert_element_type_1515 = torch.ops.prims.convert_element_type.default(mm_341, torch.float32); mm_341 = None + reduce_scatter_tensor_75 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1515, 'avg', 64, '0'); convert_element_type_1515 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_75); reduce_scatter_tensor_75 = None + convert_element_type_1516 = torch.ops.prims.convert_element_type.default(mul_425, torch.float32); mul_425 = None + neg_8 = torch.ops.aten.neg.default(convert_element_type_785) + exp_8 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_185 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + reciprocal_8 = torch.ops.aten.reciprocal.default(add_185); add_185 = None + mul_426 = torch.ops.aten.mul.Tensor(reciprocal_8, 1); reciprocal_8 = None + mul_427 = torch.ops.aten.mul.Tensor(convert_element_type_1516, mul_426); convert_element_type_1516 = None + sub_25 = torch.ops.aten.sub.Tensor(1, mul_426); mul_426 = None + mul_428 = torch.ops.aten.mul.Tensor(convert_element_type_785, sub_25); convert_element_type_785 = sub_25 = None + add_186 = torch.ops.aten.add.Tensor(mul_428, 1); mul_428 = None + mul_429 = torch.ops.aten.mul.Tensor(mul_427, add_186); mul_427 = add_186 = None + convert_element_type_1518 = torch.ops.prims.convert_element_type.default(mul_429, torch.bfloat16); mul_429 = None + view_1291 = torch.ops.aten.view.default(convert_element_type_1518, [16384, 14336]); convert_element_type_1518 = None + permute_621 = torch.ops.aten.permute.default(view_1291, [1, 0]) + mm_343 = torch.ops.aten.mm.default(permute_621, view_809); permute_621 = view_809 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16); primals_217 = None + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 64, '0'); convert_element_type_782 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + permute_623 = torch.ops.aten.permute.default(permute_261, [1, 0]); permute_261 = None + mm_344 = torch.ops.aten.mm.default(view_1291, permute_623); view_1291 = permute_623 = None + view_1292 = torch.ops.aten.view.default(mm_344, [2, 8192, 4096]); mm_344 = None + add_187 = torch.ops.aten.add.Tensor(view_1290, view_1292); view_1290 = view_1292 = None + convert_element_type_1523 = torch.ops.prims.convert_element_type.default(mm_343, torch.float32); mm_343 = None + reduce_scatter_tensor_76 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1523, 'avg', 64, '0'); convert_element_type_1523 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_76); reduce_scatter_tensor_76 = None + convert_element_type_1524 = torch.ops.prims.convert_element_type.default(add_187, torch.float32); add_187 = None + convert_element_type_1526 = torch.ops.prims.convert_element_type.default(wait_tensor_213, torch.float32); wait_tensor_213 = None + mul_430 = torch.ops.aten.mul.Tensor(convert_element_type_1524, convert_element_type_1526); convert_element_type_1526 = None + mul_432 = torch.ops.aten.mul.Tensor(mul_188, mul_430) + sum_51 = torch.ops.aten.sum.dim_IntList(mul_432, [2], True); mul_432 = None + div_17 = torch.ops.aten.div.Tensor(mul_188, 4096) + mul_433 = torch.ops.aten.mul.Tensor(div_17, sum_51); div_17 = sum_51 = None + sub_26 = torch.ops.aten.sub.Tensor(mul_430, mul_433); mul_430 = mul_433 = None + mul_434 = torch.ops.aten.mul.Tensor(sub_26, rsqrt_47); sub_26 = rsqrt_47 = None + mul_435 = torch.ops.aten.mul.Tensor(convert_element_type_1524, mul_188); convert_element_type_1524 = mul_188 = None + sum_52 = torch.ops.aten.sum.dim_IntList(mul_435, [0, 1]); mul_435 = None + convert_element_type_1527 = torch.ops.prims.convert_element_type.default(mul_434, torch.bfloat16); mul_434 = None + add_188 = torch.ops.aten.add.Tensor(add_184, convert_element_type_1527); add_184 = convert_element_type_1527 = None + convert_element_type_default_48 = torch.ops.prims.convert_element_type.default(sum_52, torch.float32); sum_52 = None + reduce_scatter_tensor_77 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_48, 'avg', 64, '0'); convert_element_type_default_48 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_77); reduce_scatter_tensor_77 = None + view_1293 = torch.ops.aten.view.default(add_188, [16384, 4096]) + permute_625 = torch.ops.aten.permute.default(view_1293, [1, 0]) + mm_345 = torch.ops.aten.mm.default(permute_625, view_805); permute_625 = view_805 = None + permute_627 = torch.ops.aten.permute.default(permute_260, [1, 0]); permute_260 = None + mm_346 = torch.ops.aten.mm.default(view_1293, permute_627); view_1293 = permute_627 = None + view_1294 = torch.ops.aten.view.default(mm_346, [2, 8192, 4096]); mm_346 = None + convert_element_type_1534 = torch.ops.prims.convert_element_type.default(mm_345, torch.float32); mm_345 = None + reduce_scatter_tensor_78 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1534, 'avg', 64, '0'); convert_element_type_1534 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_78); reduce_scatter_tensor_78 = None + view_1295 = torch.ops.aten.view.default(view_1294, [2, 8192, 32, 128]); view_1294 = None + permute_629 = torch.ops.aten.permute.default(view_1295, [0, 2, 1, 3]); view_1295 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16); primals_211 = None + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 64, '0'); convert_element_type_760 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32); add_91 = None + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_208) + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + view_785 = torch.ops.aten.view.default(convert_element_type_762, [16384, 4096]); convert_element_type_762 = None + view_786 = torch.ops.aten.view.default(mm_161, [2, 8192, 4096]); mm_161 = None + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16); primals_213 = None + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 64, '0'); convert_element_type_766 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_210, [1, 0]); wait_tensor_210 = None + mm_162 = torch.ops.aten.mm.default(view_785, permute_254) + view_789 = torch.ops.aten.view.default(mm_162, [2, 8192, 1024]); mm_162 = None + view_792 = torch.ops.aten.view.default(mm_163, [2, 8192, 1024]); mm_163 = None + view_793 = torch.ops.aten.view.default(view_786, [2, 8192, -1, 128]); view_786 = None + view_794 = torch.ops.aten.view.default(view_789, [2, 8192, -1, 128]); view_789 = None + view_795 = torch.ops.aten.view.default(view_792, [2, 8192, -1, 128]); view_792 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_793, torch.float32); view_793 = None + view_796 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 32, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_796); view_796 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_794, torch.float32); view_794 = None + view_797 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 8, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_797); view_797 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_16); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_799 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 32, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_16); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_800 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 8, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_799, torch.bfloat16); view_799 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_800, torch.bfloat16); view_800 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 8, 4, 128]); unsqueeze_46 = None + clone_46 = torch.ops.aten.clone.default(expand_46, memory_format = torch.contiguous_format); expand_46 = None + view_801 = torch.ops.aten.view.default(clone_46, [2, 8192, 32, 128]); clone_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_795, 3); view_795 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 8, 4, 128]); unsqueeze_47 = None + clone_47 = torch.ops.aten.clone.default(expand_47, memory_format = torch.contiguous_format); expand_47 = None + view_802 = torch.ops.aten.view.default(clone_47, [2, 8192, 32, 128]); clone_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_801, [0, 2, 1, 3]); view_801 = None + permute_258 = torch.ops.aten.permute.default(view_802, [0, 2, 1, 3]); view_802 = None + _scaled_dot_product_cudnn_attention_backward_8 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_629, permute_256, permute_257, permute_258, getitem_207, getitem_208, getitem_213, getitem_214, None, None, None, 8192, 8192, 0.0, True); permute_629 = permute_256 = permute_257 = permute_258 = getitem_207 = getitem_208 = getitem_213 = getitem_214 = None + getitem_312 = _scaled_dot_product_cudnn_attention_backward_8[0] + getitem_313 = _scaled_dot_product_cudnn_attention_backward_8[1] + getitem_314 = _scaled_dot_product_cudnn_attention_backward_8[2]; _scaled_dot_product_cudnn_attention_backward_8 = None + permute_630 = torch.ops.aten.permute.default(getitem_314, [0, 2, 1, 3]); getitem_314 = None + permute_631 = torch.ops.aten.permute.default(getitem_313, [0, 2, 1, 3]); getitem_313 = None + permute_632 = torch.ops.aten.permute.default(getitem_312, [0, 2, 1, 3]); getitem_312 = None + view_1296 = torch.ops.aten.view.default(permute_630, [2, 8192, 8, 4, 128]); permute_630 = None + sum_53 = torch.ops.aten.sum.dim_IntList(view_1296, [3], True); view_1296 = None + squeeze_16 = torch.ops.aten.squeeze.dim(sum_53, 3); sum_53 = None + view_1297 = torch.ops.aten.view.default(permute_631, [2, 8192, 8, 4, 128]); permute_631 = None + sum_54 = torch.ops.aten.sum.dim_IntList(view_1297, [3], True); view_1297 = None + squeeze_17 = torch.ops.aten.squeeze.dim(sum_54, 3); sum_54 = None + convert_element_type_1535 = torch.ops.prims.convert_element_type.default(squeeze_17, torch.float32); squeeze_17 = None + convert_element_type_1536 = torch.ops.prims.convert_element_type.default(permute_632, torch.float32); permute_632 = None + view_1298 = torch.ops.aten.view.default(convert_element_type_1535, [2, 8192, 8, 64, 2]); convert_element_type_1535 = None + view_as_complex_80 = torch.ops.aten.view_as_complex.default(view_1298); view_1298 = None + mul_436 = torch.ops.aten.mul.Tensor(view_as_complex_80, _conj); view_as_complex_80 = None + view_1299 = torch.ops.aten.view.default(convert_element_type_1536, [2, 8192, 32, 64, 2]); convert_element_type_1536 = None + view_as_complex_81 = torch.ops.aten.view_as_complex.default(view_1299); view_1299 = None + mul_437 = torch.ops.aten.mul.Tensor(view_as_complex_81, _conj); view_as_complex_81 = None + view_as_real_80 = torch.ops.aten.view_as_real.default(mul_436); mul_436 = None + view_1300 = torch.ops.aten.view.default(view_as_real_80, [2, 8192, 8, 128]); view_as_real_80 = None + convert_element_type_1537 = torch.ops.prims.convert_element_type.default(view_1300, torch.bfloat16); view_1300 = None + view_as_real_81 = torch.ops.aten.view_as_real.default(mul_437); mul_437 = None + view_1301 = torch.ops.aten.view.default(view_as_real_81, [2, 8192, 32, 128]); view_as_real_81 = None + convert_element_type_1538 = torch.ops.prims.convert_element_type.default(view_1301, torch.bfloat16); view_1301 = None + view_1302 = torch.ops.aten.view.default(squeeze_16, [2, 8192, 1024]); squeeze_16 = None + view_1303 = torch.ops.aten.view.default(convert_element_type_1537, [2, 8192, 1024]); convert_element_type_1537 = None + view_1304 = torch.ops.aten.view.default(convert_element_type_1538, [2, 8192, 4096]); convert_element_type_1538 = None + view_1305 = torch.ops.aten.view.default(view_1302, [16384, 1024]); view_1302 = None + permute_633 = torch.ops.aten.permute.default(view_1305, [1, 0]) + mm_347 = torch.ops.aten.mm.default(permute_633, view_785); permute_633 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16); primals_214 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 64, '0'); convert_element_type_769 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_211, [1, 0]); wait_tensor_211 = None + permute_635 = torch.ops.aten.permute.default(permute_255, [1, 0]); permute_255 = None + mm_348 = torch.ops.aten.mm.default(view_1305, permute_635); view_1305 = permute_635 = None + view_1306 = torch.ops.aten.view.default(mm_348, [2, 8192, 4096]); mm_348 = None + convert_element_type_1543 = torch.ops.prims.convert_element_type.default(mm_347, torch.float32); mm_347 = None + reduce_scatter_tensor_79 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1543, 'avg', 64, '0'); convert_element_type_1543 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_79); reduce_scatter_tensor_79 = None + view_1307 = torch.ops.aten.view.default(view_1303, [16384, 1024]); view_1303 = None + permute_637 = torch.ops.aten.permute.default(view_1307, [1, 0]) + mm_349 = torch.ops.aten.mm.default(permute_637, view_785); permute_637 = None + permute_639 = torch.ops.aten.permute.default(permute_254, [1, 0]); permute_254 = None + mm_350 = torch.ops.aten.mm.default(view_1307, permute_639); view_1307 = permute_639 = None + view_1308 = torch.ops.aten.view.default(mm_350, [2, 8192, 4096]); mm_350 = None + add_189 = torch.ops.aten.add.Tensor(view_1306, view_1308); view_1306 = view_1308 = None + convert_element_type_1548 = torch.ops.prims.convert_element_type.default(mm_349, torch.float32); mm_349 = None + reduce_scatter_tensor_80 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1548, 'avg', 64, '0'); convert_element_type_1548 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_80); reduce_scatter_tensor_80 = None + view_1309 = torch.ops.aten.view.default(view_1304, [16384, 4096]); view_1304 = None + permute_641 = torch.ops.aten.permute.default(view_1309, [1, 0]) + mm_351 = torch.ops.aten.mm.default(permute_641, view_785); permute_641 = view_785 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16); primals_212 = None + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 64, '0'); convert_element_type_763 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_209, [1, 0]); wait_tensor_209 = None + permute_643 = torch.ops.aten.permute.default(permute_253, [1, 0]); permute_253 = None + mm_352 = torch.ops.aten.mm.default(view_1309, permute_643); view_1309 = permute_643 = None + view_1310 = torch.ops.aten.view.default(mm_352, [2, 8192, 4096]); mm_352 = None + add_190 = torch.ops.aten.add.Tensor(add_189, view_1310); add_189 = view_1310 = None + convert_element_type_1553 = torch.ops.prims.convert_element_type.default(mm_351, torch.float32); mm_351 = None + reduce_scatter_tensor_81 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1553, 'avg', 64, '0'); convert_element_type_1553 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_81); reduce_scatter_tensor_81 = None + convert_element_type_1554 = torch.ops.prims.convert_element_type.default(add_190, torch.float32); add_190 = None + convert_element_type_1556 = torch.ops.prims.convert_element_type.default(wait_tensor_208, torch.float32); wait_tensor_208 = None + mul_438 = torch.ops.aten.mul.Tensor(convert_element_type_1554, convert_element_type_1556); convert_element_type_1556 = None + mul_440 = torch.ops.aten.mul.Tensor(mul_184, mul_438) + sum_55 = torch.ops.aten.sum.dim_IntList(mul_440, [2], True); mul_440 = None + div_18 = torch.ops.aten.div.Tensor(mul_184, 4096) + mul_441 = torch.ops.aten.mul.Tensor(div_18, sum_55); div_18 = sum_55 = None + sub_27 = torch.ops.aten.sub.Tensor(mul_438, mul_441); mul_438 = mul_441 = None + mul_442 = torch.ops.aten.mul.Tensor(sub_27, rsqrt_46); sub_27 = rsqrt_46 = None + mul_443 = torch.ops.aten.mul.Tensor(convert_element_type_1554, mul_184); convert_element_type_1554 = mul_184 = None + sum_56 = torch.ops.aten.sum.dim_IntList(mul_443, [0, 1]); mul_443 = None + convert_element_type_1557 = torch.ops.prims.convert_element_type.default(mul_442, torch.bfloat16); mul_442 = None + add_191 = torch.ops.aten.add.Tensor(add_188, convert_element_type_1557); add_188 = convert_element_type_1557 = None + convert_element_type_default_47 = torch.ops.prims.convert_element_type.default(sum_56, torch.float32); sum_56 = None + reduce_scatter_tensor_82 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_47, 'avg', 64, '0'); convert_element_type_default_47 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_82); reduce_scatter_tensor_82 = None + view_1311 = torch.ops.aten.view.default(add_191, [16384, 4096]) + permute_645 = torch.ops.aten.permute.default(view_1311, [1, 0]) + permute_248 = torch.ops.aten.permute.default(getitem_198, [0, 2, 1, 3]) + view_769 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16); primals_206 = None + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 64, '0'); convert_element_type_743 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_203, [1, 0]); wait_tensor_203 = None + view_771 = torch.ops.aten.view.default(view_769, [16384, 4096]); view_769 = None + mm_157 = torch.ops.aten.mm.default(view_771, permute_249) + view_772 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + add_89 = torch.ops.aten.add.Tensor(add_87, view_772); view_772 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16); primals_207 = None + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 64, '0'); convert_element_type_746 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32); add_89 = None + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_204) + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + view_775 = torch.ops.aten.view.default(convert_element_type_748, [16384, 4096]); convert_element_type_748 = None + view_776 = torch.ops.aten.view.default(mm_158, [2, 8192, 14336]); mm_158 = None + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_776, torch.float32); view_776 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16); primals_209 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 64, '0'); convert_element_type_754 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + mm_159 = torch.ops.aten.mm.default(view_775, permute_251) + view_779 = torch.ops.aten.view.default(mm_159, [2, 8192, 14336]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_779) + view_781 = torch.ops.aten.view.default(mul_183, [16384, 14336]); mul_183 = None + mm_353 = torch.ops.aten.mm.default(permute_645, view_781); permute_645 = view_781 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16); primals_210 = None + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 64, '0'); convert_element_type_757 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + permute_647 = torch.ops.aten.permute.default(permute_252, [1, 0]); permute_252 = None + mm_354 = torch.ops.aten.mm.default(view_1311, permute_647); view_1311 = permute_647 = None + view_1312 = torch.ops.aten.view.default(mm_354, [2, 8192, 14336]); mm_354 = None + convert_element_type_1564 = torch.ops.prims.convert_element_type.default(mm_353, torch.float32); mm_353 = None + reduce_scatter_tensor_83 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1564, 'avg', 64, '0'); convert_element_type_1564 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_83); reduce_scatter_tensor_83 = None + mul_444 = torch.ops.aten.mul.Tensor(view_1312, convert_element_type_753); convert_element_type_753 = None + mul_445 = torch.ops.aten.mul.Tensor(view_1312, view_779); view_1312 = view_779 = None + view_1313 = torch.ops.aten.view.default(mul_444, [16384, 14336]); mul_444 = None + permute_649 = torch.ops.aten.permute.default(view_1313, [1, 0]) + mm_355 = torch.ops.aten.mm.default(permute_649, view_775); permute_649 = None + permute_651 = torch.ops.aten.permute.default(permute_251, [1, 0]); permute_251 = None + mm_356 = torch.ops.aten.mm.default(view_1313, permute_651); view_1313 = permute_651 = None + view_1314 = torch.ops.aten.view.default(mm_356, [2, 8192, 4096]); mm_356 = None + convert_element_type_1569 = torch.ops.prims.convert_element_type.default(mm_355, torch.float32); mm_355 = None + reduce_scatter_tensor_84 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1569, 'avg', 64, '0'); convert_element_type_1569 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_84); reduce_scatter_tensor_84 = None + convert_element_type_1570 = torch.ops.prims.convert_element_type.default(mul_445, torch.float32); mul_445 = None + neg_9 = torch.ops.aten.neg.default(convert_element_type_752) + exp_9 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_192 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + reciprocal_9 = torch.ops.aten.reciprocal.default(add_192); add_192 = None + mul_446 = torch.ops.aten.mul.Tensor(reciprocal_9, 1); reciprocal_9 = None + mul_447 = torch.ops.aten.mul.Tensor(convert_element_type_1570, mul_446); convert_element_type_1570 = None + sub_28 = torch.ops.aten.sub.Tensor(1, mul_446); mul_446 = None + mul_448 = torch.ops.aten.mul.Tensor(convert_element_type_752, sub_28); convert_element_type_752 = sub_28 = None + add_193 = torch.ops.aten.add.Tensor(mul_448, 1); mul_448 = None + mul_449 = torch.ops.aten.mul.Tensor(mul_447, add_193); mul_447 = add_193 = None + convert_element_type_1572 = torch.ops.prims.convert_element_type.default(mul_449, torch.bfloat16); mul_449 = None + view_1315 = torch.ops.aten.view.default(convert_element_type_1572, [16384, 14336]); convert_element_type_1572 = None + permute_653 = torch.ops.aten.permute.default(view_1315, [1, 0]) + mm_357 = torch.ops.aten.mm.default(permute_653, view_775); permute_653 = view_775 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16); primals_208 = None + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 64, '0'); convert_element_type_749 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + permute_655 = torch.ops.aten.permute.default(permute_250, [1, 0]); permute_250 = None + mm_358 = torch.ops.aten.mm.default(view_1315, permute_655); view_1315 = permute_655 = None + view_1316 = torch.ops.aten.view.default(mm_358, [2, 8192, 4096]); mm_358 = None + add_194 = torch.ops.aten.add.Tensor(view_1314, view_1316); view_1314 = view_1316 = None + convert_element_type_1577 = torch.ops.prims.convert_element_type.default(mm_357, torch.float32); mm_357 = None + reduce_scatter_tensor_85 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1577, 'avg', 64, '0'); convert_element_type_1577 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_85); reduce_scatter_tensor_85 = None + convert_element_type_1578 = torch.ops.prims.convert_element_type.default(add_194, torch.float32); add_194 = None + convert_element_type_1580 = torch.ops.prims.convert_element_type.default(wait_tensor_204, torch.float32); wait_tensor_204 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_1578, convert_element_type_1580); convert_element_type_1580 = None + mul_452 = torch.ops.aten.mul.Tensor(mul_180, mul_450) + sum_57 = torch.ops.aten.sum.dim_IntList(mul_452, [2], True); mul_452 = None + div_19 = torch.ops.aten.div.Tensor(mul_180, 4096) + mul_453 = torch.ops.aten.mul.Tensor(div_19, sum_57); div_19 = sum_57 = None + sub_29 = torch.ops.aten.sub.Tensor(mul_450, mul_453); mul_450 = mul_453 = None + mul_454 = torch.ops.aten.mul.Tensor(sub_29, rsqrt_45); sub_29 = rsqrt_45 = None + mul_455 = torch.ops.aten.mul.Tensor(convert_element_type_1578, mul_180); convert_element_type_1578 = mul_180 = None + sum_58 = torch.ops.aten.sum.dim_IntList(mul_455, [0, 1]); mul_455 = None + convert_element_type_1581 = torch.ops.prims.convert_element_type.default(mul_454, torch.bfloat16); mul_454 = None + add_195 = torch.ops.aten.add.Tensor(add_191, convert_element_type_1581); add_191 = convert_element_type_1581 = None + convert_element_type_default_46 = torch.ops.prims.convert_element_type.default(sum_58, torch.float32); sum_58 = None + reduce_scatter_tensor_86 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_46, 'avg', 64, '0'); convert_element_type_default_46 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_86); reduce_scatter_tensor_86 = None + view_1317 = torch.ops.aten.view.default(add_195, [16384, 4096]) + permute_657 = torch.ops.aten.permute.default(view_1317, [1, 0]) + mm_359 = torch.ops.aten.mm.default(permute_657, view_771); permute_657 = view_771 = None + permute_659 = torch.ops.aten.permute.default(permute_249, [1, 0]); permute_249 = None + mm_360 = torch.ops.aten.mm.default(view_1317, permute_659); view_1317 = permute_659 = None + view_1318 = torch.ops.aten.view.default(mm_360, [2, 8192, 4096]); mm_360 = None + convert_element_type_1588 = torch.ops.prims.convert_element_type.default(mm_359, torch.float32); mm_359 = None + reduce_scatter_tensor_87 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1588, 'avg', 64, '0'); convert_element_type_1588 = None + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_87); reduce_scatter_tensor_87 = None + view_1319 = torch.ops.aten.view.default(view_1318, [2, 8192, 32, 128]); view_1318 = None + permute_661 = torch.ops.aten.permute.default(view_1319, [0, 2, 1, 3]); view_1319 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16); primals_202 = None + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 64, '0'); convert_element_type_727 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32); add_87 = None + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_199) + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + view_751 = torch.ops.aten.view.default(convert_element_type_729, [16384, 4096]); convert_element_type_729 = None + view_752 = torch.ops.aten.view.default(mm_154, [2, 8192, 4096]); mm_154 = None + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16); primals_204 = None + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 64, '0'); convert_element_type_733 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_155 = torch.ops.aten.mm.default(view_751, permute_243) + view_755 = torch.ops.aten.view.default(mm_155, [2, 8192, 1024]); mm_155 = None + view_758 = torch.ops.aten.view.default(mm_156, [2, 8192, 1024]); mm_156 = None + view_759 = torch.ops.aten.view.default(view_752, [2, 8192, -1, 128]); view_752 = None + view_760 = torch.ops.aten.view.default(view_755, [2, 8192, -1, 128]); view_755 = None + view_761 = torch.ops.aten.view.default(view_758, [2, 8192, -1, 128]); view_758 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_759, torch.float32); view_759 = None + view_762 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 32, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_762); view_762 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_760, torch.float32); view_760 = None + view_763 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 8, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_763); view_763 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_16); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_765 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 32, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_16); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_766 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 8, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_765, torch.bfloat16); view_765 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_766, torch.bfloat16); view_766 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 8, 4, 128]); unsqueeze_44 = None + clone_44 = torch.ops.aten.clone.default(expand_44, memory_format = torch.contiguous_format); expand_44 = None + view_767 = torch.ops.aten.view.default(clone_44, [2, 8192, 32, 128]); clone_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_761, 3); view_761 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 8, 4, 128]); unsqueeze_45 = None + clone_45 = torch.ops.aten.clone.default(expand_45, memory_format = torch.contiguous_format); expand_45 = None + view_768 = torch.ops.aten.view.default(clone_45, [2, 8192, 32, 128]); clone_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_767, [0, 2, 1, 3]); view_767 = None + permute_247 = torch.ops.aten.permute.default(view_768, [0, 2, 1, 3]); view_768 = None + _scaled_dot_product_cudnn_attention_backward_9 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_661, permute_245, permute_246, permute_247, getitem_198, getitem_199, getitem_204, getitem_205, None, None, None, 8192, 8192, 0.0, True); permute_661 = permute_245 = permute_246 = permute_247 = getitem_198 = getitem_199 = getitem_204 = getitem_205 = None + getitem_315 = _scaled_dot_product_cudnn_attention_backward_9[0] + getitem_316 = _scaled_dot_product_cudnn_attention_backward_9[1] + getitem_317 = _scaled_dot_product_cudnn_attention_backward_9[2]; _scaled_dot_product_cudnn_attention_backward_9 = None + permute_662 = torch.ops.aten.permute.default(getitem_317, [0, 2, 1, 3]); getitem_317 = None + permute_663 = torch.ops.aten.permute.default(getitem_316, [0, 2, 1, 3]); getitem_316 = None + permute_664 = torch.ops.aten.permute.default(getitem_315, [0, 2, 1, 3]); getitem_315 = None + view_1320 = torch.ops.aten.view.default(permute_662, [2, 8192, 8, 4, 128]); permute_662 = None + sum_59 = torch.ops.aten.sum.dim_IntList(view_1320, [3], True); view_1320 = None + squeeze_18 = torch.ops.aten.squeeze.dim(sum_59, 3); sum_59 = None + view_1321 = torch.ops.aten.view.default(permute_663, [2, 8192, 8, 4, 128]); permute_663 = None + sum_60 = torch.ops.aten.sum.dim_IntList(view_1321, [3], True); view_1321 = None + squeeze_19 = torch.ops.aten.squeeze.dim(sum_60, 3); sum_60 = None + convert_element_type_1589 = torch.ops.prims.convert_element_type.default(squeeze_19, torch.float32); squeeze_19 = None + convert_element_type_1590 = torch.ops.prims.convert_element_type.default(permute_664, torch.float32); permute_664 = None + view_1322 = torch.ops.aten.view.default(convert_element_type_1589, [2, 8192, 8, 64, 2]); convert_element_type_1589 = None + view_as_complex_82 = torch.ops.aten.view_as_complex.default(view_1322); view_1322 = None + mul_456 = torch.ops.aten.mul.Tensor(view_as_complex_82, _conj); view_as_complex_82 = None + view_1323 = torch.ops.aten.view.default(convert_element_type_1590, [2, 8192, 32, 64, 2]); convert_element_type_1590 = None + view_as_complex_83 = torch.ops.aten.view_as_complex.default(view_1323); view_1323 = None + mul_457 = torch.ops.aten.mul.Tensor(view_as_complex_83, _conj); view_as_complex_83 = None + view_as_real_82 = torch.ops.aten.view_as_real.default(mul_456); mul_456 = None + view_1324 = torch.ops.aten.view.default(view_as_real_82, [2, 8192, 8, 128]); view_as_real_82 = None + convert_element_type_1591 = torch.ops.prims.convert_element_type.default(view_1324, torch.bfloat16); view_1324 = None + view_as_real_83 = torch.ops.aten.view_as_real.default(mul_457); mul_457 = None + view_1325 = torch.ops.aten.view.default(view_as_real_83, [2, 8192, 32, 128]); view_as_real_83 = None + convert_element_type_1592 = torch.ops.prims.convert_element_type.default(view_1325, torch.bfloat16); view_1325 = None + view_1326 = torch.ops.aten.view.default(squeeze_18, [2, 8192, 1024]); squeeze_18 = None + view_1327 = torch.ops.aten.view.default(convert_element_type_1591, [2, 8192, 1024]); convert_element_type_1591 = None + view_1328 = torch.ops.aten.view.default(convert_element_type_1592, [2, 8192, 4096]); convert_element_type_1592 = None + view_1329 = torch.ops.aten.view.default(view_1326, [16384, 1024]); view_1326 = None + permute_665 = torch.ops.aten.permute.default(view_1329, [1, 0]) + mm_361 = torch.ops.aten.mm.default(permute_665, view_751); permute_665 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16); primals_205 = None + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 64, '0'); convert_element_type_736 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + permute_667 = torch.ops.aten.permute.default(permute_244, [1, 0]); permute_244 = None + mm_362 = torch.ops.aten.mm.default(view_1329, permute_667); view_1329 = permute_667 = None + view_1330 = torch.ops.aten.view.default(mm_362, [2, 8192, 4096]); mm_362 = None + convert_element_type_1597 = torch.ops.prims.convert_element_type.default(mm_361, torch.float32); mm_361 = None + reduce_scatter_tensor_88 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1597, 'avg', 64, '0'); convert_element_type_1597 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_88); reduce_scatter_tensor_88 = None + view_1331 = torch.ops.aten.view.default(view_1327, [16384, 1024]); view_1327 = None + permute_669 = torch.ops.aten.permute.default(view_1331, [1, 0]) + mm_363 = torch.ops.aten.mm.default(permute_669, view_751); permute_669 = None + permute_671 = torch.ops.aten.permute.default(permute_243, [1, 0]); permute_243 = None + mm_364 = torch.ops.aten.mm.default(view_1331, permute_671); view_1331 = permute_671 = None + view_1332 = torch.ops.aten.view.default(mm_364, [2, 8192, 4096]); mm_364 = None + add_196 = torch.ops.aten.add.Tensor(view_1330, view_1332); view_1330 = view_1332 = None + convert_element_type_1602 = torch.ops.prims.convert_element_type.default(mm_363, torch.float32); mm_363 = None + reduce_scatter_tensor_89 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1602, 'avg', 64, '0'); convert_element_type_1602 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_89); reduce_scatter_tensor_89 = None + view_1333 = torch.ops.aten.view.default(view_1328, [16384, 4096]); view_1328 = None + permute_673 = torch.ops.aten.permute.default(view_1333, [1, 0]) + mm_365 = torch.ops.aten.mm.default(permute_673, view_751); permute_673 = view_751 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16); primals_203 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 64, '0'); convert_element_type_730 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + permute_675 = torch.ops.aten.permute.default(permute_242, [1, 0]); permute_242 = None + mm_366 = torch.ops.aten.mm.default(view_1333, permute_675); view_1333 = permute_675 = None + view_1334 = torch.ops.aten.view.default(mm_366, [2, 8192, 4096]); mm_366 = None + add_197 = torch.ops.aten.add.Tensor(add_196, view_1334); add_196 = view_1334 = None + convert_element_type_1607 = torch.ops.prims.convert_element_type.default(mm_365, torch.float32); mm_365 = None + reduce_scatter_tensor_90 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1607, 'avg', 64, '0'); convert_element_type_1607 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_90); reduce_scatter_tensor_90 = None + convert_element_type_1608 = torch.ops.prims.convert_element_type.default(add_197, torch.float32); add_197 = None + convert_element_type_1610 = torch.ops.prims.convert_element_type.default(wait_tensor_199, torch.float32); wait_tensor_199 = None + mul_458 = torch.ops.aten.mul.Tensor(convert_element_type_1608, convert_element_type_1610); convert_element_type_1610 = None + mul_460 = torch.ops.aten.mul.Tensor(mul_176, mul_458) + sum_61 = torch.ops.aten.sum.dim_IntList(mul_460, [2], True); mul_460 = None + div_20 = torch.ops.aten.div.Tensor(mul_176, 4096) + mul_461 = torch.ops.aten.mul.Tensor(div_20, sum_61); div_20 = sum_61 = None + sub_30 = torch.ops.aten.sub.Tensor(mul_458, mul_461); mul_458 = mul_461 = None + mul_462 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_44); sub_30 = rsqrt_44 = None + mul_463 = torch.ops.aten.mul.Tensor(convert_element_type_1608, mul_176); convert_element_type_1608 = mul_176 = None + sum_62 = torch.ops.aten.sum.dim_IntList(mul_463, [0, 1]); mul_463 = None + convert_element_type_1611 = torch.ops.prims.convert_element_type.default(mul_462, torch.bfloat16); mul_462 = None + add_198 = torch.ops.aten.add.Tensor(add_195, convert_element_type_1611); add_195 = convert_element_type_1611 = None + convert_element_type_default_45 = torch.ops.prims.convert_element_type.default(sum_62, torch.float32); sum_62 = None + reduce_scatter_tensor_91 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_45, 'avg', 64, '0'); convert_element_type_default_45 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_91); reduce_scatter_tensor_91 = None + view_1335 = torch.ops.aten.view.default(add_198, [16384, 4096]) + permute_677 = torch.ops.aten.permute.default(view_1335, [1, 0]) + permute_237 = torch.ops.aten.permute.default(getitem_189, [0, 2, 1, 3]) + view_735 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16); primals_197 = None + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 64, '0'); convert_element_type_710 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + view_737 = torch.ops.aten.view.default(view_735, [16384, 4096]); view_735 = None + mm_150 = torch.ops.aten.mm.default(view_737, permute_238) + view_738 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + add_85 = torch.ops.aten.add.Tensor(add_83, view_738); view_738 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16); primals_198 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 64, '0'); convert_element_type_713 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32); add_85 = None + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_195) + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + view_741 = torch.ops.aten.view.default(convert_element_type_715, [16384, 4096]); convert_element_type_715 = None + view_742 = torch.ops.aten.view.default(mm_151, [2, 8192, 14336]); mm_151 = None + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_742, torch.float32); view_742 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16); primals_200 = None + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 64, '0'); convert_element_type_721 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + mm_152 = torch.ops.aten.mm.default(view_741, permute_240) + view_745 = torch.ops.aten.view.default(mm_152, [2, 8192, 14336]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_745) + view_747 = torch.ops.aten.view.default(mul_175, [16384, 14336]); mul_175 = None + mm_367 = torch.ops.aten.mm.default(permute_677, view_747); permute_677 = view_747 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16); primals_201 = None + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 64, '0'); convert_element_type_724 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + permute_679 = torch.ops.aten.permute.default(permute_241, [1, 0]); permute_241 = None + mm_368 = torch.ops.aten.mm.default(view_1335, permute_679); view_1335 = permute_679 = None + view_1336 = torch.ops.aten.view.default(mm_368, [2, 8192, 14336]); mm_368 = None + convert_element_type_1618 = torch.ops.prims.convert_element_type.default(mm_367, torch.float32); mm_367 = None + reduce_scatter_tensor_92 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1618, 'avg', 64, '0'); convert_element_type_1618 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_92); reduce_scatter_tensor_92 = None + mul_464 = torch.ops.aten.mul.Tensor(view_1336, convert_element_type_720); convert_element_type_720 = None + mul_465 = torch.ops.aten.mul.Tensor(view_1336, view_745); view_1336 = view_745 = None + view_1337 = torch.ops.aten.view.default(mul_464, [16384, 14336]); mul_464 = None + permute_681 = torch.ops.aten.permute.default(view_1337, [1, 0]) + mm_369 = torch.ops.aten.mm.default(permute_681, view_741); permute_681 = None + permute_683 = torch.ops.aten.permute.default(permute_240, [1, 0]); permute_240 = None + mm_370 = torch.ops.aten.mm.default(view_1337, permute_683); view_1337 = permute_683 = None + view_1338 = torch.ops.aten.view.default(mm_370, [2, 8192, 4096]); mm_370 = None + convert_element_type_1623 = torch.ops.prims.convert_element_type.default(mm_369, torch.float32); mm_369 = None + reduce_scatter_tensor_93 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1623, 'avg', 64, '0'); convert_element_type_1623 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_93); reduce_scatter_tensor_93 = None + convert_element_type_1624 = torch.ops.prims.convert_element_type.default(mul_465, torch.float32); mul_465 = None + neg_10 = torch.ops.aten.neg.default(convert_element_type_719) + exp_10 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_199 = torch.ops.aten.add.Tensor(exp_10, 1); exp_10 = None + reciprocal_10 = torch.ops.aten.reciprocal.default(add_199); add_199 = None + mul_466 = torch.ops.aten.mul.Tensor(reciprocal_10, 1); reciprocal_10 = None + mul_467 = torch.ops.aten.mul.Tensor(convert_element_type_1624, mul_466); convert_element_type_1624 = None + sub_31 = torch.ops.aten.sub.Tensor(1, mul_466); mul_466 = None + mul_468 = torch.ops.aten.mul.Tensor(convert_element_type_719, sub_31); convert_element_type_719 = sub_31 = None + add_200 = torch.ops.aten.add.Tensor(mul_468, 1); mul_468 = None + mul_469 = torch.ops.aten.mul.Tensor(mul_467, add_200); mul_467 = add_200 = None + convert_element_type_1626 = torch.ops.prims.convert_element_type.default(mul_469, torch.bfloat16); mul_469 = None + view_1339 = torch.ops.aten.view.default(convert_element_type_1626, [16384, 14336]); convert_element_type_1626 = None + permute_685 = torch.ops.aten.permute.default(view_1339, [1, 0]) + mm_371 = torch.ops.aten.mm.default(permute_685, view_741); permute_685 = view_741 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16); primals_199 = None + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 64, '0'); convert_element_type_716 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_196, [1, 0]); wait_tensor_196 = None + permute_687 = torch.ops.aten.permute.default(permute_239, [1, 0]); permute_239 = None + mm_372 = torch.ops.aten.mm.default(view_1339, permute_687); view_1339 = permute_687 = None + view_1340 = torch.ops.aten.view.default(mm_372, [2, 8192, 4096]); mm_372 = None + add_201 = torch.ops.aten.add.Tensor(view_1338, view_1340); view_1338 = view_1340 = None + convert_element_type_1631 = torch.ops.prims.convert_element_type.default(mm_371, torch.float32); mm_371 = None + reduce_scatter_tensor_94 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1631, 'avg', 64, '0'); convert_element_type_1631 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_94); reduce_scatter_tensor_94 = None + convert_element_type_1632 = torch.ops.prims.convert_element_type.default(add_201, torch.float32); add_201 = None + convert_element_type_1634 = torch.ops.prims.convert_element_type.default(wait_tensor_195, torch.float32); wait_tensor_195 = None + mul_470 = torch.ops.aten.mul.Tensor(convert_element_type_1632, convert_element_type_1634); convert_element_type_1634 = None + mul_472 = torch.ops.aten.mul.Tensor(mul_172, mul_470) + sum_63 = torch.ops.aten.sum.dim_IntList(mul_472, [2], True); mul_472 = None + div_21 = torch.ops.aten.div.Tensor(mul_172, 4096) + mul_473 = torch.ops.aten.mul.Tensor(div_21, sum_63); div_21 = sum_63 = None + sub_32 = torch.ops.aten.sub.Tensor(mul_470, mul_473); mul_470 = mul_473 = None + mul_474 = torch.ops.aten.mul.Tensor(sub_32, rsqrt_43); sub_32 = rsqrt_43 = None + mul_475 = torch.ops.aten.mul.Tensor(convert_element_type_1632, mul_172); convert_element_type_1632 = mul_172 = None + sum_64 = torch.ops.aten.sum.dim_IntList(mul_475, [0, 1]); mul_475 = None + convert_element_type_1635 = torch.ops.prims.convert_element_type.default(mul_474, torch.bfloat16); mul_474 = None + add_202 = torch.ops.aten.add.Tensor(add_198, convert_element_type_1635); add_198 = convert_element_type_1635 = None + convert_element_type_default_44 = torch.ops.prims.convert_element_type.default(sum_64, torch.float32); sum_64 = None + reduce_scatter_tensor_95 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_44, 'avg', 64, '0'); convert_element_type_default_44 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_95); reduce_scatter_tensor_95 = None + view_1341 = torch.ops.aten.view.default(add_202, [16384, 4096]) + permute_689 = torch.ops.aten.permute.default(view_1341, [1, 0]) + mm_373 = torch.ops.aten.mm.default(permute_689, view_737); permute_689 = view_737 = None + permute_691 = torch.ops.aten.permute.default(permute_238, [1, 0]); permute_238 = None + mm_374 = torch.ops.aten.mm.default(view_1341, permute_691); view_1341 = permute_691 = None + view_1342 = torch.ops.aten.view.default(mm_374, [2, 8192, 4096]); mm_374 = None + convert_element_type_1642 = torch.ops.prims.convert_element_type.default(mm_373, torch.float32); mm_373 = None + reduce_scatter_tensor_96 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1642, 'avg', 64, '0'); convert_element_type_1642 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_96); reduce_scatter_tensor_96 = None + view_1343 = torch.ops.aten.view.default(view_1342, [2, 8192, 32, 128]); view_1342 = None + permute_693 = torch.ops.aten.permute.default(view_1343, [0, 2, 1, 3]); view_1343 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16); primals_193 = None + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 64, '0'); convert_element_type_694 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32); add_83 = None + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_190) + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + view_717 = torch.ops.aten.view.default(convert_element_type_696, [16384, 4096]); convert_element_type_696 = None + view_718 = torch.ops.aten.view.default(mm_147, [2, 8192, 4096]); mm_147 = None + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 64, '0'); convert_element_type_700 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_192, [1, 0]); wait_tensor_192 = None + mm_148 = torch.ops.aten.mm.default(view_717, permute_232) + view_721 = torch.ops.aten.view.default(mm_148, [2, 8192, 1024]); mm_148 = None + view_724 = torch.ops.aten.view.default(mm_149, [2, 8192, 1024]); mm_149 = None + view_725 = torch.ops.aten.view.default(view_718, [2, 8192, -1, 128]); view_718 = None + view_726 = torch.ops.aten.view.default(view_721, [2, 8192, -1, 128]); view_721 = None + view_727 = torch.ops.aten.view.default(view_724, [2, 8192, -1, 128]); view_724 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_725, torch.float32); view_725 = None + view_728 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 32, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_728); view_728 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_726, torch.float32); view_726 = None + view_729 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 8, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_729); view_729 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_16); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_731 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 32, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_16); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_732 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 8, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_731, torch.bfloat16); view_731 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_732, torch.bfloat16); view_732 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 8, 4, 128]); unsqueeze_42 = None + clone_42 = torch.ops.aten.clone.default(expand_42, memory_format = torch.contiguous_format); expand_42 = None + view_733 = torch.ops.aten.view.default(clone_42, [2, 8192, 32, 128]); clone_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_727, 3); view_727 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 8, 4, 128]); unsqueeze_43 = None + clone_43 = torch.ops.aten.clone.default(expand_43, memory_format = torch.contiguous_format); expand_43 = None + view_734 = torch.ops.aten.view.default(clone_43, [2, 8192, 32, 128]); clone_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_733, [0, 2, 1, 3]); view_733 = None + permute_236 = torch.ops.aten.permute.default(view_734, [0, 2, 1, 3]); view_734 = None + _scaled_dot_product_cudnn_attention_backward_10 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_693, permute_234, permute_235, permute_236, getitem_189, getitem_190, getitem_195, getitem_196, None, None, None, 8192, 8192, 0.0, True); permute_693 = permute_234 = permute_235 = permute_236 = getitem_189 = getitem_190 = getitem_195 = getitem_196 = None + getitem_318 = _scaled_dot_product_cudnn_attention_backward_10[0] + getitem_319 = _scaled_dot_product_cudnn_attention_backward_10[1] + getitem_320 = _scaled_dot_product_cudnn_attention_backward_10[2]; _scaled_dot_product_cudnn_attention_backward_10 = None + permute_694 = torch.ops.aten.permute.default(getitem_320, [0, 2, 1, 3]); getitem_320 = None + permute_695 = torch.ops.aten.permute.default(getitem_319, [0, 2, 1, 3]); getitem_319 = None + permute_696 = torch.ops.aten.permute.default(getitem_318, [0, 2, 1, 3]); getitem_318 = None + view_1344 = torch.ops.aten.view.default(permute_694, [2, 8192, 8, 4, 128]); permute_694 = None + sum_65 = torch.ops.aten.sum.dim_IntList(view_1344, [3], True); view_1344 = None + squeeze_20 = torch.ops.aten.squeeze.dim(sum_65, 3); sum_65 = None + view_1345 = torch.ops.aten.view.default(permute_695, [2, 8192, 8, 4, 128]); permute_695 = None + sum_66 = torch.ops.aten.sum.dim_IntList(view_1345, [3], True); view_1345 = None + squeeze_21 = torch.ops.aten.squeeze.dim(sum_66, 3); sum_66 = None + convert_element_type_1643 = torch.ops.prims.convert_element_type.default(squeeze_21, torch.float32); squeeze_21 = None + convert_element_type_1644 = torch.ops.prims.convert_element_type.default(permute_696, torch.float32); permute_696 = None + view_1346 = torch.ops.aten.view.default(convert_element_type_1643, [2, 8192, 8, 64, 2]); convert_element_type_1643 = None + view_as_complex_84 = torch.ops.aten.view_as_complex.default(view_1346); view_1346 = None + mul_476 = torch.ops.aten.mul.Tensor(view_as_complex_84, _conj); view_as_complex_84 = None + view_1347 = torch.ops.aten.view.default(convert_element_type_1644, [2, 8192, 32, 64, 2]); convert_element_type_1644 = None + view_as_complex_85 = torch.ops.aten.view_as_complex.default(view_1347); view_1347 = None + mul_477 = torch.ops.aten.mul.Tensor(view_as_complex_85, _conj); view_as_complex_85 = None + view_as_real_84 = torch.ops.aten.view_as_real.default(mul_476); mul_476 = None + view_1348 = torch.ops.aten.view.default(view_as_real_84, [2, 8192, 8, 128]); view_as_real_84 = None + convert_element_type_1645 = torch.ops.prims.convert_element_type.default(view_1348, torch.bfloat16); view_1348 = None + view_as_real_85 = torch.ops.aten.view_as_real.default(mul_477); mul_477 = None + view_1349 = torch.ops.aten.view.default(view_as_real_85, [2, 8192, 32, 128]); view_as_real_85 = None + convert_element_type_1646 = torch.ops.prims.convert_element_type.default(view_1349, torch.bfloat16); view_1349 = None + view_1350 = torch.ops.aten.view.default(squeeze_20, [2, 8192, 1024]); squeeze_20 = None + view_1351 = torch.ops.aten.view.default(convert_element_type_1645, [2, 8192, 1024]); convert_element_type_1645 = None + view_1352 = torch.ops.aten.view.default(convert_element_type_1646, [2, 8192, 4096]); convert_element_type_1646 = None + view_1353 = torch.ops.aten.view.default(view_1350, [16384, 1024]); view_1350 = None + permute_697 = torch.ops.aten.permute.default(view_1353, [1, 0]) + mm_375 = torch.ops.aten.mm.default(permute_697, view_717); permute_697 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 64, '0'); convert_element_type_703 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + permute_699 = torch.ops.aten.permute.default(permute_233, [1, 0]); permute_233 = None + mm_376 = torch.ops.aten.mm.default(view_1353, permute_699); view_1353 = permute_699 = None + view_1354 = torch.ops.aten.view.default(mm_376, [2, 8192, 4096]); mm_376 = None + convert_element_type_1651 = torch.ops.prims.convert_element_type.default(mm_375, torch.float32); mm_375 = None + reduce_scatter_tensor_97 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1651, 'avg', 64, '0'); convert_element_type_1651 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_97); reduce_scatter_tensor_97 = None + view_1355 = torch.ops.aten.view.default(view_1351, [16384, 1024]); view_1351 = None + permute_701 = torch.ops.aten.permute.default(view_1355, [1, 0]) + mm_377 = torch.ops.aten.mm.default(permute_701, view_717); permute_701 = None + permute_703 = torch.ops.aten.permute.default(permute_232, [1, 0]); permute_232 = None + mm_378 = torch.ops.aten.mm.default(view_1355, permute_703); view_1355 = permute_703 = None + view_1356 = torch.ops.aten.view.default(mm_378, [2, 8192, 4096]); mm_378 = None + add_203 = torch.ops.aten.add.Tensor(view_1354, view_1356); view_1354 = view_1356 = None + convert_element_type_1656 = torch.ops.prims.convert_element_type.default(mm_377, torch.float32); mm_377 = None + reduce_scatter_tensor_98 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1656, 'avg', 64, '0'); convert_element_type_1656 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_98); reduce_scatter_tensor_98 = None + view_1357 = torch.ops.aten.view.default(view_1352, [16384, 4096]); view_1352 = None + permute_705 = torch.ops.aten.permute.default(view_1357, [1, 0]) + mm_379 = torch.ops.aten.mm.default(permute_705, view_717); permute_705 = view_717 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16); primals_194 = None + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 64, '0'); convert_element_type_697 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_191, [1, 0]); wait_tensor_191 = None + permute_707 = torch.ops.aten.permute.default(permute_231, [1, 0]); permute_231 = None + mm_380 = torch.ops.aten.mm.default(view_1357, permute_707); view_1357 = permute_707 = None + view_1358 = torch.ops.aten.view.default(mm_380, [2, 8192, 4096]); mm_380 = None + add_204 = torch.ops.aten.add.Tensor(add_203, view_1358); add_203 = view_1358 = None + convert_element_type_1661 = torch.ops.prims.convert_element_type.default(mm_379, torch.float32); mm_379 = None + reduce_scatter_tensor_99 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1661, 'avg', 64, '0'); convert_element_type_1661 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_99); reduce_scatter_tensor_99 = None + convert_element_type_1662 = torch.ops.prims.convert_element_type.default(add_204, torch.float32); add_204 = None + convert_element_type_1664 = torch.ops.prims.convert_element_type.default(wait_tensor_190, torch.float32); wait_tensor_190 = None + mul_478 = torch.ops.aten.mul.Tensor(convert_element_type_1662, convert_element_type_1664); convert_element_type_1664 = None + mul_480 = torch.ops.aten.mul.Tensor(mul_168, mul_478) + sum_67 = torch.ops.aten.sum.dim_IntList(mul_480, [2], True); mul_480 = None + div_22 = torch.ops.aten.div.Tensor(mul_168, 4096) + mul_481 = torch.ops.aten.mul.Tensor(div_22, sum_67); div_22 = sum_67 = None + sub_33 = torch.ops.aten.sub.Tensor(mul_478, mul_481); mul_478 = mul_481 = None + mul_482 = torch.ops.aten.mul.Tensor(sub_33, rsqrt_42); sub_33 = rsqrt_42 = None + mul_483 = torch.ops.aten.mul.Tensor(convert_element_type_1662, mul_168); convert_element_type_1662 = mul_168 = None + sum_68 = torch.ops.aten.sum.dim_IntList(mul_483, [0, 1]); mul_483 = None + convert_element_type_1665 = torch.ops.prims.convert_element_type.default(mul_482, torch.bfloat16); mul_482 = None + add_205 = torch.ops.aten.add.Tensor(add_202, convert_element_type_1665); add_202 = convert_element_type_1665 = None + convert_element_type_default_43 = torch.ops.prims.convert_element_type.default(sum_68, torch.float32); sum_68 = None + reduce_scatter_tensor_100 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_43, 'avg', 64, '0'); convert_element_type_default_43 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_100); reduce_scatter_tensor_100 = None + view_1359 = torch.ops.aten.view.default(add_205, [16384, 4096]) + permute_709 = torch.ops.aten.permute.default(view_1359, [1, 0]) + permute_226 = torch.ops.aten.permute.default(getitem_180, [0, 2, 1, 3]) + view_701 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16); primals_188 = None + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 64, '0'); convert_element_type_677 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_185, [1, 0]); wait_tensor_185 = None + view_703 = torch.ops.aten.view.default(view_701, [16384, 4096]); view_701 = None + mm_143 = torch.ops.aten.mm.default(view_703, permute_227) + view_704 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + add_81 = torch.ops.aten.add.Tensor(add_79, view_704); view_704 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16); primals_189 = None + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 64, '0'); convert_element_type_680 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32); add_81 = None + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_186) + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + view_707 = torch.ops.aten.view.default(convert_element_type_682, [16384, 4096]); convert_element_type_682 = None + view_708 = torch.ops.aten.view.default(mm_144, [2, 8192, 14336]); mm_144 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_708, torch.float32); view_708 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16); primals_191 = None + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 64, '0'); convert_element_type_688 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_145 = torch.ops.aten.mm.default(view_707, permute_229) + view_711 = torch.ops.aten.view.default(mm_145, [2, 8192, 14336]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_711) + view_713 = torch.ops.aten.view.default(mul_167, [16384, 14336]); mul_167 = None + mm_381 = torch.ops.aten.mm.default(permute_709, view_713); permute_709 = view_713 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16); primals_192 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 64, '0'); convert_element_type_691 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + permute_711 = torch.ops.aten.permute.default(permute_230, [1, 0]); permute_230 = None + mm_382 = torch.ops.aten.mm.default(view_1359, permute_711); view_1359 = permute_711 = None + view_1360 = torch.ops.aten.view.default(mm_382, [2, 8192, 14336]); mm_382 = None + convert_element_type_1672 = torch.ops.prims.convert_element_type.default(mm_381, torch.float32); mm_381 = None + reduce_scatter_tensor_101 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1672, 'avg', 64, '0'); convert_element_type_1672 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_101); reduce_scatter_tensor_101 = None + mul_484 = torch.ops.aten.mul.Tensor(view_1360, convert_element_type_687); convert_element_type_687 = None + mul_485 = torch.ops.aten.mul.Tensor(view_1360, view_711); view_1360 = view_711 = None + view_1361 = torch.ops.aten.view.default(mul_484, [16384, 14336]); mul_484 = None + permute_713 = torch.ops.aten.permute.default(view_1361, [1, 0]) + mm_383 = torch.ops.aten.mm.default(permute_713, view_707); permute_713 = None + permute_715 = torch.ops.aten.permute.default(permute_229, [1, 0]); permute_229 = None + mm_384 = torch.ops.aten.mm.default(view_1361, permute_715); view_1361 = permute_715 = None + view_1362 = torch.ops.aten.view.default(mm_384, [2, 8192, 4096]); mm_384 = None + convert_element_type_1677 = torch.ops.prims.convert_element_type.default(mm_383, torch.float32); mm_383 = None + reduce_scatter_tensor_102 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1677, 'avg', 64, '0'); convert_element_type_1677 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_102); reduce_scatter_tensor_102 = None + convert_element_type_1678 = torch.ops.prims.convert_element_type.default(mul_485, torch.float32); mul_485 = None + neg_11 = torch.ops.aten.neg.default(convert_element_type_686) + exp_11 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_206 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + reciprocal_11 = torch.ops.aten.reciprocal.default(add_206); add_206 = None + mul_486 = torch.ops.aten.mul.Tensor(reciprocal_11, 1); reciprocal_11 = None + mul_487 = torch.ops.aten.mul.Tensor(convert_element_type_1678, mul_486); convert_element_type_1678 = None + sub_34 = torch.ops.aten.sub.Tensor(1, mul_486); mul_486 = None + mul_488 = torch.ops.aten.mul.Tensor(convert_element_type_686, sub_34); convert_element_type_686 = sub_34 = None + add_207 = torch.ops.aten.add.Tensor(mul_488, 1); mul_488 = None + mul_489 = torch.ops.aten.mul.Tensor(mul_487, add_207); mul_487 = add_207 = None + convert_element_type_1680 = torch.ops.prims.convert_element_type.default(mul_489, torch.bfloat16); mul_489 = None + view_1363 = torch.ops.aten.view.default(convert_element_type_1680, [16384, 14336]); convert_element_type_1680 = None + permute_717 = torch.ops.aten.permute.default(view_1363, [1, 0]) + mm_385 = torch.ops.aten.mm.default(permute_717, view_707); permute_717 = view_707 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16); primals_190 = None + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 64, '0'); convert_element_type_683 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + permute_719 = torch.ops.aten.permute.default(permute_228, [1, 0]); permute_228 = None + mm_386 = torch.ops.aten.mm.default(view_1363, permute_719); view_1363 = permute_719 = None + view_1364 = torch.ops.aten.view.default(mm_386, [2, 8192, 4096]); mm_386 = None + add_208 = torch.ops.aten.add.Tensor(view_1362, view_1364); view_1362 = view_1364 = None + convert_element_type_1685 = torch.ops.prims.convert_element_type.default(mm_385, torch.float32); mm_385 = None + reduce_scatter_tensor_103 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1685, 'avg', 64, '0'); convert_element_type_1685 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_103); reduce_scatter_tensor_103 = None + convert_element_type_1686 = torch.ops.prims.convert_element_type.default(add_208, torch.float32); add_208 = None + convert_element_type_1688 = torch.ops.prims.convert_element_type.default(wait_tensor_186, torch.float32); wait_tensor_186 = None + mul_490 = torch.ops.aten.mul.Tensor(convert_element_type_1686, convert_element_type_1688); convert_element_type_1688 = None + mul_492 = torch.ops.aten.mul.Tensor(mul_164, mul_490) + sum_69 = torch.ops.aten.sum.dim_IntList(mul_492, [2], True); mul_492 = None + div_23 = torch.ops.aten.div.Tensor(mul_164, 4096) + mul_493 = torch.ops.aten.mul.Tensor(div_23, sum_69); div_23 = sum_69 = None + sub_35 = torch.ops.aten.sub.Tensor(mul_490, mul_493); mul_490 = mul_493 = None + mul_494 = torch.ops.aten.mul.Tensor(sub_35, rsqrt_41); sub_35 = rsqrt_41 = None + mul_495 = torch.ops.aten.mul.Tensor(convert_element_type_1686, mul_164); convert_element_type_1686 = mul_164 = None + sum_70 = torch.ops.aten.sum.dim_IntList(mul_495, [0, 1]); mul_495 = None + convert_element_type_1689 = torch.ops.prims.convert_element_type.default(mul_494, torch.bfloat16); mul_494 = None + add_209 = torch.ops.aten.add.Tensor(add_205, convert_element_type_1689); add_205 = convert_element_type_1689 = None + convert_element_type_default_42 = torch.ops.prims.convert_element_type.default(sum_70, torch.float32); sum_70 = None + reduce_scatter_tensor_104 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_42, 'avg', 64, '0'); convert_element_type_default_42 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_104); reduce_scatter_tensor_104 = None + view_1365 = torch.ops.aten.view.default(add_209, [16384, 4096]) + permute_721 = torch.ops.aten.permute.default(view_1365, [1, 0]) + mm_387 = torch.ops.aten.mm.default(permute_721, view_703); permute_721 = view_703 = None + permute_723 = torch.ops.aten.permute.default(permute_227, [1, 0]); permute_227 = None + mm_388 = torch.ops.aten.mm.default(view_1365, permute_723); view_1365 = permute_723 = None + view_1366 = torch.ops.aten.view.default(mm_388, [2, 8192, 4096]); mm_388 = None + convert_element_type_1696 = torch.ops.prims.convert_element_type.default(mm_387, torch.float32); mm_387 = None + reduce_scatter_tensor_105 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1696, 'avg', 64, '0'); convert_element_type_1696 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_105); reduce_scatter_tensor_105 = None + view_1367 = torch.ops.aten.view.default(view_1366, [2, 8192, 32, 128]); view_1366 = None + permute_725 = torch.ops.aten.permute.default(view_1367, [0, 2, 1, 3]); view_1367 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16); primals_184 = None + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 64, '0'); convert_element_type_661 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32); add_79 = None + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_181) + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + view_683 = torch.ops.aten.view.default(convert_element_type_663, [16384, 4096]); convert_element_type_663 = None + view_684 = torch.ops.aten.view.default(mm_140, [2, 8192, 4096]); mm_140 = None + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16); primals_186 = None + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 64, '0'); convert_element_type_667 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + mm_141 = torch.ops.aten.mm.default(view_683, permute_221) + view_687 = torch.ops.aten.view.default(mm_141, [2, 8192, 1024]); mm_141 = None + view_690 = torch.ops.aten.view.default(mm_142, [2, 8192, 1024]); mm_142 = None + view_691 = torch.ops.aten.view.default(view_684, [2, 8192, -1, 128]); view_684 = None + view_692 = torch.ops.aten.view.default(view_687, [2, 8192, -1, 128]); view_687 = None + view_693 = torch.ops.aten.view.default(view_690, [2, 8192, -1, 128]); view_690 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_691, torch.float32); view_691 = None + view_694 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 32, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_694); view_694 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_692, torch.float32); view_692 = None + view_695 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 8, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_695); view_695 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_16); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_697 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 32, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_16); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_698 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 8, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_697, torch.bfloat16); view_697 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_698, torch.bfloat16); view_698 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 8, 4, 128]); unsqueeze_40 = None + clone_40 = torch.ops.aten.clone.default(expand_40, memory_format = torch.contiguous_format); expand_40 = None + view_699 = torch.ops.aten.view.default(clone_40, [2, 8192, 32, 128]); clone_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_693, 3); view_693 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 8, 4, 128]); unsqueeze_41 = None + clone_41 = torch.ops.aten.clone.default(expand_41, memory_format = torch.contiguous_format); expand_41 = None + view_700 = torch.ops.aten.view.default(clone_41, [2, 8192, 32, 128]); clone_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_699, [0, 2, 1, 3]); view_699 = None + permute_225 = torch.ops.aten.permute.default(view_700, [0, 2, 1, 3]); view_700 = None + _scaled_dot_product_cudnn_attention_backward_11 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_725, permute_223, permute_224, permute_225, getitem_180, getitem_181, getitem_186, getitem_187, None, None, None, 8192, 8192, 0.0, True); permute_725 = permute_223 = permute_224 = permute_225 = getitem_180 = getitem_181 = getitem_186 = getitem_187 = None + getitem_321 = _scaled_dot_product_cudnn_attention_backward_11[0] + getitem_322 = _scaled_dot_product_cudnn_attention_backward_11[1] + getitem_323 = _scaled_dot_product_cudnn_attention_backward_11[2]; _scaled_dot_product_cudnn_attention_backward_11 = None + permute_726 = torch.ops.aten.permute.default(getitem_323, [0, 2, 1, 3]); getitem_323 = None + permute_727 = torch.ops.aten.permute.default(getitem_322, [0, 2, 1, 3]); getitem_322 = None + permute_728 = torch.ops.aten.permute.default(getitem_321, [0, 2, 1, 3]); getitem_321 = None + view_1368 = torch.ops.aten.view.default(permute_726, [2, 8192, 8, 4, 128]); permute_726 = None + sum_71 = torch.ops.aten.sum.dim_IntList(view_1368, [3], True); view_1368 = None + squeeze_22 = torch.ops.aten.squeeze.dim(sum_71, 3); sum_71 = None + view_1369 = torch.ops.aten.view.default(permute_727, [2, 8192, 8, 4, 128]); permute_727 = None + sum_72 = torch.ops.aten.sum.dim_IntList(view_1369, [3], True); view_1369 = None + squeeze_23 = torch.ops.aten.squeeze.dim(sum_72, 3); sum_72 = None + convert_element_type_1697 = torch.ops.prims.convert_element_type.default(squeeze_23, torch.float32); squeeze_23 = None + convert_element_type_1698 = torch.ops.prims.convert_element_type.default(permute_728, torch.float32); permute_728 = None + view_1370 = torch.ops.aten.view.default(convert_element_type_1697, [2, 8192, 8, 64, 2]); convert_element_type_1697 = None + view_as_complex_86 = torch.ops.aten.view_as_complex.default(view_1370); view_1370 = None + mul_496 = torch.ops.aten.mul.Tensor(view_as_complex_86, _conj); view_as_complex_86 = None + view_1371 = torch.ops.aten.view.default(convert_element_type_1698, [2, 8192, 32, 64, 2]); convert_element_type_1698 = None + view_as_complex_87 = torch.ops.aten.view_as_complex.default(view_1371); view_1371 = None + mul_497 = torch.ops.aten.mul.Tensor(view_as_complex_87, _conj); view_as_complex_87 = None + view_as_real_86 = torch.ops.aten.view_as_real.default(mul_496); mul_496 = None + view_1372 = torch.ops.aten.view.default(view_as_real_86, [2, 8192, 8, 128]); view_as_real_86 = None + convert_element_type_1699 = torch.ops.prims.convert_element_type.default(view_1372, torch.bfloat16); view_1372 = None + view_as_real_87 = torch.ops.aten.view_as_real.default(mul_497); mul_497 = None + view_1373 = torch.ops.aten.view.default(view_as_real_87, [2, 8192, 32, 128]); view_as_real_87 = None + convert_element_type_1700 = torch.ops.prims.convert_element_type.default(view_1373, torch.bfloat16); view_1373 = None + view_1374 = torch.ops.aten.view.default(squeeze_22, [2, 8192, 1024]); squeeze_22 = None + view_1375 = torch.ops.aten.view.default(convert_element_type_1699, [2, 8192, 1024]); convert_element_type_1699 = None + view_1376 = torch.ops.aten.view.default(convert_element_type_1700, [2, 8192, 4096]); convert_element_type_1700 = None + view_1377 = torch.ops.aten.view.default(view_1374, [16384, 1024]); view_1374 = None + permute_729 = torch.ops.aten.permute.default(view_1377, [1, 0]) + mm_389 = torch.ops.aten.mm.default(permute_729, view_683); permute_729 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16); primals_187 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 64, '0'); convert_element_type_670 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + permute_731 = torch.ops.aten.permute.default(permute_222, [1, 0]); permute_222 = None + mm_390 = torch.ops.aten.mm.default(view_1377, permute_731); view_1377 = permute_731 = None + view_1378 = torch.ops.aten.view.default(mm_390, [2, 8192, 4096]); mm_390 = None + convert_element_type_1705 = torch.ops.prims.convert_element_type.default(mm_389, torch.float32); mm_389 = None + reduce_scatter_tensor_106 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1705, 'avg', 64, '0'); convert_element_type_1705 = None + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_106); reduce_scatter_tensor_106 = None + view_1379 = torch.ops.aten.view.default(view_1375, [16384, 1024]); view_1375 = None + permute_733 = torch.ops.aten.permute.default(view_1379, [1, 0]) + mm_391 = torch.ops.aten.mm.default(permute_733, view_683); permute_733 = None + permute_735 = torch.ops.aten.permute.default(permute_221, [1, 0]); permute_221 = None + mm_392 = torch.ops.aten.mm.default(view_1379, permute_735); view_1379 = permute_735 = None + view_1380 = torch.ops.aten.view.default(mm_392, [2, 8192, 4096]); mm_392 = None + add_210 = torch.ops.aten.add.Tensor(view_1378, view_1380); view_1378 = view_1380 = None + convert_element_type_1710 = torch.ops.prims.convert_element_type.default(mm_391, torch.float32); mm_391 = None + reduce_scatter_tensor_107 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1710, 'avg', 64, '0'); convert_element_type_1710 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_107); reduce_scatter_tensor_107 = None + view_1381 = torch.ops.aten.view.default(view_1376, [16384, 4096]); view_1376 = None + permute_737 = torch.ops.aten.permute.default(view_1381, [1, 0]) + mm_393 = torch.ops.aten.mm.default(permute_737, view_683); permute_737 = view_683 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16); primals_185 = None + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 64, '0'); convert_element_type_664 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + permute_739 = torch.ops.aten.permute.default(permute_220, [1, 0]); permute_220 = None + mm_394 = torch.ops.aten.mm.default(view_1381, permute_739); view_1381 = permute_739 = None + view_1382 = torch.ops.aten.view.default(mm_394, [2, 8192, 4096]); mm_394 = None + add_211 = torch.ops.aten.add.Tensor(add_210, view_1382); add_210 = view_1382 = None + convert_element_type_1715 = torch.ops.prims.convert_element_type.default(mm_393, torch.float32); mm_393 = None + reduce_scatter_tensor_108 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1715, 'avg', 64, '0'); convert_element_type_1715 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_108); reduce_scatter_tensor_108 = None + convert_element_type_1716 = torch.ops.prims.convert_element_type.default(add_211, torch.float32); add_211 = None + convert_element_type_1718 = torch.ops.prims.convert_element_type.default(wait_tensor_181, torch.float32); wait_tensor_181 = None + mul_498 = torch.ops.aten.mul.Tensor(convert_element_type_1716, convert_element_type_1718); convert_element_type_1718 = None + mul_500 = torch.ops.aten.mul.Tensor(mul_160, mul_498) + sum_73 = torch.ops.aten.sum.dim_IntList(mul_500, [2], True); mul_500 = None + div_24 = torch.ops.aten.div.Tensor(mul_160, 4096) + mul_501 = torch.ops.aten.mul.Tensor(div_24, sum_73); div_24 = sum_73 = None + sub_36 = torch.ops.aten.sub.Tensor(mul_498, mul_501); mul_498 = mul_501 = None + mul_502 = torch.ops.aten.mul.Tensor(sub_36, rsqrt_40); sub_36 = rsqrt_40 = None + mul_503 = torch.ops.aten.mul.Tensor(convert_element_type_1716, mul_160); convert_element_type_1716 = mul_160 = None + sum_74 = torch.ops.aten.sum.dim_IntList(mul_503, [0, 1]); mul_503 = None + convert_element_type_1719 = torch.ops.prims.convert_element_type.default(mul_502, torch.bfloat16); mul_502 = None + add_212 = torch.ops.aten.add.Tensor(add_209, convert_element_type_1719); add_209 = convert_element_type_1719 = None + convert_element_type_default_41 = torch.ops.prims.convert_element_type.default(sum_74, torch.float32); sum_74 = None + reduce_scatter_tensor_109 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_41, 'avg', 64, '0'); convert_element_type_default_41 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_109); reduce_scatter_tensor_109 = None + view_1383 = torch.ops.aten.view.default(add_212, [16384, 4096]) + permute_741 = torch.ops.aten.permute.default(view_1383, [1, 0]) + permute_215 = torch.ops.aten.permute.default(getitem_171, [0, 2, 1, 3]) + view_667 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 64, '0'); convert_element_type_644 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_669 = torch.ops.aten.view.default(view_667, [16384, 4096]); view_667 = None + mm_136 = torch.ops.aten.mm.default(view_669, permute_216) + view_670 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + add_77 = torch.ops.aten.add.Tensor(add_75, view_670); view_670 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 64, '0'); convert_element_type_647 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32); add_77 = None + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_177) + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + view_673 = torch.ops.aten.view.default(convert_element_type_649, [16384, 4096]); convert_element_type_649 = None + view_674 = torch.ops.aten.view.default(mm_137, [2, 8192, 14336]); mm_137 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_674, torch.float32); view_674 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16); primals_182 = None + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 64, '0'); convert_element_type_655 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_179, [1, 0]); wait_tensor_179 = None + mm_138 = torch.ops.aten.mm.default(view_673, permute_218) + view_677 = torch.ops.aten.view.default(mm_138, [2, 8192, 14336]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_677) + view_679 = torch.ops.aten.view.default(mul_159, [16384, 14336]); mul_159 = None + mm_395 = torch.ops.aten.mm.default(permute_741, view_679); permute_741 = view_679 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16); primals_183 = None + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 64, '0'); convert_element_type_658 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + permute_743 = torch.ops.aten.permute.default(permute_219, [1, 0]); permute_219 = None + mm_396 = torch.ops.aten.mm.default(view_1383, permute_743); view_1383 = permute_743 = None + view_1384 = torch.ops.aten.view.default(mm_396, [2, 8192, 14336]); mm_396 = None + convert_element_type_1726 = torch.ops.prims.convert_element_type.default(mm_395, torch.float32); mm_395 = None + reduce_scatter_tensor_110 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1726, 'avg', 64, '0'); convert_element_type_1726 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_110); reduce_scatter_tensor_110 = None + mul_504 = torch.ops.aten.mul.Tensor(view_1384, convert_element_type_654); convert_element_type_654 = None + mul_505 = torch.ops.aten.mul.Tensor(view_1384, view_677); view_1384 = view_677 = None + view_1385 = torch.ops.aten.view.default(mul_504, [16384, 14336]); mul_504 = None + permute_745 = torch.ops.aten.permute.default(view_1385, [1, 0]) + mm_397 = torch.ops.aten.mm.default(permute_745, view_673); permute_745 = None + permute_747 = torch.ops.aten.permute.default(permute_218, [1, 0]); permute_218 = None + mm_398 = torch.ops.aten.mm.default(view_1385, permute_747); view_1385 = permute_747 = None + view_1386 = torch.ops.aten.view.default(mm_398, [2, 8192, 4096]); mm_398 = None + convert_element_type_1731 = torch.ops.prims.convert_element_type.default(mm_397, torch.float32); mm_397 = None + reduce_scatter_tensor_111 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1731, 'avg', 64, '0'); convert_element_type_1731 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_111); reduce_scatter_tensor_111 = None + convert_element_type_1732 = torch.ops.prims.convert_element_type.default(mul_505, torch.float32); mul_505 = None + neg_12 = torch.ops.aten.neg.default(convert_element_type_653) + exp_12 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_213 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + reciprocal_12 = torch.ops.aten.reciprocal.default(add_213); add_213 = None + mul_506 = torch.ops.aten.mul.Tensor(reciprocal_12, 1); reciprocal_12 = None + mul_507 = torch.ops.aten.mul.Tensor(convert_element_type_1732, mul_506); convert_element_type_1732 = None + sub_37 = torch.ops.aten.sub.Tensor(1, mul_506); mul_506 = None + mul_508 = torch.ops.aten.mul.Tensor(convert_element_type_653, sub_37); convert_element_type_653 = sub_37 = None + add_214 = torch.ops.aten.add.Tensor(mul_508, 1); mul_508 = None + mul_509 = torch.ops.aten.mul.Tensor(mul_507, add_214); mul_507 = add_214 = None + convert_element_type_1734 = torch.ops.prims.convert_element_type.default(mul_509, torch.bfloat16); mul_509 = None + view_1387 = torch.ops.aten.view.default(convert_element_type_1734, [16384, 14336]); convert_element_type_1734 = None + permute_749 = torch.ops.aten.permute.default(view_1387, [1, 0]) + mm_399 = torch.ops.aten.mm.default(permute_749, view_673); permute_749 = view_673 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16); primals_181 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 64, '0'); convert_element_type_650 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + permute_751 = torch.ops.aten.permute.default(permute_217, [1, 0]); permute_217 = None + mm_400 = torch.ops.aten.mm.default(view_1387, permute_751); view_1387 = permute_751 = None + view_1388 = torch.ops.aten.view.default(mm_400, [2, 8192, 4096]); mm_400 = None + add_215 = torch.ops.aten.add.Tensor(view_1386, view_1388); view_1386 = view_1388 = None + convert_element_type_1739 = torch.ops.prims.convert_element_type.default(mm_399, torch.float32); mm_399 = None + reduce_scatter_tensor_112 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1739, 'avg', 64, '0'); convert_element_type_1739 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_112); reduce_scatter_tensor_112 = None + convert_element_type_1740 = torch.ops.prims.convert_element_type.default(add_215, torch.float32); add_215 = None + convert_element_type_1742 = torch.ops.prims.convert_element_type.default(wait_tensor_177, torch.float32); wait_tensor_177 = None + mul_510 = torch.ops.aten.mul.Tensor(convert_element_type_1740, convert_element_type_1742); convert_element_type_1742 = None + mul_512 = torch.ops.aten.mul.Tensor(mul_156, mul_510) + sum_75 = torch.ops.aten.sum.dim_IntList(mul_512, [2], True); mul_512 = None + div_25 = torch.ops.aten.div.Tensor(mul_156, 4096) + mul_513 = torch.ops.aten.mul.Tensor(div_25, sum_75); div_25 = sum_75 = None + sub_38 = torch.ops.aten.sub.Tensor(mul_510, mul_513); mul_510 = mul_513 = None + mul_514 = torch.ops.aten.mul.Tensor(sub_38, rsqrt_39); sub_38 = rsqrt_39 = None + mul_515 = torch.ops.aten.mul.Tensor(convert_element_type_1740, mul_156); convert_element_type_1740 = mul_156 = None + sum_76 = torch.ops.aten.sum.dim_IntList(mul_515, [0, 1]); mul_515 = None + convert_element_type_1743 = torch.ops.prims.convert_element_type.default(mul_514, torch.bfloat16); mul_514 = None + add_216 = torch.ops.aten.add.Tensor(add_212, convert_element_type_1743); add_212 = convert_element_type_1743 = None + convert_element_type_default_40 = torch.ops.prims.convert_element_type.default(sum_76, torch.float32); sum_76 = None + reduce_scatter_tensor_113 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_40, 'avg', 64, '0'); convert_element_type_default_40 = None + wait_tensor_404 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_113); reduce_scatter_tensor_113 = None + view_1389 = torch.ops.aten.view.default(add_216, [16384, 4096]) + permute_753 = torch.ops.aten.permute.default(view_1389, [1, 0]) + mm_401 = torch.ops.aten.mm.default(permute_753, view_669); permute_753 = view_669 = None + permute_755 = torch.ops.aten.permute.default(permute_216, [1, 0]); permute_216 = None + mm_402 = torch.ops.aten.mm.default(view_1389, permute_755); view_1389 = permute_755 = None + view_1390 = torch.ops.aten.view.default(mm_402, [2, 8192, 4096]); mm_402 = None + convert_element_type_1750 = torch.ops.prims.convert_element_type.default(mm_401, torch.float32); mm_401 = None + reduce_scatter_tensor_114 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1750, 'avg', 64, '0'); convert_element_type_1750 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_114); reduce_scatter_tensor_114 = None + view_1391 = torch.ops.aten.view.default(view_1390, [2, 8192, 32, 128]); view_1390 = None + permute_757 = torch.ops.aten.permute.default(view_1391, [0, 2, 1, 3]); view_1391 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16); primals_175 = None + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 64, '0'); convert_element_type_628 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32); add_75 = None + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_172) + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + view_649 = torch.ops.aten.view.default(convert_element_type_630, [16384, 4096]); convert_element_type_630 = None + view_650 = torch.ops.aten.view.default(mm_133, [2, 8192, 4096]); mm_133 = None + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16); primals_177 = None + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 64, '0'); convert_element_type_634 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_134 = torch.ops.aten.mm.default(view_649, permute_210) + view_653 = torch.ops.aten.view.default(mm_134, [2, 8192, 1024]); mm_134 = None + view_656 = torch.ops.aten.view.default(mm_135, [2, 8192, 1024]); mm_135 = None + view_657 = torch.ops.aten.view.default(view_650, [2, 8192, -1, 128]); view_650 = None + view_658 = torch.ops.aten.view.default(view_653, [2, 8192, -1, 128]); view_653 = None + view_659 = torch.ops.aten.view.default(view_656, [2, 8192, -1, 128]); view_656 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_657, torch.float32); view_657 = None + view_660 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 32, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_660); view_660 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_658, torch.float32); view_658 = None + view_661 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 8, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_661); view_661 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_16); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_663 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 32, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_16); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_664 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 8, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_663, torch.bfloat16); view_663 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_664, torch.bfloat16); view_664 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 8, 4, 128]); unsqueeze_38 = None + clone_38 = torch.ops.aten.clone.default(expand_38, memory_format = torch.contiguous_format); expand_38 = None + view_665 = torch.ops.aten.view.default(clone_38, [2, 8192, 32, 128]); clone_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_659, 3); view_659 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 8, 4, 128]); unsqueeze_39 = None + clone_39 = torch.ops.aten.clone.default(expand_39, memory_format = torch.contiguous_format); expand_39 = None + view_666 = torch.ops.aten.view.default(clone_39, [2, 8192, 32, 128]); clone_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_665, [0, 2, 1, 3]); view_665 = None + permute_214 = torch.ops.aten.permute.default(view_666, [0, 2, 1, 3]); view_666 = None + _scaled_dot_product_cudnn_attention_backward_12 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_757, permute_212, permute_213, permute_214, getitem_171, getitem_172, getitem_177, getitem_178, None, None, None, 8192, 8192, 0.0, True); permute_757 = permute_212 = permute_213 = permute_214 = getitem_171 = getitem_172 = getitem_177 = getitem_178 = None + getitem_324 = _scaled_dot_product_cudnn_attention_backward_12[0] + getitem_325 = _scaled_dot_product_cudnn_attention_backward_12[1] + getitem_326 = _scaled_dot_product_cudnn_attention_backward_12[2]; _scaled_dot_product_cudnn_attention_backward_12 = None + permute_758 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]); getitem_326 = None + permute_759 = torch.ops.aten.permute.default(getitem_325, [0, 2, 1, 3]); getitem_325 = None + permute_760 = torch.ops.aten.permute.default(getitem_324, [0, 2, 1, 3]); getitem_324 = None + view_1392 = torch.ops.aten.view.default(permute_758, [2, 8192, 8, 4, 128]); permute_758 = None + sum_77 = torch.ops.aten.sum.dim_IntList(view_1392, [3], True); view_1392 = None + squeeze_24 = torch.ops.aten.squeeze.dim(sum_77, 3); sum_77 = None + view_1393 = torch.ops.aten.view.default(permute_759, [2, 8192, 8, 4, 128]); permute_759 = None + sum_78 = torch.ops.aten.sum.dim_IntList(view_1393, [3], True); view_1393 = None + squeeze_25 = torch.ops.aten.squeeze.dim(sum_78, 3); sum_78 = None + convert_element_type_1751 = torch.ops.prims.convert_element_type.default(squeeze_25, torch.float32); squeeze_25 = None + convert_element_type_1752 = torch.ops.prims.convert_element_type.default(permute_760, torch.float32); permute_760 = None + view_1394 = torch.ops.aten.view.default(convert_element_type_1751, [2, 8192, 8, 64, 2]); convert_element_type_1751 = None + view_as_complex_88 = torch.ops.aten.view_as_complex.default(view_1394); view_1394 = None + mul_516 = torch.ops.aten.mul.Tensor(view_as_complex_88, _conj); view_as_complex_88 = None + view_1395 = torch.ops.aten.view.default(convert_element_type_1752, [2, 8192, 32, 64, 2]); convert_element_type_1752 = None + view_as_complex_89 = torch.ops.aten.view_as_complex.default(view_1395); view_1395 = None + mul_517 = torch.ops.aten.mul.Tensor(view_as_complex_89, _conj); view_as_complex_89 = None + view_as_real_88 = torch.ops.aten.view_as_real.default(mul_516); mul_516 = None + view_1396 = torch.ops.aten.view.default(view_as_real_88, [2, 8192, 8, 128]); view_as_real_88 = None + convert_element_type_1753 = torch.ops.prims.convert_element_type.default(view_1396, torch.bfloat16); view_1396 = None + view_as_real_89 = torch.ops.aten.view_as_real.default(mul_517); mul_517 = None + view_1397 = torch.ops.aten.view.default(view_as_real_89, [2, 8192, 32, 128]); view_as_real_89 = None + convert_element_type_1754 = torch.ops.prims.convert_element_type.default(view_1397, torch.bfloat16); view_1397 = None + view_1398 = torch.ops.aten.view.default(squeeze_24, [2, 8192, 1024]); squeeze_24 = None + view_1399 = torch.ops.aten.view.default(convert_element_type_1753, [2, 8192, 1024]); convert_element_type_1753 = None + view_1400 = torch.ops.aten.view.default(convert_element_type_1754, [2, 8192, 4096]); convert_element_type_1754 = None + view_1401 = torch.ops.aten.view.default(view_1398, [16384, 1024]); view_1398 = None + permute_761 = torch.ops.aten.permute.default(view_1401, [1, 0]) + mm_403 = torch.ops.aten.mm.default(permute_761, view_649); permute_761 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16); primals_178 = None + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 64, '0'); convert_element_type_637 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + permute_763 = torch.ops.aten.permute.default(permute_211, [1, 0]); permute_211 = None + mm_404 = torch.ops.aten.mm.default(view_1401, permute_763); view_1401 = permute_763 = None + view_1402 = torch.ops.aten.view.default(mm_404, [2, 8192, 4096]); mm_404 = None + convert_element_type_1759 = torch.ops.prims.convert_element_type.default(mm_403, torch.float32); mm_403 = None + reduce_scatter_tensor_115 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1759, 'avg', 64, '0'); convert_element_type_1759 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_115); reduce_scatter_tensor_115 = None + view_1403 = torch.ops.aten.view.default(view_1399, [16384, 1024]); view_1399 = None + permute_765 = torch.ops.aten.permute.default(view_1403, [1, 0]) + mm_405 = torch.ops.aten.mm.default(permute_765, view_649); permute_765 = None + permute_767 = torch.ops.aten.permute.default(permute_210, [1, 0]); permute_210 = None + mm_406 = torch.ops.aten.mm.default(view_1403, permute_767); view_1403 = permute_767 = None + view_1404 = torch.ops.aten.view.default(mm_406, [2, 8192, 4096]); mm_406 = None + add_217 = torch.ops.aten.add.Tensor(view_1402, view_1404); view_1402 = view_1404 = None + convert_element_type_1764 = torch.ops.prims.convert_element_type.default(mm_405, torch.float32); mm_405 = None + reduce_scatter_tensor_116 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1764, 'avg', 64, '0'); convert_element_type_1764 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_116); reduce_scatter_tensor_116 = None + view_1405 = torch.ops.aten.view.default(view_1400, [16384, 4096]); view_1400 = None + permute_769 = torch.ops.aten.permute.default(view_1405, [1, 0]) + mm_407 = torch.ops.aten.mm.default(permute_769, view_649); permute_769 = view_649 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16); primals_176 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 64, '0'); convert_element_type_631 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + permute_771 = torch.ops.aten.permute.default(permute_209, [1, 0]); permute_209 = None + mm_408 = torch.ops.aten.mm.default(view_1405, permute_771); view_1405 = permute_771 = None + view_1406 = torch.ops.aten.view.default(mm_408, [2, 8192, 4096]); mm_408 = None + add_218 = torch.ops.aten.add.Tensor(add_217, view_1406); add_217 = view_1406 = None + convert_element_type_1769 = torch.ops.prims.convert_element_type.default(mm_407, torch.float32); mm_407 = None + reduce_scatter_tensor_117 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1769, 'avg', 64, '0'); convert_element_type_1769 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_117); reduce_scatter_tensor_117 = None + convert_element_type_1770 = torch.ops.prims.convert_element_type.default(add_218, torch.float32); add_218 = None + convert_element_type_1772 = torch.ops.prims.convert_element_type.default(wait_tensor_172, torch.float32); wait_tensor_172 = None + mul_518 = torch.ops.aten.mul.Tensor(convert_element_type_1770, convert_element_type_1772); convert_element_type_1772 = None + mul_520 = torch.ops.aten.mul.Tensor(mul_152, mul_518) + sum_79 = torch.ops.aten.sum.dim_IntList(mul_520, [2], True); mul_520 = None + div_26 = torch.ops.aten.div.Tensor(mul_152, 4096) + mul_521 = torch.ops.aten.mul.Tensor(div_26, sum_79); div_26 = sum_79 = None + sub_39 = torch.ops.aten.sub.Tensor(mul_518, mul_521); mul_518 = mul_521 = None + mul_522 = torch.ops.aten.mul.Tensor(sub_39, rsqrt_38); sub_39 = rsqrt_38 = None + mul_523 = torch.ops.aten.mul.Tensor(convert_element_type_1770, mul_152); convert_element_type_1770 = mul_152 = None + sum_80 = torch.ops.aten.sum.dim_IntList(mul_523, [0, 1]); mul_523 = None + convert_element_type_1773 = torch.ops.prims.convert_element_type.default(mul_522, torch.bfloat16); mul_522 = None + add_219 = torch.ops.aten.add.Tensor(add_216, convert_element_type_1773); add_216 = convert_element_type_1773 = None + convert_element_type_default_39 = torch.ops.prims.convert_element_type.default(sum_80, torch.float32); sum_80 = None + reduce_scatter_tensor_118 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_39, 'avg', 64, '0'); convert_element_type_default_39 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_118); reduce_scatter_tensor_118 = None + view_1407 = torch.ops.aten.view.default(add_219, [16384, 4096]) + permute_773 = torch.ops.aten.permute.default(view_1407, [1, 0]) + permute_204 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_633 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16); primals_170 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 64, '0'); convert_element_type_611 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_635 = torch.ops.aten.view.default(view_633, [16384, 4096]); view_633 = None + mm_129 = torch.ops.aten.mm.default(view_635, permute_205) + view_636 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + add_73 = torch.ops.aten.add.Tensor(add_71, view_636); view_636 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16); primals_171 = None + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 64, '0'); convert_element_type_614 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32); add_73 = None + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_168) + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + view_639 = torch.ops.aten.view.default(convert_element_type_616, [16384, 4096]); convert_element_type_616 = None + view_640 = torch.ops.aten.view.default(mm_130, [2, 8192, 14336]); mm_130 = None + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_640, torch.float32); view_640 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16); primals_173 = None + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 64, '0'); convert_element_type_622 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_170, [1, 0]); wait_tensor_170 = None + mm_131 = torch.ops.aten.mm.default(view_639, permute_207) + view_643 = torch.ops.aten.view.default(mm_131, [2, 8192, 14336]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_643) + view_645 = torch.ops.aten.view.default(mul_151, [16384, 14336]); mul_151 = None + mm_409 = torch.ops.aten.mm.default(permute_773, view_645); permute_773 = view_645 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16); primals_174 = None + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 64, '0'); convert_element_type_625 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_171, [1, 0]); wait_tensor_171 = None + permute_775 = torch.ops.aten.permute.default(permute_208, [1, 0]); permute_208 = None + mm_410 = torch.ops.aten.mm.default(view_1407, permute_775); view_1407 = permute_775 = None + view_1408 = torch.ops.aten.view.default(mm_410, [2, 8192, 14336]); mm_410 = None + convert_element_type_1780 = torch.ops.prims.convert_element_type.default(mm_409, torch.float32); mm_409 = None + reduce_scatter_tensor_119 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1780, 'avg', 64, '0'); convert_element_type_1780 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_119); reduce_scatter_tensor_119 = None + mul_524 = torch.ops.aten.mul.Tensor(view_1408, convert_element_type_621); convert_element_type_621 = None + mul_525 = torch.ops.aten.mul.Tensor(view_1408, view_643); view_1408 = view_643 = None + view_1409 = torch.ops.aten.view.default(mul_524, [16384, 14336]); mul_524 = None + permute_777 = torch.ops.aten.permute.default(view_1409, [1, 0]) + mm_411 = torch.ops.aten.mm.default(permute_777, view_639); permute_777 = None + permute_779 = torch.ops.aten.permute.default(permute_207, [1, 0]); permute_207 = None + mm_412 = torch.ops.aten.mm.default(view_1409, permute_779); view_1409 = permute_779 = None + view_1410 = torch.ops.aten.view.default(mm_412, [2, 8192, 4096]); mm_412 = None + convert_element_type_1785 = torch.ops.prims.convert_element_type.default(mm_411, torch.float32); mm_411 = None + reduce_scatter_tensor_120 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1785, 'avg', 64, '0'); convert_element_type_1785 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_120); reduce_scatter_tensor_120 = None + convert_element_type_1786 = torch.ops.prims.convert_element_type.default(mul_525, torch.float32); mul_525 = None + neg_13 = torch.ops.aten.neg.default(convert_element_type_620) + exp_13 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_220 = torch.ops.aten.add.Tensor(exp_13, 1); exp_13 = None + reciprocal_13 = torch.ops.aten.reciprocal.default(add_220); add_220 = None + mul_526 = torch.ops.aten.mul.Tensor(reciprocal_13, 1); reciprocal_13 = None + mul_527 = torch.ops.aten.mul.Tensor(convert_element_type_1786, mul_526); convert_element_type_1786 = None + sub_40 = torch.ops.aten.sub.Tensor(1, mul_526); mul_526 = None + mul_528 = torch.ops.aten.mul.Tensor(convert_element_type_620, sub_40); convert_element_type_620 = sub_40 = None + add_221 = torch.ops.aten.add.Tensor(mul_528, 1); mul_528 = None + mul_529 = torch.ops.aten.mul.Tensor(mul_527, add_221); mul_527 = add_221 = None + convert_element_type_1788 = torch.ops.prims.convert_element_type.default(mul_529, torch.bfloat16); mul_529 = None + view_1411 = torch.ops.aten.view.default(convert_element_type_1788, [16384, 14336]); convert_element_type_1788 = None + permute_781 = torch.ops.aten.permute.default(view_1411, [1, 0]) + mm_413 = torch.ops.aten.mm.default(permute_781, view_639); permute_781 = view_639 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16); primals_172 = None + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 64, '0'); convert_element_type_617 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + permute_783 = torch.ops.aten.permute.default(permute_206, [1, 0]); permute_206 = None + mm_414 = torch.ops.aten.mm.default(view_1411, permute_783); view_1411 = permute_783 = None + view_1412 = torch.ops.aten.view.default(mm_414, [2, 8192, 4096]); mm_414 = None + add_222 = torch.ops.aten.add.Tensor(view_1410, view_1412); view_1410 = view_1412 = None + convert_element_type_1793 = torch.ops.prims.convert_element_type.default(mm_413, torch.float32); mm_413 = None + reduce_scatter_tensor_121 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1793, 'avg', 64, '0'); convert_element_type_1793 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_121); reduce_scatter_tensor_121 = None + convert_element_type_1794 = torch.ops.prims.convert_element_type.default(add_222, torch.float32); add_222 = None + convert_element_type_1796 = torch.ops.prims.convert_element_type.default(wait_tensor_168, torch.float32); wait_tensor_168 = None + mul_530 = torch.ops.aten.mul.Tensor(convert_element_type_1794, convert_element_type_1796); convert_element_type_1796 = None + mul_532 = torch.ops.aten.mul.Tensor(mul_148, mul_530) + sum_81 = torch.ops.aten.sum.dim_IntList(mul_532, [2], True); mul_532 = None + div_27 = torch.ops.aten.div.Tensor(mul_148, 4096) + mul_533 = torch.ops.aten.mul.Tensor(div_27, sum_81); div_27 = sum_81 = None + sub_41 = torch.ops.aten.sub.Tensor(mul_530, mul_533); mul_530 = mul_533 = None + mul_534 = torch.ops.aten.mul.Tensor(sub_41, rsqrt_37); sub_41 = rsqrt_37 = None + mul_535 = torch.ops.aten.mul.Tensor(convert_element_type_1794, mul_148); convert_element_type_1794 = mul_148 = None + sum_82 = torch.ops.aten.sum.dim_IntList(mul_535, [0, 1]); mul_535 = None + convert_element_type_1797 = torch.ops.prims.convert_element_type.default(mul_534, torch.bfloat16); mul_534 = None + add_223 = torch.ops.aten.add.Tensor(add_219, convert_element_type_1797); add_219 = convert_element_type_1797 = None + convert_element_type_default_38 = torch.ops.prims.convert_element_type.default(sum_82, torch.float32); sum_82 = None + reduce_scatter_tensor_122 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_38, 'avg', 64, '0'); convert_element_type_default_38 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_122); reduce_scatter_tensor_122 = None + view_1413 = torch.ops.aten.view.default(add_223, [16384, 4096]) + permute_785 = torch.ops.aten.permute.default(view_1413, [1, 0]) + mm_415 = torch.ops.aten.mm.default(permute_785, view_635); permute_785 = view_635 = None + permute_787 = torch.ops.aten.permute.default(permute_205, [1, 0]); permute_205 = None + mm_416 = torch.ops.aten.mm.default(view_1413, permute_787); view_1413 = permute_787 = None + view_1414 = torch.ops.aten.view.default(mm_416, [2, 8192, 4096]); mm_416 = None + convert_element_type_1804 = torch.ops.prims.convert_element_type.default(mm_415, torch.float32); mm_415 = None + reduce_scatter_tensor_123 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1804, 'avg', 64, '0'); convert_element_type_1804 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_123); reduce_scatter_tensor_123 = None + view_1415 = torch.ops.aten.view.default(view_1414, [2, 8192, 32, 128]); view_1414 = None + permute_789 = torch.ops.aten.permute.default(view_1415, [0, 2, 1, 3]); view_1415 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16); primals_166 = None + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 64, '0'); convert_element_type_595 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32); add_71 = None + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_163) + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + view_615 = torch.ops.aten.view.default(convert_element_type_597, [16384, 4096]); convert_element_type_597 = None + view_616 = torch.ops.aten.view.default(mm_126, [2, 8192, 4096]); mm_126 = None + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16); primals_168 = None + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 64, '0'); convert_element_type_601 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_165, [1, 0]); wait_tensor_165 = None + mm_127 = torch.ops.aten.mm.default(view_615, permute_199) + view_619 = torch.ops.aten.view.default(mm_127, [2, 8192, 1024]); mm_127 = None + view_622 = torch.ops.aten.view.default(mm_128, [2, 8192, 1024]); mm_128 = None + view_623 = torch.ops.aten.view.default(view_616, [2, 8192, -1, 128]); view_616 = None + view_624 = torch.ops.aten.view.default(view_619, [2, 8192, -1, 128]); view_619 = None + view_625 = torch.ops.aten.view.default(view_622, [2, 8192, -1, 128]); view_622 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_623, torch.float32); view_623 = None + view_626 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 32, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_626); view_626 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_624, torch.float32); view_624 = None + view_627 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 8, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_627); view_627 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_16); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_629 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 32, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_16); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_630 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 8, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_629, torch.bfloat16); view_629 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_630, torch.bfloat16); view_630 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 8, 4, 128]); unsqueeze_36 = None + clone_36 = torch.ops.aten.clone.default(expand_36, memory_format = torch.contiguous_format); expand_36 = None + view_631 = torch.ops.aten.view.default(clone_36, [2, 8192, 32, 128]); clone_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_625, 3); view_625 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 8, 4, 128]); unsqueeze_37 = None + clone_37 = torch.ops.aten.clone.default(expand_37, memory_format = torch.contiguous_format); expand_37 = None + view_632 = torch.ops.aten.view.default(clone_37, [2, 8192, 32, 128]); clone_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_631, [0, 2, 1, 3]); view_631 = None + permute_203 = torch.ops.aten.permute.default(view_632, [0, 2, 1, 3]); view_632 = None + _scaled_dot_product_cudnn_attention_backward_13 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_789, permute_201, permute_202, permute_203, getitem_162, getitem_163, getitem_168, getitem_169, None, None, None, 8192, 8192, 0.0, True); permute_789 = permute_201 = permute_202 = permute_203 = getitem_162 = getitem_163 = getitem_168 = getitem_169 = None + getitem_327 = _scaled_dot_product_cudnn_attention_backward_13[0] + getitem_328 = _scaled_dot_product_cudnn_attention_backward_13[1] + getitem_329 = _scaled_dot_product_cudnn_attention_backward_13[2]; _scaled_dot_product_cudnn_attention_backward_13 = None + permute_790 = torch.ops.aten.permute.default(getitem_329, [0, 2, 1, 3]); getitem_329 = None + permute_791 = torch.ops.aten.permute.default(getitem_328, [0, 2, 1, 3]); getitem_328 = None + permute_792 = torch.ops.aten.permute.default(getitem_327, [0, 2, 1, 3]); getitem_327 = None + view_1416 = torch.ops.aten.view.default(permute_790, [2, 8192, 8, 4, 128]); permute_790 = None + sum_83 = torch.ops.aten.sum.dim_IntList(view_1416, [3], True); view_1416 = None + squeeze_26 = torch.ops.aten.squeeze.dim(sum_83, 3); sum_83 = None + view_1417 = torch.ops.aten.view.default(permute_791, [2, 8192, 8, 4, 128]); permute_791 = None + sum_84 = torch.ops.aten.sum.dim_IntList(view_1417, [3], True); view_1417 = None + squeeze_27 = torch.ops.aten.squeeze.dim(sum_84, 3); sum_84 = None + convert_element_type_1805 = torch.ops.prims.convert_element_type.default(squeeze_27, torch.float32); squeeze_27 = None + convert_element_type_1806 = torch.ops.prims.convert_element_type.default(permute_792, torch.float32); permute_792 = None + view_1418 = torch.ops.aten.view.default(convert_element_type_1805, [2, 8192, 8, 64, 2]); convert_element_type_1805 = None + view_as_complex_90 = torch.ops.aten.view_as_complex.default(view_1418); view_1418 = None + mul_536 = torch.ops.aten.mul.Tensor(view_as_complex_90, _conj); view_as_complex_90 = None + view_1419 = torch.ops.aten.view.default(convert_element_type_1806, [2, 8192, 32, 64, 2]); convert_element_type_1806 = None + view_as_complex_91 = torch.ops.aten.view_as_complex.default(view_1419); view_1419 = None + mul_537 = torch.ops.aten.mul.Tensor(view_as_complex_91, _conj); view_as_complex_91 = None + view_as_real_90 = torch.ops.aten.view_as_real.default(mul_536); mul_536 = None + view_1420 = torch.ops.aten.view.default(view_as_real_90, [2, 8192, 8, 128]); view_as_real_90 = None + convert_element_type_1807 = torch.ops.prims.convert_element_type.default(view_1420, torch.bfloat16); view_1420 = None + view_as_real_91 = torch.ops.aten.view_as_real.default(mul_537); mul_537 = None + view_1421 = torch.ops.aten.view.default(view_as_real_91, [2, 8192, 32, 128]); view_as_real_91 = None + convert_element_type_1808 = torch.ops.prims.convert_element_type.default(view_1421, torch.bfloat16); view_1421 = None + view_1422 = torch.ops.aten.view.default(squeeze_26, [2, 8192, 1024]); squeeze_26 = None + view_1423 = torch.ops.aten.view.default(convert_element_type_1807, [2, 8192, 1024]); convert_element_type_1807 = None + view_1424 = torch.ops.aten.view.default(convert_element_type_1808, [2, 8192, 4096]); convert_element_type_1808 = None + view_1425 = torch.ops.aten.view.default(view_1422, [16384, 1024]); view_1422 = None + permute_793 = torch.ops.aten.permute.default(view_1425, [1, 0]) + mm_417 = torch.ops.aten.mm.default(permute_793, view_615); permute_793 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16); primals_169 = None + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 64, '0'); convert_element_type_604 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_166, [1, 0]); wait_tensor_166 = None + permute_795 = torch.ops.aten.permute.default(permute_200, [1, 0]); permute_200 = None + mm_418 = torch.ops.aten.mm.default(view_1425, permute_795); view_1425 = permute_795 = None + view_1426 = torch.ops.aten.view.default(mm_418, [2, 8192, 4096]); mm_418 = None + convert_element_type_1813 = torch.ops.prims.convert_element_type.default(mm_417, torch.float32); mm_417 = None + reduce_scatter_tensor_124 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1813, 'avg', 64, '0'); convert_element_type_1813 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_124); reduce_scatter_tensor_124 = None + view_1427 = torch.ops.aten.view.default(view_1423, [16384, 1024]); view_1423 = None + permute_797 = torch.ops.aten.permute.default(view_1427, [1, 0]) + mm_419 = torch.ops.aten.mm.default(permute_797, view_615); permute_797 = None + permute_799 = torch.ops.aten.permute.default(permute_199, [1, 0]); permute_199 = None + mm_420 = torch.ops.aten.mm.default(view_1427, permute_799); view_1427 = permute_799 = None + view_1428 = torch.ops.aten.view.default(mm_420, [2, 8192, 4096]); mm_420 = None + add_224 = torch.ops.aten.add.Tensor(view_1426, view_1428); view_1426 = view_1428 = None + convert_element_type_1818 = torch.ops.prims.convert_element_type.default(mm_419, torch.float32); mm_419 = None + reduce_scatter_tensor_125 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1818, 'avg', 64, '0'); convert_element_type_1818 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_125); reduce_scatter_tensor_125 = None + view_1429 = torch.ops.aten.view.default(view_1424, [16384, 4096]); view_1424 = None + permute_801 = torch.ops.aten.permute.default(view_1429, [1, 0]) + mm_421 = torch.ops.aten.mm.default(permute_801, view_615); permute_801 = view_615 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16); primals_167 = None + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 64, '0'); convert_element_type_598 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_164, [1, 0]); wait_tensor_164 = None + permute_803 = torch.ops.aten.permute.default(permute_198, [1, 0]); permute_198 = None + mm_422 = torch.ops.aten.mm.default(view_1429, permute_803); view_1429 = permute_803 = None + view_1430 = torch.ops.aten.view.default(mm_422, [2, 8192, 4096]); mm_422 = None + add_225 = torch.ops.aten.add.Tensor(add_224, view_1430); add_224 = view_1430 = None + convert_element_type_1823 = torch.ops.prims.convert_element_type.default(mm_421, torch.float32); mm_421 = None + reduce_scatter_tensor_126 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1823, 'avg', 64, '0'); convert_element_type_1823 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_126); reduce_scatter_tensor_126 = None + convert_element_type_1824 = torch.ops.prims.convert_element_type.default(add_225, torch.float32); add_225 = None + convert_element_type_1826 = torch.ops.prims.convert_element_type.default(wait_tensor_163, torch.float32); wait_tensor_163 = None + mul_538 = torch.ops.aten.mul.Tensor(convert_element_type_1824, convert_element_type_1826); convert_element_type_1826 = None + mul_540 = torch.ops.aten.mul.Tensor(mul_144, mul_538) + sum_85 = torch.ops.aten.sum.dim_IntList(mul_540, [2], True); mul_540 = None + div_28 = torch.ops.aten.div.Tensor(mul_144, 4096) + mul_541 = torch.ops.aten.mul.Tensor(div_28, sum_85); div_28 = sum_85 = None + sub_42 = torch.ops.aten.sub.Tensor(mul_538, mul_541); mul_538 = mul_541 = None + mul_542 = torch.ops.aten.mul.Tensor(sub_42, rsqrt_36); sub_42 = rsqrt_36 = None + mul_543 = torch.ops.aten.mul.Tensor(convert_element_type_1824, mul_144); convert_element_type_1824 = mul_144 = None + sum_86 = torch.ops.aten.sum.dim_IntList(mul_543, [0, 1]); mul_543 = None + convert_element_type_1827 = torch.ops.prims.convert_element_type.default(mul_542, torch.bfloat16); mul_542 = None + add_226 = torch.ops.aten.add.Tensor(add_223, convert_element_type_1827); add_223 = convert_element_type_1827 = None + convert_element_type_default_37 = torch.ops.prims.convert_element_type.default(sum_86, torch.float32); sum_86 = None + reduce_scatter_tensor_127 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_37, 'avg', 64, '0'); convert_element_type_default_37 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_127); reduce_scatter_tensor_127 = None + view_1431 = torch.ops.aten.view.default(add_226, [16384, 4096]) + permute_805 = torch.ops.aten.permute.default(view_1431, [1, 0]) + permute_193 = torch.ops.aten.permute.default(getitem_153, [0, 2, 1, 3]) + view_599 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16); primals_161 = None + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 64, '0'); convert_element_type_578 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_158, [1, 0]); wait_tensor_158 = None + view_601 = torch.ops.aten.view.default(view_599, [16384, 4096]); view_599 = None + mm_122 = torch.ops.aten.mm.default(view_601, permute_194) + view_602 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + add_69 = torch.ops.aten.add.Tensor(add_67, view_602); view_602 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16); primals_162 = None + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 64, '0'); convert_element_type_581 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32); add_69 = None + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_159) + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + view_605 = torch.ops.aten.view.default(convert_element_type_583, [16384, 4096]); convert_element_type_583 = None + view_606 = torch.ops.aten.view.default(mm_123, [2, 8192, 14336]); mm_123 = None + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_606, torch.float32); view_606 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16); primals_164 = None + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 64, '0'); convert_element_type_589 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_124 = torch.ops.aten.mm.default(view_605, permute_196) + view_609 = torch.ops.aten.view.default(mm_124, [2, 8192, 14336]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_609) + view_611 = torch.ops.aten.view.default(mul_143, [16384, 14336]); mul_143 = None + mm_423 = torch.ops.aten.mm.default(permute_805, view_611); permute_805 = view_611 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16); primals_165 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 64, '0'); convert_element_type_592 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + permute_807 = torch.ops.aten.permute.default(permute_197, [1, 0]); permute_197 = None + mm_424 = torch.ops.aten.mm.default(view_1431, permute_807); view_1431 = permute_807 = None + view_1432 = torch.ops.aten.view.default(mm_424, [2, 8192, 14336]); mm_424 = None + convert_element_type_1834 = torch.ops.prims.convert_element_type.default(mm_423, torch.float32); mm_423 = None + reduce_scatter_tensor_128 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1834, 'avg', 64, '0'); convert_element_type_1834 = None + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_128); reduce_scatter_tensor_128 = None + mul_544 = torch.ops.aten.mul.Tensor(view_1432, convert_element_type_588); convert_element_type_588 = None + mul_545 = torch.ops.aten.mul.Tensor(view_1432, view_609); view_1432 = view_609 = None + view_1433 = torch.ops.aten.view.default(mul_544, [16384, 14336]); mul_544 = None + permute_809 = torch.ops.aten.permute.default(view_1433, [1, 0]) + mm_425 = torch.ops.aten.mm.default(permute_809, view_605); permute_809 = None + permute_811 = torch.ops.aten.permute.default(permute_196, [1, 0]); permute_196 = None + mm_426 = torch.ops.aten.mm.default(view_1433, permute_811); view_1433 = permute_811 = None + view_1434 = torch.ops.aten.view.default(mm_426, [2, 8192, 4096]); mm_426 = None + convert_element_type_1839 = torch.ops.prims.convert_element_type.default(mm_425, torch.float32); mm_425 = None + reduce_scatter_tensor_129 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1839, 'avg', 64, '0'); convert_element_type_1839 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_129); reduce_scatter_tensor_129 = None + convert_element_type_1840 = torch.ops.prims.convert_element_type.default(mul_545, torch.float32); mul_545 = None + neg_14 = torch.ops.aten.neg.default(convert_element_type_587) + exp_14 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_227 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + reciprocal_14 = torch.ops.aten.reciprocal.default(add_227); add_227 = None + mul_546 = torch.ops.aten.mul.Tensor(reciprocal_14, 1); reciprocal_14 = None + mul_547 = torch.ops.aten.mul.Tensor(convert_element_type_1840, mul_546); convert_element_type_1840 = None + sub_43 = torch.ops.aten.sub.Tensor(1, mul_546); mul_546 = None + mul_548 = torch.ops.aten.mul.Tensor(convert_element_type_587, sub_43); convert_element_type_587 = sub_43 = None + add_228 = torch.ops.aten.add.Tensor(mul_548, 1); mul_548 = None + mul_549 = torch.ops.aten.mul.Tensor(mul_547, add_228); mul_547 = add_228 = None + convert_element_type_1842 = torch.ops.prims.convert_element_type.default(mul_549, torch.bfloat16); mul_549 = None + view_1435 = torch.ops.aten.view.default(convert_element_type_1842, [16384, 14336]); convert_element_type_1842 = None + permute_813 = torch.ops.aten.permute.default(view_1435, [1, 0]) + mm_427 = torch.ops.aten.mm.default(permute_813, view_605); permute_813 = view_605 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16); primals_163 = None + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 64, '0'); convert_element_type_584 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + permute_815 = torch.ops.aten.permute.default(permute_195, [1, 0]); permute_195 = None + mm_428 = torch.ops.aten.mm.default(view_1435, permute_815); view_1435 = permute_815 = None + view_1436 = torch.ops.aten.view.default(mm_428, [2, 8192, 4096]); mm_428 = None + add_229 = torch.ops.aten.add.Tensor(view_1434, view_1436); view_1434 = view_1436 = None + convert_element_type_1847 = torch.ops.prims.convert_element_type.default(mm_427, torch.float32); mm_427 = None + reduce_scatter_tensor_130 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1847, 'avg', 64, '0'); convert_element_type_1847 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_130); reduce_scatter_tensor_130 = None + convert_element_type_1848 = torch.ops.prims.convert_element_type.default(add_229, torch.float32); add_229 = None + convert_element_type_1850 = torch.ops.prims.convert_element_type.default(wait_tensor_159, torch.float32); wait_tensor_159 = None + mul_550 = torch.ops.aten.mul.Tensor(convert_element_type_1848, convert_element_type_1850); convert_element_type_1850 = None + mul_552 = torch.ops.aten.mul.Tensor(mul_140, mul_550) + sum_87 = torch.ops.aten.sum.dim_IntList(mul_552, [2], True); mul_552 = None + div_29 = torch.ops.aten.div.Tensor(mul_140, 4096) + mul_553 = torch.ops.aten.mul.Tensor(div_29, sum_87); div_29 = sum_87 = None + sub_44 = torch.ops.aten.sub.Tensor(mul_550, mul_553); mul_550 = mul_553 = None + mul_554 = torch.ops.aten.mul.Tensor(sub_44, rsqrt_35); sub_44 = rsqrt_35 = None + mul_555 = torch.ops.aten.mul.Tensor(convert_element_type_1848, mul_140); convert_element_type_1848 = mul_140 = None + sum_88 = torch.ops.aten.sum.dim_IntList(mul_555, [0, 1]); mul_555 = None + convert_element_type_1851 = torch.ops.prims.convert_element_type.default(mul_554, torch.bfloat16); mul_554 = None + add_230 = torch.ops.aten.add.Tensor(add_226, convert_element_type_1851); add_226 = convert_element_type_1851 = None + convert_element_type_default_36 = torch.ops.prims.convert_element_type.default(sum_88, torch.float32); sum_88 = None + reduce_scatter_tensor_131 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_36, 'avg', 64, '0'); convert_element_type_default_36 = None + wait_tensor_422 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_131); reduce_scatter_tensor_131 = None + view_1437 = torch.ops.aten.view.default(add_230, [16384, 4096]) + permute_817 = torch.ops.aten.permute.default(view_1437, [1, 0]) + mm_429 = torch.ops.aten.mm.default(permute_817, view_601); permute_817 = view_601 = None + permute_819 = torch.ops.aten.permute.default(permute_194, [1, 0]); permute_194 = None + mm_430 = torch.ops.aten.mm.default(view_1437, permute_819); view_1437 = permute_819 = None + view_1438 = torch.ops.aten.view.default(mm_430, [2, 8192, 4096]); mm_430 = None + convert_element_type_1858 = torch.ops.prims.convert_element_type.default(mm_429, torch.float32); mm_429 = None + reduce_scatter_tensor_132 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1858, 'avg', 64, '0'); convert_element_type_1858 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_132); reduce_scatter_tensor_132 = None + view_1439 = torch.ops.aten.view.default(view_1438, [2, 8192, 32, 128]); view_1438 = None + permute_821 = torch.ops.aten.permute.default(view_1439, [0, 2, 1, 3]); view_1439 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16); primals_157 = None + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 64, '0'); convert_element_type_562 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32); add_67 = None + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_154) + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + view_581 = torch.ops.aten.view.default(convert_element_type_564, [16384, 4096]); convert_element_type_564 = None + view_582 = torch.ops.aten.view.default(mm_119, [2, 8192, 4096]); mm_119 = None + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16); primals_159 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 64, '0'); convert_element_type_568 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + mm_120 = torch.ops.aten.mm.default(view_581, permute_188) + view_585 = torch.ops.aten.view.default(mm_120, [2, 8192, 1024]); mm_120 = None + view_588 = torch.ops.aten.view.default(mm_121, [2, 8192, 1024]); mm_121 = None + view_589 = torch.ops.aten.view.default(view_582, [2, 8192, -1, 128]); view_582 = None + view_590 = torch.ops.aten.view.default(view_585, [2, 8192, -1, 128]); view_585 = None + view_591 = torch.ops.aten.view.default(view_588, [2, 8192, -1, 128]); view_588 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_589, torch.float32); view_589 = None + view_592 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 32, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_592); view_592 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_590, torch.float32); view_590 = None + view_593 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 8, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_593); view_593 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_16); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_595 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 32, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_16); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_596 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 8, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_595, torch.bfloat16); view_595 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_596, torch.bfloat16); view_596 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 8, 4, 128]); unsqueeze_34 = None + clone_34 = torch.ops.aten.clone.default(expand_34, memory_format = torch.contiguous_format); expand_34 = None + view_597 = torch.ops.aten.view.default(clone_34, [2, 8192, 32, 128]); clone_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_591, 3); view_591 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 8, 4, 128]); unsqueeze_35 = None + clone_35 = torch.ops.aten.clone.default(expand_35, memory_format = torch.contiguous_format); expand_35 = None + view_598 = torch.ops.aten.view.default(clone_35, [2, 8192, 32, 128]); clone_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_597, [0, 2, 1, 3]); view_597 = None + permute_192 = torch.ops.aten.permute.default(view_598, [0, 2, 1, 3]); view_598 = None + _scaled_dot_product_cudnn_attention_backward_14 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_821, permute_190, permute_191, permute_192, getitem_153, getitem_154, getitem_159, getitem_160, None, None, None, 8192, 8192, 0.0, True); permute_821 = permute_190 = permute_191 = permute_192 = getitem_153 = getitem_154 = getitem_159 = getitem_160 = None + getitem_330 = _scaled_dot_product_cudnn_attention_backward_14[0] + getitem_331 = _scaled_dot_product_cudnn_attention_backward_14[1] + getitem_332 = _scaled_dot_product_cudnn_attention_backward_14[2]; _scaled_dot_product_cudnn_attention_backward_14 = None + permute_822 = torch.ops.aten.permute.default(getitem_332, [0, 2, 1, 3]); getitem_332 = None + permute_823 = torch.ops.aten.permute.default(getitem_331, [0, 2, 1, 3]); getitem_331 = None + permute_824 = torch.ops.aten.permute.default(getitem_330, [0, 2, 1, 3]); getitem_330 = None + view_1440 = torch.ops.aten.view.default(permute_822, [2, 8192, 8, 4, 128]); permute_822 = None + sum_89 = torch.ops.aten.sum.dim_IntList(view_1440, [3], True); view_1440 = None + squeeze_28 = torch.ops.aten.squeeze.dim(sum_89, 3); sum_89 = None + view_1441 = torch.ops.aten.view.default(permute_823, [2, 8192, 8, 4, 128]); permute_823 = None + sum_90 = torch.ops.aten.sum.dim_IntList(view_1441, [3], True); view_1441 = None + squeeze_29 = torch.ops.aten.squeeze.dim(sum_90, 3); sum_90 = None + convert_element_type_1859 = torch.ops.prims.convert_element_type.default(squeeze_29, torch.float32); squeeze_29 = None + convert_element_type_1860 = torch.ops.prims.convert_element_type.default(permute_824, torch.float32); permute_824 = None + view_1442 = torch.ops.aten.view.default(convert_element_type_1859, [2, 8192, 8, 64, 2]); convert_element_type_1859 = None + view_as_complex_92 = torch.ops.aten.view_as_complex.default(view_1442); view_1442 = None + mul_556 = torch.ops.aten.mul.Tensor(view_as_complex_92, _conj); view_as_complex_92 = None + view_1443 = torch.ops.aten.view.default(convert_element_type_1860, [2, 8192, 32, 64, 2]); convert_element_type_1860 = None + view_as_complex_93 = torch.ops.aten.view_as_complex.default(view_1443); view_1443 = None + mul_557 = torch.ops.aten.mul.Tensor(view_as_complex_93, _conj); view_as_complex_93 = None + view_as_real_92 = torch.ops.aten.view_as_real.default(mul_556); mul_556 = None + view_1444 = torch.ops.aten.view.default(view_as_real_92, [2, 8192, 8, 128]); view_as_real_92 = None + convert_element_type_1861 = torch.ops.prims.convert_element_type.default(view_1444, torch.bfloat16); view_1444 = None + view_as_real_93 = torch.ops.aten.view_as_real.default(mul_557); mul_557 = None + view_1445 = torch.ops.aten.view.default(view_as_real_93, [2, 8192, 32, 128]); view_as_real_93 = None + convert_element_type_1862 = torch.ops.prims.convert_element_type.default(view_1445, torch.bfloat16); view_1445 = None + view_1446 = torch.ops.aten.view.default(squeeze_28, [2, 8192, 1024]); squeeze_28 = None + view_1447 = torch.ops.aten.view.default(convert_element_type_1861, [2, 8192, 1024]); convert_element_type_1861 = None + view_1448 = torch.ops.aten.view.default(convert_element_type_1862, [2, 8192, 4096]); convert_element_type_1862 = None + view_1449 = torch.ops.aten.view.default(view_1446, [16384, 1024]); view_1446 = None + permute_825 = torch.ops.aten.permute.default(view_1449, [1, 0]) + mm_431 = torch.ops.aten.mm.default(permute_825, view_581); permute_825 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16); primals_160 = None + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 64, '0'); convert_element_type_571 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + permute_827 = torch.ops.aten.permute.default(permute_189, [1, 0]); permute_189 = None + mm_432 = torch.ops.aten.mm.default(view_1449, permute_827); view_1449 = permute_827 = None + view_1450 = torch.ops.aten.view.default(mm_432, [2, 8192, 4096]); mm_432 = None + convert_element_type_1867 = torch.ops.prims.convert_element_type.default(mm_431, torch.float32); mm_431 = None + reduce_scatter_tensor_133 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1867, 'avg', 64, '0'); convert_element_type_1867 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_133); reduce_scatter_tensor_133 = None + view_1451 = torch.ops.aten.view.default(view_1447, [16384, 1024]); view_1447 = None + permute_829 = torch.ops.aten.permute.default(view_1451, [1, 0]) + mm_433 = torch.ops.aten.mm.default(permute_829, view_581); permute_829 = None + permute_831 = torch.ops.aten.permute.default(permute_188, [1, 0]); permute_188 = None + mm_434 = torch.ops.aten.mm.default(view_1451, permute_831); view_1451 = permute_831 = None + view_1452 = torch.ops.aten.view.default(mm_434, [2, 8192, 4096]); mm_434 = None + add_231 = torch.ops.aten.add.Tensor(view_1450, view_1452); view_1450 = view_1452 = None + convert_element_type_1872 = torch.ops.prims.convert_element_type.default(mm_433, torch.float32); mm_433 = None + reduce_scatter_tensor_134 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1872, 'avg', 64, '0'); convert_element_type_1872 = None + wait_tensor_425 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_134); reduce_scatter_tensor_134 = None + view_1453 = torch.ops.aten.view.default(view_1448, [16384, 4096]); view_1448 = None + permute_833 = torch.ops.aten.permute.default(view_1453, [1, 0]) + mm_435 = torch.ops.aten.mm.default(permute_833, view_581); permute_833 = view_581 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16); primals_158 = None + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 64, '0'); convert_element_type_565 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + permute_835 = torch.ops.aten.permute.default(permute_187, [1, 0]); permute_187 = None + mm_436 = torch.ops.aten.mm.default(view_1453, permute_835); view_1453 = permute_835 = None + view_1454 = torch.ops.aten.view.default(mm_436, [2, 8192, 4096]); mm_436 = None + add_232 = torch.ops.aten.add.Tensor(add_231, view_1454); add_231 = view_1454 = None + convert_element_type_1877 = torch.ops.prims.convert_element_type.default(mm_435, torch.float32); mm_435 = None + reduce_scatter_tensor_135 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1877, 'avg', 64, '0'); convert_element_type_1877 = None + wait_tensor_426 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_135); reduce_scatter_tensor_135 = None + convert_element_type_1878 = torch.ops.prims.convert_element_type.default(add_232, torch.float32); add_232 = None + convert_element_type_1880 = torch.ops.prims.convert_element_type.default(wait_tensor_154, torch.float32); wait_tensor_154 = None + mul_558 = torch.ops.aten.mul.Tensor(convert_element_type_1878, convert_element_type_1880); convert_element_type_1880 = None + mul_560 = torch.ops.aten.mul.Tensor(mul_136, mul_558) + sum_91 = torch.ops.aten.sum.dim_IntList(mul_560, [2], True); mul_560 = None + div_30 = torch.ops.aten.div.Tensor(mul_136, 4096) + mul_561 = torch.ops.aten.mul.Tensor(div_30, sum_91); div_30 = sum_91 = None + sub_45 = torch.ops.aten.sub.Tensor(mul_558, mul_561); mul_558 = mul_561 = None + mul_562 = torch.ops.aten.mul.Tensor(sub_45, rsqrt_34); sub_45 = rsqrt_34 = None + mul_563 = torch.ops.aten.mul.Tensor(convert_element_type_1878, mul_136); convert_element_type_1878 = mul_136 = None + sum_92 = torch.ops.aten.sum.dim_IntList(mul_563, [0, 1]); mul_563 = None + convert_element_type_1881 = torch.ops.prims.convert_element_type.default(mul_562, torch.bfloat16); mul_562 = None + add_233 = torch.ops.aten.add.Tensor(add_230, convert_element_type_1881); add_230 = convert_element_type_1881 = None + convert_element_type_default_35 = torch.ops.prims.convert_element_type.default(sum_92, torch.float32); sum_92 = None + reduce_scatter_tensor_136 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_35, 'avg', 64, '0'); convert_element_type_default_35 = None + wait_tensor_427 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_136); reduce_scatter_tensor_136 = None + view_1455 = torch.ops.aten.view.default(add_233, [16384, 4096]) + permute_837 = torch.ops.aten.permute.default(view_1455, [1, 0]) + permute_182 = torch.ops.aten.permute.default(getitem_144, [0, 2, 1, 3]) + view_565 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16); primals_152 = None + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 64, '0'); convert_element_type_545 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + view_567 = torch.ops.aten.view.default(view_565, [16384, 4096]); view_565 = None + mm_115 = torch.ops.aten.mm.default(view_567, permute_183) + view_568 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + add_65 = torch.ops.aten.add.Tensor(add_63, view_568); view_568 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16); primals_153 = None + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 64, '0'); convert_element_type_548 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32); add_65 = None + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_150) + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + view_571 = torch.ops.aten.view.default(convert_element_type_550, [16384, 4096]); convert_element_type_550 = None + view_572 = torch.ops.aten.view.default(mm_116, [2, 8192, 14336]); mm_116 = None + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_572, torch.float32); view_572 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16); primals_155 = None + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 64, '0'); convert_element_type_556 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_152, [1, 0]); wait_tensor_152 = None + mm_117 = torch.ops.aten.mm.default(view_571, permute_185) + view_575 = torch.ops.aten.view.default(mm_117, [2, 8192, 14336]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_575) + view_577 = torch.ops.aten.view.default(mul_135, [16384, 14336]); mul_135 = None + mm_437 = torch.ops.aten.mm.default(permute_837, view_577); permute_837 = view_577 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16); primals_156 = None + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 64, '0'); convert_element_type_559 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_153, [1, 0]); wait_tensor_153 = None + permute_839 = torch.ops.aten.permute.default(permute_186, [1, 0]); permute_186 = None + mm_438 = torch.ops.aten.mm.default(view_1455, permute_839); view_1455 = permute_839 = None + view_1456 = torch.ops.aten.view.default(mm_438, [2, 8192, 14336]); mm_438 = None + convert_element_type_1888 = torch.ops.prims.convert_element_type.default(mm_437, torch.float32); mm_437 = None + reduce_scatter_tensor_137 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1888, 'avg', 64, '0'); convert_element_type_1888 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_137); reduce_scatter_tensor_137 = None + mul_564 = torch.ops.aten.mul.Tensor(view_1456, convert_element_type_555); convert_element_type_555 = None + mul_565 = torch.ops.aten.mul.Tensor(view_1456, view_575); view_1456 = view_575 = None + view_1457 = torch.ops.aten.view.default(mul_564, [16384, 14336]); mul_564 = None + permute_841 = torch.ops.aten.permute.default(view_1457, [1, 0]) + mm_439 = torch.ops.aten.mm.default(permute_841, view_571); permute_841 = None + permute_843 = torch.ops.aten.permute.default(permute_185, [1, 0]); permute_185 = None + mm_440 = torch.ops.aten.mm.default(view_1457, permute_843); view_1457 = permute_843 = None + view_1458 = torch.ops.aten.view.default(mm_440, [2, 8192, 4096]); mm_440 = None + convert_element_type_1893 = torch.ops.prims.convert_element_type.default(mm_439, torch.float32); mm_439 = None + reduce_scatter_tensor_138 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1893, 'avg', 64, '0'); convert_element_type_1893 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_138); reduce_scatter_tensor_138 = None + convert_element_type_1894 = torch.ops.prims.convert_element_type.default(mul_565, torch.float32); mul_565 = None + neg_15 = torch.ops.aten.neg.default(convert_element_type_554) + exp_15 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_234 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + reciprocal_15 = torch.ops.aten.reciprocal.default(add_234); add_234 = None + mul_566 = torch.ops.aten.mul.Tensor(reciprocal_15, 1); reciprocal_15 = None + mul_567 = torch.ops.aten.mul.Tensor(convert_element_type_1894, mul_566); convert_element_type_1894 = None + sub_46 = torch.ops.aten.sub.Tensor(1, mul_566); mul_566 = None + mul_568 = torch.ops.aten.mul.Tensor(convert_element_type_554, sub_46); convert_element_type_554 = sub_46 = None + add_235 = torch.ops.aten.add.Tensor(mul_568, 1); mul_568 = None + mul_569 = torch.ops.aten.mul.Tensor(mul_567, add_235); mul_567 = add_235 = None + convert_element_type_1896 = torch.ops.prims.convert_element_type.default(mul_569, torch.bfloat16); mul_569 = None + view_1459 = torch.ops.aten.view.default(convert_element_type_1896, [16384, 14336]); convert_element_type_1896 = None + permute_845 = torch.ops.aten.permute.default(view_1459, [1, 0]) + mm_441 = torch.ops.aten.mm.default(permute_845, view_571); permute_845 = view_571 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16); primals_154 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 64, '0'); convert_element_type_551 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_151, [1, 0]); wait_tensor_151 = None + permute_847 = torch.ops.aten.permute.default(permute_184, [1, 0]); permute_184 = None + mm_442 = torch.ops.aten.mm.default(view_1459, permute_847); view_1459 = permute_847 = None + view_1460 = torch.ops.aten.view.default(mm_442, [2, 8192, 4096]); mm_442 = None + add_236 = torch.ops.aten.add.Tensor(view_1458, view_1460); view_1458 = view_1460 = None + convert_element_type_1901 = torch.ops.prims.convert_element_type.default(mm_441, torch.float32); mm_441 = None + reduce_scatter_tensor_139 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1901, 'avg', 64, '0'); convert_element_type_1901 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_139); reduce_scatter_tensor_139 = None + convert_element_type_1902 = torch.ops.prims.convert_element_type.default(add_236, torch.float32); add_236 = None + convert_element_type_1904 = torch.ops.prims.convert_element_type.default(wait_tensor_150, torch.float32); wait_tensor_150 = None + mul_570 = torch.ops.aten.mul.Tensor(convert_element_type_1902, convert_element_type_1904); convert_element_type_1904 = None + mul_572 = torch.ops.aten.mul.Tensor(mul_132, mul_570) + sum_93 = torch.ops.aten.sum.dim_IntList(mul_572, [2], True); mul_572 = None + div_31 = torch.ops.aten.div.Tensor(mul_132, 4096) + mul_573 = torch.ops.aten.mul.Tensor(div_31, sum_93); div_31 = sum_93 = None + sub_47 = torch.ops.aten.sub.Tensor(mul_570, mul_573); mul_570 = mul_573 = None + mul_574 = torch.ops.aten.mul.Tensor(sub_47, rsqrt_33); sub_47 = rsqrt_33 = None + mul_575 = torch.ops.aten.mul.Tensor(convert_element_type_1902, mul_132); convert_element_type_1902 = mul_132 = None + sum_94 = torch.ops.aten.sum.dim_IntList(mul_575, [0, 1]); mul_575 = None + convert_element_type_1905 = torch.ops.prims.convert_element_type.default(mul_574, torch.bfloat16); mul_574 = None + add_237 = torch.ops.aten.add.Tensor(add_233, convert_element_type_1905); add_233 = convert_element_type_1905 = None + convert_element_type_default_34 = torch.ops.prims.convert_element_type.default(sum_94, torch.float32); sum_94 = None + reduce_scatter_tensor_140 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_34, 'avg', 64, '0'); convert_element_type_default_34 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_140); reduce_scatter_tensor_140 = None + view_1461 = torch.ops.aten.view.default(add_237, [16384, 4096]) + permute_849 = torch.ops.aten.permute.default(view_1461, [1, 0]) + mm_443 = torch.ops.aten.mm.default(permute_849, view_567); permute_849 = view_567 = None + permute_851 = torch.ops.aten.permute.default(permute_183, [1, 0]); permute_183 = None + mm_444 = torch.ops.aten.mm.default(view_1461, permute_851); view_1461 = permute_851 = None + view_1462 = torch.ops.aten.view.default(mm_444, [2, 8192, 4096]); mm_444 = None + convert_element_type_1912 = torch.ops.prims.convert_element_type.default(mm_443, torch.float32); mm_443 = None + reduce_scatter_tensor_141 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1912, 'avg', 64, '0'); convert_element_type_1912 = None + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_141); reduce_scatter_tensor_141 = None + view_1463 = torch.ops.aten.view.default(view_1462, [2, 8192, 32, 128]); view_1462 = None + permute_853 = torch.ops.aten.permute.default(view_1463, [0, 2, 1, 3]); view_1463 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16); primals_148 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 64, '0'); convert_element_type_529 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32); add_63 = None + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_145) + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + view_547 = torch.ops.aten.view.default(convert_element_type_531, [16384, 4096]); convert_element_type_531 = None + view_548 = torch.ops.aten.view.default(mm_112, [2, 8192, 4096]); mm_112 = None + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16); primals_150 = None + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 64, '0'); convert_element_type_535 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + mm_113 = torch.ops.aten.mm.default(view_547, permute_177) + view_551 = torch.ops.aten.view.default(mm_113, [2, 8192, 1024]); mm_113 = None + view_554 = torch.ops.aten.view.default(mm_114, [2, 8192, 1024]); mm_114 = None + view_555 = torch.ops.aten.view.default(view_548, [2, 8192, -1, 128]); view_548 = None + view_556 = torch.ops.aten.view.default(view_551, [2, 8192, -1, 128]); view_551 = None + view_557 = torch.ops.aten.view.default(view_554, [2, 8192, -1, 128]); view_554 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_555, torch.float32); view_555 = None + view_558 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 32, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_558); view_558 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_556, torch.float32); view_556 = None + view_559 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 8, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_559); view_559 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_16); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_561 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 32, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_16); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_562 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 8, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_561, torch.bfloat16); view_561 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_562, torch.bfloat16); view_562 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 8, 4, 128]); unsqueeze_32 = None + clone_32 = torch.ops.aten.clone.default(expand_32, memory_format = torch.contiguous_format); expand_32 = None + view_563 = torch.ops.aten.view.default(clone_32, [2, 8192, 32, 128]); clone_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_557, 3); view_557 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 8, 4, 128]); unsqueeze_33 = None + clone_33 = torch.ops.aten.clone.default(expand_33, memory_format = torch.contiguous_format); expand_33 = None + view_564 = torch.ops.aten.view.default(clone_33, [2, 8192, 32, 128]); clone_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_563, [0, 2, 1, 3]); view_563 = None + permute_181 = torch.ops.aten.permute.default(view_564, [0, 2, 1, 3]); view_564 = None + _scaled_dot_product_cudnn_attention_backward_15 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_853, permute_179, permute_180, permute_181, getitem_144, getitem_145, getitem_150, getitem_151, None, None, None, 8192, 8192, 0.0, True); permute_853 = permute_179 = permute_180 = permute_181 = getitem_144 = getitem_145 = getitem_150 = getitem_151 = None + getitem_333 = _scaled_dot_product_cudnn_attention_backward_15[0] + getitem_334 = _scaled_dot_product_cudnn_attention_backward_15[1] + getitem_335 = _scaled_dot_product_cudnn_attention_backward_15[2]; _scaled_dot_product_cudnn_attention_backward_15 = None + permute_854 = torch.ops.aten.permute.default(getitem_335, [0, 2, 1, 3]); getitem_335 = None + permute_855 = torch.ops.aten.permute.default(getitem_334, [0, 2, 1, 3]); getitem_334 = None + permute_856 = torch.ops.aten.permute.default(getitem_333, [0, 2, 1, 3]); getitem_333 = None + view_1464 = torch.ops.aten.view.default(permute_854, [2, 8192, 8, 4, 128]); permute_854 = None + sum_95 = torch.ops.aten.sum.dim_IntList(view_1464, [3], True); view_1464 = None + squeeze_30 = torch.ops.aten.squeeze.dim(sum_95, 3); sum_95 = None + view_1465 = torch.ops.aten.view.default(permute_855, [2, 8192, 8, 4, 128]); permute_855 = None + sum_96 = torch.ops.aten.sum.dim_IntList(view_1465, [3], True); view_1465 = None + squeeze_31 = torch.ops.aten.squeeze.dim(sum_96, 3); sum_96 = None + convert_element_type_1913 = torch.ops.prims.convert_element_type.default(squeeze_31, torch.float32); squeeze_31 = None + convert_element_type_1914 = torch.ops.prims.convert_element_type.default(permute_856, torch.float32); permute_856 = None + view_1466 = torch.ops.aten.view.default(convert_element_type_1913, [2, 8192, 8, 64, 2]); convert_element_type_1913 = None + view_as_complex_94 = torch.ops.aten.view_as_complex.default(view_1466); view_1466 = None + mul_576 = torch.ops.aten.mul.Tensor(view_as_complex_94, _conj); view_as_complex_94 = None + view_1467 = torch.ops.aten.view.default(convert_element_type_1914, [2, 8192, 32, 64, 2]); convert_element_type_1914 = None + view_as_complex_95 = torch.ops.aten.view_as_complex.default(view_1467); view_1467 = None + mul_577 = torch.ops.aten.mul.Tensor(view_as_complex_95, _conj); view_as_complex_95 = None + view_as_real_94 = torch.ops.aten.view_as_real.default(mul_576); mul_576 = None + view_1468 = torch.ops.aten.view.default(view_as_real_94, [2, 8192, 8, 128]); view_as_real_94 = None + convert_element_type_1915 = torch.ops.prims.convert_element_type.default(view_1468, torch.bfloat16); view_1468 = None + view_as_real_95 = torch.ops.aten.view_as_real.default(mul_577); mul_577 = None + view_1469 = torch.ops.aten.view.default(view_as_real_95, [2, 8192, 32, 128]); view_as_real_95 = None + convert_element_type_1916 = torch.ops.prims.convert_element_type.default(view_1469, torch.bfloat16); view_1469 = None + view_1470 = torch.ops.aten.view.default(squeeze_30, [2, 8192, 1024]); squeeze_30 = None + view_1471 = torch.ops.aten.view.default(convert_element_type_1915, [2, 8192, 1024]); convert_element_type_1915 = None + view_1472 = torch.ops.aten.view.default(convert_element_type_1916, [2, 8192, 4096]); convert_element_type_1916 = None + view_1473 = torch.ops.aten.view.default(view_1470, [16384, 1024]); view_1470 = None + permute_857 = torch.ops.aten.permute.default(view_1473, [1, 0]) + mm_445 = torch.ops.aten.mm.default(permute_857, view_547); permute_857 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16); primals_151 = None + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 64, '0'); convert_element_type_538 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + permute_859 = torch.ops.aten.permute.default(permute_178, [1, 0]); permute_178 = None + mm_446 = torch.ops.aten.mm.default(view_1473, permute_859); view_1473 = permute_859 = None + view_1474 = torch.ops.aten.view.default(mm_446, [2, 8192, 4096]); mm_446 = None + convert_element_type_1921 = torch.ops.prims.convert_element_type.default(mm_445, torch.float32); mm_445 = None + reduce_scatter_tensor_142 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1921, 'avg', 64, '0'); convert_element_type_1921 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_142); reduce_scatter_tensor_142 = None + view_1475 = torch.ops.aten.view.default(view_1471, [16384, 1024]); view_1471 = None + permute_861 = torch.ops.aten.permute.default(view_1475, [1, 0]) + mm_447 = torch.ops.aten.mm.default(permute_861, view_547); permute_861 = None + permute_863 = torch.ops.aten.permute.default(permute_177, [1, 0]); permute_177 = None + mm_448 = torch.ops.aten.mm.default(view_1475, permute_863); view_1475 = permute_863 = None + view_1476 = torch.ops.aten.view.default(mm_448, [2, 8192, 4096]); mm_448 = None + add_238 = torch.ops.aten.add.Tensor(view_1474, view_1476); view_1474 = view_1476 = None + convert_element_type_1926 = torch.ops.prims.convert_element_type.default(mm_447, torch.float32); mm_447 = None + reduce_scatter_tensor_143 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1926, 'avg', 64, '0'); convert_element_type_1926 = None + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_143); reduce_scatter_tensor_143 = None + view_1477 = torch.ops.aten.view.default(view_1472, [16384, 4096]); view_1472 = None + permute_865 = torch.ops.aten.permute.default(view_1477, [1, 0]) + mm_449 = torch.ops.aten.mm.default(permute_865, view_547); permute_865 = view_547 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16); primals_149 = None + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 64, '0'); convert_element_type_532 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_146, [1, 0]); wait_tensor_146 = None + permute_867 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None + mm_450 = torch.ops.aten.mm.default(view_1477, permute_867); view_1477 = permute_867 = None + view_1478 = torch.ops.aten.view.default(mm_450, [2, 8192, 4096]); mm_450 = None + add_239 = torch.ops.aten.add.Tensor(add_238, view_1478); add_238 = view_1478 = None + convert_element_type_1931 = torch.ops.prims.convert_element_type.default(mm_449, torch.float32); mm_449 = None + reduce_scatter_tensor_144 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1931, 'avg', 64, '0'); convert_element_type_1931 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_144); reduce_scatter_tensor_144 = None + convert_element_type_1932 = torch.ops.prims.convert_element_type.default(add_239, torch.float32); add_239 = None + convert_element_type_1934 = torch.ops.prims.convert_element_type.default(wait_tensor_145, torch.float32); wait_tensor_145 = None + mul_578 = torch.ops.aten.mul.Tensor(convert_element_type_1932, convert_element_type_1934); convert_element_type_1934 = None + mul_580 = torch.ops.aten.mul.Tensor(mul_128, mul_578) + sum_97 = torch.ops.aten.sum.dim_IntList(mul_580, [2], True); mul_580 = None + div_32 = torch.ops.aten.div.Tensor(mul_128, 4096) + mul_581 = torch.ops.aten.mul.Tensor(div_32, sum_97); div_32 = sum_97 = None + sub_48 = torch.ops.aten.sub.Tensor(mul_578, mul_581); mul_578 = mul_581 = None + mul_582 = torch.ops.aten.mul.Tensor(sub_48, rsqrt_32); sub_48 = rsqrt_32 = None + mul_583 = torch.ops.aten.mul.Tensor(convert_element_type_1932, mul_128); convert_element_type_1932 = mul_128 = None + sum_98 = torch.ops.aten.sum.dim_IntList(mul_583, [0, 1]); mul_583 = None + convert_element_type_1935 = torch.ops.prims.convert_element_type.default(mul_582, torch.bfloat16); mul_582 = None + add_240 = torch.ops.aten.add.Tensor(add_237, convert_element_type_1935); add_237 = convert_element_type_1935 = None + convert_element_type_default_33 = torch.ops.prims.convert_element_type.default(sum_98, torch.float32); sum_98 = None + reduce_scatter_tensor_145 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_33, 'avg', 64, '0'); convert_element_type_default_33 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_145); reduce_scatter_tensor_145 = None + view_1479 = torch.ops.aten.view.default(add_240, [16384, 4096]) + permute_869 = torch.ops.aten.permute.default(view_1479, [1, 0]) + permute_171 = torch.ops.aten.permute.default(getitem_135, [0, 2, 1, 3]) + view_531 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 64, '0'); convert_element_type_512 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_140, [1, 0]); wait_tensor_140 = None + view_533 = torch.ops.aten.view.default(view_531, [16384, 4096]); view_531 = None + mm_108 = torch.ops.aten.mm.default(view_533, permute_172) + view_534 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + add_61 = torch.ops.aten.add.Tensor(add_59, view_534); view_534 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16); primals_144 = None + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 64, '0'); convert_element_type_515 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32); add_61 = None + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_141) + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + view_537 = torch.ops.aten.view.default(convert_element_type_517, [16384, 4096]); convert_element_type_517 = None + view_538 = torch.ops.aten.view.default(mm_109, [2, 8192, 14336]); mm_109 = None + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_538, torch.float32); view_538 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16); primals_146 = None + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 64, '0'); convert_element_type_523 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + mm_110 = torch.ops.aten.mm.default(view_537, permute_174) + view_541 = torch.ops.aten.view.default(mm_110, [2, 8192, 14336]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_541) + view_543 = torch.ops.aten.view.default(mul_127, [16384, 14336]); mul_127 = None + mm_451 = torch.ops.aten.mm.default(permute_869, view_543); permute_869 = view_543 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16); primals_147 = None + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 64, '0'); convert_element_type_526 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_144, [1, 0]); wait_tensor_144 = None + permute_871 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None + mm_452 = torch.ops.aten.mm.default(view_1479, permute_871); view_1479 = permute_871 = None + view_1480 = torch.ops.aten.view.default(mm_452, [2, 8192, 14336]); mm_452 = None + convert_element_type_1942 = torch.ops.prims.convert_element_type.default(mm_451, torch.float32); mm_451 = None + reduce_scatter_tensor_146 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1942, 'avg', 64, '0'); convert_element_type_1942 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_146); reduce_scatter_tensor_146 = None + mul_584 = torch.ops.aten.mul.Tensor(view_1480, convert_element_type_522); convert_element_type_522 = None + mul_585 = torch.ops.aten.mul.Tensor(view_1480, view_541); view_1480 = view_541 = None + view_1481 = torch.ops.aten.view.default(mul_584, [16384, 14336]); mul_584 = None + permute_873 = torch.ops.aten.permute.default(view_1481, [1, 0]) + mm_453 = torch.ops.aten.mm.default(permute_873, view_537); permute_873 = None + permute_875 = torch.ops.aten.permute.default(permute_174, [1, 0]); permute_174 = None + mm_454 = torch.ops.aten.mm.default(view_1481, permute_875); view_1481 = permute_875 = None + view_1482 = torch.ops.aten.view.default(mm_454, [2, 8192, 4096]); mm_454 = None + convert_element_type_1947 = torch.ops.prims.convert_element_type.default(mm_453, torch.float32); mm_453 = None + reduce_scatter_tensor_147 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1947, 'avg', 64, '0'); convert_element_type_1947 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_147); reduce_scatter_tensor_147 = None + convert_element_type_1948 = torch.ops.prims.convert_element_type.default(mul_585, torch.float32); mul_585 = None + neg_16 = torch.ops.aten.neg.default(convert_element_type_521) + exp_16 = torch.ops.aten.exp.default(neg_16); neg_16 = None + add_241 = torch.ops.aten.add.Tensor(exp_16, 1); exp_16 = None + reciprocal_16 = torch.ops.aten.reciprocal.default(add_241); add_241 = None + mul_586 = torch.ops.aten.mul.Tensor(reciprocal_16, 1); reciprocal_16 = None + mul_587 = torch.ops.aten.mul.Tensor(convert_element_type_1948, mul_586); convert_element_type_1948 = None + sub_49 = torch.ops.aten.sub.Tensor(1, mul_586); mul_586 = None + mul_588 = torch.ops.aten.mul.Tensor(convert_element_type_521, sub_49); convert_element_type_521 = sub_49 = None + add_242 = torch.ops.aten.add.Tensor(mul_588, 1); mul_588 = None + mul_589 = torch.ops.aten.mul.Tensor(mul_587, add_242); mul_587 = add_242 = None + convert_element_type_1950 = torch.ops.prims.convert_element_type.default(mul_589, torch.bfloat16); mul_589 = None + view_1483 = torch.ops.aten.view.default(convert_element_type_1950, [16384, 14336]); convert_element_type_1950 = None + permute_877 = torch.ops.aten.permute.default(view_1483, [1, 0]) + mm_455 = torch.ops.aten.mm.default(permute_877, view_537); permute_877 = view_537 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16); primals_145 = None + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 64, '0'); convert_element_type_518 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + permute_879 = torch.ops.aten.permute.default(permute_173, [1, 0]); permute_173 = None + mm_456 = torch.ops.aten.mm.default(view_1483, permute_879); view_1483 = permute_879 = None + view_1484 = torch.ops.aten.view.default(mm_456, [2, 8192, 4096]); mm_456 = None + add_243 = torch.ops.aten.add.Tensor(view_1482, view_1484); view_1482 = view_1484 = None + convert_element_type_1955 = torch.ops.prims.convert_element_type.default(mm_455, torch.float32); mm_455 = None + reduce_scatter_tensor_148 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1955, 'avg', 64, '0'); convert_element_type_1955 = None + wait_tensor_439 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_148); reduce_scatter_tensor_148 = None + convert_element_type_1956 = torch.ops.prims.convert_element_type.default(add_243, torch.float32); add_243 = None + convert_element_type_1958 = torch.ops.prims.convert_element_type.default(wait_tensor_141, torch.float32); wait_tensor_141 = None + mul_590 = torch.ops.aten.mul.Tensor(convert_element_type_1956, convert_element_type_1958); convert_element_type_1958 = None + mul_592 = torch.ops.aten.mul.Tensor(mul_124, mul_590) + sum_99 = torch.ops.aten.sum.dim_IntList(mul_592, [2], True); mul_592 = None + div_33 = torch.ops.aten.div.Tensor(mul_124, 4096) + mul_593 = torch.ops.aten.mul.Tensor(div_33, sum_99); div_33 = sum_99 = None + sub_50 = torch.ops.aten.sub.Tensor(mul_590, mul_593); mul_590 = mul_593 = None + mul_594 = torch.ops.aten.mul.Tensor(sub_50, rsqrt_31); sub_50 = rsqrt_31 = None + mul_595 = torch.ops.aten.mul.Tensor(convert_element_type_1956, mul_124); convert_element_type_1956 = mul_124 = None + sum_100 = torch.ops.aten.sum.dim_IntList(mul_595, [0, 1]); mul_595 = None + convert_element_type_1959 = torch.ops.prims.convert_element_type.default(mul_594, torch.bfloat16); mul_594 = None + add_244 = torch.ops.aten.add.Tensor(add_240, convert_element_type_1959); add_240 = convert_element_type_1959 = None + convert_element_type_default_32 = torch.ops.prims.convert_element_type.default(sum_100, torch.float32); sum_100 = None + reduce_scatter_tensor_149 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_32, 'avg', 64, '0'); convert_element_type_default_32 = None + wait_tensor_440 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_149); reduce_scatter_tensor_149 = None + view_1485 = torch.ops.aten.view.default(add_244, [16384, 4096]) + permute_881 = torch.ops.aten.permute.default(view_1485, [1, 0]) + mm_457 = torch.ops.aten.mm.default(permute_881, view_533); permute_881 = view_533 = None + permute_883 = torch.ops.aten.permute.default(permute_172, [1, 0]); permute_172 = None + mm_458 = torch.ops.aten.mm.default(view_1485, permute_883); view_1485 = permute_883 = None + view_1486 = torch.ops.aten.view.default(mm_458, [2, 8192, 4096]); mm_458 = None + convert_element_type_1966 = torch.ops.prims.convert_element_type.default(mm_457, torch.float32); mm_457 = None + reduce_scatter_tensor_150 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1966, 'avg', 64, '0'); convert_element_type_1966 = None + wait_tensor_441 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_150); reduce_scatter_tensor_150 = None + view_1487 = torch.ops.aten.view.default(view_1486, [2, 8192, 32, 128]); view_1486 = None + permute_885 = torch.ops.aten.permute.default(view_1487, [0, 2, 1, 3]); view_1487 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 64, '0'); convert_element_type_496 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32); add_59 = None + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_136) + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + view_513 = torch.ops.aten.view.default(convert_element_type_498, [16384, 4096]); convert_element_type_498 = None + view_514 = torch.ops.aten.view.default(mm_105, [2, 8192, 4096]); mm_105 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 64, '0'); convert_element_type_502 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + mm_106 = torch.ops.aten.mm.default(view_513, permute_166) + view_517 = torch.ops.aten.view.default(mm_106, [2, 8192, 1024]); mm_106 = None + view_520 = torch.ops.aten.view.default(mm_107, [2, 8192, 1024]); mm_107 = None + view_521 = torch.ops.aten.view.default(view_514, [2, 8192, -1, 128]); view_514 = None + view_522 = torch.ops.aten.view.default(view_517, [2, 8192, -1, 128]); view_517 = None + view_523 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_521, torch.float32); view_521 = None + view_524 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 32, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_524); view_524 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_522, torch.float32); view_522 = None + view_525 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 8, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_525); view_525 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_16); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_527 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 32, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_16); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_528 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 8, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_527, torch.bfloat16); view_527 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_528, torch.bfloat16); view_528 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 8, 4, 128]); unsqueeze_30 = None + clone_30 = torch.ops.aten.clone.default(expand_30, memory_format = torch.contiguous_format); expand_30 = None + view_529 = torch.ops.aten.view.default(clone_30, [2, 8192, 32, 128]); clone_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_523, 3); view_523 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 8, 4, 128]); unsqueeze_31 = None + clone_31 = torch.ops.aten.clone.default(expand_31, memory_format = torch.contiguous_format); expand_31 = None + view_530 = torch.ops.aten.view.default(clone_31, [2, 8192, 32, 128]); clone_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_529, [0, 2, 1, 3]); view_529 = None + permute_170 = torch.ops.aten.permute.default(view_530, [0, 2, 1, 3]); view_530 = None + _scaled_dot_product_cudnn_attention_backward_16 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_885, permute_168, permute_169, permute_170, getitem_135, getitem_136, getitem_141, getitem_142, None, None, None, 8192, 8192, 0.0, True); permute_885 = permute_168 = permute_169 = permute_170 = getitem_135 = getitem_136 = getitem_141 = getitem_142 = None + getitem_336 = _scaled_dot_product_cudnn_attention_backward_16[0] + getitem_337 = _scaled_dot_product_cudnn_attention_backward_16[1] + getitem_338 = _scaled_dot_product_cudnn_attention_backward_16[2]; _scaled_dot_product_cudnn_attention_backward_16 = None + permute_886 = torch.ops.aten.permute.default(getitem_338, [0, 2, 1, 3]); getitem_338 = None + permute_887 = torch.ops.aten.permute.default(getitem_337, [0, 2, 1, 3]); getitem_337 = None + permute_888 = torch.ops.aten.permute.default(getitem_336, [0, 2, 1, 3]); getitem_336 = None + view_1488 = torch.ops.aten.view.default(permute_886, [2, 8192, 8, 4, 128]); permute_886 = None + sum_101 = torch.ops.aten.sum.dim_IntList(view_1488, [3], True); view_1488 = None + squeeze_32 = torch.ops.aten.squeeze.dim(sum_101, 3); sum_101 = None + view_1489 = torch.ops.aten.view.default(permute_887, [2, 8192, 8, 4, 128]); permute_887 = None + sum_102 = torch.ops.aten.sum.dim_IntList(view_1489, [3], True); view_1489 = None + squeeze_33 = torch.ops.aten.squeeze.dim(sum_102, 3); sum_102 = None + convert_element_type_1967 = torch.ops.prims.convert_element_type.default(squeeze_33, torch.float32); squeeze_33 = None + convert_element_type_1968 = torch.ops.prims.convert_element_type.default(permute_888, torch.float32); permute_888 = None + view_1490 = torch.ops.aten.view.default(convert_element_type_1967, [2, 8192, 8, 64, 2]); convert_element_type_1967 = None + view_as_complex_96 = torch.ops.aten.view_as_complex.default(view_1490); view_1490 = None + mul_596 = torch.ops.aten.mul.Tensor(view_as_complex_96, _conj); view_as_complex_96 = None + view_1491 = torch.ops.aten.view.default(convert_element_type_1968, [2, 8192, 32, 64, 2]); convert_element_type_1968 = None + view_as_complex_97 = torch.ops.aten.view_as_complex.default(view_1491); view_1491 = None + mul_597 = torch.ops.aten.mul.Tensor(view_as_complex_97, _conj); view_as_complex_97 = None + view_as_real_96 = torch.ops.aten.view_as_real.default(mul_596); mul_596 = None + view_1492 = torch.ops.aten.view.default(view_as_real_96, [2, 8192, 8, 128]); view_as_real_96 = None + convert_element_type_1969 = torch.ops.prims.convert_element_type.default(view_1492, torch.bfloat16); view_1492 = None + view_as_real_97 = torch.ops.aten.view_as_real.default(mul_597); mul_597 = None + view_1493 = torch.ops.aten.view.default(view_as_real_97, [2, 8192, 32, 128]); view_as_real_97 = None + convert_element_type_1970 = torch.ops.prims.convert_element_type.default(view_1493, torch.bfloat16); view_1493 = None + view_1494 = torch.ops.aten.view.default(squeeze_32, [2, 8192, 1024]); squeeze_32 = None + view_1495 = torch.ops.aten.view.default(convert_element_type_1969, [2, 8192, 1024]); convert_element_type_1969 = None + view_1496 = torch.ops.aten.view.default(convert_element_type_1970, [2, 8192, 4096]); convert_element_type_1970 = None + view_1497 = torch.ops.aten.view.default(view_1494, [16384, 1024]); view_1494 = None + permute_889 = torch.ops.aten.permute.default(view_1497, [1, 0]) + mm_459 = torch.ops.aten.mm.default(permute_889, view_513); permute_889 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16); primals_142 = None + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 64, '0'); convert_element_type_505 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_139, [1, 0]); wait_tensor_139 = None + permute_891 = torch.ops.aten.permute.default(permute_167, [1, 0]); permute_167 = None + mm_460 = torch.ops.aten.mm.default(view_1497, permute_891); view_1497 = permute_891 = None + view_1498 = torch.ops.aten.view.default(mm_460, [2, 8192, 4096]); mm_460 = None + convert_element_type_1975 = torch.ops.prims.convert_element_type.default(mm_459, torch.float32); mm_459 = None + reduce_scatter_tensor_151 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1975, 'avg', 64, '0'); convert_element_type_1975 = None + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_151); reduce_scatter_tensor_151 = None + view_1499 = torch.ops.aten.view.default(view_1495, [16384, 1024]); view_1495 = None + permute_893 = torch.ops.aten.permute.default(view_1499, [1, 0]) + mm_461 = torch.ops.aten.mm.default(permute_893, view_513); permute_893 = None + permute_895 = torch.ops.aten.permute.default(permute_166, [1, 0]); permute_166 = None + mm_462 = torch.ops.aten.mm.default(view_1499, permute_895); view_1499 = permute_895 = None + view_1500 = torch.ops.aten.view.default(mm_462, [2, 8192, 4096]); mm_462 = None + add_245 = torch.ops.aten.add.Tensor(view_1498, view_1500); view_1498 = view_1500 = None + convert_element_type_1980 = torch.ops.prims.convert_element_type.default(mm_461, torch.float32); mm_461 = None + reduce_scatter_tensor_152 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1980, 'avg', 64, '0'); convert_element_type_1980 = None + wait_tensor_443 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_152); reduce_scatter_tensor_152 = None + view_1501 = torch.ops.aten.view.default(view_1496, [16384, 4096]); view_1496 = None + permute_897 = torch.ops.aten.permute.default(view_1501, [1, 0]) + mm_463 = torch.ops.aten.mm.default(permute_897, view_513); permute_897 = view_513 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 64, '0'); convert_element_type_499 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + permute_899 = torch.ops.aten.permute.default(permute_165, [1, 0]); permute_165 = None + mm_464 = torch.ops.aten.mm.default(view_1501, permute_899); view_1501 = permute_899 = None + view_1502 = torch.ops.aten.view.default(mm_464, [2, 8192, 4096]); mm_464 = None + add_246 = torch.ops.aten.add.Tensor(add_245, view_1502); add_245 = view_1502 = None + convert_element_type_1985 = torch.ops.prims.convert_element_type.default(mm_463, torch.float32); mm_463 = None + reduce_scatter_tensor_153 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1985, 'avg', 64, '0'); convert_element_type_1985 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_153); reduce_scatter_tensor_153 = None + convert_element_type_1986 = torch.ops.prims.convert_element_type.default(add_246, torch.float32); add_246 = None + convert_element_type_1988 = torch.ops.prims.convert_element_type.default(wait_tensor_136, torch.float32); wait_tensor_136 = None + mul_598 = torch.ops.aten.mul.Tensor(convert_element_type_1986, convert_element_type_1988); convert_element_type_1988 = None + mul_600 = torch.ops.aten.mul.Tensor(mul_120, mul_598) + sum_103 = torch.ops.aten.sum.dim_IntList(mul_600, [2], True); mul_600 = None + div_34 = torch.ops.aten.div.Tensor(mul_120, 4096) + mul_601 = torch.ops.aten.mul.Tensor(div_34, sum_103); div_34 = sum_103 = None + sub_51 = torch.ops.aten.sub.Tensor(mul_598, mul_601); mul_598 = mul_601 = None + mul_602 = torch.ops.aten.mul.Tensor(sub_51, rsqrt_30); sub_51 = rsqrt_30 = None + mul_603 = torch.ops.aten.mul.Tensor(convert_element_type_1986, mul_120); convert_element_type_1986 = mul_120 = None + sum_104 = torch.ops.aten.sum.dim_IntList(mul_603, [0, 1]); mul_603 = None + convert_element_type_1989 = torch.ops.prims.convert_element_type.default(mul_602, torch.bfloat16); mul_602 = None + add_247 = torch.ops.aten.add.Tensor(add_244, convert_element_type_1989); add_244 = convert_element_type_1989 = None + convert_element_type_default_31 = torch.ops.prims.convert_element_type.default(sum_104, torch.float32); sum_104 = None + reduce_scatter_tensor_154 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_31, 'avg', 64, '0'); convert_element_type_default_31 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_154); reduce_scatter_tensor_154 = None + view_1503 = torch.ops.aten.view.default(add_247, [16384, 4096]) + permute_901 = torch.ops.aten.permute.default(view_1503, [1, 0]) + permute_160 = torch.ops.aten.permute.default(getitem_126, [0, 2, 1, 3]) + view_497 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16); primals_134 = None + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 64, '0'); convert_element_type_479 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_131, [1, 0]); wait_tensor_131 = None + view_499 = torch.ops.aten.view.default(view_497, [16384, 4096]); view_497 = None + mm_101 = torch.ops.aten.mm.default(view_499, permute_161) + view_500 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + add_57 = torch.ops.aten.add.Tensor(add_55, view_500); view_500 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16); primals_135 = None + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 64, '0'); convert_element_type_482 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32); add_57 = None + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_132) + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + view_503 = torch.ops.aten.view.default(convert_element_type_484, [16384, 4096]); convert_element_type_484 = None + view_504 = torch.ops.aten.view.default(mm_102, [2, 8192, 14336]); mm_102 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_504, torch.float32); view_504 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 64, '0'); convert_element_type_490 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + mm_103 = torch.ops.aten.mm.default(view_503, permute_163) + view_507 = torch.ops.aten.view.default(mm_103, [2, 8192, 14336]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_507) + view_509 = torch.ops.aten.view.default(mul_119, [16384, 14336]); mul_119 = None + mm_465 = torch.ops.aten.mm.default(permute_901, view_509); permute_901 = view_509 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 64, '0'); convert_element_type_493 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + permute_903 = torch.ops.aten.permute.default(permute_164, [1, 0]); permute_164 = None + mm_466 = torch.ops.aten.mm.default(view_1503, permute_903); view_1503 = permute_903 = None + view_1504 = torch.ops.aten.view.default(mm_466, [2, 8192, 14336]); mm_466 = None + convert_element_type_1996 = torch.ops.prims.convert_element_type.default(mm_465, torch.float32); mm_465 = None + reduce_scatter_tensor_155 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1996, 'avg', 64, '0'); convert_element_type_1996 = None + wait_tensor_446 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_155); reduce_scatter_tensor_155 = None + mul_604 = torch.ops.aten.mul.Tensor(view_1504, convert_element_type_489); convert_element_type_489 = None + mul_605 = torch.ops.aten.mul.Tensor(view_1504, view_507); view_1504 = view_507 = None + view_1505 = torch.ops.aten.view.default(mul_604, [16384, 14336]); mul_604 = None + permute_905 = torch.ops.aten.permute.default(view_1505, [1, 0]) + mm_467 = torch.ops.aten.mm.default(permute_905, view_503); permute_905 = None + permute_907 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None + mm_468 = torch.ops.aten.mm.default(view_1505, permute_907); view_1505 = permute_907 = None + view_1506 = torch.ops.aten.view.default(mm_468, [2, 8192, 4096]); mm_468 = None + convert_element_type_2001 = torch.ops.prims.convert_element_type.default(mm_467, torch.float32); mm_467 = None + reduce_scatter_tensor_156 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2001, 'avg', 64, '0'); convert_element_type_2001 = None + wait_tensor_447 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_156); reduce_scatter_tensor_156 = None + convert_element_type_2002 = torch.ops.prims.convert_element_type.default(mul_605, torch.float32); mul_605 = None + neg_17 = torch.ops.aten.neg.default(convert_element_type_488) + exp_17 = torch.ops.aten.exp.default(neg_17); neg_17 = None + add_248 = torch.ops.aten.add.Tensor(exp_17, 1); exp_17 = None + reciprocal_17 = torch.ops.aten.reciprocal.default(add_248); add_248 = None + mul_606 = torch.ops.aten.mul.Tensor(reciprocal_17, 1); reciprocal_17 = None + mul_607 = torch.ops.aten.mul.Tensor(convert_element_type_2002, mul_606); convert_element_type_2002 = None + sub_52 = torch.ops.aten.sub.Tensor(1, mul_606); mul_606 = None + mul_608 = torch.ops.aten.mul.Tensor(convert_element_type_488, sub_52); convert_element_type_488 = sub_52 = None + add_249 = torch.ops.aten.add.Tensor(mul_608, 1); mul_608 = None + mul_609 = torch.ops.aten.mul.Tensor(mul_607, add_249); mul_607 = add_249 = None + convert_element_type_2004 = torch.ops.prims.convert_element_type.default(mul_609, torch.bfloat16); mul_609 = None + view_1507 = torch.ops.aten.view.default(convert_element_type_2004, [16384, 14336]); convert_element_type_2004 = None + permute_909 = torch.ops.aten.permute.default(view_1507, [1, 0]) + mm_469 = torch.ops.aten.mm.default(permute_909, view_503); permute_909 = view_503 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 64, '0'); convert_element_type_485 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_133, [1, 0]); wait_tensor_133 = None + permute_911 = torch.ops.aten.permute.default(permute_162, [1, 0]); permute_162 = None + mm_470 = torch.ops.aten.mm.default(view_1507, permute_911); view_1507 = permute_911 = None + view_1508 = torch.ops.aten.view.default(mm_470, [2, 8192, 4096]); mm_470 = None + add_250 = torch.ops.aten.add.Tensor(view_1506, view_1508); view_1506 = view_1508 = None + convert_element_type_2009 = torch.ops.prims.convert_element_type.default(mm_469, torch.float32); mm_469 = None + reduce_scatter_tensor_157 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2009, 'avg', 64, '0'); convert_element_type_2009 = None + wait_tensor_448 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_157); reduce_scatter_tensor_157 = None + convert_element_type_2010 = torch.ops.prims.convert_element_type.default(add_250, torch.float32); add_250 = None + convert_element_type_2012 = torch.ops.prims.convert_element_type.default(wait_tensor_132, torch.float32); wait_tensor_132 = None + mul_610 = torch.ops.aten.mul.Tensor(convert_element_type_2010, convert_element_type_2012); convert_element_type_2012 = None + mul_612 = torch.ops.aten.mul.Tensor(mul_116, mul_610) + sum_105 = torch.ops.aten.sum.dim_IntList(mul_612, [2], True); mul_612 = None + div_35 = torch.ops.aten.div.Tensor(mul_116, 4096) + mul_613 = torch.ops.aten.mul.Tensor(div_35, sum_105); div_35 = sum_105 = None + sub_53 = torch.ops.aten.sub.Tensor(mul_610, mul_613); mul_610 = mul_613 = None + mul_614 = torch.ops.aten.mul.Tensor(sub_53, rsqrt_29); sub_53 = rsqrt_29 = None + mul_615 = torch.ops.aten.mul.Tensor(convert_element_type_2010, mul_116); convert_element_type_2010 = mul_116 = None + sum_106 = torch.ops.aten.sum.dim_IntList(mul_615, [0, 1]); mul_615 = None + convert_element_type_2013 = torch.ops.prims.convert_element_type.default(mul_614, torch.bfloat16); mul_614 = None + add_251 = torch.ops.aten.add.Tensor(add_247, convert_element_type_2013); add_247 = convert_element_type_2013 = None + convert_element_type_default_30 = torch.ops.prims.convert_element_type.default(sum_106, torch.float32); sum_106 = None + reduce_scatter_tensor_158 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_30, 'avg', 64, '0'); convert_element_type_default_30 = None + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_158); reduce_scatter_tensor_158 = None + view_1509 = torch.ops.aten.view.default(add_251, [16384, 4096]) + permute_913 = torch.ops.aten.permute.default(view_1509, [1, 0]) + mm_471 = torch.ops.aten.mm.default(permute_913, view_499); permute_913 = view_499 = None + permute_915 = torch.ops.aten.permute.default(permute_161, [1, 0]); permute_161 = None + mm_472 = torch.ops.aten.mm.default(view_1509, permute_915); view_1509 = permute_915 = None + view_1510 = torch.ops.aten.view.default(mm_472, [2, 8192, 4096]); mm_472 = None + convert_element_type_2020 = torch.ops.prims.convert_element_type.default(mm_471, torch.float32); mm_471 = None + reduce_scatter_tensor_159 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2020, 'avg', 64, '0'); convert_element_type_2020 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_159); reduce_scatter_tensor_159 = None + view_1511 = torch.ops.aten.view.default(view_1510, [2, 8192, 32, 128]); view_1510 = None + permute_917 = torch.ops.aten.permute.default(view_1511, [0, 2, 1, 3]); view_1511 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16); primals_130 = None + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 64, '0'); convert_element_type_463 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32); add_55 = None + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_127) + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + view_479 = torch.ops.aten.view.default(convert_element_type_465, [16384, 4096]); convert_element_type_465 = None + view_480 = torch.ops.aten.view.default(mm_98, [2, 8192, 4096]); mm_98 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16); primals_132 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 64, '0'); convert_element_type_469 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_99 = torch.ops.aten.mm.default(view_479, permute_155) + view_483 = torch.ops.aten.view.default(mm_99, [2, 8192, 1024]); mm_99 = None + view_486 = torch.ops.aten.view.default(mm_100, [2, 8192, 1024]); mm_100 = None + view_487 = torch.ops.aten.view.default(view_480, [2, 8192, -1, 128]); view_480 = None + view_488 = torch.ops.aten.view.default(view_483, [2, 8192, -1, 128]); view_483 = None + view_489 = torch.ops.aten.view.default(view_486, [2, 8192, -1, 128]); view_486 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_487, torch.float32); view_487 = None + view_490 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 32, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_490); view_490 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_488, torch.float32); view_488 = None + view_491 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 8, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_491); view_491 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_16); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_493 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 32, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_16); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_494 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 8, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_493, torch.bfloat16); view_493 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_494, torch.bfloat16); view_494 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 8, 4, 128]); unsqueeze_28 = None + clone_28 = torch.ops.aten.clone.default(expand_28, memory_format = torch.contiguous_format); expand_28 = None + view_495 = torch.ops.aten.view.default(clone_28, [2, 8192, 32, 128]); clone_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_489, 3); view_489 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 8, 4, 128]); unsqueeze_29 = None + clone_29 = torch.ops.aten.clone.default(expand_29, memory_format = torch.contiguous_format); expand_29 = None + view_496 = torch.ops.aten.view.default(clone_29, [2, 8192, 32, 128]); clone_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_495, [0, 2, 1, 3]); view_495 = None + permute_159 = torch.ops.aten.permute.default(view_496, [0, 2, 1, 3]); view_496 = None + _scaled_dot_product_cudnn_attention_backward_17 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_917, permute_157, permute_158, permute_159, getitem_126, getitem_127, getitem_132, getitem_133, None, None, None, 8192, 8192, 0.0, True); permute_917 = permute_157 = permute_158 = permute_159 = getitem_126 = getitem_127 = getitem_132 = getitem_133 = None + getitem_339 = _scaled_dot_product_cudnn_attention_backward_17[0] + getitem_340 = _scaled_dot_product_cudnn_attention_backward_17[1] + getitem_341 = _scaled_dot_product_cudnn_attention_backward_17[2]; _scaled_dot_product_cudnn_attention_backward_17 = None + permute_918 = torch.ops.aten.permute.default(getitem_341, [0, 2, 1, 3]); getitem_341 = None + permute_919 = torch.ops.aten.permute.default(getitem_340, [0, 2, 1, 3]); getitem_340 = None + permute_920 = torch.ops.aten.permute.default(getitem_339, [0, 2, 1, 3]); getitem_339 = None + view_1512 = torch.ops.aten.view.default(permute_918, [2, 8192, 8, 4, 128]); permute_918 = None + sum_107 = torch.ops.aten.sum.dim_IntList(view_1512, [3], True); view_1512 = None + squeeze_34 = torch.ops.aten.squeeze.dim(sum_107, 3); sum_107 = None + view_1513 = torch.ops.aten.view.default(permute_919, [2, 8192, 8, 4, 128]); permute_919 = None + sum_108 = torch.ops.aten.sum.dim_IntList(view_1513, [3], True); view_1513 = None + squeeze_35 = torch.ops.aten.squeeze.dim(sum_108, 3); sum_108 = None + convert_element_type_2021 = torch.ops.prims.convert_element_type.default(squeeze_35, torch.float32); squeeze_35 = None + convert_element_type_2022 = torch.ops.prims.convert_element_type.default(permute_920, torch.float32); permute_920 = None + view_1514 = torch.ops.aten.view.default(convert_element_type_2021, [2, 8192, 8, 64, 2]); convert_element_type_2021 = None + view_as_complex_98 = torch.ops.aten.view_as_complex.default(view_1514); view_1514 = None + mul_616 = torch.ops.aten.mul.Tensor(view_as_complex_98, _conj); view_as_complex_98 = None + view_1515 = torch.ops.aten.view.default(convert_element_type_2022, [2, 8192, 32, 64, 2]); convert_element_type_2022 = None + view_as_complex_99 = torch.ops.aten.view_as_complex.default(view_1515); view_1515 = None + mul_617 = torch.ops.aten.mul.Tensor(view_as_complex_99, _conj); view_as_complex_99 = None + view_as_real_98 = torch.ops.aten.view_as_real.default(mul_616); mul_616 = None + view_1516 = torch.ops.aten.view.default(view_as_real_98, [2, 8192, 8, 128]); view_as_real_98 = None + convert_element_type_2023 = torch.ops.prims.convert_element_type.default(view_1516, torch.bfloat16); view_1516 = None + view_as_real_99 = torch.ops.aten.view_as_real.default(mul_617); mul_617 = None + view_1517 = torch.ops.aten.view.default(view_as_real_99, [2, 8192, 32, 128]); view_as_real_99 = None + convert_element_type_2024 = torch.ops.prims.convert_element_type.default(view_1517, torch.bfloat16); view_1517 = None + view_1518 = torch.ops.aten.view.default(squeeze_34, [2, 8192, 1024]); squeeze_34 = None + view_1519 = torch.ops.aten.view.default(convert_element_type_2023, [2, 8192, 1024]); convert_element_type_2023 = None + view_1520 = torch.ops.aten.view.default(convert_element_type_2024, [2, 8192, 4096]); convert_element_type_2024 = None + view_1521 = torch.ops.aten.view.default(view_1518, [16384, 1024]); view_1518 = None + permute_921 = torch.ops.aten.permute.default(view_1521, [1, 0]) + mm_473 = torch.ops.aten.mm.default(permute_921, view_479); permute_921 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16); primals_133 = None + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 64, '0'); convert_element_type_472 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + permute_923 = torch.ops.aten.permute.default(permute_156, [1, 0]); permute_156 = None + mm_474 = torch.ops.aten.mm.default(view_1521, permute_923); view_1521 = permute_923 = None + view_1522 = torch.ops.aten.view.default(mm_474, [2, 8192, 4096]); mm_474 = None + convert_element_type_2029 = torch.ops.prims.convert_element_type.default(mm_473, torch.float32); mm_473 = None + reduce_scatter_tensor_160 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2029, 'avg', 64, '0'); convert_element_type_2029 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_160); reduce_scatter_tensor_160 = None + view_1523 = torch.ops.aten.view.default(view_1519, [16384, 1024]); view_1519 = None + permute_925 = torch.ops.aten.permute.default(view_1523, [1, 0]) + mm_475 = torch.ops.aten.mm.default(permute_925, view_479); permute_925 = None + permute_927 = torch.ops.aten.permute.default(permute_155, [1, 0]); permute_155 = None + mm_476 = torch.ops.aten.mm.default(view_1523, permute_927); view_1523 = permute_927 = None + view_1524 = torch.ops.aten.view.default(mm_476, [2, 8192, 4096]); mm_476 = None + add_252 = torch.ops.aten.add.Tensor(view_1522, view_1524); view_1522 = view_1524 = None + convert_element_type_2034 = torch.ops.prims.convert_element_type.default(mm_475, torch.float32); mm_475 = None + reduce_scatter_tensor_161 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2034, 'avg', 64, '0'); convert_element_type_2034 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_161); reduce_scatter_tensor_161 = None + view_1525 = torch.ops.aten.view.default(view_1520, [16384, 4096]); view_1520 = None + permute_929 = torch.ops.aten.permute.default(view_1525, [1, 0]) + mm_477 = torch.ops.aten.mm.default(permute_929, view_479); permute_929 = view_479 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16); primals_131 = None + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 64, '0'); convert_element_type_466 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + permute_931 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None + mm_478 = torch.ops.aten.mm.default(view_1525, permute_931); view_1525 = permute_931 = None + view_1526 = torch.ops.aten.view.default(mm_478, [2, 8192, 4096]); mm_478 = None + add_253 = torch.ops.aten.add.Tensor(add_252, view_1526); add_252 = view_1526 = None + convert_element_type_2039 = torch.ops.prims.convert_element_type.default(mm_477, torch.float32); mm_477 = None + reduce_scatter_tensor_162 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2039, 'avg', 64, '0'); convert_element_type_2039 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_162); reduce_scatter_tensor_162 = None + convert_element_type_2040 = torch.ops.prims.convert_element_type.default(add_253, torch.float32); add_253 = None + convert_element_type_2042 = torch.ops.prims.convert_element_type.default(wait_tensor_127, torch.float32); wait_tensor_127 = None + mul_618 = torch.ops.aten.mul.Tensor(convert_element_type_2040, convert_element_type_2042); convert_element_type_2042 = None + mul_620 = torch.ops.aten.mul.Tensor(mul_112, mul_618) + sum_109 = torch.ops.aten.sum.dim_IntList(mul_620, [2], True); mul_620 = None + div_36 = torch.ops.aten.div.Tensor(mul_112, 4096) + mul_621 = torch.ops.aten.mul.Tensor(div_36, sum_109); div_36 = sum_109 = None + sub_54 = torch.ops.aten.sub.Tensor(mul_618, mul_621); mul_618 = mul_621 = None + mul_622 = torch.ops.aten.mul.Tensor(sub_54, rsqrt_28); sub_54 = rsqrt_28 = None + mul_623 = torch.ops.aten.mul.Tensor(convert_element_type_2040, mul_112); convert_element_type_2040 = mul_112 = None + sum_110 = torch.ops.aten.sum.dim_IntList(mul_623, [0, 1]); mul_623 = None + convert_element_type_2043 = torch.ops.prims.convert_element_type.default(mul_622, torch.bfloat16); mul_622 = None + add_254 = torch.ops.aten.add.Tensor(add_251, convert_element_type_2043); add_251 = convert_element_type_2043 = None + convert_element_type_default_29 = torch.ops.prims.convert_element_type.default(sum_110, torch.float32); sum_110 = None + reduce_scatter_tensor_163 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_29, 'avg', 64, '0'); convert_element_type_default_29 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_163); reduce_scatter_tensor_163 = None + view_1527 = torch.ops.aten.view.default(add_254, [16384, 4096]) + permute_933 = torch.ops.aten.permute.default(view_1527, [1, 0]) + permute_149 = torch.ops.aten.permute.default(getitem_117, [0, 2, 1, 3]) + view_463 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 64, '0'); convert_element_type_446 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + view_465 = torch.ops.aten.view.default(view_463, [16384, 4096]); view_463 = None + mm_94 = torch.ops.aten.mm.default(view_465, permute_150) + view_466 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + add_53 = torch.ops.aten.add.Tensor(add_51, view_466); view_466 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16); primals_126 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 64, '0'); convert_element_type_449 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32); add_53 = None + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_123) + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + view_469 = torch.ops.aten.view.default(convert_element_type_451, [16384, 4096]); convert_element_type_451 = None + view_470 = torch.ops.aten.view.default(mm_95, [2, 8192, 14336]); mm_95 = None + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_470, torch.float32); view_470 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16); primals_128 = None + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 64, '0'); convert_element_type_457 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_125, [1, 0]); wait_tensor_125 = None + mm_96 = torch.ops.aten.mm.default(view_469, permute_152) + view_473 = torch.ops.aten.view.default(mm_96, [2, 8192, 14336]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_473) + view_475 = torch.ops.aten.view.default(mul_111, [16384, 14336]); mul_111 = None + mm_479 = torch.ops.aten.mm.default(permute_933, view_475); permute_933 = view_475 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16); primals_129 = None + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 64, '0'); convert_element_type_460 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_126, [1, 0]); wait_tensor_126 = None + permute_935 = torch.ops.aten.permute.default(permute_153, [1, 0]); permute_153 = None + mm_480 = torch.ops.aten.mm.default(view_1527, permute_935); view_1527 = permute_935 = None + view_1528 = torch.ops.aten.view.default(mm_480, [2, 8192, 14336]); mm_480 = None + convert_element_type_2050 = torch.ops.prims.convert_element_type.default(mm_479, torch.float32); mm_479 = None + reduce_scatter_tensor_164 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2050, 'avg', 64, '0'); convert_element_type_2050 = None + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_164); reduce_scatter_tensor_164 = None + mul_624 = torch.ops.aten.mul.Tensor(view_1528, convert_element_type_456); convert_element_type_456 = None + mul_625 = torch.ops.aten.mul.Tensor(view_1528, view_473); view_1528 = view_473 = None + view_1529 = torch.ops.aten.view.default(mul_624, [16384, 14336]); mul_624 = None + permute_937 = torch.ops.aten.permute.default(view_1529, [1, 0]) + mm_481 = torch.ops.aten.mm.default(permute_937, view_469); permute_937 = None + permute_939 = torch.ops.aten.permute.default(permute_152, [1, 0]); permute_152 = None + mm_482 = torch.ops.aten.mm.default(view_1529, permute_939); view_1529 = permute_939 = None + view_1530 = torch.ops.aten.view.default(mm_482, [2, 8192, 4096]); mm_482 = None + convert_element_type_2055 = torch.ops.prims.convert_element_type.default(mm_481, torch.float32); mm_481 = None + reduce_scatter_tensor_165 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2055, 'avg', 64, '0'); convert_element_type_2055 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_165); reduce_scatter_tensor_165 = None + convert_element_type_2056 = torch.ops.prims.convert_element_type.default(mul_625, torch.float32); mul_625 = None + neg_18 = torch.ops.aten.neg.default(convert_element_type_455) + exp_18 = torch.ops.aten.exp.default(neg_18); neg_18 = None + add_255 = torch.ops.aten.add.Tensor(exp_18, 1); exp_18 = None + reciprocal_18 = torch.ops.aten.reciprocal.default(add_255); add_255 = None + mul_626 = torch.ops.aten.mul.Tensor(reciprocal_18, 1); reciprocal_18 = None + mul_627 = torch.ops.aten.mul.Tensor(convert_element_type_2056, mul_626); convert_element_type_2056 = None + sub_55 = torch.ops.aten.sub.Tensor(1, mul_626); mul_626 = None + mul_628 = torch.ops.aten.mul.Tensor(convert_element_type_455, sub_55); convert_element_type_455 = sub_55 = None + add_256 = torch.ops.aten.add.Tensor(mul_628, 1); mul_628 = None + mul_629 = torch.ops.aten.mul.Tensor(mul_627, add_256); mul_627 = add_256 = None + convert_element_type_2058 = torch.ops.prims.convert_element_type.default(mul_629, torch.bfloat16); mul_629 = None + view_1531 = torch.ops.aten.view.default(convert_element_type_2058, [16384, 14336]); convert_element_type_2058 = None + permute_941 = torch.ops.aten.permute.default(view_1531, [1, 0]) + mm_483 = torch.ops.aten.mm.default(permute_941, view_469); permute_941 = view_469 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16); primals_127 = None + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 64, '0'); convert_element_type_452 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + permute_943 = torch.ops.aten.permute.default(permute_151, [1, 0]); permute_151 = None + mm_484 = torch.ops.aten.mm.default(view_1531, permute_943); view_1531 = permute_943 = None + view_1532 = torch.ops.aten.view.default(mm_484, [2, 8192, 4096]); mm_484 = None + add_257 = torch.ops.aten.add.Tensor(view_1530, view_1532); view_1530 = view_1532 = None + convert_element_type_2063 = torch.ops.prims.convert_element_type.default(mm_483, torch.float32); mm_483 = None + reduce_scatter_tensor_166 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2063, 'avg', 64, '0'); convert_element_type_2063 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_166); reduce_scatter_tensor_166 = None + convert_element_type_2064 = torch.ops.prims.convert_element_type.default(add_257, torch.float32); add_257 = None + convert_element_type_2066 = torch.ops.prims.convert_element_type.default(wait_tensor_123, torch.float32); wait_tensor_123 = None + mul_630 = torch.ops.aten.mul.Tensor(convert_element_type_2064, convert_element_type_2066); convert_element_type_2066 = None + mul_632 = torch.ops.aten.mul.Tensor(mul_108, mul_630) + sum_111 = torch.ops.aten.sum.dim_IntList(mul_632, [2], True); mul_632 = None + div_37 = torch.ops.aten.div.Tensor(mul_108, 4096) + mul_633 = torch.ops.aten.mul.Tensor(div_37, sum_111); div_37 = sum_111 = None + sub_56 = torch.ops.aten.sub.Tensor(mul_630, mul_633); mul_630 = mul_633 = None + mul_634 = torch.ops.aten.mul.Tensor(sub_56, rsqrt_27); sub_56 = rsqrt_27 = None + mul_635 = torch.ops.aten.mul.Tensor(convert_element_type_2064, mul_108); convert_element_type_2064 = mul_108 = None + sum_112 = torch.ops.aten.sum.dim_IntList(mul_635, [0, 1]); mul_635 = None + convert_element_type_2067 = torch.ops.prims.convert_element_type.default(mul_634, torch.bfloat16); mul_634 = None + add_258 = torch.ops.aten.add.Tensor(add_254, convert_element_type_2067); add_254 = convert_element_type_2067 = None + convert_element_type_default_28 = torch.ops.prims.convert_element_type.default(sum_112, torch.float32); sum_112 = None + reduce_scatter_tensor_167 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_28, 'avg', 64, '0'); convert_element_type_default_28 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_167); reduce_scatter_tensor_167 = None + view_1533 = torch.ops.aten.view.default(add_258, [16384, 4096]) + permute_945 = torch.ops.aten.permute.default(view_1533, [1, 0]) + mm_485 = torch.ops.aten.mm.default(permute_945, view_465); permute_945 = view_465 = None + permute_947 = torch.ops.aten.permute.default(permute_150, [1, 0]); permute_150 = None + mm_486 = torch.ops.aten.mm.default(view_1533, permute_947); view_1533 = permute_947 = None + view_1534 = torch.ops.aten.view.default(mm_486, [2, 8192, 4096]); mm_486 = None + convert_element_type_2074 = torch.ops.prims.convert_element_type.default(mm_485, torch.float32); mm_485 = None + reduce_scatter_tensor_168 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2074, 'avg', 64, '0'); convert_element_type_2074 = None + wait_tensor_459 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_168); reduce_scatter_tensor_168 = None + view_1535 = torch.ops.aten.view.default(view_1534, [2, 8192, 32, 128]); view_1534 = None + permute_949 = torch.ops.aten.permute.default(view_1535, [0, 2, 1, 3]); view_1535 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 64, '0'); convert_element_type_430 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32); add_51 = None + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_118) + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + view_445 = torch.ops.aten.view.default(convert_element_type_432, [16384, 4096]); convert_element_type_432 = None + view_446 = torch.ops.aten.view.default(mm_91, [2, 8192, 4096]); mm_91 = None + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 64, '0'); convert_element_type_436 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + mm_92 = torch.ops.aten.mm.default(view_445, permute_144) + view_449 = torch.ops.aten.view.default(mm_92, [2, 8192, 1024]); mm_92 = None + view_452 = torch.ops.aten.view.default(mm_93, [2, 8192, 1024]); mm_93 = None + view_453 = torch.ops.aten.view.default(view_446, [2, 8192, -1, 128]); view_446 = None + view_454 = torch.ops.aten.view.default(view_449, [2, 8192, -1, 128]); view_449 = None + view_455 = torch.ops.aten.view.default(view_452, [2, 8192, -1, 128]); view_452 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_453, torch.float32); view_453 = None + view_456 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 32, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_456); view_456 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_454, torch.float32); view_454 = None + view_457 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 8, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_457); view_457 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_16); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_459 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 32, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_16); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_460 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 8, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_459, torch.bfloat16); view_459 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_460, torch.bfloat16); view_460 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 8, 4, 128]); unsqueeze_26 = None + clone_26 = torch.ops.aten.clone.default(expand_26, memory_format = torch.contiguous_format); expand_26 = None + view_461 = torch.ops.aten.view.default(clone_26, [2, 8192, 32, 128]); clone_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_455, 3); view_455 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 8, 4, 128]); unsqueeze_27 = None + clone_27 = torch.ops.aten.clone.default(expand_27, memory_format = torch.contiguous_format); expand_27 = None + view_462 = torch.ops.aten.view.default(clone_27, [2, 8192, 32, 128]); clone_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_461, [0, 2, 1, 3]); view_461 = None + permute_148 = torch.ops.aten.permute.default(view_462, [0, 2, 1, 3]); view_462 = None + _scaled_dot_product_cudnn_attention_backward_18 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_949, permute_146, permute_147, permute_148, getitem_117, getitem_118, getitem_123, getitem_124, None, None, None, 8192, 8192, 0.0, True); permute_949 = permute_146 = permute_147 = permute_148 = getitem_117 = getitem_118 = getitem_123 = getitem_124 = None + getitem_342 = _scaled_dot_product_cudnn_attention_backward_18[0] + getitem_343 = _scaled_dot_product_cudnn_attention_backward_18[1] + getitem_344 = _scaled_dot_product_cudnn_attention_backward_18[2]; _scaled_dot_product_cudnn_attention_backward_18 = None + permute_950 = torch.ops.aten.permute.default(getitem_344, [0, 2, 1, 3]); getitem_344 = None + permute_951 = torch.ops.aten.permute.default(getitem_343, [0, 2, 1, 3]); getitem_343 = None + permute_952 = torch.ops.aten.permute.default(getitem_342, [0, 2, 1, 3]); getitem_342 = None + view_1536 = torch.ops.aten.view.default(permute_950, [2, 8192, 8, 4, 128]); permute_950 = None + sum_113 = torch.ops.aten.sum.dim_IntList(view_1536, [3], True); view_1536 = None + squeeze_36 = torch.ops.aten.squeeze.dim(sum_113, 3); sum_113 = None + view_1537 = torch.ops.aten.view.default(permute_951, [2, 8192, 8, 4, 128]); permute_951 = None + sum_114 = torch.ops.aten.sum.dim_IntList(view_1537, [3], True); view_1537 = None + squeeze_37 = torch.ops.aten.squeeze.dim(sum_114, 3); sum_114 = None + convert_element_type_2075 = torch.ops.prims.convert_element_type.default(squeeze_37, torch.float32); squeeze_37 = None + convert_element_type_2076 = torch.ops.prims.convert_element_type.default(permute_952, torch.float32); permute_952 = None + view_1538 = torch.ops.aten.view.default(convert_element_type_2075, [2, 8192, 8, 64, 2]); convert_element_type_2075 = None + view_as_complex_100 = torch.ops.aten.view_as_complex.default(view_1538); view_1538 = None + mul_636 = torch.ops.aten.mul.Tensor(view_as_complex_100, _conj); view_as_complex_100 = None + view_1539 = torch.ops.aten.view.default(convert_element_type_2076, [2, 8192, 32, 64, 2]); convert_element_type_2076 = None + view_as_complex_101 = torch.ops.aten.view_as_complex.default(view_1539); view_1539 = None + mul_637 = torch.ops.aten.mul.Tensor(view_as_complex_101, _conj); view_as_complex_101 = None + view_as_real_100 = torch.ops.aten.view_as_real.default(mul_636); mul_636 = None + view_1540 = torch.ops.aten.view.default(view_as_real_100, [2, 8192, 8, 128]); view_as_real_100 = None + convert_element_type_2077 = torch.ops.prims.convert_element_type.default(view_1540, torch.bfloat16); view_1540 = None + view_as_real_101 = torch.ops.aten.view_as_real.default(mul_637); mul_637 = None + view_1541 = torch.ops.aten.view.default(view_as_real_101, [2, 8192, 32, 128]); view_as_real_101 = None + convert_element_type_2078 = torch.ops.prims.convert_element_type.default(view_1541, torch.bfloat16); view_1541 = None + view_1542 = torch.ops.aten.view.default(squeeze_36, [2, 8192, 1024]); squeeze_36 = None + view_1543 = torch.ops.aten.view.default(convert_element_type_2077, [2, 8192, 1024]); convert_element_type_2077 = None + view_1544 = torch.ops.aten.view.default(convert_element_type_2078, [2, 8192, 4096]); convert_element_type_2078 = None + view_1545 = torch.ops.aten.view.default(view_1542, [16384, 1024]); view_1542 = None + permute_953 = torch.ops.aten.permute.default(view_1545, [1, 0]) + mm_487 = torch.ops.aten.mm.default(permute_953, view_445); permute_953 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 64, '0'); convert_element_type_439 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + permute_955 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None + mm_488 = torch.ops.aten.mm.default(view_1545, permute_955); view_1545 = permute_955 = None + view_1546 = torch.ops.aten.view.default(mm_488, [2, 8192, 4096]); mm_488 = None + convert_element_type_2083 = torch.ops.prims.convert_element_type.default(mm_487, torch.float32); mm_487 = None + reduce_scatter_tensor_169 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2083, 'avg', 64, '0'); convert_element_type_2083 = None + wait_tensor_460 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_169); reduce_scatter_tensor_169 = None + view_1547 = torch.ops.aten.view.default(view_1543, [16384, 1024]); view_1543 = None + permute_957 = torch.ops.aten.permute.default(view_1547, [1, 0]) + mm_489 = torch.ops.aten.mm.default(permute_957, view_445); permute_957 = None + permute_959 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None + mm_490 = torch.ops.aten.mm.default(view_1547, permute_959); view_1547 = permute_959 = None + view_1548 = torch.ops.aten.view.default(mm_490, [2, 8192, 4096]); mm_490 = None + add_259 = torch.ops.aten.add.Tensor(view_1546, view_1548); view_1546 = view_1548 = None + convert_element_type_2088 = torch.ops.prims.convert_element_type.default(mm_489, torch.float32); mm_489 = None + reduce_scatter_tensor_170 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2088, 'avg', 64, '0'); convert_element_type_2088 = None + wait_tensor_461 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_170); reduce_scatter_tensor_170 = None + view_1549 = torch.ops.aten.view.default(view_1544, [16384, 4096]); view_1544 = None + permute_961 = torch.ops.aten.permute.default(view_1549, [1, 0]) + mm_491 = torch.ops.aten.mm.default(permute_961, view_445); permute_961 = view_445 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 64, '0'); convert_element_type_433 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_119, [1, 0]); wait_tensor_119 = None + permute_963 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None + mm_492 = torch.ops.aten.mm.default(view_1549, permute_963); view_1549 = permute_963 = None + view_1550 = torch.ops.aten.view.default(mm_492, [2, 8192, 4096]); mm_492 = None + add_260 = torch.ops.aten.add.Tensor(add_259, view_1550); add_259 = view_1550 = None + convert_element_type_2093 = torch.ops.prims.convert_element_type.default(mm_491, torch.float32); mm_491 = None + reduce_scatter_tensor_171 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2093, 'avg', 64, '0'); convert_element_type_2093 = None + wait_tensor_462 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_171); reduce_scatter_tensor_171 = None + convert_element_type_2094 = torch.ops.prims.convert_element_type.default(add_260, torch.float32); add_260 = None + convert_element_type_2096 = torch.ops.prims.convert_element_type.default(wait_tensor_118, torch.float32); wait_tensor_118 = None + mul_638 = torch.ops.aten.mul.Tensor(convert_element_type_2094, convert_element_type_2096); convert_element_type_2096 = None + mul_640 = torch.ops.aten.mul.Tensor(mul_104, mul_638) + sum_115 = torch.ops.aten.sum.dim_IntList(mul_640, [2], True); mul_640 = None + div_38 = torch.ops.aten.div.Tensor(mul_104, 4096) + mul_641 = torch.ops.aten.mul.Tensor(div_38, sum_115); div_38 = sum_115 = None + sub_57 = torch.ops.aten.sub.Tensor(mul_638, mul_641); mul_638 = mul_641 = None + mul_642 = torch.ops.aten.mul.Tensor(sub_57, rsqrt_26); sub_57 = rsqrt_26 = None + mul_643 = torch.ops.aten.mul.Tensor(convert_element_type_2094, mul_104); convert_element_type_2094 = mul_104 = None + sum_116 = torch.ops.aten.sum.dim_IntList(mul_643, [0, 1]); mul_643 = None + convert_element_type_2097 = torch.ops.prims.convert_element_type.default(mul_642, torch.bfloat16); mul_642 = None + add_261 = torch.ops.aten.add.Tensor(add_258, convert_element_type_2097); add_258 = convert_element_type_2097 = None + convert_element_type_default_27 = torch.ops.prims.convert_element_type.default(sum_116, torch.float32); sum_116 = None + reduce_scatter_tensor_172 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_27, 'avg', 64, '0'); convert_element_type_default_27 = None + wait_tensor_463 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_172); reduce_scatter_tensor_172 = None + view_1551 = torch.ops.aten.view.default(add_261, [16384, 4096]) + permute_965 = torch.ops.aten.permute.default(view_1551, [1, 0]) + permute_138 = torch.ops.aten.permute.default(getitem_108, [0, 2, 1, 3]) + view_429 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16); primals_116 = None + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 64, '0'); convert_element_type_413 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + view_431 = torch.ops.aten.view.default(view_429, [16384, 4096]); view_429 = None + mm_87 = torch.ops.aten.mm.default(view_431, permute_139) + view_432 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + add_49 = torch.ops.aten.add.Tensor(add_47, view_432); view_432 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16); primals_117 = None + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 64, '0'); convert_element_type_416 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32); add_49 = None + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_114) + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + view_435 = torch.ops.aten.view.default(convert_element_type_418, [16384, 4096]); convert_element_type_418 = None + view_436 = torch.ops.aten.view.default(mm_88, [2, 8192, 14336]); mm_88 = None + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_436, torch.float32); view_436 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 64, '0'); convert_element_type_424 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_89 = torch.ops.aten.mm.default(view_435, permute_141) + view_439 = torch.ops.aten.view.default(mm_89, [2, 8192, 14336]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_439) + view_441 = torch.ops.aten.view.default(mul_103, [16384, 14336]); mul_103 = None + mm_493 = torch.ops.aten.mm.default(permute_965, view_441); permute_965 = view_441 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 64, '0'); convert_element_type_427 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + permute_967 = torch.ops.aten.permute.default(permute_142, [1, 0]); permute_142 = None + mm_494 = torch.ops.aten.mm.default(view_1551, permute_967); view_1551 = permute_967 = None + view_1552 = torch.ops.aten.view.default(mm_494, [2, 8192, 14336]); mm_494 = None + convert_element_type_2104 = torch.ops.prims.convert_element_type.default(mm_493, torch.float32); mm_493 = None + reduce_scatter_tensor_173 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2104, 'avg', 64, '0'); convert_element_type_2104 = None + wait_tensor_464 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_173); reduce_scatter_tensor_173 = None + mul_644 = torch.ops.aten.mul.Tensor(view_1552, convert_element_type_423); convert_element_type_423 = None + mul_645 = torch.ops.aten.mul.Tensor(view_1552, view_439); view_1552 = view_439 = None + view_1553 = torch.ops.aten.view.default(mul_644, [16384, 14336]); mul_644 = None + permute_969 = torch.ops.aten.permute.default(view_1553, [1, 0]) + mm_495 = torch.ops.aten.mm.default(permute_969, view_435); permute_969 = None + permute_971 = torch.ops.aten.permute.default(permute_141, [1, 0]); permute_141 = None + mm_496 = torch.ops.aten.mm.default(view_1553, permute_971); view_1553 = permute_971 = None + view_1554 = torch.ops.aten.view.default(mm_496, [2, 8192, 4096]); mm_496 = None + convert_element_type_2109 = torch.ops.prims.convert_element_type.default(mm_495, torch.float32); mm_495 = None + reduce_scatter_tensor_174 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2109, 'avg', 64, '0'); convert_element_type_2109 = None + wait_tensor_465 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_174); reduce_scatter_tensor_174 = None + convert_element_type_2110 = torch.ops.prims.convert_element_type.default(mul_645, torch.float32); mul_645 = None + neg_19 = torch.ops.aten.neg.default(convert_element_type_422) + exp_19 = torch.ops.aten.exp.default(neg_19); neg_19 = None + add_262 = torch.ops.aten.add.Tensor(exp_19, 1); exp_19 = None + reciprocal_19 = torch.ops.aten.reciprocal.default(add_262); add_262 = None + mul_646 = torch.ops.aten.mul.Tensor(reciprocal_19, 1); reciprocal_19 = None + mul_647 = torch.ops.aten.mul.Tensor(convert_element_type_2110, mul_646); convert_element_type_2110 = None + sub_58 = torch.ops.aten.sub.Tensor(1, mul_646); mul_646 = None + mul_648 = torch.ops.aten.mul.Tensor(convert_element_type_422, sub_58); convert_element_type_422 = sub_58 = None + add_263 = torch.ops.aten.add.Tensor(mul_648, 1); mul_648 = None + mul_649 = torch.ops.aten.mul.Tensor(mul_647, add_263); mul_647 = add_263 = None + convert_element_type_2112 = torch.ops.prims.convert_element_type.default(mul_649, torch.bfloat16); mul_649 = None + view_1555 = torch.ops.aten.view.default(convert_element_type_2112, [16384, 14336]); convert_element_type_2112 = None + permute_973 = torch.ops.aten.permute.default(view_1555, [1, 0]) + mm_497 = torch.ops.aten.mm.default(permute_973, view_435); permute_973 = view_435 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 64, '0'); convert_element_type_419 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + permute_975 = torch.ops.aten.permute.default(permute_140, [1, 0]); permute_140 = None + mm_498 = torch.ops.aten.mm.default(view_1555, permute_975); view_1555 = permute_975 = None + view_1556 = torch.ops.aten.view.default(mm_498, [2, 8192, 4096]); mm_498 = None + add_264 = torch.ops.aten.add.Tensor(view_1554, view_1556); view_1554 = view_1556 = None + convert_element_type_2117 = torch.ops.prims.convert_element_type.default(mm_497, torch.float32); mm_497 = None + reduce_scatter_tensor_175 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2117, 'avg', 64, '0'); convert_element_type_2117 = None + wait_tensor_466 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_175); reduce_scatter_tensor_175 = None + convert_element_type_2118 = torch.ops.prims.convert_element_type.default(add_264, torch.float32); add_264 = None + convert_element_type_2120 = torch.ops.prims.convert_element_type.default(wait_tensor_114, torch.float32); wait_tensor_114 = None + mul_650 = torch.ops.aten.mul.Tensor(convert_element_type_2118, convert_element_type_2120); convert_element_type_2120 = None + mul_652 = torch.ops.aten.mul.Tensor(mul_100, mul_650) + sum_117 = torch.ops.aten.sum.dim_IntList(mul_652, [2], True); mul_652 = None + div_39 = torch.ops.aten.div.Tensor(mul_100, 4096) + mul_653 = torch.ops.aten.mul.Tensor(div_39, sum_117); div_39 = sum_117 = None + sub_59 = torch.ops.aten.sub.Tensor(mul_650, mul_653); mul_650 = mul_653 = None + mul_654 = torch.ops.aten.mul.Tensor(sub_59, rsqrt_25); sub_59 = rsqrt_25 = None + mul_655 = torch.ops.aten.mul.Tensor(convert_element_type_2118, mul_100); convert_element_type_2118 = mul_100 = None + sum_118 = torch.ops.aten.sum.dim_IntList(mul_655, [0, 1]); mul_655 = None + convert_element_type_2121 = torch.ops.prims.convert_element_type.default(mul_654, torch.bfloat16); mul_654 = None + add_265 = torch.ops.aten.add.Tensor(add_261, convert_element_type_2121); add_261 = convert_element_type_2121 = None + convert_element_type_default_26 = torch.ops.prims.convert_element_type.default(sum_118, torch.float32); sum_118 = None + reduce_scatter_tensor_176 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_26, 'avg', 64, '0'); convert_element_type_default_26 = None + wait_tensor_467 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_176); reduce_scatter_tensor_176 = None + view_1557 = torch.ops.aten.view.default(add_265, [16384, 4096]) + permute_977 = torch.ops.aten.permute.default(view_1557, [1, 0]) + mm_499 = torch.ops.aten.mm.default(permute_977, view_431); permute_977 = view_431 = None + permute_979 = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None + mm_500 = torch.ops.aten.mm.default(view_1557, permute_979); view_1557 = permute_979 = None + view_1558 = torch.ops.aten.view.default(mm_500, [2, 8192, 4096]); mm_500 = None + convert_element_type_2128 = torch.ops.prims.convert_element_type.default(mm_499, torch.float32); mm_499 = None + reduce_scatter_tensor_177 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2128, 'avg', 64, '0'); convert_element_type_2128 = None + wait_tensor_468 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_177); reduce_scatter_tensor_177 = None + view_1559 = torch.ops.aten.view.default(view_1558, [2, 8192, 32, 128]); view_1558 = None + permute_981 = torch.ops.aten.permute.default(view_1559, [0, 2, 1, 3]); view_1559 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16); primals_112 = None + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 64, '0'); convert_element_type_397 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32); add_47 = None + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_109) + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + view_411 = torch.ops.aten.view.default(convert_element_type_399, [16384, 4096]); convert_element_type_399 = None + view_412 = torch.ops.aten.view.default(mm_84, [2, 8192, 4096]); mm_84 = None + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16); primals_114 = None + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 64, '0'); convert_element_type_403 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + mm_85 = torch.ops.aten.mm.default(view_411, permute_133) + view_415 = torch.ops.aten.view.default(mm_85, [2, 8192, 1024]); mm_85 = None + view_418 = torch.ops.aten.view.default(mm_86, [2, 8192, 1024]); mm_86 = None + view_419 = torch.ops.aten.view.default(view_412, [2, 8192, -1, 128]); view_412 = None + view_420 = torch.ops.aten.view.default(view_415, [2, 8192, -1, 128]); view_415 = None + view_421 = torch.ops.aten.view.default(view_418, [2, 8192, -1, 128]); view_418 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_419, torch.float32); view_419 = None + view_422 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 32, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_422); view_422 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_420, torch.float32); view_420 = None + view_423 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 8, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_423); view_423 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_16); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_425 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 32, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_16); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_426 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 8, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_425, torch.bfloat16); view_425 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_426, torch.bfloat16); view_426 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 8, 4, 128]); unsqueeze_24 = None + clone_24 = torch.ops.aten.clone.default(expand_24, memory_format = torch.contiguous_format); expand_24 = None + view_427 = torch.ops.aten.view.default(clone_24, [2, 8192, 32, 128]); clone_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_421, 3); view_421 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 8, 4, 128]); unsqueeze_25 = None + clone_25 = torch.ops.aten.clone.default(expand_25, memory_format = torch.contiguous_format); expand_25 = None + view_428 = torch.ops.aten.view.default(clone_25, [2, 8192, 32, 128]); clone_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_427, [0, 2, 1, 3]); view_427 = None + permute_137 = torch.ops.aten.permute.default(view_428, [0, 2, 1, 3]); view_428 = None + _scaled_dot_product_cudnn_attention_backward_19 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_981, permute_135, permute_136, permute_137, getitem_108, getitem_109, getitem_114, getitem_115, None, None, None, 8192, 8192, 0.0, True); permute_981 = permute_135 = permute_136 = permute_137 = getitem_108 = getitem_109 = getitem_114 = getitem_115 = None + getitem_345 = _scaled_dot_product_cudnn_attention_backward_19[0] + getitem_346 = _scaled_dot_product_cudnn_attention_backward_19[1] + getitem_347 = _scaled_dot_product_cudnn_attention_backward_19[2]; _scaled_dot_product_cudnn_attention_backward_19 = None + permute_982 = torch.ops.aten.permute.default(getitem_347, [0, 2, 1, 3]); getitem_347 = None + permute_983 = torch.ops.aten.permute.default(getitem_346, [0, 2, 1, 3]); getitem_346 = None + permute_984 = torch.ops.aten.permute.default(getitem_345, [0, 2, 1, 3]); getitem_345 = None + view_1560 = torch.ops.aten.view.default(permute_982, [2, 8192, 8, 4, 128]); permute_982 = None + sum_119 = torch.ops.aten.sum.dim_IntList(view_1560, [3], True); view_1560 = None + squeeze_38 = torch.ops.aten.squeeze.dim(sum_119, 3); sum_119 = None + view_1561 = torch.ops.aten.view.default(permute_983, [2, 8192, 8, 4, 128]); permute_983 = None + sum_120 = torch.ops.aten.sum.dim_IntList(view_1561, [3], True); view_1561 = None + squeeze_39 = torch.ops.aten.squeeze.dim(sum_120, 3); sum_120 = None + convert_element_type_2129 = torch.ops.prims.convert_element_type.default(squeeze_39, torch.float32); squeeze_39 = None + convert_element_type_2130 = torch.ops.prims.convert_element_type.default(permute_984, torch.float32); permute_984 = None + view_1562 = torch.ops.aten.view.default(convert_element_type_2129, [2, 8192, 8, 64, 2]); convert_element_type_2129 = None + view_as_complex_102 = torch.ops.aten.view_as_complex.default(view_1562); view_1562 = None + mul_656 = torch.ops.aten.mul.Tensor(view_as_complex_102, _conj); view_as_complex_102 = None + view_1563 = torch.ops.aten.view.default(convert_element_type_2130, [2, 8192, 32, 64, 2]); convert_element_type_2130 = None + view_as_complex_103 = torch.ops.aten.view_as_complex.default(view_1563); view_1563 = None + mul_657 = torch.ops.aten.mul.Tensor(view_as_complex_103, _conj); view_as_complex_103 = None + view_as_real_102 = torch.ops.aten.view_as_real.default(mul_656); mul_656 = None + view_1564 = torch.ops.aten.view.default(view_as_real_102, [2, 8192, 8, 128]); view_as_real_102 = None + convert_element_type_2131 = torch.ops.prims.convert_element_type.default(view_1564, torch.bfloat16); view_1564 = None + view_as_real_103 = torch.ops.aten.view_as_real.default(mul_657); mul_657 = None + view_1565 = torch.ops.aten.view.default(view_as_real_103, [2, 8192, 32, 128]); view_as_real_103 = None + convert_element_type_2132 = torch.ops.prims.convert_element_type.default(view_1565, torch.bfloat16); view_1565 = None + view_1566 = torch.ops.aten.view.default(squeeze_38, [2, 8192, 1024]); squeeze_38 = None + view_1567 = torch.ops.aten.view.default(convert_element_type_2131, [2, 8192, 1024]); convert_element_type_2131 = None + view_1568 = torch.ops.aten.view.default(convert_element_type_2132, [2, 8192, 4096]); convert_element_type_2132 = None + view_1569 = torch.ops.aten.view.default(view_1566, [16384, 1024]); view_1566 = None + permute_985 = torch.ops.aten.permute.default(view_1569, [1, 0]) + mm_501 = torch.ops.aten.mm.default(permute_985, view_411); permute_985 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16); primals_115 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 64, '0'); convert_element_type_406 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_112, [1, 0]); wait_tensor_112 = None + permute_987 = torch.ops.aten.permute.default(permute_134, [1, 0]); permute_134 = None + mm_502 = torch.ops.aten.mm.default(view_1569, permute_987); view_1569 = permute_987 = None + view_1570 = torch.ops.aten.view.default(mm_502, [2, 8192, 4096]); mm_502 = None + convert_element_type_2137 = torch.ops.prims.convert_element_type.default(mm_501, torch.float32); mm_501 = None + reduce_scatter_tensor_178 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2137, 'avg', 64, '0'); convert_element_type_2137 = None + wait_tensor_469 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_178); reduce_scatter_tensor_178 = None + view_1571 = torch.ops.aten.view.default(view_1567, [16384, 1024]); view_1567 = None + permute_989 = torch.ops.aten.permute.default(view_1571, [1, 0]) + mm_503 = torch.ops.aten.mm.default(permute_989, view_411); permute_989 = None + permute_991 = torch.ops.aten.permute.default(permute_133, [1, 0]); permute_133 = None + mm_504 = torch.ops.aten.mm.default(view_1571, permute_991); view_1571 = permute_991 = None + view_1572 = torch.ops.aten.view.default(mm_504, [2, 8192, 4096]); mm_504 = None + add_266 = torch.ops.aten.add.Tensor(view_1570, view_1572); view_1570 = view_1572 = None + convert_element_type_2142 = torch.ops.prims.convert_element_type.default(mm_503, torch.float32); mm_503 = None + reduce_scatter_tensor_179 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2142, 'avg', 64, '0'); convert_element_type_2142 = None + wait_tensor_470 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_179); reduce_scatter_tensor_179 = None + view_1573 = torch.ops.aten.view.default(view_1568, [16384, 4096]); view_1568 = None + permute_993 = torch.ops.aten.permute.default(view_1573, [1, 0]) + mm_505 = torch.ops.aten.mm.default(permute_993, view_411); permute_993 = view_411 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16); primals_113 = None + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 64, '0'); convert_element_type_400 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + permute_995 = torch.ops.aten.permute.default(permute_132, [1, 0]); permute_132 = None + mm_506 = torch.ops.aten.mm.default(view_1573, permute_995); view_1573 = permute_995 = None + view_1574 = torch.ops.aten.view.default(mm_506, [2, 8192, 4096]); mm_506 = None + add_267 = torch.ops.aten.add.Tensor(add_266, view_1574); add_266 = view_1574 = None + convert_element_type_2147 = torch.ops.prims.convert_element_type.default(mm_505, torch.float32); mm_505 = None + reduce_scatter_tensor_180 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2147, 'avg', 64, '0'); convert_element_type_2147 = None + wait_tensor_471 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_180); reduce_scatter_tensor_180 = None + convert_element_type_2148 = torch.ops.prims.convert_element_type.default(add_267, torch.float32); add_267 = None + convert_element_type_2150 = torch.ops.prims.convert_element_type.default(wait_tensor_109, torch.float32); wait_tensor_109 = None + mul_658 = torch.ops.aten.mul.Tensor(convert_element_type_2148, convert_element_type_2150); convert_element_type_2150 = None + mul_660 = torch.ops.aten.mul.Tensor(mul_96, mul_658) + sum_121 = torch.ops.aten.sum.dim_IntList(mul_660, [2], True); mul_660 = None + div_40 = torch.ops.aten.div.Tensor(mul_96, 4096) + mul_661 = torch.ops.aten.mul.Tensor(div_40, sum_121); div_40 = sum_121 = None + sub_60 = torch.ops.aten.sub.Tensor(mul_658, mul_661); mul_658 = mul_661 = None + mul_662 = torch.ops.aten.mul.Tensor(sub_60, rsqrt_24); sub_60 = rsqrt_24 = None + mul_663 = torch.ops.aten.mul.Tensor(convert_element_type_2148, mul_96); convert_element_type_2148 = mul_96 = None + sum_122 = torch.ops.aten.sum.dim_IntList(mul_663, [0, 1]); mul_663 = None + convert_element_type_2151 = torch.ops.prims.convert_element_type.default(mul_662, torch.bfloat16); mul_662 = None + add_268 = torch.ops.aten.add.Tensor(add_265, convert_element_type_2151); add_265 = convert_element_type_2151 = None + convert_element_type_default_25 = torch.ops.prims.convert_element_type.default(sum_122, torch.float32); sum_122 = None + reduce_scatter_tensor_181 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_25, 'avg', 64, '0'); convert_element_type_default_25 = None + wait_tensor_472 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_181); reduce_scatter_tensor_181 = None + view_1575 = torch.ops.aten.view.default(add_268, [16384, 4096]) + permute_997 = torch.ops.aten.permute.default(view_1575, [1, 0]) + permute_127 = torch.ops.aten.permute.default(getitem_99, [0, 2, 1, 3]) + view_395 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 64, '0'); convert_element_type_380 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_397 = torch.ops.aten.view.default(view_395, [16384, 4096]); view_395 = None + mm_80 = torch.ops.aten.mm.default(view_397, permute_128) + view_398 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + add_45 = torch.ops.aten.add.Tensor(add_43, view_398); view_398 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 64, '0'); convert_element_type_383 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32); add_45 = None + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_105) + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + view_401 = torch.ops.aten.view.default(convert_element_type_385, [16384, 4096]); convert_element_type_385 = None + view_402 = torch.ops.aten.view.default(mm_81, [2, 8192, 14336]); mm_81 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_402, torch.float32); view_402 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16); primals_110 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 64, '0'); convert_element_type_391 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_107, [1, 0]); wait_tensor_107 = None + mm_82 = torch.ops.aten.mm.default(view_401, permute_130) + view_405 = torch.ops.aten.view.default(mm_82, [2, 8192, 14336]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_405) + view_407 = torch.ops.aten.view.default(mul_95, [16384, 14336]); mul_95 = None + mm_507 = torch.ops.aten.mm.default(permute_997, view_407); permute_997 = view_407 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16); primals_111 = None + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 64, '0'); convert_element_type_394 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + permute_999 = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None + mm_508 = torch.ops.aten.mm.default(view_1575, permute_999); view_1575 = permute_999 = None + view_1576 = torch.ops.aten.view.default(mm_508, [2, 8192, 14336]); mm_508 = None + convert_element_type_2158 = torch.ops.prims.convert_element_type.default(mm_507, torch.float32); mm_507 = None + reduce_scatter_tensor_182 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2158, 'avg', 64, '0'); convert_element_type_2158 = None + wait_tensor_473 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_182); reduce_scatter_tensor_182 = None + mul_664 = torch.ops.aten.mul.Tensor(view_1576, convert_element_type_390); convert_element_type_390 = None + mul_665 = torch.ops.aten.mul.Tensor(view_1576, view_405); view_1576 = view_405 = None + view_1577 = torch.ops.aten.view.default(mul_664, [16384, 14336]); mul_664 = None + permute_1001 = torch.ops.aten.permute.default(view_1577, [1, 0]) + mm_509 = torch.ops.aten.mm.default(permute_1001, view_401); permute_1001 = None + permute_1003 = torch.ops.aten.permute.default(permute_130, [1, 0]); permute_130 = None + mm_510 = torch.ops.aten.mm.default(view_1577, permute_1003); view_1577 = permute_1003 = None + view_1578 = torch.ops.aten.view.default(mm_510, [2, 8192, 4096]); mm_510 = None + convert_element_type_2163 = torch.ops.prims.convert_element_type.default(mm_509, torch.float32); mm_509 = None + reduce_scatter_tensor_183 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2163, 'avg', 64, '0'); convert_element_type_2163 = None + wait_tensor_474 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_183); reduce_scatter_tensor_183 = None + convert_element_type_2164 = torch.ops.prims.convert_element_type.default(mul_665, torch.float32); mul_665 = None + neg_20 = torch.ops.aten.neg.default(convert_element_type_389) + exp_20 = torch.ops.aten.exp.default(neg_20); neg_20 = None + add_269 = torch.ops.aten.add.Tensor(exp_20, 1); exp_20 = None + reciprocal_20 = torch.ops.aten.reciprocal.default(add_269); add_269 = None + mul_666 = torch.ops.aten.mul.Tensor(reciprocal_20, 1); reciprocal_20 = None + mul_667 = torch.ops.aten.mul.Tensor(convert_element_type_2164, mul_666); convert_element_type_2164 = None + sub_61 = torch.ops.aten.sub.Tensor(1, mul_666); mul_666 = None + mul_668 = torch.ops.aten.mul.Tensor(convert_element_type_389, sub_61); convert_element_type_389 = sub_61 = None + add_270 = torch.ops.aten.add.Tensor(mul_668, 1); mul_668 = None + mul_669 = torch.ops.aten.mul.Tensor(mul_667, add_270); mul_667 = add_270 = None + convert_element_type_2166 = torch.ops.prims.convert_element_type.default(mul_669, torch.bfloat16); mul_669 = None + view_1579 = torch.ops.aten.view.default(convert_element_type_2166, [16384, 14336]); convert_element_type_2166 = None + permute_1005 = torch.ops.aten.permute.default(view_1579, [1, 0]) + mm_511 = torch.ops.aten.mm.default(permute_1005, view_401); permute_1005 = view_401 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16); primals_109 = None + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 64, '0'); convert_element_type_386 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_106, [1, 0]); wait_tensor_106 = None + permute_1007 = torch.ops.aten.permute.default(permute_129, [1, 0]); permute_129 = None + mm_512 = torch.ops.aten.mm.default(view_1579, permute_1007); view_1579 = permute_1007 = None + view_1580 = torch.ops.aten.view.default(mm_512, [2, 8192, 4096]); mm_512 = None + add_271 = torch.ops.aten.add.Tensor(view_1578, view_1580); view_1578 = view_1580 = None + convert_element_type_2171 = torch.ops.prims.convert_element_type.default(mm_511, torch.float32); mm_511 = None + reduce_scatter_tensor_184 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2171, 'avg', 64, '0'); convert_element_type_2171 = None + wait_tensor_475 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_184); reduce_scatter_tensor_184 = None + convert_element_type_2172 = torch.ops.prims.convert_element_type.default(add_271, torch.float32); add_271 = None + convert_element_type_2174 = torch.ops.prims.convert_element_type.default(wait_tensor_105, torch.float32); wait_tensor_105 = None + mul_670 = torch.ops.aten.mul.Tensor(convert_element_type_2172, convert_element_type_2174); convert_element_type_2174 = None + mul_672 = torch.ops.aten.mul.Tensor(mul_92, mul_670) + sum_123 = torch.ops.aten.sum.dim_IntList(mul_672, [2], True); mul_672 = None + div_41 = torch.ops.aten.div.Tensor(mul_92, 4096) + mul_673 = torch.ops.aten.mul.Tensor(div_41, sum_123); div_41 = sum_123 = None + sub_62 = torch.ops.aten.sub.Tensor(mul_670, mul_673); mul_670 = mul_673 = None + mul_674 = torch.ops.aten.mul.Tensor(sub_62, rsqrt_23); sub_62 = rsqrt_23 = None + mul_675 = torch.ops.aten.mul.Tensor(convert_element_type_2172, mul_92); convert_element_type_2172 = mul_92 = None + sum_124 = torch.ops.aten.sum.dim_IntList(mul_675, [0, 1]); mul_675 = None + convert_element_type_2175 = torch.ops.prims.convert_element_type.default(mul_674, torch.bfloat16); mul_674 = None + add_272 = torch.ops.aten.add.Tensor(add_268, convert_element_type_2175); add_268 = convert_element_type_2175 = None + convert_element_type_default_24 = torch.ops.prims.convert_element_type.default(sum_124, torch.float32); sum_124 = None + reduce_scatter_tensor_185 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_24, 'avg', 64, '0'); convert_element_type_default_24 = None + wait_tensor_476 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_185); reduce_scatter_tensor_185 = None + view_1581 = torch.ops.aten.view.default(add_272, [16384, 4096]) + permute_1009 = torch.ops.aten.permute.default(view_1581, [1, 0]) + mm_513 = torch.ops.aten.mm.default(permute_1009, view_397); permute_1009 = view_397 = None + permute_1011 = torch.ops.aten.permute.default(permute_128, [1, 0]); permute_128 = None + mm_514 = torch.ops.aten.mm.default(view_1581, permute_1011); view_1581 = permute_1011 = None + view_1582 = torch.ops.aten.view.default(mm_514, [2, 8192, 4096]); mm_514 = None + convert_element_type_2182 = torch.ops.prims.convert_element_type.default(mm_513, torch.float32); mm_513 = None + reduce_scatter_tensor_186 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2182, 'avg', 64, '0'); convert_element_type_2182 = None + wait_tensor_477 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_186); reduce_scatter_tensor_186 = None + view_1583 = torch.ops.aten.view.default(view_1582, [2, 8192, 32, 128]); view_1582 = None + permute_1013 = torch.ops.aten.permute.default(view_1583, [0, 2, 1, 3]); view_1583 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 64, '0'); convert_element_type_364 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32); add_43 = None + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_100) + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + view_377 = torch.ops.aten.view.default(convert_element_type_366, [16384, 4096]); convert_element_type_366 = None + view_378 = torch.ops.aten.view.default(mm_77, [2, 8192, 4096]); mm_77 = None + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 64, '0'); convert_element_type_370 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + mm_78 = torch.ops.aten.mm.default(view_377, permute_122) + view_381 = torch.ops.aten.view.default(mm_78, [2, 8192, 1024]); mm_78 = None + view_384 = torch.ops.aten.view.default(mm_79, [2, 8192, 1024]); mm_79 = None + view_385 = torch.ops.aten.view.default(view_378, [2, 8192, -1, 128]); view_378 = None + view_386 = torch.ops.aten.view.default(view_381, [2, 8192, -1, 128]); view_381 = None + view_387 = torch.ops.aten.view.default(view_384, [2, 8192, -1, 128]); view_384 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_385, torch.float32); view_385 = None + view_388 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 32, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_388); view_388 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_386, torch.float32); view_386 = None + view_389 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 8, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_389); view_389 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_16); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_391 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 32, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_16); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_392 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 8, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_391, torch.bfloat16); view_391 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_392, torch.bfloat16); view_392 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 8, 4, 128]); unsqueeze_22 = None + clone_22 = torch.ops.aten.clone.default(expand_22, memory_format = torch.contiguous_format); expand_22 = None + view_393 = torch.ops.aten.view.default(clone_22, [2, 8192, 32, 128]); clone_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_387, 3); view_387 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 8, 4, 128]); unsqueeze_23 = None + clone_23 = torch.ops.aten.clone.default(expand_23, memory_format = torch.contiguous_format); expand_23 = None + view_394 = torch.ops.aten.view.default(clone_23, [2, 8192, 32, 128]); clone_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_393, [0, 2, 1, 3]); view_393 = None + permute_126 = torch.ops.aten.permute.default(view_394, [0, 2, 1, 3]); view_394 = None + _scaled_dot_product_cudnn_attention_backward_20 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1013, permute_124, permute_125, permute_126, getitem_99, getitem_100, getitem_105, getitem_106, None, None, None, 8192, 8192, 0.0, True); permute_1013 = permute_124 = permute_125 = permute_126 = getitem_99 = getitem_100 = getitem_105 = getitem_106 = None + getitem_348 = _scaled_dot_product_cudnn_attention_backward_20[0] + getitem_349 = _scaled_dot_product_cudnn_attention_backward_20[1] + getitem_350 = _scaled_dot_product_cudnn_attention_backward_20[2]; _scaled_dot_product_cudnn_attention_backward_20 = None + permute_1014 = torch.ops.aten.permute.default(getitem_350, [0, 2, 1, 3]); getitem_350 = None + permute_1015 = torch.ops.aten.permute.default(getitem_349, [0, 2, 1, 3]); getitem_349 = None + permute_1016 = torch.ops.aten.permute.default(getitem_348, [0, 2, 1, 3]); getitem_348 = None + view_1584 = torch.ops.aten.view.default(permute_1014, [2, 8192, 8, 4, 128]); permute_1014 = None + sum_125 = torch.ops.aten.sum.dim_IntList(view_1584, [3], True); view_1584 = None + squeeze_40 = torch.ops.aten.squeeze.dim(sum_125, 3); sum_125 = None + view_1585 = torch.ops.aten.view.default(permute_1015, [2, 8192, 8, 4, 128]); permute_1015 = None + sum_126 = torch.ops.aten.sum.dim_IntList(view_1585, [3], True); view_1585 = None + squeeze_41 = torch.ops.aten.squeeze.dim(sum_126, 3); sum_126 = None + convert_element_type_2183 = torch.ops.prims.convert_element_type.default(squeeze_41, torch.float32); squeeze_41 = None + convert_element_type_2184 = torch.ops.prims.convert_element_type.default(permute_1016, torch.float32); permute_1016 = None + view_1586 = torch.ops.aten.view.default(convert_element_type_2183, [2, 8192, 8, 64, 2]); convert_element_type_2183 = None + view_as_complex_104 = torch.ops.aten.view_as_complex.default(view_1586); view_1586 = None + mul_676 = torch.ops.aten.mul.Tensor(view_as_complex_104, _conj); view_as_complex_104 = None + view_1587 = torch.ops.aten.view.default(convert_element_type_2184, [2, 8192, 32, 64, 2]); convert_element_type_2184 = None + view_as_complex_105 = torch.ops.aten.view_as_complex.default(view_1587); view_1587 = None + mul_677 = torch.ops.aten.mul.Tensor(view_as_complex_105, _conj); view_as_complex_105 = None + view_as_real_104 = torch.ops.aten.view_as_real.default(mul_676); mul_676 = None + view_1588 = torch.ops.aten.view.default(view_as_real_104, [2, 8192, 8, 128]); view_as_real_104 = None + convert_element_type_2185 = torch.ops.prims.convert_element_type.default(view_1588, torch.bfloat16); view_1588 = None + view_as_real_105 = torch.ops.aten.view_as_real.default(mul_677); mul_677 = None + view_1589 = torch.ops.aten.view.default(view_as_real_105, [2, 8192, 32, 128]); view_as_real_105 = None + convert_element_type_2186 = torch.ops.prims.convert_element_type.default(view_1589, torch.bfloat16); view_1589 = None + view_1590 = torch.ops.aten.view.default(squeeze_40, [2, 8192, 1024]); squeeze_40 = None + view_1591 = torch.ops.aten.view.default(convert_element_type_2185, [2, 8192, 1024]); convert_element_type_2185 = None + view_1592 = torch.ops.aten.view.default(convert_element_type_2186, [2, 8192, 4096]); convert_element_type_2186 = None + view_1593 = torch.ops.aten.view.default(view_1590, [16384, 1024]); view_1590 = None + permute_1017 = torch.ops.aten.permute.default(view_1593, [1, 0]) + mm_515 = torch.ops.aten.mm.default(permute_1017, view_377); permute_1017 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 64, '0'); convert_element_type_373 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + permute_1019 = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None + mm_516 = torch.ops.aten.mm.default(view_1593, permute_1019); view_1593 = permute_1019 = None + view_1594 = torch.ops.aten.view.default(mm_516, [2, 8192, 4096]); mm_516 = None + convert_element_type_2191 = torch.ops.prims.convert_element_type.default(mm_515, torch.float32); mm_515 = None + reduce_scatter_tensor_187 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2191, 'avg', 64, '0'); convert_element_type_2191 = None + wait_tensor_478 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_187); reduce_scatter_tensor_187 = None + view_1595 = torch.ops.aten.view.default(view_1591, [16384, 1024]); view_1591 = None + permute_1021 = torch.ops.aten.permute.default(view_1595, [1, 0]) + mm_517 = torch.ops.aten.mm.default(permute_1021, view_377); permute_1021 = None + permute_1023 = torch.ops.aten.permute.default(permute_122, [1, 0]); permute_122 = None + mm_518 = torch.ops.aten.mm.default(view_1595, permute_1023); view_1595 = permute_1023 = None + view_1596 = torch.ops.aten.view.default(mm_518, [2, 8192, 4096]); mm_518 = None + add_273 = torch.ops.aten.add.Tensor(view_1594, view_1596); view_1594 = view_1596 = None + convert_element_type_2196 = torch.ops.prims.convert_element_type.default(mm_517, torch.float32); mm_517 = None + reduce_scatter_tensor_188 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2196, 'avg', 64, '0'); convert_element_type_2196 = None + wait_tensor_479 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_188); reduce_scatter_tensor_188 = None + view_1597 = torch.ops.aten.view.default(view_1592, [16384, 4096]); view_1592 = None + permute_1025 = torch.ops.aten.permute.default(view_1597, [1, 0]) + mm_519 = torch.ops.aten.mm.default(permute_1025, view_377); permute_1025 = view_377 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 64, '0'); convert_element_type_367 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_101, [1, 0]); wait_tensor_101 = None + permute_1027 = torch.ops.aten.permute.default(permute_121, [1, 0]); permute_121 = None + mm_520 = torch.ops.aten.mm.default(view_1597, permute_1027); view_1597 = permute_1027 = None + view_1598 = torch.ops.aten.view.default(mm_520, [2, 8192, 4096]); mm_520 = None + add_274 = torch.ops.aten.add.Tensor(add_273, view_1598); add_273 = view_1598 = None + convert_element_type_2201 = torch.ops.prims.convert_element_type.default(mm_519, torch.float32); mm_519 = None + reduce_scatter_tensor_189 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2201, 'avg', 64, '0'); convert_element_type_2201 = None + wait_tensor_480 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_189); reduce_scatter_tensor_189 = None + convert_element_type_2202 = torch.ops.prims.convert_element_type.default(add_274, torch.float32); add_274 = None + convert_element_type_2204 = torch.ops.prims.convert_element_type.default(wait_tensor_100, torch.float32); wait_tensor_100 = None + mul_678 = torch.ops.aten.mul.Tensor(convert_element_type_2202, convert_element_type_2204); convert_element_type_2204 = None + mul_680 = torch.ops.aten.mul.Tensor(mul_88, mul_678) + sum_127 = torch.ops.aten.sum.dim_IntList(mul_680, [2], True); mul_680 = None + div_42 = torch.ops.aten.div.Tensor(mul_88, 4096) + mul_681 = torch.ops.aten.mul.Tensor(div_42, sum_127); div_42 = sum_127 = None + sub_63 = torch.ops.aten.sub.Tensor(mul_678, mul_681); mul_678 = mul_681 = None + mul_682 = torch.ops.aten.mul.Tensor(sub_63, rsqrt_22); sub_63 = rsqrt_22 = None + mul_683 = torch.ops.aten.mul.Tensor(convert_element_type_2202, mul_88); convert_element_type_2202 = mul_88 = None + sum_128 = torch.ops.aten.sum.dim_IntList(mul_683, [0, 1]); mul_683 = None + convert_element_type_2205 = torch.ops.prims.convert_element_type.default(mul_682, torch.bfloat16); mul_682 = None + add_275 = torch.ops.aten.add.Tensor(add_272, convert_element_type_2205); add_272 = convert_element_type_2205 = None + convert_element_type_default_23 = torch.ops.prims.convert_element_type.default(sum_128, torch.float32); sum_128 = None + reduce_scatter_tensor_190 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_23, 'avg', 64, '0'); convert_element_type_default_23 = None + wait_tensor_481 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_190); reduce_scatter_tensor_190 = None + view_1599 = torch.ops.aten.view.default(add_275, [16384, 4096]) + permute_1029 = torch.ops.aten.permute.default(view_1599, [1, 0]) + permute_116 = torch.ops.aten.permute.default(getitem_90, [0, 2, 1, 3]) + view_361 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16); primals_98 = None + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 64, '0'); convert_element_type_347 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_363 = torch.ops.aten.view.default(view_361, [16384, 4096]); view_361 = None + mm_73 = torch.ops.aten.mm.default(view_363, permute_117) + view_364 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + add_41 = torch.ops.aten.add.Tensor(add_39, view_364); view_364 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16); primals_99 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 64, '0'); convert_element_type_350 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32); add_41 = None + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_96) + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + view_367 = torch.ops.aten.view.default(convert_element_type_352, [16384, 4096]); convert_element_type_352 = None + view_368 = torch.ops.aten.view.default(mm_74, [2, 8192, 14336]); mm_74 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_368, torch.float32); view_368 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 64, '0'); convert_element_type_358 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + mm_75 = torch.ops.aten.mm.default(view_367, permute_119) + view_371 = torch.ops.aten.view.default(mm_75, [2, 8192, 14336]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_371) + view_373 = torch.ops.aten.view.default(mul_87, [16384, 14336]); mul_87 = None + mm_521 = torch.ops.aten.mm.default(permute_1029, view_373); permute_1029 = view_373 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 64, '0'); convert_element_type_361 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + permute_1031 = torch.ops.aten.permute.default(permute_120, [1, 0]); permute_120 = None + mm_522 = torch.ops.aten.mm.default(view_1599, permute_1031); view_1599 = permute_1031 = None + view_1600 = torch.ops.aten.view.default(mm_522, [2, 8192, 14336]); mm_522 = None + convert_element_type_2212 = torch.ops.prims.convert_element_type.default(mm_521, torch.float32); mm_521 = None + reduce_scatter_tensor_191 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2212, 'avg', 64, '0'); convert_element_type_2212 = None + wait_tensor_482 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_191); reduce_scatter_tensor_191 = None + mul_684 = torch.ops.aten.mul.Tensor(view_1600, convert_element_type_357); convert_element_type_357 = None + mul_685 = torch.ops.aten.mul.Tensor(view_1600, view_371); view_1600 = view_371 = None + view_1601 = torch.ops.aten.view.default(mul_684, [16384, 14336]); mul_684 = None + permute_1033 = torch.ops.aten.permute.default(view_1601, [1, 0]) + mm_523 = torch.ops.aten.mm.default(permute_1033, view_367); permute_1033 = None + permute_1035 = torch.ops.aten.permute.default(permute_119, [1, 0]); permute_119 = None + mm_524 = torch.ops.aten.mm.default(view_1601, permute_1035); view_1601 = permute_1035 = None + view_1602 = torch.ops.aten.view.default(mm_524, [2, 8192, 4096]); mm_524 = None + convert_element_type_2217 = torch.ops.prims.convert_element_type.default(mm_523, torch.float32); mm_523 = None + reduce_scatter_tensor_192 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2217, 'avg', 64, '0'); convert_element_type_2217 = None + wait_tensor_483 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_192); reduce_scatter_tensor_192 = None + convert_element_type_2218 = torch.ops.prims.convert_element_type.default(mul_685, torch.float32); mul_685 = None + neg_21 = torch.ops.aten.neg.default(convert_element_type_356) + exp_21 = torch.ops.aten.exp.default(neg_21); neg_21 = None + add_276 = torch.ops.aten.add.Tensor(exp_21, 1); exp_21 = None + reciprocal_21 = torch.ops.aten.reciprocal.default(add_276); add_276 = None + mul_686 = torch.ops.aten.mul.Tensor(reciprocal_21, 1); reciprocal_21 = None + mul_687 = torch.ops.aten.mul.Tensor(convert_element_type_2218, mul_686); convert_element_type_2218 = None + sub_64 = torch.ops.aten.sub.Tensor(1, mul_686); mul_686 = None + mul_688 = torch.ops.aten.mul.Tensor(convert_element_type_356, sub_64); convert_element_type_356 = sub_64 = None + add_277 = torch.ops.aten.add.Tensor(mul_688, 1); mul_688 = None + mul_689 = torch.ops.aten.mul.Tensor(mul_687, add_277); mul_687 = add_277 = None + convert_element_type_2220 = torch.ops.prims.convert_element_type.default(mul_689, torch.bfloat16); mul_689 = None + view_1603 = torch.ops.aten.view.default(convert_element_type_2220, [16384, 14336]); convert_element_type_2220 = None + permute_1037 = torch.ops.aten.permute.default(view_1603, [1, 0]) + mm_525 = torch.ops.aten.mm.default(permute_1037, view_367); permute_1037 = view_367 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 64, '0'); convert_element_type_353 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + permute_1039 = torch.ops.aten.permute.default(permute_118, [1, 0]); permute_118 = None + mm_526 = torch.ops.aten.mm.default(view_1603, permute_1039); view_1603 = permute_1039 = None + view_1604 = torch.ops.aten.view.default(mm_526, [2, 8192, 4096]); mm_526 = None + add_278 = torch.ops.aten.add.Tensor(view_1602, view_1604); view_1602 = view_1604 = None + convert_element_type_2225 = torch.ops.prims.convert_element_type.default(mm_525, torch.float32); mm_525 = None + reduce_scatter_tensor_193 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2225, 'avg', 64, '0'); convert_element_type_2225 = None + wait_tensor_484 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_193); reduce_scatter_tensor_193 = None + convert_element_type_2226 = torch.ops.prims.convert_element_type.default(add_278, torch.float32); add_278 = None + convert_element_type_2228 = torch.ops.prims.convert_element_type.default(wait_tensor_96, torch.float32); wait_tensor_96 = None + mul_690 = torch.ops.aten.mul.Tensor(convert_element_type_2226, convert_element_type_2228); convert_element_type_2228 = None + mul_692 = torch.ops.aten.mul.Tensor(mul_84, mul_690) + sum_129 = torch.ops.aten.sum.dim_IntList(mul_692, [2], True); mul_692 = None + div_43 = torch.ops.aten.div.Tensor(mul_84, 4096) + mul_693 = torch.ops.aten.mul.Tensor(div_43, sum_129); div_43 = sum_129 = None + sub_65 = torch.ops.aten.sub.Tensor(mul_690, mul_693); mul_690 = mul_693 = None + mul_694 = torch.ops.aten.mul.Tensor(sub_65, rsqrt_21); sub_65 = rsqrt_21 = None + mul_695 = torch.ops.aten.mul.Tensor(convert_element_type_2226, mul_84); convert_element_type_2226 = mul_84 = None + sum_130 = torch.ops.aten.sum.dim_IntList(mul_695, [0, 1]); mul_695 = None + convert_element_type_2229 = torch.ops.prims.convert_element_type.default(mul_694, torch.bfloat16); mul_694 = None + add_279 = torch.ops.aten.add.Tensor(add_275, convert_element_type_2229); add_275 = convert_element_type_2229 = None + convert_element_type_default_22 = torch.ops.prims.convert_element_type.default(sum_130, torch.float32); sum_130 = None + reduce_scatter_tensor_194 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_22, 'avg', 64, '0'); convert_element_type_default_22 = None + wait_tensor_485 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_194); reduce_scatter_tensor_194 = None + view_1605 = torch.ops.aten.view.default(add_279, [16384, 4096]) + permute_1041 = torch.ops.aten.permute.default(view_1605, [1, 0]) + mm_527 = torch.ops.aten.mm.default(permute_1041, view_363); permute_1041 = view_363 = None + permute_1043 = torch.ops.aten.permute.default(permute_117, [1, 0]); permute_117 = None + mm_528 = torch.ops.aten.mm.default(view_1605, permute_1043); view_1605 = permute_1043 = None + view_1606 = torch.ops.aten.view.default(mm_528, [2, 8192, 4096]); mm_528 = None + convert_element_type_2236 = torch.ops.prims.convert_element_type.default(mm_527, torch.float32); mm_527 = None + reduce_scatter_tensor_195 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2236, 'avg', 64, '0'); convert_element_type_2236 = None + wait_tensor_486 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_195); reduce_scatter_tensor_195 = None + view_1607 = torch.ops.aten.view.default(view_1606, [2, 8192, 32, 128]); view_1606 = None + permute_1045 = torch.ops.aten.permute.default(view_1607, [0, 2, 1, 3]); view_1607 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16); primals_94 = None + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 64, '0'); convert_element_type_331 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32); add_39 = None + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_91) + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + view_343 = torch.ops.aten.view.default(convert_element_type_333, [16384, 4096]); convert_element_type_333 = None + view_344 = torch.ops.aten.view.default(mm_70, [2, 8192, 4096]); mm_70 = None + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16); primals_96 = None + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 64, '0'); convert_element_type_337 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + mm_71 = torch.ops.aten.mm.default(view_343, permute_111) + view_347 = torch.ops.aten.view.default(mm_71, [2, 8192, 1024]); mm_71 = None + view_350 = torch.ops.aten.view.default(mm_72, [2, 8192, 1024]); mm_72 = None + view_351 = torch.ops.aten.view.default(view_344, [2, 8192, -1, 128]); view_344 = None + view_352 = torch.ops.aten.view.default(view_347, [2, 8192, -1, 128]); view_347 = None + view_353 = torch.ops.aten.view.default(view_350, [2, 8192, -1, 128]); view_350 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_351, torch.float32); view_351 = None + view_354 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 32, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_354); view_354 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_352, torch.float32); view_352 = None + view_355 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 8, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_355); view_355 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_16); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_357 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 32, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_16); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_358 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 8, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_357, torch.bfloat16); view_357 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_358, torch.bfloat16); view_358 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 8, 4, 128]); unsqueeze_20 = None + clone_20 = torch.ops.aten.clone.default(expand_20, memory_format = torch.contiguous_format); expand_20 = None + view_359 = torch.ops.aten.view.default(clone_20, [2, 8192, 32, 128]); clone_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_353, 3); view_353 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 8, 4, 128]); unsqueeze_21 = None + clone_21 = torch.ops.aten.clone.default(expand_21, memory_format = torch.contiguous_format); expand_21 = None + view_360 = torch.ops.aten.view.default(clone_21, [2, 8192, 32, 128]); clone_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_359, [0, 2, 1, 3]); view_359 = None + permute_115 = torch.ops.aten.permute.default(view_360, [0, 2, 1, 3]); view_360 = None + _scaled_dot_product_cudnn_attention_backward_21 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1045, permute_113, permute_114, permute_115, getitem_90, getitem_91, getitem_96, getitem_97, None, None, None, 8192, 8192, 0.0, True); permute_1045 = permute_113 = permute_114 = permute_115 = getitem_90 = getitem_91 = getitem_96 = getitem_97 = None + getitem_351 = _scaled_dot_product_cudnn_attention_backward_21[0] + getitem_352 = _scaled_dot_product_cudnn_attention_backward_21[1] + getitem_353 = _scaled_dot_product_cudnn_attention_backward_21[2]; _scaled_dot_product_cudnn_attention_backward_21 = None + permute_1046 = torch.ops.aten.permute.default(getitem_353, [0, 2, 1, 3]); getitem_353 = None + permute_1047 = torch.ops.aten.permute.default(getitem_352, [0, 2, 1, 3]); getitem_352 = None + permute_1048 = torch.ops.aten.permute.default(getitem_351, [0, 2, 1, 3]); getitem_351 = None + view_1608 = torch.ops.aten.view.default(permute_1046, [2, 8192, 8, 4, 128]); permute_1046 = None + sum_131 = torch.ops.aten.sum.dim_IntList(view_1608, [3], True); view_1608 = None + squeeze_42 = torch.ops.aten.squeeze.dim(sum_131, 3); sum_131 = None + view_1609 = torch.ops.aten.view.default(permute_1047, [2, 8192, 8, 4, 128]); permute_1047 = None + sum_132 = torch.ops.aten.sum.dim_IntList(view_1609, [3], True); view_1609 = None + squeeze_43 = torch.ops.aten.squeeze.dim(sum_132, 3); sum_132 = None + convert_element_type_2237 = torch.ops.prims.convert_element_type.default(squeeze_43, torch.float32); squeeze_43 = None + convert_element_type_2238 = torch.ops.prims.convert_element_type.default(permute_1048, torch.float32); permute_1048 = None + view_1610 = torch.ops.aten.view.default(convert_element_type_2237, [2, 8192, 8, 64, 2]); convert_element_type_2237 = None + view_as_complex_106 = torch.ops.aten.view_as_complex.default(view_1610); view_1610 = None + mul_696 = torch.ops.aten.mul.Tensor(view_as_complex_106, _conj); view_as_complex_106 = None + view_1611 = torch.ops.aten.view.default(convert_element_type_2238, [2, 8192, 32, 64, 2]); convert_element_type_2238 = None + view_as_complex_107 = torch.ops.aten.view_as_complex.default(view_1611); view_1611 = None + mul_697 = torch.ops.aten.mul.Tensor(view_as_complex_107, _conj); view_as_complex_107 = None + view_as_real_106 = torch.ops.aten.view_as_real.default(mul_696); mul_696 = None + view_1612 = torch.ops.aten.view.default(view_as_real_106, [2, 8192, 8, 128]); view_as_real_106 = None + convert_element_type_2239 = torch.ops.prims.convert_element_type.default(view_1612, torch.bfloat16); view_1612 = None + view_as_real_107 = torch.ops.aten.view_as_real.default(mul_697); mul_697 = None + view_1613 = torch.ops.aten.view.default(view_as_real_107, [2, 8192, 32, 128]); view_as_real_107 = None + convert_element_type_2240 = torch.ops.prims.convert_element_type.default(view_1613, torch.bfloat16); view_1613 = None + view_1614 = torch.ops.aten.view.default(squeeze_42, [2, 8192, 1024]); squeeze_42 = None + view_1615 = torch.ops.aten.view.default(convert_element_type_2239, [2, 8192, 1024]); convert_element_type_2239 = None + view_1616 = torch.ops.aten.view.default(convert_element_type_2240, [2, 8192, 4096]); convert_element_type_2240 = None + view_1617 = torch.ops.aten.view.default(view_1614, [16384, 1024]); view_1614 = None + permute_1049 = torch.ops.aten.permute.default(view_1617, [1, 0]) + mm_529 = torch.ops.aten.mm.default(permute_1049, view_343); permute_1049 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16); primals_97 = None + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 64, '0'); convert_element_type_340 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + permute_1051 = torch.ops.aten.permute.default(permute_112, [1, 0]); permute_112 = None + mm_530 = torch.ops.aten.mm.default(view_1617, permute_1051); view_1617 = permute_1051 = None + view_1618 = torch.ops.aten.view.default(mm_530, [2, 8192, 4096]); mm_530 = None + convert_element_type_2245 = torch.ops.prims.convert_element_type.default(mm_529, torch.float32); mm_529 = None + reduce_scatter_tensor_196 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2245, 'avg', 64, '0'); convert_element_type_2245 = None + wait_tensor_487 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_196); reduce_scatter_tensor_196 = None + view_1619 = torch.ops.aten.view.default(view_1615, [16384, 1024]); view_1615 = None + permute_1053 = torch.ops.aten.permute.default(view_1619, [1, 0]) + mm_531 = torch.ops.aten.mm.default(permute_1053, view_343); permute_1053 = None + permute_1055 = torch.ops.aten.permute.default(permute_111, [1, 0]); permute_111 = None + mm_532 = torch.ops.aten.mm.default(view_1619, permute_1055); view_1619 = permute_1055 = None + view_1620 = torch.ops.aten.view.default(mm_532, [2, 8192, 4096]); mm_532 = None + add_280 = torch.ops.aten.add.Tensor(view_1618, view_1620); view_1618 = view_1620 = None + convert_element_type_2250 = torch.ops.prims.convert_element_type.default(mm_531, torch.float32); mm_531 = None + reduce_scatter_tensor_197 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2250, 'avg', 64, '0'); convert_element_type_2250 = None + wait_tensor_488 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_197); reduce_scatter_tensor_197 = None + view_1621 = torch.ops.aten.view.default(view_1616, [16384, 4096]); view_1616 = None + permute_1057 = torch.ops.aten.permute.default(view_1621, [1, 0]) + mm_533 = torch.ops.aten.mm.default(permute_1057, view_343); permute_1057 = view_343 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16); primals_95 = None + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 64, '0'); convert_element_type_334 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + permute_1059 = torch.ops.aten.permute.default(permute_110, [1, 0]); permute_110 = None + mm_534 = torch.ops.aten.mm.default(view_1621, permute_1059); view_1621 = permute_1059 = None + view_1622 = torch.ops.aten.view.default(mm_534, [2, 8192, 4096]); mm_534 = None + add_281 = torch.ops.aten.add.Tensor(add_280, view_1622); add_280 = view_1622 = None + convert_element_type_2255 = torch.ops.prims.convert_element_type.default(mm_533, torch.float32); mm_533 = None + reduce_scatter_tensor_198 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2255, 'avg', 64, '0'); convert_element_type_2255 = None + wait_tensor_489 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_198); reduce_scatter_tensor_198 = None + convert_element_type_2256 = torch.ops.prims.convert_element_type.default(add_281, torch.float32); add_281 = None + convert_element_type_2258 = torch.ops.prims.convert_element_type.default(wait_tensor_91, torch.float32); wait_tensor_91 = None + mul_698 = torch.ops.aten.mul.Tensor(convert_element_type_2256, convert_element_type_2258); convert_element_type_2258 = None + mul_700 = torch.ops.aten.mul.Tensor(mul_80, mul_698) + sum_133 = torch.ops.aten.sum.dim_IntList(mul_700, [2], True); mul_700 = None + div_44 = torch.ops.aten.div.Tensor(mul_80, 4096) + mul_701 = torch.ops.aten.mul.Tensor(div_44, sum_133); div_44 = sum_133 = None + sub_66 = torch.ops.aten.sub.Tensor(mul_698, mul_701); mul_698 = mul_701 = None + mul_702 = torch.ops.aten.mul.Tensor(sub_66, rsqrt_20); sub_66 = rsqrt_20 = None + mul_703 = torch.ops.aten.mul.Tensor(convert_element_type_2256, mul_80); convert_element_type_2256 = mul_80 = None + sum_134 = torch.ops.aten.sum.dim_IntList(mul_703, [0, 1]); mul_703 = None + convert_element_type_2259 = torch.ops.prims.convert_element_type.default(mul_702, torch.bfloat16); mul_702 = None + add_282 = torch.ops.aten.add.Tensor(add_279, convert_element_type_2259); add_279 = convert_element_type_2259 = None + convert_element_type_default_21 = torch.ops.prims.convert_element_type.default(sum_134, torch.float32); sum_134 = None + reduce_scatter_tensor_199 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_21, 'avg', 64, '0'); convert_element_type_default_21 = None + wait_tensor_490 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_199); reduce_scatter_tensor_199 = None + view_1623 = torch.ops.aten.view.default(add_282, [16384, 4096]) + permute_1061 = torch.ops.aten.permute.default(view_1623, [1, 0]) + permute_105 = torch.ops.aten.permute.default(getitem_81, [0, 2, 1, 3]) + view_327 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 64, '0'); convert_element_type_314 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_86, [1, 0]); wait_tensor_86 = None + view_329 = torch.ops.aten.view.default(view_327, [16384, 4096]); view_327 = None + mm_66 = torch.ops.aten.mm.default(view_329, permute_106) + view_330 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + add_37 = torch.ops.aten.add.Tensor(add_35, view_330); view_330 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 64, '0'); convert_element_type_317 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32); add_37 = None + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_87) + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + view_333 = torch.ops.aten.view.default(convert_element_type_319, [16384, 4096]); convert_element_type_319 = None + view_334 = torch.ops.aten.view.default(mm_67, [2, 8192, 14336]); mm_67 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_334, torch.float32); view_334 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16); primals_92 = None + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 64, '0'); convert_element_type_325 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + mm_68 = torch.ops.aten.mm.default(view_333, permute_108) + view_337 = torch.ops.aten.view.default(mm_68, [2, 8192, 14336]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_337) + view_339 = torch.ops.aten.view.default(mul_79, [16384, 14336]); mul_79 = None + mm_535 = torch.ops.aten.mm.default(permute_1061, view_339); permute_1061 = view_339 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16); primals_93 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 64, '0'); convert_element_type_328 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + permute_1063 = torch.ops.aten.permute.default(permute_109, [1, 0]); permute_109 = None + mm_536 = torch.ops.aten.mm.default(view_1623, permute_1063); view_1623 = permute_1063 = None + view_1624 = torch.ops.aten.view.default(mm_536, [2, 8192, 14336]); mm_536 = None + convert_element_type_2266 = torch.ops.prims.convert_element_type.default(mm_535, torch.float32); mm_535 = None + reduce_scatter_tensor_200 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2266, 'avg', 64, '0'); convert_element_type_2266 = None + wait_tensor_491 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_200); reduce_scatter_tensor_200 = None + mul_704 = torch.ops.aten.mul.Tensor(view_1624, convert_element_type_324); convert_element_type_324 = None + mul_705 = torch.ops.aten.mul.Tensor(view_1624, view_337); view_1624 = view_337 = None + view_1625 = torch.ops.aten.view.default(mul_704, [16384, 14336]); mul_704 = None + permute_1065 = torch.ops.aten.permute.default(view_1625, [1, 0]) + mm_537 = torch.ops.aten.mm.default(permute_1065, view_333); permute_1065 = None + permute_1067 = torch.ops.aten.permute.default(permute_108, [1, 0]); permute_108 = None + mm_538 = torch.ops.aten.mm.default(view_1625, permute_1067); view_1625 = permute_1067 = None + view_1626 = torch.ops.aten.view.default(mm_538, [2, 8192, 4096]); mm_538 = None + convert_element_type_2271 = torch.ops.prims.convert_element_type.default(mm_537, torch.float32); mm_537 = None + reduce_scatter_tensor_201 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2271, 'avg', 64, '0'); convert_element_type_2271 = None + wait_tensor_492 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_201); reduce_scatter_tensor_201 = None + convert_element_type_2272 = torch.ops.prims.convert_element_type.default(mul_705, torch.float32); mul_705 = None + neg_22 = torch.ops.aten.neg.default(convert_element_type_323) + exp_22 = torch.ops.aten.exp.default(neg_22); neg_22 = None + add_283 = torch.ops.aten.add.Tensor(exp_22, 1); exp_22 = None + reciprocal_22 = torch.ops.aten.reciprocal.default(add_283); add_283 = None + mul_706 = torch.ops.aten.mul.Tensor(reciprocal_22, 1); reciprocal_22 = None + mul_707 = torch.ops.aten.mul.Tensor(convert_element_type_2272, mul_706); convert_element_type_2272 = None + sub_67 = torch.ops.aten.sub.Tensor(1, mul_706); mul_706 = None + mul_708 = torch.ops.aten.mul.Tensor(convert_element_type_323, sub_67); convert_element_type_323 = sub_67 = None + add_284 = torch.ops.aten.add.Tensor(mul_708, 1); mul_708 = None + mul_709 = torch.ops.aten.mul.Tensor(mul_707, add_284); mul_707 = add_284 = None + convert_element_type_2274 = torch.ops.prims.convert_element_type.default(mul_709, torch.bfloat16); mul_709 = None + view_1627 = torch.ops.aten.view.default(convert_element_type_2274, [16384, 14336]); convert_element_type_2274 = None + permute_1069 = torch.ops.aten.permute.default(view_1627, [1, 0]) + mm_539 = torch.ops.aten.mm.default(permute_1069, view_333); permute_1069 = view_333 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16); primals_91 = None + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 64, '0'); convert_element_type_320 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_88, [1, 0]); wait_tensor_88 = None + permute_1071 = torch.ops.aten.permute.default(permute_107, [1, 0]); permute_107 = None + mm_540 = torch.ops.aten.mm.default(view_1627, permute_1071); view_1627 = permute_1071 = None + view_1628 = torch.ops.aten.view.default(mm_540, [2, 8192, 4096]); mm_540 = None + add_285 = torch.ops.aten.add.Tensor(view_1626, view_1628); view_1626 = view_1628 = None + convert_element_type_2279 = torch.ops.prims.convert_element_type.default(mm_539, torch.float32); mm_539 = None + reduce_scatter_tensor_202 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2279, 'avg', 64, '0'); convert_element_type_2279 = None + wait_tensor_493 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_202); reduce_scatter_tensor_202 = None + convert_element_type_2280 = torch.ops.prims.convert_element_type.default(add_285, torch.float32); add_285 = None + convert_element_type_2282 = torch.ops.prims.convert_element_type.default(wait_tensor_87, torch.float32); wait_tensor_87 = None + mul_710 = torch.ops.aten.mul.Tensor(convert_element_type_2280, convert_element_type_2282); convert_element_type_2282 = None + mul_712 = torch.ops.aten.mul.Tensor(mul_76, mul_710) + sum_135 = torch.ops.aten.sum.dim_IntList(mul_712, [2], True); mul_712 = None + div_45 = torch.ops.aten.div.Tensor(mul_76, 4096) + mul_713 = torch.ops.aten.mul.Tensor(div_45, sum_135); div_45 = sum_135 = None + sub_68 = torch.ops.aten.sub.Tensor(mul_710, mul_713); mul_710 = mul_713 = None + mul_714 = torch.ops.aten.mul.Tensor(sub_68, rsqrt_19); sub_68 = rsqrt_19 = None + mul_715 = torch.ops.aten.mul.Tensor(convert_element_type_2280, mul_76); convert_element_type_2280 = mul_76 = None + sum_136 = torch.ops.aten.sum.dim_IntList(mul_715, [0, 1]); mul_715 = None + convert_element_type_2283 = torch.ops.prims.convert_element_type.default(mul_714, torch.bfloat16); mul_714 = None + add_286 = torch.ops.aten.add.Tensor(add_282, convert_element_type_2283); add_282 = convert_element_type_2283 = None + convert_element_type_default_20 = torch.ops.prims.convert_element_type.default(sum_136, torch.float32); sum_136 = None + reduce_scatter_tensor_203 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_20, 'avg', 64, '0'); convert_element_type_default_20 = None + wait_tensor_494 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_203); reduce_scatter_tensor_203 = None + view_1629 = torch.ops.aten.view.default(add_286, [16384, 4096]) + permute_1073 = torch.ops.aten.permute.default(view_1629, [1, 0]) + mm_541 = torch.ops.aten.mm.default(permute_1073, view_329); permute_1073 = view_329 = None + permute_1075 = torch.ops.aten.permute.default(permute_106, [1, 0]); permute_106 = None + mm_542 = torch.ops.aten.mm.default(view_1629, permute_1075); view_1629 = permute_1075 = None + view_1630 = torch.ops.aten.view.default(mm_542, [2, 8192, 4096]); mm_542 = None + convert_element_type_2290 = torch.ops.prims.convert_element_type.default(mm_541, torch.float32); mm_541 = None + reduce_scatter_tensor_204 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2290, 'avg', 64, '0'); convert_element_type_2290 = None + wait_tensor_495 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_204); reduce_scatter_tensor_204 = None + view_1631 = torch.ops.aten.view.default(view_1630, [2, 8192, 32, 128]); view_1630 = None + permute_1077 = torch.ops.aten.permute.default(view_1631, [0, 2, 1, 3]); view_1631 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 64, '0'); convert_element_type_298 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_82) + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + view_309 = torch.ops.aten.view.default(convert_element_type_300, [16384, 4096]); convert_element_type_300 = None + view_310 = torch.ops.aten.view.default(mm_63, [2, 8192, 4096]); mm_63 = None + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 64, '0'); convert_element_type_304 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_64 = torch.ops.aten.mm.default(view_309, permute_100) + view_313 = torch.ops.aten.view.default(mm_64, [2, 8192, 1024]); mm_64 = None + view_316 = torch.ops.aten.view.default(mm_65, [2, 8192, 1024]); mm_65 = None + view_317 = torch.ops.aten.view.default(view_310, [2, 8192, -1, 128]); view_310 = None + view_318 = torch.ops.aten.view.default(view_313, [2, 8192, -1, 128]); view_313 = None + view_319 = torch.ops.aten.view.default(view_316, [2, 8192, -1, 128]); view_316 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_317, torch.float32); view_317 = None + view_320 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 32, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_320); view_320 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_318, torch.float32); view_318 = None + view_321 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 8, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_321); view_321 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_16); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_323 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 32, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_16); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_324 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 8, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_323, torch.bfloat16); view_323 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_324, torch.bfloat16); view_324 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 8, 4, 128]); unsqueeze_18 = None + clone_18 = torch.ops.aten.clone.default(expand_18, memory_format = torch.contiguous_format); expand_18 = None + view_325 = torch.ops.aten.view.default(clone_18, [2, 8192, 32, 128]); clone_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_319, 3); view_319 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 8, 4, 128]); unsqueeze_19 = None + clone_19 = torch.ops.aten.clone.default(expand_19, memory_format = torch.contiguous_format); expand_19 = None + view_326 = torch.ops.aten.view.default(clone_19, [2, 8192, 32, 128]); clone_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_325, [0, 2, 1, 3]); view_325 = None + permute_104 = torch.ops.aten.permute.default(view_326, [0, 2, 1, 3]); view_326 = None + _scaled_dot_product_cudnn_attention_backward_22 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1077, permute_102, permute_103, permute_104, getitem_81, getitem_82, getitem_87, getitem_88, None, None, None, 8192, 8192, 0.0, True); permute_1077 = permute_102 = permute_103 = permute_104 = getitem_81 = getitem_82 = getitem_87 = getitem_88 = None + getitem_354 = _scaled_dot_product_cudnn_attention_backward_22[0] + getitem_355 = _scaled_dot_product_cudnn_attention_backward_22[1] + getitem_356 = _scaled_dot_product_cudnn_attention_backward_22[2]; _scaled_dot_product_cudnn_attention_backward_22 = None + permute_1078 = torch.ops.aten.permute.default(getitem_356, [0, 2, 1, 3]); getitem_356 = None + permute_1079 = torch.ops.aten.permute.default(getitem_355, [0, 2, 1, 3]); getitem_355 = None + permute_1080 = torch.ops.aten.permute.default(getitem_354, [0, 2, 1, 3]); getitem_354 = None + view_1632 = torch.ops.aten.view.default(permute_1078, [2, 8192, 8, 4, 128]); permute_1078 = None + sum_137 = torch.ops.aten.sum.dim_IntList(view_1632, [3], True); view_1632 = None + squeeze_44 = torch.ops.aten.squeeze.dim(sum_137, 3); sum_137 = None + view_1633 = torch.ops.aten.view.default(permute_1079, [2, 8192, 8, 4, 128]); permute_1079 = None + sum_138 = torch.ops.aten.sum.dim_IntList(view_1633, [3], True); view_1633 = None + squeeze_45 = torch.ops.aten.squeeze.dim(sum_138, 3); sum_138 = None + convert_element_type_2291 = torch.ops.prims.convert_element_type.default(squeeze_45, torch.float32); squeeze_45 = None + convert_element_type_2292 = torch.ops.prims.convert_element_type.default(permute_1080, torch.float32); permute_1080 = None + view_1634 = torch.ops.aten.view.default(convert_element_type_2291, [2, 8192, 8, 64, 2]); convert_element_type_2291 = None + view_as_complex_108 = torch.ops.aten.view_as_complex.default(view_1634); view_1634 = None + mul_716 = torch.ops.aten.mul.Tensor(view_as_complex_108, _conj); view_as_complex_108 = None + view_1635 = torch.ops.aten.view.default(convert_element_type_2292, [2, 8192, 32, 64, 2]); convert_element_type_2292 = None + view_as_complex_109 = torch.ops.aten.view_as_complex.default(view_1635); view_1635 = None + mul_717 = torch.ops.aten.mul.Tensor(view_as_complex_109, _conj); view_as_complex_109 = None + view_as_real_108 = torch.ops.aten.view_as_real.default(mul_716); mul_716 = None + view_1636 = torch.ops.aten.view.default(view_as_real_108, [2, 8192, 8, 128]); view_as_real_108 = None + convert_element_type_2293 = torch.ops.prims.convert_element_type.default(view_1636, torch.bfloat16); view_1636 = None + view_as_real_109 = torch.ops.aten.view_as_real.default(mul_717); mul_717 = None + view_1637 = torch.ops.aten.view.default(view_as_real_109, [2, 8192, 32, 128]); view_as_real_109 = None + convert_element_type_2294 = torch.ops.prims.convert_element_type.default(view_1637, torch.bfloat16); view_1637 = None + view_1638 = torch.ops.aten.view.default(squeeze_44, [2, 8192, 1024]); squeeze_44 = None + view_1639 = torch.ops.aten.view.default(convert_element_type_2293, [2, 8192, 1024]); convert_element_type_2293 = None + view_1640 = torch.ops.aten.view.default(convert_element_type_2294, [2, 8192, 4096]); convert_element_type_2294 = None + view_1641 = torch.ops.aten.view.default(view_1638, [16384, 1024]); view_1638 = None + permute_1081 = torch.ops.aten.permute.default(view_1641, [1, 0]) + mm_543 = torch.ops.aten.mm.default(permute_1081, view_309); permute_1081 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 64, '0'); convert_element_type_307 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + permute_1083 = torch.ops.aten.permute.default(permute_101, [1, 0]); permute_101 = None + mm_544 = torch.ops.aten.mm.default(view_1641, permute_1083); view_1641 = permute_1083 = None + view_1642 = torch.ops.aten.view.default(mm_544, [2, 8192, 4096]); mm_544 = None + convert_element_type_2299 = torch.ops.prims.convert_element_type.default(mm_543, torch.float32); mm_543 = None + reduce_scatter_tensor_205 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2299, 'avg', 64, '0'); convert_element_type_2299 = None + wait_tensor_496 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_205); reduce_scatter_tensor_205 = None + view_1643 = torch.ops.aten.view.default(view_1639, [16384, 1024]); view_1639 = None + permute_1085 = torch.ops.aten.permute.default(view_1643, [1, 0]) + mm_545 = torch.ops.aten.mm.default(permute_1085, view_309); permute_1085 = None + permute_1087 = torch.ops.aten.permute.default(permute_100, [1, 0]); permute_100 = None + mm_546 = torch.ops.aten.mm.default(view_1643, permute_1087); view_1643 = permute_1087 = None + view_1644 = torch.ops.aten.view.default(mm_546, [2, 8192, 4096]); mm_546 = None + add_287 = torch.ops.aten.add.Tensor(view_1642, view_1644); view_1642 = view_1644 = None + convert_element_type_2304 = torch.ops.prims.convert_element_type.default(mm_545, torch.float32); mm_545 = None + reduce_scatter_tensor_206 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2304, 'avg', 64, '0'); convert_element_type_2304 = None + wait_tensor_497 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_206); reduce_scatter_tensor_206 = None + view_1645 = torch.ops.aten.view.default(view_1640, [16384, 4096]); view_1640 = None + permute_1089 = torch.ops.aten.permute.default(view_1645, [1, 0]) + mm_547 = torch.ops.aten.mm.default(permute_1089, view_309); permute_1089 = view_309 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 64, '0'); convert_element_type_301 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + permute_1091 = torch.ops.aten.permute.default(permute_99, [1, 0]); permute_99 = None + mm_548 = torch.ops.aten.mm.default(view_1645, permute_1091); view_1645 = permute_1091 = None + view_1646 = torch.ops.aten.view.default(mm_548, [2, 8192, 4096]); mm_548 = None + add_288 = torch.ops.aten.add.Tensor(add_287, view_1646); add_287 = view_1646 = None + convert_element_type_2309 = torch.ops.prims.convert_element_type.default(mm_547, torch.float32); mm_547 = None + reduce_scatter_tensor_207 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2309, 'avg', 64, '0'); convert_element_type_2309 = None + wait_tensor_498 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_207); reduce_scatter_tensor_207 = None + convert_element_type_2310 = torch.ops.prims.convert_element_type.default(add_288, torch.float32); add_288 = None + convert_element_type_2312 = torch.ops.prims.convert_element_type.default(wait_tensor_82, torch.float32); wait_tensor_82 = None + mul_718 = torch.ops.aten.mul.Tensor(convert_element_type_2310, convert_element_type_2312); convert_element_type_2312 = None + mul_720 = torch.ops.aten.mul.Tensor(mul_72, mul_718) + sum_139 = torch.ops.aten.sum.dim_IntList(mul_720, [2], True); mul_720 = None + div_46 = torch.ops.aten.div.Tensor(mul_72, 4096) + mul_721 = torch.ops.aten.mul.Tensor(div_46, sum_139); div_46 = sum_139 = None + sub_69 = torch.ops.aten.sub.Tensor(mul_718, mul_721); mul_718 = mul_721 = None + mul_722 = torch.ops.aten.mul.Tensor(sub_69, rsqrt_18); sub_69 = rsqrt_18 = None + mul_723 = torch.ops.aten.mul.Tensor(convert_element_type_2310, mul_72); convert_element_type_2310 = mul_72 = None + sum_140 = torch.ops.aten.sum.dim_IntList(mul_723, [0, 1]); mul_723 = None + convert_element_type_2313 = torch.ops.prims.convert_element_type.default(mul_722, torch.bfloat16); mul_722 = None + add_289 = torch.ops.aten.add.Tensor(add_286, convert_element_type_2313); add_286 = convert_element_type_2313 = None + convert_element_type_default_19 = torch.ops.prims.convert_element_type.default(sum_140, torch.float32); sum_140 = None + reduce_scatter_tensor_208 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_19, 'avg', 64, '0'); convert_element_type_default_19 = None + wait_tensor_499 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_208); reduce_scatter_tensor_208 = None + view_1647 = torch.ops.aten.view.default(add_289, [16384, 4096]) + permute_1093 = torch.ops.aten.permute.default(view_1647, [1, 0]) + permute_94 = torch.ops.aten.permute.default(getitem_72, [0, 2, 1, 3]) + view_293 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16); primals_80 = None + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 64, '0'); convert_element_type_281 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + view_295 = torch.ops.aten.view.default(view_293, [16384, 4096]); view_293 = None + mm_59 = torch.ops.aten.mm.default(view_295, permute_95) + view_296 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + add_33 = torch.ops.aten.add.Tensor(add_31, view_296); view_296 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16); primals_81 = None + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 64, '0'); convert_element_type_284 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32); add_33 = None + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_78) + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + view_299 = torch.ops.aten.view.default(convert_element_type_286, [16384, 4096]); convert_element_type_286 = None + view_300 = torch.ops.aten.view.default(mm_60, [2, 8192, 14336]); mm_60 = None + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 64, '0'); convert_element_type_292 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_80, [1, 0]); wait_tensor_80 = None + mm_61 = torch.ops.aten.mm.default(view_299, permute_97) + view_303 = torch.ops.aten.view.default(mm_61, [2, 8192, 14336]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_303) + view_305 = torch.ops.aten.view.default(mul_71, [16384, 14336]); mul_71 = None + mm_549 = torch.ops.aten.mm.default(permute_1093, view_305); permute_1093 = view_305 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 64, '0'); convert_element_type_295 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_81, [1, 0]); wait_tensor_81 = None + permute_1095 = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None + mm_550 = torch.ops.aten.mm.default(view_1647, permute_1095); view_1647 = permute_1095 = None + view_1648 = torch.ops.aten.view.default(mm_550, [2, 8192, 14336]); mm_550 = None + convert_element_type_2320 = torch.ops.prims.convert_element_type.default(mm_549, torch.float32); mm_549 = None + reduce_scatter_tensor_209 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2320, 'avg', 64, '0'); convert_element_type_2320 = None + wait_tensor_500 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_209); reduce_scatter_tensor_209 = None + mul_724 = torch.ops.aten.mul.Tensor(view_1648, convert_element_type_291); convert_element_type_291 = None + mul_725 = torch.ops.aten.mul.Tensor(view_1648, view_303); view_1648 = view_303 = None + view_1649 = torch.ops.aten.view.default(mul_724, [16384, 14336]); mul_724 = None + permute_1097 = torch.ops.aten.permute.default(view_1649, [1, 0]) + mm_551 = torch.ops.aten.mm.default(permute_1097, view_299); permute_1097 = None + permute_1099 = torch.ops.aten.permute.default(permute_97, [1, 0]); permute_97 = None + mm_552 = torch.ops.aten.mm.default(view_1649, permute_1099); view_1649 = permute_1099 = None + view_1650 = torch.ops.aten.view.default(mm_552, [2, 8192, 4096]); mm_552 = None + convert_element_type_2325 = torch.ops.prims.convert_element_type.default(mm_551, torch.float32); mm_551 = None + reduce_scatter_tensor_210 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2325, 'avg', 64, '0'); convert_element_type_2325 = None + wait_tensor_501 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_210); reduce_scatter_tensor_210 = None + convert_element_type_2326 = torch.ops.prims.convert_element_type.default(mul_725, torch.float32); mul_725 = None + neg_23 = torch.ops.aten.neg.default(convert_element_type_290) + exp_23 = torch.ops.aten.exp.default(neg_23); neg_23 = None + add_290 = torch.ops.aten.add.Tensor(exp_23, 1); exp_23 = None + reciprocal_23 = torch.ops.aten.reciprocal.default(add_290); add_290 = None + mul_726 = torch.ops.aten.mul.Tensor(reciprocal_23, 1); reciprocal_23 = None + mul_727 = torch.ops.aten.mul.Tensor(convert_element_type_2326, mul_726); convert_element_type_2326 = None + sub_70 = torch.ops.aten.sub.Tensor(1, mul_726); mul_726 = None + mul_728 = torch.ops.aten.mul.Tensor(convert_element_type_290, sub_70); convert_element_type_290 = sub_70 = None + add_291 = torch.ops.aten.add.Tensor(mul_728, 1); mul_728 = None + mul_729 = torch.ops.aten.mul.Tensor(mul_727, add_291); mul_727 = add_291 = None + convert_element_type_2328 = torch.ops.prims.convert_element_type.default(mul_729, torch.bfloat16); mul_729 = None + view_1651 = torch.ops.aten.view.default(convert_element_type_2328, [16384, 14336]); convert_element_type_2328 = None + permute_1101 = torch.ops.aten.permute.default(view_1651, [1, 0]) + mm_553 = torch.ops.aten.mm.default(permute_1101, view_299); permute_1101 = view_299 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 64, '0'); convert_element_type_287 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + permute_1103 = torch.ops.aten.permute.default(permute_96, [1, 0]); permute_96 = None + mm_554 = torch.ops.aten.mm.default(view_1651, permute_1103); view_1651 = permute_1103 = None + view_1652 = torch.ops.aten.view.default(mm_554, [2, 8192, 4096]); mm_554 = None + add_292 = torch.ops.aten.add.Tensor(view_1650, view_1652); view_1650 = view_1652 = None + convert_element_type_2333 = torch.ops.prims.convert_element_type.default(mm_553, torch.float32); mm_553 = None + reduce_scatter_tensor_211 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2333, 'avg', 64, '0'); convert_element_type_2333 = None + wait_tensor_502 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_211); reduce_scatter_tensor_211 = None + convert_element_type_2334 = torch.ops.prims.convert_element_type.default(add_292, torch.float32); add_292 = None + convert_element_type_2336 = torch.ops.prims.convert_element_type.default(wait_tensor_78, torch.float32); wait_tensor_78 = None + mul_730 = torch.ops.aten.mul.Tensor(convert_element_type_2334, convert_element_type_2336); convert_element_type_2336 = None + mul_732 = torch.ops.aten.mul.Tensor(mul_68, mul_730) + sum_141 = torch.ops.aten.sum.dim_IntList(mul_732, [2], True); mul_732 = None + div_47 = torch.ops.aten.div.Tensor(mul_68, 4096) + mul_733 = torch.ops.aten.mul.Tensor(div_47, sum_141); div_47 = sum_141 = None + sub_71 = torch.ops.aten.sub.Tensor(mul_730, mul_733); mul_730 = mul_733 = None + mul_734 = torch.ops.aten.mul.Tensor(sub_71, rsqrt_17); sub_71 = rsqrt_17 = None + mul_735 = torch.ops.aten.mul.Tensor(convert_element_type_2334, mul_68); convert_element_type_2334 = mul_68 = None + sum_142 = torch.ops.aten.sum.dim_IntList(mul_735, [0, 1]); mul_735 = None + convert_element_type_2337 = torch.ops.prims.convert_element_type.default(mul_734, torch.bfloat16); mul_734 = None + add_293 = torch.ops.aten.add.Tensor(add_289, convert_element_type_2337); add_289 = convert_element_type_2337 = None + convert_element_type_default_18 = torch.ops.prims.convert_element_type.default(sum_142, torch.float32); sum_142 = None + reduce_scatter_tensor_212 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_18, 'avg', 64, '0'); convert_element_type_default_18 = None + wait_tensor_503 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_212); reduce_scatter_tensor_212 = None + view_1653 = torch.ops.aten.view.default(add_293, [16384, 4096]) + permute_1105 = torch.ops.aten.permute.default(view_1653, [1, 0]) + mm_555 = torch.ops.aten.mm.default(permute_1105, view_295); permute_1105 = view_295 = None + permute_1107 = torch.ops.aten.permute.default(permute_95, [1, 0]); permute_95 = None + mm_556 = torch.ops.aten.mm.default(view_1653, permute_1107); view_1653 = permute_1107 = None + view_1654 = torch.ops.aten.view.default(mm_556, [2, 8192, 4096]); mm_556 = None + convert_element_type_2344 = torch.ops.prims.convert_element_type.default(mm_555, torch.float32); mm_555 = None + reduce_scatter_tensor_213 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2344, 'avg', 64, '0'); convert_element_type_2344 = None + wait_tensor_504 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_213); reduce_scatter_tensor_213 = None + view_1655 = torch.ops.aten.view.default(view_1654, [2, 8192, 32, 128]); view_1654 = None + permute_1109 = torch.ops.aten.permute.default(view_1655, [0, 2, 1, 3]); view_1655 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16); primals_76 = None + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 64, '0'); convert_element_type_265 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32); add_31 = None + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_73) + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + view_275 = torch.ops.aten.view.default(convert_element_type_267, [16384, 4096]); convert_element_type_267 = None + view_276 = torch.ops.aten.view.default(mm_56, [2, 8192, 4096]); mm_56 = None + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16); primals_78 = None + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 64, '0'); convert_element_type_271 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + mm_57 = torch.ops.aten.mm.default(view_275, permute_89) + view_279 = torch.ops.aten.view.default(mm_57, [2, 8192, 1024]); mm_57 = None + view_282 = torch.ops.aten.view.default(mm_58, [2, 8192, 1024]); mm_58 = None + view_283 = torch.ops.aten.view.default(view_276, [2, 8192, -1, 128]); view_276 = None + view_284 = torch.ops.aten.view.default(view_279, [2, 8192, -1, 128]); view_279 = None + view_285 = torch.ops.aten.view.default(view_282, [2, 8192, -1, 128]); view_282 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_283, torch.float32); view_283 = None + view_286 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 32, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_286); view_286 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_284, torch.float32); view_284 = None + view_287 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 8, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_287); view_287 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_16); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_289 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 32, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_16); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_290 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 8, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_289, torch.bfloat16); view_289 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_290, torch.bfloat16); view_290 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 8, 4, 128]); unsqueeze_16 = None + clone_16 = torch.ops.aten.clone.default(expand_16, memory_format = torch.contiguous_format); expand_16 = None + view_291 = torch.ops.aten.view.default(clone_16, [2, 8192, 32, 128]); clone_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_285, 3); view_285 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 8, 4, 128]); unsqueeze_17 = None + clone_17 = torch.ops.aten.clone.default(expand_17, memory_format = torch.contiguous_format); expand_17 = None + view_292 = torch.ops.aten.view.default(clone_17, [2, 8192, 32, 128]); clone_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_291, [0, 2, 1, 3]); view_291 = None + permute_93 = torch.ops.aten.permute.default(view_292, [0, 2, 1, 3]); view_292 = None + _scaled_dot_product_cudnn_attention_backward_23 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1109, permute_91, permute_92, permute_93, getitem_72, getitem_73, getitem_78, getitem_79, None, None, None, 8192, 8192, 0.0, True); permute_1109 = permute_91 = permute_92 = permute_93 = getitem_72 = getitem_73 = getitem_78 = getitem_79 = None + getitem_357 = _scaled_dot_product_cudnn_attention_backward_23[0] + getitem_358 = _scaled_dot_product_cudnn_attention_backward_23[1] + getitem_359 = _scaled_dot_product_cudnn_attention_backward_23[2]; _scaled_dot_product_cudnn_attention_backward_23 = None + permute_1110 = torch.ops.aten.permute.default(getitem_359, [0, 2, 1, 3]); getitem_359 = None + permute_1111 = torch.ops.aten.permute.default(getitem_358, [0, 2, 1, 3]); getitem_358 = None + permute_1112 = torch.ops.aten.permute.default(getitem_357, [0, 2, 1, 3]); getitem_357 = None + view_1656 = torch.ops.aten.view.default(permute_1110, [2, 8192, 8, 4, 128]); permute_1110 = None + sum_143 = torch.ops.aten.sum.dim_IntList(view_1656, [3], True); view_1656 = None + squeeze_46 = torch.ops.aten.squeeze.dim(sum_143, 3); sum_143 = None + view_1657 = torch.ops.aten.view.default(permute_1111, [2, 8192, 8, 4, 128]); permute_1111 = None + sum_144 = torch.ops.aten.sum.dim_IntList(view_1657, [3], True); view_1657 = None + squeeze_47 = torch.ops.aten.squeeze.dim(sum_144, 3); sum_144 = None + convert_element_type_2345 = torch.ops.prims.convert_element_type.default(squeeze_47, torch.float32); squeeze_47 = None + convert_element_type_2346 = torch.ops.prims.convert_element_type.default(permute_1112, torch.float32); permute_1112 = None + view_1658 = torch.ops.aten.view.default(convert_element_type_2345, [2, 8192, 8, 64, 2]); convert_element_type_2345 = None + view_as_complex_110 = torch.ops.aten.view_as_complex.default(view_1658); view_1658 = None + mul_736 = torch.ops.aten.mul.Tensor(view_as_complex_110, _conj); view_as_complex_110 = None + view_1659 = torch.ops.aten.view.default(convert_element_type_2346, [2, 8192, 32, 64, 2]); convert_element_type_2346 = None + view_as_complex_111 = torch.ops.aten.view_as_complex.default(view_1659); view_1659 = None + mul_737 = torch.ops.aten.mul.Tensor(view_as_complex_111, _conj); view_as_complex_111 = None + view_as_real_110 = torch.ops.aten.view_as_real.default(mul_736); mul_736 = None + view_1660 = torch.ops.aten.view.default(view_as_real_110, [2, 8192, 8, 128]); view_as_real_110 = None + convert_element_type_2347 = torch.ops.prims.convert_element_type.default(view_1660, torch.bfloat16); view_1660 = None + view_as_real_111 = torch.ops.aten.view_as_real.default(mul_737); mul_737 = None + view_1661 = torch.ops.aten.view.default(view_as_real_111, [2, 8192, 32, 128]); view_as_real_111 = None + convert_element_type_2348 = torch.ops.prims.convert_element_type.default(view_1661, torch.bfloat16); view_1661 = None + view_1662 = torch.ops.aten.view.default(squeeze_46, [2, 8192, 1024]); squeeze_46 = None + view_1663 = torch.ops.aten.view.default(convert_element_type_2347, [2, 8192, 1024]); convert_element_type_2347 = None + view_1664 = torch.ops.aten.view.default(convert_element_type_2348, [2, 8192, 4096]); convert_element_type_2348 = None + view_1665 = torch.ops.aten.view.default(view_1662, [16384, 1024]); view_1662 = None + permute_1113 = torch.ops.aten.permute.default(view_1665, [1, 0]) + mm_557 = torch.ops.aten.mm.default(permute_1113, view_275); permute_1113 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16); primals_79 = None + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 64, '0'); convert_element_type_274 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + permute_1115 = torch.ops.aten.permute.default(permute_90, [1, 0]); permute_90 = None + mm_558 = torch.ops.aten.mm.default(view_1665, permute_1115); view_1665 = permute_1115 = None + view_1666 = torch.ops.aten.view.default(mm_558, [2, 8192, 4096]); mm_558 = None + convert_element_type_2353 = torch.ops.prims.convert_element_type.default(mm_557, torch.float32); mm_557 = None + reduce_scatter_tensor_214 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2353, 'avg', 64, '0'); convert_element_type_2353 = None + wait_tensor_505 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_214); reduce_scatter_tensor_214 = None + view_1667 = torch.ops.aten.view.default(view_1663, [16384, 1024]); view_1663 = None + permute_1117 = torch.ops.aten.permute.default(view_1667, [1, 0]) + mm_559 = torch.ops.aten.mm.default(permute_1117, view_275); permute_1117 = None + permute_1119 = torch.ops.aten.permute.default(permute_89, [1, 0]); permute_89 = None + mm_560 = torch.ops.aten.mm.default(view_1667, permute_1119); view_1667 = permute_1119 = None + view_1668 = torch.ops.aten.view.default(mm_560, [2, 8192, 4096]); mm_560 = None + add_294 = torch.ops.aten.add.Tensor(view_1666, view_1668); view_1666 = view_1668 = None + convert_element_type_2358 = torch.ops.prims.convert_element_type.default(mm_559, torch.float32); mm_559 = None + reduce_scatter_tensor_215 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2358, 'avg', 64, '0'); convert_element_type_2358 = None + wait_tensor_506 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_215); reduce_scatter_tensor_215 = None + view_1669 = torch.ops.aten.view.default(view_1664, [16384, 4096]); view_1664 = None + permute_1121 = torch.ops.aten.permute.default(view_1669, [1, 0]) + mm_561 = torch.ops.aten.mm.default(permute_1121, view_275); permute_1121 = view_275 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16); primals_77 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 64, '0'); convert_element_type_268 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_74, [1, 0]); wait_tensor_74 = None + permute_1123 = torch.ops.aten.permute.default(permute_88, [1, 0]); permute_88 = None + mm_562 = torch.ops.aten.mm.default(view_1669, permute_1123); view_1669 = permute_1123 = None + view_1670 = torch.ops.aten.view.default(mm_562, [2, 8192, 4096]); mm_562 = None + add_295 = torch.ops.aten.add.Tensor(add_294, view_1670); add_294 = view_1670 = None + convert_element_type_2363 = torch.ops.prims.convert_element_type.default(mm_561, torch.float32); mm_561 = None + reduce_scatter_tensor_216 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2363, 'avg', 64, '0'); convert_element_type_2363 = None + wait_tensor_507 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_216); reduce_scatter_tensor_216 = None + convert_element_type_2364 = torch.ops.prims.convert_element_type.default(add_295, torch.float32); add_295 = None + convert_element_type_2366 = torch.ops.prims.convert_element_type.default(wait_tensor_73, torch.float32); wait_tensor_73 = None + mul_738 = torch.ops.aten.mul.Tensor(convert_element_type_2364, convert_element_type_2366); convert_element_type_2366 = None + mul_740 = torch.ops.aten.mul.Tensor(mul_64, mul_738) + sum_145 = torch.ops.aten.sum.dim_IntList(mul_740, [2], True); mul_740 = None + div_48 = torch.ops.aten.div.Tensor(mul_64, 4096) + mul_741 = torch.ops.aten.mul.Tensor(div_48, sum_145); div_48 = sum_145 = None + sub_72 = torch.ops.aten.sub.Tensor(mul_738, mul_741); mul_738 = mul_741 = None + mul_742 = torch.ops.aten.mul.Tensor(sub_72, rsqrt_16); sub_72 = rsqrt_16 = None + mul_743 = torch.ops.aten.mul.Tensor(convert_element_type_2364, mul_64); convert_element_type_2364 = mul_64 = None + sum_146 = torch.ops.aten.sum.dim_IntList(mul_743, [0, 1]); mul_743 = None + convert_element_type_2367 = torch.ops.prims.convert_element_type.default(mul_742, torch.bfloat16); mul_742 = None + add_296 = torch.ops.aten.add.Tensor(add_293, convert_element_type_2367); add_293 = convert_element_type_2367 = None + convert_element_type_default_17 = torch.ops.prims.convert_element_type.default(sum_146, torch.float32); sum_146 = None + reduce_scatter_tensor_217 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_17, 'avg', 64, '0'); convert_element_type_default_17 = None + wait_tensor_508 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_217); reduce_scatter_tensor_217 = None + view_1671 = torch.ops.aten.view.default(add_296, [16384, 4096]) + permute_1125 = torch.ops.aten.permute.default(view_1671, [1, 0]) + permute_83 = torch.ops.aten.permute.default(getitem_63, [0, 2, 1, 3]) + view_259 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 64, '0'); convert_element_type_248 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_68, [1, 0]); wait_tensor_68 = None + view_261 = torch.ops.aten.view.default(view_259, [16384, 4096]); view_259 = None + mm_52 = torch.ops.aten.mm.default(view_261, permute_84) + view_262 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + add_29 = torch.ops.aten.add.Tensor(add_27, view_262); view_262 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 64, '0'); convert_element_type_251 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32); add_29 = None + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_69) + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + view_265 = torch.ops.aten.view.default(convert_element_type_253, [16384, 4096]); convert_element_type_253 = None + view_266 = torch.ops.aten.view.default(mm_53, [2, 8192, 14336]); mm_53 = None + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_266, torch.float32); view_266 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16); primals_74 = None + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 64, '0'); convert_element_type_259 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_54 = torch.ops.aten.mm.default(view_265, permute_86) + view_269 = torch.ops.aten.view.default(mm_54, [2, 8192, 14336]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_269) + view_271 = torch.ops.aten.view.default(mul_63, [16384, 14336]); mul_63 = None + mm_563 = torch.ops.aten.mm.default(permute_1125, view_271); permute_1125 = view_271 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16); primals_75 = None + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 64, '0'); convert_element_type_262 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + permute_1127 = torch.ops.aten.permute.default(permute_87, [1, 0]); permute_87 = None + mm_564 = torch.ops.aten.mm.default(view_1671, permute_1127); view_1671 = permute_1127 = None + view_1672 = torch.ops.aten.view.default(mm_564, [2, 8192, 14336]); mm_564 = None + convert_element_type_2374 = torch.ops.prims.convert_element_type.default(mm_563, torch.float32); mm_563 = None + reduce_scatter_tensor_218 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2374, 'avg', 64, '0'); convert_element_type_2374 = None + wait_tensor_509 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_218); reduce_scatter_tensor_218 = None + mul_744 = torch.ops.aten.mul.Tensor(view_1672, convert_element_type_258); convert_element_type_258 = None + mul_745 = torch.ops.aten.mul.Tensor(view_1672, view_269); view_1672 = view_269 = None + view_1673 = torch.ops.aten.view.default(mul_744, [16384, 14336]); mul_744 = None + permute_1129 = torch.ops.aten.permute.default(view_1673, [1, 0]) + mm_565 = torch.ops.aten.mm.default(permute_1129, view_265); permute_1129 = None + permute_1131 = torch.ops.aten.permute.default(permute_86, [1, 0]); permute_86 = None + mm_566 = torch.ops.aten.mm.default(view_1673, permute_1131); view_1673 = permute_1131 = None + view_1674 = torch.ops.aten.view.default(mm_566, [2, 8192, 4096]); mm_566 = None + convert_element_type_2379 = torch.ops.prims.convert_element_type.default(mm_565, torch.float32); mm_565 = None + reduce_scatter_tensor_219 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2379, 'avg', 64, '0'); convert_element_type_2379 = None + wait_tensor_510 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_219); reduce_scatter_tensor_219 = None + convert_element_type_2380 = torch.ops.prims.convert_element_type.default(mul_745, torch.float32); mul_745 = None + neg_24 = torch.ops.aten.neg.default(convert_element_type_257) + exp_24 = torch.ops.aten.exp.default(neg_24); neg_24 = None + add_297 = torch.ops.aten.add.Tensor(exp_24, 1); exp_24 = None + reciprocal_24 = torch.ops.aten.reciprocal.default(add_297); add_297 = None + mul_746 = torch.ops.aten.mul.Tensor(reciprocal_24, 1); reciprocal_24 = None + mul_747 = torch.ops.aten.mul.Tensor(convert_element_type_2380, mul_746); convert_element_type_2380 = None + sub_73 = torch.ops.aten.sub.Tensor(1, mul_746); mul_746 = None + mul_748 = torch.ops.aten.mul.Tensor(convert_element_type_257, sub_73); convert_element_type_257 = sub_73 = None + add_298 = torch.ops.aten.add.Tensor(mul_748, 1); mul_748 = None + mul_749 = torch.ops.aten.mul.Tensor(mul_747, add_298); mul_747 = add_298 = None + convert_element_type_2382 = torch.ops.prims.convert_element_type.default(mul_749, torch.bfloat16); mul_749 = None + view_1675 = torch.ops.aten.view.default(convert_element_type_2382, [16384, 14336]); convert_element_type_2382 = None + permute_1133 = torch.ops.aten.permute.default(view_1675, [1, 0]) + mm_567 = torch.ops.aten.mm.default(permute_1133, view_265); permute_1133 = view_265 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16); primals_73 = None + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 64, '0'); convert_element_type_254 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + permute_1135 = torch.ops.aten.permute.default(permute_85, [1, 0]); permute_85 = None + mm_568 = torch.ops.aten.mm.default(view_1675, permute_1135); view_1675 = permute_1135 = None + view_1676 = torch.ops.aten.view.default(mm_568, [2, 8192, 4096]); mm_568 = None + add_299 = torch.ops.aten.add.Tensor(view_1674, view_1676); view_1674 = view_1676 = None + convert_element_type_2387 = torch.ops.prims.convert_element_type.default(mm_567, torch.float32); mm_567 = None + reduce_scatter_tensor_220 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2387, 'avg', 64, '0'); convert_element_type_2387 = None + wait_tensor_511 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_220); reduce_scatter_tensor_220 = None + convert_element_type_2388 = torch.ops.prims.convert_element_type.default(add_299, torch.float32); add_299 = None + convert_element_type_2390 = torch.ops.prims.convert_element_type.default(wait_tensor_69, torch.float32); wait_tensor_69 = None + mul_750 = torch.ops.aten.mul.Tensor(convert_element_type_2388, convert_element_type_2390); convert_element_type_2390 = None + mul_752 = torch.ops.aten.mul.Tensor(mul_60, mul_750) + sum_147 = torch.ops.aten.sum.dim_IntList(mul_752, [2], True); mul_752 = None + div_49 = torch.ops.aten.div.Tensor(mul_60, 4096) + mul_753 = torch.ops.aten.mul.Tensor(div_49, sum_147); div_49 = sum_147 = None + sub_74 = torch.ops.aten.sub.Tensor(mul_750, mul_753); mul_750 = mul_753 = None + mul_754 = torch.ops.aten.mul.Tensor(sub_74, rsqrt_15); sub_74 = rsqrt_15 = None + mul_755 = torch.ops.aten.mul.Tensor(convert_element_type_2388, mul_60); convert_element_type_2388 = mul_60 = None + sum_148 = torch.ops.aten.sum.dim_IntList(mul_755, [0, 1]); mul_755 = None + convert_element_type_2391 = torch.ops.prims.convert_element_type.default(mul_754, torch.bfloat16); mul_754 = None + add_300 = torch.ops.aten.add.Tensor(add_296, convert_element_type_2391); add_296 = convert_element_type_2391 = None + convert_element_type_default_16 = torch.ops.prims.convert_element_type.default(sum_148, torch.float32); sum_148 = None + reduce_scatter_tensor_221 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_16, 'avg', 64, '0'); convert_element_type_default_16 = None + wait_tensor_512 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_221); reduce_scatter_tensor_221 = None + view_1677 = torch.ops.aten.view.default(add_300, [16384, 4096]) + permute_1137 = torch.ops.aten.permute.default(view_1677, [1, 0]) + mm_569 = torch.ops.aten.mm.default(permute_1137, view_261); permute_1137 = view_261 = None + permute_1139 = torch.ops.aten.permute.default(permute_84, [1, 0]); permute_84 = None + mm_570 = torch.ops.aten.mm.default(view_1677, permute_1139); view_1677 = permute_1139 = None + view_1678 = torch.ops.aten.view.default(mm_570, [2, 8192, 4096]); mm_570 = None + convert_element_type_2398 = torch.ops.prims.convert_element_type.default(mm_569, torch.float32); mm_569 = None + reduce_scatter_tensor_222 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2398, 'avg', 64, '0'); convert_element_type_2398 = None + wait_tensor_513 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_222); reduce_scatter_tensor_222 = None + view_1679 = torch.ops.aten.view.default(view_1678, [2, 8192, 32, 128]); view_1678 = None + permute_1141 = torch.ops.aten.permute.default(view_1679, [0, 2, 1, 3]); view_1679 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 64, '0'); convert_element_type_232 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32); add_27 = None + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_64) + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + view_241 = torch.ops.aten.view.default(convert_element_type_234, [16384, 4096]); convert_element_type_234 = None + view_242 = torch.ops.aten.view.default(mm_49, [2, 8192, 4096]); mm_49 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 64, '0'); convert_element_type_238 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_66, [1, 0]); wait_tensor_66 = None + mm_50 = torch.ops.aten.mm.default(view_241, permute_78) + view_245 = torch.ops.aten.view.default(mm_50, [2, 8192, 1024]); mm_50 = None + view_248 = torch.ops.aten.view.default(mm_51, [2, 8192, 1024]); mm_51 = None + view_249 = torch.ops.aten.view.default(view_242, [2, 8192, -1, 128]); view_242 = None + view_250 = torch.ops.aten.view.default(view_245, [2, 8192, -1, 128]); view_245 = None + view_251 = torch.ops.aten.view.default(view_248, [2, 8192, -1, 128]); view_248 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 32, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_250, torch.float32); view_250 = None + view_253 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 8, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_253); view_253 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_16); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_255 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 32, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_16); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_256 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 8, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_256, torch.bfloat16); view_256 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 8, 4, 128]); unsqueeze_14 = None + clone_14 = torch.ops.aten.clone.default(expand_14, memory_format = torch.contiguous_format); expand_14 = None + view_257 = torch.ops.aten.view.default(clone_14, [2, 8192, 32, 128]); clone_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_251, 3); view_251 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 8, 4, 128]); unsqueeze_15 = None + clone_15 = torch.ops.aten.clone.default(expand_15, memory_format = torch.contiguous_format); expand_15 = None + view_258 = torch.ops.aten.view.default(clone_15, [2, 8192, 32, 128]); clone_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + permute_82 = torch.ops.aten.permute.default(view_258, [0, 2, 1, 3]); view_258 = None + _scaled_dot_product_cudnn_attention_backward_24 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1141, permute_80, permute_81, permute_82, getitem_63, getitem_64, getitem_69, getitem_70, None, None, None, 8192, 8192, 0.0, True); permute_1141 = permute_80 = permute_81 = permute_82 = getitem_63 = getitem_64 = getitem_69 = getitem_70 = None + getitem_360 = _scaled_dot_product_cudnn_attention_backward_24[0] + getitem_361 = _scaled_dot_product_cudnn_attention_backward_24[1] + getitem_362 = _scaled_dot_product_cudnn_attention_backward_24[2]; _scaled_dot_product_cudnn_attention_backward_24 = None + permute_1142 = torch.ops.aten.permute.default(getitem_362, [0, 2, 1, 3]); getitem_362 = None + permute_1143 = torch.ops.aten.permute.default(getitem_361, [0, 2, 1, 3]); getitem_361 = None + permute_1144 = torch.ops.aten.permute.default(getitem_360, [0, 2, 1, 3]); getitem_360 = None + view_1680 = torch.ops.aten.view.default(permute_1142, [2, 8192, 8, 4, 128]); permute_1142 = None + sum_149 = torch.ops.aten.sum.dim_IntList(view_1680, [3], True); view_1680 = None + squeeze_48 = torch.ops.aten.squeeze.dim(sum_149, 3); sum_149 = None + view_1681 = torch.ops.aten.view.default(permute_1143, [2, 8192, 8, 4, 128]); permute_1143 = None + sum_150 = torch.ops.aten.sum.dim_IntList(view_1681, [3], True); view_1681 = None + squeeze_49 = torch.ops.aten.squeeze.dim(sum_150, 3); sum_150 = None + convert_element_type_2399 = torch.ops.prims.convert_element_type.default(squeeze_49, torch.float32); squeeze_49 = None + convert_element_type_2400 = torch.ops.prims.convert_element_type.default(permute_1144, torch.float32); permute_1144 = None + view_1682 = torch.ops.aten.view.default(convert_element_type_2399, [2, 8192, 8, 64, 2]); convert_element_type_2399 = None + view_as_complex_112 = torch.ops.aten.view_as_complex.default(view_1682); view_1682 = None + mul_756 = torch.ops.aten.mul.Tensor(view_as_complex_112, _conj); view_as_complex_112 = None + view_1683 = torch.ops.aten.view.default(convert_element_type_2400, [2, 8192, 32, 64, 2]); convert_element_type_2400 = None + view_as_complex_113 = torch.ops.aten.view_as_complex.default(view_1683); view_1683 = None + mul_757 = torch.ops.aten.mul.Tensor(view_as_complex_113, _conj); view_as_complex_113 = None + view_as_real_112 = torch.ops.aten.view_as_real.default(mul_756); mul_756 = None + view_1684 = torch.ops.aten.view.default(view_as_real_112, [2, 8192, 8, 128]); view_as_real_112 = None + convert_element_type_2401 = torch.ops.prims.convert_element_type.default(view_1684, torch.bfloat16); view_1684 = None + view_as_real_113 = torch.ops.aten.view_as_real.default(mul_757); mul_757 = None + view_1685 = torch.ops.aten.view.default(view_as_real_113, [2, 8192, 32, 128]); view_as_real_113 = None + convert_element_type_2402 = torch.ops.prims.convert_element_type.default(view_1685, torch.bfloat16); view_1685 = None + view_1686 = torch.ops.aten.view.default(squeeze_48, [2, 8192, 1024]); squeeze_48 = None + view_1687 = torch.ops.aten.view.default(convert_element_type_2401, [2, 8192, 1024]); convert_element_type_2401 = None + view_1688 = torch.ops.aten.view.default(convert_element_type_2402, [2, 8192, 4096]); convert_element_type_2402 = None + view_1689 = torch.ops.aten.view.default(view_1686, [16384, 1024]); view_1686 = None + permute_1145 = torch.ops.aten.permute.default(view_1689, [1, 0]) + mm_571 = torch.ops.aten.mm.default(permute_1145, view_241); permute_1145 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 64, '0'); convert_element_type_241 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_67, [1, 0]); wait_tensor_67 = None + permute_1147 = torch.ops.aten.permute.default(permute_79, [1, 0]); permute_79 = None + mm_572 = torch.ops.aten.mm.default(view_1689, permute_1147); view_1689 = permute_1147 = None + view_1690 = torch.ops.aten.view.default(mm_572, [2, 8192, 4096]); mm_572 = None + convert_element_type_2407 = torch.ops.prims.convert_element_type.default(mm_571, torch.float32); mm_571 = None + reduce_scatter_tensor_223 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2407, 'avg', 64, '0'); convert_element_type_2407 = None + wait_tensor_514 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_223); reduce_scatter_tensor_223 = None + view_1691 = torch.ops.aten.view.default(view_1687, [16384, 1024]); view_1687 = None + permute_1149 = torch.ops.aten.permute.default(view_1691, [1, 0]) + mm_573 = torch.ops.aten.mm.default(permute_1149, view_241); permute_1149 = None + permute_1151 = torch.ops.aten.permute.default(permute_78, [1, 0]); permute_78 = None + mm_574 = torch.ops.aten.mm.default(view_1691, permute_1151); view_1691 = permute_1151 = None + view_1692 = torch.ops.aten.view.default(mm_574, [2, 8192, 4096]); mm_574 = None + add_301 = torch.ops.aten.add.Tensor(view_1690, view_1692); view_1690 = view_1692 = None + convert_element_type_2412 = torch.ops.prims.convert_element_type.default(mm_573, torch.float32); mm_573 = None + reduce_scatter_tensor_224 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2412, 'avg', 64, '0'); convert_element_type_2412 = None + wait_tensor_515 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_224); reduce_scatter_tensor_224 = None + view_1693 = torch.ops.aten.view.default(view_1688, [16384, 4096]); view_1688 = None + permute_1153 = torch.ops.aten.permute.default(view_1693, [1, 0]) + mm_575 = torch.ops.aten.mm.default(permute_1153, view_241); permute_1153 = view_241 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 64, '0'); convert_element_type_235 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + permute_1155 = torch.ops.aten.permute.default(permute_77, [1, 0]); permute_77 = None + mm_576 = torch.ops.aten.mm.default(view_1693, permute_1155); view_1693 = permute_1155 = None + view_1694 = torch.ops.aten.view.default(mm_576, [2, 8192, 4096]); mm_576 = None + add_302 = torch.ops.aten.add.Tensor(add_301, view_1694); add_301 = view_1694 = None + convert_element_type_2417 = torch.ops.prims.convert_element_type.default(mm_575, torch.float32); mm_575 = None + reduce_scatter_tensor_225 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2417, 'avg', 64, '0'); convert_element_type_2417 = None + wait_tensor_516 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_225); reduce_scatter_tensor_225 = None + convert_element_type_2418 = torch.ops.prims.convert_element_type.default(add_302, torch.float32); add_302 = None + convert_element_type_2420 = torch.ops.prims.convert_element_type.default(wait_tensor_64, torch.float32); wait_tensor_64 = None + mul_758 = torch.ops.aten.mul.Tensor(convert_element_type_2418, convert_element_type_2420); convert_element_type_2420 = None + mul_760 = torch.ops.aten.mul.Tensor(mul_56, mul_758) + sum_151 = torch.ops.aten.sum.dim_IntList(mul_760, [2], True); mul_760 = None + div_50 = torch.ops.aten.div.Tensor(mul_56, 4096) + mul_761 = torch.ops.aten.mul.Tensor(div_50, sum_151); div_50 = sum_151 = None + sub_75 = torch.ops.aten.sub.Tensor(mul_758, mul_761); mul_758 = mul_761 = None + mul_762 = torch.ops.aten.mul.Tensor(sub_75, rsqrt_14); sub_75 = rsqrt_14 = None + mul_763 = torch.ops.aten.mul.Tensor(convert_element_type_2418, mul_56); convert_element_type_2418 = mul_56 = None + sum_152 = torch.ops.aten.sum.dim_IntList(mul_763, [0, 1]); mul_763 = None + convert_element_type_2421 = torch.ops.prims.convert_element_type.default(mul_762, torch.bfloat16); mul_762 = None + add_303 = torch.ops.aten.add.Tensor(add_300, convert_element_type_2421); add_300 = convert_element_type_2421 = None + convert_element_type_default_15 = torch.ops.prims.convert_element_type.default(sum_152, torch.float32); sum_152 = None + reduce_scatter_tensor_226 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_15, 'avg', 64, '0'); convert_element_type_default_15 = None + wait_tensor_517 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_226); reduce_scatter_tensor_226 = None + view_1695 = torch.ops.aten.view.default(add_303, [16384, 4096]) + permute_1157 = torch.ops.aten.permute.default(view_1695, [1, 0]) + permute_72 = torch.ops.aten.permute.default(getitem_54, [0, 2, 1, 3]) + view_225 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16); primals_62 = None + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 64, '0'); convert_element_type_215 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_227 = torch.ops.aten.view.default(view_225, [16384, 4096]); view_225 = None + mm_45 = torch.ops.aten.mm.default(view_227, permute_73) + view_228 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + add_25 = torch.ops.aten.add.Tensor(add_23, view_228); view_228 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16); primals_63 = None + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 64, '0'); convert_element_type_218 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32); add_25 = None + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_60) + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + view_231 = torch.ops.aten.view.default(convert_element_type_220, [16384, 4096]); convert_element_type_220 = None + view_232 = torch.ops.aten.view.default(mm_46, [2, 8192, 14336]); mm_46 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_232, torch.float32); view_232 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 64, '0'); convert_element_type_226 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_62, [1, 0]); wait_tensor_62 = None + mm_47 = torch.ops.aten.mm.default(view_231, permute_75) + view_235 = torch.ops.aten.view.default(mm_47, [2, 8192, 14336]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_235) + view_237 = torch.ops.aten.view.default(mul_55, [16384, 14336]); mul_55 = None + mm_577 = torch.ops.aten.mm.default(permute_1157, view_237); permute_1157 = view_237 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 64, '0'); convert_element_type_229 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + permute_1159 = torch.ops.aten.permute.default(permute_76, [1, 0]); permute_76 = None + mm_578 = torch.ops.aten.mm.default(view_1695, permute_1159); view_1695 = permute_1159 = None + view_1696 = torch.ops.aten.view.default(mm_578, [2, 8192, 14336]); mm_578 = None + convert_element_type_2428 = torch.ops.prims.convert_element_type.default(mm_577, torch.float32); mm_577 = None + reduce_scatter_tensor_227 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2428, 'avg', 64, '0'); convert_element_type_2428 = None + wait_tensor_518 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_227); reduce_scatter_tensor_227 = None + mul_764 = torch.ops.aten.mul.Tensor(view_1696, convert_element_type_225); convert_element_type_225 = None + mul_765 = torch.ops.aten.mul.Tensor(view_1696, view_235); view_1696 = view_235 = None + view_1697 = torch.ops.aten.view.default(mul_764, [16384, 14336]); mul_764 = None + permute_1161 = torch.ops.aten.permute.default(view_1697, [1, 0]) + mm_579 = torch.ops.aten.mm.default(permute_1161, view_231); permute_1161 = None + permute_1163 = torch.ops.aten.permute.default(permute_75, [1, 0]); permute_75 = None + mm_580 = torch.ops.aten.mm.default(view_1697, permute_1163); view_1697 = permute_1163 = None + view_1698 = torch.ops.aten.view.default(mm_580, [2, 8192, 4096]); mm_580 = None + convert_element_type_2433 = torch.ops.prims.convert_element_type.default(mm_579, torch.float32); mm_579 = None + reduce_scatter_tensor_228 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2433, 'avg', 64, '0'); convert_element_type_2433 = None + wait_tensor_519 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_228); reduce_scatter_tensor_228 = None + convert_element_type_2434 = torch.ops.prims.convert_element_type.default(mul_765, torch.float32); mul_765 = None + neg_25 = torch.ops.aten.neg.default(convert_element_type_224) + exp_25 = torch.ops.aten.exp.default(neg_25); neg_25 = None + add_304 = torch.ops.aten.add.Tensor(exp_25, 1); exp_25 = None + reciprocal_25 = torch.ops.aten.reciprocal.default(add_304); add_304 = None + mul_766 = torch.ops.aten.mul.Tensor(reciprocal_25, 1); reciprocal_25 = None + mul_767 = torch.ops.aten.mul.Tensor(convert_element_type_2434, mul_766); convert_element_type_2434 = None + sub_76 = torch.ops.aten.sub.Tensor(1, mul_766); mul_766 = None + mul_768 = torch.ops.aten.mul.Tensor(convert_element_type_224, sub_76); convert_element_type_224 = sub_76 = None + add_305 = torch.ops.aten.add.Tensor(mul_768, 1); mul_768 = None + mul_769 = torch.ops.aten.mul.Tensor(mul_767, add_305); mul_767 = add_305 = None + convert_element_type_2436 = torch.ops.prims.convert_element_type.default(mul_769, torch.bfloat16); mul_769 = None + view_1699 = torch.ops.aten.view.default(convert_element_type_2436, [16384, 14336]); convert_element_type_2436 = None + permute_1165 = torch.ops.aten.permute.default(view_1699, [1, 0]) + mm_581 = torch.ops.aten.mm.default(permute_1165, view_231); permute_1165 = view_231 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16); primals_64 = None + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 64, '0'); convert_element_type_221 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_61, [1, 0]); wait_tensor_61 = None + permute_1167 = torch.ops.aten.permute.default(permute_74, [1, 0]); permute_74 = None + mm_582 = torch.ops.aten.mm.default(view_1699, permute_1167); view_1699 = permute_1167 = None + view_1700 = torch.ops.aten.view.default(mm_582, [2, 8192, 4096]); mm_582 = None + add_306 = torch.ops.aten.add.Tensor(view_1698, view_1700); view_1698 = view_1700 = None + convert_element_type_2441 = torch.ops.prims.convert_element_type.default(mm_581, torch.float32); mm_581 = None + reduce_scatter_tensor_229 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2441, 'avg', 64, '0'); convert_element_type_2441 = None + wait_tensor_520 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_229); reduce_scatter_tensor_229 = None + convert_element_type_2442 = torch.ops.prims.convert_element_type.default(add_306, torch.float32); add_306 = None + convert_element_type_2444 = torch.ops.prims.convert_element_type.default(wait_tensor_60, torch.float32); wait_tensor_60 = None + mul_770 = torch.ops.aten.mul.Tensor(convert_element_type_2442, convert_element_type_2444); convert_element_type_2444 = None + mul_772 = torch.ops.aten.mul.Tensor(mul_52, mul_770) + sum_153 = torch.ops.aten.sum.dim_IntList(mul_772, [2], True); mul_772 = None + div_51 = torch.ops.aten.div.Tensor(mul_52, 4096) + mul_773 = torch.ops.aten.mul.Tensor(div_51, sum_153); div_51 = sum_153 = None + sub_77 = torch.ops.aten.sub.Tensor(mul_770, mul_773); mul_770 = mul_773 = None + mul_774 = torch.ops.aten.mul.Tensor(sub_77, rsqrt_13); sub_77 = rsqrt_13 = None + mul_775 = torch.ops.aten.mul.Tensor(convert_element_type_2442, mul_52); convert_element_type_2442 = mul_52 = None + sum_154 = torch.ops.aten.sum.dim_IntList(mul_775, [0, 1]); mul_775 = None + convert_element_type_2445 = torch.ops.prims.convert_element_type.default(mul_774, torch.bfloat16); mul_774 = None + add_307 = torch.ops.aten.add.Tensor(add_303, convert_element_type_2445); add_303 = convert_element_type_2445 = None + convert_element_type_default_14 = torch.ops.prims.convert_element_type.default(sum_154, torch.float32); sum_154 = None + reduce_scatter_tensor_230 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_14, 'avg', 64, '0'); convert_element_type_default_14 = None + wait_tensor_521 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_230); reduce_scatter_tensor_230 = None + view_1701 = torch.ops.aten.view.default(add_307, [16384, 4096]) + permute_1169 = torch.ops.aten.permute.default(view_1701, [1, 0]) + mm_583 = torch.ops.aten.mm.default(permute_1169, view_227); permute_1169 = view_227 = None + permute_1171 = torch.ops.aten.permute.default(permute_73, [1, 0]); permute_73 = None + mm_584 = torch.ops.aten.mm.default(view_1701, permute_1171); view_1701 = permute_1171 = None + view_1702 = torch.ops.aten.view.default(mm_584, [2, 8192, 4096]); mm_584 = None + convert_element_type_2452 = torch.ops.prims.convert_element_type.default(mm_583, torch.float32); mm_583 = None + reduce_scatter_tensor_231 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2452, 'avg', 64, '0'); convert_element_type_2452 = None + wait_tensor_522 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_231); reduce_scatter_tensor_231 = None + view_1703 = torch.ops.aten.view.default(view_1702, [2, 8192, 32, 128]); view_1702 = None + permute_1173 = torch.ops.aten.permute.default(view_1703, [0, 2, 1, 3]); view_1703 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16); primals_58 = None + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 64, '0'); convert_element_type_199 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32); add_23 = None + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_55) + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + view_207 = torch.ops.aten.view.default(convert_element_type_201, [16384, 4096]); convert_element_type_201 = None + view_208 = torch.ops.aten.view.default(mm_42, [2, 8192, 4096]); mm_42 = None + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16); primals_60 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 64, '0'); convert_element_type_205 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_43 = torch.ops.aten.mm.default(view_207, permute_67) + view_211 = torch.ops.aten.view.default(mm_43, [2, 8192, 1024]); mm_43 = None + view_214 = torch.ops.aten.view.default(mm_44, [2, 8192, 1024]); mm_44 = None + view_215 = torch.ops.aten.view.default(view_208, [2, 8192, -1, 128]); view_208 = None + view_216 = torch.ops.aten.view.default(view_211, [2, 8192, -1, 128]); view_211 = None + view_217 = torch.ops.aten.view.default(view_214, [2, 8192, -1, 128]); view_214 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_215, torch.float32); view_215 = None + view_218 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 32, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_218); view_218 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_216, torch.float32); view_216 = None + view_219 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 8, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_219); view_219 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_16); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_221 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 32, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_16); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_222 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 8, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_221, torch.bfloat16); view_221 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_222, torch.bfloat16); view_222 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 8, 4, 128]); unsqueeze_12 = None + clone_12 = torch.ops.aten.clone.default(expand_12, memory_format = torch.contiguous_format); expand_12 = None + view_223 = torch.ops.aten.view.default(clone_12, [2, 8192, 32, 128]); clone_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_217, 3); view_217 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 8, 4, 128]); unsqueeze_13 = None + clone_13 = torch.ops.aten.clone.default(expand_13, memory_format = torch.contiguous_format); expand_13 = None + view_224 = torch.ops.aten.view.default(clone_13, [2, 8192, 32, 128]); clone_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_223, [0, 2, 1, 3]); view_223 = None + permute_71 = torch.ops.aten.permute.default(view_224, [0, 2, 1, 3]); view_224 = None + _scaled_dot_product_cudnn_attention_backward_25 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1173, permute_69, permute_70, permute_71, getitem_54, getitem_55, getitem_60, getitem_61, None, None, None, 8192, 8192, 0.0, True); permute_1173 = permute_69 = permute_70 = permute_71 = getitem_54 = getitem_55 = getitem_60 = getitem_61 = None + getitem_363 = _scaled_dot_product_cudnn_attention_backward_25[0] + getitem_364 = _scaled_dot_product_cudnn_attention_backward_25[1] + getitem_365 = _scaled_dot_product_cudnn_attention_backward_25[2]; _scaled_dot_product_cudnn_attention_backward_25 = None + permute_1174 = torch.ops.aten.permute.default(getitem_365, [0, 2, 1, 3]); getitem_365 = None + permute_1175 = torch.ops.aten.permute.default(getitem_364, [0, 2, 1, 3]); getitem_364 = None + permute_1176 = torch.ops.aten.permute.default(getitem_363, [0, 2, 1, 3]); getitem_363 = None + view_1704 = torch.ops.aten.view.default(permute_1174, [2, 8192, 8, 4, 128]); permute_1174 = None + sum_155 = torch.ops.aten.sum.dim_IntList(view_1704, [3], True); view_1704 = None + squeeze_50 = torch.ops.aten.squeeze.dim(sum_155, 3); sum_155 = None + view_1705 = torch.ops.aten.view.default(permute_1175, [2, 8192, 8, 4, 128]); permute_1175 = None + sum_156 = torch.ops.aten.sum.dim_IntList(view_1705, [3], True); view_1705 = None + squeeze_51 = torch.ops.aten.squeeze.dim(sum_156, 3); sum_156 = None + convert_element_type_2453 = torch.ops.prims.convert_element_type.default(squeeze_51, torch.float32); squeeze_51 = None + convert_element_type_2454 = torch.ops.prims.convert_element_type.default(permute_1176, torch.float32); permute_1176 = None + view_1706 = torch.ops.aten.view.default(convert_element_type_2453, [2, 8192, 8, 64, 2]); convert_element_type_2453 = None + view_as_complex_114 = torch.ops.aten.view_as_complex.default(view_1706); view_1706 = None + mul_776 = torch.ops.aten.mul.Tensor(view_as_complex_114, _conj); view_as_complex_114 = None + view_1707 = torch.ops.aten.view.default(convert_element_type_2454, [2, 8192, 32, 64, 2]); convert_element_type_2454 = None + view_as_complex_115 = torch.ops.aten.view_as_complex.default(view_1707); view_1707 = None + mul_777 = torch.ops.aten.mul.Tensor(view_as_complex_115, _conj); view_as_complex_115 = None + view_as_real_114 = torch.ops.aten.view_as_real.default(mul_776); mul_776 = None + view_1708 = torch.ops.aten.view.default(view_as_real_114, [2, 8192, 8, 128]); view_as_real_114 = None + convert_element_type_2455 = torch.ops.prims.convert_element_type.default(view_1708, torch.bfloat16); view_1708 = None + view_as_real_115 = torch.ops.aten.view_as_real.default(mul_777); mul_777 = None + view_1709 = torch.ops.aten.view.default(view_as_real_115, [2, 8192, 32, 128]); view_as_real_115 = None + convert_element_type_2456 = torch.ops.prims.convert_element_type.default(view_1709, torch.bfloat16); view_1709 = None + view_1710 = torch.ops.aten.view.default(squeeze_50, [2, 8192, 1024]); squeeze_50 = None + view_1711 = torch.ops.aten.view.default(convert_element_type_2455, [2, 8192, 1024]); convert_element_type_2455 = None + view_1712 = torch.ops.aten.view.default(convert_element_type_2456, [2, 8192, 4096]); convert_element_type_2456 = None + view_1713 = torch.ops.aten.view.default(view_1710, [16384, 1024]); view_1710 = None + permute_1177 = torch.ops.aten.permute.default(view_1713, [1, 0]) + mm_585 = torch.ops.aten.mm.default(permute_1177, view_207); permute_1177 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16); primals_61 = None + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 64, '0'); convert_element_type_208 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + permute_1179 = torch.ops.aten.permute.default(permute_68, [1, 0]); permute_68 = None + mm_586 = torch.ops.aten.mm.default(view_1713, permute_1179); view_1713 = permute_1179 = None + view_1714 = torch.ops.aten.view.default(mm_586, [2, 8192, 4096]); mm_586 = None + convert_element_type_2461 = torch.ops.prims.convert_element_type.default(mm_585, torch.float32); mm_585 = None + reduce_scatter_tensor_232 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2461, 'avg', 64, '0'); convert_element_type_2461 = None + wait_tensor_523 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_232); reduce_scatter_tensor_232 = None + view_1715 = torch.ops.aten.view.default(view_1711, [16384, 1024]); view_1711 = None + permute_1181 = torch.ops.aten.permute.default(view_1715, [1, 0]) + mm_587 = torch.ops.aten.mm.default(permute_1181, view_207); permute_1181 = None + permute_1183 = torch.ops.aten.permute.default(permute_67, [1, 0]); permute_67 = None + mm_588 = torch.ops.aten.mm.default(view_1715, permute_1183); view_1715 = permute_1183 = None + view_1716 = torch.ops.aten.view.default(mm_588, [2, 8192, 4096]); mm_588 = None + add_308 = torch.ops.aten.add.Tensor(view_1714, view_1716); view_1714 = view_1716 = None + convert_element_type_2466 = torch.ops.prims.convert_element_type.default(mm_587, torch.float32); mm_587 = None + reduce_scatter_tensor_233 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2466, 'avg', 64, '0'); convert_element_type_2466 = None + wait_tensor_524 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_233); reduce_scatter_tensor_233 = None + view_1717 = torch.ops.aten.view.default(view_1712, [16384, 4096]); view_1712 = None + permute_1185 = torch.ops.aten.permute.default(view_1717, [1, 0]) + mm_589 = torch.ops.aten.mm.default(permute_1185, view_207); permute_1185 = view_207 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16); primals_59 = None + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 64, '0'); convert_element_type_202 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + permute_1187 = torch.ops.aten.permute.default(permute_66, [1, 0]); permute_66 = None + mm_590 = torch.ops.aten.mm.default(view_1717, permute_1187); view_1717 = permute_1187 = None + view_1718 = torch.ops.aten.view.default(mm_590, [2, 8192, 4096]); mm_590 = None + add_309 = torch.ops.aten.add.Tensor(add_308, view_1718); add_308 = view_1718 = None + convert_element_type_2471 = torch.ops.prims.convert_element_type.default(mm_589, torch.float32); mm_589 = None + reduce_scatter_tensor_234 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2471, 'avg', 64, '0'); convert_element_type_2471 = None + wait_tensor_525 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_234); reduce_scatter_tensor_234 = None + convert_element_type_2472 = torch.ops.prims.convert_element_type.default(add_309, torch.float32); add_309 = None + convert_element_type_2474 = torch.ops.prims.convert_element_type.default(wait_tensor_55, torch.float32); wait_tensor_55 = None + mul_778 = torch.ops.aten.mul.Tensor(convert_element_type_2472, convert_element_type_2474); convert_element_type_2474 = None + mul_780 = torch.ops.aten.mul.Tensor(mul_48, mul_778) + sum_157 = torch.ops.aten.sum.dim_IntList(mul_780, [2], True); mul_780 = None + div_52 = torch.ops.aten.div.Tensor(mul_48, 4096) + mul_781 = torch.ops.aten.mul.Tensor(div_52, sum_157); div_52 = sum_157 = None + sub_78 = torch.ops.aten.sub.Tensor(mul_778, mul_781); mul_778 = mul_781 = None + mul_782 = torch.ops.aten.mul.Tensor(sub_78, rsqrt_12); sub_78 = rsqrt_12 = None + mul_783 = torch.ops.aten.mul.Tensor(convert_element_type_2472, mul_48); convert_element_type_2472 = mul_48 = None + sum_158 = torch.ops.aten.sum.dim_IntList(mul_783, [0, 1]); mul_783 = None + convert_element_type_2475 = torch.ops.prims.convert_element_type.default(mul_782, torch.bfloat16); mul_782 = None + add_310 = torch.ops.aten.add.Tensor(add_307, convert_element_type_2475); add_307 = convert_element_type_2475 = None + convert_element_type_default_13 = torch.ops.prims.convert_element_type.default(sum_158, torch.float32); sum_158 = None + reduce_scatter_tensor_235 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_13, 'avg', 64, '0'); convert_element_type_default_13 = None + wait_tensor_526 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_235); reduce_scatter_tensor_235 = None + view_1719 = torch.ops.aten.view.default(add_310, [16384, 4096]) + permute_1189 = torch.ops.aten.permute.default(view_1719, [1, 0]) + permute_61 = torch.ops.aten.permute.default(getitem_45, [0, 2, 1, 3]) + view_191 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 64, '0'); convert_element_type_182 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_193 = torch.ops.aten.view.default(view_191, [16384, 4096]); view_191 = None + mm_38 = torch.ops.aten.mm.default(view_193, permute_62) + view_194 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + add_21 = torch.ops.aten.add.Tensor(add_19, view_194); view_194 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 64, '0'); convert_element_type_185 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32); add_21 = None + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_51) + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + view_197 = torch.ops.aten.view.default(convert_element_type_187, [16384, 4096]); convert_element_type_187 = None + view_198 = torch.ops.aten.view.default(mm_39, [2, 8192, 14336]); mm_39 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_198, torch.float32); view_198 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16); primals_56 = None + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 64, '0'); convert_element_type_193 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_53, [1, 0]); wait_tensor_53 = None + mm_40 = torch.ops.aten.mm.default(view_197, permute_64) + view_201 = torch.ops.aten.view.default(mm_40, [2, 8192, 14336]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_201) + view_203 = torch.ops.aten.view.default(mul_47, [16384, 14336]); mul_47 = None + mm_591 = torch.ops.aten.mm.default(permute_1189, view_203); permute_1189 = view_203 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16); primals_57 = None + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 64, '0'); convert_element_type_196 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + permute_1191 = torch.ops.aten.permute.default(permute_65, [1, 0]); permute_65 = None + mm_592 = torch.ops.aten.mm.default(view_1719, permute_1191); view_1719 = permute_1191 = None + view_1720 = torch.ops.aten.view.default(mm_592, [2, 8192, 14336]); mm_592 = None + convert_element_type_2482 = torch.ops.prims.convert_element_type.default(mm_591, torch.float32); mm_591 = None + reduce_scatter_tensor_236 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2482, 'avg', 64, '0'); convert_element_type_2482 = None + wait_tensor_527 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_236); reduce_scatter_tensor_236 = None + mul_784 = torch.ops.aten.mul.Tensor(view_1720, convert_element_type_192); convert_element_type_192 = None + mul_785 = torch.ops.aten.mul.Tensor(view_1720, view_201); view_1720 = view_201 = None + view_1721 = torch.ops.aten.view.default(mul_784, [16384, 14336]); mul_784 = None + permute_1193 = torch.ops.aten.permute.default(view_1721, [1, 0]) + mm_593 = torch.ops.aten.mm.default(permute_1193, view_197); permute_1193 = None + permute_1195 = torch.ops.aten.permute.default(permute_64, [1, 0]); permute_64 = None + mm_594 = torch.ops.aten.mm.default(view_1721, permute_1195); view_1721 = permute_1195 = None + view_1722 = torch.ops.aten.view.default(mm_594, [2, 8192, 4096]); mm_594 = None + convert_element_type_2487 = torch.ops.prims.convert_element_type.default(mm_593, torch.float32); mm_593 = None + reduce_scatter_tensor_237 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2487, 'avg', 64, '0'); convert_element_type_2487 = None + wait_tensor_528 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_237); reduce_scatter_tensor_237 = None + convert_element_type_2488 = torch.ops.prims.convert_element_type.default(mul_785, torch.float32); mul_785 = None + neg_26 = torch.ops.aten.neg.default(convert_element_type_191) + exp_26 = torch.ops.aten.exp.default(neg_26); neg_26 = None + add_311 = torch.ops.aten.add.Tensor(exp_26, 1); exp_26 = None + reciprocal_26 = torch.ops.aten.reciprocal.default(add_311); add_311 = None + mul_786 = torch.ops.aten.mul.Tensor(reciprocal_26, 1); reciprocal_26 = None + mul_787 = torch.ops.aten.mul.Tensor(convert_element_type_2488, mul_786); convert_element_type_2488 = None + sub_79 = torch.ops.aten.sub.Tensor(1, mul_786); mul_786 = None + mul_788 = torch.ops.aten.mul.Tensor(convert_element_type_191, sub_79); convert_element_type_191 = sub_79 = None + add_312 = torch.ops.aten.add.Tensor(mul_788, 1); mul_788 = None + mul_789 = torch.ops.aten.mul.Tensor(mul_787, add_312); mul_787 = add_312 = None + convert_element_type_2490 = torch.ops.prims.convert_element_type.default(mul_789, torch.bfloat16); mul_789 = None + view_1723 = torch.ops.aten.view.default(convert_element_type_2490, [16384, 14336]); convert_element_type_2490 = None + permute_1197 = torch.ops.aten.permute.default(view_1723, [1, 0]) + mm_595 = torch.ops.aten.mm.default(permute_1197, view_197); permute_1197 = view_197 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16); primals_55 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 64, '0'); convert_element_type_188 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + permute_1199 = torch.ops.aten.permute.default(permute_63, [1, 0]); permute_63 = None + mm_596 = torch.ops.aten.mm.default(view_1723, permute_1199); view_1723 = permute_1199 = None + view_1724 = torch.ops.aten.view.default(mm_596, [2, 8192, 4096]); mm_596 = None + add_313 = torch.ops.aten.add.Tensor(view_1722, view_1724); view_1722 = view_1724 = None + convert_element_type_2495 = torch.ops.prims.convert_element_type.default(mm_595, torch.float32); mm_595 = None + reduce_scatter_tensor_238 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2495, 'avg', 64, '0'); convert_element_type_2495 = None + wait_tensor_529 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_238); reduce_scatter_tensor_238 = None + convert_element_type_2496 = torch.ops.prims.convert_element_type.default(add_313, torch.float32); add_313 = None + convert_element_type_2498 = torch.ops.prims.convert_element_type.default(wait_tensor_51, torch.float32); wait_tensor_51 = None + mul_790 = torch.ops.aten.mul.Tensor(convert_element_type_2496, convert_element_type_2498); convert_element_type_2498 = None + mul_792 = torch.ops.aten.mul.Tensor(mul_44, mul_790) + sum_159 = torch.ops.aten.sum.dim_IntList(mul_792, [2], True); mul_792 = None + div_53 = torch.ops.aten.div.Tensor(mul_44, 4096) + mul_793 = torch.ops.aten.mul.Tensor(div_53, sum_159); div_53 = sum_159 = None + sub_80 = torch.ops.aten.sub.Tensor(mul_790, mul_793); mul_790 = mul_793 = None + mul_794 = torch.ops.aten.mul.Tensor(sub_80, rsqrt_11); sub_80 = rsqrt_11 = None + mul_795 = torch.ops.aten.mul.Tensor(convert_element_type_2496, mul_44); convert_element_type_2496 = mul_44 = None + sum_160 = torch.ops.aten.sum.dim_IntList(mul_795, [0, 1]); mul_795 = None + convert_element_type_2499 = torch.ops.prims.convert_element_type.default(mul_794, torch.bfloat16); mul_794 = None + add_314 = torch.ops.aten.add.Tensor(add_310, convert_element_type_2499); add_310 = convert_element_type_2499 = None + convert_element_type_default_12 = torch.ops.prims.convert_element_type.default(sum_160, torch.float32); sum_160 = None + reduce_scatter_tensor_239 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_12, 'avg', 64, '0'); convert_element_type_default_12 = None + wait_tensor_530 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_239); reduce_scatter_tensor_239 = None + view_1725 = torch.ops.aten.view.default(add_314, [16384, 4096]) + permute_1201 = torch.ops.aten.permute.default(view_1725, [1, 0]) + mm_597 = torch.ops.aten.mm.default(permute_1201, view_193); permute_1201 = view_193 = None + permute_1203 = torch.ops.aten.permute.default(permute_62, [1, 0]); permute_62 = None + mm_598 = torch.ops.aten.mm.default(view_1725, permute_1203); view_1725 = permute_1203 = None + view_1726 = torch.ops.aten.view.default(mm_598, [2, 8192, 4096]); mm_598 = None + convert_element_type_2506 = torch.ops.prims.convert_element_type.default(mm_597, torch.float32); mm_597 = None + reduce_scatter_tensor_240 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2506, 'avg', 64, '0'); convert_element_type_2506 = None + wait_tensor_531 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_240); reduce_scatter_tensor_240 = None + view_1727 = torch.ops.aten.view.default(view_1726, [2, 8192, 32, 128]); view_1726 = None + permute_1205 = torch.ops.aten.permute.default(view_1727, [0, 2, 1, 3]); view_1727 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 64, '0'); convert_element_type_166 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32); add_19 = None + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_46) + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + view_173 = torch.ops.aten.view.default(convert_element_type_168, [16384, 4096]); convert_element_type_168 = None + view_174 = torch.ops.aten.view.default(mm_35, [2, 8192, 4096]); mm_35 = None + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 64, '0'); convert_element_type_172 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_48, [1, 0]); wait_tensor_48 = None + mm_36 = torch.ops.aten.mm.default(view_173, permute_56) + view_177 = torch.ops.aten.view.default(mm_36, [2, 8192, 1024]); mm_36 = None + view_180 = torch.ops.aten.view.default(mm_37, [2, 8192, 1024]); mm_37 = None + view_181 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + view_182 = torch.ops.aten.view.default(view_177, [2, 8192, -1, 128]); view_177 = None + view_183 = torch.ops.aten.view.default(view_180, [2, 8192, -1, 128]); view_180 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_181, torch.float32); view_181 = None + view_184 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 32, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_184); view_184 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_182, torch.float32); view_182 = None + view_185 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 8, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_185); view_185 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_16); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_187 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 32, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_16); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_188 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 8, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_187, torch.bfloat16); view_187 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_188, torch.bfloat16); view_188 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 8, 4, 128]); unsqueeze_10 = None + clone_10 = torch.ops.aten.clone.default(expand_10, memory_format = torch.contiguous_format); expand_10 = None + view_189 = torch.ops.aten.view.default(clone_10, [2, 8192, 32, 128]); clone_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_183, 3); view_183 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 8, 4, 128]); unsqueeze_11 = None + clone_11 = torch.ops.aten.clone.default(expand_11, memory_format = torch.contiguous_format); expand_11 = None + view_190 = torch.ops.aten.view.default(clone_11, [2, 8192, 32, 128]); clone_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_189, [0, 2, 1, 3]); view_189 = None + permute_60 = torch.ops.aten.permute.default(view_190, [0, 2, 1, 3]); view_190 = None + _scaled_dot_product_cudnn_attention_backward_26 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1205, permute_58, permute_59, permute_60, getitem_45, getitem_46, getitem_51, getitem_52, None, None, None, 8192, 8192, 0.0, True); permute_1205 = permute_58 = permute_59 = permute_60 = getitem_45 = getitem_46 = getitem_51 = getitem_52 = None + getitem_366 = _scaled_dot_product_cudnn_attention_backward_26[0] + getitem_367 = _scaled_dot_product_cudnn_attention_backward_26[1] + getitem_368 = _scaled_dot_product_cudnn_attention_backward_26[2]; _scaled_dot_product_cudnn_attention_backward_26 = None + permute_1206 = torch.ops.aten.permute.default(getitem_368, [0, 2, 1, 3]); getitem_368 = None + permute_1207 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]); getitem_367 = None + permute_1208 = torch.ops.aten.permute.default(getitem_366, [0, 2, 1, 3]); getitem_366 = None + view_1728 = torch.ops.aten.view.default(permute_1206, [2, 8192, 8, 4, 128]); permute_1206 = None + sum_161 = torch.ops.aten.sum.dim_IntList(view_1728, [3], True); view_1728 = None + squeeze_52 = torch.ops.aten.squeeze.dim(sum_161, 3); sum_161 = None + view_1729 = torch.ops.aten.view.default(permute_1207, [2, 8192, 8, 4, 128]); permute_1207 = None + sum_162 = torch.ops.aten.sum.dim_IntList(view_1729, [3], True); view_1729 = None + squeeze_53 = torch.ops.aten.squeeze.dim(sum_162, 3); sum_162 = None + convert_element_type_2507 = torch.ops.prims.convert_element_type.default(squeeze_53, torch.float32); squeeze_53 = None + convert_element_type_2508 = torch.ops.prims.convert_element_type.default(permute_1208, torch.float32); permute_1208 = None + view_1730 = torch.ops.aten.view.default(convert_element_type_2507, [2, 8192, 8, 64, 2]); convert_element_type_2507 = None + view_as_complex_116 = torch.ops.aten.view_as_complex.default(view_1730); view_1730 = None + mul_796 = torch.ops.aten.mul.Tensor(view_as_complex_116, _conj); view_as_complex_116 = None + view_1731 = torch.ops.aten.view.default(convert_element_type_2508, [2, 8192, 32, 64, 2]); convert_element_type_2508 = None + view_as_complex_117 = torch.ops.aten.view_as_complex.default(view_1731); view_1731 = None + mul_797 = torch.ops.aten.mul.Tensor(view_as_complex_117, _conj); view_as_complex_117 = None + view_as_real_116 = torch.ops.aten.view_as_real.default(mul_796); mul_796 = None + view_1732 = torch.ops.aten.view.default(view_as_real_116, [2, 8192, 8, 128]); view_as_real_116 = None + convert_element_type_2509 = torch.ops.prims.convert_element_type.default(view_1732, torch.bfloat16); view_1732 = None + view_as_real_117 = torch.ops.aten.view_as_real.default(mul_797); mul_797 = None + view_1733 = torch.ops.aten.view.default(view_as_real_117, [2, 8192, 32, 128]); view_as_real_117 = None + convert_element_type_2510 = torch.ops.prims.convert_element_type.default(view_1733, torch.bfloat16); view_1733 = None + view_1734 = torch.ops.aten.view.default(squeeze_52, [2, 8192, 1024]); squeeze_52 = None + view_1735 = torch.ops.aten.view.default(convert_element_type_2509, [2, 8192, 1024]); convert_element_type_2509 = None + view_1736 = torch.ops.aten.view.default(convert_element_type_2510, [2, 8192, 4096]); convert_element_type_2510 = None + view_1737 = torch.ops.aten.view.default(view_1734, [16384, 1024]); view_1734 = None + permute_1209 = torch.ops.aten.permute.default(view_1737, [1, 0]) + mm_599 = torch.ops.aten.mm.default(permute_1209, view_173); permute_1209 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 64, '0'); convert_element_type_175 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_49, [1, 0]); wait_tensor_49 = None + permute_1211 = torch.ops.aten.permute.default(permute_57, [1, 0]); permute_57 = None + mm_600 = torch.ops.aten.mm.default(view_1737, permute_1211); view_1737 = permute_1211 = None + view_1738 = torch.ops.aten.view.default(mm_600, [2, 8192, 4096]); mm_600 = None + convert_element_type_2515 = torch.ops.prims.convert_element_type.default(mm_599, torch.float32); mm_599 = None + reduce_scatter_tensor_241 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2515, 'avg', 64, '0'); convert_element_type_2515 = None + wait_tensor_532 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_241); reduce_scatter_tensor_241 = None + view_1739 = torch.ops.aten.view.default(view_1735, [16384, 1024]); view_1735 = None + permute_1213 = torch.ops.aten.permute.default(view_1739, [1, 0]) + mm_601 = torch.ops.aten.mm.default(permute_1213, view_173); permute_1213 = None + permute_1215 = torch.ops.aten.permute.default(permute_56, [1, 0]); permute_56 = None + mm_602 = torch.ops.aten.mm.default(view_1739, permute_1215); view_1739 = permute_1215 = None + view_1740 = torch.ops.aten.view.default(mm_602, [2, 8192, 4096]); mm_602 = None + add_315 = torch.ops.aten.add.Tensor(view_1738, view_1740); view_1738 = view_1740 = None + convert_element_type_2520 = torch.ops.prims.convert_element_type.default(mm_601, torch.float32); mm_601 = None + reduce_scatter_tensor_242 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2520, 'avg', 64, '0'); convert_element_type_2520 = None + wait_tensor_533 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_242); reduce_scatter_tensor_242 = None + view_1741 = torch.ops.aten.view.default(view_1736, [16384, 4096]); view_1736 = None + permute_1217 = torch.ops.aten.permute.default(view_1741, [1, 0]) + mm_603 = torch.ops.aten.mm.default(permute_1217, view_173); permute_1217 = view_173 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 64, '0'); convert_element_type_169 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_47, [1, 0]); wait_tensor_47 = None + permute_1219 = torch.ops.aten.permute.default(permute_55, [1, 0]); permute_55 = None + mm_604 = torch.ops.aten.mm.default(view_1741, permute_1219); view_1741 = permute_1219 = None + view_1742 = torch.ops.aten.view.default(mm_604, [2, 8192, 4096]); mm_604 = None + add_316 = torch.ops.aten.add.Tensor(add_315, view_1742); add_315 = view_1742 = None + convert_element_type_2525 = torch.ops.prims.convert_element_type.default(mm_603, torch.float32); mm_603 = None + reduce_scatter_tensor_243 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2525, 'avg', 64, '0'); convert_element_type_2525 = None + wait_tensor_534 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_243); reduce_scatter_tensor_243 = None + convert_element_type_2526 = torch.ops.prims.convert_element_type.default(add_316, torch.float32); add_316 = None + convert_element_type_2528 = torch.ops.prims.convert_element_type.default(wait_tensor_46, torch.float32); wait_tensor_46 = None + mul_798 = torch.ops.aten.mul.Tensor(convert_element_type_2526, convert_element_type_2528); convert_element_type_2528 = None + mul_800 = torch.ops.aten.mul.Tensor(mul_40, mul_798) + sum_163 = torch.ops.aten.sum.dim_IntList(mul_800, [2], True); mul_800 = None + div_54 = torch.ops.aten.div.Tensor(mul_40, 4096) + mul_801 = torch.ops.aten.mul.Tensor(div_54, sum_163); div_54 = sum_163 = None + sub_81 = torch.ops.aten.sub.Tensor(mul_798, mul_801); mul_798 = mul_801 = None + mul_802 = torch.ops.aten.mul.Tensor(sub_81, rsqrt_10); sub_81 = rsqrt_10 = None + mul_803 = torch.ops.aten.mul.Tensor(convert_element_type_2526, mul_40); convert_element_type_2526 = mul_40 = None + sum_164 = torch.ops.aten.sum.dim_IntList(mul_803, [0, 1]); mul_803 = None + convert_element_type_2529 = torch.ops.prims.convert_element_type.default(mul_802, torch.bfloat16); mul_802 = None + add_317 = torch.ops.aten.add.Tensor(add_314, convert_element_type_2529); add_314 = convert_element_type_2529 = None + convert_element_type_default_11 = torch.ops.prims.convert_element_type.default(sum_164, torch.float32); sum_164 = None + reduce_scatter_tensor_244 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_11, 'avg', 64, '0'); convert_element_type_default_11 = None + wait_tensor_535 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_244); reduce_scatter_tensor_244 = None + view_1743 = torch.ops.aten.view.default(add_317, [16384, 4096]) + permute_1221 = torch.ops.aten.permute.default(view_1743, [1, 0]) + permute_50 = torch.ops.aten.permute.default(getitem_36, [0, 2, 1, 3]) + view_157 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16); primals_44 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 64, '0'); convert_element_type_149 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_41, [1, 0]); wait_tensor_41 = None + view_159 = torch.ops.aten.view.default(view_157, [16384, 4096]); view_157 = None + mm_31 = torch.ops.aten.mm.default(view_159, permute_51) + view_160 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + add_17 = torch.ops.aten.add.Tensor(add_15, view_160); view_160 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16); primals_45 = None + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 64, '0'); convert_element_type_152 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32); add_17 = None + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_42) + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + view_163 = torch.ops.aten.view.default(convert_element_type_154, [16384, 4096]); convert_element_type_154 = None + view_164 = torch.ops.aten.view.default(mm_32, [2, 8192, 14336]); mm_32 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_164, torch.float32); view_164 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 64, '0'); convert_element_type_160 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_33 = torch.ops.aten.mm.default(view_163, permute_53) + view_167 = torch.ops.aten.view.default(mm_33, [2, 8192, 14336]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_167) + view_169 = torch.ops.aten.view.default(mul_39, [16384, 14336]); mul_39 = None + mm_605 = torch.ops.aten.mm.default(permute_1221, view_169); permute_1221 = view_169 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16); primals_48 = None + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 64, '0'); convert_element_type_163 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + permute_1223 = torch.ops.aten.permute.default(permute_54, [1, 0]); permute_54 = None + mm_606 = torch.ops.aten.mm.default(view_1743, permute_1223); view_1743 = permute_1223 = None + view_1744 = torch.ops.aten.view.default(mm_606, [2, 8192, 14336]); mm_606 = None + convert_element_type_2536 = torch.ops.prims.convert_element_type.default(mm_605, torch.float32); mm_605 = None + reduce_scatter_tensor_245 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2536, 'avg', 64, '0'); convert_element_type_2536 = None + wait_tensor_536 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_245); reduce_scatter_tensor_245 = None + mul_804 = torch.ops.aten.mul.Tensor(view_1744, convert_element_type_159); convert_element_type_159 = None + mul_805 = torch.ops.aten.mul.Tensor(view_1744, view_167); view_1744 = view_167 = None + view_1745 = torch.ops.aten.view.default(mul_804, [16384, 14336]); mul_804 = None + permute_1225 = torch.ops.aten.permute.default(view_1745, [1, 0]) + mm_607 = torch.ops.aten.mm.default(permute_1225, view_163); permute_1225 = None + permute_1227 = torch.ops.aten.permute.default(permute_53, [1, 0]); permute_53 = None + mm_608 = torch.ops.aten.mm.default(view_1745, permute_1227); view_1745 = permute_1227 = None + view_1746 = torch.ops.aten.view.default(mm_608, [2, 8192, 4096]); mm_608 = None + convert_element_type_2541 = torch.ops.prims.convert_element_type.default(mm_607, torch.float32); mm_607 = None + reduce_scatter_tensor_246 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2541, 'avg', 64, '0'); convert_element_type_2541 = None + wait_tensor_537 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_246); reduce_scatter_tensor_246 = None + convert_element_type_2542 = torch.ops.prims.convert_element_type.default(mul_805, torch.float32); mul_805 = None + neg_27 = torch.ops.aten.neg.default(convert_element_type_158) + exp_27 = torch.ops.aten.exp.default(neg_27); neg_27 = None + add_318 = torch.ops.aten.add.Tensor(exp_27, 1); exp_27 = None + reciprocal_27 = torch.ops.aten.reciprocal.default(add_318); add_318 = None + mul_806 = torch.ops.aten.mul.Tensor(reciprocal_27, 1); reciprocal_27 = None + mul_807 = torch.ops.aten.mul.Tensor(convert_element_type_2542, mul_806); convert_element_type_2542 = None + sub_82 = torch.ops.aten.sub.Tensor(1, mul_806); mul_806 = None + mul_808 = torch.ops.aten.mul.Tensor(convert_element_type_158, sub_82); convert_element_type_158 = sub_82 = None + add_319 = torch.ops.aten.add.Tensor(mul_808, 1); mul_808 = None + mul_809 = torch.ops.aten.mul.Tensor(mul_807, add_319); mul_807 = add_319 = None + convert_element_type_2544 = torch.ops.prims.convert_element_type.default(mul_809, torch.bfloat16); mul_809 = None + view_1747 = torch.ops.aten.view.default(convert_element_type_2544, [16384, 14336]); convert_element_type_2544 = None + permute_1229 = torch.ops.aten.permute.default(view_1747, [1, 0]) + mm_609 = torch.ops.aten.mm.default(permute_1229, view_163); permute_1229 = view_163 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16); primals_46 = None + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 64, '0'); convert_element_type_155 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + permute_1231 = torch.ops.aten.permute.default(permute_52, [1, 0]); permute_52 = None + mm_610 = torch.ops.aten.mm.default(view_1747, permute_1231); view_1747 = permute_1231 = None + view_1748 = torch.ops.aten.view.default(mm_610, [2, 8192, 4096]); mm_610 = None + add_320 = torch.ops.aten.add.Tensor(view_1746, view_1748); view_1746 = view_1748 = None + convert_element_type_2549 = torch.ops.prims.convert_element_type.default(mm_609, torch.float32); mm_609 = None + reduce_scatter_tensor_247 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2549, 'avg', 64, '0'); convert_element_type_2549 = None + wait_tensor_538 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_247); reduce_scatter_tensor_247 = None + convert_element_type_2550 = torch.ops.prims.convert_element_type.default(add_320, torch.float32); add_320 = None + convert_element_type_2552 = torch.ops.prims.convert_element_type.default(wait_tensor_42, torch.float32); wait_tensor_42 = None + mul_810 = torch.ops.aten.mul.Tensor(convert_element_type_2550, convert_element_type_2552); convert_element_type_2552 = None + mul_812 = torch.ops.aten.mul.Tensor(mul_36, mul_810) + sum_165 = torch.ops.aten.sum.dim_IntList(mul_812, [2], True); mul_812 = None + div_55 = torch.ops.aten.div.Tensor(mul_36, 4096) + mul_813 = torch.ops.aten.mul.Tensor(div_55, sum_165); div_55 = sum_165 = None + sub_83 = torch.ops.aten.sub.Tensor(mul_810, mul_813); mul_810 = mul_813 = None + mul_814 = torch.ops.aten.mul.Tensor(sub_83, rsqrt_9); sub_83 = rsqrt_9 = None + mul_815 = torch.ops.aten.mul.Tensor(convert_element_type_2550, mul_36); convert_element_type_2550 = mul_36 = None + sum_166 = torch.ops.aten.sum.dim_IntList(mul_815, [0, 1]); mul_815 = None + convert_element_type_2553 = torch.ops.prims.convert_element_type.default(mul_814, torch.bfloat16); mul_814 = None + add_321 = torch.ops.aten.add.Tensor(add_317, convert_element_type_2553); add_317 = convert_element_type_2553 = None + convert_element_type_default_10 = torch.ops.prims.convert_element_type.default(sum_166, torch.float32); sum_166 = None + reduce_scatter_tensor_248 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_10, 'avg', 64, '0'); convert_element_type_default_10 = None + wait_tensor_539 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_248); reduce_scatter_tensor_248 = None + view_1749 = torch.ops.aten.view.default(add_321, [16384, 4096]) + permute_1233 = torch.ops.aten.permute.default(view_1749, [1, 0]) + mm_611 = torch.ops.aten.mm.default(permute_1233, view_159); permute_1233 = view_159 = None + permute_1235 = torch.ops.aten.permute.default(permute_51, [1, 0]); permute_51 = None + mm_612 = torch.ops.aten.mm.default(view_1749, permute_1235); view_1749 = permute_1235 = None + view_1750 = torch.ops.aten.view.default(mm_612, [2, 8192, 4096]); mm_612 = None + convert_element_type_2560 = torch.ops.prims.convert_element_type.default(mm_611, torch.float32); mm_611 = None + reduce_scatter_tensor_249 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2560, 'avg', 64, '0'); convert_element_type_2560 = None + wait_tensor_540 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_249); reduce_scatter_tensor_249 = None + view_1751 = torch.ops.aten.view.default(view_1750, [2, 8192, 32, 128]); view_1750 = None + permute_1237 = torch.ops.aten.permute.default(view_1751, [0, 2, 1, 3]); view_1751 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16); primals_40 = None + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 64, '0'); convert_element_type_133 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32); add_15 = None + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_37) + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + view_139 = torch.ops.aten.view.default(convert_element_type_135, [16384, 4096]); convert_element_type_135 = None + view_140 = torch.ops.aten.view.default(mm_28, [2, 8192, 4096]); mm_28 = None + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16); primals_42 = None + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 64, '0'); convert_element_type_139 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + mm_29 = torch.ops.aten.mm.default(view_139, permute_45) + view_143 = torch.ops.aten.view.default(mm_29, [2, 8192, 1024]); mm_29 = None + view_146 = torch.ops.aten.view.default(mm_30, [2, 8192, 1024]); mm_30 = None + view_147 = torch.ops.aten.view.default(view_140, [2, 8192, -1, 128]); view_140 = None + view_148 = torch.ops.aten.view.default(view_143, [2, 8192, -1, 128]); view_143 = None + view_149 = torch.ops.aten.view.default(view_146, [2, 8192, -1, 128]); view_146 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_147, torch.float32); view_147 = None + view_150 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 32, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_150); view_150 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_148, torch.float32); view_148 = None + view_151 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 8, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_151); view_151 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_16); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_153 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 32, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_16); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_154 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 8, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_153, torch.bfloat16); view_153 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_154, torch.bfloat16); view_154 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 8, 4, 128]); unsqueeze_8 = None + clone_8 = torch.ops.aten.clone.default(expand_8, memory_format = torch.contiguous_format); expand_8 = None + view_155 = torch.ops.aten.view.default(clone_8, [2, 8192, 32, 128]); clone_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_149, 3); view_149 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 8, 4, 128]); unsqueeze_9 = None + clone_9 = torch.ops.aten.clone.default(expand_9, memory_format = torch.contiguous_format); expand_9 = None + view_156 = torch.ops.aten.view.default(clone_9, [2, 8192, 32, 128]); clone_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_155, [0, 2, 1, 3]); view_155 = None + permute_49 = torch.ops.aten.permute.default(view_156, [0, 2, 1, 3]); view_156 = None + _scaled_dot_product_cudnn_attention_backward_27 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1237, permute_47, permute_48, permute_49, getitem_36, getitem_37, getitem_42, getitem_43, None, None, None, 8192, 8192, 0.0, True); permute_1237 = permute_47 = permute_48 = permute_49 = getitem_36 = getitem_37 = getitem_42 = getitem_43 = None + getitem_369 = _scaled_dot_product_cudnn_attention_backward_27[0] + getitem_370 = _scaled_dot_product_cudnn_attention_backward_27[1] + getitem_371 = _scaled_dot_product_cudnn_attention_backward_27[2]; _scaled_dot_product_cudnn_attention_backward_27 = None + permute_1238 = torch.ops.aten.permute.default(getitem_371, [0, 2, 1, 3]); getitem_371 = None + permute_1239 = torch.ops.aten.permute.default(getitem_370, [0, 2, 1, 3]); getitem_370 = None + permute_1240 = torch.ops.aten.permute.default(getitem_369, [0, 2, 1, 3]); getitem_369 = None + view_1752 = torch.ops.aten.view.default(permute_1238, [2, 8192, 8, 4, 128]); permute_1238 = None + sum_167 = torch.ops.aten.sum.dim_IntList(view_1752, [3], True); view_1752 = None + squeeze_54 = torch.ops.aten.squeeze.dim(sum_167, 3); sum_167 = None + view_1753 = torch.ops.aten.view.default(permute_1239, [2, 8192, 8, 4, 128]); permute_1239 = None + sum_168 = torch.ops.aten.sum.dim_IntList(view_1753, [3], True); view_1753 = None + squeeze_55 = torch.ops.aten.squeeze.dim(sum_168, 3); sum_168 = None + convert_element_type_2561 = torch.ops.prims.convert_element_type.default(squeeze_55, torch.float32); squeeze_55 = None + convert_element_type_2562 = torch.ops.prims.convert_element_type.default(permute_1240, torch.float32); permute_1240 = None + view_1754 = torch.ops.aten.view.default(convert_element_type_2561, [2, 8192, 8, 64, 2]); convert_element_type_2561 = None + view_as_complex_118 = torch.ops.aten.view_as_complex.default(view_1754); view_1754 = None + mul_816 = torch.ops.aten.mul.Tensor(view_as_complex_118, _conj); view_as_complex_118 = None + view_1755 = torch.ops.aten.view.default(convert_element_type_2562, [2, 8192, 32, 64, 2]); convert_element_type_2562 = None + view_as_complex_119 = torch.ops.aten.view_as_complex.default(view_1755); view_1755 = None + mul_817 = torch.ops.aten.mul.Tensor(view_as_complex_119, _conj); view_as_complex_119 = None + view_as_real_118 = torch.ops.aten.view_as_real.default(mul_816); mul_816 = None + view_1756 = torch.ops.aten.view.default(view_as_real_118, [2, 8192, 8, 128]); view_as_real_118 = None + convert_element_type_2563 = torch.ops.prims.convert_element_type.default(view_1756, torch.bfloat16); view_1756 = None + view_as_real_119 = torch.ops.aten.view_as_real.default(mul_817); mul_817 = None + view_1757 = torch.ops.aten.view.default(view_as_real_119, [2, 8192, 32, 128]); view_as_real_119 = None + convert_element_type_2564 = torch.ops.prims.convert_element_type.default(view_1757, torch.bfloat16); view_1757 = None + view_1758 = torch.ops.aten.view.default(squeeze_54, [2, 8192, 1024]); squeeze_54 = None + view_1759 = torch.ops.aten.view.default(convert_element_type_2563, [2, 8192, 1024]); convert_element_type_2563 = None + view_1760 = torch.ops.aten.view.default(convert_element_type_2564, [2, 8192, 4096]); convert_element_type_2564 = None + view_1761 = torch.ops.aten.view.default(view_1758, [16384, 1024]); view_1758 = None + permute_1241 = torch.ops.aten.permute.default(view_1761, [1, 0]) + mm_613 = torch.ops.aten.mm.default(permute_1241, view_139); permute_1241 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16); primals_43 = None + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 64, '0'); convert_element_type_142 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_40, [1, 0]); wait_tensor_40 = None + permute_1243 = torch.ops.aten.permute.default(permute_46, [1, 0]); permute_46 = None + mm_614 = torch.ops.aten.mm.default(view_1761, permute_1243); view_1761 = permute_1243 = None + view_1762 = torch.ops.aten.view.default(mm_614, [2, 8192, 4096]); mm_614 = None + convert_element_type_2569 = torch.ops.prims.convert_element_type.default(mm_613, torch.float32); mm_613 = None + reduce_scatter_tensor_250 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2569, 'avg', 64, '0'); convert_element_type_2569 = None + wait_tensor_541 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_250); reduce_scatter_tensor_250 = None + view_1763 = torch.ops.aten.view.default(view_1759, [16384, 1024]); view_1759 = None + permute_1245 = torch.ops.aten.permute.default(view_1763, [1, 0]) + mm_615 = torch.ops.aten.mm.default(permute_1245, view_139); permute_1245 = None + permute_1247 = torch.ops.aten.permute.default(permute_45, [1, 0]); permute_45 = None + mm_616 = torch.ops.aten.mm.default(view_1763, permute_1247); view_1763 = permute_1247 = None + view_1764 = torch.ops.aten.view.default(mm_616, [2, 8192, 4096]); mm_616 = None + add_322 = torch.ops.aten.add.Tensor(view_1762, view_1764); view_1762 = view_1764 = None + convert_element_type_2574 = torch.ops.prims.convert_element_type.default(mm_615, torch.float32); mm_615 = None + reduce_scatter_tensor_251 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2574, 'avg', 64, '0'); convert_element_type_2574 = None + wait_tensor_542 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_251); reduce_scatter_tensor_251 = None + view_1765 = torch.ops.aten.view.default(view_1760, [16384, 4096]); view_1760 = None + permute_1249 = torch.ops.aten.permute.default(view_1765, [1, 0]) + mm_617 = torch.ops.aten.mm.default(permute_1249, view_139); permute_1249 = view_139 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16); primals_41 = None + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 64, '0'); convert_element_type_136 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + permute_1251 = torch.ops.aten.permute.default(permute_44, [1, 0]); permute_44 = None + mm_618 = torch.ops.aten.mm.default(view_1765, permute_1251); view_1765 = permute_1251 = None + view_1766 = torch.ops.aten.view.default(mm_618, [2, 8192, 4096]); mm_618 = None + add_323 = torch.ops.aten.add.Tensor(add_322, view_1766); add_322 = view_1766 = None + convert_element_type_2579 = torch.ops.prims.convert_element_type.default(mm_617, torch.float32); mm_617 = None + reduce_scatter_tensor_252 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2579, 'avg', 64, '0'); convert_element_type_2579 = None + wait_tensor_543 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_252); reduce_scatter_tensor_252 = None + convert_element_type_2580 = torch.ops.prims.convert_element_type.default(add_323, torch.float32); add_323 = None + convert_element_type_2582 = torch.ops.prims.convert_element_type.default(wait_tensor_37, torch.float32); wait_tensor_37 = None + mul_818 = torch.ops.aten.mul.Tensor(convert_element_type_2580, convert_element_type_2582); convert_element_type_2582 = None + mul_820 = torch.ops.aten.mul.Tensor(mul_32, mul_818) + sum_169 = torch.ops.aten.sum.dim_IntList(mul_820, [2], True); mul_820 = None + div_56 = torch.ops.aten.div.Tensor(mul_32, 4096) + mul_821 = torch.ops.aten.mul.Tensor(div_56, sum_169); div_56 = sum_169 = None + sub_84 = torch.ops.aten.sub.Tensor(mul_818, mul_821); mul_818 = mul_821 = None + mul_822 = torch.ops.aten.mul.Tensor(sub_84, rsqrt_8); sub_84 = rsqrt_8 = None + mul_823 = torch.ops.aten.mul.Tensor(convert_element_type_2580, mul_32); convert_element_type_2580 = mul_32 = None + sum_170 = torch.ops.aten.sum.dim_IntList(mul_823, [0, 1]); mul_823 = None + convert_element_type_2583 = torch.ops.prims.convert_element_type.default(mul_822, torch.bfloat16); mul_822 = None + add_324 = torch.ops.aten.add.Tensor(add_321, convert_element_type_2583); add_321 = convert_element_type_2583 = None + convert_element_type_default_9 = torch.ops.prims.convert_element_type.default(sum_170, torch.float32); sum_170 = None + reduce_scatter_tensor_253 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_9, 'avg', 64, '0'); convert_element_type_default_9 = None + wait_tensor_544 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_253); reduce_scatter_tensor_253 = None + view_1767 = torch.ops.aten.view.default(add_324, [16384, 4096]) + permute_1253 = torch.ops.aten.permute.default(view_1767, [1, 0]) + permute_39 = torch.ops.aten.permute.default(getitem_27, [0, 2, 1, 3]) + view_123 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 64, '0'); convert_element_type_116 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + view_125 = torch.ops.aten.view.default(view_123, [16384, 4096]); view_123 = None + mm_24 = torch.ops.aten.mm.default(view_125, permute_40) + view_126 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + add_13 = torch.ops.aten.add.Tensor(add_11, view_126); view_126 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 64, '0'); convert_element_type_119 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32); add_13 = None + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_33) + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + view_129 = torch.ops.aten.view.default(convert_element_type_121, [16384, 4096]); convert_element_type_121 = None + view_130 = torch.ops.aten.view.default(mm_25, [2, 8192, 14336]); mm_25 = None + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_130, torch.float32); view_130 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16); primals_38 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 64, '0'); convert_element_type_127 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_35, [1, 0]); wait_tensor_35 = None + mm_26 = torch.ops.aten.mm.default(view_129, permute_42) + view_133 = torch.ops.aten.view.default(mm_26, [2, 8192, 14336]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_133) + view_135 = torch.ops.aten.view.default(mul_31, [16384, 14336]); mul_31 = None + mm_619 = torch.ops.aten.mm.default(permute_1253, view_135); permute_1253 = view_135 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16); primals_39 = None + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 64, '0'); convert_element_type_130 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + permute_1255 = torch.ops.aten.permute.default(permute_43, [1, 0]); permute_43 = None + mm_620 = torch.ops.aten.mm.default(view_1767, permute_1255); view_1767 = permute_1255 = None + view_1768 = torch.ops.aten.view.default(mm_620, [2, 8192, 14336]); mm_620 = None + convert_element_type_2590 = torch.ops.prims.convert_element_type.default(mm_619, torch.float32); mm_619 = None + reduce_scatter_tensor_254 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2590, 'avg', 64, '0'); convert_element_type_2590 = None + wait_tensor_545 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_254); reduce_scatter_tensor_254 = None + mul_824 = torch.ops.aten.mul.Tensor(view_1768, convert_element_type_126); convert_element_type_126 = None + mul_825 = torch.ops.aten.mul.Tensor(view_1768, view_133); view_1768 = view_133 = None + view_1769 = torch.ops.aten.view.default(mul_824, [16384, 14336]); mul_824 = None + permute_1257 = torch.ops.aten.permute.default(view_1769, [1, 0]) + mm_621 = torch.ops.aten.mm.default(permute_1257, view_129); permute_1257 = None + permute_1259 = torch.ops.aten.permute.default(permute_42, [1, 0]); permute_42 = None + mm_622 = torch.ops.aten.mm.default(view_1769, permute_1259); view_1769 = permute_1259 = None + view_1770 = torch.ops.aten.view.default(mm_622, [2, 8192, 4096]); mm_622 = None + convert_element_type_2595 = torch.ops.prims.convert_element_type.default(mm_621, torch.float32); mm_621 = None + reduce_scatter_tensor_255 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2595, 'avg', 64, '0'); convert_element_type_2595 = None + wait_tensor_546 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_255); reduce_scatter_tensor_255 = None + convert_element_type_2596 = torch.ops.prims.convert_element_type.default(mul_825, torch.float32); mul_825 = None + neg_28 = torch.ops.aten.neg.default(convert_element_type_125) + exp_28 = torch.ops.aten.exp.default(neg_28); neg_28 = None + add_325 = torch.ops.aten.add.Tensor(exp_28, 1); exp_28 = None + reciprocal_28 = torch.ops.aten.reciprocal.default(add_325); add_325 = None + mul_826 = torch.ops.aten.mul.Tensor(reciprocal_28, 1); reciprocal_28 = None + mul_827 = torch.ops.aten.mul.Tensor(convert_element_type_2596, mul_826); convert_element_type_2596 = None + sub_85 = torch.ops.aten.sub.Tensor(1, mul_826); mul_826 = None + mul_828 = torch.ops.aten.mul.Tensor(convert_element_type_125, sub_85); convert_element_type_125 = sub_85 = None + add_326 = torch.ops.aten.add.Tensor(mul_828, 1); mul_828 = None + mul_829 = torch.ops.aten.mul.Tensor(mul_827, add_326); mul_827 = add_326 = None + convert_element_type_2598 = torch.ops.prims.convert_element_type.default(mul_829, torch.bfloat16); mul_829 = None + view_1771 = torch.ops.aten.view.default(convert_element_type_2598, [16384, 14336]); convert_element_type_2598 = None + permute_1261 = torch.ops.aten.permute.default(view_1771, [1, 0]) + mm_623 = torch.ops.aten.mm.default(permute_1261, view_129); permute_1261 = view_129 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16); primals_37 = None + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 64, '0'); convert_element_type_122 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_34, [1, 0]); wait_tensor_34 = None + permute_1263 = torch.ops.aten.permute.default(permute_41, [1, 0]); permute_41 = None + mm_624 = torch.ops.aten.mm.default(view_1771, permute_1263); view_1771 = permute_1263 = None + view_1772 = torch.ops.aten.view.default(mm_624, [2, 8192, 4096]); mm_624 = None + add_327 = torch.ops.aten.add.Tensor(view_1770, view_1772); view_1770 = view_1772 = None + convert_element_type_2603 = torch.ops.prims.convert_element_type.default(mm_623, torch.float32); mm_623 = None + reduce_scatter_tensor_256 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2603, 'avg', 64, '0'); convert_element_type_2603 = None + wait_tensor_547 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_256); reduce_scatter_tensor_256 = None + convert_element_type_2604 = torch.ops.prims.convert_element_type.default(add_327, torch.float32); add_327 = None + convert_element_type_2606 = torch.ops.prims.convert_element_type.default(wait_tensor_33, torch.float32); wait_tensor_33 = None + mul_830 = torch.ops.aten.mul.Tensor(convert_element_type_2604, convert_element_type_2606); convert_element_type_2606 = None + mul_832 = torch.ops.aten.mul.Tensor(mul_28, mul_830) + sum_171 = torch.ops.aten.sum.dim_IntList(mul_832, [2], True); mul_832 = None + div_57 = torch.ops.aten.div.Tensor(mul_28, 4096) + mul_833 = torch.ops.aten.mul.Tensor(div_57, sum_171); div_57 = sum_171 = None + sub_86 = torch.ops.aten.sub.Tensor(mul_830, mul_833); mul_830 = mul_833 = None + mul_834 = torch.ops.aten.mul.Tensor(sub_86, rsqrt_7); sub_86 = rsqrt_7 = None + mul_835 = torch.ops.aten.mul.Tensor(convert_element_type_2604, mul_28); convert_element_type_2604 = mul_28 = None + sum_172 = torch.ops.aten.sum.dim_IntList(mul_835, [0, 1]); mul_835 = None + convert_element_type_2607 = torch.ops.prims.convert_element_type.default(mul_834, torch.bfloat16); mul_834 = None + add_328 = torch.ops.aten.add.Tensor(add_324, convert_element_type_2607); add_324 = convert_element_type_2607 = None + convert_element_type_default_8 = torch.ops.prims.convert_element_type.default(sum_172, torch.float32); sum_172 = None + reduce_scatter_tensor_257 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_8, 'avg', 64, '0'); convert_element_type_default_8 = None + wait_tensor_548 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_257); reduce_scatter_tensor_257 = None + view_1773 = torch.ops.aten.view.default(add_328, [16384, 4096]) + permute_1265 = torch.ops.aten.permute.default(view_1773, [1, 0]) + mm_625 = torch.ops.aten.mm.default(permute_1265, view_125); permute_1265 = view_125 = None + permute_1267 = torch.ops.aten.permute.default(permute_40, [1, 0]); permute_40 = None + mm_626 = torch.ops.aten.mm.default(view_1773, permute_1267); view_1773 = permute_1267 = None + view_1774 = torch.ops.aten.view.default(mm_626, [2, 8192, 4096]); mm_626 = None + convert_element_type_2614 = torch.ops.prims.convert_element_type.default(mm_625, torch.float32); mm_625 = None + reduce_scatter_tensor_258 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2614, 'avg', 64, '0'); convert_element_type_2614 = None + wait_tensor_549 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_258); reduce_scatter_tensor_258 = None + view_1775 = torch.ops.aten.view.default(view_1774, [2, 8192, 32, 128]); view_1774 = None + permute_1269 = torch.ops.aten.permute.default(view_1775, [0, 2, 1, 3]); view_1775 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 64, '0'); convert_element_type_100 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32); add_11 = None + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_28) + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + view_105 = torch.ops.aten.view.default(convert_element_type_102, [16384, 4096]); convert_element_type_102 = None + view_106 = torch.ops.aten.view.default(mm_21, [2, 8192, 4096]); mm_21 = None + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 64, '0'); convert_element_type_106 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + mm_22 = torch.ops.aten.mm.default(view_105, permute_34) + view_109 = torch.ops.aten.view.default(mm_22, [2, 8192, 1024]); mm_22 = None + view_112 = torch.ops.aten.view.default(mm_23, [2, 8192, 1024]); mm_23 = None + view_113 = torch.ops.aten.view.default(view_106, [2, 8192, -1, 128]); view_106 = None + view_114 = torch.ops.aten.view.default(view_109, [2, 8192, -1, 128]); view_109 = None + view_115 = torch.ops.aten.view.default(view_112, [2, 8192, -1, 128]); view_112 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_113, torch.float32); view_113 = None + view_116 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 32, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_116); view_116 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_114, torch.float32); view_114 = None + view_117 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 8, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_117); view_117 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_16); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_119 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 32, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_16); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_120 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 8, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_119, torch.bfloat16); view_119 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_120, torch.bfloat16); view_120 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 8, 4, 128]); unsqueeze_6 = None + clone_6 = torch.ops.aten.clone.default(expand_6, memory_format = torch.contiguous_format); expand_6 = None + view_121 = torch.ops.aten.view.default(clone_6, [2, 8192, 32, 128]); clone_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_115, 3); view_115 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 8, 4, 128]); unsqueeze_7 = None + clone_7 = torch.ops.aten.clone.default(expand_7, memory_format = torch.contiguous_format); expand_7 = None + view_122 = torch.ops.aten.view.default(clone_7, [2, 8192, 32, 128]); clone_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_121, [0, 2, 1, 3]); view_121 = None + permute_38 = torch.ops.aten.permute.default(view_122, [0, 2, 1, 3]); view_122 = None + _scaled_dot_product_cudnn_attention_backward_28 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1269, permute_36, permute_37, permute_38, getitem_27, getitem_28, getitem_33, getitem_34, None, None, None, 8192, 8192, 0.0, True); permute_1269 = permute_36 = permute_37 = permute_38 = getitem_27 = getitem_28 = getitem_33 = getitem_34 = None + getitem_372 = _scaled_dot_product_cudnn_attention_backward_28[0] + getitem_373 = _scaled_dot_product_cudnn_attention_backward_28[1] + getitem_374 = _scaled_dot_product_cudnn_attention_backward_28[2]; _scaled_dot_product_cudnn_attention_backward_28 = None + permute_1270 = torch.ops.aten.permute.default(getitem_374, [0, 2, 1, 3]); getitem_374 = None + permute_1271 = torch.ops.aten.permute.default(getitem_373, [0, 2, 1, 3]); getitem_373 = None + permute_1272 = torch.ops.aten.permute.default(getitem_372, [0, 2, 1, 3]); getitem_372 = None + view_1776 = torch.ops.aten.view.default(permute_1270, [2, 8192, 8, 4, 128]); permute_1270 = None + sum_173 = torch.ops.aten.sum.dim_IntList(view_1776, [3], True); view_1776 = None + squeeze_56 = torch.ops.aten.squeeze.dim(sum_173, 3); sum_173 = None + view_1777 = torch.ops.aten.view.default(permute_1271, [2, 8192, 8, 4, 128]); permute_1271 = None + sum_174 = torch.ops.aten.sum.dim_IntList(view_1777, [3], True); view_1777 = None + squeeze_57 = torch.ops.aten.squeeze.dim(sum_174, 3); sum_174 = None + convert_element_type_2615 = torch.ops.prims.convert_element_type.default(squeeze_57, torch.float32); squeeze_57 = None + convert_element_type_2616 = torch.ops.prims.convert_element_type.default(permute_1272, torch.float32); permute_1272 = None + view_1778 = torch.ops.aten.view.default(convert_element_type_2615, [2, 8192, 8, 64, 2]); convert_element_type_2615 = None + view_as_complex_120 = torch.ops.aten.view_as_complex.default(view_1778); view_1778 = None + mul_836 = torch.ops.aten.mul.Tensor(view_as_complex_120, _conj); view_as_complex_120 = None + view_1779 = torch.ops.aten.view.default(convert_element_type_2616, [2, 8192, 32, 64, 2]); convert_element_type_2616 = None + view_as_complex_121 = torch.ops.aten.view_as_complex.default(view_1779); view_1779 = None + mul_837 = torch.ops.aten.mul.Tensor(view_as_complex_121, _conj); view_as_complex_121 = None + view_as_real_120 = torch.ops.aten.view_as_real.default(mul_836); mul_836 = None + view_1780 = torch.ops.aten.view.default(view_as_real_120, [2, 8192, 8, 128]); view_as_real_120 = None + convert_element_type_2617 = torch.ops.prims.convert_element_type.default(view_1780, torch.bfloat16); view_1780 = None + view_as_real_121 = torch.ops.aten.view_as_real.default(mul_837); mul_837 = None + view_1781 = torch.ops.aten.view.default(view_as_real_121, [2, 8192, 32, 128]); view_as_real_121 = None + convert_element_type_2618 = torch.ops.prims.convert_element_type.default(view_1781, torch.bfloat16); view_1781 = None + view_1782 = torch.ops.aten.view.default(squeeze_56, [2, 8192, 1024]); squeeze_56 = None + view_1783 = torch.ops.aten.view.default(convert_element_type_2617, [2, 8192, 1024]); convert_element_type_2617 = None + view_1784 = torch.ops.aten.view.default(convert_element_type_2618, [2, 8192, 4096]); convert_element_type_2618 = None + view_1785 = torch.ops.aten.view.default(view_1782, [16384, 1024]); view_1782 = None + permute_1273 = torch.ops.aten.permute.default(view_1785, [1, 0]) + mm_627 = torch.ops.aten.mm.default(permute_1273, view_105); permute_1273 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 64, '0'); convert_element_type_109 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + permute_1275 = torch.ops.aten.permute.default(permute_35, [1, 0]); permute_35 = None + mm_628 = torch.ops.aten.mm.default(view_1785, permute_1275); view_1785 = permute_1275 = None + view_1786 = torch.ops.aten.view.default(mm_628, [2, 8192, 4096]); mm_628 = None + convert_element_type_2623 = torch.ops.prims.convert_element_type.default(mm_627, torch.float32); mm_627 = None + reduce_scatter_tensor_259 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2623, 'avg', 64, '0'); convert_element_type_2623 = None + wait_tensor_550 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_259); reduce_scatter_tensor_259 = None + view_1787 = torch.ops.aten.view.default(view_1783, [16384, 1024]); view_1783 = None + permute_1277 = torch.ops.aten.permute.default(view_1787, [1, 0]) + mm_629 = torch.ops.aten.mm.default(permute_1277, view_105); permute_1277 = None + permute_1279 = torch.ops.aten.permute.default(permute_34, [1, 0]); permute_34 = None + mm_630 = torch.ops.aten.mm.default(view_1787, permute_1279); view_1787 = permute_1279 = None + view_1788 = torch.ops.aten.view.default(mm_630, [2, 8192, 4096]); mm_630 = None + add_329 = torch.ops.aten.add.Tensor(view_1786, view_1788); view_1786 = view_1788 = None + convert_element_type_2628 = torch.ops.prims.convert_element_type.default(mm_629, torch.float32); mm_629 = None + reduce_scatter_tensor_260 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2628, 'avg', 64, '0'); convert_element_type_2628 = None + wait_tensor_551 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_260); reduce_scatter_tensor_260 = None + view_1789 = torch.ops.aten.view.default(view_1784, [16384, 4096]); view_1784 = None + permute_1281 = torch.ops.aten.permute.default(view_1789, [1, 0]) + mm_631 = torch.ops.aten.mm.default(permute_1281, view_105); permute_1281 = view_105 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16); primals_32 = None + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 64, '0'); convert_element_type_103 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + permute_1283 = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None + mm_632 = torch.ops.aten.mm.default(view_1789, permute_1283); view_1789 = permute_1283 = None + view_1790 = torch.ops.aten.view.default(mm_632, [2, 8192, 4096]); mm_632 = None + add_330 = torch.ops.aten.add.Tensor(add_329, view_1790); add_329 = view_1790 = None + convert_element_type_2633 = torch.ops.prims.convert_element_type.default(mm_631, torch.float32); mm_631 = None + reduce_scatter_tensor_261 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2633, 'avg', 64, '0'); convert_element_type_2633 = None + wait_tensor_552 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_261); reduce_scatter_tensor_261 = None + convert_element_type_2634 = torch.ops.prims.convert_element_type.default(add_330, torch.float32); add_330 = None + convert_element_type_2636 = torch.ops.prims.convert_element_type.default(wait_tensor_28, torch.float32); wait_tensor_28 = None + mul_838 = torch.ops.aten.mul.Tensor(convert_element_type_2634, convert_element_type_2636); convert_element_type_2636 = None + mul_840 = torch.ops.aten.mul.Tensor(mul_24, mul_838) + sum_175 = torch.ops.aten.sum.dim_IntList(mul_840, [2], True); mul_840 = None + div_58 = torch.ops.aten.div.Tensor(mul_24, 4096) + mul_841 = torch.ops.aten.mul.Tensor(div_58, sum_175); div_58 = sum_175 = None + sub_87 = torch.ops.aten.sub.Tensor(mul_838, mul_841); mul_838 = mul_841 = None + mul_842 = torch.ops.aten.mul.Tensor(sub_87, rsqrt_6); sub_87 = rsqrt_6 = None + mul_843 = torch.ops.aten.mul.Tensor(convert_element_type_2634, mul_24); convert_element_type_2634 = mul_24 = None + sum_176 = torch.ops.aten.sum.dim_IntList(mul_843, [0, 1]); mul_843 = None + convert_element_type_2637 = torch.ops.prims.convert_element_type.default(mul_842, torch.bfloat16); mul_842 = None + add_331 = torch.ops.aten.add.Tensor(add_328, convert_element_type_2637); add_328 = convert_element_type_2637 = None + convert_element_type_default_7 = torch.ops.prims.convert_element_type.default(sum_176, torch.float32); sum_176 = None + reduce_scatter_tensor_262 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_7, 'avg', 64, '0'); convert_element_type_default_7 = None + wait_tensor_553 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_262); reduce_scatter_tensor_262 = None + view_1791 = torch.ops.aten.view.default(add_331, [16384, 4096]) + permute_1285 = torch.ops.aten.permute.default(view_1791, [1, 0]) + permute_28 = torch.ops.aten.permute.default(getitem_18, [0, 2, 1, 3]) + view_89 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16); primals_26 = None + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 64, '0'); convert_element_type_83 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_23, [1, 0]); wait_tensor_23 = None + view_91 = torch.ops.aten.view.default(view_89, [16384, 4096]); view_89 = None + mm_17 = torch.ops.aten.mm.default(view_91, permute_29) + view_92 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + add_9 = torch.ops.aten.add.Tensor(add_7, view_92); view_92 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16); primals_27 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 64, '0'); convert_element_type_86 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32); add_9 = None + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_24) + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + view_95 = torch.ops.aten.view.default(convert_element_type_88, [16384, 4096]); convert_element_type_88 = None + view_96 = torch.ops.aten.view.default(mm_18, [2, 8192, 14336]); mm_18 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_96, torch.float32); view_96 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 64, '0'); convert_element_type_94 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + mm_19 = torch.ops.aten.mm.default(view_95, permute_31) + view_99 = torch.ops.aten.view.default(mm_19, [2, 8192, 14336]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_99) + view_101 = torch.ops.aten.view.default(mul_23, [16384, 14336]); mul_23 = None + mm_633 = torch.ops.aten.mm.default(permute_1285, view_101); permute_1285 = view_101 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16); primals_30 = None + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 64, '0'); convert_element_type_97 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_27, [1, 0]); wait_tensor_27 = None + permute_1287 = torch.ops.aten.permute.default(permute_32, [1, 0]); permute_32 = None + mm_634 = torch.ops.aten.mm.default(view_1791, permute_1287); view_1791 = permute_1287 = None + view_1792 = torch.ops.aten.view.default(mm_634, [2, 8192, 14336]); mm_634 = None + convert_element_type_2644 = torch.ops.prims.convert_element_type.default(mm_633, torch.float32); mm_633 = None + reduce_scatter_tensor_263 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2644, 'avg', 64, '0'); convert_element_type_2644 = None + wait_tensor_554 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_263); reduce_scatter_tensor_263 = None + mul_844 = torch.ops.aten.mul.Tensor(view_1792, convert_element_type_93); convert_element_type_93 = None + mul_845 = torch.ops.aten.mul.Tensor(view_1792, view_99); view_1792 = view_99 = None + view_1793 = torch.ops.aten.view.default(mul_844, [16384, 14336]); mul_844 = None + permute_1289 = torch.ops.aten.permute.default(view_1793, [1, 0]) + mm_635 = torch.ops.aten.mm.default(permute_1289, view_95); permute_1289 = None + permute_1291 = torch.ops.aten.permute.default(permute_31, [1, 0]); permute_31 = None + mm_636 = torch.ops.aten.mm.default(view_1793, permute_1291); view_1793 = permute_1291 = None + view_1794 = torch.ops.aten.view.default(mm_636, [2, 8192, 4096]); mm_636 = None + convert_element_type_2649 = torch.ops.prims.convert_element_type.default(mm_635, torch.float32); mm_635 = None + reduce_scatter_tensor_264 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2649, 'avg', 64, '0'); convert_element_type_2649 = None + wait_tensor_555 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_264); reduce_scatter_tensor_264 = None + convert_element_type_2650 = torch.ops.prims.convert_element_type.default(mul_845, torch.float32); mul_845 = None + neg_29 = torch.ops.aten.neg.default(convert_element_type_92) + exp_29 = torch.ops.aten.exp.default(neg_29); neg_29 = None + add_332 = torch.ops.aten.add.Tensor(exp_29, 1); exp_29 = None + reciprocal_29 = torch.ops.aten.reciprocal.default(add_332); add_332 = None + mul_846 = torch.ops.aten.mul.Tensor(reciprocal_29, 1); reciprocal_29 = None + mul_847 = torch.ops.aten.mul.Tensor(convert_element_type_2650, mul_846); convert_element_type_2650 = None + sub_88 = torch.ops.aten.sub.Tensor(1, mul_846); mul_846 = None + mul_848 = torch.ops.aten.mul.Tensor(convert_element_type_92, sub_88); convert_element_type_92 = sub_88 = None + add_333 = torch.ops.aten.add.Tensor(mul_848, 1); mul_848 = None + mul_849 = torch.ops.aten.mul.Tensor(mul_847, add_333); mul_847 = add_333 = None + convert_element_type_2652 = torch.ops.prims.convert_element_type.default(mul_849, torch.bfloat16); mul_849 = None + view_1795 = torch.ops.aten.view.default(convert_element_type_2652, [16384, 14336]); convert_element_type_2652 = None + permute_1293 = torch.ops.aten.permute.default(view_1795, [1, 0]) + mm_637 = torch.ops.aten.mm.default(permute_1293, view_95); permute_1293 = view_95 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 64, '0'); convert_element_type_89 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + permute_1295 = torch.ops.aten.permute.default(permute_30, [1, 0]); permute_30 = None + mm_638 = torch.ops.aten.mm.default(view_1795, permute_1295); view_1795 = permute_1295 = None + view_1796 = torch.ops.aten.view.default(mm_638, [2, 8192, 4096]); mm_638 = None + add_334 = torch.ops.aten.add.Tensor(view_1794, view_1796); view_1794 = view_1796 = None + convert_element_type_2657 = torch.ops.prims.convert_element_type.default(mm_637, torch.float32); mm_637 = None + reduce_scatter_tensor_265 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2657, 'avg', 64, '0'); convert_element_type_2657 = None + wait_tensor_556 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_265); reduce_scatter_tensor_265 = None + convert_element_type_2658 = torch.ops.prims.convert_element_type.default(add_334, torch.float32); add_334 = None + convert_element_type_2660 = torch.ops.prims.convert_element_type.default(wait_tensor_24, torch.float32); wait_tensor_24 = None + mul_850 = torch.ops.aten.mul.Tensor(convert_element_type_2658, convert_element_type_2660); convert_element_type_2660 = None + mul_852 = torch.ops.aten.mul.Tensor(mul_20, mul_850) + sum_177 = torch.ops.aten.sum.dim_IntList(mul_852, [2], True); mul_852 = None + div_59 = torch.ops.aten.div.Tensor(mul_20, 4096) + mul_853 = torch.ops.aten.mul.Tensor(div_59, sum_177); div_59 = sum_177 = None + sub_89 = torch.ops.aten.sub.Tensor(mul_850, mul_853); mul_850 = mul_853 = None + mul_854 = torch.ops.aten.mul.Tensor(sub_89, rsqrt_5); sub_89 = rsqrt_5 = None + mul_855 = torch.ops.aten.mul.Tensor(convert_element_type_2658, mul_20); convert_element_type_2658 = mul_20 = None + sum_178 = torch.ops.aten.sum.dim_IntList(mul_855, [0, 1]); mul_855 = None + convert_element_type_2661 = torch.ops.prims.convert_element_type.default(mul_854, torch.bfloat16); mul_854 = None + add_335 = torch.ops.aten.add.Tensor(add_331, convert_element_type_2661); add_331 = convert_element_type_2661 = None + convert_element_type_default_6 = torch.ops.prims.convert_element_type.default(sum_178, torch.float32); sum_178 = None + reduce_scatter_tensor_266 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_6, 'avg', 64, '0'); convert_element_type_default_6 = None + wait_tensor_557 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_266); reduce_scatter_tensor_266 = None + view_1797 = torch.ops.aten.view.default(add_335, [16384, 4096]) + permute_1297 = torch.ops.aten.permute.default(view_1797, [1, 0]) + mm_639 = torch.ops.aten.mm.default(permute_1297, view_91); permute_1297 = view_91 = None + permute_1299 = torch.ops.aten.permute.default(permute_29, [1, 0]); permute_29 = None + mm_640 = torch.ops.aten.mm.default(view_1797, permute_1299); view_1797 = permute_1299 = None + view_1798 = torch.ops.aten.view.default(mm_640, [2, 8192, 4096]); mm_640 = None + convert_element_type_2668 = torch.ops.prims.convert_element_type.default(mm_639, torch.float32); mm_639 = None + reduce_scatter_tensor_267 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2668, 'avg', 64, '0'); convert_element_type_2668 = None + wait_tensor_558 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_267); reduce_scatter_tensor_267 = None + view_1799 = torch.ops.aten.view.default(view_1798, [2, 8192, 32, 128]); view_1798 = None + permute_1301 = torch.ops.aten.permute.default(view_1799, [0, 2, 1, 3]); view_1799 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16); primals_22 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 64, '0'); convert_element_type_67 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32); add_7 = None + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_19) + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + view_71 = torch.ops.aten.view.default(convert_element_type_69, [16384, 4096]); convert_element_type_69 = None + view_72 = torch.ops.aten.view.default(mm_14, [2, 8192, 4096]); mm_14 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16); primals_24 = None + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 64, '0'); convert_element_type_73 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_21, [1, 0]); wait_tensor_21 = None + mm_15 = torch.ops.aten.mm.default(view_71, permute_23) + view_75 = torch.ops.aten.view.default(mm_15, [2, 8192, 1024]); mm_15 = None + view_78 = torch.ops.aten.view.default(mm_16, [2, 8192, 1024]); mm_16 = None + view_79 = torch.ops.aten.view.default(view_72, [2, 8192, -1, 128]); view_72 = None + view_80 = torch.ops.aten.view.default(view_75, [2, 8192, -1, 128]); view_75 = None + view_81 = torch.ops.aten.view.default(view_78, [2, 8192, -1, 128]); view_78 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_79, torch.float32); view_79 = None + view_82 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 32, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_82); view_82 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_80, torch.float32); view_80 = None + view_83 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 8, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_83); view_83 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_16); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_85 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 32, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_16); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_86 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 8, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_85, torch.bfloat16); view_85 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_86, torch.bfloat16); view_86 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 8, 4, 128]); unsqueeze_4 = None + clone_4 = torch.ops.aten.clone.default(expand_4, memory_format = torch.contiguous_format); expand_4 = None + view_87 = torch.ops.aten.view.default(clone_4, [2, 8192, 32, 128]); clone_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_81, 3); view_81 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 8, 4, 128]); unsqueeze_5 = None + clone_5 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format); expand_5 = None + view_88 = torch.ops.aten.view.default(clone_5, [2, 8192, 32, 128]); clone_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_87, [0, 2, 1, 3]); view_87 = None + permute_27 = torch.ops.aten.permute.default(view_88, [0, 2, 1, 3]); view_88 = None + _scaled_dot_product_cudnn_attention_backward_29 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1301, permute_25, permute_26, permute_27, getitem_18, getitem_19, getitem_24, getitem_25, None, None, None, 8192, 8192, 0.0, True); permute_1301 = permute_25 = permute_26 = permute_27 = getitem_18 = getitem_19 = getitem_24 = getitem_25 = None + getitem_375 = _scaled_dot_product_cudnn_attention_backward_29[0] + getitem_376 = _scaled_dot_product_cudnn_attention_backward_29[1] + getitem_377 = _scaled_dot_product_cudnn_attention_backward_29[2]; _scaled_dot_product_cudnn_attention_backward_29 = None + permute_1302 = torch.ops.aten.permute.default(getitem_377, [0, 2, 1, 3]); getitem_377 = None + permute_1303 = torch.ops.aten.permute.default(getitem_376, [0, 2, 1, 3]); getitem_376 = None + permute_1304 = torch.ops.aten.permute.default(getitem_375, [0, 2, 1, 3]); getitem_375 = None + view_1800 = torch.ops.aten.view.default(permute_1302, [2, 8192, 8, 4, 128]); permute_1302 = None + sum_179 = torch.ops.aten.sum.dim_IntList(view_1800, [3], True); view_1800 = None + squeeze_58 = torch.ops.aten.squeeze.dim(sum_179, 3); sum_179 = None + view_1801 = torch.ops.aten.view.default(permute_1303, [2, 8192, 8, 4, 128]); permute_1303 = None + sum_180 = torch.ops.aten.sum.dim_IntList(view_1801, [3], True); view_1801 = None + squeeze_59 = torch.ops.aten.squeeze.dim(sum_180, 3); sum_180 = None + convert_element_type_2669 = torch.ops.prims.convert_element_type.default(squeeze_59, torch.float32); squeeze_59 = None + convert_element_type_2670 = torch.ops.prims.convert_element_type.default(permute_1304, torch.float32); permute_1304 = None + view_1802 = torch.ops.aten.view.default(convert_element_type_2669, [2, 8192, 8, 64, 2]); convert_element_type_2669 = None + view_as_complex_122 = torch.ops.aten.view_as_complex.default(view_1802); view_1802 = None + mul_856 = torch.ops.aten.mul.Tensor(view_as_complex_122, _conj); view_as_complex_122 = None + view_1803 = torch.ops.aten.view.default(convert_element_type_2670, [2, 8192, 32, 64, 2]); convert_element_type_2670 = None + view_as_complex_123 = torch.ops.aten.view_as_complex.default(view_1803); view_1803 = None + mul_857 = torch.ops.aten.mul.Tensor(view_as_complex_123, _conj); view_as_complex_123 = None + view_as_real_122 = torch.ops.aten.view_as_real.default(mul_856); mul_856 = None + view_1804 = torch.ops.aten.view.default(view_as_real_122, [2, 8192, 8, 128]); view_as_real_122 = None + convert_element_type_2671 = torch.ops.prims.convert_element_type.default(view_1804, torch.bfloat16); view_1804 = None + view_as_real_123 = torch.ops.aten.view_as_real.default(mul_857); mul_857 = None + view_1805 = torch.ops.aten.view.default(view_as_real_123, [2, 8192, 32, 128]); view_as_real_123 = None + convert_element_type_2672 = torch.ops.prims.convert_element_type.default(view_1805, torch.bfloat16); view_1805 = None + view_1806 = torch.ops.aten.view.default(squeeze_58, [2, 8192, 1024]); squeeze_58 = None + view_1807 = torch.ops.aten.view.default(convert_element_type_2671, [2, 8192, 1024]); convert_element_type_2671 = None + view_1808 = torch.ops.aten.view.default(convert_element_type_2672, [2, 8192, 4096]); convert_element_type_2672 = None + view_1809 = torch.ops.aten.view.default(view_1806, [16384, 1024]); view_1806 = None + permute_1305 = torch.ops.aten.permute.default(view_1809, [1, 0]) + mm_641 = torch.ops.aten.mm.default(permute_1305, view_71); permute_1305 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16); primals_25 = None + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 64, '0'); convert_element_type_76 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_22, [1, 0]); wait_tensor_22 = None + permute_1307 = torch.ops.aten.permute.default(permute_24, [1, 0]); permute_24 = None + mm_642 = torch.ops.aten.mm.default(view_1809, permute_1307); view_1809 = permute_1307 = None + view_1810 = torch.ops.aten.view.default(mm_642, [2, 8192, 4096]); mm_642 = None + convert_element_type_2677 = torch.ops.prims.convert_element_type.default(mm_641, torch.float32); mm_641 = None + reduce_scatter_tensor_268 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2677, 'avg', 64, '0'); convert_element_type_2677 = None + wait_tensor_559 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_268); reduce_scatter_tensor_268 = None + view_1811 = torch.ops.aten.view.default(view_1807, [16384, 1024]); view_1807 = None + permute_1309 = torch.ops.aten.permute.default(view_1811, [1, 0]) + mm_643 = torch.ops.aten.mm.default(permute_1309, view_71); permute_1309 = None + permute_1311 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None + mm_644 = torch.ops.aten.mm.default(view_1811, permute_1311); view_1811 = permute_1311 = None + view_1812 = torch.ops.aten.view.default(mm_644, [2, 8192, 4096]); mm_644 = None + add_336 = torch.ops.aten.add.Tensor(view_1810, view_1812); view_1810 = view_1812 = None + convert_element_type_2682 = torch.ops.prims.convert_element_type.default(mm_643, torch.float32); mm_643 = None + reduce_scatter_tensor_269 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2682, 'avg', 64, '0'); convert_element_type_2682 = None + wait_tensor_560 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_269); reduce_scatter_tensor_269 = None + view_1813 = torch.ops.aten.view.default(view_1808, [16384, 4096]); view_1808 = None + permute_1313 = torch.ops.aten.permute.default(view_1813, [1, 0]) + mm_645 = torch.ops.aten.mm.default(permute_1313, view_71); permute_1313 = view_71 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16); primals_23 = None + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 64, '0'); convert_element_type_70 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + permute_1315 = torch.ops.aten.permute.default(permute_22, [1, 0]); permute_22 = None + mm_646 = torch.ops.aten.mm.default(view_1813, permute_1315); view_1813 = permute_1315 = None + view_1814 = torch.ops.aten.view.default(mm_646, [2, 8192, 4096]); mm_646 = None + add_337 = torch.ops.aten.add.Tensor(add_336, view_1814); add_336 = view_1814 = None + convert_element_type_2687 = torch.ops.prims.convert_element_type.default(mm_645, torch.float32); mm_645 = None + reduce_scatter_tensor_270 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2687, 'avg', 64, '0'); convert_element_type_2687 = None + wait_tensor_561 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_270); reduce_scatter_tensor_270 = None + convert_element_type_2688 = torch.ops.prims.convert_element_type.default(add_337, torch.float32); add_337 = None + convert_element_type_2690 = torch.ops.prims.convert_element_type.default(wait_tensor_19, torch.float32); wait_tensor_19 = None + mul_858 = torch.ops.aten.mul.Tensor(convert_element_type_2688, convert_element_type_2690); convert_element_type_2690 = None + mul_860 = torch.ops.aten.mul.Tensor(mul_16, mul_858) + sum_181 = torch.ops.aten.sum.dim_IntList(mul_860, [2], True); mul_860 = None + div_60 = torch.ops.aten.div.Tensor(mul_16, 4096) + mul_861 = torch.ops.aten.mul.Tensor(div_60, sum_181); div_60 = sum_181 = None + sub_90 = torch.ops.aten.sub.Tensor(mul_858, mul_861); mul_858 = mul_861 = None + mul_862 = torch.ops.aten.mul.Tensor(sub_90, rsqrt_4); sub_90 = rsqrt_4 = None + mul_863 = torch.ops.aten.mul.Tensor(convert_element_type_2688, mul_16); convert_element_type_2688 = mul_16 = None + sum_182 = torch.ops.aten.sum.dim_IntList(mul_863, [0, 1]); mul_863 = None + convert_element_type_2691 = torch.ops.prims.convert_element_type.default(mul_862, torch.bfloat16); mul_862 = None + add_338 = torch.ops.aten.add.Tensor(add_335, convert_element_type_2691); add_335 = convert_element_type_2691 = None + convert_element_type_default_5 = torch.ops.prims.convert_element_type.default(sum_182, torch.float32); sum_182 = None + reduce_scatter_tensor_271 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_5, 'avg', 64, '0'); convert_element_type_default_5 = None + wait_tensor_562 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_271); reduce_scatter_tensor_271 = None + view_1815 = torch.ops.aten.view.default(add_338, [16384, 4096]) + permute_1317 = torch.ops.aten.permute.default(view_1815, [1, 0]) + permute_17 = torch.ops.aten.permute.default(getitem_9, [0, 2, 1, 3]) + view_55 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16); primals_17 = None + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 64, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_14, [1, 0]); wait_tensor_14 = None + view_57 = torch.ops.aten.view.default(view_55, [16384, 4096]); view_55 = None + mm_10 = torch.ops.aten.mm.default(view_57, permute_18) + view_58 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + add_5 = torch.ops.aten.add.Tensor(add_3, view_58); view_58 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16); primals_18 = None + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 64, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32); add_5 = None + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_15) + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + view_61 = torch.ops.aten.view.default(convert_element_type_55, [16384, 4096]); convert_element_type_55 = None + view_62 = torch.ops.aten.view.default(mm_11, [2, 8192, 14336]); mm_11 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_62, torch.float32); view_62 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16); primals_20 = None + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 64, '0'); convert_element_type_61 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + mm_12 = torch.ops.aten.mm.default(view_61, permute_20) + view_65 = torch.ops.aten.view.default(mm_12, [2, 8192, 14336]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_65) + view_67 = torch.ops.aten.view.default(mul_15, [16384, 14336]); mul_15 = None + mm_647 = torch.ops.aten.mm.default(permute_1317, view_67); permute_1317 = view_67 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16); primals_21 = None + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 64, '0'); convert_element_type_64 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + permute_1319 = torch.ops.aten.permute.default(permute_21, [1, 0]); permute_21 = None + mm_648 = torch.ops.aten.mm.default(view_1815, permute_1319); view_1815 = permute_1319 = None + view_1816 = torch.ops.aten.view.default(mm_648, [2, 8192, 14336]); mm_648 = None + convert_element_type_2698 = torch.ops.prims.convert_element_type.default(mm_647, torch.float32); mm_647 = None + reduce_scatter_tensor_272 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2698, 'avg', 64, '0'); convert_element_type_2698 = None + wait_tensor_563 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_272); reduce_scatter_tensor_272 = None + mul_864 = torch.ops.aten.mul.Tensor(view_1816, convert_element_type_60); convert_element_type_60 = None + mul_865 = torch.ops.aten.mul.Tensor(view_1816, view_65); view_1816 = view_65 = None + view_1817 = torch.ops.aten.view.default(mul_864, [16384, 14336]); mul_864 = None + permute_1321 = torch.ops.aten.permute.default(view_1817, [1, 0]) + mm_649 = torch.ops.aten.mm.default(permute_1321, view_61); permute_1321 = None + permute_1323 = torch.ops.aten.permute.default(permute_20, [1, 0]); permute_20 = None + mm_650 = torch.ops.aten.mm.default(view_1817, permute_1323); view_1817 = permute_1323 = None + view_1818 = torch.ops.aten.view.default(mm_650, [2, 8192, 4096]); mm_650 = None + convert_element_type_2703 = torch.ops.prims.convert_element_type.default(mm_649, torch.float32); mm_649 = None + reduce_scatter_tensor_273 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2703, 'avg', 64, '0'); convert_element_type_2703 = None + wait_tensor_564 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_273); reduce_scatter_tensor_273 = None + convert_element_type_2704 = torch.ops.prims.convert_element_type.default(mul_865, torch.float32); mul_865 = None + neg_30 = torch.ops.aten.neg.default(convert_element_type_59) + exp_30 = torch.ops.aten.exp.default(neg_30); neg_30 = None + add_339 = torch.ops.aten.add.Tensor(exp_30, 1); exp_30 = None + reciprocal_30 = torch.ops.aten.reciprocal.default(add_339); add_339 = None + mul_866 = torch.ops.aten.mul.Tensor(reciprocal_30, 1); reciprocal_30 = None + mul_867 = torch.ops.aten.mul.Tensor(convert_element_type_2704, mul_866); convert_element_type_2704 = None + sub_91 = torch.ops.aten.sub.Tensor(1, mul_866); mul_866 = None + mul_868 = torch.ops.aten.mul.Tensor(convert_element_type_59, sub_91); convert_element_type_59 = sub_91 = None + add_340 = torch.ops.aten.add.Tensor(mul_868, 1); mul_868 = None + mul_869 = torch.ops.aten.mul.Tensor(mul_867, add_340); mul_867 = add_340 = None + convert_element_type_2706 = torch.ops.prims.convert_element_type.default(mul_869, torch.bfloat16); mul_869 = None + view_1819 = torch.ops.aten.view.default(convert_element_type_2706, [16384, 14336]); convert_element_type_2706 = None + permute_1325 = torch.ops.aten.permute.default(view_1819, [1, 0]) + mm_651 = torch.ops.aten.mm.default(permute_1325, view_61); permute_1325 = view_61 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16); primals_19 = None + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 64, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + permute_1327 = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None + mm_652 = torch.ops.aten.mm.default(view_1819, permute_1327); view_1819 = permute_1327 = None + view_1820 = torch.ops.aten.view.default(mm_652, [2, 8192, 4096]); mm_652 = None + add_341 = torch.ops.aten.add.Tensor(view_1818, view_1820); view_1818 = view_1820 = None + convert_element_type_2711 = torch.ops.prims.convert_element_type.default(mm_651, torch.float32); mm_651 = None + reduce_scatter_tensor_274 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2711, 'avg', 64, '0'); convert_element_type_2711 = None + wait_tensor_565 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_274); reduce_scatter_tensor_274 = None + convert_element_type_2712 = torch.ops.prims.convert_element_type.default(add_341, torch.float32); add_341 = None + convert_element_type_2714 = torch.ops.prims.convert_element_type.default(wait_tensor_15, torch.float32); wait_tensor_15 = None + mul_870 = torch.ops.aten.mul.Tensor(convert_element_type_2712, convert_element_type_2714); convert_element_type_2714 = None + mul_872 = torch.ops.aten.mul.Tensor(mul_12, mul_870) + sum_183 = torch.ops.aten.sum.dim_IntList(mul_872, [2], True); mul_872 = None + div_61 = torch.ops.aten.div.Tensor(mul_12, 4096) + mul_873 = torch.ops.aten.mul.Tensor(div_61, sum_183); div_61 = sum_183 = None + sub_92 = torch.ops.aten.sub.Tensor(mul_870, mul_873); mul_870 = mul_873 = None + mul_874 = torch.ops.aten.mul.Tensor(sub_92, rsqrt_3); sub_92 = rsqrt_3 = None + mul_875 = torch.ops.aten.mul.Tensor(convert_element_type_2712, mul_12); convert_element_type_2712 = mul_12 = None + sum_184 = torch.ops.aten.sum.dim_IntList(mul_875, [0, 1]); mul_875 = None + convert_element_type_2715 = torch.ops.prims.convert_element_type.default(mul_874, torch.bfloat16); mul_874 = None + add_342 = torch.ops.aten.add.Tensor(add_338, convert_element_type_2715); add_338 = convert_element_type_2715 = None + convert_element_type_default_4 = torch.ops.prims.convert_element_type.default(sum_184, torch.float32); sum_184 = None + reduce_scatter_tensor_275 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_4, 'avg', 64, '0'); convert_element_type_default_4 = None + wait_tensor_566 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_275); reduce_scatter_tensor_275 = None + view_1821 = torch.ops.aten.view.default(add_342, [16384, 4096]) + permute_1329 = torch.ops.aten.permute.default(view_1821, [1, 0]) + mm_653 = torch.ops.aten.mm.default(permute_1329, view_57); permute_1329 = view_57 = None + permute_1331 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None + mm_654 = torch.ops.aten.mm.default(view_1821, permute_1331); view_1821 = permute_1331 = None + view_1822 = torch.ops.aten.view.default(mm_654, [2, 8192, 4096]); mm_654 = None + convert_element_type_2722 = torch.ops.prims.convert_element_type.default(mm_653, torch.float32); mm_653 = None + reduce_scatter_tensor_276 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2722, 'avg', 64, '0'); convert_element_type_2722 = None + wait_tensor_567 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_276); reduce_scatter_tensor_276 = None + view_1823 = torch.ops.aten.view.default(view_1822, [2, 8192, 32, 128]); view_1822 = None + permute_1333 = torch.ops.aten.permute.default(view_1823, [0, 2, 1, 3]); view_1823 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16); primals_13 = None + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 64, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32); add_3 = None + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_10) + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + view_37 = torch.ops.aten.view.default(convert_element_type_36, [16384, 4096]); convert_element_type_36 = None + view_38 = torch.ops.aten.view.default(mm_7, [2, 8192, 4096]); mm_7 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16); primals_15 = None + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 64, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_8 = torch.ops.aten.mm.default(view_37, permute_12) + view_41 = torch.ops.aten.view.default(mm_8, [2, 8192, 1024]); mm_8 = None + view_44 = torch.ops.aten.view.default(mm_9, [2, 8192, 1024]); mm_9 = None + view_45 = torch.ops.aten.view.default(view_38, [2, 8192, -1, 128]); view_38 = None + view_46 = torch.ops.aten.view.default(view_41, [2, 8192, -1, 128]); view_41 = None + view_47 = torch.ops.aten.view.default(view_44, [2, 8192, -1, 128]); view_44 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_45, torch.float32); view_45 = None + view_48 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 32, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_48); view_48 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_46, torch.float32); view_46 = None + view_49 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 8, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_49); view_49 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_16); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_51 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 32, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_16); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_52 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 8, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_51, torch.bfloat16); view_51 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_52, torch.bfloat16); view_52 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 8, 4, 128]); unsqueeze_2 = None + clone_2 = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format); expand_2 = None + view_53 = torch.ops.aten.view.default(clone_2, [2, 8192, 32, 128]); clone_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_47, 3); view_47 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 8, 4, 128]); unsqueeze_3 = None + clone_3 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None + view_54 = torch.ops.aten.view.default(clone_3, [2, 8192, 32, 128]); clone_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_53, [0, 2, 1, 3]); view_53 = None + permute_16 = torch.ops.aten.permute.default(view_54, [0, 2, 1, 3]); view_54 = None + _scaled_dot_product_cudnn_attention_backward_30 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1333, permute_14, permute_15, permute_16, getitem_9, getitem_10, getitem_15, getitem_16, None, None, None, 8192, 8192, 0.0, True); permute_1333 = permute_14 = permute_15 = permute_16 = getitem_9 = getitem_10 = getitem_15 = getitem_16 = None + getitem_378 = _scaled_dot_product_cudnn_attention_backward_30[0] + getitem_379 = _scaled_dot_product_cudnn_attention_backward_30[1] + getitem_380 = _scaled_dot_product_cudnn_attention_backward_30[2]; _scaled_dot_product_cudnn_attention_backward_30 = None + permute_1334 = torch.ops.aten.permute.default(getitem_380, [0, 2, 1, 3]); getitem_380 = None + permute_1335 = torch.ops.aten.permute.default(getitem_379, [0, 2, 1, 3]); getitem_379 = None + permute_1336 = torch.ops.aten.permute.default(getitem_378, [0, 2, 1, 3]); getitem_378 = None + view_1824 = torch.ops.aten.view.default(permute_1334, [2, 8192, 8, 4, 128]); permute_1334 = None + sum_185 = torch.ops.aten.sum.dim_IntList(view_1824, [3], True); view_1824 = None + squeeze_60 = torch.ops.aten.squeeze.dim(sum_185, 3); sum_185 = None + view_1825 = torch.ops.aten.view.default(permute_1335, [2, 8192, 8, 4, 128]); permute_1335 = None + sum_186 = torch.ops.aten.sum.dim_IntList(view_1825, [3], True); view_1825 = None + squeeze_61 = torch.ops.aten.squeeze.dim(sum_186, 3); sum_186 = None + convert_element_type_2723 = torch.ops.prims.convert_element_type.default(squeeze_61, torch.float32); squeeze_61 = None + convert_element_type_2724 = torch.ops.prims.convert_element_type.default(permute_1336, torch.float32); permute_1336 = None + view_1826 = torch.ops.aten.view.default(convert_element_type_2723, [2, 8192, 8, 64, 2]); convert_element_type_2723 = None + view_as_complex_124 = torch.ops.aten.view_as_complex.default(view_1826); view_1826 = None + mul_876 = torch.ops.aten.mul.Tensor(view_as_complex_124, _conj); view_as_complex_124 = None + view_1827 = torch.ops.aten.view.default(convert_element_type_2724, [2, 8192, 32, 64, 2]); convert_element_type_2724 = None + view_as_complex_125 = torch.ops.aten.view_as_complex.default(view_1827); view_1827 = None + mul_877 = torch.ops.aten.mul.Tensor(view_as_complex_125, _conj); view_as_complex_125 = None + view_as_real_124 = torch.ops.aten.view_as_real.default(mul_876); mul_876 = None + view_1828 = torch.ops.aten.view.default(view_as_real_124, [2, 8192, 8, 128]); view_as_real_124 = None + convert_element_type_2725 = torch.ops.prims.convert_element_type.default(view_1828, torch.bfloat16); view_1828 = None + view_as_real_125 = torch.ops.aten.view_as_real.default(mul_877); mul_877 = None + view_1829 = torch.ops.aten.view.default(view_as_real_125, [2, 8192, 32, 128]); view_as_real_125 = None + convert_element_type_2726 = torch.ops.prims.convert_element_type.default(view_1829, torch.bfloat16); view_1829 = None + view_1830 = torch.ops.aten.view.default(squeeze_60, [2, 8192, 1024]); squeeze_60 = None + view_1831 = torch.ops.aten.view.default(convert_element_type_2725, [2, 8192, 1024]); convert_element_type_2725 = None + view_1832 = torch.ops.aten.view.default(convert_element_type_2726, [2, 8192, 4096]); convert_element_type_2726 = None + view_1833 = torch.ops.aten.view.default(view_1830, [16384, 1024]); view_1830 = None + permute_1337 = torch.ops.aten.permute.default(view_1833, [1, 0]) + mm_655 = torch.ops.aten.mm.default(permute_1337, view_37); permute_1337 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16); primals_16 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 64, '0'); convert_element_type_43 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + permute_1339 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None + mm_656 = torch.ops.aten.mm.default(view_1833, permute_1339); view_1833 = permute_1339 = None + view_1834 = torch.ops.aten.view.default(mm_656, [2, 8192, 4096]); mm_656 = None + convert_element_type_2731 = torch.ops.prims.convert_element_type.default(mm_655, torch.float32); mm_655 = None + reduce_scatter_tensor_277 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2731, 'avg', 64, '0'); convert_element_type_2731 = None + wait_tensor_568 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_277); reduce_scatter_tensor_277 = None + view_1835 = torch.ops.aten.view.default(view_1831, [16384, 1024]); view_1831 = None + permute_1341 = torch.ops.aten.permute.default(view_1835, [1, 0]) + mm_657 = torch.ops.aten.mm.default(permute_1341, view_37); permute_1341 = None + permute_1343 = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None + mm_658 = torch.ops.aten.mm.default(view_1835, permute_1343); view_1835 = permute_1343 = None + view_1836 = torch.ops.aten.view.default(mm_658, [2, 8192, 4096]); mm_658 = None + add_343 = torch.ops.aten.add.Tensor(view_1834, view_1836); view_1834 = view_1836 = None + convert_element_type_2736 = torch.ops.prims.convert_element_type.default(mm_657, torch.float32); mm_657 = None + reduce_scatter_tensor_278 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2736, 'avg', 64, '0'); convert_element_type_2736 = None + wait_tensor_569 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_278); reduce_scatter_tensor_278 = None + view_1837 = torch.ops.aten.view.default(view_1832, [16384, 4096]); view_1832 = None + permute_1345 = torch.ops.aten.permute.default(view_1837, [1, 0]) + mm_659 = torch.ops.aten.mm.default(permute_1345, view_37); permute_1345 = view_37 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16); primals_14 = None + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 64, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + permute_1347 = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None + mm_660 = torch.ops.aten.mm.default(view_1837, permute_1347); view_1837 = permute_1347 = None + view_1838 = torch.ops.aten.view.default(mm_660, [2, 8192, 4096]); mm_660 = None + add_344 = torch.ops.aten.add.Tensor(add_343, view_1838); add_343 = view_1838 = None + convert_element_type_2741 = torch.ops.prims.convert_element_type.default(mm_659, torch.float32); mm_659 = None + reduce_scatter_tensor_279 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2741, 'avg', 64, '0'); convert_element_type_2741 = None + wait_tensor_570 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_279); reduce_scatter_tensor_279 = None + convert_element_type_2742 = torch.ops.prims.convert_element_type.default(add_344, torch.float32); add_344 = None + convert_element_type_2744 = torch.ops.prims.convert_element_type.default(wait_tensor_10, torch.float32); wait_tensor_10 = None + mul_878 = torch.ops.aten.mul.Tensor(convert_element_type_2742, convert_element_type_2744); convert_element_type_2744 = None + mul_880 = torch.ops.aten.mul.Tensor(mul_8, mul_878) + sum_187 = torch.ops.aten.sum.dim_IntList(mul_880, [2], True); mul_880 = None + div_62 = torch.ops.aten.div.Tensor(mul_8, 4096) + mul_881 = torch.ops.aten.mul.Tensor(div_62, sum_187); div_62 = sum_187 = None + sub_93 = torch.ops.aten.sub.Tensor(mul_878, mul_881); mul_878 = mul_881 = None + mul_882 = torch.ops.aten.mul.Tensor(sub_93, rsqrt_2); sub_93 = rsqrt_2 = None + mul_883 = torch.ops.aten.mul.Tensor(convert_element_type_2742, mul_8); convert_element_type_2742 = mul_8 = None + sum_188 = torch.ops.aten.sum.dim_IntList(mul_883, [0, 1]); mul_883 = None + convert_element_type_2745 = torch.ops.prims.convert_element_type.default(mul_882, torch.bfloat16); mul_882 = None + add_345 = torch.ops.aten.add.Tensor(add_342, convert_element_type_2745); add_342 = convert_element_type_2745 = None + convert_element_type_default_3 = torch.ops.prims.convert_element_type.default(sum_188, torch.float32); sum_188 = None + reduce_scatter_tensor_280 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_3, 'avg', 64, '0'); convert_element_type_default_3 = None + wait_tensor_571 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_280); reduce_scatter_tensor_280 = None + view_1839 = torch.ops.aten.view.default(add_345, [16384, 4096]) + permute_1349 = torch.ops.aten.permute.default(view_1839, [1, 0]) + permute_6 = torch.ops.aten.permute.default(getitem, [0, 2, 1, 3]) + view_21 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16); primals_8 = None + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 64, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + view_23 = torch.ops.aten.view.default(view_21, [16384, 4096]); view_21 = None + mm_3 = torch.ops.aten.mm.default(view_23, permute_7) + view_24 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + add_1 = torch.ops.aten.add.Tensor(embedding, view_24); view_24 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16); primals_9 = None + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 64, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32); add_1 = None + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_6) + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + view_27 = torch.ops.aten.view.default(convert_element_type_22, [16384, 4096]); convert_element_type_22 = None + view_28 = torch.ops.aten.view.default(mm_4, [2, 8192, 14336]); mm_4 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_28, torch.float32); view_28 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16); primals_11 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 64, '0'); convert_element_type_28 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_8, [1, 0]); wait_tensor_8 = None + mm_5 = torch.ops.aten.mm.default(view_27, permute_9) + view_31 = torch.ops.aten.view.default(mm_5, [2, 8192, 14336]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_31) + view_33 = torch.ops.aten.view.default(mul_7, [16384, 14336]); mul_7 = None + mm_661 = torch.ops.aten.mm.default(permute_1349, view_33); permute_1349 = view_33 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16); primals_12 = None + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 64, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_9, [1, 0]); wait_tensor_9 = None + permute_1351 = torch.ops.aten.permute.default(permute_10, [1, 0]); permute_10 = None + mm_662 = torch.ops.aten.mm.default(view_1839, permute_1351); view_1839 = permute_1351 = None + view_1840 = torch.ops.aten.view.default(mm_662, [2, 8192, 14336]); mm_662 = None + convert_element_type_2752 = torch.ops.prims.convert_element_type.default(mm_661, torch.float32); mm_661 = None + reduce_scatter_tensor_281 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2752, 'avg', 64, '0'); convert_element_type_2752 = None + wait_tensor_572 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_281); reduce_scatter_tensor_281 = None + mul_884 = torch.ops.aten.mul.Tensor(view_1840, convert_element_type_27); convert_element_type_27 = None + mul_885 = torch.ops.aten.mul.Tensor(view_1840, view_31); view_1840 = view_31 = None + view_1841 = torch.ops.aten.view.default(mul_884, [16384, 14336]); mul_884 = None + permute_1353 = torch.ops.aten.permute.default(view_1841, [1, 0]) + mm_663 = torch.ops.aten.mm.default(permute_1353, view_27); permute_1353 = None + permute_1355 = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None + mm_664 = torch.ops.aten.mm.default(view_1841, permute_1355); view_1841 = permute_1355 = None + view_1842 = torch.ops.aten.view.default(mm_664, [2, 8192, 4096]); mm_664 = None + convert_element_type_2757 = torch.ops.prims.convert_element_type.default(mm_663, torch.float32); mm_663 = None + reduce_scatter_tensor_282 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2757, 'avg', 64, '0'); convert_element_type_2757 = None + wait_tensor_573 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_282); reduce_scatter_tensor_282 = None + convert_element_type_2758 = torch.ops.prims.convert_element_type.default(mul_885, torch.float32); mul_885 = None + neg_31 = torch.ops.aten.neg.default(convert_element_type_26) + exp_31 = torch.ops.aten.exp.default(neg_31); neg_31 = None + add_346 = torch.ops.aten.add.Tensor(exp_31, 1); exp_31 = None + reciprocal_31 = torch.ops.aten.reciprocal.default(add_346); add_346 = None + mul_886 = torch.ops.aten.mul.Tensor(reciprocal_31, 1); reciprocal_31 = None + mul_887 = torch.ops.aten.mul.Tensor(convert_element_type_2758, mul_886); convert_element_type_2758 = None + sub_94 = torch.ops.aten.sub.Tensor(1, mul_886); mul_886 = None + mul_888 = torch.ops.aten.mul.Tensor(convert_element_type_26, sub_94); convert_element_type_26 = sub_94 = None + add_347 = torch.ops.aten.add.Tensor(mul_888, 1); mul_888 = None + mul_889 = torch.ops.aten.mul.Tensor(mul_887, add_347); mul_887 = add_347 = None + convert_element_type_2760 = torch.ops.prims.convert_element_type.default(mul_889, torch.bfloat16); mul_889 = None + view_1843 = torch.ops.aten.view.default(convert_element_type_2760, [16384, 14336]); convert_element_type_2760 = None + permute_1357 = torch.ops.aten.permute.default(view_1843, [1, 0]) + mm_665 = torch.ops.aten.mm.default(permute_1357, view_27); permute_1357 = view_27 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16); primals_10 = None + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 64, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + permute_1359 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None + mm_666 = torch.ops.aten.mm.default(view_1843, permute_1359); view_1843 = permute_1359 = None + view_1844 = torch.ops.aten.view.default(mm_666, [2, 8192, 4096]); mm_666 = None + add_348 = torch.ops.aten.add.Tensor(view_1842, view_1844); view_1842 = view_1844 = None + convert_element_type_2765 = torch.ops.prims.convert_element_type.default(mm_665, torch.float32); mm_665 = None + reduce_scatter_tensor_283 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2765, 'avg', 64, '0'); convert_element_type_2765 = None + wait_tensor_574 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_283); reduce_scatter_tensor_283 = None + convert_element_type_2766 = torch.ops.prims.convert_element_type.default(add_348, torch.float32); add_348 = None + convert_element_type_2768 = torch.ops.prims.convert_element_type.default(wait_tensor_6, torch.float32); wait_tensor_6 = None + mul_890 = torch.ops.aten.mul.Tensor(convert_element_type_2766, convert_element_type_2768); convert_element_type_2768 = None + mul_892 = torch.ops.aten.mul.Tensor(mul_4, mul_890) + sum_189 = torch.ops.aten.sum.dim_IntList(mul_892, [2], True); mul_892 = None + div_63 = torch.ops.aten.div.Tensor(mul_4, 4096) + mul_893 = torch.ops.aten.mul.Tensor(div_63, sum_189); div_63 = sum_189 = None + sub_95 = torch.ops.aten.sub.Tensor(mul_890, mul_893); mul_890 = mul_893 = None + mul_894 = torch.ops.aten.mul.Tensor(sub_95, rsqrt_1); sub_95 = rsqrt_1 = None + mul_895 = torch.ops.aten.mul.Tensor(convert_element_type_2766, mul_4); convert_element_type_2766 = mul_4 = None + sum_190 = torch.ops.aten.sum.dim_IntList(mul_895, [0, 1]); mul_895 = None + convert_element_type_2769 = torch.ops.prims.convert_element_type.default(mul_894, torch.bfloat16); mul_894 = None + add_349 = torch.ops.aten.add.Tensor(add_345, convert_element_type_2769); add_345 = convert_element_type_2769 = None + convert_element_type_default_2 = torch.ops.prims.convert_element_type.default(sum_190, torch.float32); sum_190 = None + reduce_scatter_tensor_284 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_2, 'avg', 64, '0'); convert_element_type_default_2 = None + wait_tensor_575 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_284); reduce_scatter_tensor_284 = None + view_1845 = torch.ops.aten.view.default(add_349, [16384, 4096]) + permute_1361 = torch.ops.aten.permute.default(view_1845, [1, 0]) + mm_667 = torch.ops.aten.mm.default(permute_1361, view_23); permute_1361 = view_23 = None + permute_1363 = torch.ops.aten.permute.default(permute_7, [1, 0]); permute_7 = None + mm_668 = torch.ops.aten.mm.default(view_1845, permute_1363); view_1845 = permute_1363 = None + view_1846 = torch.ops.aten.view.default(mm_668, [2, 8192, 4096]); mm_668 = None + convert_element_type_2776 = torch.ops.prims.convert_element_type.default(mm_667, torch.float32); mm_667 = None + reduce_scatter_tensor_285 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2776, 'avg', 64, '0'); convert_element_type_2776 = None + wait_tensor_576 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_285); reduce_scatter_tensor_285 = None + view_1847 = torch.ops.aten.view.default(view_1846, [2, 8192, 32, 128]); view_1846 = None + permute_1365 = torch.ops.aten.permute.default(view_1847, [0, 2, 1, 3]); view_1847 = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 64, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32); embedding = None + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_1) + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + view_3 = torch.ops.aten.view.default(convert_element_type_3, [16384, 4096]); convert_element_type_3 = None + view_4 = torch.ops.aten.view.default(mm, [2, 8192, 4096]); mm = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16); primals_6 = None + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 64, '0'); convert_element_type_7 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_3, [1, 0]); wait_tensor_3 = None + mm_1 = torch.ops.aten.mm.default(view_3, permute_1) + view_7 = torch.ops.aten.view.default(mm_1, [2, 8192, 1024]); mm_1 = None + view_10 = torch.ops.aten.view.default(mm_2, [2, 8192, 1024]); mm_2 = None + view_11 = torch.ops.aten.view.default(view_4, [2, 8192, -1, 128]); view_4 = None + view_12 = torch.ops.aten.view.default(view_7, [2, 8192, -1, 128]); view_7 = None + view_13 = torch.ops.aten.view.default(view_10, [2, 8192, -1, 128]); view_10 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None + view_14 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 32, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_14); view_14 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_12, torch.float32); view_12 = None + view_15 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 8, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_15); view_15 = None + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_16); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_17 = torch.ops.aten.view.default(view_as_real, [2, 8192, 32, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_16); view_as_complex_1 = view_16 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_18 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 8, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_17, torch.bfloat16); view_17 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_18, torch.bfloat16); view_18 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 8, 4, 128]); unsqueeze = None + clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None + view_19 = torch.ops.aten.view.default(clone, [2, 8192, 32, 128]); clone = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_13, 3); view_13 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 8, 4, 128]); unsqueeze_1 = None + clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None + view_20 = torch.ops.aten.view.default(clone_1, [2, 8192, 32, 128]); clone_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]); view_19 = None + permute_5 = torch.ops.aten.permute.default(view_20, [0, 2, 1, 3]); view_20 = None + _scaled_dot_product_cudnn_attention_backward_31 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_1365, permute_3, permute_4, permute_5, getitem, getitem_1, getitem_6, getitem_7, None, None, None, 8192, 8192, 0.0, True); permute_1365 = permute_3 = permute_4 = permute_5 = getitem = getitem_1 = getitem_6 = getitem_7 = None + getitem_381 = _scaled_dot_product_cudnn_attention_backward_31[0] + getitem_382 = _scaled_dot_product_cudnn_attention_backward_31[1] + getitem_383 = _scaled_dot_product_cudnn_attention_backward_31[2]; _scaled_dot_product_cudnn_attention_backward_31 = None + permute_1366 = torch.ops.aten.permute.default(getitem_383, [0, 2, 1, 3]); getitem_383 = None + permute_1367 = torch.ops.aten.permute.default(getitem_382, [0, 2, 1, 3]); getitem_382 = None + permute_1368 = torch.ops.aten.permute.default(getitem_381, [0, 2, 1, 3]); getitem_381 = None + view_1848 = torch.ops.aten.view.default(permute_1366, [2, 8192, 8, 4, 128]); permute_1366 = None + sum_191 = torch.ops.aten.sum.dim_IntList(view_1848, [3], True); view_1848 = None + squeeze_62 = torch.ops.aten.squeeze.dim(sum_191, 3); sum_191 = None + view_1849 = torch.ops.aten.view.default(permute_1367, [2, 8192, 8, 4, 128]); permute_1367 = None + sum_192 = torch.ops.aten.sum.dim_IntList(view_1849, [3], True); view_1849 = None + squeeze_63 = torch.ops.aten.squeeze.dim(sum_192, 3); sum_192 = None + convert_element_type_2777 = torch.ops.prims.convert_element_type.default(squeeze_63, torch.float32); squeeze_63 = None + convert_element_type_2778 = torch.ops.prims.convert_element_type.default(permute_1368, torch.float32); permute_1368 = None + view_1850 = torch.ops.aten.view.default(convert_element_type_2777, [2, 8192, 8, 64, 2]); convert_element_type_2777 = None + view_as_complex_126 = torch.ops.aten.view_as_complex.default(view_1850); view_1850 = None + mul_896 = torch.ops.aten.mul.Tensor(view_as_complex_126, _conj); view_as_complex_126 = None + view_1851 = torch.ops.aten.view.default(convert_element_type_2778, [2, 8192, 32, 64, 2]); convert_element_type_2778 = None + view_as_complex_127 = torch.ops.aten.view_as_complex.default(view_1851); view_1851 = None + mul_897 = torch.ops.aten.mul.Tensor(view_as_complex_127, _conj); view_as_complex_127 = _conj = None + view_as_real_126 = torch.ops.aten.view_as_real.default(mul_896); mul_896 = None + view_1852 = torch.ops.aten.view.default(view_as_real_126, [2, 8192, 8, 128]); view_as_real_126 = None + convert_element_type_2779 = torch.ops.prims.convert_element_type.default(view_1852, torch.bfloat16); view_1852 = None + view_as_real_127 = torch.ops.aten.view_as_real.default(mul_897); mul_897 = None + view_1853 = torch.ops.aten.view.default(view_as_real_127, [2, 8192, 32, 128]); view_as_real_127 = None + convert_element_type_2780 = torch.ops.prims.convert_element_type.default(view_1853, torch.bfloat16); view_1853 = None + view_1854 = torch.ops.aten.view.default(squeeze_62, [2, 8192, 1024]); squeeze_62 = None + view_1855 = torch.ops.aten.view.default(convert_element_type_2779, [2, 8192, 1024]); convert_element_type_2779 = None + view_1856 = torch.ops.aten.view.default(convert_element_type_2780, [2, 8192, 4096]); convert_element_type_2780 = None + view_1857 = torch.ops.aten.view.default(view_1854, [16384, 1024]); view_1854 = None + permute_1369 = torch.ops.aten.permute.default(view_1857, [1, 0]) + mm_669 = torch.ops.aten.mm.default(permute_1369, view_3); permute_1369 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16); primals_7 = None + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 64, '0'); convert_element_type_10 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + permute_1371 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + mm_670 = torch.ops.aten.mm.default(view_1857, permute_1371); view_1857 = permute_1371 = None + view_1858 = torch.ops.aten.view.default(mm_670, [2, 8192, 4096]); mm_670 = None + convert_element_type_2785 = torch.ops.prims.convert_element_type.default(mm_669, torch.float32); mm_669 = None + reduce_scatter_tensor_286 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2785, 'avg', 64, '0'); convert_element_type_2785 = None + wait_tensor_577 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_286); reduce_scatter_tensor_286 = None + view_1859 = torch.ops.aten.view.default(view_1855, [16384, 1024]); view_1855 = None + permute_1373 = torch.ops.aten.permute.default(view_1859, [1, 0]) + mm_671 = torch.ops.aten.mm.default(permute_1373, view_3); permute_1373 = None + permute_1375 = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None + mm_672 = torch.ops.aten.mm.default(view_1859, permute_1375); view_1859 = permute_1375 = None + view_1860 = torch.ops.aten.view.default(mm_672, [2, 8192, 4096]); mm_672 = None + add_350 = torch.ops.aten.add.Tensor(view_1858, view_1860); view_1858 = view_1860 = None + convert_element_type_2790 = torch.ops.prims.convert_element_type.default(mm_671, torch.float32); mm_671 = None + reduce_scatter_tensor_287 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2790, 'avg', 64, '0'); convert_element_type_2790 = None + wait_tensor_578 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_287); reduce_scatter_tensor_287 = None + view_1861 = torch.ops.aten.view.default(view_1856, [16384, 4096]); view_1856 = None + permute_1377 = torch.ops.aten.permute.default(view_1861, [1, 0]) + mm_673 = torch.ops.aten.mm.default(permute_1377, view_3); permute_1377 = view_3 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16); primals_5 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 64, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + permute_1379 = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + mm_674 = torch.ops.aten.mm.default(view_1861, permute_1379); view_1861 = permute_1379 = None + view_1862 = torch.ops.aten.view.default(mm_674, [2, 8192, 4096]); mm_674 = None + add_351 = torch.ops.aten.add.Tensor(add_350, view_1862); add_350 = view_1862 = None + convert_element_type_2795 = torch.ops.prims.convert_element_type.default(mm_673, torch.float32); mm_673 = None + reduce_scatter_tensor_288 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_2795, 'avg', 64, '0'); convert_element_type_2795 = None + wait_tensor_579 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_288); reduce_scatter_tensor_288 = None + convert_element_type_2796 = torch.ops.prims.convert_element_type.default(add_351, torch.float32); add_351 = None + convert_element_type_2798 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32); wait_tensor_1 = None + mul_898 = torch.ops.aten.mul.Tensor(convert_element_type_2796, convert_element_type_2798); convert_element_type_2798 = None + mul_900 = torch.ops.aten.mul.Tensor(mul, mul_898) + sum_193 = torch.ops.aten.sum.dim_IntList(mul_900, [2], True); mul_900 = None + div_64 = torch.ops.aten.div.Tensor(mul, 4096) + mul_901 = torch.ops.aten.mul.Tensor(div_64, sum_193); div_64 = sum_193 = None + sub_96 = torch.ops.aten.sub.Tensor(mul_898, mul_901); mul_898 = mul_901 = None + mul_902 = torch.ops.aten.mul.Tensor(sub_96, rsqrt); sub_96 = rsqrt = None + mul_903 = torch.ops.aten.mul.Tensor(convert_element_type_2796, mul); convert_element_type_2796 = mul = None + sum_194 = torch.ops.aten.sum.dim_IntList(mul_903, [0, 1]); mul_903 = None + convert_element_type_2799 = torch.ops.prims.convert_element_type.default(mul_902, torch.bfloat16); mul_902 = None + add_352 = torch.ops.aten.add.Tensor(add_349, convert_element_type_2799); add_349 = convert_element_type_2799 = None + convert_element_type_default_1 = torch.ops.prims.convert_element_type.default(sum_194, torch.float32); sum_194 = None + reduce_scatter_tensor_289 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default_1, 'avg', 64, '0'); convert_element_type_default_1 = None + wait_tensor_580 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_289); reduce_scatter_tensor_289 = None + convert_element_type_2802 = torch.ops.prims.convert_element_type.default(add_352, torch.float32); add_352 = None + eq = torch.ops.aten.eq.Scalar(primals_2, -1) + unsqueeze_64 = torch.ops.aten.unsqueeze.default(eq, -1); eq = None + full_default = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + where = torch.ops.aten.where.self(unsqueeze_64, full_default, convert_element_type_2802); unsqueeze_64 = full_default = convert_element_type_2802 = None + full_default_1 = torch.ops.aten.full.default([128256, 4096], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put = torch.ops.aten.index_put.default(full_default_1, [primals_2], where, True); full_default_1 = primals_2 = where = None + convert_element_type_default = torch.ops.prims.convert_element_type.default(index_put, torch.float32); index_put = None + reduce_scatter_tensor_290 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_default, 'avg', 64, '0'); convert_element_type_default = None + wait_tensor_581 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_290); reduce_scatter_tensor_290 = None + return (wait_tensor_581, None, None, wait_tensor_580, wait_tensor_579, wait_tensor_578, wait_tensor_577, wait_tensor_576, wait_tensor_575, wait_tensor_574, wait_tensor_573, wait_tensor_572, wait_tensor_571, wait_tensor_570, wait_tensor_569, wait_tensor_568, wait_tensor_567, wait_tensor_566, wait_tensor_565, wait_tensor_564, wait_tensor_563, wait_tensor_562, wait_tensor_561, wait_tensor_560, wait_tensor_559, wait_tensor_558, wait_tensor_557, wait_tensor_556, wait_tensor_555, wait_tensor_554, wait_tensor_553, wait_tensor_552, wait_tensor_551, wait_tensor_550, wait_tensor_549, wait_tensor_548, wait_tensor_547, wait_tensor_546, wait_tensor_545, wait_tensor_544, wait_tensor_543, wait_tensor_542, wait_tensor_541, wait_tensor_540, wait_tensor_539, wait_tensor_538, wait_tensor_537, wait_tensor_536, wait_tensor_535, wait_tensor_534, wait_tensor_533, wait_tensor_532, wait_tensor_531, wait_tensor_530, wait_tensor_529, wait_tensor_528, wait_tensor_527, wait_tensor_526, wait_tensor_525, wait_tensor_524, wait_tensor_523, wait_tensor_522, wait_tensor_521, wait_tensor_520, wait_tensor_519, wait_tensor_518, wait_tensor_517, wait_tensor_516, wait_tensor_515, wait_tensor_514, wait_tensor_513, wait_tensor_512, wait_tensor_511, wait_tensor_510, wait_tensor_509, wait_tensor_508, wait_tensor_507, wait_tensor_506, wait_tensor_505, wait_tensor_504, wait_tensor_503, wait_tensor_502, wait_tensor_501, wait_tensor_500, wait_tensor_499, wait_tensor_498, wait_tensor_497, wait_tensor_496, wait_tensor_495, wait_tensor_494, wait_tensor_493, wait_tensor_492, wait_tensor_491, wait_tensor_490, wait_tensor_489, wait_tensor_488, wait_tensor_487, wait_tensor_486, wait_tensor_485, wait_tensor_484, wait_tensor_483, wait_tensor_482, wait_tensor_481, wait_tensor_480, wait_tensor_479, wait_tensor_478, wait_tensor_477, wait_tensor_476, wait_tensor_475, wait_tensor_474, wait_tensor_473, wait_tensor_472, wait_tensor_471, wait_tensor_470, wait_tensor_469, wait_tensor_468, wait_tensor_467, wait_tensor_466, wait_tensor_465, wait_tensor_464, wait_tensor_463, wait_tensor_462, wait_tensor_461, wait_tensor_460, wait_tensor_459, wait_tensor_458, wait_tensor_457, wait_tensor_456, wait_tensor_455, wait_tensor_454, wait_tensor_453, wait_tensor_452, wait_tensor_451, wait_tensor_450, wait_tensor_449, wait_tensor_448, wait_tensor_447, wait_tensor_446, wait_tensor_445, wait_tensor_444, wait_tensor_443, wait_tensor_442, wait_tensor_441, wait_tensor_440, wait_tensor_439, wait_tensor_438, wait_tensor_437, wait_tensor_436, wait_tensor_435, wait_tensor_434, wait_tensor_433, wait_tensor_432, wait_tensor_431, wait_tensor_430, wait_tensor_429, wait_tensor_428, wait_tensor_427, wait_tensor_426, wait_tensor_425, wait_tensor_424, wait_tensor_423, wait_tensor_422, wait_tensor_421, wait_tensor_420, wait_tensor_419, wait_tensor_418, wait_tensor_417, wait_tensor_416, wait_tensor_415, wait_tensor_414, wait_tensor_413, wait_tensor_412, wait_tensor_411, wait_tensor_410, wait_tensor_409, wait_tensor_408, wait_tensor_407, wait_tensor_406, wait_tensor_405, wait_tensor_404, wait_tensor_403, wait_tensor_402, wait_tensor_401, wait_tensor_400, wait_tensor_399, wait_tensor_398, wait_tensor_397, wait_tensor_396, wait_tensor_395, wait_tensor_394, wait_tensor_393, wait_tensor_392, wait_tensor_391, wait_tensor_390, wait_tensor_389, wait_tensor_388, wait_tensor_387, wait_tensor_386, wait_tensor_385, wait_tensor_384, wait_tensor_383, wait_tensor_382, wait_tensor_381, wait_tensor_380, wait_tensor_379, wait_tensor_378, wait_tensor_377, wait_tensor_376, wait_tensor_375, wait_tensor_374, wait_tensor_373, wait_tensor_372, wait_tensor_371, wait_tensor_370, wait_tensor_369, wait_tensor_368, wait_tensor_367, wait_tensor_366, wait_tensor_365, wait_tensor_364, wait_tensor_363, wait_tensor_362, wait_tensor_361, wait_tensor_360, wait_tensor_359, wait_tensor_358, wait_tensor_357, wait_tensor_356, wait_tensor_355, wait_tensor_354, wait_tensor_353, wait_tensor_352, wait_tensor_351, wait_tensor_350, wait_tensor_349, wait_tensor_348, wait_tensor_347, wait_tensor_346, wait_tensor_345, wait_tensor_344, wait_tensor_343, wait_tensor_342, wait_tensor_341, wait_tensor_340, wait_tensor_339, wait_tensor_338, wait_tensor_337, wait_tensor_336, wait_tensor_335, wait_tensor_334, wait_tensor_333, wait_tensor_332, wait_tensor_331, wait_tensor_330, wait_tensor_329, wait_tensor_328, wait_tensor_327, wait_tensor_326, wait_tensor_325, wait_tensor_324, wait_tensor_323, wait_tensor_322, wait_tensor_321, wait_tensor_320, wait_tensor_319, wait_tensor_318, wait_tensor_317, wait_tensor_316, wait_tensor_315, wait_tensor_314, wait_tensor_313, wait_tensor_312, wait_tensor_311, wait_tensor_310, wait_tensor_309, wait_tensor_308, wait_tensor_307, wait_tensor_306, wait_tensor_305, wait_tensor_304, wait_tensor_303, wait_tensor_302, wait_tensor_301, wait_tensor_300, wait_tensor_299, wait_tensor_298, wait_tensor_297, wait_tensor_296, wait_tensor_295, wait_tensor_294, wait_tensor_293, wait_tensor_292, wait_tensor_291) + +def load_args(reader): + buf0 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf0, (2004, 4096), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf3, (64,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf4, (64, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf5, (16, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf6, (16, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf7, (64, 4096), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf8, (64,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf9, (224, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf10, (224, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf11, (64, 14336), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf12, (64,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf13, (64, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf14, (16, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf15, (16, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf16, (64, 4096), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf17, (64,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf18, (224, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf19, (224, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf20, (64, 14336), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf21, (64,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf22, (64, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf23, (16, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf24, (16, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf25, (64, 4096), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf26, (64,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf27, (224, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf28, (224, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf29, (64, 14336), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf30, (64,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf31, (64, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf32, (16, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf33, (16, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf34, (64, 4096), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf35, (64,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf36, (224, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf37, (224, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf38, (64, 14336), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf39, (64,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf40, (64, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf41, (16, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf42, (16, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf43, (64, 4096), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf44, (64,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf45, (224, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf46, (224, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf47, (64, 14336), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf48, (64,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf49, (64, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf50, (16, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf51, (16, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf52, (64, 4096), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf53, (64,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf54, (224, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf55, (224, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf56, (64, 14336), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf57, (64,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf58, (64, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf59, (16, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf60, (16, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf61, (64, 4096), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf62, (64,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf63, (224, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf64, (224, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf65, (64, 14336), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf66, (64,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf67, (64, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf68, (16, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf69, (16, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf70, (64, 4096), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf71, (64,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf72, (224, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf73, (224, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf74, (64, 14336), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf75, (64,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf76, (64, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf77, (16, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf78, (16, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf79, (64, 4096), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf80, (64,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf81, (224, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf82, (224, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf83, (64, 14336), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf84, (64,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf85, (64, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf86, (16, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf87, (16, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf88, (64, 4096), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf89, (64,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf90, (224, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf91, (224, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf92, (64, 14336), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf93, (64,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf94, (64, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf95, (16, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf96, (16, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf97, (64, 4096), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf98, (64,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf99, (224, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf100, (224, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf101, (64, 14336), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf102, (64,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf103, (64, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf104, (16, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf105, (16, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf106, (64, 4096), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf107, (64,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf108, (224, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf109, (224, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf110, (64, 14336), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf111, (64,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf112, (64, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf113, (16, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf114, (16, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf115, (64, 4096), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf116, (64,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf117, (224, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf118, (224, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf119, (64, 14336), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf120, (64,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf121, (64, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf122, (16, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf123, (16, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf124, (64, 4096), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf125, (64,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf126, (224, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf127, (224, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf128, (64, 14336), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf129, (64,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf130, (64, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf131, (16, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf132, (16, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf133, (64, 4096), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf134, (64,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf135, (224, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf136, (224, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf137, (64, 14336), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf138, (64,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf139, (64, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf140, (16, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf141, (16, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf142, (64, 4096), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf143, (64,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf144, (224, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf145, (224, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf146, (64, 14336), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf147, (64,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf148, (64, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf149, (16, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf150, (16, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf151, (64, 4096), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf152, (64,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf153, (224, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf154, (224, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf155, (64, 14336), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf156, (64,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf157, (64, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf158, (16, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf159, (16, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf160, (64, 4096), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf161, (64,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf162, (224, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf163, (224, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf164, (64, 14336), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf165, (64,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf166, (64, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf167, (16, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf168, (16, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf169, (64, 4096), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf170, (64,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf171, (224, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf172, (224, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf173, (64, 14336), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf174, (64,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf175, (64, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf176, (16, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf177, (16, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf178, (64, 4096), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf179, (64,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf180, (224, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf181, (224, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf182, (64, 14336), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf183, (64,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf184, (64, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf185, (16, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf186, (16, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf187, (64, 4096), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf188, (64,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf189, (224, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf190, (224, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf191, (64, 14336), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf192, (64,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf193, (64, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf194, (16, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf195, (16, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf196, (64, 4096), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf197, (64,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf198, (224, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf199, (224, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf200, (64, 14336), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf201, (64,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf202, (64, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf203, (16, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf204, (16, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf205, (64, 4096), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf206, (64,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf207, (224, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf208, (224, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf209, (64, 14336), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf210, (64,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf211, (64, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf212, (16, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf213, (16, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf214, (64, 4096), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf215, (64,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf216, (224, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf217, (224, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf218, (64, 14336), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf219, (64,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf220, (64, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf221, (16, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf222, (16, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf223, (64, 4096), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf224, (64,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf225, (224, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf226, (224, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf227, (64, 14336), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf228, (64,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf229, (64, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf230, (16, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf231, (16, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf232, (64, 4096), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf233, (64,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf234, (224, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf235, (224, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf236, (64, 14336), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf237, (64,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf238, (64, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf239, (16, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf240, (16, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf241, (64, 4096), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf242, (64,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf243, (224, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf244, (224, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf245, (64, 14336), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf246, (64,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf247, (64, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf248, (16, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf249, (16, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf250, (64, 4096), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf251, (64,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf252, (224, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf253, (224, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf254, (64, 14336), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf255, (64,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf256, (64, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf257, (16, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf258, (16, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf259, (64, 4096), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf260, (64,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf261, (224, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf262, (224, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf263, (64, 14336), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf264, (64,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf265, (64, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf266, (16, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf267, (16, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf268, (64, 4096), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf269, (64,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf270, (224, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf271, (224, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf272, (64, 14336), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf273, (64,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf274, (64, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf275, (16, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf276, (16, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf277, (64, 4096), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf278, (64,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf279, (224, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf280, (224, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf281, (64, 14336), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf282, (64,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf283, (64, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf284, (16, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf285, (16, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf286, (64, 4096), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf287, (64,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf288, (224, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf289, (224, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf290, (64, 14336), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf291, (64,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf292, (2004, 4096), is_leaf=True) # primals_293 + buf293 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf293, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # embedding + buf294 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf294, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm + buf295 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf295, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_2 + buf296 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf296, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem + buf297 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf297, (2, 32, 8192, 1), is_leaf=True) # getitem_1 + buf298 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf298, (), dtype=torch.int64, is_leaf=True) # getitem_6 + buf299 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf299, (), dtype=torch.int64, is_leaf=True) # getitem_7 + buf300 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf300, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_4 + buf301 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf301, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_3 + buf302 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf302, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_7 + buf303 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf303, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_9 + buf304 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf304, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_9 + buf305 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf305, (2, 32, 8192, 1), is_leaf=True) # getitem_10 + buf306 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf306, (), dtype=torch.int64, is_leaf=True) # getitem_15 + buf307 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf307, (), dtype=torch.int64, is_leaf=True) # getitem_16 + buf308 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf308, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_11 + buf309 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf309, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_7 + buf310 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf310, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_14 + buf311 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf311, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_16 + buf312 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf312, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_18 + buf313 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf313, (2, 32, 8192, 1), is_leaf=True) # getitem_19 + buf314 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf314, (), dtype=torch.int64, is_leaf=True) # getitem_24 + buf315 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf315, (), dtype=torch.int64, is_leaf=True) # getitem_25 + buf316 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf316, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_18 + buf317 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf317, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_11 + buf318 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf318, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_21 + buf319 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf319, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_23 + buf320 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf320, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_27 + buf321 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf321, (2, 32, 8192, 1), is_leaf=True) # getitem_28 + buf322 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf322, (), dtype=torch.int64, is_leaf=True) # getitem_33 + buf323 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf323, (), dtype=torch.int64, is_leaf=True) # getitem_34 + buf324 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf324, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_25 + buf325 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf325, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_15 + buf326 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf326, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_28 + buf327 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf327, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_30 + buf328 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf328, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_36 + buf329 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf329, (2, 32, 8192, 1), is_leaf=True) # getitem_37 + buf330 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf330, (), dtype=torch.int64, is_leaf=True) # getitem_42 + buf331 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf331, (), dtype=torch.int64, is_leaf=True) # getitem_43 + buf332 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf332, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_32 + buf333 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf333, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_19 + buf334 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf334, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_35 + buf335 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf335, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_37 + buf336 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf336, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_45 + buf337 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf337, (2, 32, 8192, 1), is_leaf=True) # getitem_46 + buf338 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf338, (), dtype=torch.int64, is_leaf=True) # getitem_51 + buf339 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf339, (), dtype=torch.int64, is_leaf=True) # getitem_52 + buf340 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf340, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_39 + buf341 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf341, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_23 + buf342 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf342, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_42 + buf343 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf343, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_44 + buf344 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf344, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_54 + buf345 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf345, (2, 32, 8192, 1), is_leaf=True) # getitem_55 + buf346 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf346, (), dtype=torch.int64, is_leaf=True) # getitem_60 + buf347 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf347, (), dtype=torch.int64, is_leaf=True) # getitem_61 + buf348 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf348, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_46 + buf349 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf349, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_27 + buf350 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf350, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_49 + buf351 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf351, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_51 + buf352 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf352, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_63 + buf353 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf353, (2, 32, 8192, 1), is_leaf=True) # getitem_64 + buf354 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf354, (), dtype=torch.int64, is_leaf=True) # getitem_69 + buf355 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf355, (), dtype=torch.int64, is_leaf=True) # getitem_70 + buf356 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf356, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_53 + buf357 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf357, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_31 + buf358 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf358, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_56 + buf359 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf359, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_58 + buf360 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf360, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_72 + buf361 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf361, (2, 32, 8192, 1), is_leaf=True) # getitem_73 + buf362 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf362, (), dtype=torch.int64, is_leaf=True) # getitem_78 + buf363 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf363, (), dtype=torch.int64, is_leaf=True) # getitem_79 + buf364 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf364, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_60 + buf365 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf365, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_35 + buf366 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf366, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_63 + buf367 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf367, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_65 + buf368 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf368, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_81 + buf369 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf369, (2, 32, 8192, 1), is_leaf=True) # getitem_82 + buf370 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf370, (), dtype=torch.int64, is_leaf=True) # getitem_87 + buf371 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf371, (), dtype=torch.int64, is_leaf=True) # getitem_88 + buf372 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf372, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_67 + buf373 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf373, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_39 + buf374 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf374, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_70 + buf375 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf375, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_72 + buf376 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf376, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_90 + buf377 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf377, (2, 32, 8192, 1), is_leaf=True) # getitem_91 + buf378 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf378, (), dtype=torch.int64, is_leaf=True) # getitem_96 + buf379 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf379, (), dtype=torch.int64, is_leaf=True) # getitem_97 + buf380 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf380, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_74 + buf381 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf381, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_43 + buf382 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf382, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_77 + buf383 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf383, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_79 + buf384 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf384, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_99 + buf385 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf385, (2, 32, 8192, 1), is_leaf=True) # getitem_100 + buf386 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf386, (), dtype=torch.int64, is_leaf=True) # getitem_105 + buf387 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf387, (), dtype=torch.int64, is_leaf=True) # getitem_106 + buf388 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf388, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_81 + buf389 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf389, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_47 + buf390 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf390, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_84 + buf391 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf391, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_86 + buf392 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf392, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_108 + buf393 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf393, (2, 32, 8192, 1), is_leaf=True) # getitem_109 + buf394 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf394, (), dtype=torch.int64, is_leaf=True) # getitem_114 + buf395 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf395, (), dtype=torch.int64, is_leaf=True) # getitem_115 + buf396 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf396, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_88 + buf397 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf397, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_51 + buf398 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf398, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_91 + buf399 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf399, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_93 + buf400 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf400, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_117 + buf401 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf401, (2, 32, 8192, 1), is_leaf=True) # getitem_118 + buf402 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf402, (), dtype=torch.int64, is_leaf=True) # getitem_123 + buf403 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf403, (), dtype=torch.int64, is_leaf=True) # getitem_124 + buf404 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf404, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_95 + buf405 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf405, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_55 + buf406 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf406, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_98 + buf407 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf407, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_100 + buf408 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf408, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_126 + buf409 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf409, (2, 32, 8192, 1), is_leaf=True) # getitem_127 + buf410 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf410, (), dtype=torch.int64, is_leaf=True) # getitem_132 + buf411 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf411, (), dtype=torch.int64, is_leaf=True) # getitem_133 + buf412 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf412, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_102 + buf413 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf413, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_59 + buf414 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf414, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_105 + buf415 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf415, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_107 + buf416 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf416, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_135 + buf417 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf417, (2, 32, 8192, 1), is_leaf=True) # getitem_136 + buf418 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf418, (), dtype=torch.int64, is_leaf=True) # getitem_141 + buf419 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf419, (), dtype=torch.int64, is_leaf=True) # getitem_142 + buf420 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf420, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_109 + buf421 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf421, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_63 + buf422 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf422, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_112 + buf423 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf423, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_114 + buf424 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf424, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_144 + buf425 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf425, (2, 32, 8192, 1), is_leaf=True) # getitem_145 + buf426 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf426, (), dtype=torch.int64, is_leaf=True) # getitem_150 + buf427 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf427, (), dtype=torch.int64, is_leaf=True) # getitem_151 + buf428 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf428, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_116 + buf429 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf429, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_67 + buf430 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf430, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_119 + buf431 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf431, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_121 + buf432 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf432, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_153 + buf433 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf433, (2, 32, 8192, 1), is_leaf=True) # getitem_154 + buf434 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf434, (), dtype=torch.int64, is_leaf=True) # getitem_159 + buf435 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf435, (), dtype=torch.int64, is_leaf=True) # getitem_160 + buf436 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf436, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_123 + buf437 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf437, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_71 + buf438 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf438, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_126 + buf439 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf439, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_128 + buf440 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf440, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_162 + buf441 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf441, (2, 32, 8192, 1), is_leaf=True) # getitem_163 + buf442 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf442, (), dtype=torch.int64, is_leaf=True) # getitem_168 + buf443 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf443, (), dtype=torch.int64, is_leaf=True) # getitem_169 + buf444 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf444, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_130 + buf445 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf445, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_75 + buf446 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf446, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_133 + buf447 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf447, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_135 + buf448 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf448, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_171 + buf449 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf449, (2, 32, 8192, 1), is_leaf=True) # getitem_172 + buf450 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf450, (), dtype=torch.int64, is_leaf=True) # getitem_177 + buf451 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf451, (), dtype=torch.int64, is_leaf=True) # getitem_178 + buf452 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf452, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_137 + buf453 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf453, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_79 + buf454 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf454, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_140 + buf455 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf455, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_142 + buf456 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf456, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_180 + buf457 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf457, (2, 32, 8192, 1), is_leaf=True) # getitem_181 + buf458 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf458, (), dtype=torch.int64, is_leaf=True) # getitem_186 + buf459 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf459, (), dtype=torch.int64, is_leaf=True) # getitem_187 + buf460 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf460, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_144 + buf461 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf461, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_83 + buf462 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf462, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_147 + buf463 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf463, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_149 + buf464 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf464, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_189 + buf465 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf465, (2, 32, 8192, 1), is_leaf=True) # getitem_190 + buf466 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf466, (), dtype=torch.int64, is_leaf=True) # getitem_195 + buf467 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf467, (), dtype=torch.int64, is_leaf=True) # getitem_196 + buf468 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf468, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_151 + buf469 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf469, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_87 + buf470 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf470, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_154 + buf471 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf471, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_156 + buf472 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf472, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_198 + buf473 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf473, (2, 32, 8192, 1), is_leaf=True) # getitem_199 + buf474 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf474, (), dtype=torch.int64, is_leaf=True) # getitem_204 + buf475 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf475, (), dtype=torch.int64, is_leaf=True) # getitem_205 + buf476 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf476, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_158 + buf477 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf477, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_91 + buf478 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf478, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_161 + buf479 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf479, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_163 + buf480 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf480, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_207 + buf481 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf481, (2, 32, 8192, 1), is_leaf=True) # getitem_208 + buf482 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf482, (), dtype=torch.int64, is_leaf=True) # getitem_213 + buf483 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf483, (), dtype=torch.int64, is_leaf=True) # getitem_214 + buf484 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf484, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_165 + buf485 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf485, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_95 + buf486 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf486, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_168 + buf487 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf487, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_170 + buf488 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf488, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_216 + buf489 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf489, (2, 32, 8192, 1), is_leaf=True) # getitem_217 + buf490 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf490, (), dtype=torch.int64, is_leaf=True) # getitem_222 + buf491 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf491, (), dtype=torch.int64, is_leaf=True) # getitem_223 + buf492 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf492, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_172 + buf493 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf493, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_99 + buf494 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf494, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_175 + buf495 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf495, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_177 + buf496 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf496, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_225 + buf497 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf497, (2, 32, 8192, 1), is_leaf=True) # getitem_226 + buf498 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf498, (), dtype=torch.int64, is_leaf=True) # getitem_231 + buf499 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf499, (), dtype=torch.int64, is_leaf=True) # getitem_232 + buf500 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf500, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_179 + buf501 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf501, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_103 + buf502 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf502, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_182 + buf503 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf503, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_184 + buf504 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf504, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_234 + buf505 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf505, (2, 32, 8192, 1), is_leaf=True) # getitem_235 + buf506 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf506, (), dtype=torch.int64, is_leaf=True) # getitem_240 + buf507 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf507, (), dtype=torch.int64, is_leaf=True) # getitem_241 + buf508 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf508, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_186 + buf509 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf509, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_107 + buf510 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf510, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_189 + buf511 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf511, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_191 + buf512 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf512, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_243 + buf513 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf513, (2, 32, 8192, 1), is_leaf=True) # getitem_244 + buf514 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf514, (), dtype=torch.int64, is_leaf=True) # getitem_249 + buf515 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf515, (), dtype=torch.int64, is_leaf=True) # getitem_250 + buf516 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf516, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_193 + buf517 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf517, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_111 + buf518 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf518, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_196 + buf519 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf519, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_198 + buf520 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf520, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_252 + buf521 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf521, (2, 32, 8192, 1), is_leaf=True) # getitem_253 + buf522 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf522, (), dtype=torch.int64, is_leaf=True) # getitem_258 + buf523 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf523, (), dtype=torch.int64, is_leaf=True) # getitem_259 + buf524 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf524, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_200 + buf525 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf525, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_115 + buf526 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf526, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_203 + buf527 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf527, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_205 + buf528 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf528, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_261 + buf529 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf529, (2, 32, 8192, 1), is_leaf=True) # getitem_262 + buf530 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf530, (), dtype=torch.int64, is_leaf=True) # getitem_267 + buf531 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf531, (), dtype=torch.int64, is_leaf=True) # getitem_268 + buf532 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf532, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_207 + buf533 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf533, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_119 + buf534 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf534, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_210 + buf535 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf535, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_212 + buf536 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf536, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_270 + buf537 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf537, (2, 32, 8192, 1), is_leaf=True) # getitem_271 + buf538 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf538, (), dtype=torch.int64, is_leaf=True) # getitem_276 + buf539 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf539, (), dtype=torch.int64, is_leaf=True) # getitem_277 + buf540 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf540, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_214 + buf541 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf541, (2, 8192, 4096), dtype=torch.bfloat16, is_leaf=True) # add_123 + buf542 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf542, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_217 + buf543 = reader.storage(None, 33554432, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf543, (16384, 1024), dtype=torch.bfloat16, is_leaf=True) # mm_219 + buf544 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf544, (2, 32, 8192, 128), (33554432, 128, 4096, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_279 + buf545 = reader.storage(None, 2097152, device=device(type='cuda', index=0)) + reader.tensor(buf545, (2, 32, 8192, 1), is_leaf=True) # getitem_280 + buf546 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf546, (), dtype=torch.int64, is_leaf=True) # getitem_285 + buf547 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf547, (), dtype=torch.int64, is_leaf=True) # getitem_286 + buf548 = reader.storage(None, 469762048, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf548, (16384, 14336), dtype=torch.bfloat16, is_leaf=True) # mm_221 + buf549 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf549, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # mm_223 + buf550 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf550, (2, 8192, 1), is_leaf=True) # rsqrt_64 + buf551 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf551, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # view_1091 + buf552 = reader.storage(None, 4202692608, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf552, (2, 8192, 128256), dtype=torch.bfloat16, is_leaf=True) # tangents_1 + +load_args._version = 0 + +def get_pg_config(): + return {'0': {'size': 64, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls8_8.table" diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_2d.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_2d.py new file mode 100644 index 00000000..db8c8b77 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_bw_64_2d.py @@ -0,0 +1,5783 @@ +# fmt: off +# flake8: noqa +# isort: skip_file +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, wait_tensor_1, mm, mm_2, getitem_80, getitem_81, getitem_86, getitem_87, reduce_scatter_tensor_1, mm_4, add_3, mm_7, mm_9, getitem_121, getitem_122, getitem_127, getitem_128, reduce_scatter_tensor_3, mm_11, add_7, mm_14, mm_16, getitem_162, getitem_163, getitem_168, getitem_169, reduce_scatter_tensor_5, mm_18, add_11, mm_21, mm_23, getitem_203, getitem_204, getitem_209, getitem_210, reduce_scatter_tensor_7, mm_25, add_15, mm_28, mm_30, getitem_244, getitem_245, getitem_250, getitem_251, reduce_scatter_tensor_9, mm_32, add_19, mm_35, mm_37, getitem_285, getitem_286, getitem_291, getitem_292, reduce_scatter_tensor_11, mm_39, add_23, mm_42, mm_44, getitem_326, getitem_327, getitem_332, getitem_333, reduce_scatter_tensor_13, mm_46, add_27, mm_49, mm_51, getitem_367, getitem_368, getitem_373, getitem_374, reduce_scatter_tensor_15, mm_53, add_31, mm_56, mm_58, getitem_408, getitem_409, getitem_414, getitem_415, reduce_scatter_tensor_17, mm_60, add_35, mm_63, mm_65, getitem_449, getitem_450, getitem_455, getitem_456, reduce_scatter_tensor_19, mm_67, add_39, mm_70, mm_72, getitem_490, getitem_491, getitem_496, getitem_497, reduce_scatter_tensor_21, mm_74, add_43, mm_77, mm_79, getitem_531, getitem_532, getitem_537, getitem_538, reduce_scatter_tensor_23, mm_81, add_47, mm_84, mm_86, getitem_572, getitem_573, getitem_578, getitem_579, reduce_scatter_tensor_25, mm_88, add_51, mm_91, mm_93, getitem_613, getitem_614, getitem_619, getitem_620, reduce_scatter_tensor_27, mm_95, add_55, mm_98, mm_100, getitem_654, getitem_655, getitem_660, getitem_661, reduce_scatter_tensor_29, mm_102, add_59, mm_105, mm_107, getitem_695, getitem_696, getitem_701, getitem_702, reduce_scatter_tensor_31, mm_109, reduce_scatter_tensor_32, rsqrt_32, view_1167, tangents_1): + view_1169 = torch.ops.aten.view.default(tangents_1, [16384, 16032]); tangents_1 = None + permute_177 = torch.ops.aten.permute.default(view_1169, [1, 0]) + mm_113 = torch.ops.aten.mm.default(permute_177, view_1167); permute_177 = view_1167 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16); primals_149 = None + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 8, '0'); convert_element_type_532 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + permute_179 = torch.ops.aten.permute.default(permute_176, [1, 0]); permute_176 = None + mm_114 = torch.ops.aten.mm.default(view_1169, permute_179); view_1169 = permute_179 = None + view_1170 = torch.ops.aten.view.default(mm_114, [2, 8192, 4096]); mm_114 = None + convert_element_type_539 = torch.ops.prims.convert_element_type.default(mm_113, torch.float32); mm_113 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_539, 'avg', 8, '0'); convert_element_type_539 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33); reduce_scatter_tensor_33 = None + split_74 = torch.ops.aten.split.Tensor(view_1170, 1024, 1); view_1170 = None + getitem_736 = split_74[0] + getitem_737 = split_74[1] + getitem_738 = split_74[2] + getitem_739 = split_74[3] + getitem_740 = split_74[4] + getitem_741 = split_74[5] + getitem_742 = split_74[6] + getitem_743 = split_74[7]; split_74 = None + cat_66 = torch.ops.aten.cat.default([getitem_736, getitem_737, getitem_738, getitem_739, getitem_740, getitem_741, getitem_742, getitem_743]); getitem_736 = getitem_737 = getitem_738 = getitem_739 = getitem_740 = getitem_741 = getitem_742 = getitem_743 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_66, 'sum', 8, '1'); cat_66 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + convert_element_type_540 = torch.ops.prims.convert_element_type.default(wait_tensor_214, torch.float32); wait_tensor_214 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16); primals_148 = None + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 8, '0'); convert_element_type_529 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(wait_tensor_210, torch.float32); wait_tensor_210 = None + mul_130 = torch.ops.aten.mul.Tensor(convert_element_type_540, convert_element_type_542); convert_element_type_542 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31); reduce_scatter_tensor_31 = None + add_61 = torch.ops.aten.add.Tensor(add_59, wait_tensor_203); wait_tensor_203 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + add_63 = torch.ops.aten.add.Tensor(add_61, wait_tensor_209); wait_tensor_209 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32); add_63 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = None + mul_132 = torch.ops.aten.mul.Tensor(mul_128, mul_130) + sum_1 = torch.ops.aten.sum.dim_IntList(mul_132, [2], True); mul_132 = None + div = torch.ops.aten.div.Tensor(mul_128, 4096) + mul_133 = torch.ops.aten.mul.Tensor(div, sum_1); div = sum_1 = None + sub_1 = torch.ops.aten.sub.Tensor(mul_130, mul_133); mul_130 = mul_133 = None + mul_134 = torch.ops.aten.mul.Tensor(sub_1, rsqrt_32); sub_1 = rsqrt_32 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_540, mul_128); convert_element_type_540 = mul_128 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_135, [0, 1]); mul_135 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(sum_2, torch.bfloat16); sum_2 = None + all_reduce = torch.ops._c10d_functional.all_reduce.default(convert_element_type_544, 'sum', '1'); convert_element_type_544 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_reduce); all_reduce = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(wait_tensor_215, torch.float32); wait_tensor_215 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_545, 'avg', 8, '0'); convert_element_type_545 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35); reduce_scatter_tensor_35 = None + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_543, 8, '1') + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + split_75 = torch.ops.aten.split.Tensor(wait_tensor_217, 2); wait_tensor_217 = None + getitem_744 = split_75[0] + getitem_745 = split_75[1] + getitem_746 = split_75[2] + getitem_747 = split_75[3] + getitem_748 = split_75[4] + getitem_749 = split_75[5] + getitem_750 = split_75[6] + getitem_751 = split_75[7]; split_75 = None + cat_67 = torch.ops.aten.cat.default([getitem_744, getitem_745, getitem_746, getitem_747, getitem_748, getitem_749, getitem_750, getitem_751], 1); getitem_744 = getitem_745 = getitem_746 = getitem_747 = getitem_748 = getitem_749 = getitem_750 = getitem_751 = None + view_1171 = torch.ops.aten.view.default(cat_67, [16384, 4096]); cat_67 = None + permute_181 = torch.ops.aten.permute.default(view_1171, [1, 0]) + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16); primals_144 = None + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 8, '0'); convert_element_type_515 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32); add_61 = None + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_204) + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_517, 8, '1'); convert_element_type_517 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + split_71 = torch.ops.aten.split.Tensor(wait_tensor_205, 2); wait_tensor_205 = None + getitem_712 = split_71[0] + getitem_713 = split_71[1] + getitem_714 = split_71[2] + getitem_715 = split_71[3] + getitem_716 = split_71[4] + getitem_717 = split_71[5] + getitem_718 = split_71[6] + getitem_719 = split_71[7]; split_71 = None + cat_63 = torch.ops.aten.cat.default([getitem_712, getitem_713, getitem_714, getitem_715, getitem_716, getitem_717, getitem_718, getitem_719], 1); getitem_712 = getitem_713 = getitem_714 = getitem_715 = getitem_716 = getitem_717 = getitem_718 = getitem_719 = None + view_1140 = torch.ops.aten.view.default(cat_63, [16384, 4096]); cat_63 = None + view_1141 = torch.ops.aten.view.default(mm_109, [2, 8192, 1792]); mm_109 = None + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_1141, torch.float32); view_1141 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16); primals_146 = None + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 8, '0'); convert_element_type_523 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + mm_110 = torch.ops.aten.mm.default(view_1140, permute_174) + view_1148 = torch.ops.aten.view.default(mm_110, [2, 8192, 1792]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_1148) + view_1155 = torch.ops.aten.view.default(mul_127, [16384, 1792]); mul_127 = None + mm_115 = torch.ops.aten.mm.default(permute_181, view_1155); permute_181 = view_1155 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16); primals_147 = None + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 8, '0'); convert_element_type_526 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_208, [1, 0]); wait_tensor_208 = None + permute_183 = torch.ops.aten.permute.default(permute_175, [1, 0]); permute_175 = None + mm_116 = torch.ops.aten.mm.default(view_1171, permute_183); view_1171 = permute_183 = None + view_1172 = torch.ops.aten.view.default(mm_116, [2, 8192, 1792]); mm_116 = None + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mm_115, torch.float32); mm_115 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_550, 'avg', 8, '0'); convert_element_type_550 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + mul_136 = torch.ops.aten.mul.Tensor(view_1172, convert_element_type_522); convert_element_type_522 = None + mul_137 = torch.ops.aten.mul.Tensor(view_1172, view_1148); view_1172 = view_1148 = None + view_1173 = torch.ops.aten.view.default(mul_136, [16384, 1792]); mul_136 = None + permute_185 = torch.ops.aten.permute.default(view_1173, [1, 0]) + mm_117 = torch.ops.aten.mm.default(permute_185, view_1140); permute_185 = None + permute_187 = torch.ops.aten.permute.default(permute_174, [1, 0]); permute_174 = None + mm_118 = torch.ops.aten.mm.default(view_1173, permute_187); view_1173 = permute_187 = None + view_1174 = torch.ops.aten.view.default(mm_118, [2, 8192, 4096]); mm_118 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mm_117, torch.float32); mm_117 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_555, 'avg', 8, '0'); convert_element_type_555 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37); reduce_scatter_tensor_37 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(mul_137, torch.float32); mul_137 = None + neg = torch.ops.aten.neg.default(convert_element_type_521) + exp = torch.ops.aten.exp.default(neg); neg = None + add_65 = torch.ops.aten.add.Tensor(exp, 1); exp = None + reciprocal = torch.ops.aten.reciprocal.default(add_65); add_65 = None + mul_138 = torch.ops.aten.mul.Tensor(reciprocal, 1); reciprocal = None + mul_139 = torch.ops.aten.mul.Tensor(convert_element_type_556, mul_138); convert_element_type_556 = None + sub_2 = torch.ops.aten.sub.Tensor(1, mul_138); mul_138 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_521, sub_2); convert_element_type_521 = sub_2 = None + add_66 = torch.ops.aten.add.Tensor(mul_140, 1); mul_140 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_139, add_66); mul_139 = add_66 = None + convert_element_type_558 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + view_1175 = torch.ops.aten.view.default(convert_element_type_558, [16384, 1792]); convert_element_type_558 = None + permute_189 = torch.ops.aten.permute.default(view_1175, [1, 0]) + mm_119 = torch.ops.aten.mm.default(permute_189, view_1140); permute_189 = view_1140 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16); primals_145 = None + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 8, '0'); convert_element_type_518 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + permute_191 = torch.ops.aten.permute.default(permute_173, [1, 0]); permute_173 = None + mm_120 = torch.ops.aten.mm.default(view_1175, permute_191); view_1175 = permute_191 = None + view_1176 = torch.ops.aten.view.default(mm_120, [2, 8192, 4096]); mm_120 = None + add_67 = torch.ops.aten.add.Tensor(view_1174, view_1176); view_1174 = view_1176 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(mm_119, torch.float32); mm_119 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_563, 'avg', 8, '0'); convert_element_type_563 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + split_76 = torch.ops.aten.split.Tensor(add_67, 1024, 1); add_67 = None + getitem_752 = split_76[0] + getitem_753 = split_76[1] + getitem_754 = split_76[2] + getitem_755 = split_76[3] + getitem_756 = split_76[4] + getitem_757 = split_76[5] + getitem_758 = split_76[6] + getitem_759 = split_76[7]; split_76 = None + cat_68 = torch.ops.aten.cat.default([getitem_752, getitem_753, getitem_754, getitem_755, getitem_756, getitem_757, getitem_758, getitem_759]); getitem_752 = getitem_753 = getitem_754 = getitem_755 = getitem_756 = getitem_757 = getitem_758 = getitem_759 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_68, 'sum', 8, '1'); cat_68 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39); reduce_scatter_tensor_39 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(wait_tensor_221, torch.float32); wait_tensor_221 = None + convert_element_type_566 = torch.ops.prims.convert_element_type.default(wait_tensor_204, torch.float32); wait_tensor_204 = None + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_564, convert_element_type_566); convert_element_type_566 = None + mul_144 = torch.ops.aten.mul.Tensor(mul_124, mul_142) + sum_3 = torch.ops.aten.sum.dim_IntList(mul_144, [2], True); mul_144 = None + div_1 = torch.ops.aten.div.Tensor(mul_124, 4096) + mul_145 = torch.ops.aten.mul.Tensor(div_1, sum_3); div_1 = sum_3 = None + sub_3 = torch.ops.aten.sub.Tensor(mul_142, mul_145); mul_142 = mul_145 = None + mul_146 = torch.ops.aten.mul.Tensor(sub_3, rsqrt_31); sub_3 = rsqrt_31 = None + mul_147 = torch.ops.aten.mul.Tensor(convert_element_type_564, mul_124); convert_element_type_564 = mul_124 = None + sum_4 = torch.ops.aten.sum.dim_IntList(mul_147, [0, 1]); mul_147 = None + convert_element_type_567 = torch.ops.prims.convert_element_type.default(mul_146, torch.bfloat16); mul_146 = None + convert_element_type_568 = torch.ops.prims.convert_element_type.default(sum_4, torch.bfloat16); sum_4 = None + all_reduce_1 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_568, 'sum', '1'); convert_element_type_568 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_1); all_reduce_1 = None + convert_element_type_569 = torch.ops.prims.convert_element_type.default(wait_tensor_222, torch.float32); wait_tensor_222 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_569, 'avg', 8, '0'); convert_element_type_569 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + add_68 = torch.ops.aten.add.Tensor(convert_element_type_543, convert_element_type_567); convert_element_type_543 = convert_element_type_567 = None + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_68, 8, '1') + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + split_77 = torch.ops.aten.split.Tensor(wait_tensor_224, 2); wait_tensor_224 = None + getitem_760 = split_77[0] + getitem_761 = split_77[1] + getitem_762 = split_77[2] + getitem_763 = split_77[3] + getitem_764 = split_77[4] + getitem_765 = split_77[5] + getitem_766 = split_77[6] + getitem_767 = split_77[7]; split_77 = None + cat_69 = torch.ops.aten.cat.default([getitem_760, getitem_761, getitem_762, getitem_763, getitem_764, getitem_765, getitem_766, getitem_767], 1); getitem_760 = getitem_761 = getitem_762 = getitem_763 = getitem_764 = getitem_765 = getitem_766 = getitem_767 = None + view_1177 = torch.ops.aten.view.default(cat_69, [16384, 4096]); cat_69 = None + permute_193 = torch.ops.aten.permute.default(view_1177, [1, 0]) + permute_171 = torch.ops.aten.permute.default(getitem_695, [0, 2, 1, 3]) + view_1122 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + view_1128 = torch.ops.aten.view.default(view_1122, [16384, 512]); view_1122 = None + mm_121 = torch.ops.aten.mm.default(permute_193, view_1128); permute_193 = view_1128 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 8, '0'); convert_element_type_512 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + permute_195 = torch.ops.aten.permute.default(permute_172, [1, 0]); permute_172 = None + mm_122 = torch.ops.aten.mm.default(view_1177, permute_195); view_1177 = permute_195 = None + view_1178 = torch.ops.aten.view.default(mm_122, [2, 8192, 512]); mm_122 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(mm_121, torch.float32); mm_121 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_574, 'avg', 8, '0'); convert_element_type_574 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41); reduce_scatter_tensor_41 = None + view_1179 = torch.ops.aten.view.default(view_1178, [2, 8192, 4, 128]); view_1178 = None + permute_197 = torch.ops.aten.permute.default(view_1179, [0, 2, 1, 3]); view_1179 = None + view_37 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]); primals_3 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 8, '0'); convert_element_type_496 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32); add_59 = None + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_197) + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_498, 8, '1'); convert_element_type_498 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + split_69 = torch.ops.aten.split.Tensor(wait_tensor_198, 2); wait_tensor_198 = None + getitem_687 = split_69[0] + getitem_688 = split_69[1] + getitem_689 = split_69[2] + getitem_690 = split_69[3] + getitem_691 = split_69[4] + getitem_692 = split_69[5] + getitem_693 = split_69[6] + getitem_694 = split_69[7]; split_69 = None + cat_61 = torch.ops.aten.cat.default([getitem_687, getitem_688, getitem_689, getitem_690, getitem_691, getitem_692, getitem_693, getitem_694], 1); getitem_687 = getitem_688 = getitem_689 = getitem_690 = getitem_691 = getitem_692 = getitem_693 = getitem_694 = None + view_1095 = torch.ops.aten.view.default(cat_61, [16384, 4096]); cat_61 = None + view_1096 = torch.ops.aten.view.default(mm_105, [2, 8192, 512]); mm_105 = None + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 8, '0'); convert_element_type_502 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + mm_106 = torch.ops.aten.mm.default(view_1095, permute_166) + view_1103 = torch.ops.aten.view.default(mm_106, [2, 8192, 128]); mm_106 = None + view_1110 = torch.ops.aten.view.default(mm_107, [2, 8192, 128]); mm_107 = None + view_1112 = torch.ops.aten.view.default(view_1096, [2, 8192, -1, 128]); view_1096 = None + view_1113 = torch.ops.aten.view.default(view_1103, [2, 8192, -1, 128]); view_1103 = None + view_1114 = torch.ops.aten.view.default(view_1110, [2, 8192, -1, 128]); view_1110 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_1112, torch.float32); view_1112 = None + view_1115 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 4, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_1115); view_1115 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_1113, torch.float32); view_1113 = None + view_1116 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 1, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_1116); view_1116 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_37); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_1118 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 4, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_37); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_1119 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 1, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_1118, torch.bfloat16); view_1118 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_1119, torch.bfloat16); view_1119 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 1, 4, 128]); unsqueeze_30 = None + view_1120 = torch.ops.aten.view.default(expand_30, [2, 8192, 4, 128]); expand_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_1114, 3); view_1114 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 1, 4, 128]); unsqueeze_31 = None + view_1121 = torch.ops.aten.view.default(expand_31, [2, 8192, 4, 128]); expand_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_1120, [0, 2, 1, 3]); view_1120 = None + permute_170 = torch.ops.aten.permute.default(view_1121, [0, 2, 1, 3]); view_1121 = None + _scaled_dot_product_cudnn_attention_backward = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_197, permute_168, permute_169, permute_170, getitem_695, getitem_696, getitem_701, getitem_702, None, None, None, 8192, 8192, 0.0, True); permute_197 = permute_168 = permute_169 = permute_170 = getitem_695 = getitem_696 = getitem_701 = getitem_702 = None + getitem_768 = _scaled_dot_product_cudnn_attention_backward[0] + getitem_769 = _scaled_dot_product_cudnn_attention_backward[1] + getitem_770 = _scaled_dot_product_cudnn_attention_backward[2]; _scaled_dot_product_cudnn_attention_backward = None + permute_198 = torch.ops.aten.permute.default(getitem_770, [0, 2, 1, 3]); getitem_770 = None + permute_199 = torch.ops.aten.permute.default(getitem_769, [0, 2, 1, 3]); getitem_769 = None + permute_200 = torch.ops.aten.permute.default(getitem_768, [0, 2, 1, 3]); getitem_768 = None + view_1180 = torch.ops.aten.view.default(permute_198, [2, 8192, 1, 4, 128]); permute_198 = None + sum_5 = torch.ops.aten.sum.dim_IntList(view_1180, [3], True); view_1180 = None + squeeze = torch.ops.aten.squeeze.dim(sum_5, 3); sum_5 = None + view_1181 = torch.ops.aten.view.default(permute_199, [2, 8192, 1, 4, 128]); permute_199 = None + sum_6 = torch.ops.aten.sum.dim_IntList(view_1181, [3], True); view_1181 = None + squeeze_1 = torch.ops.aten.squeeze.dim(sum_6, 3); sum_6 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(squeeze_1, torch.float32); squeeze_1 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(permute_200, torch.float32); permute_200 = None + view_1182 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 1, 64, 2]); convert_element_type_575 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_1182); view_1182 = None + _conj = torch.ops.aten._conj.default(view_37) + mul_148 = torch.ops.aten.mul.Tensor(view_as_complex_32, _conj); view_as_complex_32 = None + view_1183 = torch.ops.aten.view.default(convert_element_type_576, [2, 8192, 4, 64, 2]); convert_element_type_576 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_1183); view_1183 = None + mul_149 = torch.ops.aten.mul.Tensor(view_as_complex_33, _conj); view_as_complex_33 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_148); mul_148 = None + view_1184 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 1, 128]); view_as_real_32 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_1184, torch.bfloat16); view_1184 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_149); mul_149 = None + view_1185 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 4, 128]); view_as_real_33 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(view_1185, torch.bfloat16); view_1185 = None + view_1186 = torch.ops.aten.view.default(squeeze, [2, 8192, 128]); squeeze = None + view_1187 = torch.ops.aten.view.default(convert_element_type_577, [2, 8192, 128]); convert_element_type_577 = None + view_1188 = torch.ops.aten.view.default(convert_element_type_578, [2, 8192, 512]); convert_element_type_578 = None + view_1189 = torch.ops.aten.view.default(view_1186, [16384, 128]); view_1186 = None + permute_201 = torch.ops.aten.permute.default(view_1189, [1, 0]) + mm_123 = torch.ops.aten.mm.default(permute_201, view_1095); permute_201 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16); primals_142 = None + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 8, '0'); convert_element_type_505 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + permute_203 = torch.ops.aten.permute.default(permute_167, [1, 0]); permute_167 = None + mm_124 = torch.ops.aten.mm.default(view_1189, permute_203); view_1189 = permute_203 = None + view_1190 = torch.ops.aten.view.default(mm_124, [2, 8192, 4096]); mm_124 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mm_123, torch.float32); mm_123 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_583, 'avg', 8, '0'); convert_element_type_583 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + view_1191 = torch.ops.aten.view.default(view_1187, [16384, 128]); view_1187 = None + permute_205 = torch.ops.aten.permute.default(view_1191, [1, 0]) + mm_125 = torch.ops.aten.mm.default(permute_205, view_1095); permute_205 = None + permute_207 = torch.ops.aten.permute.default(permute_166, [1, 0]); permute_166 = None + mm_126 = torch.ops.aten.mm.default(view_1191, permute_207); view_1191 = permute_207 = None + view_1192 = torch.ops.aten.view.default(mm_126, [2, 8192, 4096]); mm_126 = None + add_69 = torch.ops.aten.add.Tensor(view_1190, view_1192); view_1190 = view_1192 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mm_125, torch.float32); mm_125 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_588, 'avg', 8, '0'); convert_element_type_588 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43); reduce_scatter_tensor_43 = None + view_1193 = torch.ops.aten.view.default(view_1188, [16384, 512]); view_1188 = None + permute_209 = torch.ops.aten.permute.default(view_1193, [1, 0]) + mm_127 = torch.ops.aten.mm.default(permute_209, view_1095); permute_209 = view_1095 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 8, '0'); convert_element_type_499 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + permute_211 = torch.ops.aten.permute.default(permute_165, [1, 0]); permute_165 = None + mm_128 = torch.ops.aten.mm.default(view_1193, permute_211); view_1193 = permute_211 = None + view_1194 = torch.ops.aten.view.default(mm_128, [2, 8192, 4096]); mm_128 = None + add_70 = torch.ops.aten.add.Tensor(add_69, view_1194); add_69 = view_1194 = None + convert_element_type_593 = torch.ops.prims.convert_element_type.default(mm_127, torch.float32); mm_127 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_593, 'avg', 8, '0'); convert_element_type_593 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + split_78 = torch.ops.aten.split.Tensor(add_70, 1024, 1); add_70 = None + getitem_771 = split_78[0] + getitem_772 = split_78[1] + getitem_773 = split_78[2] + getitem_774 = split_78[3] + getitem_775 = split_78[4] + getitem_776 = split_78[5] + getitem_777 = split_78[6] + getitem_778 = split_78[7]; split_78 = None + cat_70 = torch.ops.aten.cat.default([getitem_771, getitem_772, getitem_773, getitem_774, getitem_775, getitem_776, getitem_777, getitem_778]); getitem_771 = getitem_772 = getitem_773 = getitem_774 = getitem_775 = getitem_776 = getitem_777 = getitem_778 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_70, 'sum', 8, '1'); cat_70 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45); reduce_scatter_tensor_45 = None + convert_element_type_594 = torch.ops.prims.convert_element_type.default(wait_tensor_229, torch.float32); wait_tensor_229 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(wait_tensor_197, torch.float32); wait_tensor_197 = None + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_594, convert_element_type_596); convert_element_type_596 = None + mul_152 = torch.ops.aten.mul.Tensor(mul_120, mul_150) + sum_7 = torch.ops.aten.sum.dim_IntList(mul_152, [2], True); mul_152 = None + div_2 = torch.ops.aten.div.Tensor(mul_120, 4096) + mul_153 = torch.ops.aten.mul.Tensor(div_2, sum_7); div_2 = sum_7 = None + sub_4 = torch.ops.aten.sub.Tensor(mul_150, mul_153); mul_150 = mul_153 = None + mul_154 = torch.ops.aten.mul.Tensor(sub_4, rsqrt_30); sub_4 = rsqrt_30 = None + mul_155 = torch.ops.aten.mul.Tensor(convert_element_type_594, mul_120); convert_element_type_594 = mul_120 = None + sum_8 = torch.ops.aten.sum.dim_IntList(mul_155, [0, 1]); mul_155 = None + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_154, torch.bfloat16); mul_154 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(sum_8, torch.bfloat16); sum_8 = None + all_reduce_2 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_598, 'sum', '1'); convert_element_type_598 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_2); all_reduce_2 = None + convert_element_type_599 = torch.ops.prims.convert_element_type.default(wait_tensor_230, torch.float32); wait_tensor_230 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_599, 'avg', 8, '0'); convert_element_type_599 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + add_71 = torch.ops.aten.add.Tensor(add_68, convert_element_type_597); add_68 = convert_element_type_597 = None + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_71, 8, '1') + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + split_79 = torch.ops.aten.split.Tensor(wait_tensor_232, 2); wait_tensor_232 = None + getitem_779 = split_79[0] + getitem_780 = split_79[1] + getitem_781 = split_79[2] + getitem_782 = split_79[3] + getitem_783 = split_79[4] + getitem_784 = split_79[5] + getitem_785 = split_79[6] + getitem_786 = split_79[7]; split_79 = None + cat_71 = torch.ops.aten.cat.default([getitem_779, getitem_780, getitem_781, getitem_782, getitem_783, getitem_784, getitem_785, getitem_786], 1); getitem_779 = getitem_780 = getitem_781 = getitem_782 = getitem_783 = getitem_784 = getitem_785 = getitem_786 = None + view_1195 = torch.ops.aten.view.default(cat_71, [16384, 4096]); cat_71 = None + permute_213 = torch.ops.aten.permute.default(view_1195, [1, 0]) + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29); reduce_scatter_tensor_29 = None + add_57 = torch.ops.aten.add.Tensor(add_55, wait_tensor_190); wait_tensor_190 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16); primals_135 = None + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 8, '0'); convert_element_type_482 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32); add_57 = None + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_191) + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_484, 8, '1'); convert_element_type_484 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + split_67 = torch.ops.aten.split.Tensor(wait_tensor_192, 2); wait_tensor_192 = None + getitem_671 = split_67[0] + getitem_672 = split_67[1] + getitem_673 = split_67[2] + getitem_674 = split_67[3] + getitem_675 = split_67[4] + getitem_676 = split_67[5] + getitem_677 = split_67[6] + getitem_678 = split_67[7]; split_67 = None + cat_59 = torch.ops.aten.cat.default([getitem_671, getitem_672, getitem_673, getitem_674, getitem_675, getitem_676, getitem_677, getitem_678], 1); getitem_671 = getitem_672 = getitem_673 = getitem_674 = getitem_675 = getitem_676 = getitem_677 = getitem_678 = None + view_1068 = torch.ops.aten.view.default(cat_59, [16384, 4096]); cat_59 = None + view_1069 = torch.ops.aten.view.default(mm_102, [2, 8192, 1792]); mm_102 = None + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_1069, torch.float32); view_1069 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 8, '0'); convert_element_type_490 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + mm_103 = torch.ops.aten.mm.default(view_1068, permute_163) + view_1076 = torch.ops.aten.view.default(mm_103, [2, 8192, 1792]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_1076) + view_1083 = torch.ops.aten.view.default(mul_119, [16384, 1792]); mul_119 = None + mm_129 = torch.ops.aten.mm.default(permute_213, view_1083); permute_213 = view_1083 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 8, '0'); convert_element_type_493 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_195, [1, 0]); wait_tensor_195 = None + permute_215 = torch.ops.aten.permute.default(permute_164, [1, 0]); permute_164 = None + mm_130 = torch.ops.aten.mm.default(view_1195, permute_215); view_1195 = permute_215 = None + view_1196 = torch.ops.aten.view.default(mm_130, [2, 8192, 1792]); mm_130 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(mm_129, torch.float32); mm_129 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_604, 'avg', 8, '0'); convert_element_type_604 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47); reduce_scatter_tensor_47 = None + mul_156 = torch.ops.aten.mul.Tensor(view_1196, convert_element_type_489); convert_element_type_489 = None + mul_157 = torch.ops.aten.mul.Tensor(view_1196, view_1076); view_1196 = view_1076 = None + view_1197 = torch.ops.aten.view.default(mul_156, [16384, 1792]); mul_156 = None + permute_217 = torch.ops.aten.permute.default(view_1197, [1, 0]) + mm_131 = torch.ops.aten.mm.default(permute_217, view_1068); permute_217 = None + permute_219 = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None + mm_132 = torch.ops.aten.mm.default(view_1197, permute_219); view_1197 = permute_219 = None + view_1198 = torch.ops.aten.view.default(mm_132, [2, 8192, 4096]); mm_132 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(mm_131, torch.float32); mm_131 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_609, 'avg', 8, '0'); convert_element_type_609 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(mul_157, torch.float32); mul_157 = None + neg_1 = torch.ops.aten.neg.default(convert_element_type_488) + exp_1 = torch.ops.aten.exp.default(neg_1); neg_1 = None + add_72 = torch.ops.aten.add.Tensor(exp_1, 1); exp_1 = None + reciprocal_1 = torch.ops.aten.reciprocal.default(add_72); add_72 = None + mul_158 = torch.ops.aten.mul.Tensor(reciprocal_1, 1); reciprocal_1 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_610, mul_158); convert_element_type_610 = None + sub_5 = torch.ops.aten.sub.Tensor(1, mul_158); mul_158 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_488, sub_5); convert_element_type_488 = sub_5 = None + add_73 = torch.ops.aten.add.Tensor(mul_160, 1); mul_160 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_159, add_73); mul_159 = add_73 = None + convert_element_type_612 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + view_1199 = torch.ops.aten.view.default(convert_element_type_612, [16384, 1792]); convert_element_type_612 = None + permute_221 = torch.ops.aten.permute.default(view_1199, [1, 0]) + mm_133 = torch.ops.aten.mm.default(permute_221, view_1068); permute_221 = view_1068 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 8, '0'); convert_element_type_485 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + permute_223 = torch.ops.aten.permute.default(permute_162, [1, 0]); permute_162 = None + mm_134 = torch.ops.aten.mm.default(view_1199, permute_223); view_1199 = permute_223 = None + view_1200 = torch.ops.aten.view.default(mm_134, [2, 8192, 4096]); mm_134 = None + add_74 = torch.ops.aten.add.Tensor(view_1198, view_1200); view_1198 = view_1200 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(mm_133, torch.float32); mm_133 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_617, 'avg', 8, '0'); convert_element_type_617 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49); reduce_scatter_tensor_49 = None + split_80 = torch.ops.aten.split.Tensor(add_74, 1024, 1); add_74 = None + getitem_787 = split_80[0] + getitem_788 = split_80[1] + getitem_789 = split_80[2] + getitem_790 = split_80[3] + getitem_791 = split_80[4] + getitem_792 = split_80[5] + getitem_793 = split_80[6] + getitem_794 = split_80[7]; split_80 = None + cat_72 = torch.ops.aten.cat.default([getitem_787, getitem_788, getitem_789, getitem_790, getitem_791, getitem_792, getitem_793, getitem_794]); getitem_787 = getitem_788 = getitem_789 = getitem_790 = getitem_791 = getitem_792 = getitem_793 = getitem_794 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_72, 'sum', 8, '1'); cat_72 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + convert_element_type_618 = torch.ops.prims.convert_element_type.default(wait_tensor_236, torch.float32); wait_tensor_236 = None + convert_element_type_620 = torch.ops.prims.convert_element_type.default(wait_tensor_191, torch.float32); wait_tensor_191 = None + mul_162 = torch.ops.aten.mul.Tensor(convert_element_type_618, convert_element_type_620); convert_element_type_620 = None + mul_164 = torch.ops.aten.mul.Tensor(mul_116, mul_162) + sum_9 = torch.ops.aten.sum.dim_IntList(mul_164, [2], True); mul_164 = None + div_3 = torch.ops.aten.div.Tensor(mul_116, 4096) + mul_165 = torch.ops.aten.mul.Tensor(div_3, sum_9); div_3 = sum_9 = None + sub_6 = torch.ops.aten.sub.Tensor(mul_162, mul_165); mul_162 = mul_165 = None + mul_166 = torch.ops.aten.mul.Tensor(sub_6, rsqrt_29); sub_6 = rsqrt_29 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_618, mul_116); convert_element_type_618 = mul_116 = None + sum_10 = torch.ops.aten.sum.dim_IntList(mul_167, [0, 1]); mul_167 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(sum_10, torch.bfloat16); sum_10 = None + all_reduce_3 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_622, 'sum', '1'); convert_element_type_622 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_3); all_reduce_3 = None + convert_element_type_623 = torch.ops.prims.convert_element_type.default(wait_tensor_237, torch.float32); wait_tensor_237 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_623, 'avg', 8, '0'); convert_element_type_623 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51); reduce_scatter_tensor_51 = None + add_75 = torch.ops.aten.add.Tensor(add_71, convert_element_type_621); add_71 = convert_element_type_621 = None + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_75, 8, '1') + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + split_81 = torch.ops.aten.split.Tensor(wait_tensor_239, 2); wait_tensor_239 = None + getitem_795 = split_81[0] + getitem_796 = split_81[1] + getitem_797 = split_81[2] + getitem_798 = split_81[3] + getitem_799 = split_81[4] + getitem_800 = split_81[5] + getitem_801 = split_81[6] + getitem_802 = split_81[7]; split_81 = None + cat_73 = torch.ops.aten.cat.default([getitem_795, getitem_796, getitem_797, getitem_798, getitem_799, getitem_800, getitem_801, getitem_802], 1); getitem_795 = getitem_796 = getitem_797 = getitem_798 = getitem_799 = getitem_800 = getitem_801 = getitem_802 = None + view_1201 = torch.ops.aten.view.default(cat_73, [16384, 4096]); cat_73 = None + permute_225 = torch.ops.aten.permute.default(view_1201, [1, 0]) + permute_160 = torch.ops.aten.permute.default(getitem_654, [0, 2, 1, 3]) + view_1050 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + view_1056 = torch.ops.aten.view.default(view_1050, [16384, 512]); view_1050 = None + mm_135 = torch.ops.aten.mm.default(permute_225, view_1056); permute_225 = view_1056 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16); primals_134 = None + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 8, '0'); convert_element_type_479 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + permute_227 = torch.ops.aten.permute.default(permute_161, [1, 0]); permute_161 = None + mm_136 = torch.ops.aten.mm.default(view_1201, permute_227); view_1201 = permute_227 = None + view_1202 = torch.ops.aten.view.default(mm_136, [2, 8192, 512]); mm_136 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(mm_135, torch.float32); mm_135 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_628, 'avg', 8, '0'); convert_element_type_628 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + view_1203 = torch.ops.aten.view.default(view_1202, [2, 8192, 4, 128]); view_1202 = None + permute_229 = torch.ops.aten.permute.default(view_1203, [0, 2, 1, 3]); view_1203 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16); primals_130 = None + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 8, '0'); convert_element_type_463 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32); add_55 = None + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_184) + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_465, 8, '1'); convert_element_type_465 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + split_65 = torch.ops.aten.split.Tensor(wait_tensor_185, 2); wait_tensor_185 = None + getitem_646 = split_65[0] + getitem_647 = split_65[1] + getitem_648 = split_65[2] + getitem_649 = split_65[3] + getitem_650 = split_65[4] + getitem_651 = split_65[5] + getitem_652 = split_65[6] + getitem_653 = split_65[7]; split_65 = None + cat_57 = torch.ops.aten.cat.default([getitem_646, getitem_647, getitem_648, getitem_649, getitem_650, getitem_651, getitem_652, getitem_653], 1); getitem_646 = getitem_647 = getitem_648 = getitem_649 = getitem_650 = getitem_651 = getitem_652 = getitem_653 = None + view_1023 = torch.ops.aten.view.default(cat_57, [16384, 4096]); cat_57 = None + view_1024 = torch.ops.aten.view.default(mm_98, [2, 8192, 512]); mm_98 = None + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16); primals_132 = None + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 8, '0'); convert_element_type_469 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + mm_99 = torch.ops.aten.mm.default(view_1023, permute_155) + view_1031 = torch.ops.aten.view.default(mm_99, [2, 8192, 128]); mm_99 = None + view_1038 = torch.ops.aten.view.default(mm_100, [2, 8192, 128]); mm_100 = None + view_1040 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1041 = torch.ops.aten.view.default(view_1031, [2, 8192, -1, 128]); view_1031 = None + view_1042 = torch.ops.aten.view.default(view_1038, [2, 8192, -1, 128]); view_1038 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_1040, torch.float32); view_1040 = None + view_1043 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 4, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_1043); view_1043 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_1041, torch.float32); view_1041 = None + view_1044 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 1, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_1044); view_1044 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_37); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_1046 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 4, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_37); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_1047 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 1, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_1046, torch.bfloat16); view_1046 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_1047, torch.bfloat16); view_1047 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 1, 4, 128]); unsqueeze_28 = None + view_1048 = torch.ops.aten.view.default(expand_28, [2, 8192, 4, 128]); expand_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_1042, 3); view_1042 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 1, 4, 128]); unsqueeze_29 = None + view_1049 = torch.ops.aten.view.default(expand_29, [2, 8192, 4, 128]); expand_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_1048, [0, 2, 1, 3]); view_1048 = None + permute_159 = torch.ops.aten.permute.default(view_1049, [0, 2, 1, 3]); view_1049 = None + _scaled_dot_product_cudnn_attention_backward_1 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_229, permute_157, permute_158, permute_159, getitem_654, getitem_655, getitem_660, getitem_661, None, None, None, 8192, 8192, 0.0, True); permute_229 = permute_157 = permute_158 = permute_159 = getitem_654 = getitem_655 = getitem_660 = getitem_661 = None + getitem_803 = _scaled_dot_product_cudnn_attention_backward_1[0] + getitem_804 = _scaled_dot_product_cudnn_attention_backward_1[1] + getitem_805 = _scaled_dot_product_cudnn_attention_backward_1[2]; _scaled_dot_product_cudnn_attention_backward_1 = None + permute_230 = torch.ops.aten.permute.default(getitem_805, [0, 2, 1, 3]); getitem_805 = None + permute_231 = torch.ops.aten.permute.default(getitem_804, [0, 2, 1, 3]); getitem_804 = None + permute_232 = torch.ops.aten.permute.default(getitem_803, [0, 2, 1, 3]); getitem_803 = None + view_1204 = torch.ops.aten.view.default(permute_230, [2, 8192, 1, 4, 128]); permute_230 = None + sum_11 = torch.ops.aten.sum.dim_IntList(view_1204, [3], True); view_1204 = None + squeeze_2 = torch.ops.aten.squeeze.dim(sum_11, 3); sum_11 = None + view_1205 = torch.ops.aten.view.default(permute_231, [2, 8192, 1, 4, 128]); permute_231 = None + sum_12 = torch.ops.aten.sum.dim_IntList(view_1205, [3], True); view_1205 = None + squeeze_3 = torch.ops.aten.squeeze.dim(sum_12, 3); sum_12 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(squeeze_3, torch.float32); squeeze_3 = None + convert_element_type_630 = torch.ops.prims.convert_element_type.default(permute_232, torch.float32); permute_232 = None + view_1206 = torch.ops.aten.view.default(convert_element_type_629, [2, 8192, 1, 64, 2]); convert_element_type_629 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_1206); view_1206 = None + mul_168 = torch.ops.aten.mul.Tensor(view_as_complex_34, _conj); view_as_complex_34 = None + view_1207 = torch.ops.aten.view.default(convert_element_type_630, [2, 8192, 4, 64, 2]); convert_element_type_630 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_1207); view_1207 = None + mul_169 = torch.ops.aten.mul.Tensor(view_as_complex_35, _conj); view_as_complex_35 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_168); mul_168 = None + view_1208 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 1, 128]); view_as_real_34 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(view_1208, torch.bfloat16); view_1208 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_169); mul_169 = None + view_1209 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 4, 128]); view_as_real_35 = None + convert_element_type_632 = torch.ops.prims.convert_element_type.default(view_1209, torch.bfloat16); view_1209 = None + view_1210 = torch.ops.aten.view.default(squeeze_2, [2, 8192, 128]); squeeze_2 = None + view_1211 = torch.ops.aten.view.default(convert_element_type_631, [2, 8192, 128]); convert_element_type_631 = None + view_1212 = torch.ops.aten.view.default(convert_element_type_632, [2, 8192, 512]); convert_element_type_632 = None + view_1213 = torch.ops.aten.view.default(view_1210, [16384, 128]); view_1210 = None + permute_233 = torch.ops.aten.permute.default(view_1213, [1, 0]) + mm_137 = torch.ops.aten.mm.default(permute_233, view_1023); permute_233 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16); primals_133 = None + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 8, '0'); convert_element_type_472 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + permute_235 = torch.ops.aten.permute.default(permute_156, [1, 0]); permute_156 = None + mm_138 = torch.ops.aten.mm.default(view_1213, permute_235); view_1213 = permute_235 = None + view_1214 = torch.ops.aten.view.default(mm_138, [2, 8192, 4096]); mm_138 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(mm_137, torch.float32); mm_137 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_637, 'avg', 8, '0'); convert_element_type_637 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53); reduce_scatter_tensor_53 = None + view_1215 = torch.ops.aten.view.default(view_1211, [16384, 128]); view_1211 = None + permute_237 = torch.ops.aten.permute.default(view_1215, [1, 0]) + mm_139 = torch.ops.aten.mm.default(permute_237, view_1023); permute_237 = None + permute_239 = torch.ops.aten.permute.default(permute_155, [1, 0]); permute_155 = None + mm_140 = torch.ops.aten.mm.default(view_1215, permute_239); view_1215 = permute_239 = None + view_1216 = torch.ops.aten.view.default(mm_140, [2, 8192, 4096]); mm_140 = None + add_76 = torch.ops.aten.add.Tensor(view_1214, view_1216); view_1214 = view_1216 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(mm_139, torch.float32); mm_139 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_642, 'avg', 8, '0'); convert_element_type_642 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + view_1217 = torch.ops.aten.view.default(view_1212, [16384, 512]); view_1212 = None + permute_241 = torch.ops.aten.permute.default(view_1217, [1, 0]) + mm_141 = torch.ops.aten.mm.default(permute_241, view_1023); permute_241 = view_1023 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16); primals_131 = None + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 8, '0'); convert_element_type_466 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_186, [1, 0]); wait_tensor_186 = None + permute_243 = torch.ops.aten.permute.default(permute_154, [1, 0]); permute_154 = None + mm_142 = torch.ops.aten.mm.default(view_1217, permute_243); view_1217 = permute_243 = None + view_1218 = torch.ops.aten.view.default(mm_142, [2, 8192, 4096]); mm_142 = None + add_77 = torch.ops.aten.add.Tensor(add_76, view_1218); add_76 = view_1218 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(mm_141, torch.float32); mm_141 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_647, 'avg', 8, '0'); convert_element_type_647 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55); reduce_scatter_tensor_55 = None + split_82 = torch.ops.aten.split.Tensor(add_77, 1024, 1); add_77 = None + getitem_806 = split_82[0] + getitem_807 = split_82[1] + getitem_808 = split_82[2] + getitem_809 = split_82[3] + getitem_810 = split_82[4] + getitem_811 = split_82[5] + getitem_812 = split_82[6] + getitem_813 = split_82[7]; split_82 = None + cat_74 = torch.ops.aten.cat.default([getitem_806, getitem_807, getitem_808, getitem_809, getitem_810, getitem_811, getitem_812, getitem_813]); getitem_806 = getitem_807 = getitem_808 = getitem_809 = getitem_810 = getitem_811 = getitem_812 = getitem_813 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_74, 'sum', 8, '1'); cat_74 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(wait_tensor_244, torch.float32); wait_tensor_244 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(wait_tensor_184, torch.float32); wait_tensor_184 = None + mul_170 = torch.ops.aten.mul.Tensor(convert_element_type_648, convert_element_type_650); convert_element_type_650 = None + mul_172 = torch.ops.aten.mul.Tensor(mul_112, mul_170) + sum_13 = torch.ops.aten.sum.dim_IntList(mul_172, [2], True); mul_172 = None + div_4 = torch.ops.aten.div.Tensor(mul_112, 4096) + mul_173 = torch.ops.aten.mul.Tensor(div_4, sum_13); div_4 = sum_13 = None + sub_7 = torch.ops.aten.sub.Tensor(mul_170, mul_173); mul_170 = mul_173 = None + mul_174 = torch.ops.aten.mul.Tensor(sub_7, rsqrt_28); sub_7 = rsqrt_28 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_648, mul_112); convert_element_type_648 = mul_112 = None + sum_14 = torch.ops.aten.sum.dim_IntList(mul_175, [0, 1]); mul_175 = None + convert_element_type_651 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_652 = torch.ops.prims.convert_element_type.default(sum_14, torch.bfloat16); sum_14 = None + all_reduce_4 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_652, 'sum', '1'); convert_element_type_652 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_4); all_reduce_4 = None + convert_element_type_653 = torch.ops.prims.convert_element_type.default(wait_tensor_245, torch.float32); wait_tensor_245 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_653, 'avg', 8, '0'); convert_element_type_653 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57); reduce_scatter_tensor_57 = None + add_78 = torch.ops.aten.add.Tensor(add_75, convert_element_type_651); add_75 = convert_element_type_651 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_78, 8, '1') + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + split_83 = torch.ops.aten.split.Tensor(wait_tensor_247, 2); wait_tensor_247 = None + getitem_814 = split_83[0] + getitem_815 = split_83[1] + getitem_816 = split_83[2] + getitem_817 = split_83[3] + getitem_818 = split_83[4] + getitem_819 = split_83[5] + getitem_820 = split_83[6] + getitem_821 = split_83[7]; split_83 = None + cat_75 = torch.ops.aten.cat.default([getitem_814, getitem_815, getitem_816, getitem_817, getitem_818, getitem_819, getitem_820, getitem_821], 1); getitem_814 = getitem_815 = getitem_816 = getitem_817 = getitem_818 = getitem_819 = getitem_820 = getitem_821 = None + view_1219 = torch.ops.aten.view.default(cat_75, [16384, 4096]); cat_75 = None + permute_245 = torch.ops.aten.permute.default(view_1219, [1, 0]) + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27); reduce_scatter_tensor_27 = None + add_53 = torch.ops.aten.add.Tensor(add_51, wait_tensor_177); wait_tensor_177 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16); primals_126 = None + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 8, '0'); convert_element_type_449 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32); add_53 = None + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_178) + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 8, '1'); convert_element_type_451 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + split_63 = torch.ops.aten.split.Tensor(wait_tensor_179, 2); wait_tensor_179 = None + getitem_630 = split_63[0] + getitem_631 = split_63[1] + getitem_632 = split_63[2] + getitem_633 = split_63[3] + getitem_634 = split_63[4] + getitem_635 = split_63[5] + getitem_636 = split_63[6] + getitem_637 = split_63[7]; split_63 = None + cat_55 = torch.ops.aten.cat.default([getitem_630, getitem_631, getitem_632, getitem_633, getitem_634, getitem_635, getitem_636, getitem_637], 1); getitem_630 = getitem_631 = getitem_632 = getitem_633 = getitem_634 = getitem_635 = getitem_636 = getitem_637 = None + view_996 = torch.ops.aten.view.default(cat_55, [16384, 4096]); cat_55 = None + view_997 = torch.ops.aten.view.default(mm_95, [2, 8192, 1792]); mm_95 = None + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16); primals_128 = None + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 8, '0'); convert_element_type_457 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_181, [1, 0]); wait_tensor_181 = None + mm_96 = torch.ops.aten.mm.default(view_996, permute_152) + view_1004 = torch.ops.aten.view.default(mm_96, [2, 8192, 1792]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_1004) + view_1011 = torch.ops.aten.view.default(mul_111, [16384, 1792]); mul_111 = None + mm_143 = torch.ops.aten.mm.default(permute_245, view_1011); permute_245 = view_1011 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16); primals_129 = None + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 8, '0'); convert_element_type_460 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + permute_247 = torch.ops.aten.permute.default(permute_153, [1, 0]); permute_153 = None + mm_144 = torch.ops.aten.mm.default(view_1219, permute_247); view_1219 = permute_247 = None + view_1220 = torch.ops.aten.view.default(mm_144, [2, 8192, 1792]); mm_144 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(mm_143, torch.float32); mm_143 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_658, 'avg', 8, '0'); convert_element_type_658 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + mul_176 = torch.ops.aten.mul.Tensor(view_1220, convert_element_type_456); convert_element_type_456 = None + mul_177 = torch.ops.aten.mul.Tensor(view_1220, view_1004); view_1220 = view_1004 = None + view_1221 = torch.ops.aten.view.default(mul_176, [16384, 1792]); mul_176 = None + permute_249 = torch.ops.aten.permute.default(view_1221, [1, 0]) + mm_145 = torch.ops.aten.mm.default(permute_249, view_996); permute_249 = None + permute_251 = torch.ops.aten.permute.default(permute_152, [1, 0]); permute_152 = None + mm_146 = torch.ops.aten.mm.default(view_1221, permute_251); view_1221 = permute_251 = None + view_1222 = torch.ops.aten.view.default(mm_146, [2, 8192, 4096]); mm_146 = None + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mm_145, torch.float32); mm_145 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_663, 'avg', 8, '0'); convert_element_type_663 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59); reduce_scatter_tensor_59 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(mul_177, torch.float32); mul_177 = None + neg_2 = torch.ops.aten.neg.default(convert_element_type_455) + exp_2 = torch.ops.aten.exp.default(neg_2); neg_2 = None + add_79 = torch.ops.aten.add.Tensor(exp_2, 1); exp_2 = None + reciprocal_2 = torch.ops.aten.reciprocal.default(add_79); add_79 = None + mul_178 = torch.ops.aten.mul.Tensor(reciprocal_2, 1); reciprocal_2 = None + mul_179 = torch.ops.aten.mul.Tensor(convert_element_type_664, mul_178); convert_element_type_664 = None + sub_8 = torch.ops.aten.sub.Tensor(1, mul_178); mul_178 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_455, sub_8); convert_element_type_455 = sub_8 = None + add_80 = torch.ops.aten.add.Tensor(mul_180, 1); mul_180 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_179, add_80); mul_179 = add_80 = None + convert_element_type_666 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + view_1223 = torch.ops.aten.view.default(convert_element_type_666, [16384, 1792]); convert_element_type_666 = None + permute_253 = torch.ops.aten.permute.default(view_1223, [1, 0]) + mm_147 = torch.ops.aten.mm.default(permute_253, view_996); permute_253 = view_996 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16); primals_127 = None + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 8, '0'); convert_element_type_452 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + permute_255 = torch.ops.aten.permute.default(permute_151, [1, 0]); permute_151 = None + mm_148 = torch.ops.aten.mm.default(view_1223, permute_255); view_1223 = permute_255 = None + view_1224 = torch.ops.aten.view.default(mm_148, [2, 8192, 4096]); mm_148 = None + add_81 = torch.ops.aten.add.Tensor(view_1222, view_1224); view_1222 = view_1224 = None + convert_element_type_671 = torch.ops.prims.convert_element_type.default(mm_147, torch.float32); mm_147 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_671, 'avg', 8, '0'); convert_element_type_671 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + split_84 = torch.ops.aten.split.Tensor(add_81, 1024, 1); add_81 = None + getitem_822 = split_84[0] + getitem_823 = split_84[1] + getitem_824 = split_84[2] + getitem_825 = split_84[3] + getitem_826 = split_84[4] + getitem_827 = split_84[5] + getitem_828 = split_84[6] + getitem_829 = split_84[7]; split_84 = None + cat_76 = torch.ops.aten.cat.default([getitem_822, getitem_823, getitem_824, getitem_825, getitem_826, getitem_827, getitem_828, getitem_829]); getitem_822 = getitem_823 = getitem_824 = getitem_825 = getitem_826 = getitem_827 = getitem_828 = getitem_829 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_76, 'sum', 8, '1'); cat_76 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61); reduce_scatter_tensor_61 = None + convert_element_type_672 = torch.ops.prims.convert_element_type.default(wait_tensor_251, torch.float32); wait_tensor_251 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(wait_tensor_178, torch.float32); wait_tensor_178 = None + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_672, convert_element_type_674); convert_element_type_674 = None + mul_184 = torch.ops.aten.mul.Tensor(mul_108, mul_182) + sum_15 = torch.ops.aten.sum.dim_IntList(mul_184, [2], True); mul_184 = None + div_5 = torch.ops.aten.div.Tensor(mul_108, 4096) + mul_185 = torch.ops.aten.mul.Tensor(div_5, sum_15); div_5 = sum_15 = None + sub_9 = torch.ops.aten.sub.Tensor(mul_182, mul_185); mul_182 = mul_185 = None + mul_186 = torch.ops.aten.mul.Tensor(sub_9, rsqrt_27); sub_9 = rsqrt_27 = None + mul_187 = torch.ops.aten.mul.Tensor(convert_element_type_672, mul_108); convert_element_type_672 = mul_108 = None + sum_16 = torch.ops.aten.sum.dim_IntList(mul_187, [0, 1]); mul_187 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(mul_186, torch.bfloat16); mul_186 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(sum_16, torch.bfloat16); sum_16 = None + all_reduce_5 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_676, 'sum', '1'); convert_element_type_676 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_5); all_reduce_5 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(wait_tensor_252, torch.float32); wait_tensor_252 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_677, 'avg', 8, '0'); convert_element_type_677 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + add_82 = torch.ops.aten.add.Tensor(add_78, convert_element_type_675); add_78 = convert_element_type_675 = None + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_82, 8, '1') + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + split_85 = torch.ops.aten.split.Tensor(wait_tensor_254, 2); wait_tensor_254 = None + getitem_830 = split_85[0] + getitem_831 = split_85[1] + getitem_832 = split_85[2] + getitem_833 = split_85[3] + getitem_834 = split_85[4] + getitem_835 = split_85[5] + getitem_836 = split_85[6] + getitem_837 = split_85[7]; split_85 = None + cat_77 = torch.ops.aten.cat.default([getitem_830, getitem_831, getitem_832, getitem_833, getitem_834, getitem_835, getitem_836, getitem_837], 1); getitem_830 = getitem_831 = getitem_832 = getitem_833 = getitem_834 = getitem_835 = getitem_836 = getitem_837 = None + view_1225 = torch.ops.aten.view.default(cat_77, [16384, 4096]); cat_77 = None + permute_257 = torch.ops.aten.permute.default(view_1225, [1, 0]) + permute_149 = torch.ops.aten.permute.default(getitem_613, [0, 2, 1, 3]) + view_978 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + view_984 = torch.ops.aten.view.default(view_978, [16384, 512]); view_978 = None + mm_149 = torch.ops.aten.mm.default(permute_257, view_984); permute_257 = view_984 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 8, '0'); convert_element_type_446 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + permute_259 = torch.ops.aten.permute.default(permute_150, [1, 0]); permute_150 = None + mm_150 = torch.ops.aten.mm.default(view_1225, permute_259); view_1225 = permute_259 = None + view_1226 = torch.ops.aten.view.default(mm_150, [2, 8192, 512]); mm_150 = None + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mm_149, torch.float32); mm_149 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_682, 'avg', 8, '0'); convert_element_type_682 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63); reduce_scatter_tensor_63 = None + view_1227 = torch.ops.aten.view.default(view_1226, [2, 8192, 4, 128]); view_1226 = None + permute_261 = torch.ops.aten.permute.default(view_1227, [0, 2, 1, 3]); view_1227 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 8, '0'); convert_element_type_430 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32); add_51 = None + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_171) + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_432, 8, '1'); convert_element_type_432 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + split_61 = torch.ops.aten.split.Tensor(wait_tensor_172, 2); wait_tensor_172 = None + getitem_605 = split_61[0] + getitem_606 = split_61[1] + getitem_607 = split_61[2] + getitem_608 = split_61[3] + getitem_609 = split_61[4] + getitem_610 = split_61[5] + getitem_611 = split_61[6] + getitem_612 = split_61[7]; split_61 = None + cat_53 = torch.ops.aten.cat.default([getitem_605, getitem_606, getitem_607, getitem_608, getitem_609, getitem_610, getitem_611, getitem_612], 1); getitem_605 = getitem_606 = getitem_607 = getitem_608 = getitem_609 = getitem_610 = getitem_611 = getitem_612 = None + view_951 = torch.ops.aten.view.default(cat_53, [16384, 4096]); cat_53 = None + view_952 = torch.ops.aten.view.default(mm_91, [2, 8192, 512]); mm_91 = None + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 8, '0'); convert_element_type_436 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_92 = torch.ops.aten.mm.default(view_951, permute_144) + view_959 = torch.ops.aten.view.default(mm_92, [2, 8192, 128]); mm_92 = None + view_966 = torch.ops.aten.view.default(mm_93, [2, 8192, 128]); mm_93 = None + view_968 = torch.ops.aten.view.default(view_952, [2, 8192, -1, 128]); view_952 = None + view_969 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_970 = torch.ops.aten.view.default(view_966, [2, 8192, -1, 128]); view_966 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_968, torch.float32); view_968 = None + view_971 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 4, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_971); view_971 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_969, torch.float32); view_969 = None + view_972 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 1, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_972); view_972 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_37); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_974 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 4, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_37); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_975 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 1, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_974, torch.bfloat16); view_974 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_975, torch.bfloat16); view_975 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 1, 4, 128]); unsqueeze_26 = None + view_976 = torch.ops.aten.view.default(expand_26, [2, 8192, 4, 128]); expand_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_970, 3); view_970 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 1, 4, 128]); unsqueeze_27 = None + view_977 = torch.ops.aten.view.default(expand_27, [2, 8192, 4, 128]); expand_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_976, [0, 2, 1, 3]); view_976 = None + permute_148 = torch.ops.aten.permute.default(view_977, [0, 2, 1, 3]); view_977 = None + _scaled_dot_product_cudnn_attention_backward_2 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_261, permute_146, permute_147, permute_148, getitem_613, getitem_614, getitem_619, getitem_620, None, None, None, 8192, 8192, 0.0, True); permute_261 = permute_146 = permute_147 = permute_148 = getitem_613 = getitem_614 = getitem_619 = getitem_620 = None + getitem_838 = _scaled_dot_product_cudnn_attention_backward_2[0] + getitem_839 = _scaled_dot_product_cudnn_attention_backward_2[1] + getitem_840 = _scaled_dot_product_cudnn_attention_backward_2[2]; _scaled_dot_product_cudnn_attention_backward_2 = None + permute_262 = torch.ops.aten.permute.default(getitem_840, [0, 2, 1, 3]); getitem_840 = None + permute_263 = torch.ops.aten.permute.default(getitem_839, [0, 2, 1, 3]); getitem_839 = None + permute_264 = torch.ops.aten.permute.default(getitem_838, [0, 2, 1, 3]); getitem_838 = None + view_1228 = torch.ops.aten.view.default(permute_262, [2, 8192, 1, 4, 128]); permute_262 = None + sum_17 = torch.ops.aten.sum.dim_IntList(view_1228, [3], True); view_1228 = None + squeeze_4 = torch.ops.aten.squeeze.dim(sum_17, 3); sum_17 = None + view_1229 = torch.ops.aten.view.default(permute_263, [2, 8192, 1, 4, 128]); permute_263 = None + sum_18 = torch.ops.aten.sum.dim_IntList(view_1229, [3], True); view_1229 = None + squeeze_5 = torch.ops.aten.squeeze.dim(sum_18, 3); sum_18 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(squeeze_5, torch.float32); squeeze_5 = None + convert_element_type_684 = torch.ops.prims.convert_element_type.default(permute_264, torch.float32); permute_264 = None + view_1230 = torch.ops.aten.view.default(convert_element_type_683, [2, 8192, 1, 64, 2]); convert_element_type_683 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_1230); view_1230 = None + mul_188 = torch.ops.aten.mul.Tensor(view_as_complex_36, _conj); view_as_complex_36 = None + view_1231 = torch.ops.aten.view.default(convert_element_type_684, [2, 8192, 4, 64, 2]); convert_element_type_684 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_1231); view_1231 = None + mul_189 = torch.ops.aten.mul.Tensor(view_as_complex_37, _conj); view_as_complex_37 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_188); mul_188 = None + view_1232 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 1, 128]); view_as_real_36 = None + convert_element_type_685 = torch.ops.prims.convert_element_type.default(view_1232, torch.bfloat16); view_1232 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_189); mul_189 = None + view_1233 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 4, 128]); view_as_real_37 = None + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_1233, torch.bfloat16); view_1233 = None + view_1234 = torch.ops.aten.view.default(squeeze_4, [2, 8192, 128]); squeeze_4 = None + view_1235 = torch.ops.aten.view.default(convert_element_type_685, [2, 8192, 128]); convert_element_type_685 = None + view_1236 = torch.ops.aten.view.default(convert_element_type_686, [2, 8192, 512]); convert_element_type_686 = None + view_1237 = torch.ops.aten.view.default(view_1234, [16384, 128]); view_1234 = None + permute_265 = torch.ops.aten.permute.default(view_1237, [1, 0]) + mm_151 = torch.ops.aten.mm.default(permute_265, view_951); permute_265 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 8, '0'); convert_element_type_439 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + permute_267 = torch.ops.aten.permute.default(permute_145, [1, 0]); permute_145 = None + mm_152 = torch.ops.aten.mm.default(view_1237, permute_267); view_1237 = permute_267 = None + view_1238 = torch.ops.aten.view.default(mm_152, [2, 8192, 4096]); mm_152 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(mm_151, torch.float32); mm_151 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_691, 'avg', 8, '0'); convert_element_type_691 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64); reduce_scatter_tensor_64 = None + view_1239 = torch.ops.aten.view.default(view_1235, [16384, 128]); view_1235 = None + permute_269 = torch.ops.aten.permute.default(view_1239, [1, 0]) + mm_153 = torch.ops.aten.mm.default(permute_269, view_951); permute_269 = None + permute_271 = torch.ops.aten.permute.default(permute_144, [1, 0]); permute_144 = None + mm_154 = torch.ops.aten.mm.default(view_1239, permute_271); view_1239 = permute_271 = None + view_1240 = torch.ops.aten.view.default(mm_154, [2, 8192, 4096]); mm_154 = None + add_83 = torch.ops.aten.add.Tensor(view_1238, view_1240); view_1238 = view_1240 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mm_153, torch.float32); mm_153 = None + reduce_scatter_tensor_65 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_696, 'avg', 8, '0'); convert_element_type_696 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_65); reduce_scatter_tensor_65 = None + view_1241 = torch.ops.aten.view.default(view_1236, [16384, 512]); view_1236 = None + permute_273 = torch.ops.aten.permute.default(view_1241, [1, 0]) + mm_155 = torch.ops.aten.mm.default(permute_273, view_951); permute_273 = view_951 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 8, '0'); convert_element_type_433 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + permute_275 = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None + mm_156 = torch.ops.aten.mm.default(view_1241, permute_275); view_1241 = permute_275 = None + view_1242 = torch.ops.aten.view.default(mm_156, [2, 8192, 4096]); mm_156 = None + add_84 = torch.ops.aten.add.Tensor(add_83, view_1242); add_83 = view_1242 = None + convert_element_type_701 = torch.ops.prims.convert_element_type.default(mm_155, torch.float32); mm_155 = None + reduce_scatter_tensor_66 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_701, 'avg', 8, '0'); convert_element_type_701 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_66); reduce_scatter_tensor_66 = None + split_86 = torch.ops.aten.split.Tensor(add_84, 1024, 1); add_84 = None + getitem_841 = split_86[0] + getitem_842 = split_86[1] + getitem_843 = split_86[2] + getitem_844 = split_86[3] + getitem_845 = split_86[4] + getitem_846 = split_86[5] + getitem_847 = split_86[6] + getitem_848 = split_86[7]; split_86 = None + cat_78 = torch.ops.aten.cat.default([getitem_841, getitem_842, getitem_843, getitem_844, getitem_845, getitem_846, getitem_847, getitem_848]); getitem_841 = getitem_842 = getitem_843 = getitem_844 = getitem_845 = getitem_846 = getitem_847 = getitem_848 = None + reduce_scatter_tensor_67 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_78, 'sum', 8, '1'); cat_78 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_67); reduce_scatter_tensor_67 = None + convert_element_type_702 = torch.ops.prims.convert_element_type.default(wait_tensor_259, torch.float32); wait_tensor_259 = None + convert_element_type_704 = torch.ops.prims.convert_element_type.default(wait_tensor_171, torch.float32); wait_tensor_171 = None + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_702, convert_element_type_704); convert_element_type_704 = None + mul_192 = torch.ops.aten.mul.Tensor(mul_104, mul_190) + sum_19 = torch.ops.aten.sum.dim_IntList(mul_192, [2], True); mul_192 = None + div_6 = torch.ops.aten.div.Tensor(mul_104, 4096) + mul_193 = torch.ops.aten.mul.Tensor(div_6, sum_19); div_6 = sum_19 = None + sub_10 = torch.ops.aten.sub.Tensor(mul_190, mul_193); mul_190 = mul_193 = None + mul_194 = torch.ops.aten.mul.Tensor(sub_10, rsqrt_26); sub_10 = rsqrt_26 = None + mul_195 = torch.ops.aten.mul.Tensor(convert_element_type_702, mul_104); convert_element_type_702 = mul_104 = None + sum_20 = torch.ops.aten.sum.dim_IntList(mul_195, [0, 1]); mul_195 = None + convert_element_type_705 = torch.ops.prims.convert_element_type.default(mul_194, torch.bfloat16); mul_194 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(sum_20, torch.bfloat16); sum_20 = None + all_reduce_6 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_706, 'sum', '1'); convert_element_type_706 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_6); all_reduce_6 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(wait_tensor_260, torch.float32); wait_tensor_260 = None + reduce_scatter_tensor_68 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_707, 'avg', 8, '0'); convert_element_type_707 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_68); reduce_scatter_tensor_68 = None + add_85 = torch.ops.aten.add.Tensor(add_82, convert_element_type_705); add_82 = convert_element_type_705 = None + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_85, 8, '1') + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + split_87 = torch.ops.aten.split.Tensor(wait_tensor_262, 2); wait_tensor_262 = None + getitem_849 = split_87[0] + getitem_850 = split_87[1] + getitem_851 = split_87[2] + getitem_852 = split_87[3] + getitem_853 = split_87[4] + getitem_854 = split_87[5] + getitem_855 = split_87[6] + getitem_856 = split_87[7]; split_87 = None + cat_79 = torch.ops.aten.cat.default([getitem_849, getitem_850, getitem_851, getitem_852, getitem_853, getitem_854, getitem_855, getitem_856], 1); getitem_849 = getitem_850 = getitem_851 = getitem_852 = getitem_853 = getitem_854 = getitem_855 = getitem_856 = None + view_1243 = torch.ops.aten.view.default(cat_79, [16384, 4096]); cat_79 = None + permute_277 = torch.ops.aten.permute.default(view_1243, [1, 0]) + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25); reduce_scatter_tensor_25 = None + add_49 = torch.ops.aten.add.Tensor(add_47, wait_tensor_164); wait_tensor_164 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16); primals_117 = None + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 8, '0'); convert_element_type_416 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32); add_49 = None + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_165) + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 8, '1'); convert_element_type_418 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + split_59 = torch.ops.aten.split.Tensor(wait_tensor_166, 2); wait_tensor_166 = None + getitem_589 = split_59[0] + getitem_590 = split_59[1] + getitem_591 = split_59[2] + getitem_592 = split_59[3] + getitem_593 = split_59[4] + getitem_594 = split_59[5] + getitem_595 = split_59[6] + getitem_596 = split_59[7]; split_59 = None + cat_51 = torch.ops.aten.cat.default([getitem_589, getitem_590, getitem_591, getitem_592, getitem_593, getitem_594, getitem_595, getitem_596], 1); getitem_589 = getitem_590 = getitem_591 = getitem_592 = getitem_593 = getitem_594 = getitem_595 = getitem_596 = None + view_924 = torch.ops.aten.view.default(cat_51, [16384, 4096]); cat_51 = None + view_925 = torch.ops.aten.view.default(mm_88, [2, 8192, 1792]); mm_88 = None + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_925, torch.float32); view_925 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 8, '0'); convert_element_type_424 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_168, [1, 0]); wait_tensor_168 = None + mm_89 = torch.ops.aten.mm.default(view_924, permute_141) + view_932 = torch.ops.aten.view.default(mm_89, [2, 8192, 1792]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_932) + view_939 = torch.ops.aten.view.default(mul_103, [16384, 1792]); mul_103 = None + mm_157 = torch.ops.aten.mm.default(permute_277, view_939); permute_277 = view_939 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 8, '0'); convert_element_type_427 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + permute_279 = torch.ops.aten.permute.default(permute_142, [1, 0]); permute_142 = None + mm_158 = torch.ops.aten.mm.default(view_1243, permute_279); view_1243 = permute_279 = None + view_1244 = torch.ops.aten.view.default(mm_158, [2, 8192, 1792]); mm_158 = None + convert_element_type_712 = torch.ops.prims.convert_element_type.default(mm_157, torch.float32); mm_157 = None + reduce_scatter_tensor_69 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_712, 'avg', 8, '0'); convert_element_type_712 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_69); reduce_scatter_tensor_69 = None + mul_196 = torch.ops.aten.mul.Tensor(view_1244, convert_element_type_423); convert_element_type_423 = None + mul_197 = torch.ops.aten.mul.Tensor(view_1244, view_932); view_1244 = view_932 = None + view_1245 = torch.ops.aten.view.default(mul_196, [16384, 1792]); mul_196 = None + permute_281 = torch.ops.aten.permute.default(view_1245, [1, 0]) + mm_159 = torch.ops.aten.mm.default(permute_281, view_924); permute_281 = None + permute_283 = torch.ops.aten.permute.default(permute_141, [1, 0]); permute_141 = None + mm_160 = torch.ops.aten.mm.default(view_1245, permute_283); view_1245 = permute_283 = None + view_1246 = torch.ops.aten.view.default(mm_160, [2, 8192, 4096]); mm_160 = None + convert_element_type_717 = torch.ops.prims.convert_element_type.default(mm_159, torch.float32); mm_159 = None + reduce_scatter_tensor_70 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_717, 'avg', 8, '0'); convert_element_type_717 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_70); reduce_scatter_tensor_70 = None + convert_element_type_718 = torch.ops.prims.convert_element_type.default(mul_197, torch.float32); mul_197 = None + neg_3 = torch.ops.aten.neg.default(convert_element_type_422) + exp_3 = torch.ops.aten.exp.default(neg_3); neg_3 = None + add_86 = torch.ops.aten.add.Tensor(exp_3, 1); exp_3 = None + reciprocal_3 = torch.ops.aten.reciprocal.default(add_86); add_86 = None + mul_198 = torch.ops.aten.mul.Tensor(reciprocal_3, 1); reciprocal_3 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_718, mul_198); convert_element_type_718 = None + sub_11 = torch.ops.aten.sub.Tensor(1, mul_198); mul_198 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_422, sub_11); convert_element_type_422 = sub_11 = None + add_87 = torch.ops.aten.add.Tensor(mul_200, 1); mul_200 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_199, add_87); mul_199 = add_87 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + view_1247 = torch.ops.aten.view.default(convert_element_type_720, [16384, 1792]); convert_element_type_720 = None + permute_285 = torch.ops.aten.permute.default(view_1247, [1, 0]) + mm_161 = torch.ops.aten.mm.default(permute_285, view_924); permute_285 = view_924 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 8, '0'); convert_element_type_419 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + permute_287 = torch.ops.aten.permute.default(permute_140, [1, 0]); permute_140 = None + mm_162 = torch.ops.aten.mm.default(view_1247, permute_287); view_1247 = permute_287 = None + view_1248 = torch.ops.aten.view.default(mm_162, [2, 8192, 4096]); mm_162 = None + add_88 = torch.ops.aten.add.Tensor(view_1246, view_1248); view_1246 = view_1248 = None + convert_element_type_725 = torch.ops.prims.convert_element_type.default(mm_161, torch.float32); mm_161 = None + reduce_scatter_tensor_71 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_725, 'avg', 8, '0'); convert_element_type_725 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_71); reduce_scatter_tensor_71 = None + split_88 = torch.ops.aten.split.Tensor(add_88, 1024, 1); add_88 = None + getitem_857 = split_88[0] + getitem_858 = split_88[1] + getitem_859 = split_88[2] + getitem_860 = split_88[3] + getitem_861 = split_88[4] + getitem_862 = split_88[5] + getitem_863 = split_88[6] + getitem_864 = split_88[7]; split_88 = None + cat_80 = torch.ops.aten.cat.default([getitem_857, getitem_858, getitem_859, getitem_860, getitem_861, getitem_862, getitem_863, getitem_864]); getitem_857 = getitem_858 = getitem_859 = getitem_860 = getitem_861 = getitem_862 = getitem_863 = getitem_864 = None + reduce_scatter_tensor_72 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_80, 'sum', 8, '1'); cat_80 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_72); reduce_scatter_tensor_72 = None + convert_element_type_726 = torch.ops.prims.convert_element_type.default(wait_tensor_266, torch.float32); wait_tensor_266 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(wait_tensor_165, torch.float32); wait_tensor_165 = None + mul_202 = torch.ops.aten.mul.Tensor(convert_element_type_726, convert_element_type_728); convert_element_type_728 = None + mul_204 = torch.ops.aten.mul.Tensor(mul_100, mul_202) + sum_21 = torch.ops.aten.sum.dim_IntList(mul_204, [2], True); mul_204 = None + div_7 = torch.ops.aten.div.Tensor(mul_100, 4096) + mul_205 = torch.ops.aten.mul.Tensor(div_7, sum_21); div_7 = sum_21 = None + sub_12 = torch.ops.aten.sub.Tensor(mul_202, mul_205); mul_202 = mul_205 = None + mul_206 = torch.ops.aten.mul.Tensor(sub_12, rsqrt_25); sub_12 = rsqrt_25 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_726, mul_100); convert_element_type_726 = mul_100 = None + sum_22 = torch.ops.aten.sum.dim_IntList(mul_207, [0, 1]); mul_207 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(sum_22, torch.bfloat16); sum_22 = None + all_reduce_7 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_730, 'sum', '1'); convert_element_type_730 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_7); all_reduce_7 = None + convert_element_type_731 = torch.ops.prims.convert_element_type.default(wait_tensor_267, torch.float32); wait_tensor_267 = None + reduce_scatter_tensor_73 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_731, 'avg', 8, '0'); convert_element_type_731 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_73); reduce_scatter_tensor_73 = None + add_89 = torch.ops.aten.add.Tensor(add_85, convert_element_type_729); add_85 = convert_element_type_729 = None + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_89, 8, '1') + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + split_89 = torch.ops.aten.split.Tensor(wait_tensor_269, 2); wait_tensor_269 = None + getitem_865 = split_89[0] + getitem_866 = split_89[1] + getitem_867 = split_89[2] + getitem_868 = split_89[3] + getitem_869 = split_89[4] + getitem_870 = split_89[5] + getitem_871 = split_89[6] + getitem_872 = split_89[7]; split_89 = None + cat_81 = torch.ops.aten.cat.default([getitem_865, getitem_866, getitem_867, getitem_868, getitem_869, getitem_870, getitem_871, getitem_872], 1); getitem_865 = getitem_866 = getitem_867 = getitem_868 = getitem_869 = getitem_870 = getitem_871 = getitem_872 = None + view_1249 = torch.ops.aten.view.default(cat_81, [16384, 4096]); cat_81 = None + permute_289 = torch.ops.aten.permute.default(view_1249, [1, 0]) + permute_138 = torch.ops.aten.permute.default(getitem_572, [0, 2, 1, 3]) + view_906 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + view_912 = torch.ops.aten.view.default(view_906, [16384, 512]); view_906 = None + mm_163 = torch.ops.aten.mm.default(permute_289, view_912); permute_289 = view_912 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16); primals_116 = None + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 8, '0'); convert_element_type_413 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + permute_291 = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None + mm_164 = torch.ops.aten.mm.default(view_1249, permute_291); view_1249 = permute_291 = None + view_1250 = torch.ops.aten.view.default(mm_164, [2, 8192, 512]); mm_164 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(mm_163, torch.float32); mm_163 = None + reduce_scatter_tensor_74 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_736, 'avg', 8, '0'); convert_element_type_736 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_74); reduce_scatter_tensor_74 = None + view_1251 = torch.ops.aten.view.default(view_1250, [2, 8192, 4, 128]); view_1250 = None + permute_293 = torch.ops.aten.permute.default(view_1251, [0, 2, 1, 3]); view_1251 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16); primals_112 = None + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 8, '0'); convert_element_type_397 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32); add_47 = None + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_158) + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_399, 8, '1'); convert_element_type_399 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + split_57 = torch.ops.aten.split.Tensor(wait_tensor_159, 2); wait_tensor_159 = None + getitem_564 = split_57[0] + getitem_565 = split_57[1] + getitem_566 = split_57[2] + getitem_567 = split_57[3] + getitem_568 = split_57[4] + getitem_569 = split_57[5] + getitem_570 = split_57[6] + getitem_571 = split_57[7]; split_57 = None + cat_49 = torch.ops.aten.cat.default([getitem_564, getitem_565, getitem_566, getitem_567, getitem_568, getitem_569, getitem_570, getitem_571], 1); getitem_564 = getitem_565 = getitem_566 = getitem_567 = getitem_568 = getitem_569 = getitem_570 = getitem_571 = None + view_879 = torch.ops.aten.view.default(cat_49, [16384, 4096]); cat_49 = None + view_880 = torch.ops.aten.view.default(mm_84, [2, 8192, 512]); mm_84 = None + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16); primals_114 = None + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 8, '0'); convert_element_type_403 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_85 = torch.ops.aten.mm.default(view_879, permute_133) + view_887 = torch.ops.aten.view.default(mm_85, [2, 8192, 128]); mm_85 = None + view_894 = torch.ops.aten.view.default(mm_86, [2, 8192, 128]); mm_86 = None + view_896 = torch.ops.aten.view.default(view_880, [2, 8192, -1, 128]); view_880 = None + view_897 = torch.ops.aten.view.default(view_887, [2, 8192, -1, 128]); view_887 = None + view_898 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 4, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_897, torch.float32); view_897 = None + view_900 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 1, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_900); view_900 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_37); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_902 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 4, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_37); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_903 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 1, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_903, torch.bfloat16); view_903 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 1, 4, 128]); unsqueeze_24 = None + view_904 = torch.ops.aten.view.default(expand_24, [2, 8192, 4, 128]); expand_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_898, 3); view_898 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 1, 4, 128]); unsqueeze_25 = None + view_905 = torch.ops.aten.view.default(expand_25, [2, 8192, 4, 128]); expand_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + permute_137 = torch.ops.aten.permute.default(view_905, [0, 2, 1, 3]); view_905 = None + _scaled_dot_product_cudnn_attention_backward_3 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_293, permute_135, permute_136, permute_137, getitem_572, getitem_573, getitem_578, getitem_579, None, None, None, 8192, 8192, 0.0, True); permute_293 = permute_135 = permute_136 = permute_137 = getitem_572 = getitem_573 = getitem_578 = getitem_579 = None + getitem_873 = _scaled_dot_product_cudnn_attention_backward_3[0] + getitem_874 = _scaled_dot_product_cudnn_attention_backward_3[1] + getitem_875 = _scaled_dot_product_cudnn_attention_backward_3[2]; _scaled_dot_product_cudnn_attention_backward_3 = None + permute_294 = torch.ops.aten.permute.default(getitem_875, [0, 2, 1, 3]); getitem_875 = None + permute_295 = torch.ops.aten.permute.default(getitem_874, [0, 2, 1, 3]); getitem_874 = None + permute_296 = torch.ops.aten.permute.default(getitem_873, [0, 2, 1, 3]); getitem_873 = None + view_1252 = torch.ops.aten.view.default(permute_294, [2, 8192, 1, 4, 128]); permute_294 = None + sum_23 = torch.ops.aten.sum.dim_IntList(view_1252, [3], True); view_1252 = None + squeeze_6 = torch.ops.aten.squeeze.dim(sum_23, 3); sum_23 = None + view_1253 = torch.ops.aten.view.default(permute_295, [2, 8192, 1, 4, 128]); permute_295 = None + sum_24 = torch.ops.aten.sum.dim_IntList(view_1253, [3], True); view_1253 = None + squeeze_7 = torch.ops.aten.squeeze.dim(sum_24, 3); sum_24 = None + convert_element_type_737 = torch.ops.prims.convert_element_type.default(squeeze_7, torch.float32); squeeze_7 = None + convert_element_type_738 = torch.ops.prims.convert_element_type.default(permute_296, torch.float32); permute_296 = None + view_1254 = torch.ops.aten.view.default(convert_element_type_737, [2, 8192, 1, 64, 2]); convert_element_type_737 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_1254); view_1254 = None + mul_208 = torch.ops.aten.mul.Tensor(view_as_complex_38, _conj); view_as_complex_38 = None + view_1255 = torch.ops.aten.view.default(convert_element_type_738, [2, 8192, 4, 64, 2]); convert_element_type_738 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_1255); view_1255 = None + mul_209 = torch.ops.aten.mul.Tensor(view_as_complex_39, _conj); view_as_complex_39 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_208); mul_208 = None + view_1256 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 1, 128]); view_as_real_38 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_1256, torch.bfloat16); view_1256 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_209); mul_209 = None + view_1257 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 4, 128]); view_as_real_39 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_1257, torch.bfloat16); view_1257 = None + view_1258 = torch.ops.aten.view.default(squeeze_6, [2, 8192, 128]); squeeze_6 = None + view_1259 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 128]); convert_element_type_739 = None + view_1260 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 512]); convert_element_type_740 = None + view_1261 = torch.ops.aten.view.default(view_1258, [16384, 128]); view_1258 = None + permute_297 = torch.ops.aten.permute.default(view_1261, [1, 0]) + mm_165 = torch.ops.aten.mm.default(permute_297, view_879); permute_297 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16); primals_115 = None + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 8, '0'); convert_element_type_406 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + permute_299 = torch.ops.aten.permute.default(permute_134, [1, 0]); permute_134 = None + mm_166 = torch.ops.aten.mm.default(view_1261, permute_299); view_1261 = permute_299 = None + view_1262 = torch.ops.aten.view.default(mm_166, [2, 8192, 4096]); mm_166 = None + convert_element_type_745 = torch.ops.prims.convert_element_type.default(mm_165, torch.float32); mm_165 = None + reduce_scatter_tensor_75 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_745, 'avg', 8, '0'); convert_element_type_745 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_75); reduce_scatter_tensor_75 = None + view_1263 = torch.ops.aten.view.default(view_1259, [16384, 128]); view_1259 = None + permute_301 = torch.ops.aten.permute.default(view_1263, [1, 0]) + mm_167 = torch.ops.aten.mm.default(permute_301, view_879); permute_301 = None + permute_303 = torch.ops.aten.permute.default(permute_133, [1, 0]); permute_133 = None + mm_168 = torch.ops.aten.mm.default(view_1263, permute_303); view_1263 = permute_303 = None + view_1264 = torch.ops.aten.view.default(mm_168, [2, 8192, 4096]); mm_168 = None + add_90 = torch.ops.aten.add.Tensor(view_1262, view_1264); view_1262 = view_1264 = None + convert_element_type_750 = torch.ops.prims.convert_element_type.default(mm_167, torch.float32); mm_167 = None + reduce_scatter_tensor_76 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_750, 'avg', 8, '0'); convert_element_type_750 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_76); reduce_scatter_tensor_76 = None + view_1265 = torch.ops.aten.view.default(view_1260, [16384, 512]); view_1260 = None + permute_305 = torch.ops.aten.permute.default(view_1265, [1, 0]) + mm_169 = torch.ops.aten.mm.default(permute_305, view_879); permute_305 = view_879 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16); primals_113 = None + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 8, '0'); convert_element_type_400 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + permute_307 = torch.ops.aten.permute.default(permute_132, [1, 0]); permute_132 = None + mm_170 = torch.ops.aten.mm.default(view_1265, permute_307); view_1265 = permute_307 = None + view_1266 = torch.ops.aten.view.default(mm_170, [2, 8192, 4096]); mm_170 = None + add_91 = torch.ops.aten.add.Tensor(add_90, view_1266); add_90 = view_1266 = None + convert_element_type_755 = torch.ops.prims.convert_element_type.default(mm_169, torch.float32); mm_169 = None + reduce_scatter_tensor_77 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_755, 'avg', 8, '0'); convert_element_type_755 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_77); reduce_scatter_tensor_77 = None + split_90 = torch.ops.aten.split.Tensor(add_91, 1024, 1); add_91 = None + getitem_876 = split_90[0] + getitem_877 = split_90[1] + getitem_878 = split_90[2] + getitem_879 = split_90[3] + getitem_880 = split_90[4] + getitem_881 = split_90[5] + getitem_882 = split_90[6] + getitem_883 = split_90[7]; split_90 = None + cat_82 = torch.ops.aten.cat.default([getitem_876, getitem_877, getitem_878, getitem_879, getitem_880, getitem_881, getitem_882, getitem_883]); getitem_876 = getitem_877 = getitem_878 = getitem_879 = getitem_880 = getitem_881 = getitem_882 = getitem_883 = None + reduce_scatter_tensor_78 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_82, 'sum', 8, '1'); cat_82 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_78); reduce_scatter_tensor_78 = None + convert_element_type_756 = torch.ops.prims.convert_element_type.default(wait_tensor_274, torch.float32); wait_tensor_274 = None + convert_element_type_758 = torch.ops.prims.convert_element_type.default(wait_tensor_158, torch.float32); wait_tensor_158 = None + mul_210 = torch.ops.aten.mul.Tensor(convert_element_type_756, convert_element_type_758); convert_element_type_758 = None + mul_212 = torch.ops.aten.mul.Tensor(mul_96, mul_210) + sum_25 = torch.ops.aten.sum.dim_IntList(mul_212, [2], True); mul_212 = None + div_8 = torch.ops.aten.div.Tensor(mul_96, 4096) + mul_213 = torch.ops.aten.mul.Tensor(div_8, sum_25); div_8 = sum_25 = None + sub_13 = torch.ops.aten.sub.Tensor(mul_210, mul_213); mul_210 = mul_213 = None + mul_214 = torch.ops.aten.mul.Tensor(sub_13, rsqrt_24); sub_13 = rsqrt_24 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_756, mul_96); convert_element_type_756 = mul_96 = None + sum_26 = torch.ops.aten.sum.dim_IntList(mul_215, [0, 1]); mul_215 = None + convert_element_type_759 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(sum_26, torch.bfloat16); sum_26 = None + all_reduce_8 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_760, 'sum', '1'); convert_element_type_760 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_8); all_reduce_8 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(wait_tensor_275, torch.float32); wait_tensor_275 = None + reduce_scatter_tensor_79 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_761, 'avg', 8, '0'); convert_element_type_761 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_79); reduce_scatter_tensor_79 = None + add_92 = torch.ops.aten.add.Tensor(add_89, convert_element_type_759); add_89 = convert_element_type_759 = None + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_92, 8, '1') + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + split_91 = torch.ops.aten.split.Tensor(wait_tensor_277, 2); wait_tensor_277 = None + getitem_884 = split_91[0] + getitem_885 = split_91[1] + getitem_886 = split_91[2] + getitem_887 = split_91[3] + getitem_888 = split_91[4] + getitem_889 = split_91[5] + getitem_890 = split_91[6] + getitem_891 = split_91[7]; split_91 = None + cat_83 = torch.ops.aten.cat.default([getitem_884, getitem_885, getitem_886, getitem_887, getitem_888, getitem_889, getitem_890, getitem_891], 1); getitem_884 = getitem_885 = getitem_886 = getitem_887 = getitem_888 = getitem_889 = getitem_890 = getitem_891 = None + view_1267 = torch.ops.aten.view.default(cat_83, [16384, 4096]); cat_83 = None + permute_309 = torch.ops.aten.permute.default(view_1267, [1, 0]) + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23); reduce_scatter_tensor_23 = None + add_45 = torch.ops.aten.add.Tensor(add_43, wait_tensor_151); wait_tensor_151 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 8, '0'); convert_element_type_383 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32); add_45 = None + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_152) + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_385, 8, '1'); convert_element_type_385 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + split_55 = torch.ops.aten.split.Tensor(wait_tensor_153, 2); wait_tensor_153 = None + getitem_548 = split_55[0] + getitem_549 = split_55[1] + getitem_550 = split_55[2] + getitem_551 = split_55[3] + getitem_552 = split_55[4] + getitem_553 = split_55[5] + getitem_554 = split_55[6] + getitem_555 = split_55[7]; split_55 = None + cat_47 = torch.ops.aten.cat.default([getitem_548, getitem_549, getitem_550, getitem_551, getitem_552, getitem_553, getitem_554, getitem_555], 1); getitem_548 = getitem_549 = getitem_550 = getitem_551 = getitem_552 = getitem_553 = getitem_554 = getitem_555 = None + view_852 = torch.ops.aten.view.default(cat_47, [16384, 4096]); cat_47 = None + view_853 = torch.ops.aten.view.default(mm_81, [2, 8192, 1792]); mm_81 = None + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_853, torch.float32); view_853 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16); primals_110 = None + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 8, '0'); convert_element_type_391 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + mm_82 = torch.ops.aten.mm.default(view_852, permute_130) + view_860 = torch.ops.aten.view.default(mm_82, [2, 8192, 1792]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_860) + view_867 = torch.ops.aten.view.default(mul_95, [16384, 1792]); mul_95 = None + mm_171 = torch.ops.aten.mm.default(permute_309, view_867); permute_309 = view_867 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16); primals_111 = None + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 8, '0'); convert_element_type_394 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + permute_311 = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None + mm_172 = torch.ops.aten.mm.default(view_1267, permute_311); view_1267 = permute_311 = None + view_1268 = torch.ops.aten.view.default(mm_172, [2, 8192, 1792]); mm_172 = None + convert_element_type_766 = torch.ops.prims.convert_element_type.default(mm_171, torch.float32); mm_171 = None + reduce_scatter_tensor_80 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_766, 'avg', 8, '0'); convert_element_type_766 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_80); reduce_scatter_tensor_80 = None + mul_216 = torch.ops.aten.mul.Tensor(view_1268, convert_element_type_390); convert_element_type_390 = None + mul_217 = torch.ops.aten.mul.Tensor(view_1268, view_860); view_1268 = view_860 = None + view_1269 = torch.ops.aten.view.default(mul_216, [16384, 1792]); mul_216 = None + permute_313 = torch.ops.aten.permute.default(view_1269, [1, 0]) + mm_173 = torch.ops.aten.mm.default(permute_313, view_852); permute_313 = None + permute_315 = torch.ops.aten.permute.default(permute_130, [1, 0]); permute_130 = None + mm_174 = torch.ops.aten.mm.default(view_1269, permute_315); view_1269 = permute_315 = None + view_1270 = torch.ops.aten.view.default(mm_174, [2, 8192, 4096]); mm_174 = None + convert_element_type_771 = torch.ops.prims.convert_element_type.default(mm_173, torch.float32); mm_173 = None + reduce_scatter_tensor_81 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_771, 'avg', 8, '0'); convert_element_type_771 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_81); reduce_scatter_tensor_81 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(mul_217, torch.float32); mul_217 = None + neg_4 = torch.ops.aten.neg.default(convert_element_type_389) + exp_4 = torch.ops.aten.exp.default(neg_4); neg_4 = None + add_93 = torch.ops.aten.add.Tensor(exp_4, 1); exp_4 = None + reciprocal_4 = torch.ops.aten.reciprocal.default(add_93); add_93 = None + mul_218 = torch.ops.aten.mul.Tensor(reciprocal_4, 1); reciprocal_4 = None + mul_219 = torch.ops.aten.mul.Tensor(convert_element_type_772, mul_218); convert_element_type_772 = None + sub_14 = torch.ops.aten.sub.Tensor(1, mul_218); mul_218 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_389, sub_14); convert_element_type_389 = sub_14 = None + add_94 = torch.ops.aten.add.Tensor(mul_220, 1); mul_220 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_219, add_94); mul_219 = add_94 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + view_1271 = torch.ops.aten.view.default(convert_element_type_774, [16384, 1792]); convert_element_type_774 = None + permute_317 = torch.ops.aten.permute.default(view_1271, [1, 0]) + mm_175 = torch.ops.aten.mm.default(permute_317, view_852); permute_317 = view_852 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16); primals_109 = None + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 8, '0'); convert_element_type_386 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_154, [1, 0]); wait_tensor_154 = None + permute_319 = torch.ops.aten.permute.default(permute_129, [1, 0]); permute_129 = None + mm_176 = torch.ops.aten.mm.default(view_1271, permute_319); view_1271 = permute_319 = None + view_1272 = torch.ops.aten.view.default(mm_176, [2, 8192, 4096]); mm_176 = None + add_95 = torch.ops.aten.add.Tensor(view_1270, view_1272); view_1270 = view_1272 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(mm_175, torch.float32); mm_175 = None + reduce_scatter_tensor_82 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_779, 'avg', 8, '0'); convert_element_type_779 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_82); reduce_scatter_tensor_82 = None + split_92 = torch.ops.aten.split.Tensor(add_95, 1024, 1); add_95 = None + getitem_892 = split_92[0] + getitem_893 = split_92[1] + getitem_894 = split_92[2] + getitem_895 = split_92[3] + getitem_896 = split_92[4] + getitem_897 = split_92[5] + getitem_898 = split_92[6] + getitem_899 = split_92[7]; split_92 = None + cat_84 = torch.ops.aten.cat.default([getitem_892, getitem_893, getitem_894, getitem_895, getitem_896, getitem_897, getitem_898, getitem_899]); getitem_892 = getitem_893 = getitem_894 = getitem_895 = getitem_896 = getitem_897 = getitem_898 = getitem_899 = None + reduce_scatter_tensor_83 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_84, 'sum', 8, '1'); cat_84 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_83); reduce_scatter_tensor_83 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(wait_tensor_281, torch.float32); wait_tensor_281 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(wait_tensor_152, torch.float32); wait_tensor_152 = None + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_780, convert_element_type_782); convert_element_type_782 = None + mul_224 = torch.ops.aten.mul.Tensor(mul_92, mul_222) + sum_27 = torch.ops.aten.sum.dim_IntList(mul_224, [2], True); mul_224 = None + div_9 = torch.ops.aten.div.Tensor(mul_92, 4096) + mul_225 = torch.ops.aten.mul.Tensor(div_9, sum_27); div_9 = sum_27 = None + sub_15 = torch.ops.aten.sub.Tensor(mul_222, mul_225); mul_222 = mul_225 = None + mul_226 = torch.ops.aten.mul.Tensor(sub_15, rsqrt_23); sub_15 = rsqrt_23 = None + mul_227 = torch.ops.aten.mul.Tensor(convert_element_type_780, mul_92); convert_element_type_780 = mul_92 = None + sum_28 = torch.ops.aten.sum.dim_IntList(mul_227, [0, 1]); mul_227 = None + convert_element_type_783 = torch.ops.prims.convert_element_type.default(mul_226, torch.bfloat16); mul_226 = None + convert_element_type_784 = torch.ops.prims.convert_element_type.default(sum_28, torch.bfloat16); sum_28 = None + all_reduce_9 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_784, 'sum', '1'); convert_element_type_784 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_9); all_reduce_9 = None + convert_element_type_785 = torch.ops.prims.convert_element_type.default(wait_tensor_282, torch.float32); wait_tensor_282 = None + reduce_scatter_tensor_84 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_785, 'avg', 8, '0'); convert_element_type_785 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_84); reduce_scatter_tensor_84 = None + add_96 = torch.ops.aten.add.Tensor(add_92, convert_element_type_783); add_92 = convert_element_type_783 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_96, 8, '1') + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + split_93 = torch.ops.aten.split.Tensor(wait_tensor_284, 2); wait_tensor_284 = None + getitem_900 = split_93[0] + getitem_901 = split_93[1] + getitem_902 = split_93[2] + getitem_903 = split_93[3] + getitem_904 = split_93[4] + getitem_905 = split_93[5] + getitem_906 = split_93[6] + getitem_907 = split_93[7]; split_93 = None + cat_85 = torch.ops.aten.cat.default([getitem_900, getitem_901, getitem_902, getitem_903, getitem_904, getitem_905, getitem_906, getitem_907], 1); getitem_900 = getitem_901 = getitem_902 = getitem_903 = getitem_904 = getitem_905 = getitem_906 = getitem_907 = None + view_1273 = torch.ops.aten.view.default(cat_85, [16384, 4096]); cat_85 = None + permute_321 = torch.ops.aten.permute.default(view_1273, [1, 0]) + permute_127 = torch.ops.aten.permute.default(getitem_531, [0, 2, 1, 3]) + view_834 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + view_840 = torch.ops.aten.view.default(view_834, [16384, 512]); view_834 = None + mm_177 = torch.ops.aten.mm.default(permute_321, view_840); permute_321 = view_840 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 8, '0'); convert_element_type_380 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_150, [1, 0]); wait_tensor_150 = None + permute_323 = torch.ops.aten.permute.default(permute_128, [1, 0]); permute_128 = None + mm_178 = torch.ops.aten.mm.default(view_1273, permute_323); view_1273 = permute_323 = None + view_1274 = torch.ops.aten.view.default(mm_178, [2, 8192, 512]); mm_178 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(mm_177, torch.float32); mm_177 = None + reduce_scatter_tensor_85 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_790, 'avg', 8, '0'); convert_element_type_790 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_85); reduce_scatter_tensor_85 = None + view_1275 = torch.ops.aten.view.default(view_1274, [2, 8192, 4, 128]); view_1274 = None + permute_325 = torch.ops.aten.permute.default(view_1275, [0, 2, 1, 3]); view_1275 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 8, '0'); convert_element_type_364 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32); add_43 = None + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_145) + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_366, 8, '1'); convert_element_type_366 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + split_53 = torch.ops.aten.split.Tensor(wait_tensor_146, 2); wait_tensor_146 = None + getitem_523 = split_53[0] + getitem_524 = split_53[1] + getitem_525 = split_53[2] + getitem_526 = split_53[3] + getitem_527 = split_53[4] + getitem_528 = split_53[5] + getitem_529 = split_53[6] + getitem_530 = split_53[7]; split_53 = None + cat_45 = torch.ops.aten.cat.default([getitem_523, getitem_524, getitem_525, getitem_526, getitem_527, getitem_528, getitem_529, getitem_530], 1); getitem_523 = getitem_524 = getitem_525 = getitem_526 = getitem_527 = getitem_528 = getitem_529 = getitem_530 = None + view_807 = torch.ops.aten.view.default(cat_45, [16384, 4096]); cat_45 = None + view_808 = torch.ops.aten.view.default(mm_77, [2, 8192, 512]); mm_77 = None + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 8, '0'); convert_element_type_370 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_78 = torch.ops.aten.mm.default(view_807, permute_122) + view_815 = torch.ops.aten.view.default(mm_78, [2, 8192, 128]); mm_78 = None + view_822 = torch.ops.aten.view.default(mm_79, [2, 8192, 128]); mm_79 = None + view_824 = torch.ops.aten.view.default(view_808, [2, 8192, -1, 128]); view_808 = None + view_825 = torch.ops.aten.view.default(view_815, [2, 8192, -1, 128]); view_815 = None + view_826 = torch.ops.aten.view.default(view_822, [2, 8192, -1, 128]); view_822 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_824, torch.float32); view_824 = None + view_827 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 4, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_827); view_827 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_825, torch.float32); view_825 = None + view_828 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 1, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_828); view_828 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_37); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_830 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 4, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_37); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_831 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 1, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_830, torch.bfloat16); view_830 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_831, torch.bfloat16); view_831 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 1, 4, 128]); unsqueeze_22 = None + view_832 = torch.ops.aten.view.default(expand_22, [2, 8192, 4, 128]); expand_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_826, 3); view_826 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 1, 4, 128]); unsqueeze_23 = None + view_833 = torch.ops.aten.view.default(expand_23, [2, 8192, 4, 128]); expand_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_832, [0, 2, 1, 3]); view_832 = None + permute_126 = torch.ops.aten.permute.default(view_833, [0, 2, 1, 3]); view_833 = None + _scaled_dot_product_cudnn_attention_backward_4 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_325, permute_124, permute_125, permute_126, getitem_531, getitem_532, getitem_537, getitem_538, None, None, None, 8192, 8192, 0.0, True); permute_325 = permute_124 = permute_125 = permute_126 = getitem_531 = getitem_532 = getitem_537 = getitem_538 = None + getitem_908 = _scaled_dot_product_cudnn_attention_backward_4[0] + getitem_909 = _scaled_dot_product_cudnn_attention_backward_4[1] + getitem_910 = _scaled_dot_product_cudnn_attention_backward_4[2]; _scaled_dot_product_cudnn_attention_backward_4 = None + permute_326 = torch.ops.aten.permute.default(getitem_910, [0, 2, 1, 3]); getitem_910 = None + permute_327 = torch.ops.aten.permute.default(getitem_909, [0, 2, 1, 3]); getitem_909 = None + permute_328 = torch.ops.aten.permute.default(getitem_908, [0, 2, 1, 3]); getitem_908 = None + view_1276 = torch.ops.aten.view.default(permute_326, [2, 8192, 1, 4, 128]); permute_326 = None + sum_29 = torch.ops.aten.sum.dim_IntList(view_1276, [3], True); view_1276 = None + squeeze_8 = torch.ops.aten.squeeze.dim(sum_29, 3); sum_29 = None + view_1277 = torch.ops.aten.view.default(permute_327, [2, 8192, 1, 4, 128]); permute_327 = None + sum_30 = torch.ops.aten.sum.dim_IntList(view_1277, [3], True); view_1277 = None + squeeze_9 = torch.ops.aten.squeeze.dim(sum_30, 3); sum_30 = None + convert_element_type_791 = torch.ops.prims.convert_element_type.default(squeeze_9, torch.float32); squeeze_9 = None + convert_element_type_792 = torch.ops.prims.convert_element_type.default(permute_328, torch.float32); permute_328 = None + view_1278 = torch.ops.aten.view.default(convert_element_type_791, [2, 8192, 1, 64, 2]); convert_element_type_791 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_1278); view_1278 = None + mul_228 = torch.ops.aten.mul.Tensor(view_as_complex_40, _conj); view_as_complex_40 = None + view_1279 = torch.ops.aten.view.default(convert_element_type_792, [2, 8192, 4, 64, 2]); convert_element_type_792 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_1279); view_1279 = None + mul_229 = torch.ops.aten.mul.Tensor(view_as_complex_41, _conj); view_as_complex_41 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_228); mul_228 = None + view_1280 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 1, 128]); view_as_real_40 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(view_1280, torch.bfloat16); view_1280 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_229); mul_229 = None + view_1281 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 4, 128]); view_as_real_41 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(view_1281, torch.bfloat16); view_1281 = None + view_1282 = torch.ops.aten.view.default(squeeze_8, [2, 8192, 128]); squeeze_8 = None + view_1283 = torch.ops.aten.view.default(convert_element_type_793, [2, 8192, 128]); convert_element_type_793 = None + view_1284 = torch.ops.aten.view.default(convert_element_type_794, [2, 8192, 512]); convert_element_type_794 = None + view_1285 = torch.ops.aten.view.default(view_1282, [16384, 128]); view_1282 = None + permute_329 = torch.ops.aten.permute.default(view_1285, [1, 0]) + mm_179 = torch.ops.aten.mm.default(permute_329, view_807); permute_329 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 8, '0'); convert_element_type_373 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + permute_331 = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None + mm_180 = torch.ops.aten.mm.default(view_1285, permute_331); view_1285 = permute_331 = None + view_1286 = torch.ops.aten.view.default(mm_180, [2, 8192, 4096]); mm_180 = None + convert_element_type_799 = torch.ops.prims.convert_element_type.default(mm_179, torch.float32); mm_179 = None + reduce_scatter_tensor_86 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_799, 'avg', 8, '0'); convert_element_type_799 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_86); reduce_scatter_tensor_86 = None + view_1287 = torch.ops.aten.view.default(view_1283, [16384, 128]); view_1283 = None + permute_333 = torch.ops.aten.permute.default(view_1287, [1, 0]) + mm_181 = torch.ops.aten.mm.default(permute_333, view_807); permute_333 = None + permute_335 = torch.ops.aten.permute.default(permute_122, [1, 0]); permute_122 = None + mm_182 = torch.ops.aten.mm.default(view_1287, permute_335); view_1287 = permute_335 = None + view_1288 = torch.ops.aten.view.default(mm_182, [2, 8192, 4096]); mm_182 = None + add_97 = torch.ops.aten.add.Tensor(view_1286, view_1288); view_1286 = view_1288 = None + convert_element_type_804 = torch.ops.prims.convert_element_type.default(mm_181, torch.float32); mm_181 = None + reduce_scatter_tensor_87 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_804, 'avg', 8, '0'); convert_element_type_804 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_87); reduce_scatter_tensor_87 = None + view_1289 = torch.ops.aten.view.default(view_1284, [16384, 512]); view_1284 = None + permute_337 = torch.ops.aten.permute.default(view_1289, [1, 0]) + mm_183 = torch.ops.aten.mm.default(permute_337, view_807); permute_337 = view_807 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 8, '0'); convert_element_type_367 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + permute_339 = torch.ops.aten.permute.default(permute_121, [1, 0]); permute_121 = None + mm_184 = torch.ops.aten.mm.default(view_1289, permute_339); view_1289 = permute_339 = None + view_1290 = torch.ops.aten.view.default(mm_184, [2, 8192, 4096]); mm_184 = None + add_98 = torch.ops.aten.add.Tensor(add_97, view_1290); add_97 = view_1290 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(mm_183, torch.float32); mm_183 = None + reduce_scatter_tensor_88 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_809, 'avg', 8, '0'); convert_element_type_809 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_88); reduce_scatter_tensor_88 = None + split_94 = torch.ops.aten.split.Tensor(add_98, 1024, 1); add_98 = None + getitem_911 = split_94[0] + getitem_912 = split_94[1] + getitem_913 = split_94[2] + getitem_914 = split_94[3] + getitem_915 = split_94[4] + getitem_916 = split_94[5] + getitem_917 = split_94[6] + getitem_918 = split_94[7]; split_94 = None + cat_86 = torch.ops.aten.cat.default([getitem_911, getitem_912, getitem_913, getitem_914, getitem_915, getitem_916, getitem_917, getitem_918]); getitem_911 = getitem_912 = getitem_913 = getitem_914 = getitem_915 = getitem_916 = getitem_917 = getitem_918 = None + reduce_scatter_tensor_89 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_86, 'sum', 8, '1'); cat_86 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_89); reduce_scatter_tensor_89 = None + convert_element_type_810 = torch.ops.prims.convert_element_type.default(wait_tensor_289, torch.float32); wait_tensor_289 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(wait_tensor_145, torch.float32); wait_tensor_145 = None + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_810, convert_element_type_812); convert_element_type_812 = None + mul_232 = torch.ops.aten.mul.Tensor(mul_88, mul_230) + sum_31 = torch.ops.aten.sum.dim_IntList(mul_232, [2], True); mul_232 = None + div_10 = torch.ops.aten.div.Tensor(mul_88, 4096) + mul_233 = torch.ops.aten.mul.Tensor(div_10, sum_31); div_10 = sum_31 = None + sub_16 = torch.ops.aten.sub.Tensor(mul_230, mul_233); mul_230 = mul_233 = None + mul_234 = torch.ops.aten.mul.Tensor(sub_16, rsqrt_22); sub_16 = rsqrt_22 = None + mul_235 = torch.ops.aten.mul.Tensor(convert_element_type_810, mul_88); convert_element_type_810 = mul_88 = None + sum_32 = torch.ops.aten.sum.dim_IntList(mul_235, [0, 1]); mul_235 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(mul_234, torch.bfloat16); mul_234 = None + convert_element_type_814 = torch.ops.prims.convert_element_type.default(sum_32, torch.bfloat16); sum_32 = None + all_reduce_10 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_814, 'sum', '1'); convert_element_type_814 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_10); all_reduce_10 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(wait_tensor_290, torch.float32); wait_tensor_290 = None + reduce_scatter_tensor_90 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_815, 'avg', 8, '0'); convert_element_type_815 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_90); reduce_scatter_tensor_90 = None + add_99 = torch.ops.aten.add.Tensor(add_96, convert_element_type_813); add_96 = convert_element_type_813 = None + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_99, 8, '1') + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + split_95 = torch.ops.aten.split.Tensor(wait_tensor_292, 2); wait_tensor_292 = None + getitem_919 = split_95[0] + getitem_920 = split_95[1] + getitem_921 = split_95[2] + getitem_922 = split_95[3] + getitem_923 = split_95[4] + getitem_924 = split_95[5] + getitem_925 = split_95[6] + getitem_926 = split_95[7]; split_95 = None + cat_87 = torch.ops.aten.cat.default([getitem_919, getitem_920, getitem_921, getitem_922, getitem_923, getitem_924, getitem_925, getitem_926], 1); getitem_919 = getitem_920 = getitem_921 = getitem_922 = getitem_923 = getitem_924 = getitem_925 = getitem_926 = None + view_1291 = torch.ops.aten.view.default(cat_87, [16384, 4096]); cat_87 = None + permute_341 = torch.ops.aten.permute.default(view_1291, [1, 0]) + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21); reduce_scatter_tensor_21 = None + add_41 = torch.ops.aten.add.Tensor(add_39, wait_tensor_138); wait_tensor_138 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16); primals_99 = None + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 8, '0'); convert_element_type_350 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32); add_41 = None + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_139) + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_352, 8, '1'); convert_element_type_352 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + split_51 = torch.ops.aten.split.Tensor(wait_tensor_140, 2); wait_tensor_140 = None + getitem_507 = split_51[0] + getitem_508 = split_51[1] + getitem_509 = split_51[2] + getitem_510 = split_51[3] + getitem_511 = split_51[4] + getitem_512 = split_51[5] + getitem_513 = split_51[6] + getitem_514 = split_51[7]; split_51 = None + cat_43 = torch.ops.aten.cat.default([getitem_507, getitem_508, getitem_509, getitem_510, getitem_511, getitem_512, getitem_513, getitem_514], 1); getitem_507 = getitem_508 = getitem_509 = getitem_510 = getitem_511 = getitem_512 = getitem_513 = getitem_514 = None + view_780 = torch.ops.aten.view.default(cat_43, [16384, 4096]); cat_43 = None + view_781 = torch.ops.aten.view.default(mm_74, [2, 8192, 1792]); mm_74 = None + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_781, torch.float32); view_781 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 8, '0'); convert_element_type_358 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + mm_75 = torch.ops.aten.mm.default(view_780, permute_119) + view_788 = torch.ops.aten.view.default(mm_75, [2, 8192, 1792]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_788) + view_795 = torch.ops.aten.view.default(mul_87, [16384, 1792]); mul_87 = None + mm_185 = torch.ops.aten.mm.default(permute_341, view_795); permute_341 = view_795 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 8, '0'); convert_element_type_361 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + permute_343 = torch.ops.aten.permute.default(permute_120, [1, 0]); permute_120 = None + mm_186 = torch.ops.aten.mm.default(view_1291, permute_343); view_1291 = permute_343 = None + view_1292 = torch.ops.aten.view.default(mm_186, [2, 8192, 1792]); mm_186 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(mm_185, torch.float32); mm_185 = None + reduce_scatter_tensor_91 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_820, 'avg', 8, '0'); convert_element_type_820 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_91); reduce_scatter_tensor_91 = None + mul_236 = torch.ops.aten.mul.Tensor(view_1292, convert_element_type_357); convert_element_type_357 = None + mul_237 = torch.ops.aten.mul.Tensor(view_1292, view_788); view_1292 = view_788 = None + view_1293 = torch.ops.aten.view.default(mul_236, [16384, 1792]); mul_236 = None + permute_345 = torch.ops.aten.permute.default(view_1293, [1, 0]) + mm_187 = torch.ops.aten.mm.default(permute_345, view_780); permute_345 = None + permute_347 = torch.ops.aten.permute.default(permute_119, [1, 0]); permute_119 = None + mm_188 = torch.ops.aten.mm.default(view_1293, permute_347); view_1293 = permute_347 = None + view_1294 = torch.ops.aten.view.default(mm_188, [2, 8192, 4096]); mm_188 = None + convert_element_type_825 = torch.ops.prims.convert_element_type.default(mm_187, torch.float32); mm_187 = None + reduce_scatter_tensor_92 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_825, 'avg', 8, '0'); convert_element_type_825 = None + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_92); reduce_scatter_tensor_92 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(mul_237, torch.float32); mul_237 = None + neg_5 = torch.ops.aten.neg.default(convert_element_type_356) + exp_5 = torch.ops.aten.exp.default(neg_5); neg_5 = None + add_100 = torch.ops.aten.add.Tensor(exp_5, 1); exp_5 = None + reciprocal_5 = torch.ops.aten.reciprocal.default(add_100); add_100 = None + mul_238 = torch.ops.aten.mul.Tensor(reciprocal_5, 1); reciprocal_5 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_826, mul_238); convert_element_type_826 = None + sub_17 = torch.ops.aten.sub.Tensor(1, mul_238); mul_238 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_356, sub_17); convert_element_type_356 = sub_17 = None + add_101 = torch.ops.aten.add.Tensor(mul_240, 1); mul_240 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_239, add_101); mul_239 = add_101 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + view_1295 = torch.ops.aten.view.default(convert_element_type_828, [16384, 1792]); convert_element_type_828 = None + permute_349 = torch.ops.aten.permute.default(view_1295, [1, 0]) + mm_189 = torch.ops.aten.mm.default(permute_349, view_780); permute_349 = view_780 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 8, '0'); convert_element_type_353 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + permute_351 = torch.ops.aten.permute.default(permute_118, [1, 0]); permute_118 = None + mm_190 = torch.ops.aten.mm.default(view_1295, permute_351); view_1295 = permute_351 = None + view_1296 = torch.ops.aten.view.default(mm_190, [2, 8192, 4096]); mm_190 = None + add_102 = torch.ops.aten.add.Tensor(view_1294, view_1296); view_1294 = view_1296 = None + convert_element_type_833 = torch.ops.prims.convert_element_type.default(mm_189, torch.float32); mm_189 = None + reduce_scatter_tensor_93 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_833, 'avg', 8, '0'); convert_element_type_833 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_93); reduce_scatter_tensor_93 = None + split_96 = torch.ops.aten.split.Tensor(add_102, 1024, 1); add_102 = None + getitem_927 = split_96[0] + getitem_928 = split_96[1] + getitem_929 = split_96[2] + getitem_930 = split_96[3] + getitem_931 = split_96[4] + getitem_932 = split_96[5] + getitem_933 = split_96[6] + getitem_934 = split_96[7]; split_96 = None + cat_88 = torch.ops.aten.cat.default([getitem_927, getitem_928, getitem_929, getitem_930, getitem_931, getitem_932, getitem_933, getitem_934]); getitem_927 = getitem_928 = getitem_929 = getitem_930 = getitem_931 = getitem_932 = getitem_933 = getitem_934 = None + reduce_scatter_tensor_94 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_88, 'sum', 8, '1'); cat_88 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_94); reduce_scatter_tensor_94 = None + convert_element_type_834 = torch.ops.prims.convert_element_type.default(wait_tensor_296, torch.float32); wait_tensor_296 = None + convert_element_type_836 = torch.ops.prims.convert_element_type.default(wait_tensor_139, torch.float32); wait_tensor_139 = None + mul_242 = torch.ops.aten.mul.Tensor(convert_element_type_834, convert_element_type_836); convert_element_type_836 = None + mul_244 = torch.ops.aten.mul.Tensor(mul_84, mul_242) + sum_33 = torch.ops.aten.sum.dim_IntList(mul_244, [2], True); mul_244 = None + div_11 = torch.ops.aten.div.Tensor(mul_84, 4096) + mul_245 = torch.ops.aten.mul.Tensor(div_11, sum_33); div_11 = sum_33 = None + sub_18 = torch.ops.aten.sub.Tensor(mul_242, mul_245); mul_242 = mul_245 = None + mul_246 = torch.ops.aten.mul.Tensor(sub_18, rsqrt_21); sub_18 = rsqrt_21 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_834, mul_84); convert_element_type_834 = mul_84 = None + sum_34 = torch.ops.aten.sum.dim_IntList(mul_247, [0, 1]); mul_247 = None + convert_element_type_837 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(sum_34, torch.bfloat16); sum_34 = None + all_reduce_11 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_838, 'sum', '1'); convert_element_type_838 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_11); all_reduce_11 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(wait_tensor_297, torch.float32); wait_tensor_297 = None + reduce_scatter_tensor_95 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_839, 'avg', 8, '0'); convert_element_type_839 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_95); reduce_scatter_tensor_95 = None + add_103 = torch.ops.aten.add.Tensor(add_99, convert_element_type_837); add_99 = convert_element_type_837 = None + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_103, 8, '1') + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + split_97 = torch.ops.aten.split.Tensor(wait_tensor_299, 2); wait_tensor_299 = None + getitem_935 = split_97[0] + getitem_936 = split_97[1] + getitem_937 = split_97[2] + getitem_938 = split_97[3] + getitem_939 = split_97[4] + getitem_940 = split_97[5] + getitem_941 = split_97[6] + getitem_942 = split_97[7]; split_97 = None + cat_89 = torch.ops.aten.cat.default([getitem_935, getitem_936, getitem_937, getitem_938, getitem_939, getitem_940, getitem_941, getitem_942], 1); getitem_935 = getitem_936 = getitem_937 = getitem_938 = getitem_939 = getitem_940 = getitem_941 = getitem_942 = None + view_1297 = torch.ops.aten.view.default(cat_89, [16384, 4096]); cat_89 = None + permute_353 = torch.ops.aten.permute.default(view_1297, [1, 0]) + permute_116 = torch.ops.aten.permute.default(getitem_490, [0, 2, 1, 3]) + view_762 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + view_768 = torch.ops.aten.view.default(view_762, [16384, 512]); view_762 = None + mm_191 = torch.ops.aten.mm.default(permute_353, view_768); permute_353 = view_768 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16); primals_98 = None + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 8, '0'); convert_element_type_347 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + permute_355 = torch.ops.aten.permute.default(permute_117, [1, 0]); permute_117 = None + mm_192 = torch.ops.aten.mm.default(view_1297, permute_355); view_1297 = permute_355 = None + view_1298 = torch.ops.aten.view.default(mm_192, [2, 8192, 512]); mm_192 = None + convert_element_type_844 = torch.ops.prims.convert_element_type.default(mm_191, torch.float32); mm_191 = None + reduce_scatter_tensor_96 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_844, 'avg', 8, '0'); convert_element_type_844 = None + wait_tensor_300 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_96); reduce_scatter_tensor_96 = None + view_1299 = torch.ops.aten.view.default(view_1298, [2, 8192, 4, 128]); view_1298 = None + permute_357 = torch.ops.aten.permute.default(view_1299, [0, 2, 1, 3]); view_1299 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16); primals_94 = None + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 8, '0'); convert_element_type_331 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32); add_39 = None + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_132) + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_333, 8, '1'); convert_element_type_333 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + split_49 = torch.ops.aten.split.Tensor(wait_tensor_133, 2); wait_tensor_133 = None + getitem_482 = split_49[0] + getitem_483 = split_49[1] + getitem_484 = split_49[2] + getitem_485 = split_49[3] + getitem_486 = split_49[4] + getitem_487 = split_49[5] + getitem_488 = split_49[6] + getitem_489 = split_49[7]; split_49 = None + cat_41 = torch.ops.aten.cat.default([getitem_482, getitem_483, getitem_484, getitem_485, getitem_486, getitem_487, getitem_488, getitem_489], 1); getitem_482 = getitem_483 = getitem_484 = getitem_485 = getitem_486 = getitem_487 = getitem_488 = getitem_489 = None + view_735 = torch.ops.aten.view.default(cat_41, [16384, 4096]); cat_41 = None + view_736 = torch.ops.aten.view.default(mm_70, [2, 8192, 512]); mm_70 = None + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16); primals_96 = None + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 8, '0'); convert_element_type_337 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + mm_71 = torch.ops.aten.mm.default(view_735, permute_111) + view_743 = torch.ops.aten.view.default(mm_71, [2, 8192, 128]); mm_71 = None + view_750 = torch.ops.aten.view.default(mm_72, [2, 8192, 128]); mm_72 = None + view_752 = torch.ops.aten.view.default(view_736, [2, 8192, -1, 128]); view_736 = None + view_753 = torch.ops.aten.view.default(view_743, [2, 8192, -1, 128]); view_743 = None + view_754 = torch.ops.aten.view.default(view_750, [2, 8192, -1, 128]); view_750 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_752, torch.float32); view_752 = None + view_755 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 4, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_755); view_755 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_753, torch.float32); view_753 = None + view_756 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 1, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_756); view_756 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_37); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_758 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 4, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_37); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_759 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 1, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_758, torch.bfloat16); view_758 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_759, torch.bfloat16); view_759 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 1, 4, 128]); unsqueeze_20 = None + view_760 = torch.ops.aten.view.default(expand_20, [2, 8192, 4, 128]); expand_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_754, 3); view_754 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 1, 4, 128]); unsqueeze_21 = None + view_761 = torch.ops.aten.view.default(expand_21, [2, 8192, 4, 128]); expand_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_760, [0, 2, 1, 3]); view_760 = None + permute_115 = torch.ops.aten.permute.default(view_761, [0, 2, 1, 3]); view_761 = None + _scaled_dot_product_cudnn_attention_backward_5 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_357, permute_113, permute_114, permute_115, getitem_490, getitem_491, getitem_496, getitem_497, None, None, None, 8192, 8192, 0.0, True); permute_357 = permute_113 = permute_114 = permute_115 = getitem_490 = getitem_491 = getitem_496 = getitem_497 = None + getitem_943 = _scaled_dot_product_cudnn_attention_backward_5[0] + getitem_944 = _scaled_dot_product_cudnn_attention_backward_5[1] + getitem_945 = _scaled_dot_product_cudnn_attention_backward_5[2]; _scaled_dot_product_cudnn_attention_backward_5 = None + permute_358 = torch.ops.aten.permute.default(getitem_945, [0, 2, 1, 3]); getitem_945 = None + permute_359 = torch.ops.aten.permute.default(getitem_944, [0, 2, 1, 3]); getitem_944 = None + permute_360 = torch.ops.aten.permute.default(getitem_943, [0, 2, 1, 3]); getitem_943 = None + view_1300 = torch.ops.aten.view.default(permute_358, [2, 8192, 1, 4, 128]); permute_358 = None + sum_35 = torch.ops.aten.sum.dim_IntList(view_1300, [3], True); view_1300 = None + squeeze_10 = torch.ops.aten.squeeze.dim(sum_35, 3); sum_35 = None + view_1301 = torch.ops.aten.view.default(permute_359, [2, 8192, 1, 4, 128]); permute_359 = None + sum_36 = torch.ops.aten.sum.dim_IntList(view_1301, [3], True); view_1301 = None + squeeze_11 = torch.ops.aten.squeeze.dim(sum_36, 3); sum_36 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(squeeze_11, torch.float32); squeeze_11 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(permute_360, torch.float32); permute_360 = None + view_1302 = torch.ops.aten.view.default(convert_element_type_845, [2, 8192, 1, 64, 2]); convert_element_type_845 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_1302); view_1302 = None + mul_248 = torch.ops.aten.mul.Tensor(view_as_complex_42, _conj); view_as_complex_42 = None + view_1303 = torch.ops.aten.view.default(convert_element_type_846, [2, 8192, 4, 64, 2]); convert_element_type_846 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_1303); view_1303 = None + mul_249 = torch.ops.aten.mul.Tensor(view_as_complex_43, _conj); view_as_complex_43 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_248); mul_248 = None + view_1304 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 1, 128]); view_as_real_42 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(view_1304, torch.bfloat16); view_1304 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_249); mul_249 = None + view_1305 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 4, 128]); view_as_real_43 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(view_1305, torch.bfloat16); view_1305 = None + view_1306 = torch.ops.aten.view.default(squeeze_10, [2, 8192, 128]); squeeze_10 = None + view_1307 = torch.ops.aten.view.default(convert_element_type_847, [2, 8192, 128]); convert_element_type_847 = None + view_1308 = torch.ops.aten.view.default(convert_element_type_848, [2, 8192, 512]); convert_element_type_848 = None + view_1309 = torch.ops.aten.view.default(view_1306, [16384, 128]); view_1306 = None + permute_361 = torch.ops.aten.permute.default(view_1309, [1, 0]) + mm_193 = torch.ops.aten.mm.default(permute_361, view_735); permute_361 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16); primals_97 = None + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 8, '0'); convert_element_type_340 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + permute_363 = torch.ops.aten.permute.default(permute_112, [1, 0]); permute_112 = None + mm_194 = torch.ops.aten.mm.default(view_1309, permute_363); view_1309 = permute_363 = None + view_1310 = torch.ops.aten.view.default(mm_194, [2, 8192, 4096]); mm_194 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(mm_193, torch.float32); mm_193 = None + reduce_scatter_tensor_97 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_853, 'avg', 8, '0'); convert_element_type_853 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_97); reduce_scatter_tensor_97 = None + view_1311 = torch.ops.aten.view.default(view_1307, [16384, 128]); view_1307 = None + permute_365 = torch.ops.aten.permute.default(view_1311, [1, 0]) + mm_195 = torch.ops.aten.mm.default(permute_365, view_735); permute_365 = None + permute_367 = torch.ops.aten.permute.default(permute_111, [1, 0]); permute_111 = None + mm_196 = torch.ops.aten.mm.default(view_1311, permute_367); view_1311 = permute_367 = None + view_1312 = torch.ops.aten.view.default(mm_196, [2, 8192, 4096]); mm_196 = None + add_104 = torch.ops.aten.add.Tensor(view_1310, view_1312); view_1310 = view_1312 = None + convert_element_type_858 = torch.ops.prims.convert_element_type.default(mm_195, torch.float32); mm_195 = None + reduce_scatter_tensor_98 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_858, 'avg', 8, '0'); convert_element_type_858 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_98); reduce_scatter_tensor_98 = None + view_1313 = torch.ops.aten.view.default(view_1308, [16384, 512]); view_1308 = None + permute_369 = torch.ops.aten.permute.default(view_1313, [1, 0]) + mm_197 = torch.ops.aten.mm.default(permute_369, view_735); permute_369 = view_735 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16); primals_95 = None + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 8, '0'); convert_element_type_334 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + permute_371 = torch.ops.aten.permute.default(permute_110, [1, 0]); permute_110 = None + mm_198 = torch.ops.aten.mm.default(view_1313, permute_371); view_1313 = permute_371 = None + view_1314 = torch.ops.aten.view.default(mm_198, [2, 8192, 4096]); mm_198 = None + add_105 = torch.ops.aten.add.Tensor(add_104, view_1314); add_104 = view_1314 = None + convert_element_type_863 = torch.ops.prims.convert_element_type.default(mm_197, torch.float32); mm_197 = None + reduce_scatter_tensor_99 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_863, 'avg', 8, '0'); convert_element_type_863 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_99); reduce_scatter_tensor_99 = None + split_98 = torch.ops.aten.split.Tensor(add_105, 1024, 1); add_105 = None + getitem_946 = split_98[0] + getitem_947 = split_98[1] + getitem_948 = split_98[2] + getitem_949 = split_98[3] + getitem_950 = split_98[4] + getitem_951 = split_98[5] + getitem_952 = split_98[6] + getitem_953 = split_98[7]; split_98 = None + cat_90 = torch.ops.aten.cat.default([getitem_946, getitem_947, getitem_948, getitem_949, getitem_950, getitem_951, getitem_952, getitem_953]); getitem_946 = getitem_947 = getitem_948 = getitem_949 = getitem_950 = getitem_951 = getitem_952 = getitem_953 = None + reduce_scatter_tensor_100 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_90, 'sum', 8, '1'); cat_90 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_100); reduce_scatter_tensor_100 = None + convert_element_type_864 = torch.ops.prims.convert_element_type.default(wait_tensor_304, torch.float32); wait_tensor_304 = None + convert_element_type_866 = torch.ops.prims.convert_element_type.default(wait_tensor_132, torch.float32); wait_tensor_132 = None + mul_250 = torch.ops.aten.mul.Tensor(convert_element_type_864, convert_element_type_866); convert_element_type_866 = None + mul_252 = torch.ops.aten.mul.Tensor(mul_80, mul_250) + sum_37 = torch.ops.aten.sum.dim_IntList(mul_252, [2], True); mul_252 = None + div_12 = torch.ops.aten.div.Tensor(mul_80, 4096) + mul_253 = torch.ops.aten.mul.Tensor(div_12, sum_37); div_12 = sum_37 = None + sub_19 = torch.ops.aten.sub.Tensor(mul_250, mul_253); mul_250 = mul_253 = None + mul_254 = torch.ops.aten.mul.Tensor(sub_19, rsqrt_20); sub_19 = rsqrt_20 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_864, mul_80); convert_element_type_864 = mul_80 = None + sum_38 = torch.ops.aten.sum.dim_IntList(mul_255, [0, 1]); mul_255 = None + convert_element_type_867 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(sum_38, torch.bfloat16); sum_38 = None + all_reduce_12 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_868, 'sum', '1'); convert_element_type_868 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_12); all_reduce_12 = None + convert_element_type_869 = torch.ops.prims.convert_element_type.default(wait_tensor_305, torch.float32); wait_tensor_305 = None + reduce_scatter_tensor_101 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_869, 'avg', 8, '0'); convert_element_type_869 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_101); reduce_scatter_tensor_101 = None + add_106 = torch.ops.aten.add.Tensor(add_103, convert_element_type_867); add_103 = convert_element_type_867 = None + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_106, 8, '1') + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + split_99 = torch.ops.aten.split.Tensor(wait_tensor_307, 2); wait_tensor_307 = None + getitem_954 = split_99[0] + getitem_955 = split_99[1] + getitem_956 = split_99[2] + getitem_957 = split_99[3] + getitem_958 = split_99[4] + getitem_959 = split_99[5] + getitem_960 = split_99[6] + getitem_961 = split_99[7]; split_99 = None + cat_91 = torch.ops.aten.cat.default([getitem_954, getitem_955, getitem_956, getitem_957, getitem_958, getitem_959, getitem_960, getitem_961], 1); getitem_954 = getitem_955 = getitem_956 = getitem_957 = getitem_958 = getitem_959 = getitem_960 = getitem_961 = None + view_1315 = torch.ops.aten.view.default(cat_91, [16384, 4096]); cat_91 = None + permute_373 = torch.ops.aten.permute.default(view_1315, [1, 0]) + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19); reduce_scatter_tensor_19 = None + add_37 = torch.ops.aten.add.Tensor(add_35, wait_tensor_125); wait_tensor_125 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 8, '0'); convert_element_type_317 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32); add_37 = None + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_126) + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_319, 8, '1'); convert_element_type_319 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + split_47 = torch.ops.aten.split.Tensor(wait_tensor_127, 2); wait_tensor_127 = None + getitem_466 = split_47[0] + getitem_467 = split_47[1] + getitem_468 = split_47[2] + getitem_469 = split_47[3] + getitem_470 = split_47[4] + getitem_471 = split_47[5] + getitem_472 = split_47[6] + getitem_473 = split_47[7]; split_47 = None + cat_39 = torch.ops.aten.cat.default([getitem_466, getitem_467, getitem_468, getitem_469, getitem_470, getitem_471, getitem_472, getitem_473], 1); getitem_466 = getitem_467 = getitem_468 = getitem_469 = getitem_470 = getitem_471 = getitem_472 = getitem_473 = None + view_708 = torch.ops.aten.view.default(cat_39, [16384, 4096]); cat_39 = None + view_709 = torch.ops.aten.view.default(mm_67, [2, 8192, 1792]); mm_67 = None + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_709, torch.float32); view_709 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16); primals_92 = None + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 8, '0'); convert_element_type_325 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_68 = torch.ops.aten.mm.default(view_708, permute_108) + view_716 = torch.ops.aten.view.default(mm_68, [2, 8192, 1792]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_716) + view_723 = torch.ops.aten.view.default(mul_79, [16384, 1792]); mul_79 = None + mm_199 = torch.ops.aten.mm.default(permute_373, view_723); permute_373 = view_723 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16); primals_93 = None + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 8, '0'); convert_element_type_328 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + permute_375 = torch.ops.aten.permute.default(permute_109, [1, 0]); permute_109 = None + mm_200 = torch.ops.aten.mm.default(view_1315, permute_375); view_1315 = permute_375 = None + view_1316 = torch.ops.aten.view.default(mm_200, [2, 8192, 1792]); mm_200 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(mm_199, torch.float32); mm_199 = None + reduce_scatter_tensor_102 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_874, 'avg', 8, '0'); convert_element_type_874 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_102); reduce_scatter_tensor_102 = None + mul_256 = torch.ops.aten.mul.Tensor(view_1316, convert_element_type_324); convert_element_type_324 = None + mul_257 = torch.ops.aten.mul.Tensor(view_1316, view_716); view_1316 = view_716 = None + view_1317 = torch.ops.aten.view.default(mul_256, [16384, 1792]); mul_256 = None + permute_377 = torch.ops.aten.permute.default(view_1317, [1, 0]) + mm_201 = torch.ops.aten.mm.default(permute_377, view_708); permute_377 = None + permute_379 = torch.ops.aten.permute.default(permute_108, [1, 0]); permute_108 = None + mm_202 = torch.ops.aten.mm.default(view_1317, permute_379); view_1317 = permute_379 = None + view_1318 = torch.ops.aten.view.default(mm_202, [2, 8192, 4096]); mm_202 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(mm_201, torch.float32); mm_201 = None + reduce_scatter_tensor_103 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_879, 'avg', 8, '0'); convert_element_type_879 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_103); reduce_scatter_tensor_103 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_257, torch.float32); mul_257 = None + neg_6 = torch.ops.aten.neg.default(convert_element_type_323) + exp_6 = torch.ops.aten.exp.default(neg_6); neg_6 = None + add_107 = torch.ops.aten.add.Tensor(exp_6, 1); exp_6 = None + reciprocal_6 = torch.ops.aten.reciprocal.default(add_107); add_107 = None + mul_258 = torch.ops.aten.mul.Tensor(reciprocal_6, 1); reciprocal_6 = None + mul_259 = torch.ops.aten.mul.Tensor(convert_element_type_880, mul_258); convert_element_type_880 = None + sub_20 = torch.ops.aten.sub.Tensor(1, mul_258); mul_258 = None + mul_260 = torch.ops.aten.mul.Tensor(convert_element_type_323, sub_20); convert_element_type_323 = sub_20 = None + add_108 = torch.ops.aten.add.Tensor(mul_260, 1); mul_260 = None + mul_261 = torch.ops.aten.mul.Tensor(mul_259, add_108); mul_259 = add_108 = None + convert_element_type_882 = torch.ops.prims.convert_element_type.default(mul_261, torch.bfloat16); mul_261 = None + view_1319 = torch.ops.aten.view.default(convert_element_type_882, [16384, 1792]); convert_element_type_882 = None + permute_381 = torch.ops.aten.permute.default(view_1319, [1, 0]) + mm_203 = torch.ops.aten.mm.default(permute_381, view_708); permute_381 = view_708 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16); primals_91 = None + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 8, '0'); convert_element_type_320 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + permute_383 = torch.ops.aten.permute.default(permute_107, [1, 0]); permute_107 = None + mm_204 = torch.ops.aten.mm.default(view_1319, permute_383); view_1319 = permute_383 = None + view_1320 = torch.ops.aten.view.default(mm_204, [2, 8192, 4096]); mm_204 = None + add_109 = torch.ops.aten.add.Tensor(view_1318, view_1320); view_1318 = view_1320 = None + convert_element_type_887 = torch.ops.prims.convert_element_type.default(mm_203, torch.float32); mm_203 = None + reduce_scatter_tensor_104 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_887, 'avg', 8, '0'); convert_element_type_887 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_104); reduce_scatter_tensor_104 = None + split_100 = torch.ops.aten.split.Tensor(add_109, 1024, 1); add_109 = None + getitem_962 = split_100[0] + getitem_963 = split_100[1] + getitem_964 = split_100[2] + getitem_965 = split_100[3] + getitem_966 = split_100[4] + getitem_967 = split_100[5] + getitem_968 = split_100[6] + getitem_969 = split_100[7]; split_100 = None + cat_92 = torch.ops.aten.cat.default([getitem_962, getitem_963, getitem_964, getitem_965, getitem_966, getitem_967, getitem_968, getitem_969]); getitem_962 = getitem_963 = getitem_964 = getitem_965 = getitem_966 = getitem_967 = getitem_968 = getitem_969 = None + reduce_scatter_tensor_105 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_92, 'sum', 8, '1'); cat_92 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_105); reduce_scatter_tensor_105 = None + convert_element_type_888 = torch.ops.prims.convert_element_type.default(wait_tensor_311, torch.float32); wait_tensor_311 = None + convert_element_type_890 = torch.ops.prims.convert_element_type.default(wait_tensor_126, torch.float32); wait_tensor_126 = None + mul_262 = torch.ops.aten.mul.Tensor(convert_element_type_888, convert_element_type_890); convert_element_type_890 = None + mul_264 = torch.ops.aten.mul.Tensor(mul_76, mul_262) + sum_39 = torch.ops.aten.sum.dim_IntList(mul_264, [2], True); mul_264 = None + div_13 = torch.ops.aten.div.Tensor(mul_76, 4096) + mul_265 = torch.ops.aten.mul.Tensor(div_13, sum_39); div_13 = sum_39 = None + sub_21 = torch.ops.aten.sub.Tensor(mul_262, mul_265); mul_262 = mul_265 = None + mul_266 = torch.ops.aten.mul.Tensor(sub_21, rsqrt_19); sub_21 = rsqrt_19 = None + mul_267 = torch.ops.aten.mul.Tensor(convert_element_type_888, mul_76); convert_element_type_888 = mul_76 = None + sum_40 = torch.ops.aten.sum.dim_IntList(mul_267, [0, 1]); mul_267 = None + convert_element_type_891 = torch.ops.prims.convert_element_type.default(mul_266, torch.bfloat16); mul_266 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(sum_40, torch.bfloat16); sum_40 = None + all_reduce_13 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_892, 'sum', '1'); convert_element_type_892 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_13); all_reduce_13 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(wait_tensor_312, torch.float32); wait_tensor_312 = None + reduce_scatter_tensor_106 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_893, 'avg', 8, '0'); convert_element_type_893 = None + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_106); reduce_scatter_tensor_106 = None + add_110 = torch.ops.aten.add.Tensor(add_106, convert_element_type_891); add_106 = convert_element_type_891 = None + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_110, 8, '1') + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + split_101 = torch.ops.aten.split.Tensor(wait_tensor_314, 2); wait_tensor_314 = None + getitem_970 = split_101[0] + getitem_971 = split_101[1] + getitem_972 = split_101[2] + getitem_973 = split_101[3] + getitem_974 = split_101[4] + getitem_975 = split_101[5] + getitem_976 = split_101[6] + getitem_977 = split_101[7]; split_101 = None + cat_93 = torch.ops.aten.cat.default([getitem_970, getitem_971, getitem_972, getitem_973, getitem_974, getitem_975, getitem_976, getitem_977], 1); getitem_970 = getitem_971 = getitem_972 = getitem_973 = getitem_974 = getitem_975 = getitem_976 = getitem_977 = None + view_1321 = torch.ops.aten.view.default(cat_93, [16384, 4096]); cat_93 = None + permute_385 = torch.ops.aten.permute.default(view_1321, [1, 0]) + permute_105 = torch.ops.aten.permute.default(getitem_449, [0, 2, 1, 3]) + view_690 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + view_696 = torch.ops.aten.view.default(view_690, [16384, 512]); view_690 = None + mm_205 = torch.ops.aten.mm.default(permute_385, view_696); permute_385 = view_696 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 8, '0'); convert_element_type_314 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + permute_387 = torch.ops.aten.permute.default(permute_106, [1, 0]); permute_106 = None + mm_206 = torch.ops.aten.mm.default(view_1321, permute_387); view_1321 = permute_387 = None + view_1322 = torch.ops.aten.view.default(mm_206, [2, 8192, 512]); mm_206 = None + convert_element_type_898 = torch.ops.prims.convert_element_type.default(mm_205, torch.float32); mm_205 = None + reduce_scatter_tensor_107 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_898, 'avg', 8, '0'); convert_element_type_898 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_107); reduce_scatter_tensor_107 = None + view_1323 = torch.ops.aten.view.default(view_1322, [2, 8192, 4, 128]); view_1322 = None + permute_389 = torch.ops.aten.permute.default(view_1323, [0, 2, 1, 3]); view_1323 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 8, '0'); convert_element_type_298 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_119) + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_300, 8, '1'); convert_element_type_300 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + split_45 = torch.ops.aten.split.Tensor(wait_tensor_120, 2); wait_tensor_120 = None + getitem_441 = split_45[0] + getitem_442 = split_45[1] + getitem_443 = split_45[2] + getitem_444 = split_45[3] + getitem_445 = split_45[4] + getitem_446 = split_45[5] + getitem_447 = split_45[6] + getitem_448 = split_45[7]; split_45 = None + cat_37 = torch.ops.aten.cat.default([getitem_441, getitem_442, getitem_443, getitem_444, getitem_445, getitem_446, getitem_447, getitem_448], 1); getitem_441 = getitem_442 = getitem_443 = getitem_444 = getitem_445 = getitem_446 = getitem_447 = getitem_448 = None + view_663 = torch.ops.aten.view.default(cat_37, [16384, 4096]); cat_37 = None + view_664 = torch.ops.aten.view.default(mm_63, [2, 8192, 512]); mm_63 = None + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 8, '0'); convert_element_type_304 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + mm_64 = torch.ops.aten.mm.default(view_663, permute_100) + view_671 = torch.ops.aten.view.default(mm_64, [2, 8192, 128]); mm_64 = None + view_678 = torch.ops.aten.view.default(mm_65, [2, 8192, 128]); mm_65 = None + view_680 = torch.ops.aten.view.default(view_664, [2, 8192, -1, 128]); view_664 = None + view_681 = torch.ops.aten.view.default(view_671, [2, 8192, -1, 128]); view_671 = None + view_682 = torch.ops.aten.view.default(view_678, [2, 8192, -1, 128]); view_678 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_680, torch.float32); view_680 = None + view_683 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 4, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_683); view_683 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_681, torch.float32); view_681 = None + view_684 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 1, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_684); view_684 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_37); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_686 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 4, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_37); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_687 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 1, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_686, torch.bfloat16); view_686 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_687, torch.bfloat16); view_687 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 1, 4, 128]); unsqueeze_18 = None + view_688 = torch.ops.aten.view.default(expand_18, [2, 8192, 4, 128]); expand_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_682, 3); view_682 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 1, 4, 128]); unsqueeze_19 = None + view_689 = torch.ops.aten.view.default(expand_19, [2, 8192, 4, 128]); expand_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_688, [0, 2, 1, 3]); view_688 = None + permute_104 = torch.ops.aten.permute.default(view_689, [0, 2, 1, 3]); view_689 = None + _scaled_dot_product_cudnn_attention_backward_6 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_389, permute_102, permute_103, permute_104, getitem_449, getitem_450, getitem_455, getitem_456, None, None, None, 8192, 8192, 0.0, True); permute_389 = permute_102 = permute_103 = permute_104 = getitem_449 = getitem_450 = getitem_455 = getitem_456 = None + getitem_978 = _scaled_dot_product_cudnn_attention_backward_6[0] + getitem_979 = _scaled_dot_product_cudnn_attention_backward_6[1] + getitem_980 = _scaled_dot_product_cudnn_attention_backward_6[2]; _scaled_dot_product_cudnn_attention_backward_6 = None + permute_390 = torch.ops.aten.permute.default(getitem_980, [0, 2, 1, 3]); getitem_980 = None + permute_391 = torch.ops.aten.permute.default(getitem_979, [0, 2, 1, 3]); getitem_979 = None + permute_392 = torch.ops.aten.permute.default(getitem_978, [0, 2, 1, 3]); getitem_978 = None + view_1324 = torch.ops.aten.view.default(permute_390, [2, 8192, 1, 4, 128]); permute_390 = None + sum_41 = torch.ops.aten.sum.dim_IntList(view_1324, [3], True); view_1324 = None + squeeze_12 = torch.ops.aten.squeeze.dim(sum_41, 3); sum_41 = None + view_1325 = torch.ops.aten.view.default(permute_391, [2, 8192, 1, 4, 128]); permute_391 = None + sum_42 = torch.ops.aten.sum.dim_IntList(view_1325, [3], True); view_1325 = None + squeeze_13 = torch.ops.aten.squeeze.dim(sum_42, 3); sum_42 = None + convert_element_type_899 = torch.ops.prims.convert_element_type.default(squeeze_13, torch.float32); squeeze_13 = None + convert_element_type_900 = torch.ops.prims.convert_element_type.default(permute_392, torch.float32); permute_392 = None + view_1326 = torch.ops.aten.view.default(convert_element_type_899, [2, 8192, 1, 64, 2]); convert_element_type_899 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_1326); view_1326 = None + mul_268 = torch.ops.aten.mul.Tensor(view_as_complex_44, _conj); view_as_complex_44 = None + view_1327 = torch.ops.aten.view.default(convert_element_type_900, [2, 8192, 4, 64, 2]); convert_element_type_900 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_1327); view_1327 = None + mul_269 = torch.ops.aten.mul.Tensor(view_as_complex_45, _conj); view_as_complex_45 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_268); mul_268 = None + view_1328 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 1, 128]); view_as_real_44 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(view_1328, torch.bfloat16); view_1328 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_269); mul_269 = None + view_1329 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 4, 128]); view_as_real_45 = None + convert_element_type_902 = torch.ops.prims.convert_element_type.default(view_1329, torch.bfloat16); view_1329 = None + view_1330 = torch.ops.aten.view.default(squeeze_12, [2, 8192, 128]); squeeze_12 = None + view_1331 = torch.ops.aten.view.default(convert_element_type_901, [2, 8192, 128]); convert_element_type_901 = None + view_1332 = torch.ops.aten.view.default(convert_element_type_902, [2, 8192, 512]); convert_element_type_902 = None + view_1333 = torch.ops.aten.view.default(view_1330, [16384, 128]); view_1330 = None + permute_393 = torch.ops.aten.permute.default(view_1333, [1, 0]) + mm_207 = torch.ops.aten.mm.default(permute_393, view_663); permute_393 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 8, '0'); convert_element_type_307 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_123, [1, 0]); wait_tensor_123 = None + permute_395 = torch.ops.aten.permute.default(permute_101, [1, 0]); permute_101 = None + mm_208 = torch.ops.aten.mm.default(view_1333, permute_395); view_1333 = permute_395 = None + view_1334 = torch.ops.aten.view.default(mm_208, [2, 8192, 4096]); mm_208 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(mm_207, torch.float32); mm_207 = None + reduce_scatter_tensor_108 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_907, 'avg', 8, '0'); convert_element_type_907 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_108); reduce_scatter_tensor_108 = None + view_1335 = torch.ops.aten.view.default(view_1331, [16384, 128]); view_1331 = None + permute_397 = torch.ops.aten.permute.default(view_1335, [1, 0]) + mm_209 = torch.ops.aten.mm.default(permute_397, view_663); permute_397 = None + permute_399 = torch.ops.aten.permute.default(permute_100, [1, 0]); permute_100 = None + mm_210 = torch.ops.aten.mm.default(view_1335, permute_399); view_1335 = permute_399 = None + view_1336 = torch.ops.aten.view.default(mm_210, [2, 8192, 4096]); mm_210 = None + add_111 = torch.ops.aten.add.Tensor(view_1334, view_1336); view_1334 = view_1336 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(mm_209, torch.float32); mm_209 = None + reduce_scatter_tensor_109 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_912, 'avg', 8, '0'); convert_element_type_912 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_109); reduce_scatter_tensor_109 = None + view_1337 = torch.ops.aten.view.default(view_1332, [16384, 512]); view_1332 = None + permute_401 = torch.ops.aten.permute.default(view_1337, [1, 0]) + mm_211 = torch.ops.aten.mm.default(permute_401, view_663); permute_401 = view_663 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 8, '0'); convert_element_type_301 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + permute_403 = torch.ops.aten.permute.default(permute_99, [1, 0]); permute_99 = None + mm_212 = torch.ops.aten.mm.default(view_1337, permute_403); view_1337 = permute_403 = None + view_1338 = torch.ops.aten.view.default(mm_212, [2, 8192, 4096]); mm_212 = None + add_112 = torch.ops.aten.add.Tensor(add_111, view_1338); add_111 = view_1338 = None + convert_element_type_917 = torch.ops.prims.convert_element_type.default(mm_211, torch.float32); mm_211 = None + reduce_scatter_tensor_110 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_917, 'avg', 8, '0'); convert_element_type_917 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_110); reduce_scatter_tensor_110 = None + split_102 = torch.ops.aten.split.Tensor(add_112, 1024, 1); add_112 = None + getitem_981 = split_102[0] + getitem_982 = split_102[1] + getitem_983 = split_102[2] + getitem_984 = split_102[3] + getitem_985 = split_102[4] + getitem_986 = split_102[5] + getitem_987 = split_102[6] + getitem_988 = split_102[7]; split_102 = None + cat_94 = torch.ops.aten.cat.default([getitem_981, getitem_982, getitem_983, getitem_984, getitem_985, getitem_986, getitem_987, getitem_988]); getitem_981 = getitem_982 = getitem_983 = getitem_984 = getitem_985 = getitem_986 = getitem_987 = getitem_988 = None + reduce_scatter_tensor_111 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_94, 'sum', 8, '1'); cat_94 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_111); reduce_scatter_tensor_111 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(wait_tensor_319, torch.float32); wait_tensor_319 = None + convert_element_type_920 = torch.ops.prims.convert_element_type.default(wait_tensor_119, torch.float32); wait_tensor_119 = None + mul_270 = torch.ops.aten.mul.Tensor(convert_element_type_918, convert_element_type_920); convert_element_type_920 = None + mul_272 = torch.ops.aten.mul.Tensor(mul_72, mul_270) + sum_43 = torch.ops.aten.sum.dim_IntList(mul_272, [2], True); mul_272 = None + div_14 = torch.ops.aten.div.Tensor(mul_72, 4096) + mul_273 = torch.ops.aten.mul.Tensor(div_14, sum_43); div_14 = sum_43 = None + sub_22 = torch.ops.aten.sub.Tensor(mul_270, mul_273); mul_270 = mul_273 = None + mul_274 = torch.ops.aten.mul.Tensor(sub_22, rsqrt_18); sub_22 = rsqrt_18 = None + mul_275 = torch.ops.aten.mul.Tensor(convert_element_type_918, mul_72); convert_element_type_918 = mul_72 = None + sum_44 = torch.ops.aten.sum.dim_IntList(mul_275, [0, 1]); mul_275 = None + convert_element_type_921 = torch.ops.prims.convert_element_type.default(mul_274, torch.bfloat16); mul_274 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(sum_44, torch.bfloat16); sum_44 = None + all_reduce_14 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_922, 'sum', '1'); convert_element_type_922 = None + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_14); all_reduce_14 = None + convert_element_type_923 = torch.ops.prims.convert_element_type.default(wait_tensor_320, torch.float32); wait_tensor_320 = None + reduce_scatter_tensor_112 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_923, 'avg', 8, '0'); convert_element_type_923 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_112); reduce_scatter_tensor_112 = None + add_113 = torch.ops.aten.add.Tensor(add_110, convert_element_type_921); add_110 = convert_element_type_921 = None + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_113, 8, '1') + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + split_103 = torch.ops.aten.split.Tensor(wait_tensor_322, 2); wait_tensor_322 = None + getitem_989 = split_103[0] + getitem_990 = split_103[1] + getitem_991 = split_103[2] + getitem_992 = split_103[3] + getitem_993 = split_103[4] + getitem_994 = split_103[5] + getitem_995 = split_103[6] + getitem_996 = split_103[7]; split_103 = None + cat_95 = torch.ops.aten.cat.default([getitem_989, getitem_990, getitem_991, getitem_992, getitem_993, getitem_994, getitem_995, getitem_996], 1); getitem_989 = getitem_990 = getitem_991 = getitem_992 = getitem_993 = getitem_994 = getitem_995 = getitem_996 = None + view_1339 = torch.ops.aten.view.default(cat_95, [16384, 4096]); cat_95 = None + permute_405 = torch.ops.aten.permute.default(view_1339, [1, 0]) + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17); reduce_scatter_tensor_17 = None + add_33 = torch.ops.aten.add.Tensor(add_31, wait_tensor_112); wait_tensor_112 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16); primals_81 = None + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 8, '0'); convert_element_type_284 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32); add_33 = None + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_113) + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 8, '1'); convert_element_type_286 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + split_43 = torch.ops.aten.split.Tensor(wait_tensor_114, 2); wait_tensor_114 = None + getitem_425 = split_43[0] + getitem_426 = split_43[1] + getitem_427 = split_43[2] + getitem_428 = split_43[3] + getitem_429 = split_43[4] + getitem_430 = split_43[5] + getitem_431 = split_43[6] + getitem_432 = split_43[7]; split_43 = None + cat_35 = torch.ops.aten.cat.default([getitem_425, getitem_426, getitem_427, getitem_428, getitem_429, getitem_430, getitem_431, getitem_432], 1); getitem_425 = getitem_426 = getitem_427 = getitem_428 = getitem_429 = getitem_430 = getitem_431 = getitem_432 = None + view_636 = torch.ops.aten.view.default(cat_35, [16384, 4096]); cat_35 = None + view_637 = torch.ops.aten.view.default(mm_60, [2, 8192, 1792]); mm_60 = None + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_637, torch.float32); view_637 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 8, '0'); convert_element_type_292 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_61 = torch.ops.aten.mm.default(view_636, permute_97) + view_644 = torch.ops.aten.view.default(mm_61, [2, 8192, 1792]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_644) + view_651 = torch.ops.aten.view.default(mul_71, [16384, 1792]); mul_71 = None + mm_213 = torch.ops.aten.mm.default(permute_405, view_651); permute_405 = view_651 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 8, '0'); convert_element_type_295 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + permute_407 = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None + mm_214 = torch.ops.aten.mm.default(view_1339, permute_407); view_1339 = permute_407 = None + view_1340 = torch.ops.aten.view.default(mm_214, [2, 8192, 1792]); mm_214 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(mm_213, torch.float32); mm_213 = None + reduce_scatter_tensor_113 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_928, 'avg', 8, '0'); convert_element_type_928 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_113); reduce_scatter_tensor_113 = None + mul_276 = torch.ops.aten.mul.Tensor(view_1340, convert_element_type_291); convert_element_type_291 = None + mul_277 = torch.ops.aten.mul.Tensor(view_1340, view_644); view_1340 = view_644 = None + view_1341 = torch.ops.aten.view.default(mul_276, [16384, 1792]); mul_276 = None + permute_409 = torch.ops.aten.permute.default(view_1341, [1, 0]) + mm_215 = torch.ops.aten.mm.default(permute_409, view_636); permute_409 = None + permute_411 = torch.ops.aten.permute.default(permute_97, [1, 0]); permute_97 = None + mm_216 = torch.ops.aten.mm.default(view_1341, permute_411); view_1341 = permute_411 = None + view_1342 = torch.ops.aten.view.default(mm_216, [2, 8192, 4096]); mm_216 = None + convert_element_type_933 = torch.ops.prims.convert_element_type.default(mm_215, torch.float32); mm_215 = None + reduce_scatter_tensor_114 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_933, 'avg', 8, '0'); convert_element_type_933 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_114); reduce_scatter_tensor_114 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(mul_277, torch.float32); mul_277 = None + neg_7 = torch.ops.aten.neg.default(convert_element_type_290) + exp_7 = torch.ops.aten.exp.default(neg_7); neg_7 = None + add_114 = torch.ops.aten.add.Tensor(exp_7, 1); exp_7 = None + reciprocal_7 = torch.ops.aten.reciprocal.default(add_114); add_114 = None + mul_278 = torch.ops.aten.mul.Tensor(reciprocal_7, 1); reciprocal_7 = None + mul_279 = torch.ops.aten.mul.Tensor(convert_element_type_934, mul_278); convert_element_type_934 = None + sub_23 = torch.ops.aten.sub.Tensor(1, mul_278); mul_278 = None + mul_280 = torch.ops.aten.mul.Tensor(convert_element_type_290, sub_23); convert_element_type_290 = sub_23 = None + add_115 = torch.ops.aten.add.Tensor(mul_280, 1); mul_280 = None + mul_281 = torch.ops.aten.mul.Tensor(mul_279, add_115); mul_279 = add_115 = None + convert_element_type_936 = torch.ops.prims.convert_element_type.default(mul_281, torch.bfloat16); mul_281 = None + view_1343 = torch.ops.aten.view.default(convert_element_type_936, [16384, 1792]); convert_element_type_936 = None + permute_413 = torch.ops.aten.permute.default(view_1343, [1, 0]) + mm_217 = torch.ops.aten.mm.default(permute_413, view_636); permute_413 = view_636 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 8, '0'); convert_element_type_287 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + permute_415 = torch.ops.aten.permute.default(permute_96, [1, 0]); permute_96 = None + mm_218 = torch.ops.aten.mm.default(view_1343, permute_415); view_1343 = permute_415 = None + view_1344 = torch.ops.aten.view.default(mm_218, [2, 8192, 4096]); mm_218 = None + add_116 = torch.ops.aten.add.Tensor(view_1342, view_1344); view_1342 = view_1344 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(mm_217, torch.float32); mm_217 = None + reduce_scatter_tensor_115 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_941, 'avg', 8, '0'); convert_element_type_941 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_115); reduce_scatter_tensor_115 = None + split_104 = torch.ops.aten.split.Tensor(add_116, 1024, 1); add_116 = None + getitem_997 = split_104[0] + getitem_998 = split_104[1] + getitem_999 = split_104[2] + getitem_1000 = split_104[3] + getitem_1001 = split_104[4] + getitem_1002 = split_104[5] + getitem_1003 = split_104[6] + getitem_1004 = split_104[7]; split_104 = None + cat_96 = torch.ops.aten.cat.default([getitem_997, getitem_998, getitem_999, getitem_1000, getitem_1001, getitem_1002, getitem_1003, getitem_1004]); getitem_997 = getitem_998 = getitem_999 = getitem_1000 = getitem_1001 = getitem_1002 = getitem_1003 = getitem_1004 = None + reduce_scatter_tensor_116 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_96, 'sum', 8, '1'); cat_96 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_116); reduce_scatter_tensor_116 = None + convert_element_type_942 = torch.ops.prims.convert_element_type.default(wait_tensor_326, torch.float32); wait_tensor_326 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(wait_tensor_113, torch.float32); wait_tensor_113 = None + mul_282 = torch.ops.aten.mul.Tensor(convert_element_type_942, convert_element_type_944); convert_element_type_944 = None + mul_284 = torch.ops.aten.mul.Tensor(mul_68, mul_282) + sum_45 = torch.ops.aten.sum.dim_IntList(mul_284, [2], True); mul_284 = None + div_15 = torch.ops.aten.div.Tensor(mul_68, 4096) + mul_285 = torch.ops.aten.mul.Tensor(div_15, sum_45); div_15 = sum_45 = None + sub_24 = torch.ops.aten.sub.Tensor(mul_282, mul_285); mul_282 = mul_285 = None + mul_286 = torch.ops.aten.mul.Tensor(sub_24, rsqrt_17); sub_24 = rsqrt_17 = None + mul_287 = torch.ops.aten.mul.Tensor(convert_element_type_942, mul_68); convert_element_type_942 = mul_68 = None + sum_46 = torch.ops.aten.sum.dim_IntList(mul_287, [0, 1]); mul_287 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(mul_286, torch.bfloat16); mul_286 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(sum_46, torch.bfloat16); sum_46 = None + all_reduce_15 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_946, 'sum', '1'); convert_element_type_946 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_15); all_reduce_15 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(wait_tensor_327, torch.float32); wait_tensor_327 = None + reduce_scatter_tensor_117 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_947, 'avg', 8, '0'); convert_element_type_947 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_117); reduce_scatter_tensor_117 = None + add_117 = torch.ops.aten.add.Tensor(add_113, convert_element_type_945); add_113 = convert_element_type_945 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_117, 8, '1') + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + split_105 = torch.ops.aten.split.Tensor(wait_tensor_329, 2); wait_tensor_329 = None + getitem_1005 = split_105[0] + getitem_1006 = split_105[1] + getitem_1007 = split_105[2] + getitem_1008 = split_105[3] + getitem_1009 = split_105[4] + getitem_1010 = split_105[5] + getitem_1011 = split_105[6] + getitem_1012 = split_105[7]; split_105 = None + cat_97 = torch.ops.aten.cat.default([getitem_1005, getitem_1006, getitem_1007, getitem_1008, getitem_1009, getitem_1010, getitem_1011, getitem_1012], 1); getitem_1005 = getitem_1006 = getitem_1007 = getitem_1008 = getitem_1009 = getitem_1010 = getitem_1011 = getitem_1012 = None + view_1345 = torch.ops.aten.view.default(cat_97, [16384, 4096]); cat_97 = None + permute_417 = torch.ops.aten.permute.default(view_1345, [1, 0]) + permute_94 = torch.ops.aten.permute.default(getitem_408, [0, 2, 1, 3]) + view_618 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + view_624 = torch.ops.aten.view.default(view_618, [16384, 512]); view_618 = None + mm_219 = torch.ops.aten.mm.default(permute_417, view_624); permute_417 = view_624 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16); primals_80 = None + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 8, '0'); convert_element_type_281 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + permute_419 = torch.ops.aten.permute.default(permute_95, [1, 0]); permute_95 = None + mm_220 = torch.ops.aten.mm.default(view_1345, permute_419); view_1345 = permute_419 = None + view_1346 = torch.ops.aten.view.default(mm_220, [2, 8192, 512]); mm_220 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(mm_219, torch.float32); mm_219 = None + reduce_scatter_tensor_118 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_952, 'avg', 8, '0'); convert_element_type_952 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_118); reduce_scatter_tensor_118 = None + view_1347 = torch.ops.aten.view.default(view_1346, [2, 8192, 4, 128]); view_1346 = None + permute_421 = torch.ops.aten.permute.default(view_1347, [0, 2, 1, 3]); view_1347 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16); primals_76 = None + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 8, '0'); convert_element_type_265 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32); add_31 = None + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_106) + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_267, 8, '1'); convert_element_type_267 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + split_41 = torch.ops.aten.split.Tensor(wait_tensor_107, 2); wait_tensor_107 = None + getitem_400 = split_41[0] + getitem_401 = split_41[1] + getitem_402 = split_41[2] + getitem_403 = split_41[3] + getitem_404 = split_41[4] + getitem_405 = split_41[5] + getitem_406 = split_41[6] + getitem_407 = split_41[7]; split_41 = None + cat_33 = torch.ops.aten.cat.default([getitem_400, getitem_401, getitem_402, getitem_403, getitem_404, getitem_405, getitem_406, getitem_407], 1); getitem_400 = getitem_401 = getitem_402 = getitem_403 = getitem_404 = getitem_405 = getitem_406 = getitem_407 = None + view_591 = torch.ops.aten.view.default(cat_33, [16384, 4096]); cat_33 = None + view_592 = torch.ops.aten.view.default(mm_56, [2, 8192, 512]); mm_56 = None + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16); primals_78 = None + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 8, '0'); convert_element_type_271 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_109, [1, 0]); wait_tensor_109 = None + mm_57 = torch.ops.aten.mm.default(view_591, permute_89) + view_599 = torch.ops.aten.view.default(mm_57, [2, 8192, 128]); mm_57 = None + view_606 = torch.ops.aten.view.default(mm_58, [2, 8192, 128]); mm_58 = None + view_608 = torch.ops.aten.view.default(view_592, [2, 8192, -1, 128]); view_592 = None + view_609 = torch.ops.aten.view.default(view_599, [2, 8192, -1, 128]); view_599 = None + view_610 = torch.ops.aten.view.default(view_606, [2, 8192, -1, 128]); view_606 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_608, torch.float32); view_608 = None + view_611 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 4, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_611); view_611 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_609, torch.float32); view_609 = None + view_612 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 1, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_612); view_612 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_37); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_614 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 4, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_37); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_615 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 1, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_614, torch.bfloat16); view_614 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_615, torch.bfloat16); view_615 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 1, 4, 128]); unsqueeze_16 = None + view_616 = torch.ops.aten.view.default(expand_16, [2, 8192, 4, 128]); expand_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_610, 3); view_610 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 1, 4, 128]); unsqueeze_17 = None + view_617 = torch.ops.aten.view.default(expand_17, [2, 8192, 4, 128]); expand_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_616, [0, 2, 1, 3]); view_616 = None + permute_93 = torch.ops.aten.permute.default(view_617, [0, 2, 1, 3]); view_617 = None + _scaled_dot_product_cudnn_attention_backward_7 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_421, permute_91, permute_92, permute_93, getitem_408, getitem_409, getitem_414, getitem_415, None, None, None, 8192, 8192, 0.0, True); permute_421 = permute_91 = permute_92 = permute_93 = getitem_408 = getitem_409 = getitem_414 = getitem_415 = None + getitem_1013 = _scaled_dot_product_cudnn_attention_backward_7[0] + getitem_1014 = _scaled_dot_product_cudnn_attention_backward_7[1] + getitem_1015 = _scaled_dot_product_cudnn_attention_backward_7[2]; _scaled_dot_product_cudnn_attention_backward_7 = None + permute_422 = torch.ops.aten.permute.default(getitem_1015, [0, 2, 1, 3]); getitem_1015 = None + permute_423 = torch.ops.aten.permute.default(getitem_1014, [0, 2, 1, 3]); getitem_1014 = None + permute_424 = torch.ops.aten.permute.default(getitem_1013, [0, 2, 1, 3]); getitem_1013 = None + view_1348 = torch.ops.aten.view.default(permute_422, [2, 8192, 1, 4, 128]); permute_422 = None + sum_47 = torch.ops.aten.sum.dim_IntList(view_1348, [3], True); view_1348 = None + squeeze_14 = torch.ops.aten.squeeze.dim(sum_47, 3); sum_47 = None + view_1349 = torch.ops.aten.view.default(permute_423, [2, 8192, 1, 4, 128]); permute_423 = None + sum_48 = torch.ops.aten.sum.dim_IntList(view_1349, [3], True); view_1349 = None + squeeze_15 = torch.ops.aten.squeeze.dim(sum_48, 3); sum_48 = None + convert_element_type_953 = torch.ops.prims.convert_element_type.default(squeeze_15, torch.float32); squeeze_15 = None + convert_element_type_954 = torch.ops.prims.convert_element_type.default(permute_424, torch.float32); permute_424 = None + view_1350 = torch.ops.aten.view.default(convert_element_type_953, [2, 8192, 1, 64, 2]); convert_element_type_953 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_1350); view_1350 = None + mul_288 = torch.ops.aten.mul.Tensor(view_as_complex_46, _conj); view_as_complex_46 = None + view_1351 = torch.ops.aten.view.default(convert_element_type_954, [2, 8192, 4, 64, 2]); convert_element_type_954 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_1351); view_1351 = None + mul_289 = torch.ops.aten.mul.Tensor(view_as_complex_47, _conj); view_as_complex_47 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_288); mul_288 = None + view_1352 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 1, 128]); view_as_real_46 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(view_1352, torch.bfloat16); view_1352 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_289); mul_289 = None + view_1353 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 4, 128]); view_as_real_47 = None + convert_element_type_956 = torch.ops.prims.convert_element_type.default(view_1353, torch.bfloat16); view_1353 = None + view_1354 = torch.ops.aten.view.default(squeeze_14, [2, 8192, 128]); squeeze_14 = None + view_1355 = torch.ops.aten.view.default(convert_element_type_955, [2, 8192, 128]); convert_element_type_955 = None + view_1356 = torch.ops.aten.view.default(convert_element_type_956, [2, 8192, 512]); convert_element_type_956 = None + view_1357 = torch.ops.aten.view.default(view_1354, [16384, 128]); view_1354 = None + permute_425 = torch.ops.aten.permute.default(view_1357, [1, 0]) + mm_221 = torch.ops.aten.mm.default(permute_425, view_591); permute_425 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16); primals_79 = None + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 8, '0'); convert_element_type_274 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + permute_427 = torch.ops.aten.permute.default(permute_90, [1, 0]); permute_90 = None + mm_222 = torch.ops.aten.mm.default(view_1357, permute_427); view_1357 = permute_427 = None + view_1358 = torch.ops.aten.view.default(mm_222, [2, 8192, 4096]); mm_222 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(mm_221, torch.float32); mm_221 = None + reduce_scatter_tensor_119 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_961, 'avg', 8, '0'); convert_element_type_961 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_119); reduce_scatter_tensor_119 = None + view_1359 = torch.ops.aten.view.default(view_1355, [16384, 128]); view_1355 = None + permute_429 = torch.ops.aten.permute.default(view_1359, [1, 0]) + mm_223 = torch.ops.aten.mm.default(permute_429, view_591); permute_429 = None + permute_431 = torch.ops.aten.permute.default(permute_89, [1, 0]); permute_89 = None + mm_224 = torch.ops.aten.mm.default(view_1359, permute_431); view_1359 = permute_431 = None + view_1360 = torch.ops.aten.view.default(mm_224, [2, 8192, 4096]); mm_224 = None + add_118 = torch.ops.aten.add.Tensor(view_1358, view_1360); view_1358 = view_1360 = None + convert_element_type_966 = torch.ops.prims.convert_element_type.default(mm_223, torch.float32); mm_223 = None + reduce_scatter_tensor_120 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_966, 'avg', 8, '0'); convert_element_type_966 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_120); reduce_scatter_tensor_120 = None + view_1361 = torch.ops.aten.view.default(view_1356, [16384, 512]); view_1356 = None + permute_433 = torch.ops.aten.permute.default(view_1361, [1, 0]) + mm_225 = torch.ops.aten.mm.default(permute_433, view_591); permute_433 = view_591 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16); primals_77 = None + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 8, '0'); convert_element_type_268 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + permute_435 = torch.ops.aten.permute.default(permute_88, [1, 0]); permute_88 = None + mm_226 = torch.ops.aten.mm.default(view_1361, permute_435); view_1361 = permute_435 = None + view_1362 = torch.ops.aten.view.default(mm_226, [2, 8192, 4096]); mm_226 = None + add_119 = torch.ops.aten.add.Tensor(add_118, view_1362); add_118 = view_1362 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(mm_225, torch.float32); mm_225 = None + reduce_scatter_tensor_121 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_971, 'avg', 8, '0'); convert_element_type_971 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_121); reduce_scatter_tensor_121 = None + split_106 = torch.ops.aten.split.Tensor(add_119, 1024, 1); add_119 = None + getitem_1016 = split_106[0] + getitem_1017 = split_106[1] + getitem_1018 = split_106[2] + getitem_1019 = split_106[3] + getitem_1020 = split_106[4] + getitem_1021 = split_106[5] + getitem_1022 = split_106[6] + getitem_1023 = split_106[7]; split_106 = None + cat_98 = torch.ops.aten.cat.default([getitem_1016, getitem_1017, getitem_1018, getitem_1019, getitem_1020, getitem_1021, getitem_1022, getitem_1023]); getitem_1016 = getitem_1017 = getitem_1018 = getitem_1019 = getitem_1020 = getitem_1021 = getitem_1022 = getitem_1023 = None + reduce_scatter_tensor_122 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_98, 'sum', 8, '1'); cat_98 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_122); reduce_scatter_tensor_122 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(wait_tensor_334, torch.float32); wait_tensor_334 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(wait_tensor_106, torch.float32); wait_tensor_106 = None + mul_290 = torch.ops.aten.mul.Tensor(convert_element_type_972, convert_element_type_974); convert_element_type_974 = None + mul_292 = torch.ops.aten.mul.Tensor(mul_64, mul_290) + sum_49 = torch.ops.aten.sum.dim_IntList(mul_292, [2], True); mul_292 = None + div_16 = torch.ops.aten.div.Tensor(mul_64, 4096) + mul_293 = torch.ops.aten.mul.Tensor(div_16, sum_49); div_16 = sum_49 = None + sub_25 = torch.ops.aten.sub.Tensor(mul_290, mul_293); mul_290 = mul_293 = None + mul_294 = torch.ops.aten.mul.Tensor(sub_25, rsqrt_16); sub_25 = rsqrt_16 = None + mul_295 = torch.ops.aten.mul.Tensor(convert_element_type_972, mul_64); convert_element_type_972 = mul_64 = None + sum_50 = torch.ops.aten.sum.dim_IntList(mul_295, [0, 1]); mul_295 = None + convert_element_type_975 = torch.ops.prims.convert_element_type.default(mul_294, torch.bfloat16); mul_294 = None + convert_element_type_976 = torch.ops.prims.convert_element_type.default(sum_50, torch.bfloat16); sum_50 = None + all_reduce_16 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_976, 'sum', '1'); convert_element_type_976 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_16); all_reduce_16 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(wait_tensor_335, torch.float32); wait_tensor_335 = None + reduce_scatter_tensor_123 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_977, 'avg', 8, '0'); convert_element_type_977 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_123); reduce_scatter_tensor_123 = None + add_120 = torch.ops.aten.add.Tensor(add_117, convert_element_type_975); add_117 = convert_element_type_975 = None + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_120, 8, '1') + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + split_107 = torch.ops.aten.split.Tensor(wait_tensor_337, 2); wait_tensor_337 = None + getitem_1024 = split_107[0] + getitem_1025 = split_107[1] + getitem_1026 = split_107[2] + getitem_1027 = split_107[3] + getitem_1028 = split_107[4] + getitem_1029 = split_107[5] + getitem_1030 = split_107[6] + getitem_1031 = split_107[7]; split_107 = None + cat_99 = torch.ops.aten.cat.default([getitem_1024, getitem_1025, getitem_1026, getitem_1027, getitem_1028, getitem_1029, getitem_1030, getitem_1031], 1); getitem_1024 = getitem_1025 = getitem_1026 = getitem_1027 = getitem_1028 = getitem_1029 = getitem_1030 = getitem_1031 = None + view_1363 = torch.ops.aten.view.default(cat_99, [16384, 4096]); cat_99 = None + permute_437 = torch.ops.aten.permute.default(view_1363, [1, 0]) + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15); reduce_scatter_tensor_15 = None + add_29 = torch.ops.aten.add.Tensor(add_27, wait_tensor_99); wait_tensor_99 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 8, '0'); convert_element_type_251 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32); add_29 = None + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_100) + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 8, '1'); convert_element_type_253 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + split_39 = torch.ops.aten.split.Tensor(wait_tensor_101, 2); wait_tensor_101 = None + getitem_384 = split_39[0] + getitem_385 = split_39[1] + getitem_386 = split_39[2] + getitem_387 = split_39[3] + getitem_388 = split_39[4] + getitem_389 = split_39[5] + getitem_390 = split_39[6] + getitem_391 = split_39[7]; split_39 = None + cat_31 = torch.ops.aten.cat.default([getitem_384, getitem_385, getitem_386, getitem_387, getitem_388, getitem_389, getitem_390, getitem_391], 1); getitem_384 = getitem_385 = getitem_386 = getitem_387 = getitem_388 = getitem_389 = getitem_390 = getitem_391 = None + view_564 = torch.ops.aten.view.default(cat_31, [16384, 4096]); cat_31 = None + view_565 = torch.ops.aten.view.default(mm_53, [2, 8192, 1792]); mm_53 = None + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_565, torch.float32); view_565 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16); primals_74 = None + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 8, '0'); convert_element_type_259 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_54 = torch.ops.aten.mm.default(view_564, permute_86) + view_572 = torch.ops.aten.view.default(mm_54, [2, 8192, 1792]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_572) + view_579 = torch.ops.aten.view.default(mul_63, [16384, 1792]); mul_63 = None + mm_227 = torch.ops.aten.mm.default(permute_437, view_579); permute_437 = view_579 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16); primals_75 = None + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 8, '0'); convert_element_type_262 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + permute_439 = torch.ops.aten.permute.default(permute_87, [1, 0]); permute_87 = None + mm_228 = torch.ops.aten.mm.default(view_1363, permute_439); view_1363 = permute_439 = None + view_1364 = torch.ops.aten.view.default(mm_228, [2, 8192, 1792]); mm_228 = None + convert_element_type_982 = torch.ops.prims.convert_element_type.default(mm_227, torch.float32); mm_227 = None + reduce_scatter_tensor_124 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_982, 'avg', 8, '0'); convert_element_type_982 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_124); reduce_scatter_tensor_124 = None + mul_296 = torch.ops.aten.mul.Tensor(view_1364, convert_element_type_258); convert_element_type_258 = None + mul_297 = torch.ops.aten.mul.Tensor(view_1364, view_572); view_1364 = view_572 = None + view_1365 = torch.ops.aten.view.default(mul_296, [16384, 1792]); mul_296 = None + permute_441 = torch.ops.aten.permute.default(view_1365, [1, 0]) + mm_229 = torch.ops.aten.mm.default(permute_441, view_564); permute_441 = None + permute_443 = torch.ops.aten.permute.default(permute_86, [1, 0]); permute_86 = None + mm_230 = torch.ops.aten.mm.default(view_1365, permute_443); view_1365 = permute_443 = None + view_1366 = torch.ops.aten.view.default(mm_230, [2, 8192, 4096]); mm_230 = None + convert_element_type_987 = torch.ops.prims.convert_element_type.default(mm_229, torch.float32); mm_229 = None + reduce_scatter_tensor_125 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_987, 'avg', 8, '0'); convert_element_type_987 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_125); reduce_scatter_tensor_125 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(mul_297, torch.float32); mul_297 = None + neg_8 = torch.ops.aten.neg.default(convert_element_type_257) + exp_8 = torch.ops.aten.exp.default(neg_8); neg_8 = None + add_121 = torch.ops.aten.add.Tensor(exp_8, 1); exp_8 = None + reciprocal_8 = torch.ops.aten.reciprocal.default(add_121); add_121 = None + mul_298 = torch.ops.aten.mul.Tensor(reciprocal_8, 1); reciprocal_8 = None + mul_299 = torch.ops.aten.mul.Tensor(convert_element_type_988, mul_298); convert_element_type_988 = None + sub_26 = torch.ops.aten.sub.Tensor(1, mul_298); mul_298 = None + mul_300 = torch.ops.aten.mul.Tensor(convert_element_type_257, sub_26); convert_element_type_257 = sub_26 = None + add_122 = torch.ops.aten.add.Tensor(mul_300, 1); mul_300 = None + mul_301 = torch.ops.aten.mul.Tensor(mul_299, add_122); mul_299 = add_122 = None + convert_element_type_990 = torch.ops.prims.convert_element_type.default(mul_301, torch.bfloat16); mul_301 = None + view_1367 = torch.ops.aten.view.default(convert_element_type_990, [16384, 1792]); convert_element_type_990 = None + permute_445 = torch.ops.aten.permute.default(view_1367, [1, 0]) + mm_231 = torch.ops.aten.mm.default(permute_445, view_564); permute_445 = view_564 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16); primals_73 = None + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 8, '0'); convert_element_type_254 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + permute_447 = torch.ops.aten.permute.default(permute_85, [1, 0]); permute_85 = None + mm_232 = torch.ops.aten.mm.default(view_1367, permute_447); view_1367 = permute_447 = None + view_1368 = torch.ops.aten.view.default(mm_232, [2, 8192, 4096]); mm_232 = None + add_123 = torch.ops.aten.add.Tensor(view_1366, view_1368); view_1366 = view_1368 = None + convert_element_type_995 = torch.ops.prims.convert_element_type.default(mm_231, torch.float32); mm_231 = None + reduce_scatter_tensor_126 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_995, 'avg', 8, '0'); convert_element_type_995 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_126); reduce_scatter_tensor_126 = None + split_108 = torch.ops.aten.split.Tensor(add_123, 1024, 1); add_123 = None + getitem_1032 = split_108[0] + getitem_1033 = split_108[1] + getitem_1034 = split_108[2] + getitem_1035 = split_108[3] + getitem_1036 = split_108[4] + getitem_1037 = split_108[5] + getitem_1038 = split_108[6] + getitem_1039 = split_108[7]; split_108 = None + cat_100 = torch.ops.aten.cat.default([getitem_1032, getitem_1033, getitem_1034, getitem_1035, getitem_1036, getitem_1037, getitem_1038, getitem_1039]); getitem_1032 = getitem_1033 = getitem_1034 = getitem_1035 = getitem_1036 = getitem_1037 = getitem_1038 = getitem_1039 = None + reduce_scatter_tensor_127 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_100, 'sum', 8, '1'); cat_100 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_127); reduce_scatter_tensor_127 = None + convert_element_type_996 = torch.ops.prims.convert_element_type.default(wait_tensor_341, torch.float32); wait_tensor_341 = None + convert_element_type_998 = torch.ops.prims.convert_element_type.default(wait_tensor_100, torch.float32); wait_tensor_100 = None + mul_302 = torch.ops.aten.mul.Tensor(convert_element_type_996, convert_element_type_998); convert_element_type_998 = None + mul_304 = torch.ops.aten.mul.Tensor(mul_60, mul_302) + sum_51 = torch.ops.aten.sum.dim_IntList(mul_304, [2], True); mul_304 = None + div_17 = torch.ops.aten.div.Tensor(mul_60, 4096) + mul_305 = torch.ops.aten.mul.Tensor(div_17, sum_51); div_17 = sum_51 = None + sub_27 = torch.ops.aten.sub.Tensor(mul_302, mul_305); mul_302 = mul_305 = None + mul_306 = torch.ops.aten.mul.Tensor(sub_27, rsqrt_15); sub_27 = rsqrt_15 = None + mul_307 = torch.ops.aten.mul.Tensor(convert_element_type_996, mul_60); convert_element_type_996 = mul_60 = None + sum_52 = torch.ops.aten.sum.dim_IntList(mul_307, [0, 1]); mul_307 = None + convert_element_type_999 = torch.ops.prims.convert_element_type.default(mul_306, torch.bfloat16); mul_306 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(sum_52, torch.bfloat16); sum_52 = None + all_reduce_17 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1000, 'sum', '1'); convert_element_type_1000 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_17); all_reduce_17 = None + convert_element_type_1001 = torch.ops.prims.convert_element_type.default(wait_tensor_342, torch.float32); wait_tensor_342 = None + reduce_scatter_tensor_128 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1001, 'avg', 8, '0'); convert_element_type_1001 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_128); reduce_scatter_tensor_128 = None + add_124 = torch.ops.aten.add.Tensor(add_120, convert_element_type_999); add_120 = convert_element_type_999 = None + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_124, 8, '1') + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + split_109 = torch.ops.aten.split.Tensor(wait_tensor_344, 2); wait_tensor_344 = None + getitem_1040 = split_109[0] + getitem_1041 = split_109[1] + getitem_1042 = split_109[2] + getitem_1043 = split_109[3] + getitem_1044 = split_109[4] + getitem_1045 = split_109[5] + getitem_1046 = split_109[6] + getitem_1047 = split_109[7]; split_109 = None + cat_101 = torch.ops.aten.cat.default([getitem_1040, getitem_1041, getitem_1042, getitem_1043, getitem_1044, getitem_1045, getitem_1046, getitem_1047], 1); getitem_1040 = getitem_1041 = getitem_1042 = getitem_1043 = getitem_1044 = getitem_1045 = getitem_1046 = getitem_1047 = None + view_1369 = torch.ops.aten.view.default(cat_101, [16384, 4096]); cat_101 = None + permute_449 = torch.ops.aten.permute.default(view_1369, [1, 0]) + permute_83 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]) + view_546 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + view_552 = torch.ops.aten.view.default(view_546, [16384, 512]); view_546 = None + mm_233 = torch.ops.aten.mm.default(permute_449, view_552); permute_449 = view_552 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 8, '0'); convert_element_type_248 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + permute_451 = torch.ops.aten.permute.default(permute_84, [1, 0]); permute_84 = None + mm_234 = torch.ops.aten.mm.default(view_1369, permute_451); view_1369 = permute_451 = None + view_1370 = torch.ops.aten.view.default(mm_234, [2, 8192, 512]); mm_234 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(mm_233, torch.float32); mm_233 = None + reduce_scatter_tensor_129 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1006, 'avg', 8, '0'); convert_element_type_1006 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_129); reduce_scatter_tensor_129 = None + view_1371 = torch.ops.aten.view.default(view_1370, [2, 8192, 4, 128]); view_1370 = None + permute_453 = torch.ops.aten.permute.default(view_1371, [0, 2, 1, 3]); view_1371 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 8, '0'); convert_element_type_232 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32); add_27 = None + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_93) + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 8, '1'); convert_element_type_234 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + split_37 = torch.ops.aten.split.Tensor(wait_tensor_94, 2); wait_tensor_94 = None + getitem_359 = split_37[0] + getitem_360 = split_37[1] + getitem_361 = split_37[2] + getitem_362 = split_37[3] + getitem_363 = split_37[4] + getitem_364 = split_37[5] + getitem_365 = split_37[6] + getitem_366 = split_37[7]; split_37 = None + cat_29 = torch.ops.aten.cat.default([getitem_359, getitem_360, getitem_361, getitem_362, getitem_363, getitem_364, getitem_365, getitem_366], 1); getitem_359 = getitem_360 = getitem_361 = getitem_362 = getitem_363 = getitem_364 = getitem_365 = getitem_366 = None + view_519 = torch.ops.aten.view.default(cat_29, [16384, 4096]); cat_29 = None + view_520 = torch.ops.aten.view.default(mm_49, [2, 8192, 512]); mm_49 = None + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 8, '0'); convert_element_type_238 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + mm_50 = torch.ops.aten.mm.default(view_519, permute_78) + view_527 = torch.ops.aten.view.default(mm_50, [2, 8192, 128]); mm_50 = None + view_534 = torch.ops.aten.view.default(mm_51, [2, 8192, 128]); mm_51 = None + view_536 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + view_537 = torch.ops.aten.view.default(view_527, [2, 8192, -1, 128]); view_527 = None + view_538 = torch.ops.aten.view.default(view_534, [2, 8192, -1, 128]); view_534 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_536, torch.float32); view_536 = None + view_539 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 4, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_539); view_539 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_537, torch.float32); view_537 = None + view_540 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 1, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_540); view_540 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_37); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_542 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 4, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_37); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_543 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 1, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_542, torch.bfloat16); view_542 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_543, torch.bfloat16); view_543 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 1, 4, 128]); unsqueeze_14 = None + view_544 = torch.ops.aten.view.default(expand_14, [2, 8192, 4, 128]); expand_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_538, 3); view_538 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 1, 4, 128]); unsqueeze_15 = None + view_545 = torch.ops.aten.view.default(expand_15, [2, 8192, 4, 128]); expand_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_544, [0, 2, 1, 3]); view_544 = None + permute_82 = torch.ops.aten.permute.default(view_545, [0, 2, 1, 3]); view_545 = None + _scaled_dot_product_cudnn_attention_backward_8 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_453, permute_80, permute_81, permute_82, getitem_367, getitem_368, getitem_373, getitem_374, None, None, None, 8192, 8192, 0.0, True); permute_453 = permute_80 = permute_81 = permute_82 = getitem_367 = getitem_368 = getitem_373 = getitem_374 = None + getitem_1048 = _scaled_dot_product_cudnn_attention_backward_8[0] + getitem_1049 = _scaled_dot_product_cudnn_attention_backward_8[1] + getitem_1050 = _scaled_dot_product_cudnn_attention_backward_8[2]; _scaled_dot_product_cudnn_attention_backward_8 = None + permute_454 = torch.ops.aten.permute.default(getitem_1050, [0, 2, 1, 3]); getitem_1050 = None + permute_455 = torch.ops.aten.permute.default(getitem_1049, [0, 2, 1, 3]); getitem_1049 = None + permute_456 = torch.ops.aten.permute.default(getitem_1048, [0, 2, 1, 3]); getitem_1048 = None + view_1372 = torch.ops.aten.view.default(permute_454, [2, 8192, 1, 4, 128]); permute_454 = None + sum_53 = torch.ops.aten.sum.dim_IntList(view_1372, [3], True); view_1372 = None + squeeze_16 = torch.ops.aten.squeeze.dim(sum_53, 3); sum_53 = None + view_1373 = torch.ops.aten.view.default(permute_455, [2, 8192, 1, 4, 128]); permute_455 = None + sum_54 = torch.ops.aten.sum.dim_IntList(view_1373, [3], True); view_1373 = None + squeeze_17 = torch.ops.aten.squeeze.dim(sum_54, 3); sum_54 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(squeeze_17, torch.float32); squeeze_17 = None + convert_element_type_1008 = torch.ops.prims.convert_element_type.default(permute_456, torch.float32); permute_456 = None + view_1374 = torch.ops.aten.view.default(convert_element_type_1007, [2, 8192, 1, 64, 2]); convert_element_type_1007 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_1374); view_1374 = None + mul_308 = torch.ops.aten.mul.Tensor(view_as_complex_48, _conj); view_as_complex_48 = None + view_1375 = torch.ops.aten.view.default(convert_element_type_1008, [2, 8192, 4, 64, 2]); convert_element_type_1008 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_1375); view_1375 = None + mul_309 = torch.ops.aten.mul.Tensor(view_as_complex_49, _conj); view_as_complex_49 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_308); mul_308 = None + view_1376 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 1, 128]); view_as_real_48 = None + convert_element_type_1009 = torch.ops.prims.convert_element_type.default(view_1376, torch.bfloat16); view_1376 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_309); mul_309 = None + view_1377 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 4, 128]); view_as_real_49 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(view_1377, torch.bfloat16); view_1377 = None + view_1378 = torch.ops.aten.view.default(squeeze_16, [2, 8192, 128]); squeeze_16 = None + view_1379 = torch.ops.aten.view.default(convert_element_type_1009, [2, 8192, 128]); convert_element_type_1009 = None + view_1380 = torch.ops.aten.view.default(convert_element_type_1010, [2, 8192, 512]); convert_element_type_1010 = None + view_1381 = torch.ops.aten.view.default(view_1378, [16384, 128]); view_1378 = None + permute_457 = torch.ops.aten.permute.default(view_1381, [1, 0]) + mm_235 = torch.ops.aten.mm.default(permute_457, view_519); permute_457 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 8, '0'); convert_element_type_241 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + permute_459 = torch.ops.aten.permute.default(permute_79, [1, 0]); permute_79 = None + mm_236 = torch.ops.aten.mm.default(view_1381, permute_459); view_1381 = permute_459 = None + view_1382 = torch.ops.aten.view.default(mm_236, [2, 8192, 4096]); mm_236 = None + convert_element_type_1015 = torch.ops.prims.convert_element_type.default(mm_235, torch.float32); mm_235 = None + reduce_scatter_tensor_130 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1015, 'avg', 8, '0'); convert_element_type_1015 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_130); reduce_scatter_tensor_130 = None + view_1383 = torch.ops.aten.view.default(view_1379, [16384, 128]); view_1379 = None + permute_461 = torch.ops.aten.permute.default(view_1383, [1, 0]) + mm_237 = torch.ops.aten.mm.default(permute_461, view_519); permute_461 = None + permute_463 = torch.ops.aten.permute.default(permute_78, [1, 0]); permute_78 = None + mm_238 = torch.ops.aten.mm.default(view_1383, permute_463); view_1383 = permute_463 = None + view_1384 = torch.ops.aten.view.default(mm_238, [2, 8192, 4096]); mm_238 = None + add_125 = torch.ops.aten.add.Tensor(view_1382, view_1384); view_1382 = view_1384 = None + convert_element_type_1020 = torch.ops.prims.convert_element_type.default(mm_237, torch.float32); mm_237 = None + reduce_scatter_tensor_131 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1020, 'avg', 8, '0'); convert_element_type_1020 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_131); reduce_scatter_tensor_131 = None + view_1385 = torch.ops.aten.view.default(view_1380, [16384, 512]); view_1380 = None + permute_465 = torch.ops.aten.permute.default(view_1385, [1, 0]) + mm_239 = torch.ops.aten.mm.default(permute_465, view_519); permute_465 = view_519 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 8, '0'); convert_element_type_235 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + permute_467 = torch.ops.aten.permute.default(permute_77, [1, 0]); permute_77 = None + mm_240 = torch.ops.aten.mm.default(view_1385, permute_467); view_1385 = permute_467 = None + view_1386 = torch.ops.aten.view.default(mm_240, [2, 8192, 4096]); mm_240 = None + add_126 = torch.ops.aten.add.Tensor(add_125, view_1386); add_125 = view_1386 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(mm_239, torch.float32); mm_239 = None + reduce_scatter_tensor_132 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1025, 'avg', 8, '0'); convert_element_type_1025 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_132); reduce_scatter_tensor_132 = None + split_110 = torch.ops.aten.split.Tensor(add_126, 1024, 1); add_126 = None + getitem_1051 = split_110[0] + getitem_1052 = split_110[1] + getitem_1053 = split_110[2] + getitem_1054 = split_110[3] + getitem_1055 = split_110[4] + getitem_1056 = split_110[5] + getitem_1057 = split_110[6] + getitem_1058 = split_110[7]; split_110 = None + cat_102 = torch.ops.aten.cat.default([getitem_1051, getitem_1052, getitem_1053, getitem_1054, getitem_1055, getitem_1056, getitem_1057, getitem_1058]); getitem_1051 = getitem_1052 = getitem_1053 = getitem_1054 = getitem_1055 = getitem_1056 = getitem_1057 = getitem_1058 = None + reduce_scatter_tensor_133 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_102, 'sum', 8, '1'); cat_102 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_133); reduce_scatter_tensor_133 = None + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(wait_tensor_349, torch.float32); wait_tensor_349 = None + convert_element_type_1028 = torch.ops.prims.convert_element_type.default(wait_tensor_93, torch.float32); wait_tensor_93 = None + mul_310 = torch.ops.aten.mul.Tensor(convert_element_type_1026, convert_element_type_1028); convert_element_type_1028 = None + mul_312 = torch.ops.aten.mul.Tensor(mul_56, mul_310) + sum_55 = torch.ops.aten.sum.dim_IntList(mul_312, [2], True); mul_312 = None + div_18 = torch.ops.aten.div.Tensor(mul_56, 4096) + mul_313 = torch.ops.aten.mul.Tensor(div_18, sum_55); div_18 = sum_55 = None + sub_28 = torch.ops.aten.sub.Tensor(mul_310, mul_313); mul_310 = mul_313 = None + mul_314 = torch.ops.aten.mul.Tensor(sub_28, rsqrt_14); sub_28 = rsqrt_14 = None + mul_315 = torch.ops.aten.mul.Tensor(convert_element_type_1026, mul_56); convert_element_type_1026 = mul_56 = None + sum_56 = torch.ops.aten.sum.dim_IntList(mul_315, [0, 1]); mul_315 = None + convert_element_type_1029 = torch.ops.prims.convert_element_type.default(mul_314, torch.bfloat16); mul_314 = None + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(sum_56, torch.bfloat16); sum_56 = None + all_reduce_18 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1030, 'sum', '1'); convert_element_type_1030 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_18); all_reduce_18 = None + convert_element_type_1031 = torch.ops.prims.convert_element_type.default(wait_tensor_350, torch.float32); wait_tensor_350 = None + reduce_scatter_tensor_134 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1031, 'avg', 8, '0'); convert_element_type_1031 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_134); reduce_scatter_tensor_134 = None + add_127 = torch.ops.aten.add.Tensor(add_124, convert_element_type_1029); add_124 = convert_element_type_1029 = None + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_127, 8, '1') + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + split_111 = torch.ops.aten.split.Tensor(wait_tensor_352, 2); wait_tensor_352 = None + getitem_1059 = split_111[0] + getitem_1060 = split_111[1] + getitem_1061 = split_111[2] + getitem_1062 = split_111[3] + getitem_1063 = split_111[4] + getitem_1064 = split_111[5] + getitem_1065 = split_111[6] + getitem_1066 = split_111[7]; split_111 = None + cat_103 = torch.ops.aten.cat.default([getitem_1059, getitem_1060, getitem_1061, getitem_1062, getitem_1063, getitem_1064, getitem_1065, getitem_1066], 1); getitem_1059 = getitem_1060 = getitem_1061 = getitem_1062 = getitem_1063 = getitem_1064 = getitem_1065 = getitem_1066 = None + view_1387 = torch.ops.aten.view.default(cat_103, [16384, 4096]); cat_103 = None + permute_469 = torch.ops.aten.permute.default(view_1387, [1, 0]) + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13); reduce_scatter_tensor_13 = None + add_25 = torch.ops.aten.add.Tensor(add_23, wait_tensor_86); wait_tensor_86 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16); primals_63 = None + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 8, '0'); convert_element_type_218 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32); add_25 = None + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_87) + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_220, 8, '1'); convert_element_type_220 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + split_35 = torch.ops.aten.split.Tensor(wait_tensor_88, 2); wait_tensor_88 = None + getitem_343 = split_35[0] + getitem_344 = split_35[1] + getitem_345 = split_35[2] + getitem_346 = split_35[3] + getitem_347 = split_35[4] + getitem_348 = split_35[5] + getitem_349 = split_35[6] + getitem_350 = split_35[7]; split_35 = None + cat_27 = torch.ops.aten.cat.default([getitem_343, getitem_344, getitem_345, getitem_346, getitem_347, getitem_348, getitem_349, getitem_350], 1); getitem_343 = getitem_344 = getitem_345 = getitem_346 = getitem_347 = getitem_348 = getitem_349 = getitem_350 = None + view_492 = torch.ops.aten.view.default(cat_27, [16384, 4096]); cat_27 = None + view_493 = torch.ops.aten.view.default(mm_46, [2, 8192, 1792]); mm_46 = None + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_493, torch.float32); view_493 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 8, '0'); convert_element_type_226 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + mm_47 = torch.ops.aten.mm.default(view_492, permute_75) + view_500 = torch.ops.aten.view.default(mm_47, [2, 8192, 1792]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_500) + view_507 = torch.ops.aten.view.default(mul_55, [16384, 1792]); mul_55 = None + mm_241 = torch.ops.aten.mm.default(permute_469, view_507); permute_469 = view_507 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 8, '0'); convert_element_type_229 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_91, [1, 0]); wait_tensor_91 = None + permute_471 = torch.ops.aten.permute.default(permute_76, [1, 0]); permute_76 = None + mm_242 = torch.ops.aten.mm.default(view_1387, permute_471); view_1387 = permute_471 = None + view_1388 = torch.ops.aten.view.default(mm_242, [2, 8192, 1792]); mm_242 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(mm_241, torch.float32); mm_241 = None + reduce_scatter_tensor_135 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1036, 'avg', 8, '0'); convert_element_type_1036 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_135); reduce_scatter_tensor_135 = None + mul_316 = torch.ops.aten.mul.Tensor(view_1388, convert_element_type_225); convert_element_type_225 = None + mul_317 = torch.ops.aten.mul.Tensor(view_1388, view_500); view_1388 = view_500 = None + view_1389 = torch.ops.aten.view.default(mul_316, [16384, 1792]); mul_316 = None + permute_473 = torch.ops.aten.permute.default(view_1389, [1, 0]) + mm_243 = torch.ops.aten.mm.default(permute_473, view_492); permute_473 = None + permute_475 = torch.ops.aten.permute.default(permute_75, [1, 0]); permute_75 = None + mm_244 = torch.ops.aten.mm.default(view_1389, permute_475); view_1389 = permute_475 = None + view_1390 = torch.ops.aten.view.default(mm_244, [2, 8192, 4096]); mm_244 = None + convert_element_type_1041 = torch.ops.prims.convert_element_type.default(mm_243, torch.float32); mm_243 = None + reduce_scatter_tensor_136 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1041, 'avg', 8, '0'); convert_element_type_1041 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_136); reduce_scatter_tensor_136 = None + convert_element_type_1042 = torch.ops.prims.convert_element_type.default(mul_317, torch.float32); mul_317 = None + neg_9 = torch.ops.aten.neg.default(convert_element_type_224) + exp_9 = torch.ops.aten.exp.default(neg_9); neg_9 = None + add_128 = torch.ops.aten.add.Tensor(exp_9, 1); exp_9 = None + reciprocal_9 = torch.ops.aten.reciprocal.default(add_128); add_128 = None + mul_318 = torch.ops.aten.mul.Tensor(reciprocal_9, 1); reciprocal_9 = None + mul_319 = torch.ops.aten.mul.Tensor(convert_element_type_1042, mul_318); convert_element_type_1042 = None + sub_29 = torch.ops.aten.sub.Tensor(1, mul_318); mul_318 = None + mul_320 = torch.ops.aten.mul.Tensor(convert_element_type_224, sub_29); convert_element_type_224 = sub_29 = None + add_129 = torch.ops.aten.add.Tensor(mul_320, 1); mul_320 = None + mul_321 = torch.ops.aten.mul.Tensor(mul_319, add_129); mul_319 = add_129 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(mul_321, torch.bfloat16); mul_321 = None + view_1391 = torch.ops.aten.view.default(convert_element_type_1044, [16384, 1792]); convert_element_type_1044 = None + permute_477 = torch.ops.aten.permute.default(view_1391, [1, 0]) + mm_245 = torch.ops.aten.mm.default(permute_477, view_492); permute_477 = view_492 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16); primals_64 = None + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 8, '0'); convert_element_type_221 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + permute_479 = torch.ops.aten.permute.default(permute_74, [1, 0]); permute_74 = None + mm_246 = torch.ops.aten.mm.default(view_1391, permute_479); view_1391 = permute_479 = None + view_1392 = torch.ops.aten.view.default(mm_246, [2, 8192, 4096]); mm_246 = None + add_130 = torch.ops.aten.add.Tensor(view_1390, view_1392); view_1390 = view_1392 = None + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(mm_245, torch.float32); mm_245 = None + reduce_scatter_tensor_137 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1049, 'avg', 8, '0'); convert_element_type_1049 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_137); reduce_scatter_tensor_137 = None + split_112 = torch.ops.aten.split.Tensor(add_130, 1024, 1); add_130 = None + getitem_1067 = split_112[0] + getitem_1068 = split_112[1] + getitem_1069 = split_112[2] + getitem_1070 = split_112[3] + getitem_1071 = split_112[4] + getitem_1072 = split_112[5] + getitem_1073 = split_112[6] + getitem_1074 = split_112[7]; split_112 = None + cat_104 = torch.ops.aten.cat.default([getitem_1067, getitem_1068, getitem_1069, getitem_1070, getitem_1071, getitem_1072, getitem_1073, getitem_1074]); getitem_1067 = getitem_1068 = getitem_1069 = getitem_1070 = getitem_1071 = getitem_1072 = getitem_1073 = getitem_1074 = None + reduce_scatter_tensor_138 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_104, 'sum', 8, '1'); cat_104 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_138); reduce_scatter_tensor_138 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(wait_tensor_356, torch.float32); wait_tensor_356 = None + convert_element_type_1052 = torch.ops.prims.convert_element_type.default(wait_tensor_87, torch.float32); wait_tensor_87 = None + mul_322 = torch.ops.aten.mul.Tensor(convert_element_type_1050, convert_element_type_1052); convert_element_type_1052 = None + mul_324 = torch.ops.aten.mul.Tensor(mul_52, mul_322) + sum_57 = torch.ops.aten.sum.dim_IntList(mul_324, [2], True); mul_324 = None + div_19 = torch.ops.aten.div.Tensor(mul_52, 4096) + mul_325 = torch.ops.aten.mul.Tensor(div_19, sum_57); div_19 = sum_57 = None + sub_30 = torch.ops.aten.sub.Tensor(mul_322, mul_325); mul_322 = mul_325 = None + mul_326 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_13); sub_30 = rsqrt_13 = None + mul_327 = torch.ops.aten.mul.Tensor(convert_element_type_1050, mul_52); convert_element_type_1050 = mul_52 = None + sum_58 = torch.ops.aten.sum.dim_IntList(mul_327, [0, 1]); mul_327 = None + convert_element_type_1053 = torch.ops.prims.convert_element_type.default(mul_326, torch.bfloat16); mul_326 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(sum_58, torch.bfloat16); sum_58 = None + all_reduce_19 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1054, 'sum', '1'); convert_element_type_1054 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_19); all_reduce_19 = None + convert_element_type_1055 = torch.ops.prims.convert_element_type.default(wait_tensor_357, torch.float32); wait_tensor_357 = None + reduce_scatter_tensor_139 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1055, 'avg', 8, '0'); convert_element_type_1055 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_139); reduce_scatter_tensor_139 = None + add_131 = torch.ops.aten.add.Tensor(add_127, convert_element_type_1053); add_127 = convert_element_type_1053 = None + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_131, 8, '1') + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + split_113 = torch.ops.aten.split.Tensor(wait_tensor_359, 2); wait_tensor_359 = None + getitem_1075 = split_113[0] + getitem_1076 = split_113[1] + getitem_1077 = split_113[2] + getitem_1078 = split_113[3] + getitem_1079 = split_113[4] + getitem_1080 = split_113[5] + getitem_1081 = split_113[6] + getitem_1082 = split_113[7]; split_113 = None + cat_105 = torch.ops.aten.cat.default([getitem_1075, getitem_1076, getitem_1077, getitem_1078, getitem_1079, getitem_1080, getitem_1081, getitem_1082], 1); getitem_1075 = getitem_1076 = getitem_1077 = getitem_1078 = getitem_1079 = getitem_1080 = getitem_1081 = getitem_1082 = None + view_1393 = torch.ops.aten.view.default(cat_105, [16384, 4096]); cat_105 = None + permute_481 = torch.ops.aten.permute.default(view_1393, [1, 0]) + permute_72 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]) + view_474 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + view_480 = torch.ops.aten.view.default(view_474, [16384, 512]); view_474 = None + mm_247 = torch.ops.aten.mm.default(permute_481, view_480); permute_481 = view_480 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16); primals_62 = None + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 8, '0'); convert_element_type_215 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + permute_483 = torch.ops.aten.permute.default(permute_73, [1, 0]); permute_73 = None + mm_248 = torch.ops.aten.mm.default(view_1393, permute_483); view_1393 = permute_483 = None + view_1394 = torch.ops.aten.view.default(mm_248, [2, 8192, 512]); mm_248 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(mm_247, torch.float32); mm_247 = None + reduce_scatter_tensor_140 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1060, 'avg', 8, '0'); convert_element_type_1060 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_140); reduce_scatter_tensor_140 = None + view_1395 = torch.ops.aten.view.default(view_1394, [2, 8192, 4, 128]); view_1394 = None + permute_485 = torch.ops.aten.permute.default(view_1395, [0, 2, 1, 3]); view_1395 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16); primals_58 = None + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 8, '0'); convert_element_type_199 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32); add_23 = None + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_80) + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_201, 8, '1'); convert_element_type_201 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + split_33 = torch.ops.aten.split.Tensor(wait_tensor_81, 2); wait_tensor_81 = None + getitem_318 = split_33[0] + getitem_319 = split_33[1] + getitem_320 = split_33[2] + getitem_321 = split_33[3] + getitem_322 = split_33[4] + getitem_323 = split_33[5] + getitem_324 = split_33[6] + getitem_325 = split_33[7]; split_33 = None + cat_25 = torch.ops.aten.cat.default([getitem_318, getitem_319, getitem_320, getitem_321, getitem_322, getitem_323, getitem_324, getitem_325], 1); getitem_318 = getitem_319 = getitem_320 = getitem_321 = getitem_322 = getitem_323 = getitem_324 = getitem_325 = None + view_447 = torch.ops.aten.view.default(cat_25, [16384, 4096]); cat_25 = None + view_448 = torch.ops.aten.view.default(mm_42, [2, 8192, 512]); mm_42 = None + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16); primals_60 = None + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 8, '0'); convert_element_type_205 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + mm_43 = torch.ops.aten.mm.default(view_447, permute_67) + view_455 = torch.ops.aten.view.default(mm_43, [2, 8192, 128]); mm_43 = None + view_462 = torch.ops.aten.view.default(mm_44, [2, 8192, 128]); mm_44 = None + view_464 = torch.ops.aten.view.default(view_448, [2, 8192, -1, 128]); view_448 = None + view_465 = torch.ops.aten.view.default(view_455, [2, 8192, -1, 128]); view_455 = None + view_466 = torch.ops.aten.view.default(view_462, [2, 8192, -1, 128]); view_462 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_464, torch.float32); view_464 = None + view_467 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 4, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_467); view_467 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_465, torch.float32); view_465 = None + view_468 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 1, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_468); view_468 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_37); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_470 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 4, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_37); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_471 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 1, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_470, torch.bfloat16); view_470 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_471, torch.bfloat16); view_471 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 1, 4, 128]); unsqueeze_12 = None + view_472 = torch.ops.aten.view.default(expand_12, [2, 8192, 4, 128]); expand_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_466, 3); view_466 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 1, 4, 128]); unsqueeze_13 = None + view_473 = torch.ops.aten.view.default(expand_13, [2, 8192, 4, 128]); expand_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_472, [0, 2, 1, 3]); view_472 = None + permute_71 = torch.ops.aten.permute.default(view_473, [0, 2, 1, 3]); view_473 = None + _scaled_dot_product_cudnn_attention_backward_9 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_485, permute_69, permute_70, permute_71, getitem_326, getitem_327, getitem_332, getitem_333, None, None, None, 8192, 8192, 0.0, True); permute_485 = permute_69 = permute_70 = permute_71 = getitem_326 = getitem_327 = getitem_332 = getitem_333 = None + getitem_1083 = _scaled_dot_product_cudnn_attention_backward_9[0] + getitem_1084 = _scaled_dot_product_cudnn_attention_backward_9[1] + getitem_1085 = _scaled_dot_product_cudnn_attention_backward_9[2]; _scaled_dot_product_cudnn_attention_backward_9 = None + permute_486 = torch.ops.aten.permute.default(getitem_1085, [0, 2, 1, 3]); getitem_1085 = None + permute_487 = torch.ops.aten.permute.default(getitem_1084, [0, 2, 1, 3]); getitem_1084 = None + permute_488 = torch.ops.aten.permute.default(getitem_1083, [0, 2, 1, 3]); getitem_1083 = None + view_1396 = torch.ops.aten.view.default(permute_486, [2, 8192, 1, 4, 128]); permute_486 = None + sum_59 = torch.ops.aten.sum.dim_IntList(view_1396, [3], True); view_1396 = None + squeeze_18 = torch.ops.aten.squeeze.dim(sum_59, 3); sum_59 = None + view_1397 = torch.ops.aten.view.default(permute_487, [2, 8192, 1, 4, 128]); permute_487 = None + sum_60 = torch.ops.aten.sum.dim_IntList(view_1397, [3], True); view_1397 = None + squeeze_19 = torch.ops.aten.squeeze.dim(sum_60, 3); sum_60 = None + convert_element_type_1061 = torch.ops.prims.convert_element_type.default(squeeze_19, torch.float32); squeeze_19 = None + convert_element_type_1062 = torch.ops.prims.convert_element_type.default(permute_488, torch.float32); permute_488 = None + view_1398 = torch.ops.aten.view.default(convert_element_type_1061, [2, 8192, 1, 64, 2]); convert_element_type_1061 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_1398); view_1398 = None + mul_328 = torch.ops.aten.mul.Tensor(view_as_complex_50, _conj); view_as_complex_50 = None + view_1399 = torch.ops.aten.view.default(convert_element_type_1062, [2, 8192, 4, 64, 2]); convert_element_type_1062 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_1399); view_1399 = None + mul_329 = torch.ops.aten.mul.Tensor(view_as_complex_51, _conj); view_as_complex_51 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_328); mul_328 = None + view_1400 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 1, 128]); view_as_real_50 = None + convert_element_type_1063 = torch.ops.prims.convert_element_type.default(view_1400, torch.bfloat16); view_1400 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_329); mul_329 = None + view_1401 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 4, 128]); view_as_real_51 = None + convert_element_type_1064 = torch.ops.prims.convert_element_type.default(view_1401, torch.bfloat16); view_1401 = None + view_1402 = torch.ops.aten.view.default(squeeze_18, [2, 8192, 128]); squeeze_18 = None + view_1403 = torch.ops.aten.view.default(convert_element_type_1063, [2, 8192, 128]); convert_element_type_1063 = None + view_1404 = torch.ops.aten.view.default(convert_element_type_1064, [2, 8192, 512]); convert_element_type_1064 = None + view_1405 = torch.ops.aten.view.default(view_1402, [16384, 128]); view_1402 = None + permute_489 = torch.ops.aten.permute.default(view_1405, [1, 0]) + mm_249 = torch.ops.aten.mm.default(permute_489, view_447); permute_489 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16); primals_61 = None + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 8, '0'); convert_element_type_208 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + permute_491 = torch.ops.aten.permute.default(permute_68, [1, 0]); permute_68 = None + mm_250 = torch.ops.aten.mm.default(view_1405, permute_491); view_1405 = permute_491 = None + view_1406 = torch.ops.aten.view.default(mm_250, [2, 8192, 4096]); mm_250 = None + convert_element_type_1069 = torch.ops.prims.convert_element_type.default(mm_249, torch.float32); mm_249 = None + reduce_scatter_tensor_141 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1069, 'avg', 8, '0'); convert_element_type_1069 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_141); reduce_scatter_tensor_141 = None + view_1407 = torch.ops.aten.view.default(view_1403, [16384, 128]); view_1403 = None + permute_493 = torch.ops.aten.permute.default(view_1407, [1, 0]) + mm_251 = torch.ops.aten.mm.default(permute_493, view_447); permute_493 = None + permute_495 = torch.ops.aten.permute.default(permute_67, [1, 0]); permute_67 = None + mm_252 = torch.ops.aten.mm.default(view_1407, permute_495); view_1407 = permute_495 = None + view_1408 = torch.ops.aten.view.default(mm_252, [2, 8192, 4096]); mm_252 = None + add_132 = torch.ops.aten.add.Tensor(view_1406, view_1408); view_1406 = view_1408 = None + convert_element_type_1074 = torch.ops.prims.convert_element_type.default(mm_251, torch.float32); mm_251 = None + reduce_scatter_tensor_142 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1074, 'avg', 8, '0'); convert_element_type_1074 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_142); reduce_scatter_tensor_142 = None + view_1409 = torch.ops.aten.view.default(view_1404, [16384, 512]); view_1404 = None + permute_497 = torch.ops.aten.permute.default(view_1409, [1, 0]) + mm_253 = torch.ops.aten.mm.default(permute_497, view_447); permute_497 = view_447 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16); primals_59 = None + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 8, '0'); convert_element_type_202 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_82, [1, 0]); wait_tensor_82 = None + permute_499 = torch.ops.aten.permute.default(permute_66, [1, 0]); permute_66 = None + mm_254 = torch.ops.aten.mm.default(view_1409, permute_499); view_1409 = permute_499 = None + view_1410 = torch.ops.aten.view.default(mm_254, [2, 8192, 4096]); mm_254 = None + add_133 = torch.ops.aten.add.Tensor(add_132, view_1410); add_132 = view_1410 = None + convert_element_type_1079 = torch.ops.prims.convert_element_type.default(mm_253, torch.float32); mm_253 = None + reduce_scatter_tensor_143 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1079, 'avg', 8, '0'); convert_element_type_1079 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_143); reduce_scatter_tensor_143 = None + split_114 = torch.ops.aten.split.Tensor(add_133, 1024, 1); add_133 = None + getitem_1086 = split_114[0] + getitem_1087 = split_114[1] + getitem_1088 = split_114[2] + getitem_1089 = split_114[3] + getitem_1090 = split_114[4] + getitem_1091 = split_114[5] + getitem_1092 = split_114[6] + getitem_1093 = split_114[7]; split_114 = None + cat_106 = torch.ops.aten.cat.default([getitem_1086, getitem_1087, getitem_1088, getitem_1089, getitem_1090, getitem_1091, getitem_1092, getitem_1093]); getitem_1086 = getitem_1087 = getitem_1088 = getitem_1089 = getitem_1090 = getitem_1091 = getitem_1092 = getitem_1093 = None + reduce_scatter_tensor_144 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_106, 'sum', 8, '1'); cat_106 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_144); reduce_scatter_tensor_144 = None + convert_element_type_1080 = torch.ops.prims.convert_element_type.default(wait_tensor_364, torch.float32); wait_tensor_364 = None + convert_element_type_1082 = torch.ops.prims.convert_element_type.default(wait_tensor_80, torch.float32); wait_tensor_80 = None + mul_330 = torch.ops.aten.mul.Tensor(convert_element_type_1080, convert_element_type_1082); convert_element_type_1082 = None + mul_332 = torch.ops.aten.mul.Tensor(mul_48, mul_330) + sum_61 = torch.ops.aten.sum.dim_IntList(mul_332, [2], True); mul_332 = None + div_20 = torch.ops.aten.div.Tensor(mul_48, 4096) + mul_333 = torch.ops.aten.mul.Tensor(div_20, sum_61); div_20 = sum_61 = None + sub_31 = torch.ops.aten.sub.Tensor(mul_330, mul_333); mul_330 = mul_333 = None + mul_334 = torch.ops.aten.mul.Tensor(sub_31, rsqrt_12); sub_31 = rsqrt_12 = None + mul_335 = torch.ops.aten.mul.Tensor(convert_element_type_1080, mul_48); convert_element_type_1080 = mul_48 = None + sum_62 = torch.ops.aten.sum.dim_IntList(mul_335, [0, 1]); mul_335 = None + convert_element_type_1083 = torch.ops.prims.convert_element_type.default(mul_334, torch.bfloat16); mul_334 = None + convert_element_type_1084 = torch.ops.prims.convert_element_type.default(sum_62, torch.bfloat16); sum_62 = None + all_reduce_20 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1084, 'sum', '1'); convert_element_type_1084 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_20); all_reduce_20 = None + convert_element_type_1085 = torch.ops.prims.convert_element_type.default(wait_tensor_365, torch.float32); wait_tensor_365 = None + reduce_scatter_tensor_145 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1085, 'avg', 8, '0'); convert_element_type_1085 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_145); reduce_scatter_tensor_145 = None + add_134 = torch.ops.aten.add.Tensor(add_131, convert_element_type_1083); add_131 = convert_element_type_1083 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_134, 8, '1') + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + split_115 = torch.ops.aten.split.Tensor(wait_tensor_367, 2); wait_tensor_367 = None + getitem_1094 = split_115[0] + getitem_1095 = split_115[1] + getitem_1096 = split_115[2] + getitem_1097 = split_115[3] + getitem_1098 = split_115[4] + getitem_1099 = split_115[5] + getitem_1100 = split_115[6] + getitem_1101 = split_115[7]; split_115 = None + cat_107 = torch.ops.aten.cat.default([getitem_1094, getitem_1095, getitem_1096, getitem_1097, getitem_1098, getitem_1099, getitem_1100, getitem_1101], 1); getitem_1094 = getitem_1095 = getitem_1096 = getitem_1097 = getitem_1098 = getitem_1099 = getitem_1100 = getitem_1101 = None + view_1411 = torch.ops.aten.view.default(cat_107, [16384, 4096]); cat_107 = None + permute_501 = torch.ops.aten.permute.default(view_1411, [1, 0]) + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11); reduce_scatter_tensor_11 = None + add_21 = torch.ops.aten.add.Tensor(add_19, wait_tensor_73); wait_tensor_73 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 8, '0'); convert_element_type_185 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32); add_21 = None + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_74) + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_187, 8, '1'); convert_element_type_187 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + split_31 = torch.ops.aten.split.Tensor(wait_tensor_75, 2); wait_tensor_75 = None + getitem_302 = split_31[0] + getitem_303 = split_31[1] + getitem_304 = split_31[2] + getitem_305 = split_31[3] + getitem_306 = split_31[4] + getitem_307 = split_31[5] + getitem_308 = split_31[6] + getitem_309 = split_31[7]; split_31 = None + cat_23 = torch.ops.aten.cat.default([getitem_302, getitem_303, getitem_304, getitem_305, getitem_306, getitem_307, getitem_308, getitem_309], 1); getitem_302 = getitem_303 = getitem_304 = getitem_305 = getitem_306 = getitem_307 = getitem_308 = getitem_309 = None + view_420 = torch.ops.aten.view.default(cat_23, [16384, 4096]); cat_23 = None + view_421 = torch.ops.aten.view.default(mm_39, [2, 8192, 1792]); mm_39 = None + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_421, torch.float32); view_421 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16); primals_56 = None + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 8, '0'); convert_element_type_193 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + mm_40 = torch.ops.aten.mm.default(view_420, permute_64) + view_428 = torch.ops.aten.view.default(mm_40, [2, 8192, 1792]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_428) + view_435 = torch.ops.aten.view.default(mul_47, [16384, 1792]); mul_47 = None + mm_255 = torch.ops.aten.mm.default(permute_501, view_435); permute_501 = view_435 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16); primals_57 = None + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 8, '0'); convert_element_type_196 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + permute_503 = torch.ops.aten.permute.default(permute_65, [1, 0]); permute_65 = None + mm_256 = torch.ops.aten.mm.default(view_1411, permute_503); view_1411 = permute_503 = None + view_1412 = torch.ops.aten.view.default(mm_256, [2, 8192, 1792]); mm_256 = None + convert_element_type_1090 = torch.ops.prims.convert_element_type.default(mm_255, torch.float32); mm_255 = None + reduce_scatter_tensor_146 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1090, 'avg', 8, '0'); convert_element_type_1090 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_146); reduce_scatter_tensor_146 = None + mul_336 = torch.ops.aten.mul.Tensor(view_1412, convert_element_type_192); convert_element_type_192 = None + mul_337 = torch.ops.aten.mul.Tensor(view_1412, view_428); view_1412 = view_428 = None + view_1413 = torch.ops.aten.view.default(mul_336, [16384, 1792]); mul_336 = None + permute_505 = torch.ops.aten.permute.default(view_1413, [1, 0]) + mm_257 = torch.ops.aten.mm.default(permute_505, view_420); permute_505 = None + permute_507 = torch.ops.aten.permute.default(permute_64, [1, 0]); permute_64 = None + mm_258 = torch.ops.aten.mm.default(view_1413, permute_507); view_1413 = permute_507 = None + view_1414 = torch.ops.aten.view.default(mm_258, [2, 8192, 4096]); mm_258 = None + convert_element_type_1095 = torch.ops.prims.convert_element_type.default(mm_257, torch.float32); mm_257 = None + reduce_scatter_tensor_147 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1095, 'avg', 8, '0'); convert_element_type_1095 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_147); reduce_scatter_tensor_147 = None + convert_element_type_1096 = torch.ops.prims.convert_element_type.default(mul_337, torch.float32); mul_337 = None + neg_10 = torch.ops.aten.neg.default(convert_element_type_191) + exp_10 = torch.ops.aten.exp.default(neg_10); neg_10 = None + add_135 = torch.ops.aten.add.Tensor(exp_10, 1); exp_10 = None + reciprocal_10 = torch.ops.aten.reciprocal.default(add_135); add_135 = None + mul_338 = torch.ops.aten.mul.Tensor(reciprocal_10, 1); reciprocal_10 = None + mul_339 = torch.ops.aten.mul.Tensor(convert_element_type_1096, mul_338); convert_element_type_1096 = None + sub_32 = torch.ops.aten.sub.Tensor(1, mul_338); mul_338 = None + mul_340 = torch.ops.aten.mul.Tensor(convert_element_type_191, sub_32); convert_element_type_191 = sub_32 = None + add_136 = torch.ops.aten.add.Tensor(mul_340, 1); mul_340 = None + mul_341 = torch.ops.aten.mul.Tensor(mul_339, add_136); mul_339 = add_136 = None + convert_element_type_1098 = torch.ops.prims.convert_element_type.default(mul_341, torch.bfloat16); mul_341 = None + view_1415 = torch.ops.aten.view.default(convert_element_type_1098, [16384, 1792]); convert_element_type_1098 = None + permute_509 = torch.ops.aten.permute.default(view_1415, [1, 0]) + mm_259 = torch.ops.aten.mm.default(permute_509, view_420); permute_509 = view_420 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16); primals_55 = None + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 8, '0'); convert_element_type_188 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + permute_511 = torch.ops.aten.permute.default(permute_63, [1, 0]); permute_63 = None + mm_260 = torch.ops.aten.mm.default(view_1415, permute_511); view_1415 = permute_511 = None + view_1416 = torch.ops.aten.view.default(mm_260, [2, 8192, 4096]); mm_260 = None + add_137 = torch.ops.aten.add.Tensor(view_1414, view_1416); view_1414 = view_1416 = None + convert_element_type_1103 = torch.ops.prims.convert_element_type.default(mm_259, torch.float32); mm_259 = None + reduce_scatter_tensor_148 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1103, 'avg', 8, '0'); convert_element_type_1103 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_148); reduce_scatter_tensor_148 = None + split_116 = torch.ops.aten.split.Tensor(add_137, 1024, 1); add_137 = None + getitem_1102 = split_116[0] + getitem_1103 = split_116[1] + getitem_1104 = split_116[2] + getitem_1105 = split_116[3] + getitem_1106 = split_116[4] + getitem_1107 = split_116[5] + getitem_1108 = split_116[6] + getitem_1109 = split_116[7]; split_116 = None + cat_108 = torch.ops.aten.cat.default([getitem_1102, getitem_1103, getitem_1104, getitem_1105, getitem_1106, getitem_1107, getitem_1108, getitem_1109]); getitem_1102 = getitem_1103 = getitem_1104 = getitem_1105 = getitem_1106 = getitem_1107 = getitem_1108 = getitem_1109 = None + reduce_scatter_tensor_149 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_108, 'sum', 8, '1'); cat_108 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_149); reduce_scatter_tensor_149 = None + convert_element_type_1104 = torch.ops.prims.convert_element_type.default(wait_tensor_371, torch.float32); wait_tensor_371 = None + convert_element_type_1106 = torch.ops.prims.convert_element_type.default(wait_tensor_74, torch.float32); wait_tensor_74 = None + mul_342 = torch.ops.aten.mul.Tensor(convert_element_type_1104, convert_element_type_1106); convert_element_type_1106 = None + mul_344 = torch.ops.aten.mul.Tensor(mul_44, mul_342) + sum_63 = torch.ops.aten.sum.dim_IntList(mul_344, [2], True); mul_344 = None + div_21 = torch.ops.aten.div.Tensor(mul_44, 4096) + mul_345 = torch.ops.aten.mul.Tensor(div_21, sum_63); div_21 = sum_63 = None + sub_33 = torch.ops.aten.sub.Tensor(mul_342, mul_345); mul_342 = mul_345 = None + mul_346 = torch.ops.aten.mul.Tensor(sub_33, rsqrt_11); sub_33 = rsqrt_11 = None + mul_347 = torch.ops.aten.mul.Tensor(convert_element_type_1104, mul_44); convert_element_type_1104 = mul_44 = None + sum_64 = torch.ops.aten.sum.dim_IntList(mul_347, [0, 1]); mul_347 = None + convert_element_type_1107 = torch.ops.prims.convert_element_type.default(mul_346, torch.bfloat16); mul_346 = None + convert_element_type_1108 = torch.ops.prims.convert_element_type.default(sum_64, torch.bfloat16); sum_64 = None + all_reduce_21 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1108, 'sum', '1'); convert_element_type_1108 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_21); all_reduce_21 = None + convert_element_type_1109 = torch.ops.prims.convert_element_type.default(wait_tensor_372, torch.float32); wait_tensor_372 = None + reduce_scatter_tensor_150 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1109, 'avg', 8, '0'); convert_element_type_1109 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_150); reduce_scatter_tensor_150 = None + add_138 = torch.ops.aten.add.Tensor(add_134, convert_element_type_1107); add_134 = convert_element_type_1107 = None + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_138, 8, '1') + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + split_117 = torch.ops.aten.split.Tensor(wait_tensor_374, 2); wait_tensor_374 = None + getitem_1110 = split_117[0] + getitem_1111 = split_117[1] + getitem_1112 = split_117[2] + getitem_1113 = split_117[3] + getitem_1114 = split_117[4] + getitem_1115 = split_117[5] + getitem_1116 = split_117[6] + getitem_1117 = split_117[7]; split_117 = None + cat_109 = torch.ops.aten.cat.default([getitem_1110, getitem_1111, getitem_1112, getitem_1113, getitem_1114, getitem_1115, getitem_1116, getitem_1117], 1); getitem_1110 = getitem_1111 = getitem_1112 = getitem_1113 = getitem_1114 = getitem_1115 = getitem_1116 = getitem_1117 = None + view_1417 = torch.ops.aten.view.default(cat_109, [16384, 4096]); cat_109 = None + permute_513 = torch.ops.aten.permute.default(view_1417, [1, 0]) + permute_61 = torch.ops.aten.permute.default(getitem_285, [0, 2, 1, 3]) + view_402 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + view_408 = torch.ops.aten.view.default(view_402, [16384, 512]); view_402 = None + mm_261 = torch.ops.aten.mm.default(permute_513, view_408); permute_513 = view_408 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 8, '0'); convert_element_type_182 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + permute_515 = torch.ops.aten.permute.default(permute_62, [1, 0]); permute_62 = None + mm_262 = torch.ops.aten.mm.default(view_1417, permute_515); view_1417 = permute_515 = None + view_1418 = torch.ops.aten.view.default(mm_262, [2, 8192, 512]); mm_262 = None + convert_element_type_1114 = torch.ops.prims.convert_element_type.default(mm_261, torch.float32); mm_261 = None + reduce_scatter_tensor_151 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1114, 'avg', 8, '0'); convert_element_type_1114 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_151); reduce_scatter_tensor_151 = None + view_1419 = torch.ops.aten.view.default(view_1418, [2, 8192, 4, 128]); view_1418 = None + permute_517 = torch.ops.aten.permute.default(view_1419, [0, 2, 1, 3]); view_1419 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 8, '0'); convert_element_type_166 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32); add_19 = None + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_67) + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_168, 8, '1'); convert_element_type_168 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + split_29 = torch.ops.aten.split.Tensor(wait_tensor_68, 2); wait_tensor_68 = None + getitem_277 = split_29[0] + getitem_278 = split_29[1] + getitem_279 = split_29[2] + getitem_280 = split_29[3] + getitem_281 = split_29[4] + getitem_282 = split_29[5] + getitem_283 = split_29[6] + getitem_284 = split_29[7]; split_29 = None + cat_21 = torch.ops.aten.cat.default([getitem_277, getitem_278, getitem_279, getitem_280, getitem_281, getitem_282, getitem_283, getitem_284], 1); getitem_277 = getitem_278 = getitem_279 = getitem_280 = getitem_281 = getitem_282 = getitem_283 = getitem_284 = None + view_375 = torch.ops.aten.view.default(cat_21, [16384, 4096]); cat_21 = None + view_376 = torch.ops.aten.view.default(mm_35, [2, 8192, 512]); mm_35 = None + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 8, '0'); convert_element_type_172 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + mm_36 = torch.ops.aten.mm.default(view_375, permute_56) + view_383 = torch.ops.aten.view.default(mm_36, [2, 8192, 128]); mm_36 = None + view_390 = torch.ops.aten.view.default(mm_37, [2, 8192, 128]); mm_37 = None + view_392 = torch.ops.aten.view.default(view_376, [2, 8192, -1, 128]); view_376 = None + view_393 = torch.ops.aten.view.default(view_383, [2, 8192, -1, 128]); view_383 = None + view_394 = torch.ops.aten.view.default(view_390, [2, 8192, -1, 128]); view_390 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_392, torch.float32); view_392 = None + view_395 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 4, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_395); view_395 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_393, torch.float32); view_393 = None + view_396 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 1, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_396); view_396 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_37); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_398 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 4, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_37); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_399 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 1, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_398, torch.bfloat16); view_398 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_399, torch.bfloat16); view_399 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 1, 4, 128]); unsqueeze_10 = None + view_400 = torch.ops.aten.view.default(expand_10, [2, 8192, 4, 128]); expand_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_394, 3); view_394 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 1, 4, 128]); unsqueeze_11 = None + view_401 = torch.ops.aten.view.default(expand_11, [2, 8192, 4, 128]); expand_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_400, [0, 2, 1, 3]); view_400 = None + permute_60 = torch.ops.aten.permute.default(view_401, [0, 2, 1, 3]); view_401 = None + _scaled_dot_product_cudnn_attention_backward_10 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_517, permute_58, permute_59, permute_60, getitem_285, getitem_286, getitem_291, getitem_292, None, None, None, 8192, 8192, 0.0, True); permute_517 = permute_58 = permute_59 = permute_60 = getitem_285 = getitem_286 = getitem_291 = getitem_292 = None + getitem_1118 = _scaled_dot_product_cudnn_attention_backward_10[0] + getitem_1119 = _scaled_dot_product_cudnn_attention_backward_10[1] + getitem_1120 = _scaled_dot_product_cudnn_attention_backward_10[2]; _scaled_dot_product_cudnn_attention_backward_10 = None + permute_518 = torch.ops.aten.permute.default(getitem_1120, [0, 2, 1, 3]); getitem_1120 = None + permute_519 = torch.ops.aten.permute.default(getitem_1119, [0, 2, 1, 3]); getitem_1119 = None + permute_520 = torch.ops.aten.permute.default(getitem_1118, [0, 2, 1, 3]); getitem_1118 = None + view_1420 = torch.ops.aten.view.default(permute_518, [2, 8192, 1, 4, 128]); permute_518 = None + sum_65 = torch.ops.aten.sum.dim_IntList(view_1420, [3], True); view_1420 = None + squeeze_20 = torch.ops.aten.squeeze.dim(sum_65, 3); sum_65 = None + view_1421 = torch.ops.aten.view.default(permute_519, [2, 8192, 1, 4, 128]); permute_519 = None + sum_66 = torch.ops.aten.sum.dim_IntList(view_1421, [3], True); view_1421 = None + squeeze_21 = torch.ops.aten.squeeze.dim(sum_66, 3); sum_66 = None + convert_element_type_1115 = torch.ops.prims.convert_element_type.default(squeeze_21, torch.float32); squeeze_21 = None + convert_element_type_1116 = torch.ops.prims.convert_element_type.default(permute_520, torch.float32); permute_520 = None + view_1422 = torch.ops.aten.view.default(convert_element_type_1115, [2, 8192, 1, 64, 2]); convert_element_type_1115 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_1422); view_1422 = None + mul_348 = torch.ops.aten.mul.Tensor(view_as_complex_52, _conj); view_as_complex_52 = None + view_1423 = torch.ops.aten.view.default(convert_element_type_1116, [2, 8192, 4, 64, 2]); convert_element_type_1116 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_1423); view_1423 = None + mul_349 = torch.ops.aten.mul.Tensor(view_as_complex_53, _conj); view_as_complex_53 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_348); mul_348 = None + view_1424 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 1, 128]); view_as_real_52 = None + convert_element_type_1117 = torch.ops.prims.convert_element_type.default(view_1424, torch.bfloat16); view_1424 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_349); mul_349 = None + view_1425 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 4, 128]); view_as_real_53 = None + convert_element_type_1118 = torch.ops.prims.convert_element_type.default(view_1425, torch.bfloat16); view_1425 = None + view_1426 = torch.ops.aten.view.default(squeeze_20, [2, 8192, 128]); squeeze_20 = None + view_1427 = torch.ops.aten.view.default(convert_element_type_1117, [2, 8192, 128]); convert_element_type_1117 = None + view_1428 = torch.ops.aten.view.default(convert_element_type_1118, [2, 8192, 512]); convert_element_type_1118 = None + view_1429 = torch.ops.aten.view.default(view_1426, [16384, 128]); view_1426 = None + permute_521 = torch.ops.aten.permute.default(view_1429, [1, 0]) + mm_263 = torch.ops.aten.mm.default(permute_521, view_375); permute_521 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 8, '0'); convert_element_type_175 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + permute_523 = torch.ops.aten.permute.default(permute_57, [1, 0]); permute_57 = None + mm_264 = torch.ops.aten.mm.default(view_1429, permute_523); view_1429 = permute_523 = None + view_1430 = torch.ops.aten.view.default(mm_264, [2, 8192, 4096]); mm_264 = None + convert_element_type_1123 = torch.ops.prims.convert_element_type.default(mm_263, torch.float32); mm_263 = None + reduce_scatter_tensor_152 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1123, 'avg', 8, '0'); convert_element_type_1123 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_152); reduce_scatter_tensor_152 = None + view_1431 = torch.ops.aten.view.default(view_1427, [16384, 128]); view_1427 = None + permute_525 = torch.ops.aten.permute.default(view_1431, [1, 0]) + mm_265 = torch.ops.aten.mm.default(permute_525, view_375); permute_525 = None + permute_527 = torch.ops.aten.permute.default(permute_56, [1, 0]); permute_56 = None + mm_266 = torch.ops.aten.mm.default(view_1431, permute_527); view_1431 = permute_527 = None + view_1432 = torch.ops.aten.view.default(mm_266, [2, 8192, 4096]); mm_266 = None + add_139 = torch.ops.aten.add.Tensor(view_1430, view_1432); view_1430 = view_1432 = None + convert_element_type_1128 = torch.ops.prims.convert_element_type.default(mm_265, torch.float32); mm_265 = None + reduce_scatter_tensor_153 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1128, 'avg', 8, '0'); convert_element_type_1128 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_153); reduce_scatter_tensor_153 = None + view_1433 = torch.ops.aten.view.default(view_1428, [16384, 512]); view_1428 = None + permute_529 = torch.ops.aten.permute.default(view_1433, [1, 0]) + mm_267 = torch.ops.aten.mm.default(permute_529, view_375); permute_529 = view_375 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 8, '0'); convert_element_type_169 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_69, [1, 0]); wait_tensor_69 = None + permute_531 = torch.ops.aten.permute.default(permute_55, [1, 0]); permute_55 = None + mm_268 = torch.ops.aten.mm.default(view_1433, permute_531); view_1433 = permute_531 = None + view_1434 = torch.ops.aten.view.default(mm_268, [2, 8192, 4096]); mm_268 = None + add_140 = torch.ops.aten.add.Tensor(add_139, view_1434); add_139 = view_1434 = None + convert_element_type_1133 = torch.ops.prims.convert_element_type.default(mm_267, torch.float32); mm_267 = None + reduce_scatter_tensor_154 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1133, 'avg', 8, '0'); convert_element_type_1133 = None + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_154); reduce_scatter_tensor_154 = None + split_118 = torch.ops.aten.split.Tensor(add_140, 1024, 1); add_140 = None + getitem_1121 = split_118[0] + getitem_1122 = split_118[1] + getitem_1123 = split_118[2] + getitem_1124 = split_118[3] + getitem_1125 = split_118[4] + getitem_1126 = split_118[5] + getitem_1127 = split_118[6] + getitem_1128 = split_118[7]; split_118 = None + cat_110 = torch.ops.aten.cat.default([getitem_1121, getitem_1122, getitem_1123, getitem_1124, getitem_1125, getitem_1126, getitem_1127, getitem_1128]); getitem_1121 = getitem_1122 = getitem_1123 = getitem_1124 = getitem_1125 = getitem_1126 = getitem_1127 = getitem_1128 = None + reduce_scatter_tensor_155 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_110, 'sum', 8, '1'); cat_110 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_155); reduce_scatter_tensor_155 = None + convert_element_type_1134 = torch.ops.prims.convert_element_type.default(wait_tensor_379, torch.float32); wait_tensor_379 = None + convert_element_type_1136 = torch.ops.prims.convert_element_type.default(wait_tensor_67, torch.float32); wait_tensor_67 = None + mul_350 = torch.ops.aten.mul.Tensor(convert_element_type_1134, convert_element_type_1136); convert_element_type_1136 = None + mul_352 = torch.ops.aten.mul.Tensor(mul_40, mul_350) + sum_67 = torch.ops.aten.sum.dim_IntList(mul_352, [2], True); mul_352 = None + div_22 = torch.ops.aten.div.Tensor(mul_40, 4096) + mul_353 = torch.ops.aten.mul.Tensor(div_22, sum_67); div_22 = sum_67 = None + sub_34 = torch.ops.aten.sub.Tensor(mul_350, mul_353); mul_350 = mul_353 = None + mul_354 = torch.ops.aten.mul.Tensor(sub_34, rsqrt_10); sub_34 = rsqrt_10 = None + mul_355 = torch.ops.aten.mul.Tensor(convert_element_type_1134, mul_40); convert_element_type_1134 = mul_40 = None + sum_68 = torch.ops.aten.sum.dim_IntList(mul_355, [0, 1]); mul_355 = None + convert_element_type_1137 = torch.ops.prims.convert_element_type.default(mul_354, torch.bfloat16); mul_354 = None + convert_element_type_1138 = torch.ops.prims.convert_element_type.default(sum_68, torch.bfloat16); sum_68 = None + all_reduce_22 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1138, 'sum', '1'); convert_element_type_1138 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_22); all_reduce_22 = None + convert_element_type_1139 = torch.ops.prims.convert_element_type.default(wait_tensor_380, torch.float32); wait_tensor_380 = None + reduce_scatter_tensor_156 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1139, 'avg', 8, '0'); convert_element_type_1139 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_156); reduce_scatter_tensor_156 = None + add_141 = torch.ops.aten.add.Tensor(add_138, convert_element_type_1137); add_138 = convert_element_type_1137 = None + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_141, 8, '1') + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + split_119 = torch.ops.aten.split.Tensor(wait_tensor_382, 2); wait_tensor_382 = None + getitem_1129 = split_119[0] + getitem_1130 = split_119[1] + getitem_1131 = split_119[2] + getitem_1132 = split_119[3] + getitem_1133 = split_119[4] + getitem_1134 = split_119[5] + getitem_1135 = split_119[6] + getitem_1136 = split_119[7]; split_119 = None + cat_111 = torch.ops.aten.cat.default([getitem_1129, getitem_1130, getitem_1131, getitem_1132, getitem_1133, getitem_1134, getitem_1135, getitem_1136], 1); getitem_1129 = getitem_1130 = getitem_1131 = getitem_1132 = getitem_1133 = getitem_1134 = getitem_1135 = getitem_1136 = None + view_1435 = torch.ops.aten.view.default(cat_111, [16384, 4096]); cat_111 = None + permute_533 = torch.ops.aten.permute.default(view_1435, [1, 0]) + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9); reduce_scatter_tensor_9 = None + add_17 = torch.ops.aten.add.Tensor(add_15, wait_tensor_60); wait_tensor_60 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16); primals_45 = None + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 8, '0'); convert_element_type_152 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32); add_17 = None + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_61) + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_154, 8, '1'); convert_element_type_154 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + split_27 = torch.ops.aten.split.Tensor(wait_tensor_62, 2); wait_tensor_62 = None + getitem_261 = split_27[0] + getitem_262 = split_27[1] + getitem_263 = split_27[2] + getitem_264 = split_27[3] + getitem_265 = split_27[4] + getitem_266 = split_27[5] + getitem_267 = split_27[6] + getitem_268 = split_27[7]; split_27 = None + cat_19 = torch.ops.aten.cat.default([getitem_261, getitem_262, getitem_263, getitem_264, getitem_265, getitem_266, getitem_267, getitem_268], 1); getitem_261 = getitem_262 = getitem_263 = getitem_264 = getitem_265 = getitem_266 = getitem_267 = getitem_268 = None + view_348 = torch.ops.aten.view.default(cat_19, [16384, 4096]); cat_19 = None + view_349 = torch.ops.aten.view.default(mm_32, [2, 8192, 1792]); mm_32 = None + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_349, torch.float32); view_349 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 8, '0'); convert_element_type_160 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_64, [1, 0]); wait_tensor_64 = None + mm_33 = torch.ops.aten.mm.default(view_348, permute_53) + view_356 = torch.ops.aten.view.default(mm_33, [2, 8192, 1792]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_356) + view_363 = torch.ops.aten.view.default(mul_39, [16384, 1792]); mul_39 = None + mm_269 = torch.ops.aten.mm.default(permute_533, view_363); permute_533 = view_363 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16); primals_48 = None + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 8, '0'); convert_element_type_163 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + permute_535 = torch.ops.aten.permute.default(permute_54, [1, 0]); permute_54 = None + mm_270 = torch.ops.aten.mm.default(view_1435, permute_535); view_1435 = permute_535 = None + view_1436 = torch.ops.aten.view.default(mm_270, [2, 8192, 1792]); mm_270 = None + convert_element_type_1144 = torch.ops.prims.convert_element_type.default(mm_269, torch.float32); mm_269 = None + reduce_scatter_tensor_157 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1144, 'avg', 8, '0'); convert_element_type_1144 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_157); reduce_scatter_tensor_157 = None + mul_356 = torch.ops.aten.mul.Tensor(view_1436, convert_element_type_159); convert_element_type_159 = None + mul_357 = torch.ops.aten.mul.Tensor(view_1436, view_356); view_1436 = view_356 = None + view_1437 = torch.ops.aten.view.default(mul_356, [16384, 1792]); mul_356 = None + permute_537 = torch.ops.aten.permute.default(view_1437, [1, 0]) + mm_271 = torch.ops.aten.mm.default(permute_537, view_348); permute_537 = None + permute_539 = torch.ops.aten.permute.default(permute_53, [1, 0]); permute_53 = None + mm_272 = torch.ops.aten.mm.default(view_1437, permute_539); view_1437 = permute_539 = None + view_1438 = torch.ops.aten.view.default(mm_272, [2, 8192, 4096]); mm_272 = None + convert_element_type_1149 = torch.ops.prims.convert_element_type.default(mm_271, torch.float32); mm_271 = None + reduce_scatter_tensor_158 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1149, 'avg', 8, '0'); convert_element_type_1149 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_158); reduce_scatter_tensor_158 = None + convert_element_type_1150 = torch.ops.prims.convert_element_type.default(mul_357, torch.float32); mul_357 = None + neg_11 = torch.ops.aten.neg.default(convert_element_type_158) + exp_11 = torch.ops.aten.exp.default(neg_11); neg_11 = None + add_142 = torch.ops.aten.add.Tensor(exp_11, 1); exp_11 = None + reciprocal_11 = torch.ops.aten.reciprocal.default(add_142); add_142 = None + mul_358 = torch.ops.aten.mul.Tensor(reciprocal_11, 1); reciprocal_11 = None + mul_359 = torch.ops.aten.mul.Tensor(convert_element_type_1150, mul_358); convert_element_type_1150 = None + sub_35 = torch.ops.aten.sub.Tensor(1, mul_358); mul_358 = None + mul_360 = torch.ops.aten.mul.Tensor(convert_element_type_158, sub_35); convert_element_type_158 = sub_35 = None + add_143 = torch.ops.aten.add.Tensor(mul_360, 1); mul_360 = None + mul_361 = torch.ops.aten.mul.Tensor(mul_359, add_143); mul_359 = add_143 = None + convert_element_type_1152 = torch.ops.prims.convert_element_type.default(mul_361, torch.bfloat16); mul_361 = None + view_1439 = torch.ops.aten.view.default(convert_element_type_1152, [16384, 1792]); convert_element_type_1152 = None + permute_541 = torch.ops.aten.permute.default(view_1439, [1, 0]) + mm_273 = torch.ops.aten.mm.default(permute_541, view_348); permute_541 = view_348 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16); primals_46 = None + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 8, '0'); convert_element_type_155 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + permute_543 = torch.ops.aten.permute.default(permute_52, [1, 0]); permute_52 = None + mm_274 = torch.ops.aten.mm.default(view_1439, permute_543); view_1439 = permute_543 = None + view_1440 = torch.ops.aten.view.default(mm_274, [2, 8192, 4096]); mm_274 = None + add_144 = torch.ops.aten.add.Tensor(view_1438, view_1440); view_1438 = view_1440 = None + convert_element_type_1157 = torch.ops.prims.convert_element_type.default(mm_273, torch.float32); mm_273 = None + reduce_scatter_tensor_159 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1157, 'avg', 8, '0'); convert_element_type_1157 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_159); reduce_scatter_tensor_159 = None + split_120 = torch.ops.aten.split.Tensor(add_144, 1024, 1); add_144 = None + getitem_1137 = split_120[0] + getitem_1138 = split_120[1] + getitem_1139 = split_120[2] + getitem_1140 = split_120[3] + getitem_1141 = split_120[4] + getitem_1142 = split_120[5] + getitem_1143 = split_120[6] + getitem_1144 = split_120[7]; split_120 = None + cat_112 = torch.ops.aten.cat.default([getitem_1137, getitem_1138, getitem_1139, getitem_1140, getitem_1141, getitem_1142, getitem_1143, getitem_1144]); getitem_1137 = getitem_1138 = getitem_1139 = getitem_1140 = getitem_1141 = getitem_1142 = getitem_1143 = getitem_1144 = None + reduce_scatter_tensor_160 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_112, 'sum', 8, '1'); cat_112 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_160); reduce_scatter_tensor_160 = None + convert_element_type_1158 = torch.ops.prims.convert_element_type.default(wait_tensor_386, torch.float32); wait_tensor_386 = None + convert_element_type_1160 = torch.ops.prims.convert_element_type.default(wait_tensor_61, torch.float32); wait_tensor_61 = None + mul_362 = torch.ops.aten.mul.Tensor(convert_element_type_1158, convert_element_type_1160); convert_element_type_1160 = None + mul_364 = torch.ops.aten.mul.Tensor(mul_36, mul_362) + sum_69 = torch.ops.aten.sum.dim_IntList(mul_364, [2], True); mul_364 = None + div_23 = torch.ops.aten.div.Tensor(mul_36, 4096) + mul_365 = torch.ops.aten.mul.Tensor(div_23, sum_69); div_23 = sum_69 = None + sub_36 = torch.ops.aten.sub.Tensor(mul_362, mul_365); mul_362 = mul_365 = None + mul_366 = torch.ops.aten.mul.Tensor(sub_36, rsqrt_9); sub_36 = rsqrt_9 = None + mul_367 = torch.ops.aten.mul.Tensor(convert_element_type_1158, mul_36); convert_element_type_1158 = mul_36 = None + sum_70 = torch.ops.aten.sum.dim_IntList(mul_367, [0, 1]); mul_367 = None + convert_element_type_1161 = torch.ops.prims.convert_element_type.default(mul_366, torch.bfloat16); mul_366 = None + convert_element_type_1162 = torch.ops.prims.convert_element_type.default(sum_70, torch.bfloat16); sum_70 = None + all_reduce_23 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1162, 'sum', '1'); convert_element_type_1162 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_23); all_reduce_23 = None + convert_element_type_1163 = torch.ops.prims.convert_element_type.default(wait_tensor_387, torch.float32); wait_tensor_387 = None + reduce_scatter_tensor_161 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1163, 'avg', 8, '0'); convert_element_type_1163 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_161); reduce_scatter_tensor_161 = None + add_145 = torch.ops.aten.add.Tensor(add_141, convert_element_type_1161); add_141 = convert_element_type_1161 = None + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_145, 8, '1') + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + split_121 = torch.ops.aten.split.Tensor(wait_tensor_389, 2); wait_tensor_389 = None + getitem_1145 = split_121[0] + getitem_1146 = split_121[1] + getitem_1147 = split_121[2] + getitem_1148 = split_121[3] + getitem_1149 = split_121[4] + getitem_1150 = split_121[5] + getitem_1151 = split_121[6] + getitem_1152 = split_121[7]; split_121 = None + cat_113 = torch.ops.aten.cat.default([getitem_1145, getitem_1146, getitem_1147, getitem_1148, getitem_1149, getitem_1150, getitem_1151, getitem_1152], 1); getitem_1145 = getitem_1146 = getitem_1147 = getitem_1148 = getitem_1149 = getitem_1150 = getitem_1151 = getitem_1152 = None + view_1441 = torch.ops.aten.view.default(cat_113, [16384, 4096]); cat_113 = None + permute_545 = torch.ops.aten.permute.default(view_1441, [1, 0]) + permute_50 = torch.ops.aten.permute.default(getitem_244, [0, 2, 1, 3]) + view_330 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + view_336 = torch.ops.aten.view.default(view_330, [16384, 512]); view_330 = None + mm_275 = torch.ops.aten.mm.default(permute_545, view_336); permute_545 = view_336 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16); primals_44 = None + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 8, '0'); convert_element_type_149 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + permute_547 = torch.ops.aten.permute.default(permute_51, [1, 0]); permute_51 = None + mm_276 = torch.ops.aten.mm.default(view_1441, permute_547); view_1441 = permute_547 = None + view_1442 = torch.ops.aten.view.default(mm_276, [2, 8192, 512]); mm_276 = None + convert_element_type_1168 = torch.ops.prims.convert_element_type.default(mm_275, torch.float32); mm_275 = None + reduce_scatter_tensor_162 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1168, 'avg', 8, '0'); convert_element_type_1168 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_162); reduce_scatter_tensor_162 = None + view_1443 = torch.ops.aten.view.default(view_1442, [2, 8192, 4, 128]); view_1442 = None + permute_549 = torch.ops.aten.permute.default(view_1443, [0, 2, 1, 3]); view_1443 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16); primals_40 = None + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 8, '0'); convert_element_type_133 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32); add_15 = None + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_54) + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_135, 8, '1'); convert_element_type_135 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + split_25 = torch.ops.aten.split.Tensor(wait_tensor_55, 2); wait_tensor_55 = None + getitem_236 = split_25[0] + getitem_237 = split_25[1] + getitem_238 = split_25[2] + getitem_239 = split_25[3] + getitem_240 = split_25[4] + getitem_241 = split_25[5] + getitem_242 = split_25[6] + getitem_243 = split_25[7]; split_25 = None + cat_17 = torch.ops.aten.cat.default([getitem_236, getitem_237, getitem_238, getitem_239, getitem_240, getitem_241, getitem_242, getitem_243], 1); getitem_236 = getitem_237 = getitem_238 = getitem_239 = getitem_240 = getitem_241 = getitem_242 = getitem_243 = None + view_303 = torch.ops.aten.view.default(cat_17, [16384, 4096]); cat_17 = None + view_304 = torch.ops.aten.view.default(mm_28, [2, 8192, 512]); mm_28 = None + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16); primals_42 = None + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 8, '0'); convert_element_type_139 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_29 = torch.ops.aten.mm.default(view_303, permute_45) + view_311 = torch.ops.aten.view.default(mm_29, [2, 8192, 128]); mm_29 = None + view_318 = torch.ops.aten.view.default(mm_30, [2, 8192, 128]); mm_30 = None + view_320 = torch.ops.aten.view.default(view_304, [2, 8192, -1, 128]); view_304 = None + view_321 = torch.ops.aten.view.default(view_311, [2, 8192, -1, 128]); view_311 = None + view_322 = torch.ops.aten.view.default(view_318, [2, 8192, -1, 128]); view_318 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_320, torch.float32); view_320 = None + view_323 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 4, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_323); view_323 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_321, torch.float32); view_321 = None + view_324 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 1, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_324); view_324 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_37); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_326 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 4, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_37); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_327 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 1, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_326, torch.bfloat16); view_326 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_327, torch.bfloat16); view_327 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 1, 4, 128]); unsqueeze_8 = None + view_328 = torch.ops.aten.view.default(expand_8, [2, 8192, 4, 128]); expand_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_322, 3); view_322 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 1, 4, 128]); unsqueeze_9 = None + view_329 = torch.ops.aten.view.default(expand_9, [2, 8192, 4, 128]); expand_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_328, [0, 2, 1, 3]); view_328 = None + permute_49 = torch.ops.aten.permute.default(view_329, [0, 2, 1, 3]); view_329 = None + _scaled_dot_product_cudnn_attention_backward_11 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_549, permute_47, permute_48, permute_49, getitem_244, getitem_245, getitem_250, getitem_251, None, None, None, 8192, 8192, 0.0, True); permute_549 = permute_47 = permute_48 = permute_49 = getitem_244 = getitem_245 = getitem_250 = getitem_251 = None + getitem_1153 = _scaled_dot_product_cudnn_attention_backward_11[0] + getitem_1154 = _scaled_dot_product_cudnn_attention_backward_11[1] + getitem_1155 = _scaled_dot_product_cudnn_attention_backward_11[2]; _scaled_dot_product_cudnn_attention_backward_11 = None + permute_550 = torch.ops.aten.permute.default(getitem_1155, [0, 2, 1, 3]); getitem_1155 = None + permute_551 = torch.ops.aten.permute.default(getitem_1154, [0, 2, 1, 3]); getitem_1154 = None + permute_552 = torch.ops.aten.permute.default(getitem_1153, [0, 2, 1, 3]); getitem_1153 = None + view_1444 = torch.ops.aten.view.default(permute_550, [2, 8192, 1, 4, 128]); permute_550 = None + sum_71 = torch.ops.aten.sum.dim_IntList(view_1444, [3], True); view_1444 = None + squeeze_22 = torch.ops.aten.squeeze.dim(sum_71, 3); sum_71 = None + view_1445 = torch.ops.aten.view.default(permute_551, [2, 8192, 1, 4, 128]); permute_551 = None + sum_72 = torch.ops.aten.sum.dim_IntList(view_1445, [3], True); view_1445 = None + squeeze_23 = torch.ops.aten.squeeze.dim(sum_72, 3); sum_72 = None + convert_element_type_1169 = torch.ops.prims.convert_element_type.default(squeeze_23, torch.float32); squeeze_23 = None + convert_element_type_1170 = torch.ops.prims.convert_element_type.default(permute_552, torch.float32); permute_552 = None + view_1446 = torch.ops.aten.view.default(convert_element_type_1169, [2, 8192, 1, 64, 2]); convert_element_type_1169 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_1446); view_1446 = None + mul_368 = torch.ops.aten.mul.Tensor(view_as_complex_54, _conj); view_as_complex_54 = None + view_1447 = torch.ops.aten.view.default(convert_element_type_1170, [2, 8192, 4, 64, 2]); convert_element_type_1170 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_1447); view_1447 = None + mul_369 = torch.ops.aten.mul.Tensor(view_as_complex_55, _conj); view_as_complex_55 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_368); mul_368 = None + view_1448 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 1, 128]); view_as_real_54 = None + convert_element_type_1171 = torch.ops.prims.convert_element_type.default(view_1448, torch.bfloat16); view_1448 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_369); mul_369 = None + view_1449 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 4, 128]); view_as_real_55 = None + convert_element_type_1172 = torch.ops.prims.convert_element_type.default(view_1449, torch.bfloat16); view_1449 = None + view_1450 = torch.ops.aten.view.default(squeeze_22, [2, 8192, 128]); squeeze_22 = None + view_1451 = torch.ops.aten.view.default(convert_element_type_1171, [2, 8192, 128]); convert_element_type_1171 = None + view_1452 = torch.ops.aten.view.default(convert_element_type_1172, [2, 8192, 512]); convert_element_type_1172 = None + view_1453 = torch.ops.aten.view.default(view_1450, [16384, 128]); view_1450 = None + permute_553 = torch.ops.aten.permute.default(view_1453, [1, 0]) + mm_277 = torch.ops.aten.mm.default(permute_553, view_303); permute_553 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16); primals_43 = None + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 8, '0'); convert_element_type_142 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + permute_555 = torch.ops.aten.permute.default(permute_46, [1, 0]); permute_46 = None + mm_278 = torch.ops.aten.mm.default(view_1453, permute_555); view_1453 = permute_555 = None + view_1454 = torch.ops.aten.view.default(mm_278, [2, 8192, 4096]); mm_278 = None + convert_element_type_1177 = torch.ops.prims.convert_element_type.default(mm_277, torch.float32); mm_277 = None + reduce_scatter_tensor_163 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1177, 'avg', 8, '0'); convert_element_type_1177 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_163); reduce_scatter_tensor_163 = None + view_1455 = torch.ops.aten.view.default(view_1451, [16384, 128]); view_1451 = None + permute_557 = torch.ops.aten.permute.default(view_1455, [1, 0]) + mm_279 = torch.ops.aten.mm.default(permute_557, view_303); permute_557 = None + permute_559 = torch.ops.aten.permute.default(permute_45, [1, 0]); permute_45 = None + mm_280 = torch.ops.aten.mm.default(view_1455, permute_559); view_1455 = permute_559 = None + view_1456 = torch.ops.aten.view.default(mm_280, [2, 8192, 4096]); mm_280 = None + add_146 = torch.ops.aten.add.Tensor(view_1454, view_1456); view_1454 = view_1456 = None + convert_element_type_1182 = torch.ops.prims.convert_element_type.default(mm_279, torch.float32); mm_279 = None + reduce_scatter_tensor_164 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1182, 'avg', 8, '0'); convert_element_type_1182 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_164); reduce_scatter_tensor_164 = None + view_1457 = torch.ops.aten.view.default(view_1452, [16384, 512]); view_1452 = None + permute_561 = torch.ops.aten.permute.default(view_1457, [1, 0]) + mm_281 = torch.ops.aten.mm.default(permute_561, view_303); permute_561 = view_303 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16); primals_41 = None + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 8, '0'); convert_element_type_136 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + permute_563 = torch.ops.aten.permute.default(permute_44, [1, 0]); permute_44 = None + mm_282 = torch.ops.aten.mm.default(view_1457, permute_563); view_1457 = permute_563 = None + view_1458 = torch.ops.aten.view.default(mm_282, [2, 8192, 4096]); mm_282 = None + add_147 = torch.ops.aten.add.Tensor(add_146, view_1458); add_146 = view_1458 = None + convert_element_type_1187 = torch.ops.prims.convert_element_type.default(mm_281, torch.float32); mm_281 = None + reduce_scatter_tensor_165 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1187, 'avg', 8, '0'); convert_element_type_1187 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_165); reduce_scatter_tensor_165 = None + split_122 = torch.ops.aten.split.Tensor(add_147, 1024, 1); add_147 = None + getitem_1156 = split_122[0] + getitem_1157 = split_122[1] + getitem_1158 = split_122[2] + getitem_1159 = split_122[3] + getitem_1160 = split_122[4] + getitem_1161 = split_122[5] + getitem_1162 = split_122[6] + getitem_1163 = split_122[7]; split_122 = None + cat_114 = torch.ops.aten.cat.default([getitem_1156, getitem_1157, getitem_1158, getitem_1159, getitem_1160, getitem_1161, getitem_1162, getitem_1163]); getitem_1156 = getitem_1157 = getitem_1158 = getitem_1159 = getitem_1160 = getitem_1161 = getitem_1162 = getitem_1163 = None + reduce_scatter_tensor_166 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_114, 'sum', 8, '1'); cat_114 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_166); reduce_scatter_tensor_166 = None + convert_element_type_1188 = torch.ops.prims.convert_element_type.default(wait_tensor_394, torch.float32); wait_tensor_394 = None + convert_element_type_1190 = torch.ops.prims.convert_element_type.default(wait_tensor_54, torch.float32); wait_tensor_54 = None + mul_370 = torch.ops.aten.mul.Tensor(convert_element_type_1188, convert_element_type_1190); convert_element_type_1190 = None + mul_372 = torch.ops.aten.mul.Tensor(mul_32, mul_370) + sum_73 = torch.ops.aten.sum.dim_IntList(mul_372, [2], True); mul_372 = None + div_24 = torch.ops.aten.div.Tensor(mul_32, 4096) + mul_373 = torch.ops.aten.mul.Tensor(div_24, sum_73); div_24 = sum_73 = None + sub_37 = torch.ops.aten.sub.Tensor(mul_370, mul_373); mul_370 = mul_373 = None + mul_374 = torch.ops.aten.mul.Tensor(sub_37, rsqrt_8); sub_37 = rsqrt_8 = None + mul_375 = torch.ops.aten.mul.Tensor(convert_element_type_1188, mul_32); convert_element_type_1188 = mul_32 = None + sum_74 = torch.ops.aten.sum.dim_IntList(mul_375, [0, 1]); mul_375 = None + convert_element_type_1191 = torch.ops.prims.convert_element_type.default(mul_374, torch.bfloat16); mul_374 = None + convert_element_type_1192 = torch.ops.prims.convert_element_type.default(sum_74, torch.bfloat16); sum_74 = None + all_reduce_24 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1192, 'sum', '1'); convert_element_type_1192 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_24); all_reduce_24 = None + convert_element_type_1193 = torch.ops.prims.convert_element_type.default(wait_tensor_395, torch.float32); wait_tensor_395 = None + reduce_scatter_tensor_167 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1193, 'avg', 8, '0'); convert_element_type_1193 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_167); reduce_scatter_tensor_167 = None + add_148 = torch.ops.aten.add.Tensor(add_145, convert_element_type_1191); add_145 = convert_element_type_1191 = None + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_148, 8, '1') + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + split_123 = torch.ops.aten.split.Tensor(wait_tensor_397, 2); wait_tensor_397 = None + getitem_1164 = split_123[0] + getitem_1165 = split_123[1] + getitem_1166 = split_123[2] + getitem_1167 = split_123[3] + getitem_1168 = split_123[4] + getitem_1169 = split_123[5] + getitem_1170 = split_123[6] + getitem_1171 = split_123[7]; split_123 = None + cat_115 = torch.ops.aten.cat.default([getitem_1164, getitem_1165, getitem_1166, getitem_1167, getitem_1168, getitem_1169, getitem_1170, getitem_1171], 1); getitem_1164 = getitem_1165 = getitem_1166 = getitem_1167 = getitem_1168 = getitem_1169 = getitem_1170 = getitem_1171 = None + view_1459 = torch.ops.aten.view.default(cat_115, [16384, 4096]); cat_115 = None + permute_565 = torch.ops.aten.permute.default(view_1459, [1, 0]) + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7); reduce_scatter_tensor_7 = None + add_13 = torch.ops.aten.add.Tensor(add_11, wait_tensor_47); wait_tensor_47 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 8, '0'); convert_element_type_119 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32); add_13 = None + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_48) + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_121, 8, '1'); convert_element_type_121 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + split_23 = torch.ops.aten.split.Tensor(wait_tensor_49, 2); wait_tensor_49 = None + getitem_220 = split_23[0] + getitem_221 = split_23[1] + getitem_222 = split_23[2] + getitem_223 = split_23[3] + getitem_224 = split_23[4] + getitem_225 = split_23[5] + getitem_226 = split_23[6] + getitem_227 = split_23[7]; split_23 = None + cat_15 = torch.ops.aten.cat.default([getitem_220, getitem_221, getitem_222, getitem_223, getitem_224, getitem_225, getitem_226, getitem_227], 1); getitem_220 = getitem_221 = getitem_222 = getitem_223 = getitem_224 = getitem_225 = getitem_226 = getitem_227 = None + view_276 = torch.ops.aten.view.default(cat_15, [16384, 4096]); cat_15 = None + view_277 = torch.ops.aten.view.default(mm_25, [2, 8192, 1792]); mm_25 = None + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_277, torch.float32); view_277 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16); primals_38 = None + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 8, '0'); convert_element_type_127 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + mm_26 = torch.ops.aten.mm.default(view_276, permute_42) + view_284 = torch.ops.aten.view.default(mm_26, [2, 8192, 1792]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_284) + view_291 = torch.ops.aten.view.default(mul_31, [16384, 1792]); mul_31 = None + mm_283 = torch.ops.aten.mm.default(permute_565, view_291); permute_565 = view_291 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16); primals_39 = None + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 8, '0'); convert_element_type_130 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + permute_567 = torch.ops.aten.permute.default(permute_43, [1, 0]); permute_43 = None + mm_284 = torch.ops.aten.mm.default(view_1459, permute_567); view_1459 = permute_567 = None + view_1460 = torch.ops.aten.view.default(mm_284, [2, 8192, 1792]); mm_284 = None + convert_element_type_1198 = torch.ops.prims.convert_element_type.default(mm_283, torch.float32); mm_283 = None + reduce_scatter_tensor_168 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1198, 'avg', 8, '0'); convert_element_type_1198 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_168); reduce_scatter_tensor_168 = None + mul_376 = torch.ops.aten.mul.Tensor(view_1460, convert_element_type_126); convert_element_type_126 = None + mul_377 = torch.ops.aten.mul.Tensor(view_1460, view_284); view_1460 = view_284 = None + view_1461 = torch.ops.aten.view.default(mul_376, [16384, 1792]); mul_376 = None + permute_569 = torch.ops.aten.permute.default(view_1461, [1, 0]) + mm_285 = torch.ops.aten.mm.default(permute_569, view_276); permute_569 = None + permute_571 = torch.ops.aten.permute.default(permute_42, [1, 0]); permute_42 = None + mm_286 = torch.ops.aten.mm.default(view_1461, permute_571); view_1461 = permute_571 = None + view_1462 = torch.ops.aten.view.default(mm_286, [2, 8192, 4096]); mm_286 = None + convert_element_type_1203 = torch.ops.prims.convert_element_type.default(mm_285, torch.float32); mm_285 = None + reduce_scatter_tensor_169 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1203, 'avg', 8, '0'); convert_element_type_1203 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_169); reduce_scatter_tensor_169 = None + convert_element_type_1204 = torch.ops.prims.convert_element_type.default(mul_377, torch.float32); mul_377 = None + neg_12 = torch.ops.aten.neg.default(convert_element_type_125) + exp_12 = torch.ops.aten.exp.default(neg_12); neg_12 = None + add_149 = torch.ops.aten.add.Tensor(exp_12, 1); exp_12 = None + reciprocal_12 = torch.ops.aten.reciprocal.default(add_149); add_149 = None + mul_378 = torch.ops.aten.mul.Tensor(reciprocal_12, 1); reciprocal_12 = None + mul_379 = torch.ops.aten.mul.Tensor(convert_element_type_1204, mul_378); convert_element_type_1204 = None + sub_38 = torch.ops.aten.sub.Tensor(1, mul_378); mul_378 = None + mul_380 = torch.ops.aten.mul.Tensor(convert_element_type_125, sub_38); convert_element_type_125 = sub_38 = None + add_150 = torch.ops.aten.add.Tensor(mul_380, 1); mul_380 = None + mul_381 = torch.ops.aten.mul.Tensor(mul_379, add_150); mul_379 = add_150 = None + convert_element_type_1206 = torch.ops.prims.convert_element_type.default(mul_381, torch.bfloat16); mul_381 = None + view_1463 = torch.ops.aten.view.default(convert_element_type_1206, [16384, 1792]); convert_element_type_1206 = None + permute_573 = torch.ops.aten.permute.default(view_1463, [1, 0]) + mm_287 = torch.ops.aten.mm.default(permute_573, view_276); permute_573 = view_276 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16); primals_37 = None + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 8, '0'); convert_element_type_122 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + permute_575 = torch.ops.aten.permute.default(permute_41, [1, 0]); permute_41 = None + mm_288 = torch.ops.aten.mm.default(view_1463, permute_575); view_1463 = permute_575 = None + view_1464 = torch.ops.aten.view.default(mm_288, [2, 8192, 4096]); mm_288 = None + add_151 = torch.ops.aten.add.Tensor(view_1462, view_1464); view_1462 = view_1464 = None + convert_element_type_1211 = torch.ops.prims.convert_element_type.default(mm_287, torch.float32); mm_287 = None + reduce_scatter_tensor_170 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1211, 'avg', 8, '0'); convert_element_type_1211 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_170); reduce_scatter_tensor_170 = None + split_124 = torch.ops.aten.split.Tensor(add_151, 1024, 1); add_151 = None + getitem_1172 = split_124[0] + getitem_1173 = split_124[1] + getitem_1174 = split_124[2] + getitem_1175 = split_124[3] + getitem_1176 = split_124[4] + getitem_1177 = split_124[5] + getitem_1178 = split_124[6] + getitem_1179 = split_124[7]; split_124 = None + cat_116 = torch.ops.aten.cat.default([getitem_1172, getitem_1173, getitem_1174, getitem_1175, getitem_1176, getitem_1177, getitem_1178, getitem_1179]); getitem_1172 = getitem_1173 = getitem_1174 = getitem_1175 = getitem_1176 = getitem_1177 = getitem_1178 = getitem_1179 = None + reduce_scatter_tensor_171 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_116, 'sum', 8, '1'); cat_116 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_171); reduce_scatter_tensor_171 = None + convert_element_type_1212 = torch.ops.prims.convert_element_type.default(wait_tensor_401, torch.float32); wait_tensor_401 = None + convert_element_type_1214 = torch.ops.prims.convert_element_type.default(wait_tensor_48, torch.float32); wait_tensor_48 = None + mul_382 = torch.ops.aten.mul.Tensor(convert_element_type_1212, convert_element_type_1214); convert_element_type_1214 = None + mul_384 = torch.ops.aten.mul.Tensor(mul_28, mul_382) + sum_75 = torch.ops.aten.sum.dim_IntList(mul_384, [2], True); mul_384 = None + div_25 = torch.ops.aten.div.Tensor(mul_28, 4096) + mul_385 = torch.ops.aten.mul.Tensor(div_25, sum_75); div_25 = sum_75 = None + sub_39 = torch.ops.aten.sub.Tensor(mul_382, mul_385); mul_382 = mul_385 = None + mul_386 = torch.ops.aten.mul.Tensor(sub_39, rsqrt_7); sub_39 = rsqrt_7 = None + mul_387 = torch.ops.aten.mul.Tensor(convert_element_type_1212, mul_28); convert_element_type_1212 = mul_28 = None + sum_76 = torch.ops.aten.sum.dim_IntList(mul_387, [0, 1]); mul_387 = None + convert_element_type_1215 = torch.ops.prims.convert_element_type.default(mul_386, torch.bfloat16); mul_386 = None + convert_element_type_1216 = torch.ops.prims.convert_element_type.default(sum_76, torch.bfloat16); sum_76 = None + all_reduce_25 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1216, 'sum', '1'); convert_element_type_1216 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_25); all_reduce_25 = None + convert_element_type_1217 = torch.ops.prims.convert_element_type.default(wait_tensor_402, torch.float32); wait_tensor_402 = None + reduce_scatter_tensor_172 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1217, 'avg', 8, '0'); convert_element_type_1217 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_172); reduce_scatter_tensor_172 = None + add_152 = torch.ops.aten.add.Tensor(add_148, convert_element_type_1215); add_148 = convert_element_type_1215 = None + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_152, 8, '1') + wait_tensor_404 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + split_125 = torch.ops.aten.split.Tensor(wait_tensor_404, 2); wait_tensor_404 = None + getitem_1180 = split_125[0] + getitem_1181 = split_125[1] + getitem_1182 = split_125[2] + getitem_1183 = split_125[3] + getitem_1184 = split_125[4] + getitem_1185 = split_125[5] + getitem_1186 = split_125[6] + getitem_1187 = split_125[7]; split_125 = None + cat_117 = torch.ops.aten.cat.default([getitem_1180, getitem_1181, getitem_1182, getitem_1183, getitem_1184, getitem_1185, getitem_1186, getitem_1187], 1); getitem_1180 = getitem_1181 = getitem_1182 = getitem_1183 = getitem_1184 = getitem_1185 = getitem_1186 = getitem_1187 = None + view_1465 = torch.ops.aten.view.default(cat_117, [16384, 4096]); cat_117 = None + permute_577 = torch.ops.aten.permute.default(view_1465, [1, 0]) + permute_39 = torch.ops.aten.permute.default(getitem_203, [0, 2, 1, 3]) + view_258 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + view_264 = torch.ops.aten.view.default(view_258, [16384, 512]); view_258 = None + mm_289 = torch.ops.aten.mm.default(permute_577, view_264); permute_577 = view_264 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 8, '0'); convert_element_type_116 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_46, [1, 0]); wait_tensor_46 = None + permute_579 = torch.ops.aten.permute.default(permute_40, [1, 0]); permute_40 = None + mm_290 = torch.ops.aten.mm.default(view_1465, permute_579); view_1465 = permute_579 = None + view_1466 = torch.ops.aten.view.default(mm_290, [2, 8192, 512]); mm_290 = None + convert_element_type_1222 = torch.ops.prims.convert_element_type.default(mm_289, torch.float32); mm_289 = None + reduce_scatter_tensor_173 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1222, 'avg', 8, '0'); convert_element_type_1222 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_173); reduce_scatter_tensor_173 = None + view_1467 = torch.ops.aten.view.default(view_1466, [2, 8192, 4, 128]); view_1466 = None + permute_581 = torch.ops.aten.permute.default(view_1467, [0, 2, 1, 3]); view_1467 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 8, '0'); convert_element_type_100 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32); add_11 = None + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_41) + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_102, 8, '1'); convert_element_type_102 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + split_21 = torch.ops.aten.split.Tensor(wait_tensor_42, 2); wait_tensor_42 = None + getitem_195 = split_21[0] + getitem_196 = split_21[1] + getitem_197 = split_21[2] + getitem_198 = split_21[3] + getitem_199 = split_21[4] + getitem_200 = split_21[5] + getitem_201 = split_21[6] + getitem_202 = split_21[7]; split_21 = None + cat_13 = torch.ops.aten.cat.default([getitem_195, getitem_196, getitem_197, getitem_198, getitem_199, getitem_200, getitem_201, getitem_202], 1); getitem_195 = getitem_196 = getitem_197 = getitem_198 = getitem_199 = getitem_200 = getitem_201 = getitem_202 = None + view_231 = torch.ops.aten.view.default(cat_13, [16384, 4096]); cat_13 = None + view_232 = torch.ops.aten.view.default(mm_21, [2, 8192, 512]); mm_21 = None + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 8, '0'); convert_element_type_106 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_22 = torch.ops.aten.mm.default(view_231, permute_34) + view_239 = torch.ops.aten.view.default(mm_22, [2, 8192, 128]); mm_22 = None + view_246 = torch.ops.aten.view.default(mm_23, [2, 8192, 128]); mm_23 = None + view_248 = torch.ops.aten.view.default(view_232, [2, 8192, -1, 128]); view_232 = None + view_249 = torch.ops.aten.view.default(view_239, [2, 8192, -1, 128]); view_239 = None + view_250 = torch.ops.aten.view.default(view_246, [2, 8192, -1, 128]); view_246 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_248, torch.float32); view_248 = None + view_251 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 4, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_251); view_251 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 1, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_37); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_254 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 4, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_37); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_255 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 1, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_254, torch.bfloat16); view_254 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 1, 4, 128]); unsqueeze_6 = None + view_256 = torch.ops.aten.view.default(expand_6, [2, 8192, 4, 128]); expand_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_250, 3); view_250 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 1, 4, 128]); unsqueeze_7 = None + view_257 = torch.ops.aten.view.default(expand_7, [2, 8192, 4, 128]); expand_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_256, [0, 2, 1, 3]); view_256 = None + permute_38 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + _scaled_dot_product_cudnn_attention_backward_12 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_581, permute_36, permute_37, permute_38, getitem_203, getitem_204, getitem_209, getitem_210, None, None, None, 8192, 8192, 0.0, True); permute_581 = permute_36 = permute_37 = permute_38 = getitem_203 = getitem_204 = getitem_209 = getitem_210 = None + getitem_1188 = _scaled_dot_product_cudnn_attention_backward_12[0] + getitem_1189 = _scaled_dot_product_cudnn_attention_backward_12[1] + getitem_1190 = _scaled_dot_product_cudnn_attention_backward_12[2]; _scaled_dot_product_cudnn_attention_backward_12 = None + permute_582 = torch.ops.aten.permute.default(getitem_1190, [0, 2, 1, 3]); getitem_1190 = None + permute_583 = torch.ops.aten.permute.default(getitem_1189, [0, 2, 1, 3]); getitem_1189 = None + permute_584 = torch.ops.aten.permute.default(getitem_1188, [0, 2, 1, 3]); getitem_1188 = None + view_1468 = torch.ops.aten.view.default(permute_582, [2, 8192, 1, 4, 128]); permute_582 = None + sum_77 = torch.ops.aten.sum.dim_IntList(view_1468, [3], True); view_1468 = None + squeeze_24 = torch.ops.aten.squeeze.dim(sum_77, 3); sum_77 = None + view_1469 = torch.ops.aten.view.default(permute_583, [2, 8192, 1, 4, 128]); permute_583 = None + sum_78 = torch.ops.aten.sum.dim_IntList(view_1469, [3], True); view_1469 = None + squeeze_25 = torch.ops.aten.squeeze.dim(sum_78, 3); sum_78 = None + convert_element_type_1223 = torch.ops.prims.convert_element_type.default(squeeze_25, torch.float32); squeeze_25 = None + convert_element_type_1224 = torch.ops.prims.convert_element_type.default(permute_584, torch.float32); permute_584 = None + view_1470 = torch.ops.aten.view.default(convert_element_type_1223, [2, 8192, 1, 64, 2]); convert_element_type_1223 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_1470); view_1470 = None + mul_388 = torch.ops.aten.mul.Tensor(view_as_complex_56, _conj); view_as_complex_56 = None + view_1471 = torch.ops.aten.view.default(convert_element_type_1224, [2, 8192, 4, 64, 2]); convert_element_type_1224 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_1471); view_1471 = None + mul_389 = torch.ops.aten.mul.Tensor(view_as_complex_57, _conj); view_as_complex_57 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_388); mul_388 = None + view_1472 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 1, 128]); view_as_real_56 = None + convert_element_type_1225 = torch.ops.prims.convert_element_type.default(view_1472, torch.bfloat16); view_1472 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_389); mul_389 = None + view_1473 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 4, 128]); view_as_real_57 = None + convert_element_type_1226 = torch.ops.prims.convert_element_type.default(view_1473, torch.bfloat16); view_1473 = None + view_1474 = torch.ops.aten.view.default(squeeze_24, [2, 8192, 128]); squeeze_24 = None + view_1475 = torch.ops.aten.view.default(convert_element_type_1225, [2, 8192, 128]); convert_element_type_1225 = None + view_1476 = torch.ops.aten.view.default(convert_element_type_1226, [2, 8192, 512]); convert_element_type_1226 = None + view_1477 = torch.ops.aten.view.default(view_1474, [16384, 128]); view_1474 = None + permute_585 = torch.ops.aten.permute.default(view_1477, [1, 0]) + mm_291 = torch.ops.aten.mm.default(permute_585, view_231); permute_585 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 8, '0'); convert_element_type_109 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + permute_587 = torch.ops.aten.permute.default(permute_35, [1, 0]); permute_35 = None + mm_292 = torch.ops.aten.mm.default(view_1477, permute_587); view_1477 = permute_587 = None + view_1478 = torch.ops.aten.view.default(mm_292, [2, 8192, 4096]); mm_292 = None + convert_element_type_1231 = torch.ops.prims.convert_element_type.default(mm_291, torch.float32); mm_291 = None + reduce_scatter_tensor_174 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1231, 'avg', 8, '0'); convert_element_type_1231 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_174); reduce_scatter_tensor_174 = None + view_1479 = torch.ops.aten.view.default(view_1475, [16384, 128]); view_1475 = None + permute_589 = torch.ops.aten.permute.default(view_1479, [1, 0]) + mm_293 = torch.ops.aten.mm.default(permute_589, view_231); permute_589 = None + permute_591 = torch.ops.aten.permute.default(permute_34, [1, 0]); permute_34 = None + mm_294 = torch.ops.aten.mm.default(view_1479, permute_591); view_1479 = permute_591 = None + view_1480 = torch.ops.aten.view.default(mm_294, [2, 8192, 4096]); mm_294 = None + add_153 = torch.ops.aten.add.Tensor(view_1478, view_1480); view_1478 = view_1480 = None + convert_element_type_1236 = torch.ops.prims.convert_element_type.default(mm_293, torch.float32); mm_293 = None + reduce_scatter_tensor_175 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1236, 'avg', 8, '0'); convert_element_type_1236 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_175); reduce_scatter_tensor_175 = None + view_1481 = torch.ops.aten.view.default(view_1476, [16384, 512]); view_1476 = None + permute_593 = torch.ops.aten.permute.default(view_1481, [1, 0]) + mm_295 = torch.ops.aten.mm.default(permute_593, view_231); permute_593 = view_231 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16); primals_32 = None + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 8, '0'); convert_element_type_103 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + permute_595 = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None + mm_296 = torch.ops.aten.mm.default(view_1481, permute_595); view_1481 = permute_595 = None + view_1482 = torch.ops.aten.view.default(mm_296, [2, 8192, 4096]); mm_296 = None + add_154 = torch.ops.aten.add.Tensor(add_153, view_1482); add_153 = view_1482 = None + convert_element_type_1241 = torch.ops.prims.convert_element_type.default(mm_295, torch.float32); mm_295 = None + reduce_scatter_tensor_176 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1241, 'avg', 8, '0'); convert_element_type_1241 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_176); reduce_scatter_tensor_176 = None + split_126 = torch.ops.aten.split.Tensor(add_154, 1024, 1); add_154 = None + getitem_1191 = split_126[0] + getitem_1192 = split_126[1] + getitem_1193 = split_126[2] + getitem_1194 = split_126[3] + getitem_1195 = split_126[4] + getitem_1196 = split_126[5] + getitem_1197 = split_126[6] + getitem_1198 = split_126[7]; split_126 = None + cat_118 = torch.ops.aten.cat.default([getitem_1191, getitem_1192, getitem_1193, getitem_1194, getitem_1195, getitem_1196, getitem_1197, getitem_1198]); getitem_1191 = getitem_1192 = getitem_1193 = getitem_1194 = getitem_1195 = getitem_1196 = getitem_1197 = getitem_1198 = None + reduce_scatter_tensor_177 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_118, 'sum', 8, '1'); cat_118 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_177); reduce_scatter_tensor_177 = None + convert_element_type_1242 = torch.ops.prims.convert_element_type.default(wait_tensor_409, torch.float32); wait_tensor_409 = None + convert_element_type_1244 = torch.ops.prims.convert_element_type.default(wait_tensor_41, torch.float32); wait_tensor_41 = None + mul_390 = torch.ops.aten.mul.Tensor(convert_element_type_1242, convert_element_type_1244); convert_element_type_1244 = None + mul_392 = torch.ops.aten.mul.Tensor(mul_24, mul_390) + sum_79 = torch.ops.aten.sum.dim_IntList(mul_392, [2], True); mul_392 = None + div_26 = torch.ops.aten.div.Tensor(mul_24, 4096) + mul_393 = torch.ops.aten.mul.Tensor(div_26, sum_79); div_26 = sum_79 = None + sub_40 = torch.ops.aten.sub.Tensor(mul_390, mul_393); mul_390 = mul_393 = None + mul_394 = torch.ops.aten.mul.Tensor(sub_40, rsqrt_6); sub_40 = rsqrt_6 = None + mul_395 = torch.ops.aten.mul.Tensor(convert_element_type_1242, mul_24); convert_element_type_1242 = mul_24 = None + sum_80 = torch.ops.aten.sum.dim_IntList(mul_395, [0, 1]); mul_395 = None + convert_element_type_1245 = torch.ops.prims.convert_element_type.default(mul_394, torch.bfloat16); mul_394 = None + convert_element_type_1246 = torch.ops.prims.convert_element_type.default(sum_80, torch.bfloat16); sum_80 = None + all_reduce_26 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1246, 'sum', '1'); convert_element_type_1246 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_26); all_reduce_26 = None + convert_element_type_1247 = torch.ops.prims.convert_element_type.default(wait_tensor_410, torch.float32); wait_tensor_410 = None + reduce_scatter_tensor_178 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1247, 'avg', 8, '0'); convert_element_type_1247 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_178); reduce_scatter_tensor_178 = None + add_155 = torch.ops.aten.add.Tensor(add_152, convert_element_type_1245); add_152 = convert_element_type_1245 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_155, 8, '1') + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + split_127 = torch.ops.aten.split.Tensor(wait_tensor_412, 2); wait_tensor_412 = None + getitem_1199 = split_127[0] + getitem_1200 = split_127[1] + getitem_1201 = split_127[2] + getitem_1202 = split_127[3] + getitem_1203 = split_127[4] + getitem_1204 = split_127[5] + getitem_1205 = split_127[6] + getitem_1206 = split_127[7]; split_127 = None + cat_119 = torch.ops.aten.cat.default([getitem_1199, getitem_1200, getitem_1201, getitem_1202, getitem_1203, getitem_1204, getitem_1205, getitem_1206], 1); getitem_1199 = getitem_1200 = getitem_1201 = getitem_1202 = getitem_1203 = getitem_1204 = getitem_1205 = getitem_1206 = None + view_1483 = torch.ops.aten.view.default(cat_119, [16384, 4096]); cat_119 = None + permute_597 = torch.ops.aten.permute.default(view_1483, [1, 0]) + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5); reduce_scatter_tensor_5 = None + add_9 = torch.ops.aten.add.Tensor(add_7, wait_tensor_34); wait_tensor_34 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16); primals_27 = None + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 8, '0'); convert_element_type_86 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32); add_9 = None + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_35) + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_88, 8, '1'); convert_element_type_88 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + split_19 = torch.ops.aten.split.Tensor(wait_tensor_36, 2); wait_tensor_36 = None + getitem_179 = split_19[0] + getitem_180 = split_19[1] + getitem_181 = split_19[2] + getitem_182 = split_19[3] + getitem_183 = split_19[4] + getitem_184 = split_19[5] + getitem_185 = split_19[6] + getitem_186 = split_19[7]; split_19 = None + cat_11 = torch.ops.aten.cat.default([getitem_179, getitem_180, getitem_181, getitem_182, getitem_183, getitem_184, getitem_185, getitem_186], 1); getitem_179 = getitem_180 = getitem_181 = getitem_182 = getitem_183 = getitem_184 = getitem_185 = getitem_186 = None + view_204 = torch.ops.aten.view.default(cat_11, [16384, 4096]); cat_11 = None + view_205 = torch.ops.aten.view.default(mm_18, [2, 8192, 1792]); mm_18 = None + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_205, torch.float32); view_205 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 8, '0'); convert_element_type_94 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + mm_19 = torch.ops.aten.mm.default(view_204, permute_31) + view_212 = torch.ops.aten.view.default(mm_19, [2, 8192, 1792]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_212) + view_219 = torch.ops.aten.view.default(mul_23, [16384, 1792]); mul_23 = None + mm_297 = torch.ops.aten.mm.default(permute_597, view_219); permute_597 = view_219 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16); primals_30 = None + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 8, '0'); convert_element_type_97 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + permute_599 = torch.ops.aten.permute.default(permute_32, [1, 0]); permute_32 = None + mm_298 = torch.ops.aten.mm.default(view_1483, permute_599); view_1483 = permute_599 = None + view_1484 = torch.ops.aten.view.default(mm_298, [2, 8192, 1792]); mm_298 = None + convert_element_type_1252 = torch.ops.prims.convert_element_type.default(mm_297, torch.float32); mm_297 = None + reduce_scatter_tensor_179 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1252, 'avg', 8, '0'); convert_element_type_1252 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_179); reduce_scatter_tensor_179 = None + mul_396 = torch.ops.aten.mul.Tensor(view_1484, convert_element_type_93); convert_element_type_93 = None + mul_397 = torch.ops.aten.mul.Tensor(view_1484, view_212); view_1484 = view_212 = None + view_1485 = torch.ops.aten.view.default(mul_396, [16384, 1792]); mul_396 = None + permute_601 = torch.ops.aten.permute.default(view_1485, [1, 0]) + mm_299 = torch.ops.aten.mm.default(permute_601, view_204); permute_601 = None + permute_603 = torch.ops.aten.permute.default(permute_31, [1, 0]); permute_31 = None + mm_300 = torch.ops.aten.mm.default(view_1485, permute_603); view_1485 = permute_603 = None + view_1486 = torch.ops.aten.view.default(mm_300, [2, 8192, 4096]); mm_300 = None + convert_element_type_1257 = torch.ops.prims.convert_element_type.default(mm_299, torch.float32); mm_299 = None + reduce_scatter_tensor_180 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1257, 'avg', 8, '0'); convert_element_type_1257 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_180); reduce_scatter_tensor_180 = None + convert_element_type_1258 = torch.ops.prims.convert_element_type.default(mul_397, torch.float32); mul_397 = None + neg_13 = torch.ops.aten.neg.default(convert_element_type_92) + exp_13 = torch.ops.aten.exp.default(neg_13); neg_13 = None + add_156 = torch.ops.aten.add.Tensor(exp_13, 1); exp_13 = None + reciprocal_13 = torch.ops.aten.reciprocal.default(add_156); add_156 = None + mul_398 = torch.ops.aten.mul.Tensor(reciprocal_13, 1); reciprocal_13 = None + mul_399 = torch.ops.aten.mul.Tensor(convert_element_type_1258, mul_398); convert_element_type_1258 = None + sub_41 = torch.ops.aten.sub.Tensor(1, mul_398); mul_398 = None + mul_400 = torch.ops.aten.mul.Tensor(convert_element_type_92, sub_41); convert_element_type_92 = sub_41 = None + add_157 = torch.ops.aten.add.Tensor(mul_400, 1); mul_400 = None + mul_401 = torch.ops.aten.mul.Tensor(mul_399, add_157); mul_399 = add_157 = None + convert_element_type_1260 = torch.ops.prims.convert_element_type.default(mul_401, torch.bfloat16); mul_401 = None + view_1487 = torch.ops.aten.view.default(convert_element_type_1260, [16384, 1792]); convert_element_type_1260 = None + permute_605 = torch.ops.aten.permute.default(view_1487, [1, 0]) + mm_301 = torch.ops.aten.mm.default(permute_605, view_204); permute_605 = view_204 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 8, '0'); convert_element_type_89 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + permute_607 = torch.ops.aten.permute.default(permute_30, [1, 0]); permute_30 = None + mm_302 = torch.ops.aten.mm.default(view_1487, permute_607); view_1487 = permute_607 = None + view_1488 = torch.ops.aten.view.default(mm_302, [2, 8192, 4096]); mm_302 = None + add_158 = torch.ops.aten.add.Tensor(view_1486, view_1488); view_1486 = view_1488 = None + convert_element_type_1265 = torch.ops.prims.convert_element_type.default(mm_301, torch.float32); mm_301 = None + reduce_scatter_tensor_181 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1265, 'avg', 8, '0'); convert_element_type_1265 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_181); reduce_scatter_tensor_181 = None + split_128 = torch.ops.aten.split.Tensor(add_158, 1024, 1); add_158 = None + getitem_1207 = split_128[0] + getitem_1208 = split_128[1] + getitem_1209 = split_128[2] + getitem_1210 = split_128[3] + getitem_1211 = split_128[4] + getitem_1212 = split_128[5] + getitem_1213 = split_128[6] + getitem_1214 = split_128[7]; split_128 = None + cat_120 = torch.ops.aten.cat.default([getitem_1207, getitem_1208, getitem_1209, getitem_1210, getitem_1211, getitem_1212, getitem_1213, getitem_1214]); getitem_1207 = getitem_1208 = getitem_1209 = getitem_1210 = getitem_1211 = getitem_1212 = getitem_1213 = getitem_1214 = None + reduce_scatter_tensor_182 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_120, 'sum', 8, '1'); cat_120 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_182); reduce_scatter_tensor_182 = None + convert_element_type_1266 = torch.ops.prims.convert_element_type.default(wait_tensor_416, torch.float32); wait_tensor_416 = None + convert_element_type_1268 = torch.ops.prims.convert_element_type.default(wait_tensor_35, torch.float32); wait_tensor_35 = None + mul_402 = torch.ops.aten.mul.Tensor(convert_element_type_1266, convert_element_type_1268); convert_element_type_1268 = None + mul_404 = torch.ops.aten.mul.Tensor(mul_20, mul_402) + sum_81 = torch.ops.aten.sum.dim_IntList(mul_404, [2], True); mul_404 = None + div_27 = torch.ops.aten.div.Tensor(mul_20, 4096) + mul_405 = torch.ops.aten.mul.Tensor(div_27, sum_81); div_27 = sum_81 = None + sub_42 = torch.ops.aten.sub.Tensor(mul_402, mul_405); mul_402 = mul_405 = None + mul_406 = torch.ops.aten.mul.Tensor(sub_42, rsqrt_5); sub_42 = rsqrt_5 = None + mul_407 = torch.ops.aten.mul.Tensor(convert_element_type_1266, mul_20); convert_element_type_1266 = mul_20 = None + sum_82 = torch.ops.aten.sum.dim_IntList(mul_407, [0, 1]); mul_407 = None + convert_element_type_1269 = torch.ops.prims.convert_element_type.default(mul_406, torch.bfloat16); mul_406 = None + convert_element_type_1270 = torch.ops.prims.convert_element_type.default(sum_82, torch.bfloat16); sum_82 = None + all_reduce_27 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1270, 'sum', '1'); convert_element_type_1270 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_27); all_reduce_27 = None + convert_element_type_1271 = torch.ops.prims.convert_element_type.default(wait_tensor_417, torch.float32); wait_tensor_417 = None + reduce_scatter_tensor_183 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1271, 'avg', 8, '0'); convert_element_type_1271 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_183); reduce_scatter_tensor_183 = None + add_159 = torch.ops.aten.add.Tensor(add_155, convert_element_type_1269); add_155 = convert_element_type_1269 = None + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_159, 8, '1') + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + split_129 = torch.ops.aten.split.Tensor(wait_tensor_419, 2); wait_tensor_419 = None + getitem_1215 = split_129[0] + getitem_1216 = split_129[1] + getitem_1217 = split_129[2] + getitem_1218 = split_129[3] + getitem_1219 = split_129[4] + getitem_1220 = split_129[5] + getitem_1221 = split_129[6] + getitem_1222 = split_129[7]; split_129 = None + cat_121 = torch.ops.aten.cat.default([getitem_1215, getitem_1216, getitem_1217, getitem_1218, getitem_1219, getitem_1220, getitem_1221, getitem_1222], 1); getitem_1215 = getitem_1216 = getitem_1217 = getitem_1218 = getitem_1219 = getitem_1220 = getitem_1221 = getitem_1222 = None + view_1489 = torch.ops.aten.view.default(cat_121, [16384, 4096]); cat_121 = None + permute_609 = torch.ops.aten.permute.default(view_1489, [1, 0]) + permute_28 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_186 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + view_192 = torch.ops.aten.view.default(view_186, [16384, 512]); view_186 = None + mm_303 = torch.ops.aten.mm.default(permute_609, view_192); permute_609 = view_192 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16); primals_26 = None + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 8, '0'); convert_element_type_83 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + permute_611 = torch.ops.aten.permute.default(permute_29, [1, 0]); permute_29 = None + mm_304 = torch.ops.aten.mm.default(view_1489, permute_611); view_1489 = permute_611 = None + view_1490 = torch.ops.aten.view.default(mm_304, [2, 8192, 512]); mm_304 = None + convert_element_type_1276 = torch.ops.prims.convert_element_type.default(mm_303, torch.float32); mm_303 = None + reduce_scatter_tensor_184 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1276, 'avg', 8, '0'); convert_element_type_1276 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_184); reduce_scatter_tensor_184 = None + view_1491 = torch.ops.aten.view.default(view_1490, [2, 8192, 4, 128]); view_1490 = None + permute_613 = torch.ops.aten.permute.default(view_1491, [0, 2, 1, 3]); view_1491 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16); primals_22 = None + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 8, '0'); convert_element_type_67 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32); add_7 = None + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_28) + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_69, 8, '1'); convert_element_type_69 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + split_17 = torch.ops.aten.split.Tensor(wait_tensor_29, 2); wait_tensor_29 = None + getitem_154 = split_17[0] + getitem_155 = split_17[1] + getitem_156 = split_17[2] + getitem_157 = split_17[3] + getitem_158 = split_17[4] + getitem_159 = split_17[5] + getitem_160 = split_17[6] + getitem_161 = split_17[7]; split_17 = None + cat_9 = torch.ops.aten.cat.default([getitem_154, getitem_155, getitem_156, getitem_157, getitem_158, getitem_159, getitem_160, getitem_161], 1); getitem_154 = getitem_155 = getitem_156 = getitem_157 = getitem_158 = getitem_159 = getitem_160 = getitem_161 = None + view_159 = torch.ops.aten.view.default(cat_9, [16384, 4096]); cat_9 = None + view_160 = torch.ops.aten.view.default(mm_14, [2, 8192, 512]); mm_14 = None + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16); primals_24 = None + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 8, '0'); convert_element_type_73 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_15 = torch.ops.aten.mm.default(view_159, permute_23) + view_167 = torch.ops.aten.view.default(mm_15, [2, 8192, 128]); mm_15 = None + view_174 = torch.ops.aten.view.default(mm_16, [2, 8192, 128]); mm_16 = None + view_176 = torch.ops.aten.view.default(view_160, [2, 8192, -1, 128]); view_160 = None + view_177 = torch.ops.aten.view.default(view_167, [2, 8192, -1, 128]); view_167 = None + view_178 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_176, torch.float32); view_176 = None + view_179 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 4, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_179); view_179 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_177, torch.float32); view_177 = None + view_180 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 1, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_180); view_180 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_37); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_182 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 4, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_37); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_183 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 1, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_182, torch.bfloat16); view_182 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_183, torch.bfloat16); view_183 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 1, 4, 128]); unsqueeze_4 = None + view_184 = torch.ops.aten.view.default(expand_4, [2, 8192, 4, 128]); expand_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_178, 3); view_178 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 1, 4, 128]); unsqueeze_5 = None + view_185 = torch.ops.aten.view.default(expand_5, [2, 8192, 4, 128]); expand_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_184, [0, 2, 1, 3]); view_184 = None + permute_27 = torch.ops.aten.permute.default(view_185, [0, 2, 1, 3]); view_185 = None + _scaled_dot_product_cudnn_attention_backward_13 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_613, permute_25, permute_26, permute_27, getitem_162, getitem_163, getitem_168, getitem_169, None, None, None, 8192, 8192, 0.0, True); permute_613 = permute_25 = permute_26 = permute_27 = getitem_162 = getitem_163 = getitem_168 = getitem_169 = None + getitem_1223 = _scaled_dot_product_cudnn_attention_backward_13[0] + getitem_1224 = _scaled_dot_product_cudnn_attention_backward_13[1] + getitem_1225 = _scaled_dot_product_cudnn_attention_backward_13[2]; _scaled_dot_product_cudnn_attention_backward_13 = None + permute_614 = torch.ops.aten.permute.default(getitem_1225, [0, 2, 1, 3]); getitem_1225 = None + permute_615 = torch.ops.aten.permute.default(getitem_1224, [0, 2, 1, 3]); getitem_1224 = None + permute_616 = torch.ops.aten.permute.default(getitem_1223, [0, 2, 1, 3]); getitem_1223 = None + view_1492 = torch.ops.aten.view.default(permute_614, [2, 8192, 1, 4, 128]); permute_614 = None + sum_83 = torch.ops.aten.sum.dim_IntList(view_1492, [3], True); view_1492 = None + squeeze_26 = torch.ops.aten.squeeze.dim(sum_83, 3); sum_83 = None + view_1493 = torch.ops.aten.view.default(permute_615, [2, 8192, 1, 4, 128]); permute_615 = None + sum_84 = torch.ops.aten.sum.dim_IntList(view_1493, [3], True); view_1493 = None + squeeze_27 = torch.ops.aten.squeeze.dim(sum_84, 3); sum_84 = None + convert_element_type_1277 = torch.ops.prims.convert_element_type.default(squeeze_27, torch.float32); squeeze_27 = None + convert_element_type_1278 = torch.ops.prims.convert_element_type.default(permute_616, torch.float32); permute_616 = None + view_1494 = torch.ops.aten.view.default(convert_element_type_1277, [2, 8192, 1, 64, 2]); convert_element_type_1277 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1494); view_1494 = None + mul_408 = torch.ops.aten.mul.Tensor(view_as_complex_58, _conj); view_as_complex_58 = None + view_1495 = torch.ops.aten.view.default(convert_element_type_1278, [2, 8192, 4, 64, 2]); convert_element_type_1278 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1495); view_1495 = None + mul_409 = torch.ops.aten.mul.Tensor(view_as_complex_59, _conj); view_as_complex_59 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_408); mul_408 = None + view_1496 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 1, 128]); view_as_real_58 = None + convert_element_type_1279 = torch.ops.prims.convert_element_type.default(view_1496, torch.bfloat16); view_1496 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_409); mul_409 = None + view_1497 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 4, 128]); view_as_real_59 = None + convert_element_type_1280 = torch.ops.prims.convert_element_type.default(view_1497, torch.bfloat16); view_1497 = None + view_1498 = torch.ops.aten.view.default(squeeze_26, [2, 8192, 128]); squeeze_26 = None + view_1499 = torch.ops.aten.view.default(convert_element_type_1279, [2, 8192, 128]); convert_element_type_1279 = None + view_1500 = torch.ops.aten.view.default(convert_element_type_1280, [2, 8192, 512]); convert_element_type_1280 = None + view_1501 = torch.ops.aten.view.default(view_1498, [16384, 128]); view_1498 = None + permute_617 = torch.ops.aten.permute.default(view_1501, [1, 0]) + mm_305 = torch.ops.aten.mm.default(permute_617, view_159); permute_617 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16); primals_25 = None + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 8, '0'); convert_element_type_76 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + permute_619 = torch.ops.aten.permute.default(permute_24, [1, 0]); permute_24 = None + mm_306 = torch.ops.aten.mm.default(view_1501, permute_619); view_1501 = permute_619 = None + view_1502 = torch.ops.aten.view.default(mm_306, [2, 8192, 4096]); mm_306 = None + convert_element_type_1285 = torch.ops.prims.convert_element_type.default(mm_305, torch.float32); mm_305 = None + reduce_scatter_tensor_185 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1285, 'avg', 8, '0'); convert_element_type_1285 = None + wait_tensor_421 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_185); reduce_scatter_tensor_185 = None + view_1503 = torch.ops.aten.view.default(view_1499, [16384, 128]); view_1499 = None + permute_621 = torch.ops.aten.permute.default(view_1503, [1, 0]) + mm_307 = torch.ops.aten.mm.default(permute_621, view_159); permute_621 = None + permute_623 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None + mm_308 = torch.ops.aten.mm.default(view_1503, permute_623); view_1503 = permute_623 = None + view_1504 = torch.ops.aten.view.default(mm_308, [2, 8192, 4096]); mm_308 = None + add_160 = torch.ops.aten.add.Tensor(view_1502, view_1504); view_1502 = view_1504 = None + convert_element_type_1290 = torch.ops.prims.convert_element_type.default(mm_307, torch.float32); mm_307 = None + reduce_scatter_tensor_186 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1290, 'avg', 8, '0'); convert_element_type_1290 = None + wait_tensor_422 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_186); reduce_scatter_tensor_186 = None + view_1505 = torch.ops.aten.view.default(view_1500, [16384, 512]); view_1500 = None + permute_625 = torch.ops.aten.permute.default(view_1505, [1, 0]) + mm_309 = torch.ops.aten.mm.default(permute_625, view_159); permute_625 = view_159 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16); primals_23 = None + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 8, '0'); convert_element_type_70 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + permute_627 = torch.ops.aten.permute.default(permute_22, [1, 0]); permute_22 = None + mm_310 = torch.ops.aten.mm.default(view_1505, permute_627); view_1505 = permute_627 = None + view_1506 = torch.ops.aten.view.default(mm_310, [2, 8192, 4096]); mm_310 = None + add_161 = torch.ops.aten.add.Tensor(add_160, view_1506); add_160 = view_1506 = None + convert_element_type_1295 = torch.ops.prims.convert_element_type.default(mm_309, torch.float32); mm_309 = None + reduce_scatter_tensor_187 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1295, 'avg', 8, '0'); convert_element_type_1295 = None + wait_tensor_423 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_187); reduce_scatter_tensor_187 = None + split_130 = torch.ops.aten.split.Tensor(add_161, 1024, 1); add_161 = None + getitem_1226 = split_130[0] + getitem_1227 = split_130[1] + getitem_1228 = split_130[2] + getitem_1229 = split_130[3] + getitem_1230 = split_130[4] + getitem_1231 = split_130[5] + getitem_1232 = split_130[6] + getitem_1233 = split_130[7]; split_130 = None + cat_122 = torch.ops.aten.cat.default([getitem_1226, getitem_1227, getitem_1228, getitem_1229, getitem_1230, getitem_1231, getitem_1232, getitem_1233]); getitem_1226 = getitem_1227 = getitem_1228 = getitem_1229 = getitem_1230 = getitem_1231 = getitem_1232 = getitem_1233 = None + reduce_scatter_tensor_188 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_122, 'sum', 8, '1'); cat_122 = None + wait_tensor_424 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_188); reduce_scatter_tensor_188 = None + convert_element_type_1296 = torch.ops.prims.convert_element_type.default(wait_tensor_424, torch.float32); wait_tensor_424 = None + convert_element_type_1298 = torch.ops.prims.convert_element_type.default(wait_tensor_28, torch.float32); wait_tensor_28 = None + mul_410 = torch.ops.aten.mul.Tensor(convert_element_type_1296, convert_element_type_1298); convert_element_type_1298 = None + mul_412 = torch.ops.aten.mul.Tensor(mul_16, mul_410) + sum_85 = torch.ops.aten.sum.dim_IntList(mul_412, [2], True); mul_412 = None + div_28 = torch.ops.aten.div.Tensor(mul_16, 4096) + mul_413 = torch.ops.aten.mul.Tensor(div_28, sum_85); div_28 = sum_85 = None + sub_43 = torch.ops.aten.sub.Tensor(mul_410, mul_413); mul_410 = mul_413 = None + mul_414 = torch.ops.aten.mul.Tensor(sub_43, rsqrt_4); sub_43 = rsqrt_4 = None + mul_415 = torch.ops.aten.mul.Tensor(convert_element_type_1296, mul_16); convert_element_type_1296 = mul_16 = None + sum_86 = torch.ops.aten.sum.dim_IntList(mul_415, [0, 1]); mul_415 = None + convert_element_type_1299 = torch.ops.prims.convert_element_type.default(mul_414, torch.bfloat16); mul_414 = None + convert_element_type_1300 = torch.ops.prims.convert_element_type.default(sum_86, torch.bfloat16); sum_86 = None + all_reduce_28 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1300, 'sum', '1'); convert_element_type_1300 = None + wait_tensor_425 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_28); all_reduce_28 = None + convert_element_type_1301 = torch.ops.prims.convert_element_type.default(wait_tensor_425, torch.float32); wait_tensor_425 = None + reduce_scatter_tensor_189 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1301, 'avg', 8, '0'); convert_element_type_1301 = None + wait_tensor_426 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_189); reduce_scatter_tensor_189 = None + add_162 = torch.ops.aten.add.Tensor(add_159, convert_element_type_1299); add_159 = convert_element_type_1299 = None + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_162, 8, '1') + wait_tensor_427 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + split_131 = torch.ops.aten.split.Tensor(wait_tensor_427, 2); wait_tensor_427 = None + getitem_1234 = split_131[0] + getitem_1235 = split_131[1] + getitem_1236 = split_131[2] + getitem_1237 = split_131[3] + getitem_1238 = split_131[4] + getitem_1239 = split_131[5] + getitem_1240 = split_131[6] + getitem_1241 = split_131[7]; split_131 = None + cat_123 = torch.ops.aten.cat.default([getitem_1234, getitem_1235, getitem_1236, getitem_1237, getitem_1238, getitem_1239, getitem_1240, getitem_1241], 1); getitem_1234 = getitem_1235 = getitem_1236 = getitem_1237 = getitem_1238 = getitem_1239 = getitem_1240 = getitem_1241 = None + view_1507 = torch.ops.aten.view.default(cat_123, [16384, 4096]); cat_123 = None + permute_629 = torch.ops.aten.permute.default(view_1507, [1, 0]) + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3); reduce_scatter_tensor_3 = None + add_5 = torch.ops.aten.add.Tensor(add_3, wait_tensor_21); wait_tensor_21 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16); primals_18 = None + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 8, '0'); convert_element_type_53 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32); add_5 = None + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_22) + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_55, 8, '1'); convert_element_type_55 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + split_15 = torch.ops.aten.split.Tensor(wait_tensor_23, 2); wait_tensor_23 = None + getitem_138 = split_15[0] + getitem_139 = split_15[1] + getitem_140 = split_15[2] + getitem_141 = split_15[3] + getitem_142 = split_15[4] + getitem_143 = split_15[5] + getitem_144 = split_15[6] + getitem_145 = split_15[7]; split_15 = None + cat_7 = torch.ops.aten.cat.default([getitem_138, getitem_139, getitem_140, getitem_141, getitem_142, getitem_143, getitem_144, getitem_145], 1); getitem_138 = getitem_139 = getitem_140 = getitem_141 = getitem_142 = getitem_143 = getitem_144 = getitem_145 = None + view_132 = torch.ops.aten.view.default(cat_7, [16384, 4096]); cat_7 = None + view_133 = torch.ops.aten.view.default(mm_11, [2, 8192, 1792]); mm_11 = None + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_133, torch.float32); view_133 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16); primals_20 = None + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 8, '0'); convert_element_type_61 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + mm_12 = torch.ops.aten.mm.default(view_132, permute_20) + view_140 = torch.ops.aten.view.default(mm_12, [2, 8192, 1792]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_140) + view_147 = torch.ops.aten.view.default(mul_15, [16384, 1792]); mul_15 = None + mm_311 = torch.ops.aten.mm.default(permute_629, view_147); permute_629 = view_147 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16); primals_21 = None + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 8, '0'); convert_element_type_64 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + permute_631 = torch.ops.aten.permute.default(permute_21, [1, 0]); permute_21 = None + mm_312 = torch.ops.aten.mm.default(view_1507, permute_631); view_1507 = permute_631 = None + view_1508 = torch.ops.aten.view.default(mm_312, [2, 8192, 1792]); mm_312 = None + convert_element_type_1306 = torch.ops.prims.convert_element_type.default(mm_311, torch.float32); mm_311 = None + reduce_scatter_tensor_190 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1306, 'avg', 8, '0'); convert_element_type_1306 = None + wait_tensor_428 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_190); reduce_scatter_tensor_190 = None + mul_416 = torch.ops.aten.mul.Tensor(view_1508, convert_element_type_60); convert_element_type_60 = None + mul_417 = torch.ops.aten.mul.Tensor(view_1508, view_140); view_1508 = view_140 = None + view_1509 = torch.ops.aten.view.default(mul_416, [16384, 1792]); mul_416 = None + permute_633 = torch.ops.aten.permute.default(view_1509, [1, 0]) + mm_313 = torch.ops.aten.mm.default(permute_633, view_132); permute_633 = None + permute_635 = torch.ops.aten.permute.default(permute_20, [1, 0]); permute_20 = None + mm_314 = torch.ops.aten.mm.default(view_1509, permute_635); view_1509 = permute_635 = None + view_1510 = torch.ops.aten.view.default(mm_314, [2, 8192, 4096]); mm_314 = None + convert_element_type_1311 = torch.ops.prims.convert_element_type.default(mm_313, torch.float32); mm_313 = None + reduce_scatter_tensor_191 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1311, 'avg', 8, '0'); convert_element_type_1311 = None + wait_tensor_429 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_191); reduce_scatter_tensor_191 = None + convert_element_type_1312 = torch.ops.prims.convert_element_type.default(mul_417, torch.float32); mul_417 = None + neg_14 = torch.ops.aten.neg.default(convert_element_type_59) + exp_14 = torch.ops.aten.exp.default(neg_14); neg_14 = None + add_163 = torch.ops.aten.add.Tensor(exp_14, 1); exp_14 = None + reciprocal_14 = torch.ops.aten.reciprocal.default(add_163); add_163 = None + mul_418 = torch.ops.aten.mul.Tensor(reciprocal_14, 1); reciprocal_14 = None + mul_419 = torch.ops.aten.mul.Tensor(convert_element_type_1312, mul_418); convert_element_type_1312 = None + sub_44 = torch.ops.aten.sub.Tensor(1, mul_418); mul_418 = None + mul_420 = torch.ops.aten.mul.Tensor(convert_element_type_59, sub_44); convert_element_type_59 = sub_44 = None + add_164 = torch.ops.aten.add.Tensor(mul_420, 1); mul_420 = None + mul_421 = torch.ops.aten.mul.Tensor(mul_419, add_164); mul_419 = add_164 = None + convert_element_type_1314 = torch.ops.prims.convert_element_type.default(mul_421, torch.bfloat16); mul_421 = None + view_1511 = torch.ops.aten.view.default(convert_element_type_1314, [16384, 1792]); convert_element_type_1314 = None + permute_637 = torch.ops.aten.permute.default(view_1511, [1, 0]) + mm_315 = torch.ops.aten.mm.default(permute_637, view_132); permute_637 = view_132 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16); primals_19 = None + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 8, '0'); convert_element_type_56 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_24, [1, 0]); wait_tensor_24 = None + permute_639 = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None + mm_316 = torch.ops.aten.mm.default(view_1511, permute_639); view_1511 = permute_639 = None + view_1512 = torch.ops.aten.view.default(mm_316, [2, 8192, 4096]); mm_316 = None + add_165 = torch.ops.aten.add.Tensor(view_1510, view_1512); view_1510 = view_1512 = None + convert_element_type_1319 = torch.ops.prims.convert_element_type.default(mm_315, torch.float32); mm_315 = None + reduce_scatter_tensor_192 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1319, 'avg', 8, '0'); convert_element_type_1319 = None + wait_tensor_430 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_192); reduce_scatter_tensor_192 = None + split_132 = torch.ops.aten.split.Tensor(add_165, 1024, 1); add_165 = None + getitem_1242 = split_132[0] + getitem_1243 = split_132[1] + getitem_1244 = split_132[2] + getitem_1245 = split_132[3] + getitem_1246 = split_132[4] + getitem_1247 = split_132[5] + getitem_1248 = split_132[6] + getitem_1249 = split_132[7]; split_132 = None + cat_124 = torch.ops.aten.cat.default([getitem_1242, getitem_1243, getitem_1244, getitem_1245, getitem_1246, getitem_1247, getitem_1248, getitem_1249]); getitem_1242 = getitem_1243 = getitem_1244 = getitem_1245 = getitem_1246 = getitem_1247 = getitem_1248 = getitem_1249 = None + reduce_scatter_tensor_193 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_124, 'sum', 8, '1'); cat_124 = None + wait_tensor_431 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_193); reduce_scatter_tensor_193 = None + convert_element_type_1320 = torch.ops.prims.convert_element_type.default(wait_tensor_431, torch.float32); wait_tensor_431 = None + convert_element_type_1322 = torch.ops.prims.convert_element_type.default(wait_tensor_22, torch.float32); wait_tensor_22 = None + mul_422 = torch.ops.aten.mul.Tensor(convert_element_type_1320, convert_element_type_1322); convert_element_type_1322 = None + mul_424 = torch.ops.aten.mul.Tensor(mul_12, mul_422) + sum_87 = torch.ops.aten.sum.dim_IntList(mul_424, [2], True); mul_424 = None + div_29 = torch.ops.aten.div.Tensor(mul_12, 4096) + mul_425 = torch.ops.aten.mul.Tensor(div_29, sum_87); div_29 = sum_87 = None + sub_45 = torch.ops.aten.sub.Tensor(mul_422, mul_425); mul_422 = mul_425 = None + mul_426 = torch.ops.aten.mul.Tensor(sub_45, rsqrt_3); sub_45 = rsqrt_3 = None + mul_427 = torch.ops.aten.mul.Tensor(convert_element_type_1320, mul_12); convert_element_type_1320 = mul_12 = None + sum_88 = torch.ops.aten.sum.dim_IntList(mul_427, [0, 1]); mul_427 = None + convert_element_type_1323 = torch.ops.prims.convert_element_type.default(mul_426, torch.bfloat16); mul_426 = None + convert_element_type_1324 = torch.ops.prims.convert_element_type.default(sum_88, torch.bfloat16); sum_88 = None + all_reduce_29 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1324, 'sum', '1'); convert_element_type_1324 = None + wait_tensor_432 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_29); all_reduce_29 = None + convert_element_type_1325 = torch.ops.prims.convert_element_type.default(wait_tensor_432, torch.float32); wait_tensor_432 = None + reduce_scatter_tensor_194 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1325, 'avg', 8, '0'); convert_element_type_1325 = None + wait_tensor_433 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_194); reduce_scatter_tensor_194 = None + add_166 = torch.ops.aten.add.Tensor(add_162, convert_element_type_1323); add_162 = convert_element_type_1323 = None + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_166, 8, '1') + wait_tensor_434 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + split_133 = torch.ops.aten.split.Tensor(wait_tensor_434, 2); wait_tensor_434 = None + getitem_1250 = split_133[0] + getitem_1251 = split_133[1] + getitem_1252 = split_133[2] + getitem_1253 = split_133[3] + getitem_1254 = split_133[4] + getitem_1255 = split_133[5] + getitem_1256 = split_133[6] + getitem_1257 = split_133[7]; split_133 = None + cat_125 = torch.ops.aten.cat.default([getitem_1250, getitem_1251, getitem_1252, getitem_1253, getitem_1254, getitem_1255, getitem_1256, getitem_1257], 1); getitem_1250 = getitem_1251 = getitem_1252 = getitem_1253 = getitem_1254 = getitem_1255 = getitem_1256 = getitem_1257 = None + view_1513 = torch.ops.aten.view.default(cat_125, [16384, 4096]); cat_125 = None + permute_641 = torch.ops.aten.permute.default(view_1513, [1, 0]) + permute_17 = torch.ops.aten.permute.default(getitem_121, [0, 2, 1, 3]) + view_114 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + view_120 = torch.ops.aten.view.default(view_114, [16384, 512]); view_114 = None + mm_317 = torch.ops.aten.mm.default(permute_641, view_120); permute_641 = view_120 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16); primals_17 = None + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 8, '0'); convert_element_type_50 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + permute_643 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None + mm_318 = torch.ops.aten.mm.default(view_1513, permute_643); view_1513 = permute_643 = None + view_1514 = torch.ops.aten.view.default(mm_318, [2, 8192, 512]); mm_318 = None + convert_element_type_1330 = torch.ops.prims.convert_element_type.default(mm_317, torch.float32); mm_317 = None + reduce_scatter_tensor_195 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1330, 'avg', 8, '0'); convert_element_type_1330 = None + wait_tensor_435 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_195); reduce_scatter_tensor_195 = None + view_1515 = torch.ops.aten.view.default(view_1514, [2, 8192, 4, 128]); view_1514 = None + permute_645 = torch.ops.aten.permute.default(view_1515, [0, 2, 1, 3]); view_1515 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16); primals_13 = None + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 8, '0'); convert_element_type_34 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32); add_3 = None + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_15) + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_36, 8, '1'); convert_element_type_36 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + split_13 = torch.ops.aten.split.Tensor(wait_tensor_16, 2); wait_tensor_16 = None + getitem_113 = split_13[0] + getitem_114 = split_13[1] + getitem_115 = split_13[2] + getitem_116 = split_13[3] + getitem_117 = split_13[4] + getitem_118 = split_13[5] + getitem_119 = split_13[6] + getitem_120 = split_13[7]; split_13 = None + cat_5 = torch.ops.aten.cat.default([getitem_113, getitem_114, getitem_115, getitem_116, getitem_117, getitem_118, getitem_119, getitem_120], 1); getitem_113 = getitem_114 = getitem_115 = getitem_116 = getitem_117 = getitem_118 = getitem_119 = getitem_120 = None + view_87 = torch.ops.aten.view.default(cat_5, [16384, 4096]); cat_5 = None + view_88 = torch.ops.aten.view.default(mm_7, [2, 8192, 512]); mm_7 = None + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16); primals_15 = None + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 8, '0'); convert_element_type_40 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + mm_8 = torch.ops.aten.mm.default(view_87, permute_12) + view_95 = torch.ops.aten.view.default(mm_8, [2, 8192, 128]); mm_8 = None + view_102 = torch.ops.aten.view.default(mm_9, [2, 8192, 128]); mm_9 = None + view_104 = torch.ops.aten.view.default(view_88, [2, 8192, -1, 128]); view_88 = None + view_105 = torch.ops.aten.view.default(view_95, [2, 8192, -1, 128]); view_95 = None + view_106 = torch.ops.aten.view.default(view_102, [2, 8192, -1, 128]); view_102 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_104, torch.float32); view_104 = None + view_107 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 4, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_107); view_107 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_105, torch.float32); view_105 = None + view_108 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 1, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_108); view_108 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_37); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_110 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 4, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_37); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_111 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 1, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_110, torch.bfloat16); view_110 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_111, torch.bfloat16); view_111 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 1, 4, 128]); unsqueeze_2 = None + view_112 = torch.ops.aten.view.default(expand_2, [2, 8192, 4, 128]); expand_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_106, 3); view_106 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 1, 4, 128]); unsqueeze_3 = None + view_113 = torch.ops.aten.view.default(expand_3, [2, 8192, 4, 128]); expand_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_112, [0, 2, 1, 3]); view_112 = None + permute_16 = torch.ops.aten.permute.default(view_113, [0, 2, 1, 3]); view_113 = None + _scaled_dot_product_cudnn_attention_backward_14 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_645, permute_14, permute_15, permute_16, getitem_121, getitem_122, getitem_127, getitem_128, None, None, None, 8192, 8192, 0.0, True); permute_645 = permute_14 = permute_15 = permute_16 = getitem_121 = getitem_122 = getitem_127 = getitem_128 = None + getitem_1258 = _scaled_dot_product_cudnn_attention_backward_14[0] + getitem_1259 = _scaled_dot_product_cudnn_attention_backward_14[1] + getitem_1260 = _scaled_dot_product_cudnn_attention_backward_14[2]; _scaled_dot_product_cudnn_attention_backward_14 = None + permute_646 = torch.ops.aten.permute.default(getitem_1260, [0, 2, 1, 3]); getitem_1260 = None + permute_647 = torch.ops.aten.permute.default(getitem_1259, [0, 2, 1, 3]); getitem_1259 = None + permute_648 = torch.ops.aten.permute.default(getitem_1258, [0, 2, 1, 3]); getitem_1258 = None + view_1516 = torch.ops.aten.view.default(permute_646, [2, 8192, 1, 4, 128]); permute_646 = None + sum_89 = torch.ops.aten.sum.dim_IntList(view_1516, [3], True); view_1516 = None + squeeze_28 = torch.ops.aten.squeeze.dim(sum_89, 3); sum_89 = None + view_1517 = torch.ops.aten.view.default(permute_647, [2, 8192, 1, 4, 128]); permute_647 = None + sum_90 = torch.ops.aten.sum.dim_IntList(view_1517, [3], True); view_1517 = None + squeeze_29 = torch.ops.aten.squeeze.dim(sum_90, 3); sum_90 = None + convert_element_type_1331 = torch.ops.prims.convert_element_type.default(squeeze_29, torch.float32); squeeze_29 = None + convert_element_type_1332 = torch.ops.prims.convert_element_type.default(permute_648, torch.float32); permute_648 = None + view_1518 = torch.ops.aten.view.default(convert_element_type_1331, [2, 8192, 1, 64, 2]); convert_element_type_1331 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1518); view_1518 = None + mul_428 = torch.ops.aten.mul.Tensor(view_as_complex_60, _conj); view_as_complex_60 = None + view_1519 = torch.ops.aten.view.default(convert_element_type_1332, [2, 8192, 4, 64, 2]); convert_element_type_1332 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1519); view_1519 = None + mul_429 = torch.ops.aten.mul.Tensor(view_as_complex_61, _conj); view_as_complex_61 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_428); mul_428 = None + view_1520 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 1, 128]); view_as_real_60 = None + convert_element_type_1333 = torch.ops.prims.convert_element_type.default(view_1520, torch.bfloat16); view_1520 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_429); mul_429 = None + view_1521 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 4, 128]); view_as_real_61 = None + convert_element_type_1334 = torch.ops.prims.convert_element_type.default(view_1521, torch.bfloat16); view_1521 = None + view_1522 = torch.ops.aten.view.default(squeeze_28, [2, 8192, 128]); squeeze_28 = None + view_1523 = torch.ops.aten.view.default(convert_element_type_1333, [2, 8192, 128]); convert_element_type_1333 = None + view_1524 = torch.ops.aten.view.default(convert_element_type_1334, [2, 8192, 512]); convert_element_type_1334 = None + view_1525 = torch.ops.aten.view.default(view_1522, [16384, 128]); view_1522 = None + permute_649 = torch.ops.aten.permute.default(view_1525, [1, 0]) + mm_319 = torch.ops.aten.mm.default(permute_649, view_87); permute_649 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16); primals_16 = None + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 8, '0'); convert_element_type_43 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_19, [1, 0]); wait_tensor_19 = None + permute_651 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None + mm_320 = torch.ops.aten.mm.default(view_1525, permute_651); view_1525 = permute_651 = None + view_1526 = torch.ops.aten.view.default(mm_320, [2, 8192, 4096]); mm_320 = None + convert_element_type_1339 = torch.ops.prims.convert_element_type.default(mm_319, torch.float32); mm_319 = None + reduce_scatter_tensor_196 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1339, 'avg', 8, '0'); convert_element_type_1339 = None + wait_tensor_436 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_196); reduce_scatter_tensor_196 = None + view_1527 = torch.ops.aten.view.default(view_1523, [16384, 128]); view_1523 = None + permute_653 = torch.ops.aten.permute.default(view_1527, [1, 0]) + mm_321 = torch.ops.aten.mm.default(permute_653, view_87); permute_653 = None + permute_655 = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None + mm_322 = torch.ops.aten.mm.default(view_1527, permute_655); view_1527 = permute_655 = None + view_1528 = torch.ops.aten.view.default(mm_322, [2, 8192, 4096]); mm_322 = None + add_167 = torch.ops.aten.add.Tensor(view_1526, view_1528); view_1526 = view_1528 = None + convert_element_type_1344 = torch.ops.prims.convert_element_type.default(mm_321, torch.float32); mm_321 = None + reduce_scatter_tensor_197 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1344, 'avg', 8, '0'); convert_element_type_1344 = None + wait_tensor_437 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_197); reduce_scatter_tensor_197 = None + view_1529 = torch.ops.aten.view.default(view_1524, [16384, 512]); view_1524 = None + permute_657 = torch.ops.aten.permute.default(view_1529, [1, 0]) + mm_323 = torch.ops.aten.mm.default(permute_657, view_87); permute_657 = view_87 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16); primals_14 = None + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 8, '0'); convert_element_type_37 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + permute_659 = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None + mm_324 = torch.ops.aten.mm.default(view_1529, permute_659); view_1529 = permute_659 = None + view_1530 = torch.ops.aten.view.default(mm_324, [2, 8192, 4096]); mm_324 = None + add_168 = torch.ops.aten.add.Tensor(add_167, view_1530); add_167 = view_1530 = None + convert_element_type_1349 = torch.ops.prims.convert_element_type.default(mm_323, torch.float32); mm_323 = None + reduce_scatter_tensor_198 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1349, 'avg', 8, '0'); convert_element_type_1349 = None + wait_tensor_438 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_198); reduce_scatter_tensor_198 = None + split_134 = torch.ops.aten.split.Tensor(add_168, 1024, 1); add_168 = None + getitem_1261 = split_134[0] + getitem_1262 = split_134[1] + getitem_1263 = split_134[2] + getitem_1264 = split_134[3] + getitem_1265 = split_134[4] + getitem_1266 = split_134[5] + getitem_1267 = split_134[6] + getitem_1268 = split_134[7]; split_134 = None + cat_126 = torch.ops.aten.cat.default([getitem_1261, getitem_1262, getitem_1263, getitem_1264, getitem_1265, getitem_1266, getitem_1267, getitem_1268]); getitem_1261 = getitem_1262 = getitem_1263 = getitem_1264 = getitem_1265 = getitem_1266 = getitem_1267 = getitem_1268 = None + reduce_scatter_tensor_199 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_126, 'sum', 8, '1'); cat_126 = None + wait_tensor_439 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_199); reduce_scatter_tensor_199 = None + convert_element_type_1350 = torch.ops.prims.convert_element_type.default(wait_tensor_439, torch.float32); wait_tensor_439 = None + convert_element_type_1352 = torch.ops.prims.convert_element_type.default(wait_tensor_15, torch.float32); wait_tensor_15 = None + mul_430 = torch.ops.aten.mul.Tensor(convert_element_type_1350, convert_element_type_1352); convert_element_type_1352 = None + mul_432 = torch.ops.aten.mul.Tensor(mul_8, mul_430) + sum_91 = torch.ops.aten.sum.dim_IntList(mul_432, [2], True); mul_432 = None + div_30 = torch.ops.aten.div.Tensor(mul_8, 4096) + mul_433 = torch.ops.aten.mul.Tensor(div_30, sum_91); div_30 = sum_91 = None + sub_46 = torch.ops.aten.sub.Tensor(mul_430, mul_433); mul_430 = mul_433 = None + mul_434 = torch.ops.aten.mul.Tensor(sub_46, rsqrt_2); sub_46 = rsqrt_2 = None + mul_435 = torch.ops.aten.mul.Tensor(convert_element_type_1350, mul_8); convert_element_type_1350 = mul_8 = None + sum_92 = torch.ops.aten.sum.dim_IntList(mul_435, [0, 1]); mul_435 = None + convert_element_type_1353 = torch.ops.prims.convert_element_type.default(mul_434, torch.bfloat16); mul_434 = None + convert_element_type_1354 = torch.ops.prims.convert_element_type.default(sum_92, torch.bfloat16); sum_92 = None + all_reduce_30 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1354, 'sum', '1'); convert_element_type_1354 = None + wait_tensor_440 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_30); all_reduce_30 = None + convert_element_type_1355 = torch.ops.prims.convert_element_type.default(wait_tensor_440, torch.float32); wait_tensor_440 = None + reduce_scatter_tensor_200 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1355, 'avg', 8, '0'); convert_element_type_1355 = None + wait_tensor_441 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_200); reduce_scatter_tensor_200 = None + add_169 = torch.ops.aten.add.Tensor(add_166, convert_element_type_1353); add_166 = convert_element_type_1353 = None + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_169, 8, '1') + wait_tensor_442 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + split_135 = torch.ops.aten.split.Tensor(wait_tensor_442, 2); wait_tensor_442 = None + getitem_1269 = split_135[0] + getitem_1270 = split_135[1] + getitem_1271 = split_135[2] + getitem_1272 = split_135[3] + getitem_1273 = split_135[4] + getitem_1274 = split_135[5] + getitem_1275 = split_135[6] + getitem_1276 = split_135[7]; split_135 = None + cat_127 = torch.ops.aten.cat.default([getitem_1269, getitem_1270, getitem_1271, getitem_1272, getitem_1273, getitem_1274, getitem_1275, getitem_1276], 1); getitem_1269 = getitem_1270 = getitem_1271 = getitem_1272 = getitem_1273 = getitem_1274 = getitem_1275 = getitem_1276 = None + view_1531 = torch.ops.aten.view.default(cat_127, [16384, 4096]); cat_127 = None + permute_661 = torch.ops.aten.permute.default(view_1531, [1, 0]) + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None + add_1 = torch.ops.aten.add.Tensor(wait_tensor_1, wait_tensor_8); wait_tensor_8 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16); primals_9 = None + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 8, '0'); convert_element_type_20 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32); add_1 = None + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_9) + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_22, 8, '1'); convert_element_type_22 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + split_11 = torch.ops.aten.split.Tensor(wait_tensor_10, 2); wait_tensor_10 = None + getitem_97 = split_11[0] + getitem_98 = split_11[1] + getitem_99 = split_11[2] + getitem_100 = split_11[3] + getitem_101 = split_11[4] + getitem_102 = split_11[5] + getitem_103 = split_11[6] + getitem_104 = split_11[7]; split_11 = None + cat_3 = torch.ops.aten.cat.default([getitem_97, getitem_98, getitem_99, getitem_100, getitem_101, getitem_102, getitem_103, getitem_104], 1); getitem_97 = getitem_98 = getitem_99 = getitem_100 = getitem_101 = getitem_102 = getitem_103 = getitem_104 = None + view_60 = torch.ops.aten.view.default(cat_3, [16384, 4096]); cat_3 = None + view_61 = torch.ops.aten.view.default(mm_4, [2, 8192, 1792]); mm_4 = None + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_61, torch.float32); view_61 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16); primals_11 = None + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 8, '0'); convert_element_type_28 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_5 = torch.ops.aten.mm.default(view_60, permute_9) + view_68 = torch.ops.aten.view.default(mm_5, [2, 8192, 1792]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_68) + view_75 = torch.ops.aten.view.default(mul_7, [16384, 1792]); mul_7 = None + mm_325 = torch.ops.aten.mm.default(permute_661, view_75); permute_661 = view_75 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16); primals_12 = None + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 8, '0'); convert_element_type_31 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + permute_663 = torch.ops.aten.permute.default(permute_10, [1, 0]); permute_10 = None + mm_326 = torch.ops.aten.mm.default(view_1531, permute_663); view_1531 = permute_663 = None + view_1532 = torch.ops.aten.view.default(mm_326, [2, 8192, 1792]); mm_326 = None + convert_element_type_1360 = torch.ops.prims.convert_element_type.default(mm_325, torch.float32); mm_325 = None + reduce_scatter_tensor_201 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1360, 'avg', 8, '0'); convert_element_type_1360 = None + wait_tensor_443 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_201); reduce_scatter_tensor_201 = None + mul_436 = torch.ops.aten.mul.Tensor(view_1532, convert_element_type_27); convert_element_type_27 = None + mul_437 = torch.ops.aten.mul.Tensor(view_1532, view_68); view_1532 = view_68 = None + view_1533 = torch.ops.aten.view.default(mul_436, [16384, 1792]); mul_436 = None + permute_665 = torch.ops.aten.permute.default(view_1533, [1, 0]) + mm_327 = torch.ops.aten.mm.default(permute_665, view_60); permute_665 = None + permute_667 = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None + mm_328 = torch.ops.aten.mm.default(view_1533, permute_667); view_1533 = permute_667 = None + view_1534 = torch.ops.aten.view.default(mm_328, [2, 8192, 4096]); mm_328 = None + convert_element_type_1365 = torch.ops.prims.convert_element_type.default(mm_327, torch.float32); mm_327 = None + reduce_scatter_tensor_202 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1365, 'avg', 8, '0'); convert_element_type_1365 = None + wait_tensor_444 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_202); reduce_scatter_tensor_202 = None + convert_element_type_1366 = torch.ops.prims.convert_element_type.default(mul_437, torch.float32); mul_437 = None + neg_15 = torch.ops.aten.neg.default(convert_element_type_26) + exp_15 = torch.ops.aten.exp.default(neg_15); neg_15 = None + add_170 = torch.ops.aten.add.Tensor(exp_15, 1); exp_15 = None + reciprocal_15 = torch.ops.aten.reciprocal.default(add_170); add_170 = None + mul_438 = torch.ops.aten.mul.Tensor(reciprocal_15, 1); reciprocal_15 = None + mul_439 = torch.ops.aten.mul.Tensor(convert_element_type_1366, mul_438); convert_element_type_1366 = None + sub_47 = torch.ops.aten.sub.Tensor(1, mul_438); mul_438 = None + mul_440 = torch.ops.aten.mul.Tensor(convert_element_type_26, sub_47); convert_element_type_26 = sub_47 = None + add_171 = torch.ops.aten.add.Tensor(mul_440, 1); mul_440 = None + mul_441 = torch.ops.aten.mul.Tensor(mul_439, add_171); mul_439 = add_171 = None + convert_element_type_1368 = torch.ops.prims.convert_element_type.default(mul_441, torch.bfloat16); mul_441 = None + view_1535 = torch.ops.aten.view.default(convert_element_type_1368, [16384, 1792]); convert_element_type_1368 = None + permute_669 = torch.ops.aten.permute.default(view_1535, [1, 0]) + mm_329 = torch.ops.aten.mm.default(permute_669, view_60); permute_669 = view_60 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16); primals_10 = None + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 8, '0'); convert_element_type_23 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + permute_671 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None + mm_330 = torch.ops.aten.mm.default(view_1535, permute_671); view_1535 = permute_671 = None + view_1536 = torch.ops.aten.view.default(mm_330, [2, 8192, 4096]); mm_330 = None + add_172 = torch.ops.aten.add.Tensor(view_1534, view_1536); view_1534 = view_1536 = None + convert_element_type_1373 = torch.ops.prims.convert_element_type.default(mm_329, torch.float32); mm_329 = None + reduce_scatter_tensor_203 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1373, 'avg', 8, '0'); convert_element_type_1373 = None + wait_tensor_445 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_203); reduce_scatter_tensor_203 = None + split_136 = torch.ops.aten.split.Tensor(add_172, 1024, 1); add_172 = None + getitem_1277 = split_136[0] + getitem_1278 = split_136[1] + getitem_1279 = split_136[2] + getitem_1280 = split_136[3] + getitem_1281 = split_136[4] + getitem_1282 = split_136[5] + getitem_1283 = split_136[6] + getitem_1284 = split_136[7]; split_136 = None + cat_128 = torch.ops.aten.cat.default([getitem_1277, getitem_1278, getitem_1279, getitem_1280, getitem_1281, getitem_1282, getitem_1283, getitem_1284]); getitem_1277 = getitem_1278 = getitem_1279 = getitem_1280 = getitem_1281 = getitem_1282 = getitem_1283 = getitem_1284 = None + reduce_scatter_tensor_204 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_128, 'sum', 8, '1'); cat_128 = None + wait_tensor_446 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_204); reduce_scatter_tensor_204 = None + convert_element_type_1374 = torch.ops.prims.convert_element_type.default(wait_tensor_446, torch.float32); wait_tensor_446 = None + convert_element_type_1376 = torch.ops.prims.convert_element_type.default(wait_tensor_9, torch.float32); wait_tensor_9 = None + mul_442 = torch.ops.aten.mul.Tensor(convert_element_type_1374, convert_element_type_1376); convert_element_type_1376 = None + mul_444 = torch.ops.aten.mul.Tensor(mul_4, mul_442) + sum_93 = torch.ops.aten.sum.dim_IntList(mul_444, [2], True); mul_444 = None + div_31 = torch.ops.aten.div.Tensor(mul_4, 4096) + mul_445 = torch.ops.aten.mul.Tensor(div_31, sum_93); div_31 = sum_93 = None + sub_48 = torch.ops.aten.sub.Tensor(mul_442, mul_445); mul_442 = mul_445 = None + mul_446 = torch.ops.aten.mul.Tensor(sub_48, rsqrt_1); sub_48 = rsqrt_1 = None + mul_447 = torch.ops.aten.mul.Tensor(convert_element_type_1374, mul_4); convert_element_type_1374 = mul_4 = None + sum_94 = torch.ops.aten.sum.dim_IntList(mul_447, [0, 1]); mul_447 = None + convert_element_type_1377 = torch.ops.prims.convert_element_type.default(mul_446, torch.bfloat16); mul_446 = None + convert_element_type_1378 = torch.ops.prims.convert_element_type.default(sum_94, torch.bfloat16); sum_94 = None + all_reduce_31 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1378, 'sum', '1'); convert_element_type_1378 = None + wait_tensor_447 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_31); all_reduce_31 = None + convert_element_type_1379 = torch.ops.prims.convert_element_type.default(wait_tensor_447, torch.float32); wait_tensor_447 = None + reduce_scatter_tensor_205 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1379, 'avg', 8, '0'); convert_element_type_1379 = None + wait_tensor_448 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_205); reduce_scatter_tensor_205 = None + add_173 = torch.ops.aten.add.Tensor(add_169, convert_element_type_1377); add_169 = convert_element_type_1377 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_173, 8, '1') + wait_tensor_449 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + split_137 = torch.ops.aten.split.Tensor(wait_tensor_449, 2); wait_tensor_449 = None + getitem_1285 = split_137[0] + getitem_1286 = split_137[1] + getitem_1287 = split_137[2] + getitem_1288 = split_137[3] + getitem_1289 = split_137[4] + getitem_1290 = split_137[5] + getitem_1291 = split_137[6] + getitem_1292 = split_137[7]; split_137 = None + cat_129 = torch.ops.aten.cat.default([getitem_1285, getitem_1286, getitem_1287, getitem_1288, getitem_1289, getitem_1290, getitem_1291, getitem_1292], 1); getitem_1285 = getitem_1286 = getitem_1287 = getitem_1288 = getitem_1289 = getitem_1290 = getitem_1291 = getitem_1292 = None + view_1537 = torch.ops.aten.view.default(cat_129, [16384, 4096]); cat_129 = None + permute_673 = torch.ops.aten.permute.default(view_1537, [1, 0]) + permute_6 = torch.ops.aten.permute.default(getitem_80, [0, 2, 1, 3]) + view_42 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + view_48 = torch.ops.aten.view.default(view_42, [16384, 512]); view_42 = None + mm_331 = torch.ops.aten.mm.default(permute_673, view_48); permute_673 = view_48 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16); primals_8 = None + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 8, '0'); convert_element_type_17 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + permute_675 = torch.ops.aten.permute.default(permute_7, [1, 0]); permute_7 = None + mm_332 = torch.ops.aten.mm.default(view_1537, permute_675); view_1537 = permute_675 = None + view_1538 = torch.ops.aten.view.default(mm_332, [2, 8192, 512]); mm_332 = None + convert_element_type_1384 = torch.ops.prims.convert_element_type.default(mm_331, torch.float32); mm_331 = None + reduce_scatter_tensor_206 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1384, 'avg', 8, '0'); convert_element_type_1384 = None + wait_tensor_450 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_206); reduce_scatter_tensor_206 = None + view_1539 = torch.ops.aten.view.default(view_1538, [2, 8192, 4, 128]); view_1538 = None + permute_677 = torch.ops.aten.permute.default(view_1539, [0, 2, 1, 3]); view_1539 = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 8, '0'); convert_element_type_1 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32); wait_tensor_1 = None + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_2) + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_3, 8, '1'); convert_element_type_3 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + split_9 = torch.ops.aten.split.Tensor(wait_tensor_3, 2); wait_tensor_3 = None + getitem_72 = split_9[0] + getitem_73 = split_9[1] + getitem_74 = split_9[2] + getitem_75 = split_9[3] + getitem_76 = split_9[4] + getitem_77 = split_9[5] + getitem_78 = split_9[6] + getitem_79 = split_9[7]; split_9 = None + cat_1 = torch.ops.aten.cat.default([getitem_72, getitem_73, getitem_74, getitem_75, getitem_76, getitem_77, getitem_78, getitem_79], 1); getitem_72 = getitem_73 = getitem_74 = getitem_75 = getitem_76 = getitem_77 = getitem_78 = getitem_79 = None + view_15 = torch.ops.aten.view.default(cat_1, [16384, 4096]); cat_1 = None + view_16 = torch.ops.aten.view.default(mm, [2, 8192, 512]); mm = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16); primals_6 = None + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 8, '0'); convert_element_type_7 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + mm_1 = torch.ops.aten.mm.default(view_15, permute_1) + view_23 = torch.ops.aten.view.default(mm_1, [2, 8192, 128]); mm_1 = None + view_30 = torch.ops.aten.view.default(mm_2, [2, 8192, 128]); mm_2 = None + view_32 = torch.ops.aten.view.default(view_16, [2, 8192, -1, 128]); view_16 = None + view_33 = torch.ops.aten.view.default(view_23, [2, 8192, -1, 128]); view_23 = None + view_34 = torch.ops.aten.view.default(view_30, [2, 8192, -1, 128]); view_30 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_32, torch.float32); view_32 = None + view_35 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 4, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_35); view_35 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_33, torch.float32); view_33 = None + view_36 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 1, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_36); view_36 = None + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_37); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_38 = torch.ops.aten.view.default(view_as_real, [2, 8192, 4, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_37); view_as_complex_1 = view_37 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_39 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 1, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_38, torch.bfloat16); view_38 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_39, torch.bfloat16); view_39 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 1, 4, 128]); unsqueeze = None + view_40 = torch.ops.aten.view.default(expand, [2, 8192, 4, 128]); expand = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_34, 3); view_34 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 1, 4, 128]); unsqueeze_1 = None + view_41 = torch.ops.aten.view.default(expand_1, [2, 8192, 4, 128]); expand_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_40, [0, 2, 1, 3]); view_40 = None + permute_5 = torch.ops.aten.permute.default(view_41, [0, 2, 1, 3]); view_41 = None + _scaled_dot_product_cudnn_attention_backward_15 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_677, permute_3, permute_4, permute_5, getitem_80, getitem_81, getitem_86, getitem_87, None, None, None, 8192, 8192, 0.0, True); permute_677 = permute_3 = permute_4 = permute_5 = getitem_80 = getitem_81 = getitem_86 = getitem_87 = None + getitem_1293 = _scaled_dot_product_cudnn_attention_backward_15[0] + getitem_1294 = _scaled_dot_product_cudnn_attention_backward_15[1] + getitem_1295 = _scaled_dot_product_cudnn_attention_backward_15[2]; _scaled_dot_product_cudnn_attention_backward_15 = None + permute_678 = torch.ops.aten.permute.default(getitem_1295, [0, 2, 1, 3]); getitem_1295 = None + permute_679 = torch.ops.aten.permute.default(getitem_1294, [0, 2, 1, 3]); getitem_1294 = None + permute_680 = torch.ops.aten.permute.default(getitem_1293, [0, 2, 1, 3]); getitem_1293 = None + view_1540 = torch.ops.aten.view.default(permute_678, [2, 8192, 1, 4, 128]); permute_678 = None + sum_95 = torch.ops.aten.sum.dim_IntList(view_1540, [3], True); view_1540 = None + squeeze_30 = torch.ops.aten.squeeze.dim(sum_95, 3); sum_95 = None + view_1541 = torch.ops.aten.view.default(permute_679, [2, 8192, 1, 4, 128]); permute_679 = None + sum_96 = torch.ops.aten.sum.dim_IntList(view_1541, [3], True); view_1541 = None + squeeze_31 = torch.ops.aten.squeeze.dim(sum_96, 3); sum_96 = None + convert_element_type_1385 = torch.ops.prims.convert_element_type.default(squeeze_31, torch.float32); squeeze_31 = None + convert_element_type_1386 = torch.ops.prims.convert_element_type.default(permute_680, torch.float32); permute_680 = None + view_1542 = torch.ops.aten.view.default(convert_element_type_1385, [2, 8192, 1, 64, 2]); convert_element_type_1385 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1542); view_1542 = None + mul_448 = torch.ops.aten.mul.Tensor(view_as_complex_62, _conj); view_as_complex_62 = None + view_1543 = torch.ops.aten.view.default(convert_element_type_1386, [2, 8192, 4, 64, 2]); convert_element_type_1386 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1543); view_1543 = None + mul_449 = torch.ops.aten.mul.Tensor(view_as_complex_63, _conj); view_as_complex_63 = _conj = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_448); mul_448 = None + view_1544 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 1, 128]); view_as_real_62 = None + convert_element_type_1387 = torch.ops.prims.convert_element_type.default(view_1544, torch.bfloat16); view_1544 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_449); mul_449 = None + view_1545 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 4, 128]); view_as_real_63 = None + convert_element_type_1388 = torch.ops.prims.convert_element_type.default(view_1545, torch.bfloat16); view_1545 = None + view_1546 = torch.ops.aten.view.default(squeeze_30, [2, 8192, 128]); squeeze_30 = None + view_1547 = torch.ops.aten.view.default(convert_element_type_1387, [2, 8192, 128]); convert_element_type_1387 = None + view_1548 = torch.ops.aten.view.default(convert_element_type_1388, [2, 8192, 512]); convert_element_type_1388 = None + view_1549 = torch.ops.aten.view.default(view_1546, [16384, 128]); view_1546 = None + permute_681 = torch.ops.aten.permute.default(view_1549, [1, 0]) + mm_333 = torch.ops.aten.mm.default(permute_681, view_15); permute_681 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16); primals_7 = None + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 8, '0'); convert_element_type_10 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + permute_683 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + mm_334 = torch.ops.aten.mm.default(view_1549, permute_683); view_1549 = permute_683 = None + view_1550 = torch.ops.aten.view.default(mm_334, [2, 8192, 4096]); mm_334 = None + convert_element_type_1393 = torch.ops.prims.convert_element_type.default(mm_333, torch.float32); mm_333 = None + reduce_scatter_tensor_207 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1393, 'avg', 8, '0'); convert_element_type_1393 = None + wait_tensor_451 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_207); reduce_scatter_tensor_207 = None + view_1551 = torch.ops.aten.view.default(view_1547, [16384, 128]); view_1547 = None + permute_685 = torch.ops.aten.permute.default(view_1551, [1, 0]) + mm_335 = torch.ops.aten.mm.default(permute_685, view_15); permute_685 = None + permute_687 = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None + mm_336 = torch.ops.aten.mm.default(view_1551, permute_687); view_1551 = permute_687 = None + view_1552 = torch.ops.aten.view.default(mm_336, [2, 8192, 4096]); mm_336 = None + add_174 = torch.ops.aten.add.Tensor(view_1550, view_1552); view_1550 = view_1552 = None + convert_element_type_1398 = torch.ops.prims.convert_element_type.default(mm_335, torch.float32); mm_335 = None + reduce_scatter_tensor_208 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1398, 'avg', 8, '0'); convert_element_type_1398 = None + wait_tensor_452 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_208); reduce_scatter_tensor_208 = None + view_1553 = torch.ops.aten.view.default(view_1548, [16384, 512]); view_1548 = None + permute_689 = torch.ops.aten.permute.default(view_1553, [1, 0]) + mm_337 = torch.ops.aten.mm.default(permute_689, view_15); permute_689 = view_15 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16); primals_5 = None + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 8, '0'); convert_element_type_4 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + permute_691 = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + mm_338 = torch.ops.aten.mm.default(view_1553, permute_691); view_1553 = permute_691 = None + view_1554 = torch.ops.aten.view.default(mm_338, [2, 8192, 4096]); mm_338 = None + add_175 = torch.ops.aten.add.Tensor(add_174, view_1554); add_174 = view_1554 = None + convert_element_type_1403 = torch.ops.prims.convert_element_type.default(mm_337, torch.float32); mm_337 = None + reduce_scatter_tensor_209 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1403, 'avg', 8, '0'); convert_element_type_1403 = None + wait_tensor_453 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_209); reduce_scatter_tensor_209 = None + split_138 = torch.ops.aten.split.Tensor(add_175, 1024, 1); add_175 = None + getitem_1296 = split_138[0] + getitem_1297 = split_138[1] + getitem_1298 = split_138[2] + getitem_1299 = split_138[3] + getitem_1300 = split_138[4] + getitem_1301 = split_138[5] + getitem_1302 = split_138[6] + getitem_1303 = split_138[7]; split_138 = None + cat_130 = torch.ops.aten.cat.default([getitem_1296, getitem_1297, getitem_1298, getitem_1299, getitem_1300, getitem_1301, getitem_1302, getitem_1303]); getitem_1296 = getitem_1297 = getitem_1298 = getitem_1299 = getitem_1300 = getitem_1301 = getitem_1302 = getitem_1303 = None + reduce_scatter_tensor_210 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_130, 'sum', 8, '1'); cat_130 = None + wait_tensor_454 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_210); reduce_scatter_tensor_210 = None + convert_element_type_1404 = torch.ops.prims.convert_element_type.default(wait_tensor_454, torch.float32); wait_tensor_454 = None + convert_element_type_1406 = torch.ops.prims.convert_element_type.default(wait_tensor_2, torch.float32); wait_tensor_2 = None + mul_450 = torch.ops.aten.mul.Tensor(convert_element_type_1404, convert_element_type_1406); convert_element_type_1406 = None + mul_452 = torch.ops.aten.mul.Tensor(mul, mul_450) + sum_97 = torch.ops.aten.sum.dim_IntList(mul_452, [2], True); mul_452 = None + div_32 = torch.ops.aten.div.Tensor(mul, 4096) + mul_453 = torch.ops.aten.mul.Tensor(div_32, sum_97); div_32 = sum_97 = None + sub_49 = torch.ops.aten.sub.Tensor(mul_450, mul_453); mul_450 = mul_453 = None + mul_454 = torch.ops.aten.mul.Tensor(sub_49, rsqrt); sub_49 = rsqrt = None + mul_455 = torch.ops.aten.mul.Tensor(convert_element_type_1404, mul); convert_element_type_1404 = mul = None + sum_98 = torch.ops.aten.sum.dim_IntList(mul_455, [0, 1]); mul_455 = None + convert_element_type_1407 = torch.ops.prims.convert_element_type.default(mul_454, torch.bfloat16); mul_454 = None + convert_element_type_1408 = torch.ops.prims.convert_element_type.default(sum_98, torch.bfloat16); sum_98 = None + all_reduce_32 = torch.ops._c10d_functional.all_reduce.default(convert_element_type_1408, 'sum', '1'); convert_element_type_1408 = None + wait_tensor_455 = torch.ops._c10d_functional.wait_tensor.default(all_reduce_32); all_reduce_32 = None + convert_element_type_1409 = torch.ops.prims.convert_element_type.default(wait_tensor_455, torch.float32); wait_tensor_455 = None + reduce_scatter_tensor_211 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1409, 'avg', 8, '0'); convert_element_type_1409 = None + wait_tensor_456 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_211); reduce_scatter_tensor_211 = None + add_176 = torch.ops.aten.add.Tensor(add_173, convert_element_type_1407); add_173 = convert_element_type_1407 = None + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(add_176, 8, '1'); add_176 = None + wait_tensor_457 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + split_139 = torch.ops.aten.split.Tensor(wait_tensor_457, 2); wait_tensor_457 = None + getitem_1304 = split_139[0] + getitem_1305 = split_139[1] + getitem_1306 = split_139[2] + getitem_1307 = split_139[3] + getitem_1308 = split_139[4] + getitem_1309 = split_139[5] + getitem_1310 = split_139[6] + getitem_1311 = split_139[7]; split_139 = None + cat_131 = torch.ops.aten.cat.default([getitem_1304, getitem_1305, getitem_1306, getitem_1307, getitem_1308, getitem_1309, getitem_1310, getitem_1311], 1); getitem_1304 = getitem_1305 = getitem_1306 = getitem_1307 = getitem_1308 = getitem_1309 = getitem_1310 = getitem_1311 = None + convert_element_type_1410 = torch.ops.prims.convert_element_type.default(cat_131, torch.float32); cat_131 = None + eq = torch.ops.aten.eq.Scalar(primals_1, -1) + unsqueeze_32 = torch.ops.aten.unsqueeze.default(eq, -1); eq = None + full_default_2 = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + where = torch.ops.aten.where.self(unsqueeze_32, full_default_2, convert_element_type_1410); unsqueeze_32 = full_default_2 = convert_element_type_1410 = None + full_default_3 = torch.ops.aten.full.default([128256, 4096], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) + index_put_2 = torch.ops.aten.index_put.default(full_default_3, [primals_1], where, True); full_default_3 = primals_1 = where = None + convert_element_type_1411 = torch.ops.prims.convert_element_type.default(index_put_2, torch.bfloat16); index_put_2 = None + split_140 = torch.ops.aten.split.Tensor(convert_element_type_1411, 16032); convert_element_type_1411 = None + getitem_1312 = split_140[0]; split_140 = None + convert_element_type_1412 = torch.ops.prims.convert_element_type.default(getitem_1312, torch.float32); getitem_1312 = None + reduce_scatter_tensor_212 = torch.ops._c10d_functional.reduce_scatter_tensor.default(convert_element_type_1412, 'avg', 8, '0'); convert_element_type_1412 = None + wait_tensor_458 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_212); reduce_scatter_tensor_212 = None + return (None, wait_tensor_458, None, wait_tensor_456, wait_tensor_453, wait_tensor_452, wait_tensor_451, wait_tensor_450, wait_tensor_448, wait_tensor_445, wait_tensor_444, wait_tensor_443, wait_tensor_441, wait_tensor_438, wait_tensor_437, wait_tensor_436, wait_tensor_435, wait_tensor_433, wait_tensor_430, wait_tensor_429, wait_tensor_428, wait_tensor_426, wait_tensor_423, wait_tensor_422, wait_tensor_421, wait_tensor_420, wait_tensor_418, wait_tensor_415, wait_tensor_414, wait_tensor_413, wait_tensor_411, wait_tensor_408, wait_tensor_407, wait_tensor_406, wait_tensor_405, wait_tensor_403, wait_tensor_400, wait_tensor_399, wait_tensor_398, wait_tensor_396, wait_tensor_393, wait_tensor_392, wait_tensor_391, wait_tensor_390, wait_tensor_388, wait_tensor_385, wait_tensor_384, wait_tensor_383, wait_tensor_381, wait_tensor_378, wait_tensor_377, wait_tensor_376, wait_tensor_375, wait_tensor_373, wait_tensor_370, wait_tensor_369, wait_tensor_368, wait_tensor_366, wait_tensor_363, wait_tensor_362, wait_tensor_361, wait_tensor_360, wait_tensor_358, wait_tensor_355, wait_tensor_354, wait_tensor_353, wait_tensor_351, wait_tensor_348, wait_tensor_347, wait_tensor_346, wait_tensor_345, wait_tensor_343, wait_tensor_340, wait_tensor_339, wait_tensor_338, wait_tensor_336, wait_tensor_333, wait_tensor_332, wait_tensor_331, wait_tensor_330, wait_tensor_328, wait_tensor_325, wait_tensor_324, wait_tensor_323, wait_tensor_321, wait_tensor_318, wait_tensor_317, wait_tensor_316, wait_tensor_315, wait_tensor_313, wait_tensor_310, wait_tensor_309, wait_tensor_308, wait_tensor_306, wait_tensor_303, wait_tensor_302, wait_tensor_301, wait_tensor_300, wait_tensor_298, wait_tensor_295, wait_tensor_294, wait_tensor_293, wait_tensor_291, wait_tensor_288, wait_tensor_287, wait_tensor_286, wait_tensor_285, wait_tensor_283, wait_tensor_280, wait_tensor_279, wait_tensor_278, wait_tensor_276, wait_tensor_273, wait_tensor_272, wait_tensor_271, wait_tensor_270, wait_tensor_268, wait_tensor_265, wait_tensor_264, wait_tensor_263, wait_tensor_261, wait_tensor_258, wait_tensor_257, wait_tensor_256, wait_tensor_255, wait_tensor_253, wait_tensor_250, wait_tensor_249, wait_tensor_248, wait_tensor_246, wait_tensor_243, wait_tensor_242, wait_tensor_241, wait_tensor_240, wait_tensor_238, wait_tensor_235, wait_tensor_234, wait_tensor_233, wait_tensor_231, wait_tensor_228, wait_tensor_227, wait_tensor_226, wait_tensor_225, wait_tensor_223, wait_tensor_220, wait_tensor_219, wait_tensor_218, wait_tensor_216, wait_tensor_213) + +def load_args(reader): + buf0 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf0, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_1 + buf1 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf1, (2004, 4096), is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf3, (512,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf4, (64, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf5, (16, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf6, (16, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf7, (512, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf8, (512,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf9, (224, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf10, (224, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf11, (512, 1792), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf12, (512,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf13, (64, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf14, (16, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf15, (16, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf16, (512, 512), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf17, (512,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf18, (224, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf19, (224, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf20, (512, 1792), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf21, (512,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf22, (64, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf23, (16, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf24, (16, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf25, (512, 512), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf26, (512,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf27, (224, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf28, (224, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf29, (512, 1792), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf30, (512,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf31, (64, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf32, (16, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf33, (16, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf34, (512, 512), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf35, (512,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf36, (224, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf37, (224, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf38, (512, 1792), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf39, (512,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf40, (64, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf41, (16, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf42, (16, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf43, (512, 512), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf44, (512,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf45, (224, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf46, (224, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf47, (512, 1792), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf48, (512,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf49, (64, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf50, (16, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf51, (16, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf52, (512, 512), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf53, (512,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf54, (224, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf55, (224, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf56, (512, 1792), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf57, (512,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf58, (64, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf59, (16, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf60, (16, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf61, (512, 512), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf62, (512,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf63, (224, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf64, (224, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf65, (512, 1792), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf66, (512,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf67, (64, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf68, (16, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf69, (16, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf70, (512, 512), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf71, (512,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf72, (224, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf73, (224, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf74, (512, 1792), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf75, (512,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf76, (64, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf77, (16, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf78, (16, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf79, (512, 512), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf80, (512,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf81, (224, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf82, (224, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf83, (512, 1792), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf84, (512,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf85, (64, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf86, (16, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf87, (16, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf88, (512, 512), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf89, (512,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf90, (224, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf91, (224, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf92, (512, 1792), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf93, (512,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf94, (64, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf95, (16, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf96, (16, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf97, (512, 512), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf98, (512,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf99, (224, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf100, (224, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf101, (512, 1792), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf102, (512,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf103, (64, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf104, (16, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf105, (16, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf106, (512, 512), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf107, (512,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf108, (224, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf109, (224, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf110, (512, 1792), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf111, (512,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf112, (64, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf113, (16, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf114, (16, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf115, (512, 512), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf116, (512,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf117, (224, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf118, (224, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf119, (512, 1792), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf120, (512,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf121, (64, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf122, (16, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf123, (16, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf124, (512, 512), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf125, (512,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf126, (224, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf127, (224, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf128, (512, 1792), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf129, (512,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf130, (64, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf131, (16, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf132, (16, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf133, (512, 512), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf134, (512,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf135, (224, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf136, (224, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf137, (512, 1792), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf138, (512,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf139, (64, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf140, (16, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf141, (16, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf142, (512, 512), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf143, (512,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf144, (224, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf145, (224, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf146, (512, 1792), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf147, (512,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf148, (2004, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf149, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # wait_tensor_1 + buf150 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf150, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm + buf151 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf151, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_2 + buf152 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf152, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_80 + buf153 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf153, (2, 4, 8192, 1), is_leaf=True) # getitem_81 + buf154 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf154, (), dtype=torch.int64, is_leaf=True) # getitem_86 + buf155 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf155, (), dtype=torch.int64, is_leaf=True) # getitem_87 + buf156 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf156, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_1 + buf157 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf157, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_4 + buf158 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf158, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_3 + buf159 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf159, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_7 + buf160 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf160, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_9 + buf161 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf161, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_121 + buf162 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf162, (2, 4, 8192, 1), is_leaf=True) # getitem_122 + buf163 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf163, (), dtype=torch.int64, is_leaf=True) # getitem_127 + buf164 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf164, (), dtype=torch.int64, is_leaf=True) # getitem_128 + buf165 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf165, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_3 + buf166 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf166, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_11 + buf167 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf167, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_7 + buf168 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf168, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_14 + buf169 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf169, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_16 + buf170 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf170, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_162 + buf171 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf171, (2, 4, 8192, 1), is_leaf=True) # getitem_163 + buf172 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf172, (), dtype=torch.int64, is_leaf=True) # getitem_168 + buf173 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf173, (), dtype=torch.int64, is_leaf=True) # getitem_169 + buf174 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf174, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_5 + buf175 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf175, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_18 + buf176 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf176, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_11 + buf177 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf177, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_21 + buf178 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf178, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_23 + buf179 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf179, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_203 + buf180 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf180, (2, 4, 8192, 1), is_leaf=True) # getitem_204 + buf181 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf181, (), dtype=torch.int64, is_leaf=True) # getitem_209 + buf182 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf182, (), dtype=torch.int64, is_leaf=True) # getitem_210 + buf183 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf183, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_7 + buf184 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf184, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_25 + buf185 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf185, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_15 + buf186 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf186, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_28 + buf187 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf187, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_30 + buf188 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf188, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_244 + buf189 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf189, (2, 4, 8192, 1), is_leaf=True) # getitem_245 + buf190 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf190, (), dtype=torch.int64, is_leaf=True) # getitem_250 + buf191 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf191, (), dtype=torch.int64, is_leaf=True) # getitem_251 + buf192 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf192, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_9 + buf193 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf193, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_32 + buf194 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf194, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_19 + buf195 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf195, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_35 + buf196 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf196, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_37 + buf197 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf197, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_285 + buf198 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf198, (2, 4, 8192, 1), is_leaf=True) # getitem_286 + buf199 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf199, (), dtype=torch.int64, is_leaf=True) # getitem_291 + buf200 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf200, (), dtype=torch.int64, is_leaf=True) # getitem_292 + buf201 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf201, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_11 + buf202 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf202, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_39 + buf203 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf203, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_23 + buf204 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf204, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_42 + buf205 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf205, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_44 + buf206 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf206, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_326 + buf207 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf207, (2, 4, 8192, 1), is_leaf=True) # getitem_327 + buf208 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf208, (), dtype=torch.int64, is_leaf=True) # getitem_332 + buf209 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf209, (), dtype=torch.int64, is_leaf=True) # getitem_333 + buf210 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf210, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_13 + buf211 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf211, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_46 + buf212 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf212, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_27 + buf213 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf213, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_49 + buf214 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf214, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_51 + buf215 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf215, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_367 + buf216 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf216, (2, 4, 8192, 1), is_leaf=True) # getitem_368 + buf217 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf217, (), dtype=torch.int64, is_leaf=True) # getitem_373 + buf218 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf218, (), dtype=torch.int64, is_leaf=True) # getitem_374 + buf219 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf219, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_15 + buf220 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf220, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_53 + buf221 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf221, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_31 + buf222 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf222, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_56 + buf223 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf223, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_58 + buf224 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf224, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_408 + buf225 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf225, (2, 4, 8192, 1), is_leaf=True) # getitem_409 + buf226 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf226, (), dtype=torch.int64, is_leaf=True) # getitem_414 + buf227 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf227, (), dtype=torch.int64, is_leaf=True) # getitem_415 + buf228 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf228, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_17 + buf229 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf229, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_60 + buf230 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf230, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_35 + buf231 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf231, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_63 + buf232 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf232, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_65 + buf233 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf233, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_449 + buf234 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf234, (2, 4, 8192, 1), is_leaf=True) # getitem_450 + buf235 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf235, (), dtype=torch.int64, is_leaf=True) # getitem_455 + buf236 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf236, (), dtype=torch.int64, is_leaf=True) # getitem_456 + buf237 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf237, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_19 + buf238 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf238, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_67 + buf239 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf239, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_39 + buf240 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf240, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_70 + buf241 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf241, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_72 + buf242 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf242, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_490 + buf243 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf243, (2, 4, 8192, 1), is_leaf=True) # getitem_491 + buf244 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf244, (), dtype=torch.int64, is_leaf=True) # getitem_496 + buf245 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf245, (), dtype=torch.int64, is_leaf=True) # getitem_497 + buf246 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf246, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_21 + buf247 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf247, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_74 + buf248 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf248, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_43 + buf249 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf249, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_77 + buf250 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf250, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_79 + buf251 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf251, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_531 + buf252 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf252, (2, 4, 8192, 1), is_leaf=True) # getitem_532 + buf253 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf253, (), dtype=torch.int64, is_leaf=True) # getitem_537 + buf254 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf254, (), dtype=torch.int64, is_leaf=True) # getitem_538 + buf255 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf255, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_23 + buf256 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf256, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_81 + buf257 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf257, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_47 + buf258 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf258, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_84 + buf259 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf259, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_86 + buf260 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf260, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_572 + buf261 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf261, (2, 4, 8192, 1), is_leaf=True) # getitem_573 + buf262 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf262, (), dtype=torch.int64, is_leaf=True) # getitem_578 + buf263 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf263, (), dtype=torch.int64, is_leaf=True) # getitem_579 + buf264 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf264, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_25 + buf265 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf265, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_88 + buf266 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf266, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_51 + buf267 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf267, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_91 + buf268 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf268, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_93 + buf269 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf269, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_613 + buf270 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf270, (2, 4, 8192, 1), is_leaf=True) # getitem_614 + buf271 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf271, (), dtype=torch.int64, is_leaf=True) # getitem_619 + buf272 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf272, (), dtype=torch.int64, is_leaf=True) # getitem_620 + buf273 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf273, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_27 + buf274 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf274, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_95 + buf275 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf275, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_55 + buf276 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf276, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_98 + buf277 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf277, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_100 + buf278 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf278, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_654 + buf279 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf279, (2, 4, 8192, 1), is_leaf=True) # getitem_655 + buf280 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf280, (), dtype=torch.int64, is_leaf=True) # getitem_660 + buf281 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf281, (), dtype=torch.int64, is_leaf=True) # getitem_661 + buf282 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf282, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_29 + buf283 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf283, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_102 + buf284 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf284, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # add_59 + buf285 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf285, (16384, 512), dtype=torch.bfloat16, is_leaf=True) # mm_105 + buf286 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf286, (16384, 128), dtype=torch.bfloat16, is_leaf=True) # mm_107 + buf287 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf287, (2, 4, 8192, 128), (4194304, 128, 512, 1), dtype=torch.bfloat16, is_leaf=True) # getitem_695 + buf288 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf288, (2, 4, 8192, 1), is_leaf=True) # getitem_696 + buf289 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf289, (), dtype=torch.int64, is_leaf=True) # getitem_701 + buf290 = reader.storage(None, 8, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf290, (), dtype=torch.int64, is_leaf=True) # getitem_702 + buf291 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf291, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_31 + buf292 = reader.storage(None, 58720256, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf292, (16384, 1792), dtype=torch.bfloat16, is_leaf=True) # mm_109 + buf293 = reader.storage(None, 16777216, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf293, (2, 1024, 4096), dtype=torch.bfloat16, is_leaf=True) # reduce_scatter_tensor_32 + buf294 = reader.storage(None, 8192, device=device(type='cuda', index=0)) + reader.tensor(buf294, (2, 1024, 1), is_leaf=True) # rsqrt_32 + buf295 = reader.storage(None, 134217728, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf295, (16384, 4096), dtype=torch.bfloat16, is_leaf=True) # view_1167 + buf296 = reader.storage(None, 525336576, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16) + reader.tensor(buf296, (2, 8192, 16032), dtype=torch.bfloat16, is_leaf=True) # tangents_1 + +load_args._version = 0 + +def get_pg_config(): + return {'0': {'size': 8, 'rank': 0}, '1': {'size': 8, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls8_8.table" diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_1d.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_1d.py new file mode 100644 index 00000000..bef8b924 --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_1d.py @@ -0,0 +1,4153 @@ +# fmt: off +# flake8: noqa +# isort: skip_file +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293): + convert_element_type = torch.ops.prims.convert_element_type.default(primals_1, torch.bfloat16) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type, 256, '0'); convert_element_type = None + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None + embedding = torch.ops.aten.embedding.default(wait_tensor, primals_2); wait_tensor = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16) + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 256, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32) + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = rsqrt = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_1); mul = wait_tensor_1 = None + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16) + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 256, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + view_3 = torch.ops.aten.view.default(convert_element_type_3, [16384, 4096]); convert_element_type_3 = None + mm = torch.ops.aten.mm.default(view_3, permute); permute = None + view_4 = torch.ops.aten.view.default(mm, [2, 8192, 4096]) + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16) + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 256, '0'); convert_element_type_7 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_3, [1, 0]); wait_tensor_3 = None + mm_1 = torch.ops.aten.mm.default(view_3, permute_1); permute_1 = None + view_7 = torch.ops.aten.view.default(mm_1, [2, 8192, 1024]); mm_1 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16) + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 256, '0'); convert_element_type_10 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + mm_2 = torch.ops.aten.mm.default(view_3, permute_2); view_3 = permute_2 = None + view_10 = torch.ops.aten.view.default(mm_2, [2, 8192, 1024]) + view_11 = torch.ops.aten.view.default(view_4, [2, 8192, -1, 128]); view_4 = None + view_12 = torch.ops.aten.view.default(view_7, [2, 8192, -1, 128]); view_7 = None + view_13 = torch.ops.aten.view.default(view_10, [2, 8192, -1, 128]); view_10 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None + view_14 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 32, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_14); view_14 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_12, torch.float32); view_12 = None + view_15 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 8, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_15); view_15 = None + view_16 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]) + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_16); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_17 = torch.ops.aten.view.default(view_as_real, [2, 8192, 32, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_16); view_as_complex_1 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_18 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 8, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_17, torch.bfloat16); view_17 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_18, torch.bfloat16); view_18 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 8, 4, 128]); unsqueeze = None + clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None + view_19 = torch.ops.aten.view.default(clone, [2, 8192, 32, 128]); clone = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_13, 3); view_13 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 8, 4, 128]); unsqueeze_1 = None + clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None + view_20 = torch.ops.aten.view.default(clone_1, [2, 8192, 32, 128]); clone_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]); view_19 = None + permute_5 = torch.ops.aten.permute.default(view_20, [0, 2, 1, 3]); view_20 = None + _scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_3, permute_4, permute_5, None, True, 0.0, True); permute_3 = permute_4 = permute_5 = None + getitem = _scaled_dot_product_cudnn_attention[0] + getitem_1 = _scaled_dot_product_cudnn_attention[1] + getitem_6 = _scaled_dot_product_cudnn_attention[6] + getitem_7 = _scaled_dot_product_cudnn_attention[7]; _scaled_dot_product_cudnn_attention = None + permute_6 = torch.ops.aten.permute.default(getitem, [0, 2, 1, 3]) + view_21 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16) + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 256, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + view_23 = torch.ops.aten.view.default(view_21, [16384, 4096]); view_21 = None + mm_3 = torch.ops.aten.mm.default(view_23, permute_7); view_23 = permute_7 = None + view_24 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + add_1 = torch.ops.aten.add.Tensor(embedding, view_24); view_24 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16) + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 256, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32) + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = rsqrt_1 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_6); mul_4 = wait_tensor_6 = None + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16) + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 256, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + view_27 = torch.ops.aten.view.default(convert_element_type_22, [16384, 4096]); convert_element_type_22 = None + mm_4 = torch.ops.aten.mm.default(view_27, permute_8); permute_8 = None + view_28 = torch.ops.aten.view.default(mm_4, [2, 8192, 14336]) + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_28, torch.float32); view_28 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); convert_element_type_26 = sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16) + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 256, '0'); convert_element_type_28 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_8, [1, 0]); wait_tensor_8 = None + mm_5 = torch.ops.aten.mm.default(view_27, permute_9); view_27 = permute_9 = None + view_31 = torch.ops.aten.view.default(mm_5, [2, 8192, 14336]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_31); convert_element_type_27 = view_31 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16) + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 256, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_9, [1, 0]); wait_tensor_9 = None + view_33 = torch.ops.aten.view.default(mul_7, [16384, 14336]); mul_7 = None + mm_6 = torch.ops.aten.mm.default(view_33, permute_10); view_33 = permute_10 = None + view_34 = torch.ops.aten.view.default(mm_6, [2, 8192, 4096]); mm_6 = None + add_3 = torch.ops.aten.add.Tensor(add_1, view_34); add_1 = view_34 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16) + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 256, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32) + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = rsqrt_2 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_10); mul_8 = wait_tensor_10 = None + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16) + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 256, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + view_37 = torch.ops.aten.view.default(convert_element_type_36, [16384, 4096]); convert_element_type_36 = None + mm_7 = torch.ops.aten.mm.default(view_37, permute_11); permute_11 = None + view_38 = torch.ops.aten.view.default(mm_7, [2, 8192, 4096]) + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16) + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 256, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_8 = torch.ops.aten.mm.default(view_37, permute_12); permute_12 = None + view_41 = torch.ops.aten.view.default(mm_8, [2, 8192, 1024]); mm_8 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16) + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 256, '0'); convert_element_type_43 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + mm_9 = torch.ops.aten.mm.default(view_37, permute_13); view_37 = permute_13 = None + view_44 = torch.ops.aten.view.default(mm_9, [2, 8192, 1024]) + view_45 = torch.ops.aten.view.default(view_38, [2, 8192, -1, 128]); view_38 = None + view_46 = torch.ops.aten.view.default(view_41, [2, 8192, -1, 128]); view_41 = None + view_47 = torch.ops.aten.view.default(view_44, [2, 8192, -1, 128]); view_44 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_45, torch.float32); view_45 = None + view_48 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 32, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_48); view_48 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_46, torch.float32); view_46 = None + view_49 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 8, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_49); view_49 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_16); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_51 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 32, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_16); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_52 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 8, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_51, torch.bfloat16); view_51 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_52, torch.bfloat16); view_52 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 8, 4, 128]); unsqueeze_2 = None + clone_2 = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format); expand_2 = None + view_53 = torch.ops.aten.view.default(clone_2, [2, 8192, 32, 128]); clone_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_47, 3); view_47 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 8, 4, 128]); unsqueeze_3 = None + clone_3 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None + view_54 = torch.ops.aten.view.default(clone_3, [2, 8192, 32, 128]); clone_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_53, [0, 2, 1, 3]); view_53 = None + permute_16 = torch.ops.aten.permute.default(view_54, [0, 2, 1, 3]); view_54 = None + _scaled_dot_product_cudnn_attention_1 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_14, permute_15, permute_16, None, True, 0.0, True); permute_14 = permute_15 = permute_16 = None + getitem_9 = _scaled_dot_product_cudnn_attention_1[0] + getitem_10 = _scaled_dot_product_cudnn_attention_1[1] + getitem_15 = _scaled_dot_product_cudnn_attention_1[6] + getitem_16 = _scaled_dot_product_cudnn_attention_1[7]; _scaled_dot_product_cudnn_attention_1 = None + permute_17 = torch.ops.aten.permute.default(getitem_9, [0, 2, 1, 3]) + view_55 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16) + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 256, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_14, [1, 0]); wait_tensor_14 = None + view_57 = torch.ops.aten.view.default(view_55, [16384, 4096]); view_55 = None + mm_10 = torch.ops.aten.mm.default(view_57, permute_18); view_57 = permute_18 = None + view_58 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + add_5 = torch.ops.aten.add.Tensor(add_3, view_58); view_58 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16) + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 256, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32) + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = rsqrt_3 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_15); mul_12 = wait_tensor_15 = None + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16) + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 256, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + view_61 = torch.ops.aten.view.default(convert_element_type_55, [16384, 4096]); convert_element_type_55 = None + mm_11 = torch.ops.aten.mm.default(view_61, permute_19); permute_19 = None + view_62 = torch.ops.aten.view.default(mm_11, [2, 8192, 14336]) + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_62, torch.float32); view_62 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); convert_element_type_59 = sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16) + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 256, '0'); convert_element_type_61 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + mm_12 = torch.ops.aten.mm.default(view_61, permute_20); view_61 = permute_20 = None + view_65 = torch.ops.aten.view.default(mm_12, [2, 8192, 14336]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_65); convert_element_type_60 = view_65 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16) + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 256, '0'); convert_element_type_64 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + view_67 = torch.ops.aten.view.default(mul_15, [16384, 14336]); mul_15 = None + mm_13 = torch.ops.aten.mm.default(view_67, permute_21); view_67 = permute_21 = None + view_68 = torch.ops.aten.view.default(mm_13, [2, 8192, 4096]); mm_13 = None + add_7 = torch.ops.aten.add.Tensor(add_5, view_68); add_5 = view_68 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16) + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 256, '0'); convert_element_type_67 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32) + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = rsqrt_4 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_19); mul_16 = wait_tensor_19 = None + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16) + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 256, '0'); convert_element_type_70 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + view_71 = torch.ops.aten.view.default(convert_element_type_69, [16384, 4096]); convert_element_type_69 = None + mm_14 = torch.ops.aten.mm.default(view_71, permute_22); permute_22 = None + view_72 = torch.ops.aten.view.default(mm_14, [2, 8192, 4096]) + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16) + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 256, '0'); convert_element_type_73 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_21, [1, 0]); wait_tensor_21 = None + mm_15 = torch.ops.aten.mm.default(view_71, permute_23); permute_23 = None + view_75 = torch.ops.aten.view.default(mm_15, [2, 8192, 1024]); mm_15 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16) + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 256, '0'); convert_element_type_76 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_22, [1, 0]); wait_tensor_22 = None + mm_16 = torch.ops.aten.mm.default(view_71, permute_24); view_71 = permute_24 = None + view_78 = torch.ops.aten.view.default(mm_16, [2, 8192, 1024]) + view_79 = torch.ops.aten.view.default(view_72, [2, 8192, -1, 128]); view_72 = None + view_80 = torch.ops.aten.view.default(view_75, [2, 8192, -1, 128]); view_75 = None + view_81 = torch.ops.aten.view.default(view_78, [2, 8192, -1, 128]); view_78 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_79, torch.float32); view_79 = None + view_82 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 32, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_82); view_82 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_80, torch.float32); view_80 = None + view_83 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 8, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_83); view_83 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_16); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_85 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 32, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_16); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_86 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 8, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_85, torch.bfloat16); view_85 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_86, torch.bfloat16); view_86 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 8, 4, 128]); unsqueeze_4 = None + clone_4 = torch.ops.aten.clone.default(expand_4, memory_format = torch.contiguous_format); expand_4 = None + view_87 = torch.ops.aten.view.default(clone_4, [2, 8192, 32, 128]); clone_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_81, 3); view_81 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 8, 4, 128]); unsqueeze_5 = None + clone_5 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format); expand_5 = None + view_88 = torch.ops.aten.view.default(clone_5, [2, 8192, 32, 128]); clone_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_87, [0, 2, 1, 3]); view_87 = None + permute_27 = torch.ops.aten.permute.default(view_88, [0, 2, 1, 3]); view_88 = None + _scaled_dot_product_cudnn_attention_2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_25, permute_26, permute_27, None, True, 0.0, True); permute_25 = permute_26 = permute_27 = None + getitem_18 = _scaled_dot_product_cudnn_attention_2[0] + getitem_19 = _scaled_dot_product_cudnn_attention_2[1] + getitem_24 = _scaled_dot_product_cudnn_attention_2[6] + getitem_25 = _scaled_dot_product_cudnn_attention_2[7]; _scaled_dot_product_cudnn_attention_2 = None + permute_28 = torch.ops.aten.permute.default(getitem_18, [0, 2, 1, 3]) + view_89 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16) + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 256, '0'); convert_element_type_83 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_23, [1, 0]); wait_tensor_23 = None + view_91 = torch.ops.aten.view.default(view_89, [16384, 4096]); view_89 = None + mm_17 = torch.ops.aten.mm.default(view_91, permute_29); view_91 = permute_29 = None + view_92 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + add_9 = torch.ops.aten.add.Tensor(add_7, view_92); view_92 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16) + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 256, '0'); convert_element_type_86 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32) + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = rsqrt_5 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_24); mul_20 = wait_tensor_24 = None + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16) + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 256, '0'); convert_element_type_89 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + view_95 = torch.ops.aten.view.default(convert_element_type_88, [16384, 4096]); convert_element_type_88 = None + mm_18 = torch.ops.aten.mm.default(view_95, permute_30); permute_30 = None + view_96 = torch.ops.aten.view.default(mm_18, [2, 8192, 14336]) + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_96, torch.float32); view_96 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); convert_element_type_92 = sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16) + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 256, '0'); convert_element_type_94 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + mm_19 = torch.ops.aten.mm.default(view_95, permute_31); view_95 = permute_31 = None + view_99 = torch.ops.aten.view.default(mm_19, [2, 8192, 14336]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_99); convert_element_type_93 = view_99 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16) + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 256, '0'); convert_element_type_97 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_27, [1, 0]); wait_tensor_27 = None + view_101 = torch.ops.aten.view.default(mul_23, [16384, 14336]); mul_23 = None + mm_20 = torch.ops.aten.mm.default(view_101, permute_32); view_101 = permute_32 = None + view_102 = torch.ops.aten.view.default(mm_20, [2, 8192, 4096]); mm_20 = None + add_11 = torch.ops.aten.add.Tensor(add_9, view_102); add_9 = view_102 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16) + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 256, '0'); convert_element_type_100 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32) + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = rsqrt_6 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_28); mul_24 = wait_tensor_28 = None + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16) + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 256, '0'); convert_element_type_103 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + view_105 = torch.ops.aten.view.default(convert_element_type_102, [16384, 4096]); convert_element_type_102 = None + mm_21 = torch.ops.aten.mm.default(view_105, permute_33); permute_33 = None + view_106 = torch.ops.aten.view.default(mm_21, [2, 8192, 4096]) + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16) + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 256, '0'); convert_element_type_106 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + mm_22 = torch.ops.aten.mm.default(view_105, permute_34); permute_34 = None + view_109 = torch.ops.aten.view.default(mm_22, [2, 8192, 1024]); mm_22 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16) + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 256, '0'); convert_element_type_109 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_23 = torch.ops.aten.mm.default(view_105, permute_35); view_105 = permute_35 = None + view_112 = torch.ops.aten.view.default(mm_23, [2, 8192, 1024]) + view_113 = torch.ops.aten.view.default(view_106, [2, 8192, -1, 128]); view_106 = None + view_114 = torch.ops.aten.view.default(view_109, [2, 8192, -1, 128]); view_109 = None + view_115 = torch.ops.aten.view.default(view_112, [2, 8192, -1, 128]); view_112 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_113, torch.float32); view_113 = None + view_116 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 32, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_116); view_116 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_114, torch.float32); view_114 = None + view_117 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 8, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_117); view_117 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_16); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_119 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 32, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_16); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_120 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 8, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_119, torch.bfloat16); view_119 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_120, torch.bfloat16); view_120 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 8, 4, 128]); unsqueeze_6 = None + clone_6 = torch.ops.aten.clone.default(expand_6, memory_format = torch.contiguous_format); expand_6 = None + view_121 = torch.ops.aten.view.default(clone_6, [2, 8192, 32, 128]); clone_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_115, 3); view_115 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 8, 4, 128]); unsqueeze_7 = None + clone_7 = torch.ops.aten.clone.default(expand_7, memory_format = torch.contiguous_format); expand_7 = None + view_122 = torch.ops.aten.view.default(clone_7, [2, 8192, 32, 128]); clone_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_121, [0, 2, 1, 3]); view_121 = None + permute_38 = torch.ops.aten.permute.default(view_122, [0, 2, 1, 3]); view_122 = None + _scaled_dot_product_cudnn_attention_3 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_36, permute_37, permute_38, None, True, 0.0, True); permute_36 = permute_37 = permute_38 = None + getitem_27 = _scaled_dot_product_cudnn_attention_3[0] + getitem_28 = _scaled_dot_product_cudnn_attention_3[1] + getitem_33 = _scaled_dot_product_cudnn_attention_3[6] + getitem_34 = _scaled_dot_product_cudnn_attention_3[7]; _scaled_dot_product_cudnn_attention_3 = None + permute_39 = torch.ops.aten.permute.default(getitem_27, [0, 2, 1, 3]) + view_123 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16) + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 256, '0'); convert_element_type_116 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + view_125 = torch.ops.aten.view.default(view_123, [16384, 4096]); view_123 = None + mm_24 = torch.ops.aten.mm.default(view_125, permute_40); view_125 = permute_40 = None + view_126 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + add_13 = torch.ops.aten.add.Tensor(add_11, view_126); view_126 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16) + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 256, '0'); convert_element_type_119 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32) + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = rsqrt_7 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_33); mul_28 = wait_tensor_33 = None + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16) + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 256, '0'); convert_element_type_122 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_34, [1, 0]); wait_tensor_34 = None + view_129 = torch.ops.aten.view.default(convert_element_type_121, [16384, 4096]); convert_element_type_121 = None + mm_25 = torch.ops.aten.mm.default(view_129, permute_41); permute_41 = None + view_130 = torch.ops.aten.view.default(mm_25, [2, 8192, 14336]) + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_130, torch.float32); view_130 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); convert_element_type_125 = sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16) + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 256, '0'); convert_element_type_127 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_35, [1, 0]); wait_tensor_35 = None + mm_26 = torch.ops.aten.mm.default(view_129, permute_42); view_129 = permute_42 = None + view_133 = torch.ops.aten.view.default(mm_26, [2, 8192, 14336]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_133); convert_element_type_126 = view_133 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16) + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 256, '0'); convert_element_type_130 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + view_135 = torch.ops.aten.view.default(mul_31, [16384, 14336]); mul_31 = None + mm_27 = torch.ops.aten.mm.default(view_135, permute_43); view_135 = permute_43 = None + view_136 = torch.ops.aten.view.default(mm_27, [2, 8192, 4096]); mm_27 = None + add_15 = torch.ops.aten.add.Tensor(add_13, view_136); add_13 = view_136 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16) + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 256, '0'); convert_element_type_133 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32) + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = rsqrt_8 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_37); mul_32 = wait_tensor_37 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16) + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 256, '0'); convert_element_type_136 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + view_139 = torch.ops.aten.view.default(convert_element_type_135, [16384, 4096]); convert_element_type_135 = None + mm_28 = torch.ops.aten.mm.default(view_139, permute_44); permute_44 = None + view_140 = torch.ops.aten.view.default(mm_28, [2, 8192, 4096]) + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16) + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 256, '0'); convert_element_type_139 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + mm_29 = torch.ops.aten.mm.default(view_139, permute_45); permute_45 = None + view_143 = torch.ops.aten.view.default(mm_29, [2, 8192, 1024]); mm_29 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16) + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 256, '0'); convert_element_type_142 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_40, [1, 0]); wait_tensor_40 = None + mm_30 = torch.ops.aten.mm.default(view_139, permute_46); view_139 = permute_46 = None + view_146 = torch.ops.aten.view.default(mm_30, [2, 8192, 1024]) + view_147 = torch.ops.aten.view.default(view_140, [2, 8192, -1, 128]); view_140 = None + view_148 = torch.ops.aten.view.default(view_143, [2, 8192, -1, 128]); view_143 = None + view_149 = torch.ops.aten.view.default(view_146, [2, 8192, -1, 128]); view_146 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_147, torch.float32); view_147 = None + view_150 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 32, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_150); view_150 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_148, torch.float32); view_148 = None + view_151 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 8, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_151); view_151 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_16); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_153 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 32, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_16); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_154 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 8, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_153, torch.bfloat16); view_153 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_154, torch.bfloat16); view_154 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 8, 4, 128]); unsqueeze_8 = None + clone_8 = torch.ops.aten.clone.default(expand_8, memory_format = torch.contiguous_format); expand_8 = None + view_155 = torch.ops.aten.view.default(clone_8, [2, 8192, 32, 128]); clone_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_149, 3); view_149 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 8, 4, 128]); unsqueeze_9 = None + clone_9 = torch.ops.aten.clone.default(expand_9, memory_format = torch.contiguous_format); expand_9 = None + view_156 = torch.ops.aten.view.default(clone_9, [2, 8192, 32, 128]); clone_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_155, [0, 2, 1, 3]); view_155 = None + permute_49 = torch.ops.aten.permute.default(view_156, [0, 2, 1, 3]); view_156 = None + _scaled_dot_product_cudnn_attention_4 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_47, permute_48, permute_49, None, True, 0.0, True); permute_47 = permute_48 = permute_49 = None + getitem_36 = _scaled_dot_product_cudnn_attention_4[0] + getitem_37 = _scaled_dot_product_cudnn_attention_4[1] + getitem_42 = _scaled_dot_product_cudnn_attention_4[6] + getitem_43 = _scaled_dot_product_cudnn_attention_4[7]; _scaled_dot_product_cudnn_attention_4 = None + permute_50 = torch.ops.aten.permute.default(getitem_36, [0, 2, 1, 3]) + view_157 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16) + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 256, '0'); convert_element_type_149 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_41, [1, 0]); wait_tensor_41 = None + view_159 = torch.ops.aten.view.default(view_157, [16384, 4096]); view_157 = None + mm_31 = torch.ops.aten.mm.default(view_159, permute_51); view_159 = permute_51 = None + view_160 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + add_17 = torch.ops.aten.add.Tensor(add_15, view_160); view_160 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16) + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 256, '0'); convert_element_type_152 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32) + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = rsqrt_9 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_42); mul_36 = wait_tensor_42 = None + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16) + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 256, '0'); convert_element_type_155 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + view_163 = torch.ops.aten.view.default(convert_element_type_154, [16384, 4096]); convert_element_type_154 = None + mm_32 = torch.ops.aten.mm.default(view_163, permute_52); permute_52 = None + view_164 = torch.ops.aten.view.default(mm_32, [2, 8192, 14336]) + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_164, torch.float32); view_164 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); convert_element_type_158 = sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16) + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 256, '0'); convert_element_type_160 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_33 = torch.ops.aten.mm.default(view_163, permute_53); view_163 = permute_53 = None + view_167 = torch.ops.aten.view.default(mm_33, [2, 8192, 14336]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_167); convert_element_type_159 = view_167 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16) + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 256, '0'); convert_element_type_163 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + view_169 = torch.ops.aten.view.default(mul_39, [16384, 14336]); mul_39 = None + mm_34 = torch.ops.aten.mm.default(view_169, permute_54); view_169 = permute_54 = None + view_170 = torch.ops.aten.view.default(mm_34, [2, 8192, 4096]); mm_34 = None + add_19 = torch.ops.aten.add.Tensor(add_17, view_170); add_17 = view_170 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16) + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 256, '0'); convert_element_type_166 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32) + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = rsqrt_10 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_46); mul_40 = wait_tensor_46 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16) + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 256, '0'); convert_element_type_169 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_47, [1, 0]); wait_tensor_47 = None + view_173 = torch.ops.aten.view.default(convert_element_type_168, [16384, 4096]); convert_element_type_168 = None + mm_35 = torch.ops.aten.mm.default(view_173, permute_55); permute_55 = None + view_174 = torch.ops.aten.view.default(mm_35, [2, 8192, 4096]) + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16) + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 256, '0'); convert_element_type_172 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_48, [1, 0]); wait_tensor_48 = None + mm_36 = torch.ops.aten.mm.default(view_173, permute_56); permute_56 = None + view_177 = torch.ops.aten.view.default(mm_36, [2, 8192, 1024]); mm_36 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16) + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 256, '0'); convert_element_type_175 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_49, [1, 0]); wait_tensor_49 = None + mm_37 = torch.ops.aten.mm.default(view_173, permute_57); view_173 = permute_57 = None + view_180 = torch.ops.aten.view.default(mm_37, [2, 8192, 1024]) + view_181 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + view_182 = torch.ops.aten.view.default(view_177, [2, 8192, -1, 128]); view_177 = None + view_183 = torch.ops.aten.view.default(view_180, [2, 8192, -1, 128]); view_180 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_181, torch.float32); view_181 = None + view_184 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 32, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_184); view_184 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_182, torch.float32); view_182 = None + view_185 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 8, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_185); view_185 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_16); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_187 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 32, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_16); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_188 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 8, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_187, torch.bfloat16); view_187 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_188, torch.bfloat16); view_188 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 8, 4, 128]); unsqueeze_10 = None + clone_10 = torch.ops.aten.clone.default(expand_10, memory_format = torch.contiguous_format); expand_10 = None + view_189 = torch.ops.aten.view.default(clone_10, [2, 8192, 32, 128]); clone_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_183, 3); view_183 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 8, 4, 128]); unsqueeze_11 = None + clone_11 = torch.ops.aten.clone.default(expand_11, memory_format = torch.contiguous_format); expand_11 = None + view_190 = torch.ops.aten.view.default(clone_11, [2, 8192, 32, 128]); clone_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_189, [0, 2, 1, 3]); view_189 = None + permute_60 = torch.ops.aten.permute.default(view_190, [0, 2, 1, 3]); view_190 = None + _scaled_dot_product_cudnn_attention_5 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_58, permute_59, permute_60, None, True, 0.0, True); permute_58 = permute_59 = permute_60 = None + getitem_45 = _scaled_dot_product_cudnn_attention_5[0] + getitem_46 = _scaled_dot_product_cudnn_attention_5[1] + getitem_51 = _scaled_dot_product_cudnn_attention_5[6] + getitem_52 = _scaled_dot_product_cudnn_attention_5[7]; _scaled_dot_product_cudnn_attention_5 = None + permute_61 = torch.ops.aten.permute.default(getitem_45, [0, 2, 1, 3]) + view_191 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16) + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 256, '0'); convert_element_type_182 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_193 = torch.ops.aten.view.default(view_191, [16384, 4096]); view_191 = None + mm_38 = torch.ops.aten.mm.default(view_193, permute_62); view_193 = permute_62 = None + view_194 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + add_21 = torch.ops.aten.add.Tensor(add_19, view_194); view_194 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16) + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 256, '0'); convert_element_type_185 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32) + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = rsqrt_11 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_51); mul_44 = wait_tensor_51 = None + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16) + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 256, '0'); convert_element_type_188 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + view_197 = torch.ops.aten.view.default(convert_element_type_187, [16384, 4096]); convert_element_type_187 = None + mm_39 = torch.ops.aten.mm.default(view_197, permute_63); permute_63 = None + view_198 = torch.ops.aten.view.default(mm_39, [2, 8192, 14336]) + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_198, torch.float32); view_198 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); convert_element_type_191 = sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16) + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 256, '0'); convert_element_type_193 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_53, [1, 0]); wait_tensor_53 = None + mm_40 = torch.ops.aten.mm.default(view_197, permute_64); view_197 = permute_64 = None + view_201 = torch.ops.aten.view.default(mm_40, [2, 8192, 14336]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_201); convert_element_type_192 = view_201 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16) + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 256, '0'); convert_element_type_196 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + view_203 = torch.ops.aten.view.default(mul_47, [16384, 14336]); mul_47 = None + mm_41 = torch.ops.aten.mm.default(view_203, permute_65); view_203 = permute_65 = None + view_204 = torch.ops.aten.view.default(mm_41, [2, 8192, 4096]); mm_41 = None + add_23 = torch.ops.aten.add.Tensor(add_21, view_204); add_21 = view_204 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16) + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 256, '0'); convert_element_type_199 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32) + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = rsqrt_12 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_55); mul_48 = wait_tensor_55 = None + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16) + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 256, '0'); convert_element_type_202 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + view_207 = torch.ops.aten.view.default(convert_element_type_201, [16384, 4096]); convert_element_type_201 = None + mm_42 = torch.ops.aten.mm.default(view_207, permute_66); permute_66 = None + view_208 = torch.ops.aten.view.default(mm_42, [2, 8192, 4096]) + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16) + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 256, '0'); convert_element_type_205 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_43 = torch.ops.aten.mm.default(view_207, permute_67); permute_67 = None + view_211 = torch.ops.aten.view.default(mm_43, [2, 8192, 1024]); mm_43 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16) + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 256, '0'); convert_element_type_208 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + mm_44 = torch.ops.aten.mm.default(view_207, permute_68); view_207 = permute_68 = None + view_214 = torch.ops.aten.view.default(mm_44, [2, 8192, 1024]) + view_215 = torch.ops.aten.view.default(view_208, [2, 8192, -1, 128]); view_208 = None + view_216 = torch.ops.aten.view.default(view_211, [2, 8192, -1, 128]); view_211 = None + view_217 = torch.ops.aten.view.default(view_214, [2, 8192, -1, 128]); view_214 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_215, torch.float32); view_215 = None + view_218 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 32, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_218); view_218 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_216, torch.float32); view_216 = None + view_219 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 8, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_219); view_219 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_16); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_221 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 32, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_16); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_222 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 8, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_221, torch.bfloat16); view_221 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_222, torch.bfloat16); view_222 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 8, 4, 128]); unsqueeze_12 = None + clone_12 = torch.ops.aten.clone.default(expand_12, memory_format = torch.contiguous_format); expand_12 = None + view_223 = torch.ops.aten.view.default(clone_12, [2, 8192, 32, 128]); clone_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_217, 3); view_217 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 8, 4, 128]); unsqueeze_13 = None + clone_13 = torch.ops.aten.clone.default(expand_13, memory_format = torch.contiguous_format); expand_13 = None + view_224 = torch.ops.aten.view.default(clone_13, [2, 8192, 32, 128]); clone_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_223, [0, 2, 1, 3]); view_223 = None + permute_71 = torch.ops.aten.permute.default(view_224, [0, 2, 1, 3]); view_224 = None + _scaled_dot_product_cudnn_attention_6 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_69, permute_70, permute_71, None, True, 0.0, True); permute_69 = permute_70 = permute_71 = None + getitem_54 = _scaled_dot_product_cudnn_attention_6[0] + getitem_55 = _scaled_dot_product_cudnn_attention_6[1] + getitem_60 = _scaled_dot_product_cudnn_attention_6[6] + getitem_61 = _scaled_dot_product_cudnn_attention_6[7]; _scaled_dot_product_cudnn_attention_6 = None + permute_72 = torch.ops.aten.permute.default(getitem_54, [0, 2, 1, 3]) + view_225 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16) + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 256, '0'); convert_element_type_215 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_227 = torch.ops.aten.view.default(view_225, [16384, 4096]); view_225 = None + mm_45 = torch.ops.aten.mm.default(view_227, permute_73); view_227 = permute_73 = None + view_228 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + add_25 = torch.ops.aten.add.Tensor(add_23, view_228); view_228 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16) + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 256, '0'); convert_element_type_218 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32) + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = rsqrt_13 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_60); mul_52 = wait_tensor_60 = None + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16) + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 256, '0'); convert_element_type_221 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_61, [1, 0]); wait_tensor_61 = None + view_231 = torch.ops.aten.view.default(convert_element_type_220, [16384, 4096]); convert_element_type_220 = None + mm_46 = torch.ops.aten.mm.default(view_231, permute_74); permute_74 = None + view_232 = torch.ops.aten.view.default(mm_46, [2, 8192, 14336]) + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_232, torch.float32); view_232 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); convert_element_type_224 = sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16) + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 256, '0'); convert_element_type_226 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_62, [1, 0]); wait_tensor_62 = None + mm_47 = torch.ops.aten.mm.default(view_231, permute_75); view_231 = permute_75 = None + view_235 = torch.ops.aten.view.default(mm_47, [2, 8192, 14336]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_235); convert_element_type_225 = view_235 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16) + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 256, '0'); convert_element_type_229 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + view_237 = torch.ops.aten.view.default(mul_55, [16384, 14336]); mul_55 = None + mm_48 = torch.ops.aten.mm.default(view_237, permute_76); view_237 = permute_76 = None + view_238 = torch.ops.aten.view.default(mm_48, [2, 8192, 4096]); mm_48 = None + add_27 = torch.ops.aten.add.Tensor(add_25, view_238); add_25 = view_238 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16) + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 256, '0'); convert_element_type_232 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32) + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = rsqrt_14 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_64); mul_56 = wait_tensor_64 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16) + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 256, '0'); convert_element_type_235 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + view_241 = torch.ops.aten.view.default(convert_element_type_234, [16384, 4096]); convert_element_type_234 = None + mm_49 = torch.ops.aten.mm.default(view_241, permute_77); permute_77 = None + view_242 = torch.ops.aten.view.default(mm_49, [2, 8192, 4096]) + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16) + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 256, '0'); convert_element_type_238 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_66, [1, 0]); wait_tensor_66 = None + mm_50 = torch.ops.aten.mm.default(view_241, permute_78); permute_78 = None + view_245 = torch.ops.aten.view.default(mm_50, [2, 8192, 1024]); mm_50 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16) + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 256, '0'); convert_element_type_241 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_67, [1, 0]); wait_tensor_67 = None + mm_51 = torch.ops.aten.mm.default(view_241, permute_79); view_241 = permute_79 = None + view_248 = torch.ops.aten.view.default(mm_51, [2, 8192, 1024]) + view_249 = torch.ops.aten.view.default(view_242, [2, 8192, -1, 128]); view_242 = None + view_250 = torch.ops.aten.view.default(view_245, [2, 8192, -1, 128]); view_245 = None + view_251 = torch.ops.aten.view.default(view_248, [2, 8192, -1, 128]); view_248 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 32, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_250, torch.float32); view_250 = None + view_253 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 8, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_253); view_253 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_16); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_255 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 32, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_16); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_256 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 8, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_256, torch.bfloat16); view_256 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 8, 4, 128]); unsqueeze_14 = None + clone_14 = torch.ops.aten.clone.default(expand_14, memory_format = torch.contiguous_format); expand_14 = None + view_257 = torch.ops.aten.view.default(clone_14, [2, 8192, 32, 128]); clone_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_251, 3); view_251 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 8, 4, 128]); unsqueeze_15 = None + clone_15 = torch.ops.aten.clone.default(expand_15, memory_format = torch.contiguous_format); expand_15 = None + view_258 = torch.ops.aten.view.default(clone_15, [2, 8192, 32, 128]); clone_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + permute_82 = torch.ops.aten.permute.default(view_258, [0, 2, 1, 3]); view_258 = None + _scaled_dot_product_cudnn_attention_7 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_80, permute_81, permute_82, None, True, 0.0, True); permute_80 = permute_81 = permute_82 = None + getitem_63 = _scaled_dot_product_cudnn_attention_7[0] + getitem_64 = _scaled_dot_product_cudnn_attention_7[1] + getitem_69 = _scaled_dot_product_cudnn_attention_7[6] + getitem_70 = _scaled_dot_product_cudnn_attention_7[7]; _scaled_dot_product_cudnn_attention_7 = None + permute_83 = torch.ops.aten.permute.default(getitem_63, [0, 2, 1, 3]) + view_259 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16) + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 256, '0'); convert_element_type_248 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_68, [1, 0]); wait_tensor_68 = None + view_261 = torch.ops.aten.view.default(view_259, [16384, 4096]); view_259 = None + mm_52 = torch.ops.aten.mm.default(view_261, permute_84); view_261 = permute_84 = None + view_262 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + add_29 = torch.ops.aten.add.Tensor(add_27, view_262); view_262 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16) + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 256, '0'); convert_element_type_251 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32) + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = rsqrt_15 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_69); mul_60 = wait_tensor_69 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16) + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 256, '0'); convert_element_type_254 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + view_265 = torch.ops.aten.view.default(convert_element_type_253, [16384, 4096]); convert_element_type_253 = None + mm_53 = torch.ops.aten.mm.default(view_265, permute_85); permute_85 = None + view_266 = torch.ops.aten.view.default(mm_53, [2, 8192, 14336]) + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_266, torch.float32); view_266 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); convert_element_type_257 = sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16) + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 256, '0'); convert_element_type_259 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_54 = torch.ops.aten.mm.default(view_265, permute_86); view_265 = permute_86 = None + view_269 = torch.ops.aten.view.default(mm_54, [2, 8192, 14336]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_269); convert_element_type_258 = view_269 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16) + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 256, '0'); convert_element_type_262 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + view_271 = torch.ops.aten.view.default(mul_63, [16384, 14336]); mul_63 = None + mm_55 = torch.ops.aten.mm.default(view_271, permute_87); view_271 = permute_87 = None + view_272 = torch.ops.aten.view.default(mm_55, [2, 8192, 4096]); mm_55 = None + add_31 = torch.ops.aten.add.Tensor(add_29, view_272); add_29 = view_272 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16) + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 256, '0'); convert_element_type_265 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32) + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = rsqrt_16 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_73); mul_64 = wait_tensor_73 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16) + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 256, '0'); convert_element_type_268 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_74, [1, 0]); wait_tensor_74 = None + view_275 = torch.ops.aten.view.default(convert_element_type_267, [16384, 4096]); convert_element_type_267 = None + mm_56 = torch.ops.aten.mm.default(view_275, permute_88); permute_88 = None + view_276 = torch.ops.aten.view.default(mm_56, [2, 8192, 4096]) + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16) + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 256, '0'); convert_element_type_271 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + mm_57 = torch.ops.aten.mm.default(view_275, permute_89); permute_89 = None + view_279 = torch.ops.aten.view.default(mm_57, [2, 8192, 1024]); mm_57 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16) + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 256, '0'); convert_element_type_274 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + mm_58 = torch.ops.aten.mm.default(view_275, permute_90); view_275 = permute_90 = None + view_282 = torch.ops.aten.view.default(mm_58, [2, 8192, 1024]) + view_283 = torch.ops.aten.view.default(view_276, [2, 8192, -1, 128]); view_276 = None + view_284 = torch.ops.aten.view.default(view_279, [2, 8192, -1, 128]); view_279 = None + view_285 = torch.ops.aten.view.default(view_282, [2, 8192, -1, 128]); view_282 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_283, torch.float32); view_283 = None + view_286 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 32, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_286); view_286 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_284, torch.float32); view_284 = None + view_287 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 8, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_287); view_287 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_16); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_289 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 32, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_16); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_290 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 8, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_289, torch.bfloat16); view_289 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_290, torch.bfloat16); view_290 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 8, 4, 128]); unsqueeze_16 = None + clone_16 = torch.ops.aten.clone.default(expand_16, memory_format = torch.contiguous_format); expand_16 = None + view_291 = torch.ops.aten.view.default(clone_16, [2, 8192, 32, 128]); clone_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_285, 3); view_285 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 8, 4, 128]); unsqueeze_17 = None + clone_17 = torch.ops.aten.clone.default(expand_17, memory_format = torch.contiguous_format); expand_17 = None + view_292 = torch.ops.aten.view.default(clone_17, [2, 8192, 32, 128]); clone_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_291, [0, 2, 1, 3]); view_291 = None + permute_93 = torch.ops.aten.permute.default(view_292, [0, 2, 1, 3]); view_292 = None + _scaled_dot_product_cudnn_attention_8 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_91, permute_92, permute_93, None, True, 0.0, True); permute_91 = permute_92 = permute_93 = None + getitem_72 = _scaled_dot_product_cudnn_attention_8[0] + getitem_73 = _scaled_dot_product_cudnn_attention_8[1] + getitem_78 = _scaled_dot_product_cudnn_attention_8[6] + getitem_79 = _scaled_dot_product_cudnn_attention_8[7]; _scaled_dot_product_cudnn_attention_8 = None + permute_94 = torch.ops.aten.permute.default(getitem_72, [0, 2, 1, 3]) + view_293 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16) + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 256, '0'); convert_element_type_281 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + view_295 = torch.ops.aten.view.default(view_293, [16384, 4096]); view_293 = None + mm_59 = torch.ops.aten.mm.default(view_295, permute_95); view_295 = permute_95 = None + view_296 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + add_33 = torch.ops.aten.add.Tensor(add_31, view_296); view_296 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16) + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 256, '0'); convert_element_type_284 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32) + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = rsqrt_17 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_78); mul_68 = wait_tensor_78 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16) + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 256, '0'); convert_element_type_287 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + view_299 = torch.ops.aten.view.default(convert_element_type_286, [16384, 4096]); convert_element_type_286 = None + mm_60 = torch.ops.aten.mm.default(view_299, permute_96); permute_96 = None + view_300 = torch.ops.aten.view.default(mm_60, [2, 8192, 14336]) + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); convert_element_type_290 = sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16) + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 256, '0'); convert_element_type_292 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_80, [1, 0]); wait_tensor_80 = None + mm_61 = torch.ops.aten.mm.default(view_299, permute_97); view_299 = permute_97 = None + view_303 = torch.ops.aten.view.default(mm_61, [2, 8192, 14336]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_303); convert_element_type_291 = view_303 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16) + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 256, '0'); convert_element_type_295 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_81, [1, 0]); wait_tensor_81 = None + view_305 = torch.ops.aten.view.default(mul_71, [16384, 14336]); mul_71 = None + mm_62 = torch.ops.aten.mm.default(view_305, permute_98); view_305 = permute_98 = None + view_306 = torch.ops.aten.view.default(mm_62, [2, 8192, 4096]); mm_62 = None + add_35 = torch.ops.aten.add.Tensor(add_33, view_306); add_33 = view_306 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16) + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 256, '0'); convert_element_type_298 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32) + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = rsqrt_18 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_82); mul_72 = wait_tensor_82 = None + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16) + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 256, '0'); convert_element_type_301 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + view_309 = torch.ops.aten.view.default(convert_element_type_300, [16384, 4096]); convert_element_type_300 = None + mm_63 = torch.ops.aten.mm.default(view_309, permute_99); permute_99 = None + view_310 = torch.ops.aten.view.default(mm_63, [2, 8192, 4096]) + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16) + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 256, '0'); convert_element_type_304 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_64 = torch.ops.aten.mm.default(view_309, permute_100); permute_100 = None + view_313 = torch.ops.aten.view.default(mm_64, [2, 8192, 1024]); mm_64 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16) + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 256, '0'); convert_element_type_307 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + mm_65 = torch.ops.aten.mm.default(view_309, permute_101); view_309 = permute_101 = None + view_316 = torch.ops.aten.view.default(mm_65, [2, 8192, 1024]) + view_317 = torch.ops.aten.view.default(view_310, [2, 8192, -1, 128]); view_310 = None + view_318 = torch.ops.aten.view.default(view_313, [2, 8192, -1, 128]); view_313 = None + view_319 = torch.ops.aten.view.default(view_316, [2, 8192, -1, 128]); view_316 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_317, torch.float32); view_317 = None + view_320 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 32, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_320); view_320 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_318, torch.float32); view_318 = None + view_321 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 8, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_321); view_321 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_16); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_323 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 32, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_16); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_324 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 8, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_323, torch.bfloat16); view_323 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_324, torch.bfloat16); view_324 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 8, 4, 128]); unsqueeze_18 = None + clone_18 = torch.ops.aten.clone.default(expand_18, memory_format = torch.contiguous_format); expand_18 = None + view_325 = torch.ops.aten.view.default(clone_18, [2, 8192, 32, 128]); clone_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_319, 3); view_319 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 8, 4, 128]); unsqueeze_19 = None + clone_19 = torch.ops.aten.clone.default(expand_19, memory_format = torch.contiguous_format); expand_19 = None + view_326 = torch.ops.aten.view.default(clone_19, [2, 8192, 32, 128]); clone_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_325, [0, 2, 1, 3]); view_325 = None + permute_104 = torch.ops.aten.permute.default(view_326, [0, 2, 1, 3]); view_326 = None + _scaled_dot_product_cudnn_attention_9 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_102, permute_103, permute_104, None, True, 0.0, True); permute_102 = permute_103 = permute_104 = None + getitem_81 = _scaled_dot_product_cudnn_attention_9[0] + getitem_82 = _scaled_dot_product_cudnn_attention_9[1] + getitem_87 = _scaled_dot_product_cudnn_attention_9[6] + getitem_88 = _scaled_dot_product_cudnn_attention_9[7]; _scaled_dot_product_cudnn_attention_9 = None + permute_105 = torch.ops.aten.permute.default(getitem_81, [0, 2, 1, 3]) + view_327 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16) + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 256, '0'); convert_element_type_314 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_86, [1, 0]); wait_tensor_86 = None + view_329 = torch.ops.aten.view.default(view_327, [16384, 4096]); view_327 = None + mm_66 = torch.ops.aten.mm.default(view_329, permute_106); view_329 = permute_106 = None + view_330 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + add_37 = torch.ops.aten.add.Tensor(add_35, view_330); view_330 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16) + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 256, '0'); convert_element_type_317 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32) + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = rsqrt_19 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_87); mul_76 = wait_tensor_87 = None + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16) + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 256, '0'); convert_element_type_320 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_88, [1, 0]); wait_tensor_88 = None + view_333 = torch.ops.aten.view.default(convert_element_type_319, [16384, 4096]); convert_element_type_319 = None + mm_67 = torch.ops.aten.mm.default(view_333, permute_107); permute_107 = None + view_334 = torch.ops.aten.view.default(mm_67, [2, 8192, 14336]) + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_334, torch.float32); view_334 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); convert_element_type_323 = sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16) + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 256, '0'); convert_element_type_325 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + mm_68 = torch.ops.aten.mm.default(view_333, permute_108); view_333 = permute_108 = None + view_337 = torch.ops.aten.view.default(mm_68, [2, 8192, 14336]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_337); convert_element_type_324 = view_337 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16) + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 256, '0'); convert_element_type_328 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + view_339 = torch.ops.aten.view.default(mul_79, [16384, 14336]); mul_79 = None + mm_69 = torch.ops.aten.mm.default(view_339, permute_109); view_339 = permute_109 = None + view_340 = torch.ops.aten.view.default(mm_69, [2, 8192, 4096]); mm_69 = None + add_39 = torch.ops.aten.add.Tensor(add_37, view_340); add_37 = view_340 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16) + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 256, '0'); convert_element_type_331 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32) + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = rsqrt_20 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_91); mul_80 = wait_tensor_91 = None + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16) + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 256, '0'); convert_element_type_334 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + view_343 = torch.ops.aten.view.default(convert_element_type_333, [16384, 4096]); convert_element_type_333 = None + mm_70 = torch.ops.aten.mm.default(view_343, permute_110); permute_110 = None + view_344 = torch.ops.aten.view.default(mm_70, [2, 8192, 4096]) + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16) + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 256, '0'); convert_element_type_337 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + mm_71 = torch.ops.aten.mm.default(view_343, permute_111); permute_111 = None + view_347 = torch.ops.aten.view.default(mm_71, [2, 8192, 1024]); mm_71 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16) + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 256, '0'); convert_element_type_340 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + mm_72 = torch.ops.aten.mm.default(view_343, permute_112); view_343 = permute_112 = None + view_350 = torch.ops.aten.view.default(mm_72, [2, 8192, 1024]) + view_351 = torch.ops.aten.view.default(view_344, [2, 8192, -1, 128]); view_344 = None + view_352 = torch.ops.aten.view.default(view_347, [2, 8192, -1, 128]); view_347 = None + view_353 = torch.ops.aten.view.default(view_350, [2, 8192, -1, 128]); view_350 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_351, torch.float32); view_351 = None + view_354 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 32, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_354); view_354 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_352, torch.float32); view_352 = None + view_355 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 8, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_355); view_355 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_16); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_357 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 32, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_16); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_358 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 8, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_357, torch.bfloat16); view_357 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_358, torch.bfloat16); view_358 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 8, 4, 128]); unsqueeze_20 = None + clone_20 = torch.ops.aten.clone.default(expand_20, memory_format = torch.contiguous_format); expand_20 = None + view_359 = torch.ops.aten.view.default(clone_20, [2, 8192, 32, 128]); clone_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_353, 3); view_353 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 8, 4, 128]); unsqueeze_21 = None + clone_21 = torch.ops.aten.clone.default(expand_21, memory_format = torch.contiguous_format); expand_21 = None + view_360 = torch.ops.aten.view.default(clone_21, [2, 8192, 32, 128]); clone_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_359, [0, 2, 1, 3]); view_359 = None + permute_115 = torch.ops.aten.permute.default(view_360, [0, 2, 1, 3]); view_360 = None + _scaled_dot_product_cudnn_attention_10 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_113, permute_114, permute_115, None, True, 0.0, True); permute_113 = permute_114 = permute_115 = None + getitem_90 = _scaled_dot_product_cudnn_attention_10[0] + getitem_91 = _scaled_dot_product_cudnn_attention_10[1] + getitem_96 = _scaled_dot_product_cudnn_attention_10[6] + getitem_97 = _scaled_dot_product_cudnn_attention_10[7]; _scaled_dot_product_cudnn_attention_10 = None + permute_116 = torch.ops.aten.permute.default(getitem_90, [0, 2, 1, 3]) + view_361 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16) + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 256, '0'); convert_element_type_347 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_363 = torch.ops.aten.view.default(view_361, [16384, 4096]); view_361 = None + mm_73 = torch.ops.aten.mm.default(view_363, permute_117); view_363 = permute_117 = None + view_364 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + add_41 = torch.ops.aten.add.Tensor(add_39, view_364); view_364 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16) + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 256, '0'); convert_element_type_350 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32) + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = rsqrt_21 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_96); mul_84 = wait_tensor_96 = None + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16) + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 256, '0'); convert_element_type_353 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + view_367 = torch.ops.aten.view.default(convert_element_type_352, [16384, 4096]); convert_element_type_352 = None + mm_74 = torch.ops.aten.mm.default(view_367, permute_118); permute_118 = None + view_368 = torch.ops.aten.view.default(mm_74, [2, 8192, 14336]) + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_368, torch.float32); view_368 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); convert_element_type_356 = sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16) + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 256, '0'); convert_element_type_358 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + mm_75 = torch.ops.aten.mm.default(view_367, permute_119); view_367 = permute_119 = None + view_371 = torch.ops.aten.view.default(mm_75, [2, 8192, 14336]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_371); convert_element_type_357 = view_371 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16) + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 256, '0'); convert_element_type_361 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + view_373 = torch.ops.aten.view.default(mul_87, [16384, 14336]); mul_87 = None + mm_76 = torch.ops.aten.mm.default(view_373, permute_120); view_373 = permute_120 = None + view_374 = torch.ops.aten.view.default(mm_76, [2, 8192, 4096]); mm_76 = None + add_43 = torch.ops.aten.add.Tensor(add_41, view_374); add_41 = view_374 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16) + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 256, '0'); convert_element_type_364 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32) + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = rsqrt_22 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_100); mul_88 = wait_tensor_100 = None + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16) + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 256, '0'); convert_element_type_367 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_101, [1, 0]); wait_tensor_101 = None + view_377 = torch.ops.aten.view.default(convert_element_type_366, [16384, 4096]); convert_element_type_366 = None + mm_77 = torch.ops.aten.mm.default(view_377, permute_121); permute_121 = None + view_378 = torch.ops.aten.view.default(mm_77, [2, 8192, 4096]) + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16) + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 256, '0'); convert_element_type_370 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + mm_78 = torch.ops.aten.mm.default(view_377, permute_122); permute_122 = None + view_381 = torch.ops.aten.view.default(mm_78, [2, 8192, 1024]); mm_78 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16) + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 256, '0'); convert_element_type_373 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_79 = torch.ops.aten.mm.default(view_377, permute_123); view_377 = permute_123 = None + view_384 = torch.ops.aten.view.default(mm_79, [2, 8192, 1024]) + view_385 = torch.ops.aten.view.default(view_378, [2, 8192, -1, 128]); view_378 = None + view_386 = torch.ops.aten.view.default(view_381, [2, 8192, -1, 128]); view_381 = None + view_387 = torch.ops.aten.view.default(view_384, [2, 8192, -1, 128]); view_384 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_385, torch.float32); view_385 = None + view_388 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 32, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_388); view_388 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_386, torch.float32); view_386 = None + view_389 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 8, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_389); view_389 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_16); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_391 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 32, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_16); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_392 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 8, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_391, torch.bfloat16); view_391 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_392, torch.bfloat16); view_392 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 8, 4, 128]); unsqueeze_22 = None + clone_22 = torch.ops.aten.clone.default(expand_22, memory_format = torch.contiguous_format); expand_22 = None + view_393 = torch.ops.aten.view.default(clone_22, [2, 8192, 32, 128]); clone_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_387, 3); view_387 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 8, 4, 128]); unsqueeze_23 = None + clone_23 = torch.ops.aten.clone.default(expand_23, memory_format = torch.contiguous_format); expand_23 = None + view_394 = torch.ops.aten.view.default(clone_23, [2, 8192, 32, 128]); clone_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_393, [0, 2, 1, 3]); view_393 = None + permute_126 = torch.ops.aten.permute.default(view_394, [0, 2, 1, 3]); view_394 = None + _scaled_dot_product_cudnn_attention_11 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_124, permute_125, permute_126, None, True, 0.0, True); permute_124 = permute_125 = permute_126 = None + getitem_99 = _scaled_dot_product_cudnn_attention_11[0] + getitem_100 = _scaled_dot_product_cudnn_attention_11[1] + getitem_105 = _scaled_dot_product_cudnn_attention_11[6] + getitem_106 = _scaled_dot_product_cudnn_attention_11[7]; _scaled_dot_product_cudnn_attention_11 = None + permute_127 = torch.ops.aten.permute.default(getitem_99, [0, 2, 1, 3]) + view_395 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16) + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 256, '0'); convert_element_type_380 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_397 = torch.ops.aten.view.default(view_395, [16384, 4096]); view_395 = None + mm_80 = torch.ops.aten.mm.default(view_397, permute_128); view_397 = permute_128 = None + view_398 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + add_45 = torch.ops.aten.add.Tensor(add_43, view_398); view_398 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16) + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 256, '0'); convert_element_type_383 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32) + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = rsqrt_23 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_105); mul_92 = wait_tensor_105 = None + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16) + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 256, '0'); convert_element_type_386 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_106, [1, 0]); wait_tensor_106 = None + view_401 = torch.ops.aten.view.default(convert_element_type_385, [16384, 4096]); convert_element_type_385 = None + mm_81 = torch.ops.aten.mm.default(view_401, permute_129); permute_129 = None + view_402 = torch.ops.aten.view.default(mm_81, [2, 8192, 14336]) + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_402, torch.float32); view_402 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); convert_element_type_389 = sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16) + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 256, '0'); convert_element_type_391 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_107, [1, 0]); wait_tensor_107 = None + mm_82 = torch.ops.aten.mm.default(view_401, permute_130); view_401 = permute_130 = None + view_405 = torch.ops.aten.view.default(mm_82, [2, 8192, 14336]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_405); convert_element_type_390 = view_405 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16) + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 256, '0'); convert_element_type_394 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + view_407 = torch.ops.aten.view.default(mul_95, [16384, 14336]); mul_95 = None + mm_83 = torch.ops.aten.mm.default(view_407, permute_131); view_407 = permute_131 = None + view_408 = torch.ops.aten.view.default(mm_83, [2, 8192, 4096]); mm_83 = None + add_47 = torch.ops.aten.add.Tensor(add_45, view_408); add_45 = view_408 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16) + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 256, '0'); convert_element_type_397 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32) + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = rsqrt_24 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_109); mul_96 = wait_tensor_109 = None + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16) + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 256, '0'); convert_element_type_400 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + view_411 = torch.ops.aten.view.default(convert_element_type_399, [16384, 4096]); convert_element_type_399 = None + mm_84 = torch.ops.aten.mm.default(view_411, permute_132); permute_132 = None + view_412 = torch.ops.aten.view.default(mm_84, [2, 8192, 4096]) + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16) + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 256, '0'); convert_element_type_403 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + mm_85 = torch.ops.aten.mm.default(view_411, permute_133); permute_133 = None + view_415 = torch.ops.aten.view.default(mm_85, [2, 8192, 1024]); mm_85 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16) + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 256, '0'); convert_element_type_406 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_112, [1, 0]); wait_tensor_112 = None + mm_86 = torch.ops.aten.mm.default(view_411, permute_134); view_411 = permute_134 = None + view_418 = torch.ops.aten.view.default(mm_86, [2, 8192, 1024]) + view_419 = torch.ops.aten.view.default(view_412, [2, 8192, -1, 128]); view_412 = None + view_420 = torch.ops.aten.view.default(view_415, [2, 8192, -1, 128]); view_415 = None + view_421 = torch.ops.aten.view.default(view_418, [2, 8192, -1, 128]); view_418 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_419, torch.float32); view_419 = None + view_422 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 32, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_422); view_422 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_420, torch.float32); view_420 = None + view_423 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 8, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_423); view_423 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_16); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_425 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 32, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_16); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_426 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 8, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_425, torch.bfloat16); view_425 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_426, torch.bfloat16); view_426 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 8, 4, 128]); unsqueeze_24 = None + clone_24 = torch.ops.aten.clone.default(expand_24, memory_format = torch.contiguous_format); expand_24 = None + view_427 = torch.ops.aten.view.default(clone_24, [2, 8192, 32, 128]); clone_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_421, 3); view_421 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 8, 4, 128]); unsqueeze_25 = None + clone_25 = torch.ops.aten.clone.default(expand_25, memory_format = torch.contiguous_format); expand_25 = None + view_428 = torch.ops.aten.view.default(clone_25, [2, 8192, 32, 128]); clone_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_427, [0, 2, 1, 3]); view_427 = None + permute_137 = torch.ops.aten.permute.default(view_428, [0, 2, 1, 3]); view_428 = None + _scaled_dot_product_cudnn_attention_12 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_135, permute_136, permute_137, None, True, 0.0, True); permute_135 = permute_136 = permute_137 = None + getitem_108 = _scaled_dot_product_cudnn_attention_12[0] + getitem_109 = _scaled_dot_product_cudnn_attention_12[1] + getitem_114 = _scaled_dot_product_cudnn_attention_12[6] + getitem_115 = _scaled_dot_product_cudnn_attention_12[7]; _scaled_dot_product_cudnn_attention_12 = None + permute_138 = torch.ops.aten.permute.default(getitem_108, [0, 2, 1, 3]) + view_429 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16) + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 256, '0'); convert_element_type_413 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + view_431 = torch.ops.aten.view.default(view_429, [16384, 4096]); view_429 = None + mm_87 = torch.ops.aten.mm.default(view_431, permute_139); view_431 = permute_139 = None + view_432 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + add_49 = torch.ops.aten.add.Tensor(add_47, view_432); view_432 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16) + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 256, '0'); convert_element_type_416 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32) + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = rsqrt_25 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_114); mul_100 = wait_tensor_114 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16) + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 256, '0'); convert_element_type_419 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + view_435 = torch.ops.aten.view.default(convert_element_type_418, [16384, 4096]); convert_element_type_418 = None + mm_88 = torch.ops.aten.mm.default(view_435, permute_140); permute_140 = None + view_436 = torch.ops.aten.view.default(mm_88, [2, 8192, 14336]) + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_436, torch.float32); view_436 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); convert_element_type_422 = sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16) + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 256, '0'); convert_element_type_424 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_89 = torch.ops.aten.mm.default(view_435, permute_141); view_435 = permute_141 = None + view_439 = torch.ops.aten.view.default(mm_89, [2, 8192, 14336]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_439); convert_element_type_423 = view_439 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16) + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 256, '0'); convert_element_type_427 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + view_441 = torch.ops.aten.view.default(mul_103, [16384, 14336]); mul_103 = None + mm_90 = torch.ops.aten.mm.default(view_441, permute_142); view_441 = permute_142 = None + view_442 = torch.ops.aten.view.default(mm_90, [2, 8192, 4096]); mm_90 = None + add_51 = torch.ops.aten.add.Tensor(add_49, view_442); add_49 = view_442 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16) + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 256, '0'); convert_element_type_430 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32) + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = rsqrt_26 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_118); mul_104 = wait_tensor_118 = None + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16) + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 256, '0'); convert_element_type_433 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_119, [1, 0]); wait_tensor_119 = None + view_445 = torch.ops.aten.view.default(convert_element_type_432, [16384, 4096]); convert_element_type_432 = None + mm_91 = torch.ops.aten.mm.default(view_445, permute_143); permute_143 = None + view_446 = torch.ops.aten.view.default(mm_91, [2, 8192, 4096]) + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16) + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 256, '0'); convert_element_type_436 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + mm_92 = torch.ops.aten.mm.default(view_445, permute_144); permute_144 = None + view_449 = torch.ops.aten.view.default(mm_92, [2, 8192, 1024]); mm_92 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16) + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 256, '0'); convert_element_type_439 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + mm_93 = torch.ops.aten.mm.default(view_445, permute_145); view_445 = permute_145 = None + view_452 = torch.ops.aten.view.default(mm_93, [2, 8192, 1024]) + view_453 = torch.ops.aten.view.default(view_446, [2, 8192, -1, 128]); view_446 = None + view_454 = torch.ops.aten.view.default(view_449, [2, 8192, -1, 128]); view_449 = None + view_455 = torch.ops.aten.view.default(view_452, [2, 8192, -1, 128]); view_452 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_453, torch.float32); view_453 = None + view_456 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 32, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_456); view_456 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_454, torch.float32); view_454 = None + view_457 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 8, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_457); view_457 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_16); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_459 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 32, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_16); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_460 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 8, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_459, torch.bfloat16); view_459 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_460, torch.bfloat16); view_460 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 8, 4, 128]); unsqueeze_26 = None + clone_26 = torch.ops.aten.clone.default(expand_26, memory_format = torch.contiguous_format); expand_26 = None + view_461 = torch.ops.aten.view.default(clone_26, [2, 8192, 32, 128]); clone_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_455, 3); view_455 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 8, 4, 128]); unsqueeze_27 = None + clone_27 = torch.ops.aten.clone.default(expand_27, memory_format = torch.contiguous_format); expand_27 = None + view_462 = torch.ops.aten.view.default(clone_27, [2, 8192, 32, 128]); clone_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_461, [0, 2, 1, 3]); view_461 = None + permute_148 = torch.ops.aten.permute.default(view_462, [0, 2, 1, 3]); view_462 = None + _scaled_dot_product_cudnn_attention_13 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_146, permute_147, permute_148, None, True, 0.0, True); permute_146 = permute_147 = permute_148 = None + getitem_117 = _scaled_dot_product_cudnn_attention_13[0] + getitem_118 = _scaled_dot_product_cudnn_attention_13[1] + getitem_123 = _scaled_dot_product_cudnn_attention_13[6] + getitem_124 = _scaled_dot_product_cudnn_attention_13[7]; _scaled_dot_product_cudnn_attention_13 = None + permute_149 = torch.ops.aten.permute.default(getitem_117, [0, 2, 1, 3]) + view_463 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16) + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 256, '0'); convert_element_type_446 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + view_465 = torch.ops.aten.view.default(view_463, [16384, 4096]); view_463 = None + mm_94 = torch.ops.aten.mm.default(view_465, permute_150); view_465 = permute_150 = None + view_466 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + add_53 = torch.ops.aten.add.Tensor(add_51, view_466); view_466 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16) + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 256, '0'); convert_element_type_449 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32) + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = rsqrt_27 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_123); mul_108 = wait_tensor_123 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16) + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 256, '0'); convert_element_type_452 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + view_469 = torch.ops.aten.view.default(convert_element_type_451, [16384, 4096]); convert_element_type_451 = None + mm_95 = torch.ops.aten.mm.default(view_469, permute_151); permute_151 = None + view_470 = torch.ops.aten.view.default(mm_95, [2, 8192, 14336]) + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_470, torch.float32); view_470 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); convert_element_type_455 = sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16) + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 256, '0'); convert_element_type_457 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_125, [1, 0]); wait_tensor_125 = None + mm_96 = torch.ops.aten.mm.default(view_469, permute_152); view_469 = permute_152 = None + view_473 = torch.ops.aten.view.default(mm_96, [2, 8192, 14336]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_473); convert_element_type_456 = view_473 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16) + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 256, '0'); convert_element_type_460 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_126, [1, 0]); wait_tensor_126 = None + view_475 = torch.ops.aten.view.default(mul_111, [16384, 14336]); mul_111 = None + mm_97 = torch.ops.aten.mm.default(view_475, permute_153); view_475 = permute_153 = None + view_476 = torch.ops.aten.view.default(mm_97, [2, 8192, 4096]); mm_97 = None + add_55 = torch.ops.aten.add.Tensor(add_53, view_476); add_53 = view_476 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16) + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 256, '0'); convert_element_type_463 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32) + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = rsqrt_28 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_127); mul_112 = wait_tensor_127 = None + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16) + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 256, '0'); convert_element_type_466 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + view_479 = torch.ops.aten.view.default(convert_element_type_465, [16384, 4096]); convert_element_type_465 = None + mm_98 = torch.ops.aten.mm.default(view_479, permute_154); permute_154 = None + view_480 = torch.ops.aten.view.default(mm_98, [2, 8192, 4096]) + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16) + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 256, '0'); convert_element_type_469 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_99 = torch.ops.aten.mm.default(view_479, permute_155); permute_155 = None + view_483 = torch.ops.aten.view.default(mm_99, [2, 8192, 1024]); mm_99 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16) + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 256, '0'); convert_element_type_472 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + mm_100 = torch.ops.aten.mm.default(view_479, permute_156); view_479 = permute_156 = None + view_486 = torch.ops.aten.view.default(mm_100, [2, 8192, 1024]) + view_487 = torch.ops.aten.view.default(view_480, [2, 8192, -1, 128]); view_480 = None + view_488 = torch.ops.aten.view.default(view_483, [2, 8192, -1, 128]); view_483 = None + view_489 = torch.ops.aten.view.default(view_486, [2, 8192, -1, 128]); view_486 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_487, torch.float32); view_487 = None + view_490 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 32, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_490); view_490 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_488, torch.float32); view_488 = None + view_491 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 8, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_491); view_491 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_16); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_493 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 32, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_16); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_494 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 8, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_493, torch.bfloat16); view_493 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_494, torch.bfloat16); view_494 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 8, 4, 128]); unsqueeze_28 = None + clone_28 = torch.ops.aten.clone.default(expand_28, memory_format = torch.contiguous_format); expand_28 = None + view_495 = torch.ops.aten.view.default(clone_28, [2, 8192, 32, 128]); clone_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_489, 3); view_489 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 8, 4, 128]); unsqueeze_29 = None + clone_29 = torch.ops.aten.clone.default(expand_29, memory_format = torch.contiguous_format); expand_29 = None + view_496 = torch.ops.aten.view.default(clone_29, [2, 8192, 32, 128]); clone_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_495, [0, 2, 1, 3]); view_495 = None + permute_159 = torch.ops.aten.permute.default(view_496, [0, 2, 1, 3]); view_496 = None + _scaled_dot_product_cudnn_attention_14 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_157, permute_158, permute_159, None, True, 0.0, True); permute_157 = permute_158 = permute_159 = None + getitem_126 = _scaled_dot_product_cudnn_attention_14[0] + getitem_127 = _scaled_dot_product_cudnn_attention_14[1] + getitem_132 = _scaled_dot_product_cudnn_attention_14[6] + getitem_133 = _scaled_dot_product_cudnn_attention_14[7]; _scaled_dot_product_cudnn_attention_14 = None + permute_160 = torch.ops.aten.permute.default(getitem_126, [0, 2, 1, 3]) + view_497 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16) + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 256, '0'); convert_element_type_479 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_131, [1, 0]); wait_tensor_131 = None + view_499 = torch.ops.aten.view.default(view_497, [16384, 4096]); view_497 = None + mm_101 = torch.ops.aten.mm.default(view_499, permute_161); view_499 = permute_161 = None + view_500 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + add_57 = torch.ops.aten.add.Tensor(add_55, view_500); view_500 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16) + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 256, '0'); convert_element_type_482 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32) + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = rsqrt_29 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_132); mul_116 = wait_tensor_132 = None + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16) + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 256, '0'); convert_element_type_485 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_133, [1, 0]); wait_tensor_133 = None + view_503 = torch.ops.aten.view.default(convert_element_type_484, [16384, 4096]); convert_element_type_484 = None + mm_102 = torch.ops.aten.mm.default(view_503, permute_162); permute_162 = None + view_504 = torch.ops.aten.view.default(mm_102, [2, 8192, 14336]) + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_504, torch.float32); view_504 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); convert_element_type_488 = sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16) + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 256, '0'); convert_element_type_490 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + mm_103 = torch.ops.aten.mm.default(view_503, permute_163); view_503 = permute_163 = None + view_507 = torch.ops.aten.view.default(mm_103, [2, 8192, 14336]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_507); convert_element_type_489 = view_507 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16) + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 256, '0'); convert_element_type_493 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + view_509 = torch.ops.aten.view.default(mul_119, [16384, 14336]); mul_119 = None + mm_104 = torch.ops.aten.mm.default(view_509, permute_164); view_509 = permute_164 = None + view_510 = torch.ops.aten.view.default(mm_104, [2, 8192, 4096]); mm_104 = None + add_59 = torch.ops.aten.add.Tensor(add_57, view_510); add_57 = view_510 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16) + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 256, '0'); convert_element_type_496 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32) + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = rsqrt_30 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_136); mul_120 = wait_tensor_136 = None + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16) + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 256, '0'); convert_element_type_499 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + view_513 = torch.ops.aten.view.default(convert_element_type_498, [16384, 4096]); convert_element_type_498 = None + mm_105 = torch.ops.aten.mm.default(view_513, permute_165); permute_165 = None + view_514 = torch.ops.aten.view.default(mm_105, [2, 8192, 4096]) + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16) + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 256, '0'); convert_element_type_502 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + mm_106 = torch.ops.aten.mm.default(view_513, permute_166); permute_166 = None + view_517 = torch.ops.aten.view.default(mm_106, [2, 8192, 1024]); mm_106 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16) + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 256, '0'); convert_element_type_505 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_139, [1, 0]); wait_tensor_139 = None + mm_107 = torch.ops.aten.mm.default(view_513, permute_167); view_513 = permute_167 = None + view_520 = torch.ops.aten.view.default(mm_107, [2, 8192, 1024]) + view_521 = torch.ops.aten.view.default(view_514, [2, 8192, -1, 128]); view_514 = None + view_522 = torch.ops.aten.view.default(view_517, [2, 8192, -1, 128]); view_517 = None + view_523 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_521, torch.float32); view_521 = None + view_524 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 32, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_524); view_524 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_522, torch.float32); view_522 = None + view_525 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 8, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_525); view_525 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_16); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_527 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 32, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_16); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_528 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 8, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_527, torch.bfloat16); view_527 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_528, torch.bfloat16); view_528 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 8, 4, 128]); unsqueeze_30 = None + clone_30 = torch.ops.aten.clone.default(expand_30, memory_format = torch.contiguous_format); expand_30 = None + view_529 = torch.ops.aten.view.default(clone_30, [2, 8192, 32, 128]); clone_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_523, 3); view_523 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 8, 4, 128]); unsqueeze_31 = None + clone_31 = torch.ops.aten.clone.default(expand_31, memory_format = torch.contiguous_format); expand_31 = None + view_530 = torch.ops.aten.view.default(clone_31, [2, 8192, 32, 128]); clone_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_529, [0, 2, 1, 3]); view_529 = None + permute_170 = torch.ops.aten.permute.default(view_530, [0, 2, 1, 3]); view_530 = None + _scaled_dot_product_cudnn_attention_15 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_168, permute_169, permute_170, None, True, 0.0, True); permute_168 = permute_169 = permute_170 = None + getitem_135 = _scaled_dot_product_cudnn_attention_15[0] + getitem_136 = _scaled_dot_product_cudnn_attention_15[1] + getitem_141 = _scaled_dot_product_cudnn_attention_15[6] + getitem_142 = _scaled_dot_product_cudnn_attention_15[7]; _scaled_dot_product_cudnn_attention_15 = None + permute_171 = torch.ops.aten.permute.default(getitem_135, [0, 2, 1, 3]) + view_531 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16) + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 256, '0'); convert_element_type_512 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_140, [1, 0]); wait_tensor_140 = None + view_533 = torch.ops.aten.view.default(view_531, [16384, 4096]); view_531 = None + mm_108 = torch.ops.aten.mm.default(view_533, permute_172); view_533 = permute_172 = None + view_534 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + add_61 = torch.ops.aten.add.Tensor(add_59, view_534); view_534 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16) + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 256, '0'); convert_element_type_515 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32) + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = rsqrt_31 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_141); mul_124 = wait_tensor_141 = None + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16) + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 256, '0'); convert_element_type_518 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + view_537 = torch.ops.aten.view.default(convert_element_type_517, [16384, 4096]); convert_element_type_517 = None + mm_109 = torch.ops.aten.mm.default(view_537, permute_173); permute_173 = None + view_538 = torch.ops.aten.view.default(mm_109, [2, 8192, 14336]) + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_538, torch.float32); view_538 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); convert_element_type_521 = sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16) + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 256, '0'); convert_element_type_523 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + mm_110 = torch.ops.aten.mm.default(view_537, permute_174); view_537 = permute_174 = None + view_541 = torch.ops.aten.view.default(mm_110, [2, 8192, 14336]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_541); convert_element_type_522 = view_541 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16) + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 256, '0'); convert_element_type_526 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_144, [1, 0]); wait_tensor_144 = None + view_543 = torch.ops.aten.view.default(mul_127, [16384, 14336]); mul_127 = None + mm_111 = torch.ops.aten.mm.default(view_543, permute_175); view_543 = permute_175 = None + view_544 = torch.ops.aten.view.default(mm_111, [2, 8192, 4096]); mm_111 = None + add_63 = torch.ops.aten.add.Tensor(add_61, view_544); add_61 = view_544 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16) + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 256, '0'); convert_element_type_529 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32) + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = rsqrt_32 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_145); mul_128 = wait_tensor_145 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16) + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 256, '0'); convert_element_type_532 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_146, [1, 0]); wait_tensor_146 = None + view_547 = torch.ops.aten.view.default(convert_element_type_531, [16384, 4096]); convert_element_type_531 = None + mm_112 = torch.ops.aten.mm.default(view_547, permute_176); permute_176 = None + view_548 = torch.ops.aten.view.default(mm_112, [2, 8192, 4096]) + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16) + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 256, '0'); convert_element_type_535 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + mm_113 = torch.ops.aten.mm.default(view_547, permute_177); permute_177 = None + view_551 = torch.ops.aten.view.default(mm_113, [2, 8192, 1024]); mm_113 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16) + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 256, '0'); convert_element_type_538 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_114 = torch.ops.aten.mm.default(view_547, permute_178); view_547 = permute_178 = None + view_554 = torch.ops.aten.view.default(mm_114, [2, 8192, 1024]) + view_555 = torch.ops.aten.view.default(view_548, [2, 8192, -1, 128]); view_548 = None + view_556 = torch.ops.aten.view.default(view_551, [2, 8192, -1, 128]); view_551 = None + view_557 = torch.ops.aten.view.default(view_554, [2, 8192, -1, 128]); view_554 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_555, torch.float32); view_555 = None + view_558 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 32, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_558); view_558 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_556, torch.float32); view_556 = None + view_559 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 8, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_559); view_559 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_16); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_561 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 32, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_16); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_562 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 8, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_561, torch.bfloat16); view_561 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_562, torch.bfloat16); view_562 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 8, 4, 128]); unsqueeze_32 = None + clone_32 = torch.ops.aten.clone.default(expand_32, memory_format = torch.contiguous_format); expand_32 = None + view_563 = torch.ops.aten.view.default(clone_32, [2, 8192, 32, 128]); clone_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_557, 3); view_557 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 8, 4, 128]); unsqueeze_33 = None + clone_33 = torch.ops.aten.clone.default(expand_33, memory_format = torch.contiguous_format); expand_33 = None + view_564 = torch.ops.aten.view.default(clone_33, [2, 8192, 32, 128]); clone_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_563, [0, 2, 1, 3]); view_563 = None + permute_181 = torch.ops.aten.permute.default(view_564, [0, 2, 1, 3]); view_564 = None + _scaled_dot_product_cudnn_attention_16 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_179, permute_180, permute_181, None, True, 0.0, True); permute_179 = permute_180 = permute_181 = None + getitem_144 = _scaled_dot_product_cudnn_attention_16[0] + getitem_145 = _scaled_dot_product_cudnn_attention_16[1] + getitem_150 = _scaled_dot_product_cudnn_attention_16[6] + getitem_151 = _scaled_dot_product_cudnn_attention_16[7]; _scaled_dot_product_cudnn_attention_16 = None + permute_182 = torch.ops.aten.permute.default(getitem_144, [0, 2, 1, 3]) + view_565 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16) + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 256, '0'); convert_element_type_545 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + view_567 = torch.ops.aten.view.default(view_565, [16384, 4096]); view_565 = None + mm_115 = torch.ops.aten.mm.default(view_567, permute_183); view_567 = permute_183 = None + view_568 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + add_65 = torch.ops.aten.add.Tensor(add_63, view_568); view_568 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16) + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 256, '0'); convert_element_type_548 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32) + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = rsqrt_33 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_150); mul_132 = wait_tensor_150 = None + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16) + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 256, '0'); convert_element_type_551 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_151, [1, 0]); wait_tensor_151 = None + view_571 = torch.ops.aten.view.default(convert_element_type_550, [16384, 4096]); convert_element_type_550 = None + mm_116 = torch.ops.aten.mm.default(view_571, permute_184); permute_184 = None + view_572 = torch.ops.aten.view.default(mm_116, [2, 8192, 14336]) + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_572, torch.float32); view_572 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); convert_element_type_554 = sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16) + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 256, '0'); convert_element_type_556 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_152, [1, 0]); wait_tensor_152 = None + mm_117 = torch.ops.aten.mm.default(view_571, permute_185); view_571 = permute_185 = None + view_575 = torch.ops.aten.view.default(mm_117, [2, 8192, 14336]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_575); convert_element_type_555 = view_575 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16) + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 256, '0'); convert_element_type_559 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_153, [1, 0]); wait_tensor_153 = None + view_577 = torch.ops.aten.view.default(mul_135, [16384, 14336]); mul_135 = None + mm_118 = torch.ops.aten.mm.default(view_577, permute_186); view_577 = permute_186 = None + view_578 = torch.ops.aten.view.default(mm_118, [2, 8192, 4096]); mm_118 = None + add_67 = torch.ops.aten.add.Tensor(add_65, view_578); add_65 = view_578 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16) + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 256, '0'); convert_element_type_562 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32) + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = rsqrt_34 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_154); mul_136 = wait_tensor_154 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16) + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 256, '0'); convert_element_type_565 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + view_581 = torch.ops.aten.view.default(convert_element_type_564, [16384, 4096]); convert_element_type_564 = None + mm_119 = torch.ops.aten.mm.default(view_581, permute_187); permute_187 = None + view_582 = torch.ops.aten.view.default(mm_119, [2, 8192, 4096]) + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16) + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 256, '0'); convert_element_type_568 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + mm_120 = torch.ops.aten.mm.default(view_581, permute_188); permute_188 = None + view_585 = torch.ops.aten.view.default(mm_120, [2, 8192, 1024]); mm_120 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16) + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 256, '0'); convert_element_type_571 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + mm_121 = torch.ops.aten.mm.default(view_581, permute_189); view_581 = permute_189 = None + view_588 = torch.ops.aten.view.default(mm_121, [2, 8192, 1024]) + view_589 = torch.ops.aten.view.default(view_582, [2, 8192, -1, 128]); view_582 = None + view_590 = torch.ops.aten.view.default(view_585, [2, 8192, -1, 128]); view_585 = None + view_591 = torch.ops.aten.view.default(view_588, [2, 8192, -1, 128]); view_588 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_589, torch.float32); view_589 = None + view_592 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 32, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_592); view_592 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_590, torch.float32); view_590 = None + view_593 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 8, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_593); view_593 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_16); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_595 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 32, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_16); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_596 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 8, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_595, torch.bfloat16); view_595 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_596, torch.bfloat16); view_596 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 8, 4, 128]); unsqueeze_34 = None + clone_34 = torch.ops.aten.clone.default(expand_34, memory_format = torch.contiguous_format); expand_34 = None + view_597 = torch.ops.aten.view.default(clone_34, [2, 8192, 32, 128]); clone_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_591, 3); view_591 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 8, 4, 128]); unsqueeze_35 = None + clone_35 = torch.ops.aten.clone.default(expand_35, memory_format = torch.contiguous_format); expand_35 = None + view_598 = torch.ops.aten.view.default(clone_35, [2, 8192, 32, 128]); clone_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_597, [0, 2, 1, 3]); view_597 = None + permute_192 = torch.ops.aten.permute.default(view_598, [0, 2, 1, 3]); view_598 = None + _scaled_dot_product_cudnn_attention_17 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_190, permute_191, permute_192, None, True, 0.0, True); permute_190 = permute_191 = permute_192 = None + getitem_153 = _scaled_dot_product_cudnn_attention_17[0] + getitem_154 = _scaled_dot_product_cudnn_attention_17[1] + getitem_159 = _scaled_dot_product_cudnn_attention_17[6] + getitem_160 = _scaled_dot_product_cudnn_attention_17[7]; _scaled_dot_product_cudnn_attention_17 = None + permute_193 = torch.ops.aten.permute.default(getitem_153, [0, 2, 1, 3]) + view_599 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16) + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 256, '0'); convert_element_type_578 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_158, [1, 0]); wait_tensor_158 = None + view_601 = torch.ops.aten.view.default(view_599, [16384, 4096]); view_599 = None + mm_122 = torch.ops.aten.mm.default(view_601, permute_194); view_601 = permute_194 = None + view_602 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + add_69 = torch.ops.aten.add.Tensor(add_67, view_602); view_602 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16) + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 256, '0'); convert_element_type_581 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32) + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = rsqrt_35 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_159); mul_140 = wait_tensor_159 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16) + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 256, '0'); convert_element_type_584 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + view_605 = torch.ops.aten.view.default(convert_element_type_583, [16384, 4096]); convert_element_type_583 = None + mm_123 = torch.ops.aten.mm.default(view_605, permute_195); permute_195 = None + view_606 = torch.ops.aten.view.default(mm_123, [2, 8192, 14336]) + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_606, torch.float32); view_606 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); convert_element_type_587 = sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16) + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 256, '0'); convert_element_type_589 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_124 = torch.ops.aten.mm.default(view_605, permute_196); view_605 = permute_196 = None + view_609 = torch.ops.aten.view.default(mm_124, [2, 8192, 14336]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_609); convert_element_type_588 = view_609 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16) + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 256, '0'); convert_element_type_592 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + view_611 = torch.ops.aten.view.default(mul_143, [16384, 14336]); mul_143 = None + mm_125 = torch.ops.aten.mm.default(view_611, permute_197); view_611 = permute_197 = None + view_612 = torch.ops.aten.view.default(mm_125, [2, 8192, 4096]); mm_125 = None + add_71 = torch.ops.aten.add.Tensor(add_69, view_612); add_69 = view_612 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16) + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 256, '0'); convert_element_type_595 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32) + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = rsqrt_36 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_163); mul_144 = wait_tensor_163 = None + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16) + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 256, '0'); convert_element_type_598 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_164, [1, 0]); wait_tensor_164 = None + view_615 = torch.ops.aten.view.default(convert_element_type_597, [16384, 4096]); convert_element_type_597 = None + mm_126 = torch.ops.aten.mm.default(view_615, permute_198); permute_198 = None + view_616 = torch.ops.aten.view.default(mm_126, [2, 8192, 4096]) + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16) + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 256, '0'); convert_element_type_601 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_165, [1, 0]); wait_tensor_165 = None + mm_127 = torch.ops.aten.mm.default(view_615, permute_199); permute_199 = None + view_619 = torch.ops.aten.view.default(mm_127, [2, 8192, 1024]); mm_127 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16) + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 256, '0'); convert_element_type_604 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_166, [1, 0]); wait_tensor_166 = None + mm_128 = torch.ops.aten.mm.default(view_615, permute_200); view_615 = permute_200 = None + view_622 = torch.ops.aten.view.default(mm_128, [2, 8192, 1024]) + view_623 = torch.ops.aten.view.default(view_616, [2, 8192, -1, 128]); view_616 = None + view_624 = torch.ops.aten.view.default(view_619, [2, 8192, -1, 128]); view_619 = None + view_625 = torch.ops.aten.view.default(view_622, [2, 8192, -1, 128]); view_622 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_623, torch.float32); view_623 = None + view_626 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 32, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_626); view_626 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_624, torch.float32); view_624 = None + view_627 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 8, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_627); view_627 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_16); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_629 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 32, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_16); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_630 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 8, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_629, torch.bfloat16); view_629 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_630, torch.bfloat16); view_630 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 8, 4, 128]); unsqueeze_36 = None + clone_36 = torch.ops.aten.clone.default(expand_36, memory_format = torch.contiguous_format); expand_36 = None + view_631 = torch.ops.aten.view.default(clone_36, [2, 8192, 32, 128]); clone_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_625, 3); view_625 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 8, 4, 128]); unsqueeze_37 = None + clone_37 = torch.ops.aten.clone.default(expand_37, memory_format = torch.contiguous_format); expand_37 = None + view_632 = torch.ops.aten.view.default(clone_37, [2, 8192, 32, 128]); clone_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_631, [0, 2, 1, 3]); view_631 = None + permute_203 = torch.ops.aten.permute.default(view_632, [0, 2, 1, 3]); view_632 = None + _scaled_dot_product_cudnn_attention_18 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_201, permute_202, permute_203, None, True, 0.0, True); permute_201 = permute_202 = permute_203 = None + getitem_162 = _scaled_dot_product_cudnn_attention_18[0] + getitem_163 = _scaled_dot_product_cudnn_attention_18[1] + getitem_168 = _scaled_dot_product_cudnn_attention_18[6] + getitem_169 = _scaled_dot_product_cudnn_attention_18[7]; _scaled_dot_product_cudnn_attention_18 = None + permute_204 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_633 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16) + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 256, '0'); convert_element_type_611 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_635 = torch.ops.aten.view.default(view_633, [16384, 4096]); view_633 = None + mm_129 = torch.ops.aten.mm.default(view_635, permute_205); view_635 = permute_205 = None + view_636 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + add_73 = torch.ops.aten.add.Tensor(add_71, view_636); view_636 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16) + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 256, '0'); convert_element_type_614 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32) + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = rsqrt_37 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_168); mul_148 = wait_tensor_168 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16) + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 256, '0'); convert_element_type_617 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + view_639 = torch.ops.aten.view.default(convert_element_type_616, [16384, 4096]); convert_element_type_616 = None + mm_130 = torch.ops.aten.mm.default(view_639, permute_206); permute_206 = None + view_640 = torch.ops.aten.view.default(mm_130, [2, 8192, 14336]) + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_640, torch.float32); view_640 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); convert_element_type_620 = sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16) + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 256, '0'); convert_element_type_622 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_170, [1, 0]); wait_tensor_170 = None + mm_131 = torch.ops.aten.mm.default(view_639, permute_207); view_639 = permute_207 = None + view_643 = torch.ops.aten.view.default(mm_131, [2, 8192, 14336]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_643); convert_element_type_621 = view_643 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16) + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 256, '0'); convert_element_type_625 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_171, [1, 0]); wait_tensor_171 = None + view_645 = torch.ops.aten.view.default(mul_151, [16384, 14336]); mul_151 = None + mm_132 = torch.ops.aten.mm.default(view_645, permute_208); view_645 = permute_208 = None + view_646 = torch.ops.aten.view.default(mm_132, [2, 8192, 4096]); mm_132 = None + add_75 = torch.ops.aten.add.Tensor(add_73, view_646); add_73 = view_646 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16) + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 256, '0'); convert_element_type_628 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32) + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = rsqrt_38 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_172); mul_152 = wait_tensor_172 = None + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16) + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 256, '0'); convert_element_type_631 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + view_649 = torch.ops.aten.view.default(convert_element_type_630, [16384, 4096]); convert_element_type_630 = None + mm_133 = torch.ops.aten.mm.default(view_649, permute_209); permute_209 = None + view_650 = torch.ops.aten.view.default(mm_133, [2, 8192, 4096]) + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16) + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 256, '0'); convert_element_type_634 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_134 = torch.ops.aten.mm.default(view_649, permute_210); permute_210 = None + view_653 = torch.ops.aten.view.default(mm_134, [2, 8192, 1024]); mm_134 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16) + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 256, '0'); convert_element_type_637 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + mm_135 = torch.ops.aten.mm.default(view_649, permute_211); view_649 = permute_211 = None + view_656 = torch.ops.aten.view.default(mm_135, [2, 8192, 1024]) + view_657 = torch.ops.aten.view.default(view_650, [2, 8192, -1, 128]); view_650 = None + view_658 = torch.ops.aten.view.default(view_653, [2, 8192, -1, 128]); view_653 = None + view_659 = torch.ops.aten.view.default(view_656, [2, 8192, -1, 128]); view_656 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_657, torch.float32); view_657 = None + view_660 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 32, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_660); view_660 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_658, torch.float32); view_658 = None + view_661 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 8, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_661); view_661 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_16); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_663 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 32, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_16); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_664 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 8, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_663, torch.bfloat16); view_663 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_664, torch.bfloat16); view_664 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 8, 4, 128]); unsqueeze_38 = None + clone_38 = torch.ops.aten.clone.default(expand_38, memory_format = torch.contiguous_format); expand_38 = None + view_665 = torch.ops.aten.view.default(clone_38, [2, 8192, 32, 128]); clone_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_659, 3); view_659 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 8, 4, 128]); unsqueeze_39 = None + clone_39 = torch.ops.aten.clone.default(expand_39, memory_format = torch.contiguous_format); expand_39 = None + view_666 = torch.ops.aten.view.default(clone_39, [2, 8192, 32, 128]); clone_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_665, [0, 2, 1, 3]); view_665 = None + permute_214 = torch.ops.aten.permute.default(view_666, [0, 2, 1, 3]); view_666 = None + _scaled_dot_product_cudnn_attention_19 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_212, permute_213, permute_214, None, True, 0.0, True); permute_212 = permute_213 = permute_214 = None + getitem_171 = _scaled_dot_product_cudnn_attention_19[0] + getitem_172 = _scaled_dot_product_cudnn_attention_19[1] + getitem_177 = _scaled_dot_product_cudnn_attention_19[6] + getitem_178 = _scaled_dot_product_cudnn_attention_19[7]; _scaled_dot_product_cudnn_attention_19 = None + permute_215 = torch.ops.aten.permute.default(getitem_171, [0, 2, 1, 3]) + view_667 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16) + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 256, '0'); convert_element_type_644 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_669 = torch.ops.aten.view.default(view_667, [16384, 4096]); view_667 = None + mm_136 = torch.ops.aten.mm.default(view_669, permute_216); view_669 = permute_216 = None + view_670 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + add_77 = torch.ops.aten.add.Tensor(add_75, view_670); view_670 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16) + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 256, '0'); convert_element_type_647 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32) + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = rsqrt_39 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_177); mul_156 = wait_tensor_177 = None + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16) + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 256, '0'); convert_element_type_650 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + view_673 = torch.ops.aten.view.default(convert_element_type_649, [16384, 4096]); convert_element_type_649 = None + mm_137 = torch.ops.aten.mm.default(view_673, permute_217); permute_217 = None + view_674 = torch.ops.aten.view.default(mm_137, [2, 8192, 14336]) + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_674, torch.float32); view_674 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); convert_element_type_653 = sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16) + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 256, '0'); convert_element_type_655 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_179, [1, 0]); wait_tensor_179 = None + mm_138 = torch.ops.aten.mm.default(view_673, permute_218); view_673 = permute_218 = None + view_677 = torch.ops.aten.view.default(mm_138, [2, 8192, 14336]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_677); convert_element_type_654 = view_677 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16) + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 256, '0'); convert_element_type_658 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + view_679 = torch.ops.aten.view.default(mul_159, [16384, 14336]); mul_159 = None + mm_139 = torch.ops.aten.mm.default(view_679, permute_219); view_679 = permute_219 = None + view_680 = torch.ops.aten.view.default(mm_139, [2, 8192, 4096]); mm_139 = None + add_79 = torch.ops.aten.add.Tensor(add_77, view_680); add_77 = view_680 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16) + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 256, '0'); convert_element_type_661 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32) + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = rsqrt_40 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_181); mul_160 = wait_tensor_181 = None + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16) + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 256, '0'); convert_element_type_664 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + view_683 = torch.ops.aten.view.default(convert_element_type_663, [16384, 4096]); convert_element_type_663 = None + mm_140 = torch.ops.aten.mm.default(view_683, permute_220); permute_220 = None + view_684 = torch.ops.aten.view.default(mm_140, [2, 8192, 4096]) + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16) + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 256, '0'); convert_element_type_667 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + mm_141 = torch.ops.aten.mm.default(view_683, permute_221); permute_221 = None + view_687 = torch.ops.aten.view.default(mm_141, [2, 8192, 1024]); mm_141 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16) + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 256, '0'); convert_element_type_670 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + mm_142 = torch.ops.aten.mm.default(view_683, permute_222); view_683 = permute_222 = None + view_690 = torch.ops.aten.view.default(mm_142, [2, 8192, 1024]) + view_691 = torch.ops.aten.view.default(view_684, [2, 8192, -1, 128]); view_684 = None + view_692 = torch.ops.aten.view.default(view_687, [2, 8192, -1, 128]); view_687 = None + view_693 = torch.ops.aten.view.default(view_690, [2, 8192, -1, 128]); view_690 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_691, torch.float32); view_691 = None + view_694 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 32, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_694); view_694 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_692, torch.float32); view_692 = None + view_695 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 8, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_695); view_695 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_16); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_697 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 32, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_16); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_698 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 8, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_697, torch.bfloat16); view_697 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_698, torch.bfloat16); view_698 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 8, 4, 128]); unsqueeze_40 = None + clone_40 = torch.ops.aten.clone.default(expand_40, memory_format = torch.contiguous_format); expand_40 = None + view_699 = torch.ops.aten.view.default(clone_40, [2, 8192, 32, 128]); clone_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_693, 3); view_693 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 8, 4, 128]); unsqueeze_41 = None + clone_41 = torch.ops.aten.clone.default(expand_41, memory_format = torch.contiguous_format); expand_41 = None + view_700 = torch.ops.aten.view.default(clone_41, [2, 8192, 32, 128]); clone_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_699, [0, 2, 1, 3]); view_699 = None + permute_225 = torch.ops.aten.permute.default(view_700, [0, 2, 1, 3]); view_700 = None + _scaled_dot_product_cudnn_attention_20 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_223, permute_224, permute_225, None, True, 0.0, True); permute_223 = permute_224 = permute_225 = None + getitem_180 = _scaled_dot_product_cudnn_attention_20[0] + getitem_181 = _scaled_dot_product_cudnn_attention_20[1] + getitem_186 = _scaled_dot_product_cudnn_attention_20[6] + getitem_187 = _scaled_dot_product_cudnn_attention_20[7]; _scaled_dot_product_cudnn_attention_20 = None + permute_226 = torch.ops.aten.permute.default(getitem_180, [0, 2, 1, 3]) + view_701 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16) + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 256, '0'); convert_element_type_677 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_185, [1, 0]); wait_tensor_185 = None + view_703 = torch.ops.aten.view.default(view_701, [16384, 4096]); view_701 = None + mm_143 = torch.ops.aten.mm.default(view_703, permute_227); view_703 = permute_227 = None + view_704 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + add_81 = torch.ops.aten.add.Tensor(add_79, view_704); view_704 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16) + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 256, '0'); convert_element_type_680 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32) + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = rsqrt_41 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_186); mul_164 = wait_tensor_186 = None + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16) + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 256, '0'); convert_element_type_683 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + view_707 = torch.ops.aten.view.default(convert_element_type_682, [16384, 4096]); convert_element_type_682 = None + mm_144 = torch.ops.aten.mm.default(view_707, permute_228); permute_228 = None + view_708 = torch.ops.aten.view.default(mm_144, [2, 8192, 14336]) + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_708, torch.float32); view_708 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); convert_element_type_686 = sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16) + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 256, '0'); convert_element_type_688 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_145 = torch.ops.aten.mm.default(view_707, permute_229); view_707 = permute_229 = None + view_711 = torch.ops.aten.view.default(mm_145, [2, 8192, 14336]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_711); convert_element_type_687 = view_711 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16) + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 256, '0'); convert_element_type_691 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + view_713 = torch.ops.aten.view.default(mul_167, [16384, 14336]); mul_167 = None + mm_146 = torch.ops.aten.mm.default(view_713, permute_230); view_713 = permute_230 = None + view_714 = torch.ops.aten.view.default(mm_146, [2, 8192, 4096]); mm_146 = None + add_83 = torch.ops.aten.add.Tensor(add_81, view_714); add_81 = view_714 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16) + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 256, '0'); convert_element_type_694 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32) + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = rsqrt_42 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_190); mul_168 = wait_tensor_190 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16) + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 256, '0'); convert_element_type_697 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_191, [1, 0]); wait_tensor_191 = None + view_717 = torch.ops.aten.view.default(convert_element_type_696, [16384, 4096]); convert_element_type_696 = None + mm_147 = torch.ops.aten.mm.default(view_717, permute_231); permute_231 = None + view_718 = torch.ops.aten.view.default(mm_147, [2, 8192, 4096]) + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16) + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 256, '0'); convert_element_type_700 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_192, [1, 0]); wait_tensor_192 = None + mm_148 = torch.ops.aten.mm.default(view_717, permute_232); permute_232 = None + view_721 = torch.ops.aten.view.default(mm_148, [2, 8192, 1024]); mm_148 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16) + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 256, '0'); convert_element_type_703 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + mm_149 = torch.ops.aten.mm.default(view_717, permute_233); view_717 = permute_233 = None + view_724 = torch.ops.aten.view.default(mm_149, [2, 8192, 1024]) + view_725 = torch.ops.aten.view.default(view_718, [2, 8192, -1, 128]); view_718 = None + view_726 = torch.ops.aten.view.default(view_721, [2, 8192, -1, 128]); view_721 = None + view_727 = torch.ops.aten.view.default(view_724, [2, 8192, -1, 128]); view_724 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_725, torch.float32); view_725 = None + view_728 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 32, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_728); view_728 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_726, torch.float32); view_726 = None + view_729 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 8, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_729); view_729 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_16); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_731 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 32, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_16); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_732 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 8, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_731, torch.bfloat16); view_731 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_732, torch.bfloat16); view_732 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 8, 4, 128]); unsqueeze_42 = None + clone_42 = torch.ops.aten.clone.default(expand_42, memory_format = torch.contiguous_format); expand_42 = None + view_733 = torch.ops.aten.view.default(clone_42, [2, 8192, 32, 128]); clone_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_727, 3); view_727 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 8, 4, 128]); unsqueeze_43 = None + clone_43 = torch.ops.aten.clone.default(expand_43, memory_format = torch.contiguous_format); expand_43 = None + view_734 = torch.ops.aten.view.default(clone_43, [2, 8192, 32, 128]); clone_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_733, [0, 2, 1, 3]); view_733 = None + permute_236 = torch.ops.aten.permute.default(view_734, [0, 2, 1, 3]); view_734 = None + _scaled_dot_product_cudnn_attention_21 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_234, permute_235, permute_236, None, True, 0.0, True); permute_234 = permute_235 = permute_236 = None + getitem_189 = _scaled_dot_product_cudnn_attention_21[0] + getitem_190 = _scaled_dot_product_cudnn_attention_21[1] + getitem_195 = _scaled_dot_product_cudnn_attention_21[6] + getitem_196 = _scaled_dot_product_cudnn_attention_21[7]; _scaled_dot_product_cudnn_attention_21 = None + permute_237 = torch.ops.aten.permute.default(getitem_189, [0, 2, 1, 3]) + view_735 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16) + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 256, '0'); convert_element_type_710 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + view_737 = torch.ops.aten.view.default(view_735, [16384, 4096]); view_735 = None + mm_150 = torch.ops.aten.mm.default(view_737, permute_238); view_737 = permute_238 = None + view_738 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + add_85 = torch.ops.aten.add.Tensor(add_83, view_738); view_738 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16) + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 256, '0'); convert_element_type_713 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32) + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = rsqrt_43 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_195); mul_172 = wait_tensor_195 = None + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16) + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 256, '0'); convert_element_type_716 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_196, [1, 0]); wait_tensor_196 = None + view_741 = torch.ops.aten.view.default(convert_element_type_715, [16384, 4096]); convert_element_type_715 = None + mm_151 = torch.ops.aten.mm.default(view_741, permute_239); permute_239 = None + view_742 = torch.ops.aten.view.default(mm_151, [2, 8192, 14336]) + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_742, torch.float32); view_742 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); convert_element_type_719 = sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16) + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 256, '0'); convert_element_type_721 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + mm_152 = torch.ops.aten.mm.default(view_741, permute_240); view_741 = permute_240 = None + view_745 = torch.ops.aten.view.default(mm_152, [2, 8192, 14336]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_745); convert_element_type_720 = view_745 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16) + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 256, '0'); convert_element_type_724 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + view_747 = torch.ops.aten.view.default(mul_175, [16384, 14336]); mul_175 = None + mm_153 = torch.ops.aten.mm.default(view_747, permute_241); view_747 = permute_241 = None + view_748 = torch.ops.aten.view.default(mm_153, [2, 8192, 4096]); mm_153 = None + add_87 = torch.ops.aten.add.Tensor(add_85, view_748); add_85 = view_748 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16) + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 256, '0'); convert_element_type_727 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32) + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = rsqrt_44 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_199); mul_176 = wait_tensor_199 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16) + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 256, '0'); convert_element_type_730 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + view_751 = torch.ops.aten.view.default(convert_element_type_729, [16384, 4096]); convert_element_type_729 = None + mm_154 = torch.ops.aten.mm.default(view_751, permute_242); permute_242 = None + view_752 = torch.ops.aten.view.default(mm_154, [2, 8192, 4096]) + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16) + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 256, '0'); convert_element_type_733 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_155 = torch.ops.aten.mm.default(view_751, permute_243); permute_243 = None + view_755 = torch.ops.aten.view.default(mm_155, [2, 8192, 1024]); mm_155 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16) + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 256, '0'); convert_element_type_736 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + mm_156 = torch.ops.aten.mm.default(view_751, permute_244); view_751 = permute_244 = None + view_758 = torch.ops.aten.view.default(mm_156, [2, 8192, 1024]) + view_759 = torch.ops.aten.view.default(view_752, [2, 8192, -1, 128]); view_752 = None + view_760 = torch.ops.aten.view.default(view_755, [2, 8192, -1, 128]); view_755 = None + view_761 = torch.ops.aten.view.default(view_758, [2, 8192, -1, 128]); view_758 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_759, torch.float32); view_759 = None + view_762 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 32, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_762); view_762 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_760, torch.float32); view_760 = None + view_763 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 8, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_763); view_763 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_16); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_765 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 32, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_16); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_766 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 8, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_765, torch.bfloat16); view_765 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_766, torch.bfloat16); view_766 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 8, 4, 128]); unsqueeze_44 = None + clone_44 = torch.ops.aten.clone.default(expand_44, memory_format = torch.contiguous_format); expand_44 = None + view_767 = torch.ops.aten.view.default(clone_44, [2, 8192, 32, 128]); clone_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_761, 3); view_761 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 8, 4, 128]); unsqueeze_45 = None + clone_45 = torch.ops.aten.clone.default(expand_45, memory_format = torch.contiguous_format); expand_45 = None + view_768 = torch.ops.aten.view.default(clone_45, [2, 8192, 32, 128]); clone_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_767, [0, 2, 1, 3]); view_767 = None + permute_247 = torch.ops.aten.permute.default(view_768, [0, 2, 1, 3]); view_768 = None + _scaled_dot_product_cudnn_attention_22 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_245, permute_246, permute_247, None, True, 0.0, True); permute_245 = permute_246 = permute_247 = None + getitem_198 = _scaled_dot_product_cudnn_attention_22[0] + getitem_199 = _scaled_dot_product_cudnn_attention_22[1] + getitem_204 = _scaled_dot_product_cudnn_attention_22[6] + getitem_205 = _scaled_dot_product_cudnn_attention_22[7]; _scaled_dot_product_cudnn_attention_22 = None + permute_248 = torch.ops.aten.permute.default(getitem_198, [0, 2, 1, 3]) + view_769 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16) + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 256, '0'); convert_element_type_743 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_203, [1, 0]); wait_tensor_203 = None + view_771 = torch.ops.aten.view.default(view_769, [16384, 4096]); view_769 = None + mm_157 = torch.ops.aten.mm.default(view_771, permute_249); view_771 = permute_249 = None + view_772 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + add_89 = torch.ops.aten.add.Tensor(add_87, view_772); view_772 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16) + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 256, '0'); convert_element_type_746 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32) + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = rsqrt_45 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_204); mul_180 = wait_tensor_204 = None + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16) + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 256, '0'); convert_element_type_749 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + view_775 = torch.ops.aten.view.default(convert_element_type_748, [16384, 4096]); convert_element_type_748 = None + mm_158 = torch.ops.aten.mm.default(view_775, permute_250); permute_250 = None + view_776 = torch.ops.aten.view.default(mm_158, [2, 8192, 14336]) + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_776, torch.float32); view_776 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); convert_element_type_752 = sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16) + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 256, '0'); convert_element_type_754 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + mm_159 = torch.ops.aten.mm.default(view_775, permute_251); view_775 = permute_251 = None + view_779 = torch.ops.aten.view.default(mm_159, [2, 8192, 14336]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_779); convert_element_type_753 = view_779 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16) + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 256, '0'); convert_element_type_757 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + view_781 = torch.ops.aten.view.default(mul_183, [16384, 14336]); mul_183 = None + mm_160 = torch.ops.aten.mm.default(view_781, permute_252); view_781 = permute_252 = None + view_782 = torch.ops.aten.view.default(mm_160, [2, 8192, 4096]); mm_160 = None + add_91 = torch.ops.aten.add.Tensor(add_89, view_782); add_89 = view_782 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16) + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 256, '0'); convert_element_type_760 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32) + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = rsqrt_46 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_208); mul_184 = wait_tensor_208 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16) + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 256, '0'); convert_element_type_763 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_209, [1, 0]); wait_tensor_209 = None + view_785 = torch.ops.aten.view.default(convert_element_type_762, [16384, 4096]); convert_element_type_762 = None + mm_161 = torch.ops.aten.mm.default(view_785, permute_253); permute_253 = None + view_786 = torch.ops.aten.view.default(mm_161, [2, 8192, 4096]) + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16) + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 256, '0'); convert_element_type_766 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_210, [1, 0]); wait_tensor_210 = None + mm_162 = torch.ops.aten.mm.default(view_785, permute_254); permute_254 = None + view_789 = torch.ops.aten.view.default(mm_162, [2, 8192, 1024]); mm_162 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16) + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 256, '0'); convert_element_type_769 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_211, [1, 0]); wait_tensor_211 = None + mm_163 = torch.ops.aten.mm.default(view_785, permute_255); view_785 = permute_255 = None + view_792 = torch.ops.aten.view.default(mm_163, [2, 8192, 1024]) + view_793 = torch.ops.aten.view.default(view_786, [2, 8192, -1, 128]); view_786 = None + view_794 = torch.ops.aten.view.default(view_789, [2, 8192, -1, 128]); view_789 = None + view_795 = torch.ops.aten.view.default(view_792, [2, 8192, -1, 128]); view_792 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_793, torch.float32); view_793 = None + view_796 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 32, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_796); view_796 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_794, torch.float32); view_794 = None + view_797 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 8, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_797); view_797 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_16); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_799 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 32, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_16); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_800 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 8, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_799, torch.bfloat16); view_799 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_800, torch.bfloat16); view_800 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 8, 4, 128]); unsqueeze_46 = None + clone_46 = torch.ops.aten.clone.default(expand_46, memory_format = torch.contiguous_format); expand_46 = None + view_801 = torch.ops.aten.view.default(clone_46, [2, 8192, 32, 128]); clone_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_795, 3); view_795 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 8, 4, 128]); unsqueeze_47 = None + clone_47 = torch.ops.aten.clone.default(expand_47, memory_format = torch.contiguous_format); expand_47 = None + view_802 = torch.ops.aten.view.default(clone_47, [2, 8192, 32, 128]); clone_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_801, [0, 2, 1, 3]); view_801 = None + permute_258 = torch.ops.aten.permute.default(view_802, [0, 2, 1, 3]); view_802 = None + _scaled_dot_product_cudnn_attention_23 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_256, permute_257, permute_258, None, True, 0.0, True); permute_256 = permute_257 = permute_258 = None + getitem_207 = _scaled_dot_product_cudnn_attention_23[0] + getitem_208 = _scaled_dot_product_cudnn_attention_23[1] + getitem_213 = _scaled_dot_product_cudnn_attention_23[6] + getitem_214 = _scaled_dot_product_cudnn_attention_23[7]; _scaled_dot_product_cudnn_attention_23 = None + permute_259 = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3]) + view_803 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16) + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 256, '0'); convert_element_type_776 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_805 = torch.ops.aten.view.default(view_803, [16384, 4096]); view_803 = None + mm_164 = torch.ops.aten.mm.default(view_805, permute_260); view_805 = permute_260 = None + view_806 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + add_93 = torch.ops.aten.add.Tensor(add_91, view_806); view_806 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16) + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 256, '0'); convert_element_type_779 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32) + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = rsqrt_47 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_213); mul_188 = wait_tensor_213 = None + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16) + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 256, '0'); convert_element_type_782 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + view_809 = torch.ops.aten.view.default(convert_element_type_781, [16384, 4096]); convert_element_type_781 = None + mm_165 = torch.ops.aten.mm.default(view_809, permute_261); permute_261 = None + view_810 = torch.ops.aten.view.default(mm_165, [2, 8192, 14336]) + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_810, torch.float32); view_810 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); convert_element_type_785 = sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16) + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 256, '0'); convert_element_type_787 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + mm_166 = torch.ops.aten.mm.default(view_809, permute_262); view_809 = permute_262 = None + view_813 = torch.ops.aten.view.default(mm_166, [2, 8192, 14336]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_813); convert_element_type_786 = view_813 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16) + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 256, '0'); convert_element_type_790 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_216, [1, 0]); wait_tensor_216 = None + view_815 = torch.ops.aten.view.default(mul_191, [16384, 14336]); mul_191 = None + mm_167 = torch.ops.aten.mm.default(view_815, permute_263); view_815 = permute_263 = None + view_816 = torch.ops.aten.view.default(mm_167, [2, 8192, 4096]); mm_167 = None + add_95 = torch.ops.aten.add.Tensor(add_93, view_816); add_93 = view_816 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16) + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 256, '0'); convert_element_type_793 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32) + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = rsqrt_48 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_217); mul_192 = wait_tensor_217 = None + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16) + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 256, '0'); convert_element_type_796 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + view_819 = torch.ops.aten.view.default(convert_element_type_795, [16384, 4096]); convert_element_type_795 = None + mm_168 = torch.ops.aten.mm.default(view_819, permute_264); permute_264 = None + view_820 = torch.ops.aten.view.default(mm_168, [2, 8192, 4096]) + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16) + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 256, '0'); convert_element_type_799 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + mm_169 = torch.ops.aten.mm.default(view_819, permute_265); permute_265 = None + view_823 = torch.ops.aten.view.default(mm_169, [2, 8192, 1024]); mm_169 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16) + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 256, '0'); convert_element_type_802 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_170 = torch.ops.aten.mm.default(view_819, permute_266); view_819 = permute_266 = None + view_826 = torch.ops.aten.view.default(mm_170, [2, 8192, 1024]) + view_827 = torch.ops.aten.view.default(view_820, [2, 8192, -1, 128]); view_820 = None + view_828 = torch.ops.aten.view.default(view_823, [2, 8192, -1, 128]); view_823 = None + view_829 = torch.ops.aten.view.default(view_826, [2, 8192, -1, 128]); view_826 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_827, torch.float32); view_827 = None + view_830 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 32, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_830); view_830 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_828, torch.float32); view_828 = None + view_831 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 8, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_831); view_831 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_16); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_833 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 32, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_16); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_834 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 8, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_833, torch.bfloat16); view_833 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_834, torch.bfloat16); view_834 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 8, 4, 128]); unsqueeze_48 = None + clone_48 = torch.ops.aten.clone.default(expand_48, memory_format = torch.contiguous_format); expand_48 = None + view_835 = torch.ops.aten.view.default(clone_48, [2, 8192, 32, 128]); clone_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_829, 3); view_829 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 8, 4, 128]); unsqueeze_49 = None + clone_49 = torch.ops.aten.clone.default(expand_49, memory_format = torch.contiguous_format); expand_49 = None + view_836 = torch.ops.aten.view.default(clone_49, [2, 8192, 32, 128]); clone_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_835, [0, 2, 1, 3]); view_835 = None + permute_269 = torch.ops.aten.permute.default(view_836, [0, 2, 1, 3]); view_836 = None + _scaled_dot_product_cudnn_attention_24 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_267, permute_268, permute_269, None, True, 0.0, True); permute_267 = permute_268 = permute_269 = None + getitem_216 = _scaled_dot_product_cudnn_attention_24[0] + getitem_217 = _scaled_dot_product_cudnn_attention_24[1] + getitem_222 = _scaled_dot_product_cudnn_attention_24[6] + getitem_223 = _scaled_dot_product_cudnn_attention_24[7]; _scaled_dot_product_cudnn_attention_24 = None + permute_270 = torch.ops.aten.permute.default(getitem_216, [0, 2, 1, 3]) + view_837 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16) + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 256, '0'); convert_element_type_809 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_839 = torch.ops.aten.view.default(view_837, [16384, 4096]); view_837 = None + mm_171 = torch.ops.aten.mm.default(view_839, permute_271); view_839 = permute_271 = None + view_840 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + add_97 = torch.ops.aten.add.Tensor(add_95, view_840); view_840 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16) + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 256, '0'); convert_element_type_812 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32) + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = rsqrt_49 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_222); mul_196 = wait_tensor_222 = None + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16) + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 256, '0'); convert_element_type_815 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_223, [1, 0]); wait_tensor_223 = None + view_843 = torch.ops.aten.view.default(convert_element_type_814, [16384, 4096]); convert_element_type_814 = None + mm_172 = torch.ops.aten.mm.default(view_843, permute_272); permute_272 = None + view_844 = torch.ops.aten.view.default(mm_172, [2, 8192, 14336]) + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_844, torch.float32); view_844 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); convert_element_type_818 = sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16) + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 256, '0'); convert_element_type_820 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_224, [1, 0]); wait_tensor_224 = None + mm_173 = torch.ops.aten.mm.default(view_843, permute_273); view_843 = permute_273 = None + view_847 = torch.ops.aten.view.default(mm_173, [2, 8192, 14336]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_847); convert_element_type_819 = view_847 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16) + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 256, '0'); convert_element_type_823 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + view_849 = torch.ops.aten.view.default(mul_199, [16384, 14336]); mul_199 = None + mm_174 = torch.ops.aten.mm.default(view_849, permute_274); view_849 = permute_274 = None + view_850 = torch.ops.aten.view.default(mm_174, [2, 8192, 4096]); mm_174 = None + add_99 = torch.ops.aten.add.Tensor(add_97, view_850); add_97 = view_850 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16) + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 256, '0'); convert_element_type_826 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32) + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = rsqrt_50 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_226); mul_200 = wait_tensor_226 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16) + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 256, '0'); convert_element_type_829 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + view_853 = torch.ops.aten.view.default(convert_element_type_828, [16384, 4096]); convert_element_type_828 = None + mm_175 = torch.ops.aten.mm.default(view_853, permute_275); permute_275 = None + view_854 = torch.ops.aten.view.default(mm_175, [2, 8192, 4096]) + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16) + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 256, '0'); convert_element_type_832 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + mm_176 = torch.ops.aten.mm.default(view_853, permute_276); permute_276 = None + view_857 = torch.ops.aten.view.default(mm_176, [2, 8192, 1024]); mm_176 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16) + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 256, '0'); convert_element_type_835 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_229, [1, 0]); wait_tensor_229 = None + mm_177 = torch.ops.aten.mm.default(view_853, permute_277); view_853 = permute_277 = None + view_860 = torch.ops.aten.view.default(mm_177, [2, 8192, 1024]) + view_861 = torch.ops.aten.view.default(view_854, [2, 8192, -1, 128]); view_854 = None + view_862 = torch.ops.aten.view.default(view_857, [2, 8192, -1, 128]); view_857 = None + view_863 = torch.ops.aten.view.default(view_860, [2, 8192, -1, 128]); view_860 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_861, torch.float32); view_861 = None + view_864 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 32, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_864); view_864 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_862, torch.float32); view_862 = None + view_865 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 8, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_865); view_865 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_16); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_867 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 32, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_16); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_868 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 8, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_867, torch.bfloat16); view_867 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_868, torch.bfloat16); view_868 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 8, 4, 128]); unsqueeze_50 = None + clone_50 = torch.ops.aten.clone.default(expand_50, memory_format = torch.contiguous_format); expand_50 = None + view_869 = torch.ops.aten.view.default(clone_50, [2, 8192, 32, 128]); clone_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_863, 3); view_863 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 8, 4, 128]); unsqueeze_51 = None + clone_51 = torch.ops.aten.clone.default(expand_51, memory_format = torch.contiguous_format); expand_51 = None + view_870 = torch.ops.aten.view.default(clone_51, [2, 8192, 32, 128]); clone_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_869, [0, 2, 1, 3]); view_869 = None + permute_280 = torch.ops.aten.permute.default(view_870, [0, 2, 1, 3]); view_870 = None + _scaled_dot_product_cudnn_attention_25 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_278, permute_279, permute_280, None, True, 0.0, True); permute_278 = permute_279 = permute_280 = None + getitem_225 = _scaled_dot_product_cudnn_attention_25[0] + getitem_226 = _scaled_dot_product_cudnn_attention_25[1] + getitem_231 = _scaled_dot_product_cudnn_attention_25[6] + getitem_232 = _scaled_dot_product_cudnn_attention_25[7]; _scaled_dot_product_cudnn_attention_25 = None + permute_281 = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]) + view_871 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16) + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 256, '0'); convert_element_type_842 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_230, [1, 0]); wait_tensor_230 = None + view_873 = torch.ops.aten.view.default(view_871, [16384, 4096]); view_871 = None + mm_178 = torch.ops.aten.mm.default(view_873, permute_282); view_873 = permute_282 = None + view_874 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + add_101 = torch.ops.aten.add.Tensor(add_99, view_874); view_874 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16) + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 256, '0'); convert_element_type_845 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32) + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = rsqrt_51 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_231); mul_204 = wait_tensor_231 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16) + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 256, '0'); convert_element_type_848 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + view_877 = torch.ops.aten.view.default(convert_element_type_847, [16384, 4096]); convert_element_type_847 = None + mm_179 = torch.ops.aten.mm.default(view_877, permute_283); permute_283 = None + view_878 = torch.ops.aten.view.default(mm_179, [2, 8192, 14336]) + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_878, torch.float32); view_878 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); convert_element_type_851 = sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16) + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 256, '0'); convert_element_type_853 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_180 = torch.ops.aten.mm.default(view_877, permute_284); view_877 = permute_284 = None + view_881 = torch.ops.aten.view.default(mm_180, [2, 8192, 14336]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_881); convert_element_type_852 = view_881 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16) + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 256, '0'); convert_element_type_856 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + view_883 = torch.ops.aten.view.default(mul_207, [16384, 14336]); mul_207 = None + mm_181 = torch.ops.aten.mm.default(view_883, permute_285); view_883 = permute_285 = None + view_884 = torch.ops.aten.view.default(mm_181, [2, 8192, 4096]); mm_181 = None + add_103 = torch.ops.aten.add.Tensor(add_101, view_884); add_101 = view_884 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16) + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 256, '0'); convert_element_type_859 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32) + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = rsqrt_52 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_235); mul_208 = wait_tensor_235 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16) + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 256, '0'); convert_element_type_862 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_236, [1, 0]); wait_tensor_236 = None + view_887 = torch.ops.aten.view.default(convert_element_type_861, [16384, 4096]); convert_element_type_861 = None + mm_182 = torch.ops.aten.mm.default(view_887, permute_286); permute_286 = None + view_888 = torch.ops.aten.view.default(mm_182, [2, 8192, 4096]) + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16) + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 256, '0'); convert_element_type_865 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_237, [1, 0]); wait_tensor_237 = None + mm_183 = torch.ops.aten.mm.default(view_887, permute_287); permute_287 = None + view_891 = torch.ops.aten.view.default(mm_183, [2, 8192, 1024]); mm_183 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16) + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 256, '0'); convert_element_type_868 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + mm_184 = torch.ops.aten.mm.default(view_887, permute_288); view_887 = permute_288 = None + view_894 = torch.ops.aten.view.default(mm_184, [2, 8192, 1024]) + view_895 = torch.ops.aten.view.default(view_888, [2, 8192, -1, 128]); view_888 = None + view_896 = torch.ops.aten.view.default(view_891, [2, 8192, -1, 128]); view_891 = None + view_897 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_895, torch.float32); view_895 = None + view_898 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 32, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_898); view_898 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 8, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_16); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_901 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 32, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_16); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_902 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 8, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_901, torch.bfloat16); view_901 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 8, 4, 128]); unsqueeze_52 = None + clone_52 = torch.ops.aten.clone.default(expand_52, memory_format = torch.contiguous_format); expand_52 = None + view_903 = torch.ops.aten.view.default(clone_52, [2, 8192, 32, 128]); clone_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_897, 3); view_897 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 8, 4, 128]); unsqueeze_53 = None + clone_53 = torch.ops.aten.clone.default(expand_53, memory_format = torch.contiguous_format); expand_53 = None + view_904 = torch.ops.aten.view.default(clone_53, [2, 8192, 32, 128]); clone_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_903, [0, 2, 1, 3]); view_903 = None + permute_291 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + _scaled_dot_product_cudnn_attention_26 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_289, permute_290, permute_291, None, True, 0.0, True); permute_289 = permute_290 = permute_291 = None + getitem_234 = _scaled_dot_product_cudnn_attention_26[0] + getitem_235 = _scaled_dot_product_cudnn_attention_26[1] + getitem_240 = _scaled_dot_product_cudnn_attention_26[6] + getitem_241 = _scaled_dot_product_cudnn_attention_26[7]; _scaled_dot_product_cudnn_attention_26 = None + permute_292 = torch.ops.aten.permute.default(getitem_234, [0, 2, 1, 3]) + view_905 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16) + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 256, '0'); convert_element_type_875 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + view_907 = torch.ops.aten.view.default(view_905, [16384, 4096]); view_905 = None + mm_185 = torch.ops.aten.mm.default(view_907, permute_293); view_907 = permute_293 = None + view_908 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + add_105 = torch.ops.aten.add.Tensor(add_103, view_908); view_908 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16) + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 256, '0'); convert_element_type_878 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32) + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = rsqrt_53 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_240); mul_212 = wait_tensor_240 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16) + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 256, '0'); convert_element_type_881 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + view_911 = torch.ops.aten.view.default(convert_element_type_880, [16384, 4096]); convert_element_type_880 = None + mm_186 = torch.ops.aten.mm.default(view_911, permute_294); permute_294 = None + view_912 = torch.ops.aten.view.default(mm_186, [2, 8192, 14336]) + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_912, torch.float32); view_912 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); convert_element_type_884 = sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16) + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 256, '0'); convert_element_type_886 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_242, [1, 0]); wait_tensor_242 = None + mm_187 = torch.ops.aten.mm.default(view_911, permute_295); view_911 = permute_295 = None + view_915 = torch.ops.aten.view.default(mm_187, [2, 8192, 14336]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_915); convert_element_type_885 = view_915 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16) + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 256, '0'); convert_element_type_889 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + view_917 = torch.ops.aten.view.default(mul_215, [16384, 14336]); mul_215 = None + mm_188 = torch.ops.aten.mm.default(view_917, permute_296); view_917 = permute_296 = None + view_918 = torch.ops.aten.view.default(mm_188, [2, 8192, 4096]); mm_188 = None + add_107 = torch.ops.aten.add.Tensor(add_105, view_918); add_105 = view_918 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16) + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 256, '0'); convert_element_type_892 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32) + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = rsqrt_54 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_244); mul_216 = wait_tensor_244 = None + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16) + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 256, '0'); convert_element_type_895 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + view_921 = torch.ops.aten.view.default(convert_element_type_894, [16384, 4096]); convert_element_type_894 = None + mm_189 = torch.ops.aten.mm.default(view_921, permute_297); permute_297 = None + view_922 = torch.ops.aten.view.default(mm_189, [2, 8192, 4096]) + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16) + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 256, '0'); convert_element_type_898 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_190 = torch.ops.aten.mm.default(view_921, permute_298); permute_298 = None + view_925 = torch.ops.aten.view.default(mm_190, [2, 8192, 1024]); mm_190 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16) + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 256, '0'); convert_element_type_901 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + mm_191 = torch.ops.aten.mm.default(view_921, permute_299); view_921 = permute_299 = None + view_928 = torch.ops.aten.view.default(mm_191, [2, 8192, 1024]) + view_929 = torch.ops.aten.view.default(view_922, [2, 8192, -1, 128]); view_922 = None + view_930 = torch.ops.aten.view.default(view_925, [2, 8192, -1, 128]); view_925 = None + view_931 = torch.ops.aten.view.default(view_928, [2, 8192, -1, 128]); view_928 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_929, torch.float32); view_929 = None + view_932 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 32, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_932); view_932 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_930, torch.float32); view_930 = None + view_933 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 8, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_933); view_933 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_16); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_935 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 32, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_16); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_936 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 8, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_935, torch.bfloat16); view_935 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_936, torch.bfloat16); view_936 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 8, 4, 128]); unsqueeze_54 = None + clone_54 = torch.ops.aten.clone.default(expand_54, memory_format = torch.contiguous_format); expand_54 = None + view_937 = torch.ops.aten.view.default(clone_54, [2, 8192, 32, 128]); clone_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_931, 3); view_931 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 8, 4, 128]); unsqueeze_55 = None + clone_55 = torch.ops.aten.clone.default(expand_55, memory_format = torch.contiguous_format); expand_55 = None + view_938 = torch.ops.aten.view.default(clone_55, [2, 8192, 32, 128]); clone_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_937, [0, 2, 1, 3]); view_937 = None + permute_302 = torch.ops.aten.permute.default(view_938, [0, 2, 1, 3]); view_938 = None + _scaled_dot_product_cudnn_attention_27 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_300, permute_301, permute_302, None, True, 0.0, True); permute_300 = permute_301 = permute_302 = None + getitem_243 = _scaled_dot_product_cudnn_attention_27[0] + getitem_244 = _scaled_dot_product_cudnn_attention_27[1] + getitem_249 = _scaled_dot_product_cudnn_attention_27[6] + getitem_250 = _scaled_dot_product_cudnn_attention_27[7]; _scaled_dot_product_cudnn_attention_27 = None + permute_303 = torch.ops.aten.permute.default(getitem_243, [0, 2, 1, 3]) + view_939 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16) + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 256, '0'); convert_element_type_908 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_248, [1, 0]); wait_tensor_248 = None + view_941 = torch.ops.aten.view.default(view_939, [16384, 4096]); view_939 = None + mm_192 = torch.ops.aten.mm.default(view_941, permute_304); view_941 = permute_304 = None + view_942 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + add_109 = torch.ops.aten.add.Tensor(add_107, view_942); view_942 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16) + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 256, '0'); convert_element_type_911 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32) + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = rsqrt_55 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_249); mul_220 = wait_tensor_249 = None + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16) + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 256, '0'); convert_element_type_914 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_250, [1, 0]); wait_tensor_250 = None + view_945 = torch.ops.aten.view.default(convert_element_type_913, [16384, 4096]); convert_element_type_913 = None + mm_193 = torch.ops.aten.mm.default(view_945, permute_305); permute_305 = None + view_946 = torch.ops.aten.view.default(mm_193, [2, 8192, 14336]) + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_946, torch.float32); view_946 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); convert_element_type_917 = sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16) + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 256, '0'); convert_element_type_919 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + mm_194 = torch.ops.aten.mm.default(view_945, permute_306); view_945 = permute_306 = None + view_949 = torch.ops.aten.view.default(mm_194, [2, 8192, 14336]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_949); convert_element_type_918 = view_949 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16) + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 256, '0'); convert_element_type_922 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + view_951 = torch.ops.aten.view.default(mul_223, [16384, 14336]); mul_223 = None + mm_195 = torch.ops.aten.mm.default(view_951, permute_307); view_951 = permute_307 = None + view_952 = torch.ops.aten.view.default(mm_195, [2, 8192, 4096]); mm_195 = None + add_111 = torch.ops.aten.add.Tensor(add_109, view_952); add_109 = view_952 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16) + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 256, '0'); convert_element_type_925 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32) + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = rsqrt_56 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_253); mul_224 = wait_tensor_253 = None + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16) + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 256, '0'); convert_element_type_928 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + view_955 = torch.ops.aten.view.default(convert_element_type_927, [16384, 4096]); convert_element_type_927 = None + mm_196 = torch.ops.aten.mm.default(view_955, permute_308); permute_308 = None + view_956 = torch.ops.aten.view.default(mm_196, [2, 8192, 4096]) + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16) + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 256, '0'); convert_element_type_931 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_255, [1, 0]); wait_tensor_255 = None + mm_197 = torch.ops.aten.mm.default(view_955, permute_309); permute_309 = None + view_959 = torch.ops.aten.view.default(mm_197, [2, 8192, 1024]); mm_197 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16) + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 256, '0'); convert_element_type_934 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_256, [1, 0]); wait_tensor_256 = None + mm_198 = torch.ops.aten.mm.default(view_955, permute_310); view_955 = permute_310 = None + view_962 = torch.ops.aten.view.default(mm_198, [2, 8192, 1024]) + view_963 = torch.ops.aten.view.default(view_956, [2, 8192, -1, 128]); view_956 = None + view_964 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_965 = torch.ops.aten.view.default(view_962, [2, 8192, -1, 128]); view_962 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_963, torch.float32); view_963 = None + view_966 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 32, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_966); view_966 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_964, torch.float32); view_964 = None + view_967 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 8, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_967); view_967 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_16); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_969 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 32, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_16); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_970 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 8, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_969, torch.bfloat16); view_969 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_970, torch.bfloat16); view_970 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 8, 4, 128]); unsqueeze_56 = None + clone_56 = torch.ops.aten.clone.default(expand_56, memory_format = torch.contiguous_format); expand_56 = None + view_971 = torch.ops.aten.view.default(clone_56, [2, 8192, 32, 128]); clone_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_965, 3); view_965 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 8, 4, 128]); unsqueeze_57 = None + clone_57 = torch.ops.aten.clone.default(expand_57, memory_format = torch.contiguous_format); expand_57 = None + view_972 = torch.ops.aten.view.default(clone_57, [2, 8192, 32, 128]); clone_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_971, [0, 2, 1, 3]); view_971 = None + permute_313 = torch.ops.aten.permute.default(view_972, [0, 2, 1, 3]); view_972 = None + _scaled_dot_product_cudnn_attention_28 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_311, permute_312, permute_313, None, True, 0.0, True); permute_311 = permute_312 = permute_313 = None + getitem_252 = _scaled_dot_product_cudnn_attention_28[0] + getitem_253 = _scaled_dot_product_cudnn_attention_28[1] + getitem_258 = _scaled_dot_product_cudnn_attention_28[6] + getitem_259 = _scaled_dot_product_cudnn_attention_28[7]; _scaled_dot_product_cudnn_attention_28 = None + permute_314 = torch.ops.aten.permute.default(getitem_252, [0, 2, 1, 3]) + view_973 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16) + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 256, '0'); convert_element_type_941 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_257, [1, 0]); wait_tensor_257 = None + view_975 = torch.ops.aten.view.default(view_973, [16384, 4096]); view_973 = None + mm_199 = torch.ops.aten.mm.default(view_975, permute_315); view_975 = permute_315 = None + view_976 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + add_113 = torch.ops.aten.add.Tensor(add_111, view_976); view_976 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16) + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 256, '0'); convert_element_type_944 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32) + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = rsqrt_57 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_258); mul_228 = wait_tensor_258 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16) + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 256, '0'); convert_element_type_947 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + view_979 = torch.ops.aten.view.default(convert_element_type_946, [16384, 4096]); convert_element_type_946 = None + mm_200 = torch.ops.aten.mm.default(view_979, permute_316); permute_316 = None + view_980 = torch.ops.aten.view.default(mm_200, [2, 8192, 14336]) + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_980, torch.float32); view_980 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); convert_element_type_950 = sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16) + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 256, '0'); convert_element_type_952 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + mm_201 = torch.ops.aten.mm.default(view_979, permute_317); view_979 = permute_317 = None + view_983 = torch.ops.aten.view.default(mm_201, [2, 8192, 14336]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_983); convert_element_type_951 = view_983 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16) + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 256, '0'); convert_element_type_955 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + view_985 = torch.ops.aten.view.default(mul_231, [16384, 14336]); mul_231 = None + mm_202 = torch.ops.aten.mm.default(view_985, permute_318); view_985 = permute_318 = None + view_986 = torch.ops.aten.view.default(mm_202, [2, 8192, 4096]); mm_202 = None + add_115 = torch.ops.aten.add.Tensor(add_113, view_986); add_113 = view_986 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16) + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 256, '0'); convert_element_type_958 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32) + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = rsqrt_58 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_262); mul_232 = wait_tensor_262 = None + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16) + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 256, '0'); convert_element_type_961 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_263, [1, 0]); wait_tensor_263 = None + view_989 = torch.ops.aten.view.default(convert_element_type_960, [16384, 4096]); convert_element_type_960 = None + mm_203 = torch.ops.aten.mm.default(view_989, permute_319); permute_319 = None + view_990 = torch.ops.aten.view.default(mm_203, [2, 8192, 4096]) + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16) + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 256, '0'); convert_element_type_964 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + mm_204 = torch.ops.aten.mm.default(view_989, permute_320); permute_320 = None + view_993 = torch.ops.aten.view.default(mm_204, [2, 8192, 1024]); mm_204 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16) + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 256, '0'); convert_element_type_967 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_205 = torch.ops.aten.mm.default(view_989, permute_321); view_989 = permute_321 = None + view_996 = torch.ops.aten.view.default(mm_205, [2, 8192, 1024]) + view_997 = torch.ops.aten.view.default(view_990, [2, 8192, -1, 128]); view_990 = None + view_998 = torch.ops.aten.view.default(view_993, [2, 8192, -1, 128]); view_993 = None + view_999 = torch.ops.aten.view.default(view_996, [2, 8192, -1, 128]); view_996 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + view_1000 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 32, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1000); view_1000 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_998, torch.float32); view_998 = None + view_1001 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 8, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1001); view_1001 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_16); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_1003 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 32, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_16); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_1004 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 8, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_1003, torch.bfloat16); view_1003 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_1004, torch.bfloat16); view_1004 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 8, 4, 128]); unsqueeze_58 = None + clone_58 = torch.ops.aten.clone.default(expand_58, memory_format = torch.contiguous_format); expand_58 = None + view_1005 = torch.ops.aten.view.default(clone_58, [2, 8192, 32, 128]); clone_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_999, 3); view_999 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 8, 4, 128]); unsqueeze_59 = None + clone_59 = torch.ops.aten.clone.default(expand_59, memory_format = torch.contiguous_format); expand_59 = None + view_1006 = torch.ops.aten.view.default(clone_59, [2, 8192, 32, 128]); clone_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_1005, [0, 2, 1, 3]); view_1005 = None + permute_324 = torch.ops.aten.permute.default(view_1006, [0, 2, 1, 3]); view_1006 = None + _scaled_dot_product_cudnn_attention_29 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_322, permute_323, permute_324, None, True, 0.0, True); permute_322 = permute_323 = permute_324 = None + getitem_261 = _scaled_dot_product_cudnn_attention_29[0] + getitem_262 = _scaled_dot_product_cudnn_attention_29[1] + getitem_267 = _scaled_dot_product_cudnn_attention_29[6] + getitem_268 = _scaled_dot_product_cudnn_attention_29[7]; _scaled_dot_product_cudnn_attention_29 = None + permute_325 = torch.ops.aten.permute.default(getitem_261, [0, 2, 1, 3]) + view_1007 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16) + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 256, '0'); convert_element_type_974 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + view_1009 = torch.ops.aten.view.default(view_1007, [16384, 4096]); view_1007 = None + mm_206 = torch.ops.aten.mm.default(view_1009, permute_326); view_1009 = permute_326 = None + view_1010 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + add_117 = torch.ops.aten.add.Tensor(add_115, view_1010); view_1010 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16) + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 256, '0'); convert_element_type_977 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32) + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = rsqrt_59 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_267); mul_236 = wait_tensor_267 = None + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16) + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 256, '0'); convert_element_type_980 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + view_1013 = torch.ops.aten.view.default(convert_element_type_979, [16384, 4096]); convert_element_type_979 = None + mm_207 = torch.ops.aten.mm.default(view_1013, permute_327); permute_327 = None + view_1014 = torch.ops.aten.view.default(mm_207, [2, 8192, 14336]) + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_1014, torch.float32); view_1014 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); convert_element_type_983 = sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16) + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 256, '0'); convert_element_type_985 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_269, [1, 0]); wait_tensor_269 = None + mm_208 = torch.ops.aten.mm.default(view_1013, permute_328); view_1013 = permute_328 = None + view_1017 = torch.ops.aten.view.default(mm_208, [2, 8192, 14336]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_1017); convert_element_type_984 = view_1017 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16) + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 256, '0'); convert_element_type_988 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_270, [1, 0]); wait_tensor_270 = None + view_1019 = torch.ops.aten.view.default(mul_239, [16384, 14336]); mul_239 = None + mm_209 = torch.ops.aten.mm.default(view_1019, permute_329); view_1019 = permute_329 = None + view_1020 = torch.ops.aten.view.default(mm_209, [2, 8192, 4096]); mm_209 = None + add_119 = torch.ops.aten.add.Tensor(add_117, view_1020); add_117 = view_1020 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16) + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 256, '0'); convert_element_type_991 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32) + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = rsqrt_60 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_271); mul_240 = wait_tensor_271 = None + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16) + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 256, '0'); convert_element_type_994 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + view_1023 = torch.ops.aten.view.default(convert_element_type_993, [16384, 4096]); convert_element_type_993 = None + mm_210 = torch.ops.aten.mm.default(view_1023, permute_330); permute_330 = None + view_1024 = torch.ops.aten.view.default(mm_210, [2, 8192, 4096]) + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16) + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 256, '0'); convert_element_type_997 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + mm_211 = torch.ops.aten.mm.default(view_1023, permute_331); permute_331 = None + view_1027 = torch.ops.aten.view.default(mm_211, [2, 8192, 1024]); mm_211 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16) + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 256, '0'); convert_element_type_1000 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_274, [1, 0]); wait_tensor_274 = None + mm_212 = torch.ops.aten.mm.default(view_1023, permute_332); view_1023 = permute_332 = None + view_1030 = torch.ops.aten.view.default(mm_212, [2, 8192, 1024]) + view_1031 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1032 = torch.ops.aten.view.default(view_1027, [2, 8192, -1, 128]); view_1027 = None + view_1033 = torch.ops.aten.view.default(view_1030, [2, 8192, -1, 128]); view_1030 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_1031, torch.float32); view_1031 = None + view_1034 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 32, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1034); view_1034 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_1032, torch.float32); view_1032 = None + view_1035 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 8, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1035); view_1035 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_16); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_1037 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 32, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_16); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_1038 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 8, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_1037, torch.bfloat16); view_1037 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_1038, torch.bfloat16); view_1038 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 8, 4, 128]); unsqueeze_60 = None + clone_60 = torch.ops.aten.clone.default(expand_60, memory_format = torch.contiguous_format); expand_60 = None + view_1039 = torch.ops.aten.view.default(clone_60, [2, 8192, 32, 128]); clone_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_1033, 3); view_1033 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 8, 4, 128]); unsqueeze_61 = None + clone_61 = torch.ops.aten.clone.default(expand_61, memory_format = torch.contiguous_format); expand_61 = None + view_1040 = torch.ops.aten.view.default(clone_61, [2, 8192, 32, 128]); clone_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_1039, [0, 2, 1, 3]); view_1039 = None + permute_335 = torch.ops.aten.permute.default(view_1040, [0, 2, 1, 3]); view_1040 = None + _scaled_dot_product_cudnn_attention_30 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_333, permute_334, permute_335, None, True, 0.0, True); permute_333 = permute_334 = permute_335 = None + getitem_270 = _scaled_dot_product_cudnn_attention_30[0] + getitem_271 = _scaled_dot_product_cudnn_attention_30[1] + getitem_276 = _scaled_dot_product_cudnn_attention_30[6] + getitem_277 = _scaled_dot_product_cudnn_attention_30[7]; _scaled_dot_product_cudnn_attention_30 = None + permute_336 = torch.ops.aten.permute.default(getitem_270, [0, 2, 1, 3]) + view_1041 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16) + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 256, '0'); convert_element_type_1007 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_275, [1, 0]); wait_tensor_275 = None + view_1043 = torch.ops.aten.view.default(view_1041, [16384, 4096]); view_1041 = None + mm_213 = torch.ops.aten.mm.default(view_1043, permute_337); view_1043 = permute_337 = None + view_1044 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + add_121 = torch.ops.aten.add.Tensor(add_119, view_1044); view_1044 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16) + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 256, '0'); convert_element_type_1010 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32) + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = rsqrt_61 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_276); mul_244 = wait_tensor_276 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16) + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 256, '0'); convert_element_type_1013 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + view_1047 = torch.ops.aten.view.default(convert_element_type_1012, [16384, 4096]); convert_element_type_1012 = None + mm_214 = torch.ops.aten.mm.default(view_1047, permute_338); permute_338 = None + view_1048 = torch.ops.aten.view.default(mm_214, [2, 8192, 14336]) + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_1048, torch.float32); view_1048 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); convert_element_type_1016 = sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16) + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 256, '0'); convert_element_type_1018 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_215 = torch.ops.aten.mm.default(view_1047, permute_339); view_1047 = permute_339 = None + view_1051 = torch.ops.aten.view.default(mm_215, [2, 8192, 14336]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_1051); convert_element_type_1017 = view_1051 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16) + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 256, '0'); convert_element_type_1021 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + view_1053 = torch.ops.aten.view.default(mul_247, [16384, 14336]); mul_247 = None + mm_216 = torch.ops.aten.mm.default(view_1053, permute_340); view_1053 = permute_340 = None + view_1054 = torch.ops.aten.view.default(mm_216, [2, 8192, 4096]); mm_216 = None + add_123 = torch.ops.aten.add.Tensor(add_121, view_1054); add_121 = view_1054 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16) + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 256, '0'); convert_element_type_1024 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32) + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = rsqrt_62 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_280); mul_248 = wait_tensor_280 = None + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16) + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 256, '0'); convert_element_type_1027 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + view_1057 = torch.ops.aten.view.default(convert_element_type_1026, [16384, 4096]); convert_element_type_1026 = None + mm_217 = torch.ops.aten.mm.default(view_1057, permute_341); permute_341 = None + view_1058 = torch.ops.aten.view.default(mm_217, [2, 8192, 4096]) + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16) + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 256, '0'); convert_element_type_1030 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + mm_218 = torch.ops.aten.mm.default(view_1057, permute_342); permute_342 = None + view_1061 = torch.ops.aten.view.default(mm_218, [2, 8192, 1024]); mm_218 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16) + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 256, '0'); convert_element_type_1033 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + mm_219 = torch.ops.aten.mm.default(view_1057, permute_343); view_1057 = permute_343 = None + view_1064 = torch.ops.aten.view.default(mm_219, [2, 8192, 1024]) + view_1065 = torch.ops.aten.view.default(view_1058, [2, 8192, -1, 128]); view_1058 = None + view_1066 = torch.ops.aten.view.default(view_1061, [2, 8192, -1, 128]); view_1061 = None + view_1067 = torch.ops.aten.view.default(view_1064, [2, 8192, -1, 128]); view_1064 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_1065, torch.float32); view_1065 = None + view_1068 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 32, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1068); view_1068 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_1066, torch.float32); view_1066 = None + view_1069 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 8, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1069); view_1069 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_16); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_1071 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 32, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_16); view_as_complex_63 = view_16 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_1072 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 8, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_1071, torch.bfloat16); view_1071 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_1072, torch.bfloat16); view_1072 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 8, 4, 128]); unsqueeze_62 = None + clone_62 = torch.ops.aten.clone.default(expand_62, memory_format = torch.contiguous_format); expand_62 = None + view_1073 = torch.ops.aten.view.default(clone_62, [2, 8192, 32, 128]); clone_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_1067, 3); view_1067 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 8, 4, 128]); unsqueeze_63 = None + clone_63 = torch.ops.aten.clone.default(expand_63, memory_format = torch.contiguous_format); expand_63 = None + view_1074 = torch.ops.aten.view.default(clone_63, [2, 8192, 32, 128]); clone_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_1073, [0, 2, 1, 3]); view_1073 = None + permute_346 = torch.ops.aten.permute.default(view_1074, [0, 2, 1, 3]); view_1074 = None + _scaled_dot_product_cudnn_attention_31 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_344, permute_345, permute_346, None, True, 0.0, True); permute_344 = permute_345 = permute_346 = None + getitem_279 = _scaled_dot_product_cudnn_attention_31[0] + getitem_280 = _scaled_dot_product_cudnn_attention_31[1] + getitem_285 = _scaled_dot_product_cudnn_attention_31[6] + getitem_286 = _scaled_dot_product_cudnn_attention_31[7]; _scaled_dot_product_cudnn_attention_31 = None + permute_347 = torch.ops.aten.permute.default(getitem_279, [0, 2, 1, 3]) + view_1075 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16) + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 256, '0'); convert_element_type_1040 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1077 = torch.ops.aten.view.default(view_1075, [16384, 4096]); view_1075 = None + mm_220 = torch.ops.aten.mm.default(view_1077, permute_348); view_1077 = permute_348 = None + view_1078 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + add_125 = torch.ops.aten.add.Tensor(add_123, view_1078); view_1078 = None + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16) + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 256, '0'); convert_element_type_1043 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32) + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = rsqrt_63 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_285); mul_252 = wait_tensor_285 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16) + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 256, '0'); convert_element_type_1046 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + view_1081 = torch.ops.aten.view.default(convert_element_type_1045, [16384, 4096]); convert_element_type_1045 = None + mm_221 = torch.ops.aten.mm.default(view_1081, permute_349); permute_349 = None + view_1082 = torch.ops.aten.view.default(mm_221, [2, 8192, 14336]) + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_1082, torch.float32); view_1082 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); convert_element_type_1049 = sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16) + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 256, '0'); convert_element_type_1051 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_287, [1, 0]); wait_tensor_287 = None + mm_222 = torch.ops.aten.mm.default(view_1081, permute_350); view_1081 = permute_350 = None + view_1085 = torch.ops.aten.view.default(mm_222, [2, 8192, 14336]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_1085); convert_element_type_1050 = view_1085 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16) + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 256, '0'); convert_element_type_1054 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + view_1087 = torch.ops.aten.view.default(mul_255, [16384, 14336]); mul_255 = None + mm_223 = torch.ops.aten.mm.default(view_1087, permute_351); view_1087 = permute_351 = None + view_1088 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]) + add_127 = torch.ops.aten.add.Tensor(add_125, view_1088); add_125 = view_1088 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16) + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 256, '0'); convert_element_type_1057 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + pow_65 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1058, 2) + mean_64 = torch.ops.aten.mean.dim(pow_65, [2], True); pow_65 = None + add_128 = torch.ops.aten.add.Scalar(mean_64, 1e-05); mean_64 = None + rsqrt_64 = torch.ops.aten.rsqrt.default(add_128); add_128 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_257 = torch.ops.aten.mul.Tensor(mul_256, wait_tensor_289); mul_256 = wait_tensor_289 = None + convert_element_type_1059 = torch.ops.prims.convert_element_type.default(mul_257, torch.bfloat16); mul_257 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16) + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 256, '0'); convert_element_type_1060 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + view_1091 = torch.ops.aten.view.default(convert_element_type_1059, [16384, 4096]); convert_element_type_1059 = None + mm_224 = torch.ops.aten.mm.default(view_1091, permute_352); permute_352 = None + view_1092 = torch.ops.aten.view.default(mm_224, [2, 8192, 128256]); mm_224 = None + return (view_1092, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, embedding, mm, mm_2, getitem, getitem_1, getitem_6, getitem_7, mm_4, add_3, mm_7, mm_9, getitem_9, getitem_10, getitem_15, getitem_16, mm_11, add_7, mm_14, mm_16, getitem_18, getitem_19, getitem_24, getitem_25, mm_18, add_11, mm_21, mm_23, getitem_27, getitem_28, getitem_33, getitem_34, mm_25, add_15, mm_28, mm_30, getitem_36, getitem_37, getitem_42, getitem_43, mm_32, add_19, mm_35, mm_37, getitem_45, getitem_46, getitem_51, getitem_52, mm_39, add_23, mm_42, mm_44, getitem_54, getitem_55, getitem_60, getitem_61, mm_46, add_27, mm_49, mm_51, getitem_63, getitem_64, getitem_69, getitem_70, mm_53, add_31, mm_56, mm_58, getitem_72, getitem_73, getitem_78, getitem_79, mm_60, add_35, mm_63, mm_65, getitem_81, getitem_82, getitem_87, getitem_88, mm_67, add_39, mm_70, mm_72, getitem_90, getitem_91, getitem_96, getitem_97, mm_74, add_43, mm_77, mm_79, getitem_99, getitem_100, getitem_105, getitem_106, mm_81, add_47, mm_84, mm_86, getitem_108, getitem_109, getitem_114, getitem_115, mm_88, add_51, mm_91, mm_93, getitem_117, getitem_118, getitem_123, getitem_124, mm_95, add_55, mm_98, mm_100, getitem_126, getitem_127, getitem_132, getitem_133, mm_102, add_59, mm_105, mm_107, getitem_135, getitem_136, getitem_141, getitem_142, mm_109, add_63, mm_112, mm_114, getitem_144, getitem_145, getitem_150, getitem_151, mm_116, add_67, mm_119, mm_121, getitem_153, getitem_154, getitem_159, getitem_160, mm_123, add_71, mm_126, mm_128, getitem_162, getitem_163, getitem_168, getitem_169, mm_130, add_75, mm_133, mm_135, getitem_171, getitem_172, getitem_177, getitem_178, mm_137, add_79, mm_140, mm_142, getitem_180, getitem_181, getitem_186, getitem_187, mm_144, add_83, mm_147, mm_149, getitem_189, getitem_190, getitem_195, getitem_196, mm_151, add_87, mm_154, mm_156, getitem_198, getitem_199, getitem_204, getitem_205, mm_158, add_91, mm_161, mm_163, getitem_207, getitem_208, getitem_213, getitem_214, mm_165, add_95, mm_168, mm_170, getitem_216, getitem_217, getitem_222, getitem_223, mm_172, add_99, mm_175, mm_177, getitem_225, getitem_226, getitem_231, getitem_232, mm_179, add_103, mm_182, mm_184, getitem_234, getitem_235, getitem_240, getitem_241, mm_186, add_107, mm_189, mm_191, getitem_243, getitem_244, getitem_249, getitem_250, mm_193, add_111, mm_196, mm_198, getitem_252, getitem_253, getitem_258, getitem_259, mm_200, add_115, mm_203, mm_205, getitem_261, getitem_262, getitem_267, getitem_268, mm_207, add_119, mm_210, mm_212, getitem_270, getitem_271, getitem_276, getitem_277, mm_214, add_123, mm_217, mm_219, getitem_279, getitem_280, getitem_285, getitem_286, mm_221, mm_223, rsqrt_64, view_1091) + +def load_args(reader): + buf0 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf0, (501, 4096), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf3, (16,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf4, (16, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf5, (4, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf6, (4, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf7, (16, 4096), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf8, (16,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf9, (56, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf10, (56, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf11, (16, 14336), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf12, (16,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf13, (16, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf14, (4, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf15, (4, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf16, (16, 4096), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf17, (16,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf18, (56, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf19, (56, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf20, (16, 14336), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf21, (16,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf22, (16, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf23, (4, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf24, (4, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf25, (16, 4096), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf26, (16,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf27, (56, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf28, (56, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf29, (16, 14336), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf30, (16,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf31, (16, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf32, (4, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf33, (4, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf34, (16, 4096), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf35, (16,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf36, (56, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf37, (56, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf38, (16, 14336), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf39, (16,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf40, (16, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf41, (4, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf42, (4, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf43, (16, 4096), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf44, (16,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf45, (56, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf46, (56, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf47, (16, 14336), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf48, (16,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf49, (16, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf50, (4, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf51, (4, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf52, (16, 4096), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf53, (16,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf54, (56, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf55, (56, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf56, (16, 14336), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf57, (16,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf58, (16, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf59, (4, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf60, (4, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf61, (16, 4096), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf62, (16,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf63, (56, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf64, (56, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf65, (16, 14336), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf66, (16,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf67, (16, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf68, (4, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf69, (4, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf70, (16, 4096), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf71, (16,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf72, (56, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf73, (56, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf74, (16, 14336), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf75, (16,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf76, (16, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf77, (4, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf78, (4, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf79, (16, 4096), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf80, (16,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf81, (56, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf82, (56, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf83, (16, 14336), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf84, (16,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf85, (16, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf86, (4, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf87, (4, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf88, (16, 4096), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf89, (16,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf90, (56, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf91, (56, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf92, (16, 14336), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf93, (16,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf94, (16, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf95, (4, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf96, (4, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf97, (16, 4096), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf98, (16,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf99, (56, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf100, (56, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf101, (16, 14336), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf102, (16,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf103, (16, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf104, (4, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf105, (4, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf106, (16, 4096), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf107, (16,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf108, (56, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf109, (56, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf110, (16, 14336), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf111, (16,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf112, (16, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf113, (4, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf114, (4, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf115, (16, 4096), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf116, (16,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf117, (56, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf118, (56, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf119, (16, 14336), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf120, (16,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf121, (16, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf122, (4, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf123, (4, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf124, (16, 4096), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf125, (16,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf126, (56, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf127, (56, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf128, (16, 14336), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf129, (16,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf130, (16, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf131, (4, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf132, (4, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf133, (16, 4096), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf134, (16,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf135, (56, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf136, (56, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf137, (16, 14336), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf138, (16,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf139, (16, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf140, (4, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf141, (4, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf142, (16, 4096), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf143, (16,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf144, (56, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf145, (56, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf146, (16, 14336), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf147, (16,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf148, (16, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf149, (4, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf150, (4, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf151, (16, 4096), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf152, (16,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf153, (56, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf154, (56, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf155, (16, 14336), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf156, (16,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf157, (16, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf158, (4, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf159, (4, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf160, (16, 4096), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf161, (16,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf162, (56, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf163, (56, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf164, (16, 14336), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf165, (16,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf166, (16, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf167, (4, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf168, (4, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf169, (16, 4096), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf170, (16,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf171, (56, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf172, (56, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf173, (16, 14336), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf174, (16,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf175, (16, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf176, (4, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf177, (4, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf178, (16, 4096), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf179, (16,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf180, (56, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf181, (56, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf182, (16, 14336), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf183, (16,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf184, (16, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf185, (4, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf186, (4, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf187, (16, 4096), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf188, (16,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf189, (56, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf190, (56, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf191, (16, 14336), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf192, (16,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf193, (16, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf194, (4, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf195, (4, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf196, (16, 4096), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf197, (16,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf198, (56, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf199, (56, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf200, (16, 14336), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf201, (16,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf202, (16, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf203, (4, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf204, (4, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf205, (16, 4096), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf206, (16,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf207, (56, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf208, (56, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf209, (16, 14336), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf210, (16,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf211, (16, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf212, (4, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf213, (4, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf214, (16, 4096), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf215, (16,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf216, (56, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf217, (56, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf218, (16, 14336), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf219, (16,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf220, (16, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf221, (4, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf222, (4, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf223, (16, 4096), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf224, (16,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf225, (56, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf226, (56, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf227, (16, 14336), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf228, (16,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf229, (16, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf230, (4, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf231, (4, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf232, (16, 4096), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf233, (16,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf234, (56, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf235, (56, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf236, (16, 14336), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf237, (16,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf238, (16, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf239, (4, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf240, (4, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf241, (16, 4096), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf242, (16,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf243, (56, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf244, (56, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf245, (16, 14336), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf246, (16,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf247, (16, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf248, (4, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf249, (4, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf250, (16, 4096), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf251, (16,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf252, (56, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf253, (56, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf254, (16, 14336), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf255, (16,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf256, (16, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf257, (4, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf258, (4, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf259, (16, 4096), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf260, (16,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf261, (56, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf262, (56, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf263, (16, 14336), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf264, (16,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf265, (16, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf266, (4, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf267, (4, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf268, (16, 4096), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf269, (16,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf270, (56, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf271, (56, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf272, (16, 14336), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf273, (16,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf274, (16, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf275, (4, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf276, (4, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf277, (16, 4096), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf278, (16,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf279, (56, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf280, (56, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf281, (16, 14336), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf282, (16,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf283, (16, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf284, (4, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf285, (4, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf286, (16, 4096), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf287, (16,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf288, (56, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf289, (56, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf290, (16, 14336), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 64, device=device(type='cuda', index=0)) + reader.tensor(buf291, (16,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf292, (501, 4096), is_leaf=True) # primals_293 + +load_args._version = 0 + +def get_pg_config(): + return {'0': {'size': 256, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls32_8.table" diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_2d.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_2d.py new file mode 100644 index 00000000..d406d50f --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_256_2d.py @@ -0,0 +1,5658 @@ +# fmt: off +# flake8: noqa +# isort: skip_file +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293): + convert_element_type = torch.ops.prims.convert_element_type.default(primals_2, torch.bfloat16) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type, 32, '0'); convert_element_type = None + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None + lt = torch.ops.aten.lt.Scalar(primals_1, 0) + ge = torch.ops.aten.ge.Scalar(primals_1, 16032) + bitwise_or = torch.ops.aten.bitwise_or.Tensor(lt, ge); lt = ge = None + sub = torch.ops.aten.sub.Tensor(primals_1, 0) + full_default = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + index_put = torch.ops.aten.index_put.default(sub, [bitwise_or], full_default); sub = full_default = None + embedding = torch.ops.aten.embedding.default(wait_tensor, index_put); wait_tensor = index_put = None + full_default_1 = torch.ops.aten.full.default([], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + index_put_1 = torch.ops.aten.index_put.default(embedding, [bitwise_or], full_default_1); embedding = bitwise_or = full_default_1 = None + split_1 = torch.ops.aten.split.Tensor(index_put_1, 1024, 1); index_put_1 = None + getitem_8 = split_1[0] + getitem_17 = split_1[1] + getitem_26 = split_1[2] + getitem_35 = split_1[3] + getitem_44 = split_1[4] + getitem_53 = split_1[5] + getitem_62 = split_1[6] + getitem_71 = split_1[7]; split_1 = None + cat = torch.ops.aten.cat.default([getitem_8, getitem_17, getitem_26, getitem_35, getitem_44, getitem_53, getitem_62, getitem_71]); getitem_8 = getitem_17 = getitem_26 = getitem_35 = getitem_44 = getitem_53 = getitem_62 = getitem_71 = None + reduce_scatter_tensor = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat, 'sum', 8, '1'); cat = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16) + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 32, '0'); convert_element_type_1 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32) + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = rsqrt = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_2); mul = wait_tensor_2 = None + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_3, 8, '1'); convert_element_type_3 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + split_9 = torch.ops.aten.split.Tensor(wait_tensor_3, 2); wait_tensor_3 = None + getitem_72 = split_9[0] + getitem_73 = split_9[1] + getitem_74 = split_9[2] + getitem_75 = split_9[3] + getitem_76 = split_9[4] + getitem_77 = split_9[5] + getitem_78 = split_9[6] + getitem_79 = split_9[7]; split_9 = None + cat_1 = torch.ops.aten.cat.default([getitem_72, getitem_73, getitem_74, getitem_75, getitem_76, getitem_77, getitem_78, getitem_79], 1); getitem_72 = getitem_73 = getitem_74 = getitem_75 = getitem_76 = getitem_77 = getitem_78 = getitem_79 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16) + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 32, '0'); convert_element_type_4 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + view_15 = torch.ops.aten.view.default(cat_1, [16384, 4096]); cat_1 = None + mm = torch.ops.aten.mm.default(view_15, permute); permute = None + view_16 = torch.ops.aten.view.default(mm, [2, 8192, 512]) + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16) + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 32, '0'); convert_element_type_7 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + mm_1 = torch.ops.aten.mm.default(view_15, permute_1); permute_1 = None + view_23 = torch.ops.aten.view.default(mm_1, [2, 8192, 128]); mm_1 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16) + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 32, '0'); convert_element_type_10 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + mm_2 = torch.ops.aten.mm.default(view_15, permute_2); view_15 = permute_2 = None + view_30 = torch.ops.aten.view.default(mm_2, [2, 8192, 128]) + view_32 = torch.ops.aten.view.default(view_16, [2, 8192, -1, 128]); view_16 = None + view_33 = torch.ops.aten.view.default(view_23, [2, 8192, -1, 128]); view_23 = None + view_34 = torch.ops.aten.view.default(view_30, [2, 8192, -1, 128]); view_30 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_32, torch.float32); view_32 = None + view_35 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 4, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_35); view_35 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_33, torch.float32); view_33 = None + view_36 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 1, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_36); view_36 = None + view_37 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]) + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_37); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_38 = torch.ops.aten.view.default(view_as_real, [2, 8192, 4, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_37); view_as_complex_1 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_39 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 1, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_38, torch.bfloat16); view_38 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_39, torch.bfloat16); view_39 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 1, 4, 128]); unsqueeze = None + view_40 = torch.ops.aten.view.default(expand, [2, 8192, 4, 128]); expand = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_34, 3); view_34 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 1, 4, 128]); unsqueeze_1 = None + view_41 = torch.ops.aten.view.default(expand_1, [2, 8192, 4, 128]); expand_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_40, [0, 2, 1, 3]); view_40 = None + permute_5 = torch.ops.aten.permute.default(view_41, [0, 2, 1, 3]); view_41 = None + _scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_3, permute_4, permute_5, None, True, 0.0, True); permute_3 = permute_4 = permute_5 = None + getitem_80 = _scaled_dot_product_cudnn_attention[0] + getitem_81 = _scaled_dot_product_cudnn_attention[1] + getitem_86 = _scaled_dot_product_cudnn_attention[6] + getitem_87 = _scaled_dot_product_cudnn_attention[7]; _scaled_dot_product_cudnn_attention = None + permute_6 = torch.ops.aten.permute.default(getitem_80, [0, 2, 1, 3]) + view_42 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16) + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 32, '0'); convert_element_type_17 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + view_48 = torch.ops.aten.view.default(view_42, [16384, 512]); view_42 = None + mm_3 = torch.ops.aten.mm.default(view_48, permute_7); view_48 = permute_7 = None + view_49 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + split_10 = torch.ops.aten.split.Tensor(view_49, 1024, 1); view_49 = None + getitem_89 = split_10[0] + getitem_90 = split_10[1] + getitem_91 = split_10[2] + getitem_92 = split_10[3] + getitem_93 = split_10[4] + getitem_94 = split_10[5] + getitem_95 = split_10[6] + getitem_96 = split_10[7]; split_10 = None + cat_2 = torch.ops.aten.cat.default([getitem_89, getitem_90, getitem_91, getitem_92, getitem_93, getitem_94, getitem_95, getitem_96]); getitem_89 = getitem_90 = getitem_91 = getitem_92 = getitem_93 = getitem_94 = getitem_95 = getitem_96 = None + reduce_scatter_tensor_1 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_2, 'sum', 8, '1'); cat_2 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1) + add_1 = torch.ops.aten.add.Tensor(wait_tensor_1, wait_tensor_8); wait_tensor_8 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16) + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 32, '0'); convert_element_type_20 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32) + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = rsqrt_1 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_9); mul_4 = wait_tensor_9 = None + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_22, 8, '1'); convert_element_type_22 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + split_11 = torch.ops.aten.split.Tensor(wait_tensor_10, 2); wait_tensor_10 = None + getitem_97 = split_11[0] + getitem_98 = split_11[1] + getitem_99 = split_11[2] + getitem_100 = split_11[3] + getitem_101 = split_11[4] + getitem_102 = split_11[5] + getitem_103 = split_11[6] + getitem_104 = split_11[7]; split_11 = None + cat_3 = torch.ops.aten.cat.default([getitem_97, getitem_98, getitem_99, getitem_100, getitem_101, getitem_102, getitem_103, getitem_104], 1); getitem_97 = getitem_98 = getitem_99 = getitem_100 = getitem_101 = getitem_102 = getitem_103 = getitem_104 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16) + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 32, '0'); convert_element_type_23 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + view_60 = torch.ops.aten.view.default(cat_3, [16384, 4096]); cat_3 = None + mm_4 = torch.ops.aten.mm.default(view_60, permute_8); permute_8 = None + view_61 = torch.ops.aten.view.default(mm_4, [2, 8192, 1792]) + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_61, torch.float32); view_61 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); convert_element_type_26 = sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16) + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 32, '0'); convert_element_type_28 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_5 = torch.ops.aten.mm.default(view_60, permute_9); view_60 = permute_9 = None + view_68 = torch.ops.aten.view.default(mm_5, [2, 8192, 1792]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_68); convert_element_type_27 = view_68 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16) + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 32, '0'); convert_element_type_31 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + view_75 = torch.ops.aten.view.default(mul_7, [16384, 1792]); mul_7 = None + mm_6 = torch.ops.aten.mm.default(view_75, permute_10); view_75 = permute_10 = None + view_76 = torch.ops.aten.view.default(mm_6, [2, 8192, 4096]); mm_6 = None + split_12 = torch.ops.aten.split.Tensor(view_76, 1024, 1); view_76 = None + getitem_105 = split_12[0] + getitem_106 = split_12[1] + getitem_107 = split_12[2] + getitem_108 = split_12[3] + getitem_109 = split_12[4] + getitem_110 = split_12[5] + getitem_111 = split_12[6] + getitem_112 = split_12[7]; split_12 = None + cat_4 = torch.ops.aten.cat.default([getitem_105, getitem_106, getitem_107, getitem_108, getitem_109, getitem_110, getitem_111, getitem_112]); getitem_105 = getitem_106 = getitem_107 = getitem_108 = getitem_109 = getitem_110 = getitem_111 = getitem_112 = None + reduce_scatter_tensor_2 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_4, 'sum', 8, '1'); cat_4 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None + add_3 = torch.ops.aten.add.Tensor(add_1, wait_tensor_14); add_1 = wait_tensor_14 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16) + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 32, '0'); convert_element_type_34 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32) + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = rsqrt_2 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_15); mul_8 = wait_tensor_15 = None + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_36, 8, '1'); convert_element_type_36 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + split_13 = torch.ops.aten.split.Tensor(wait_tensor_16, 2); wait_tensor_16 = None + getitem_113 = split_13[0] + getitem_114 = split_13[1] + getitem_115 = split_13[2] + getitem_116 = split_13[3] + getitem_117 = split_13[4] + getitem_118 = split_13[5] + getitem_119 = split_13[6] + getitem_120 = split_13[7]; split_13 = None + cat_5 = torch.ops.aten.cat.default([getitem_113, getitem_114, getitem_115, getitem_116, getitem_117, getitem_118, getitem_119, getitem_120], 1); getitem_113 = getitem_114 = getitem_115 = getitem_116 = getitem_117 = getitem_118 = getitem_119 = getitem_120 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16) + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 32, '0'); convert_element_type_37 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + view_87 = torch.ops.aten.view.default(cat_5, [16384, 4096]); cat_5 = None + mm_7 = torch.ops.aten.mm.default(view_87, permute_11); permute_11 = None + view_88 = torch.ops.aten.view.default(mm_7, [2, 8192, 512]) + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16) + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 32, '0'); convert_element_type_40 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + mm_8 = torch.ops.aten.mm.default(view_87, permute_12); permute_12 = None + view_95 = torch.ops.aten.view.default(mm_8, [2, 8192, 128]); mm_8 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16) + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 32, '0'); convert_element_type_43 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_19, [1, 0]); wait_tensor_19 = None + mm_9 = torch.ops.aten.mm.default(view_87, permute_13); view_87 = permute_13 = None + view_102 = torch.ops.aten.view.default(mm_9, [2, 8192, 128]) + view_104 = torch.ops.aten.view.default(view_88, [2, 8192, -1, 128]); view_88 = None + view_105 = torch.ops.aten.view.default(view_95, [2, 8192, -1, 128]); view_95 = None + view_106 = torch.ops.aten.view.default(view_102, [2, 8192, -1, 128]); view_102 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_104, torch.float32); view_104 = None + view_107 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 4, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_107); view_107 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_105, torch.float32); view_105 = None + view_108 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 1, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_108); view_108 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_37); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_110 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 4, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_37); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_111 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 1, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_110, torch.bfloat16); view_110 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_111, torch.bfloat16); view_111 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 1, 4, 128]); unsqueeze_2 = None + view_112 = torch.ops.aten.view.default(expand_2, [2, 8192, 4, 128]); expand_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_106, 3); view_106 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 1, 4, 128]); unsqueeze_3 = None + view_113 = torch.ops.aten.view.default(expand_3, [2, 8192, 4, 128]); expand_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_112, [0, 2, 1, 3]); view_112 = None + permute_16 = torch.ops.aten.permute.default(view_113, [0, 2, 1, 3]); view_113 = None + _scaled_dot_product_cudnn_attention_1 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_14, permute_15, permute_16, None, True, 0.0, True); permute_14 = permute_15 = permute_16 = None + getitem_121 = _scaled_dot_product_cudnn_attention_1[0] + getitem_122 = _scaled_dot_product_cudnn_attention_1[1] + getitem_127 = _scaled_dot_product_cudnn_attention_1[6] + getitem_128 = _scaled_dot_product_cudnn_attention_1[7]; _scaled_dot_product_cudnn_attention_1 = None + permute_17 = torch.ops.aten.permute.default(getitem_121, [0, 2, 1, 3]) + view_114 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16) + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 32, '0'); convert_element_type_50 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + view_120 = torch.ops.aten.view.default(view_114, [16384, 512]); view_114 = None + mm_10 = torch.ops.aten.mm.default(view_120, permute_18); view_120 = permute_18 = None + view_121 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + split_14 = torch.ops.aten.split.Tensor(view_121, 1024, 1); view_121 = None + getitem_130 = split_14[0] + getitem_131 = split_14[1] + getitem_132 = split_14[2] + getitem_133 = split_14[3] + getitem_134 = split_14[4] + getitem_135 = split_14[5] + getitem_136 = split_14[6] + getitem_137 = split_14[7]; split_14 = None + cat_6 = torch.ops.aten.cat.default([getitem_130, getitem_131, getitem_132, getitem_133, getitem_134, getitem_135, getitem_136, getitem_137]); getitem_130 = getitem_131 = getitem_132 = getitem_133 = getitem_134 = getitem_135 = getitem_136 = getitem_137 = None + reduce_scatter_tensor_3 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_6, 'sum', 8, '1'); cat_6 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3) + add_5 = torch.ops.aten.add.Tensor(add_3, wait_tensor_21); wait_tensor_21 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16) + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 32, '0'); convert_element_type_53 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32) + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = rsqrt_3 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_22); mul_12 = wait_tensor_22 = None + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_55, 8, '1'); convert_element_type_55 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + split_15 = torch.ops.aten.split.Tensor(wait_tensor_23, 2); wait_tensor_23 = None + getitem_138 = split_15[0] + getitem_139 = split_15[1] + getitem_140 = split_15[2] + getitem_141 = split_15[3] + getitem_142 = split_15[4] + getitem_143 = split_15[5] + getitem_144 = split_15[6] + getitem_145 = split_15[7]; split_15 = None + cat_7 = torch.ops.aten.cat.default([getitem_138, getitem_139, getitem_140, getitem_141, getitem_142, getitem_143, getitem_144, getitem_145], 1); getitem_138 = getitem_139 = getitem_140 = getitem_141 = getitem_142 = getitem_143 = getitem_144 = getitem_145 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16) + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 32, '0'); convert_element_type_56 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_24, [1, 0]); wait_tensor_24 = None + view_132 = torch.ops.aten.view.default(cat_7, [16384, 4096]); cat_7 = None + mm_11 = torch.ops.aten.mm.default(view_132, permute_19); permute_19 = None + view_133 = torch.ops.aten.view.default(mm_11, [2, 8192, 1792]) + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_133, torch.float32); view_133 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); convert_element_type_59 = sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16) + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 32, '0'); convert_element_type_61 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + mm_12 = torch.ops.aten.mm.default(view_132, permute_20); view_132 = permute_20 = None + view_140 = torch.ops.aten.view.default(mm_12, [2, 8192, 1792]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_140); convert_element_type_60 = view_140 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16) + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 32, '0'); convert_element_type_64 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + view_147 = torch.ops.aten.view.default(mul_15, [16384, 1792]); mul_15 = None + mm_13 = torch.ops.aten.mm.default(view_147, permute_21); view_147 = permute_21 = None + view_148 = torch.ops.aten.view.default(mm_13, [2, 8192, 4096]); mm_13 = None + split_16 = torch.ops.aten.split.Tensor(view_148, 1024, 1); view_148 = None + getitem_146 = split_16[0] + getitem_147 = split_16[1] + getitem_148 = split_16[2] + getitem_149 = split_16[3] + getitem_150 = split_16[4] + getitem_151 = split_16[5] + getitem_152 = split_16[6] + getitem_153 = split_16[7]; split_16 = None + cat_8 = torch.ops.aten.cat.default([getitem_146, getitem_147, getitem_148, getitem_149, getitem_150, getitem_151, getitem_152, getitem_153]); getitem_146 = getitem_147 = getitem_148 = getitem_149 = getitem_150 = getitem_151 = getitem_152 = getitem_153 = None + reduce_scatter_tensor_4 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_8, 'sum', 8, '1'); cat_8 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_4); reduce_scatter_tensor_4 = None + add_7 = torch.ops.aten.add.Tensor(add_5, wait_tensor_27); add_5 = wait_tensor_27 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16) + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 32, '0'); convert_element_type_67 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32) + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = rsqrt_4 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_28); mul_16 = wait_tensor_28 = None + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_69, 8, '1'); convert_element_type_69 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + split_17 = torch.ops.aten.split.Tensor(wait_tensor_29, 2); wait_tensor_29 = None + getitem_154 = split_17[0] + getitem_155 = split_17[1] + getitem_156 = split_17[2] + getitem_157 = split_17[3] + getitem_158 = split_17[4] + getitem_159 = split_17[5] + getitem_160 = split_17[6] + getitem_161 = split_17[7]; split_17 = None + cat_9 = torch.ops.aten.cat.default([getitem_154, getitem_155, getitem_156, getitem_157, getitem_158, getitem_159, getitem_160, getitem_161], 1); getitem_154 = getitem_155 = getitem_156 = getitem_157 = getitem_158 = getitem_159 = getitem_160 = getitem_161 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16) + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 32, '0'); convert_element_type_70 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + view_159 = torch.ops.aten.view.default(cat_9, [16384, 4096]); cat_9 = None + mm_14 = torch.ops.aten.mm.default(view_159, permute_22); permute_22 = None + view_160 = torch.ops.aten.view.default(mm_14, [2, 8192, 512]) + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16) + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 32, '0'); convert_element_type_73 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_15 = torch.ops.aten.mm.default(view_159, permute_23); permute_23 = None + view_167 = torch.ops.aten.view.default(mm_15, [2, 8192, 128]); mm_15 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16) + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 32, '0'); convert_element_type_76 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + mm_16 = torch.ops.aten.mm.default(view_159, permute_24); view_159 = permute_24 = None + view_174 = torch.ops.aten.view.default(mm_16, [2, 8192, 128]) + view_176 = torch.ops.aten.view.default(view_160, [2, 8192, -1, 128]); view_160 = None + view_177 = torch.ops.aten.view.default(view_167, [2, 8192, -1, 128]); view_167 = None + view_178 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_176, torch.float32); view_176 = None + view_179 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 4, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_179); view_179 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_177, torch.float32); view_177 = None + view_180 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 1, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_180); view_180 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_37); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_182 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 4, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_37); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_183 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 1, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_182, torch.bfloat16); view_182 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_183, torch.bfloat16); view_183 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 1, 4, 128]); unsqueeze_4 = None + view_184 = torch.ops.aten.view.default(expand_4, [2, 8192, 4, 128]); expand_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_178, 3); view_178 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 1, 4, 128]); unsqueeze_5 = None + view_185 = torch.ops.aten.view.default(expand_5, [2, 8192, 4, 128]); expand_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_184, [0, 2, 1, 3]); view_184 = None + permute_27 = torch.ops.aten.permute.default(view_185, [0, 2, 1, 3]); view_185 = None + _scaled_dot_product_cudnn_attention_2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_25, permute_26, permute_27, None, True, 0.0, True); permute_25 = permute_26 = permute_27 = None + getitem_162 = _scaled_dot_product_cudnn_attention_2[0] + getitem_163 = _scaled_dot_product_cudnn_attention_2[1] + getitem_168 = _scaled_dot_product_cudnn_attention_2[6] + getitem_169 = _scaled_dot_product_cudnn_attention_2[7]; _scaled_dot_product_cudnn_attention_2 = None + permute_28 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_186 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16) + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 32, '0'); convert_element_type_83 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + view_192 = torch.ops.aten.view.default(view_186, [16384, 512]); view_186 = None + mm_17 = torch.ops.aten.mm.default(view_192, permute_29); view_192 = permute_29 = None + view_193 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + split_18 = torch.ops.aten.split.Tensor(view_193, 1024, 1); view_193 = None + getitem_171 = split_18[0] + getitem_172 = split_18[1] + getitem_173 = split_18[2] + getitem_174 = split_18[3] + getitem_175 = split_18[4] + getitem_176 = split_18[5] + getitem_177 = split_18[6] + getitem_178 = split_18[7]; split_18 = None + cat_10 = torch.ops.aten.cat.default([getitem_171, getitem_172, getitem_173, getitem_174, getitem_175, getitem_176, getitem_177, getitem_178]); getitem_171 = getitem_172 = getitem_173 = getitem_174 = getitem_175 = getitem_176 = getitem_177 = getitem_178 = None + reduce_scatter_tensor_5 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_10, 'sum', 8, '1'); cat_10 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5) + add_9 = torch.ops.aten.add.Tensor(add_7, wait_tensor_34); wait_tensor_34 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16) + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 32, '0'); convert_element_type_86 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32) + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = rsqrt_5 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_35); mul_20 = wait_tensor_35 = None + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_88, 8, '1'); convert_element_type_88 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + split_19 = torch.ops.aten.split.Tensor(wait_tensor_36, 2); wait_tensor_36 = None + getitem_179 = split_19[0] + getitem_180 = split_19[1] + getitem_181 = split_19[2] + getitem_182 = split_19[3] + getitem_183 = split_19[4] + getitem_184 = split_19[5] + getitem_185 = split_19[6] + getitem_186 = split_19[7]; split_19 = None + cat_11 = torch.ops.aten.cat.default([getitem_179, getitem_180, getitem_181, getitem_182, getitem_183, getitem_184, getitem_185, getitem_186], 1); getitem_179 = getitem_180 = getitem_181 = getitem_182 = getitem_183 = getitem_184 = getitem_185 = getitem_186 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16) + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 32, '0'); convert_element_type_89 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + view_204 = torch.ops.aten.view.default(cat_11, [16384, 4096]); cat_11 = None + mm_18 = torch.ops.aten.mm.default(view_204, permute_30); permute_30 = None + view_205 = torch.ops.aten.view.default(mm_18, [2, 8192, 1792]) + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_205, torch.float32); view_205 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); convert_element_type_92 = sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16) + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 32, '0'); convert_element_type_94 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + mm_19 = torch.ops.aten.mm.default(view_204, permute_31); view_204 = permute_31 = None + view_212 = torch.ops.aten.view.default(mm_19, [2, 8192, 1792]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_212); convert_element_type_93 = view_212 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16) + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 32, '0'); convert_element_type_97 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + view_219 = torch.ops.aten.view.default(mul_23, [16384, 1792]); mul_23 = None + mm_20 = torch.ops.aten.mm.default(view_219, permute_32); view_219 = permute_32 = None + view_220 = torch.ops.aten.view.default(mm_20, [2, 8192, 4096]); mm_20 = None + split_20 = torch.ops.aten.split.Tensor(view_220, 1024, 1); view_220 = None + getitem_187 = split_20[0] + getitem_188 = split_20[1] + getitem_189 = split_20[2] + getitem_190 = split_20[3] + getitem_191 = split_20[4] + getitem_192 = split_20[5] + getitem_193 = split_20[6] + getitem_194 = split_20[7]; split_20 = None + cat_12 = torch.ops.aten.cat.default([getitem_187, getitem_188, getitem_189, getitem_190, getitem_191, getitem_192, getitem_193, getitem_194]); getitem_187 = getitem_188 = getitem_189 = getitem_190 = getitem_191 = getitem_192 = getitem_193 = getitem_194 = None + reduce_scatter_tensor_6 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_12, 'sum', 8, '1'); cat_12 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_6); reduce_scatter_tensor_6 = None + add_11 = torch.ops.aten.add.Tensor(add_9, wait_tensor_40); add_9 = wait_tensor_40 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16) + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 32, '0'); convert_element_type_100 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32) + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = rsqrt_6 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_41); mul_24 = wait_tensor_41 = None + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_102, 8, '1'); convert_element_type_102 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + split_21 = torch.ops.aten.split.Tensor(wait_tensor_42, 2); wait_tensor_42 = None + getitem_195 = split_21[0] + getitem_196 = split_21[1] + getitem_197 = split_21[2] + getitem_198 = split_21[3] + getitem_199 = split_21[4] + getitem_200 = split_21[5] + getitem_201 = split_21[6] + getitem_202 = split_21[7]; split_21 = None + cat_13 = torch.ops.aten.cat.default([getitem_195, getitem_196, getitem_197, getitem_198, getitem_199, getitem_200, getitem_201, getitem_202], 1); getitem_195 = getitem_196 = getitem_197 = getitem_198 = getitem_199 = getitem_200 = getitem_201 = getitem_202 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16) + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 32, '0'); convert_element_type_103 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + view_231 = torch.ops.aten.view.default(cat_13, [16384, 4096]); cat_13 = None + mm_21 = torch.ops.aten.mm.default(view_231, permute_33); permute_33 = None + view_232 = torch.ops.aten.view.default(mm_21, [2, 8192, 512]) + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16) + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 32, '0'); convert_element_type_106 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_22 = torch.ops.aten.mm.default(view_231, permute_34); permute_34 = None + view_239 = torch.ops.aten.view.default(mm_22, [2, 8192, 128]); mm_22 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16) + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 32, '0'); convert_element_type_109 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + mm_23 = torch.ops.aten.mm.default(view_231, permute_35); view_231 = permute_35 = None + view_246 = torch.ops.aten.view.default(mm_23, [2, 8192, 128]) + view_248 = torch.ops.aten.view.default(view_232, [2, 8192, -1, 128]); view_232 = None + view_249 = torch.ops.aten.view.default(view_239, [2, 8192, -1, 128]); view_239 = None + view_250 = torch.ops.aten.view.default(view_246, [2, 8192, -1, 128]); view_246 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_248, torch.float32); view_248 = None + view_251 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 4, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_251); view_251 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 1, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_37); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_254 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 4, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_37); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_255 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 1, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_254, torch.bfloat16); view_254 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 1, 4, 128]); unsqueeze_6 = None + view_256 = torch.ops.aten.view.default(expand_6, [2, 8192, 4, 128]); expand_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_250, 3); view_250 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 1, 4, 128]); unsqueeze_7 = None + view_257 = torch.ops.aten.view.default(expand_7, [2, 8192, 4, 128]); expand_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_256, [0, 2, 1, 3]); view_256 = None + permute_38 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + _scaled_dot_product_cudnn_attention_3 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_36, permute_37, permute_38, None, True, 0.0, True); permute_36 = permute_37 = permute_38 = None + getitem_203 = _scaled_dot_product_cudnn_attention_3[0] + getitem_204 = _scaled_dot_product_cudnn_attention_3[1] + getitem_209 = _scaled_dot_product_cudnn_attention_3[6] + getitem_210 = _scaled_dot_product_cudnn_attention_3[7]; _scaled_dot_product_cudnn_attention_3 = None + permute_39 = torch.ops.aten.permute.default(getitem_203, [0, 2, 1, 3]) + view_258 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16) + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 32, '0'); convert_element_type_116 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_46, [1, 0]); wait_tensor_46 = None + view_264 = torch.ops.aten.view.default(view_258, [16384, 512]); view_258 = None + mm_24 = torch.ops.aten.mm.default(view_264, permute_40); view_264 = permute_40 = None + view_265 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + split_22 = torch.ops.aten.split.Tensor(view_265, 1024, 1); view_265 = None + getitem_212 = split_22[0] + getitem_213 = split_22[1] + getitem_214 = split_22[2] + getitem_215 = split_22[3] + getitem_216 = split_22[4] + getitem_217 = split_22[5] + getitem_218 = split_22[6] + getitem_219 = split_22[7]; split_22 = None + cat_14 = torch.ops.aten.cat.default([getitem_212, getitem_213, getitem_214, getitem_215, getitem_216, getitem_217, getitem_218, getitem_219]); getitem_212 = getitem_213 = getitem_214 = getitem_215 = getitem_216 = getitem_217 = getitem_218 = getitem_219 = None + reduce_scatter_tensor_7 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_14, 'sum', 8, '1'); cat_14 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7) + add_13 = torch.ops.aten.add.Tensor(add_11, wait_tensor_47); wait_tensor_47 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16) + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 32, '0'); convert_element_type_119 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32) + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = rsqrt_7 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_48); mul_28 = wait_tensor_48 = None + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_121, 8, '1'); convert_element_type_121 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + split_23 = torch.ops.aten.split.Tensor(wait_tensor_49, 2); wait_tensor_49 = None + getitem_220 = split_23[0] + getitem_221 = split_23[1] + getitem_222 = split_23[2] + getitem_223 = split_23[3] + getitem_224 = split_23[4] + getitem_225 = split_23[5] + getitem_226 = split_23[6] + getitem_227 = split_23[7]; split_23 = None + cat_15 = torch.ops.aten.cat.default([getitem_220, getitem_221, getitem_222, getitem_223, getitem_224, getitem_225, getitem_226, getitem_227], 1); getitem_220 = getitem_221 = getitem_222 = getitem_223 = getitem_224 = getitem_225 = getitem_226 = getitem_227 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16) + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 32, '0'); convert_element_type_122 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_276 = torch.ops.aten.view.default(cat_15, [16384, 4096]); cat_15 = None + mm_25 = torch.ops.aten.mm.default(view_276, permute_41); permute_41 = None + view_277 = torch.ops.aten.view.default(mm_25, [2, 8192, 1792]) + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_277, torch.float32); view_277 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); convert_element_type_125 = sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16) + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 32, '0'); convert_element_type_127 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + mm_26 = torch.ops.aten.mm.default(view_276, permute_42); view_276 = permute_42 = None + view_284 = torch.ops.aten.view.default(mm_26, [2, 8192, 1792]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_284); convert_element_type_126 = view_284 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16) + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 32, '0'); convert_element_type_130 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + view_291 = torch.ops.aten.view.default(mul_31, [16384, 1792]); mul_31 = None + mm_27 = torch.ops.aten.mm.default(view_291, permute_43); view_291 = permute_43 = None + view_292 = torch.ops.aten.view.default(mm_27, [2, 8192, 4096]); mm_27 = None + split_24 = torch.ops.aten.split.Tensor(view_292, 1024, 1); view_292 = None + getitem_228 = split_24[0] + getitem_229 = split_24[1] + getitem_230 = split_24[2] + getitem_231 = split_24[3] + getitem_232 = split_24[4] + getitem_233 = split_24[5] + getitem_234 = split_24[6] + getitem_235 = split_24[7]; split_24 = None + cat_16 = torch.ops.aten.cat.default([getitem_228, getitem_229, getitem_230, getitem_231, getitem_232, getitem_233, getitem_234, getitem_235]); getitem_228 = getitem_229 = getitem_230 = getitem_231 = getitem_232 = getitem_233 = getitem_234 = getitem_235 = None + reduce_scatter_tensor_8 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_16, 'sum', 8, '1'); cat_16 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_8); reduce_scatter_tensor_8 = None + add_15 = torch.ops.aten.add.Tensor(add_13, wait_tensor_53); add_13 = wait_tensor_53 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16) + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 32, '0'); convert_element_type_133 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32) + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = rsqrt_8 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_54); mul_32 = wait_tensor_54 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_135, 8, '1'); convert_element_type_135 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + split_25 = torch.ops.aten.split.Tensor(wait_tensor_55, 2); wait_tensor_55 = None + getitem_236 = split_25[0] + getitem_237 = split_25[1] + getitem_238 = split_25[2] + getitem_239 = split_25[3] + getitem_240 = split_25[4] + getitem_241 = split_25[5] + getitem_242 = split_25[6] + getitem_243 = split_25[7]; split_25 = None + cat_17 = torch.ops.aten.cat.default([getitem_236, getitem_237, getitem_238, getitem_239, getitem_240, getitem_241, getitem_242, getitem_243], 1); getitem_236 = getitem_237 = getitem_238 = getitem_239 = getitem_240 = getitem_241 = getitem_242 = getitem_243 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16) + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 32, '0'); convert_element_type_136 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + view_303 = torch.ops.aten.view.default(cat_17, [16384, 4096]); cat_17 = None + mm_28 = torch.ops.aten.mm.default(view_303, permute_44); permute_44 = None + view_304 = torch.ops.aten.view.default(mm_28, [2, 8192, 512]) + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16) + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 32, '0'); convert_element_type_139 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_29 = torch.ops.aten.mm.default(view_303, permute_45); permute_45 = None + view_311 = torch.ops.aten.view.default(mm_29, [2, 8192, 128]); mm_29 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16) + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 32, '0'); convert_element_type_142 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + mm_30 = torch.ops.aten.mm.default(view_303, permute_46); view_303 = permute_46 = None + view_318 = torch.ops.aten.view.default(mm_30, [2, 8192, 128]) + view_320 = torch.ops.aten.view.default(view_304, [2, 8192, -1, 128]); view_304 = None + view_321 = torch.ops.aten.view.default(view_311, [2, 8192, -1, 128]); view_311 = None + view_322 = torch.ops.aten.view.default(view_318, [2, 8192, -1, 128]); view_318 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_320, torch.float32); view_320 = None + view_323 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 4, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_323); view_323 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_321, torch.float32); view_321 = None + view_324 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 1, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_324); view_324 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_37); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_326 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 4, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_37); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_327 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 1, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_326, torch.bfloat16); view_326 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_327, torch.bfloat16); view_327 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 1, 4, 128]); unsqueeze_8 = None + view_328 = torch.ops.aten.view.default(expand_8, [2, 8192, 4, 128]); expand_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_322, 3); view_322 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 1, 4, 128]); unsqueeze_9 = None + view_329 = torch.ops.aten.view.default(expand_9, [2, 8192, 4, 128]); expand_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_328, [0, 2, 1, 3]); view_328 = None + permute_49 = torch.ops.aten.permute.default(view_329, [0, 2, 1, 3]); view_329 = None + _scaled_dot_product_cudnn_attention_4 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_47, permute_48, permute_49, None, True, 0.0, True); permute_47 = permute_48 = permute_49 = None + getitem_244 = _scaled_dot_product_cudnn_attention_4[0] + getitem_245 = _scaled_dot_product_cudnn_attention_4[1] + getitem_250 = _scaled_dot_product_cudnn_attention_4[6] + getitem_251 = _scaled_dot_product_cudnn_attention_4[7]; _scaled_dot_product_cudnn_attention_4 = None + permute_50 = torch.ops.aten.permute.default(getitem_244, [0, 2, 1, 3]) + view_330 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16) + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 32, '0'); convert_element_type_149 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_336 = torch.ops.aten.view.default(view_330, [16384, 512]); view_330 = None + mm_31 = torch.ops.aten.mm.default(view_336, permute_51); view_336 = permute_51 = None + view_337 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + split_26 = torch.ops.aten.split.Tensor(view_337, 1024, 1); view_337 = None + getitem_253 = split_26[0] + getitem_254 = split_26[1] + getitem_255 = split_26[2] + getitem_256 = split_26[3] + getitem_257 = split_26[4] + getitem_258 = split_26[5] + getitem_259 = split_26[6] + getitem_260 = split_26[7]; split_26 = None + cat_18 = torch.ops.aten.cat.default([getitem_253, getitem_254, getitem_255, getitem_256, getitem_257, getitem_258, getitem_259, getitem_260]); getitem_253 = getitem_254 = getitem_255 = getitem_256 = getitem_257 = getitem_258 = getitem_259 = getitem_260 = None + reduce_scatter_tensor_9 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_18, 'sum', 8, '1'); cat_18 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9) + add_17 = torch.ops.aten.add.Tensor(add_15, wait_tensor_60); wait_tensor_60 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16) + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 32, '0'); convert_element_type_152 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32) + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = rsqrt_9 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_61); mul_36 = wait_tensor_61 = None + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_154, 8, '1'); convert_element_type_154 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + split_27 = torch.ops.aten.split.Tensor(wait_tensor_62, 2); wait_tensor_62 = None + getitem_261 = split_27[0] + getitem_262 = split_27[1] + getitem_263 = split_27[2] + getitem_264 = split_27[3] + getitem_265 = split_27[4] + getitem_266 = split_27[5] + getitem_267 = split_27[6] + getitem_268 = split_27[7]; split_27 = None + cat_19 = torch.ops.aten.cat.default([getitem_261, getitem_262, getitem_263, getitem_264, getitem_265, getitem_266, getitem_267, getitem_268], 1); getitem_261 = getitem_262 = getitem_263 = getitem_264 = getitem_265 = getitem_266 = getitem_267 = getitem_268 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16) + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 32, '0'); convert_element_type_155 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + view_348 = torch.ops.aten.view.default(cat_19, [16384, 4096]); cat_19 = None + mm_32 = torch.ops.aten.mm.default(view_348, permute_52); permute_52 = None + view_349 = torch.ops.aten.view.default(mm_32, [2, 8192, 1792]) + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_349, torch.float32); view_349 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); convert_element_type_158 = sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16) + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 32, '0'); convert_element_type_160 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_64, [1, 0]); wait_tensor_64 = None + mm_33 = torch.ops.aten.mm.default(view_348, permute_53); view_348 = permute_53 = None + view_356 = torch.ops.aten.view.default(mm_33, [2, 8192, 1792]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_356); convert_element_type_159 = view_356 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16) + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 32, '0'); convert_element_type_163 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + view_363 = torch.ops.aten.view.default(mul_39, [16384, 1792]); mul_39 = None + mm_34 = torch.ops.aten.mm.default(view_363, permute_54); view_363 = permute_54 = None + view_364 = torch.ops.aten.view.default(mm_34, [2, 8192, 4096]); mm_34 = None + split_28 = torch.ops.aten.split.Tensor(view_364, 1024, 1); view_364 = None + getitem_269 = split_28[0] + getitem_270 = split_28[1] + getitem_271 = split_28[2] + getitem_272 = split_28[3] + getitem_273 = split_28[4] + getitem_274 = split_28[5] + getitem_275 = split_28[6] + getitem_276 = split_28[7]; split_28 = None + cat_20 = torch.ops.aten.cat.default([getitem_269, getitem_270, getitem_271, getitem_272, getitem_273, getitem_274, getitem_275, getitem_276]); getitem_269 = getitem_270 = getitem_271 = getitem_272 = getitem_273 = getitem_274 = getitem_275 = getitem_276 = None + reduce_scatter_tensor_10 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_20, 'sum', 8, '1'); cat_20 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_10); reduce_scatter_tensor_10 = None + add_19 = torch.ops.aten.add.Tensor(add_17, wait_tensor_66); add_17 = wait_tensor_66 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16) + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 32, '0'); convert_element_type_166 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32) + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = rsqrt_10 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_67); mul_40 = wait_tensor_67 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_168, 8, '1'); convert_element_type_168 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + split_29 = torch.ops.aten.split.Tensor(wait_tensor_68, 2); wait_tensor_68 = None + getitem_277 = split_29[0] + getitem_278 = split_29[1] + getitem_279 = split_29[2] + getitem_280 = split_29[3] + getitem_281 = split_29[4] + getitem_282 = split_29[5] + getitem_283 = split_29[6] + getitem_284 = split_29[7]; split_29 = None + cat_21 = torch.ops.aten.cat.default([getitem_277, getitem_278, getitem_279, getitem_280, getitem_281, getitem_282, getitem_283, getitem_284], 1); getitem_277 = getitem_278 = getitem_279 = getitem_280 = getitem_281 = getitem_282 = getitem_283 = getitem_284 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16) + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 32, '0'); convert_element_type_169 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_69, [1, 0]); wait_tensor_69 = None + view_375 = torch.ops.aten.view.default(cat_21, [16384, 4096]); cat_21 = None + mm_35 = torch.ops.aten.mm.default(view_375, permute_55); permute_55 = None + view_376 = torch.ops.aten.view.default(mm_35, [2, 8192, 512]) + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16) + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 32, '0'); convert_element_type_172 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + mm_36 = torch.ops.aten.mm.default(view_375, permute_56); permute_56 = None + view_383 = torch.ops.aten.view.default(mm_36, [2, 8192, 128]); mm_36 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16) + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 32, '0'); convert_element_type_175 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_37 = torch.ops.aten.mm.default(view_375, permute_57); view_375 = permute_57 = None + view_390 = torch.ops.aten.view.default(mm_37, [2, 8192, 128]) + view_392 = torch.ops.aten.view.default(view_376, [2, 8192, -1, 128]); view_376 = None + view_393 = torch.ops.aten.view.default(view_383, [2, 8192, -1, 128]); view_383 = None + view_394 = torch.ops.aten.view.default(view_390, [2, 8192, -1, 128]); view_390 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_392, torch.float32); view_392 = None + view_395 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 4, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_395); view_395 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_393, torch.float32); view_393 = None + view_396 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 1, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_396); view_396 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_37); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_398 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 4, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_37); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_399 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 1, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_398, torch.bfloat16); view_398 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_399, torch.bfloat16); view_399 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 1, 4, 128]); unsqueeze_10 = None + view_400 = torch.ops.aten.view.default(expand_10, [2, 8192, 4, 128]); expand_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_394, 3); view_394 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 1, 4, 128]); unsqueeze_11 = None + view_401 = torch.ops.aten.view.default(expand_11, [2, 8192, 4, 128]); expand_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_400, [0, 2, 1, 3]); view_400 = None + permute_60 = torch.ops.aten.permute.default(view_401, [0, 2, 1, 3]); view_401 = None + _scaled_dot_product_cudnn_attention_5 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_58, permute_59, permute_60, None, True, 0.0, True); permute_58 = permute_59 = permute_60 = None + getitem_285 = _scaled_dot_product_cudnn_attention_5[0] + getitem_286 = _scaled_dot_product_cudnn_attention_5[1] + getitem_291 = _scaled_dot_product_cudnn_attention_5[6] + getitem_292 = _scaled_dot_product_cudnn_attention_5[7]; _scaled_dot_product_cudnn_attention_5 = None + permute_61 = torch.ops.aten.permute.default(getitem_285, [0, 2, 1, 3]) + view_402 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16) + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 32, '0'); convert_element_type_182 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + view_408 = torch.ops.aten.view.default(view_402, [16384, 512]); view_402 = None + mm_38 = torch.ops.aten.mm.default(view_408, permute_62); view_408 = permute_62 = None + view_409 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + split_30 = torch.ops.aten.split.Tensor(view_409, 1024, 1); view_409 = None + getitem_294 = split_30[0] + getitem_295 = split_30[1] + getitem_296 = split_30[2] + getitem_297 = split_30[3] + getitem_298 = split_30[4] + getitem_299 = split_30[5] + getitem_300 = split_30[6] + getitem_301 = split_30[7]; split_30 = None + cat_22 = torch.ops.aten.cat.default([getitem_294, getitem_295, getitem_296, getitem_297, getitem_298, getitem_299, getitem_300, getitem_301]); getitem_294 = getitem_295 = getitem_296 = getitem_297 = getitem_298 = getitem_299 = getitem_300 = getitem_301 = None + reduce_scatter_tensor_11 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_22, 'sum', 8, '1'); cat_22 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11) + add_21 = torch.ops.aten.add.Tensor(add_19, wait_tensor_73); wait_tensor_73 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16) + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 32, '0'); convert_element_type_185 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32) + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = rsqrt_11 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_74); mul_44 = wait_tensor_74 = None + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_187, 8, '1'); convert_element_type_187 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + split_31 = torch.ops.aten.split.Tensor(wait_tensor_75, 2); wait_tensor_75 = None + getitem_302 = split_31[0] + getitem_303 = split_31[1] + getitem_304 = split_31[2] + getitem_305 = split_31[3] + getitem_306 = split_31[4] + getitem_307 = split_31[5] + getitem_308 = split_31[6] + getitem_309 = split_31[7]; split_31 = None + cat_23 = torch.ops.aten.cat.default([getitem_302, getitem_303, getitem_304, getitem_305, getitem_306, getitem_307, getitem_308, getitem_309], 1); getitem_302 = getitem_303 = getitem_304 = getitem_305 = getitem_306 = getitem_307 = getitem_308 = getitem_309 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16) + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 32, '0'); convert_element_type_188 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + view_420 = torch.ops.aten.view.default(cat_23, [16384, 4096]); cat_23 = None + mm_39 = torch.ops.aten.mm.default(view_420, permute_63); permute_63 = None + view_421 = torch.ops.aten.view.default(mm_39, [2, 8192, 1792]) + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_421, torch.float32); view_421 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); convert_element_type_191 = sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16) + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 32, '0'); convert_element_type_193 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + mm_40 = torch.ops.aten.mm.default(view_420, permute_64); view_420 = permute_64 = None + view_428 = torch.ops.aten.view.default(mm_40, [2, 8192, 1792]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_428); convert_element_type_192 = view_428 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16) + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 32, '0'); convert_element_type_196 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + view_435 = torch.ops.aten.view.default(mul_47, [16384, 1792]); mul_47 = None + mm_41 = torch.ops.aten.mm.default(view_435, permute_65); view_435 = permute_65 = None + view_436 = torch.ops.aten.view.default(mm_41, [2, 8192, 4096]); mm_41 = None + split_32 = torch.ops.aten.split.Tensor(view_436, 1024, 1); view_436 = None + getitem_310 = split_32[0] + getitem_311 = split_32[1] + getitem_312 = split_32[2] + getitem_313 = split_32[3] + getitem_314 = split_32[4] + getitem_315 = split_32[5] + getitem_316 = split_32[6] + getitem_317 = split_32[7]; split_32 = None + cat_24 = torch.ops.aten.cat.default([getitem_310, getitem_311, getitem_312, getitem_313, getitem_314, getitem_315, getitem_316, getitem_317]); getitem_310 = getitem_311 = getitem_312 = getitem_313 = getitem_314 = getitem_315 = getitem_316 = getitem_317 = None + reduce_scatter_tensor_12 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_24, 'sum', 8, '1'); cat_24 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_12); reduce_scatter_tensor_12 = None + add_23 = torch.ops.aten.add.Tensor(add_21, wait_tensor_79); add_21 = wait_tensor_79 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16) + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 32, '0'); convert_element_type_199 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32) + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = rsqrt_12 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_80); mul_48 = wait_tensor_80 = None + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_201, 8, '1'); convert_element_type_201 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + split_33 = torch.ops.aten.split.Tensor(wait_tensor_81, 2); wait_tensor_81 = None + getitem_318 = split_33[0] + getitem_319 = split_33[1] + getitem_320 = split_33[2] + getitem_321 = split_33[3] + getitem_322 = split_33[4] + getitem_323 = split_33[5] + getitem_324 = split_33[6] + getitem_325 = split_33[7]; split_33 = None + cat_25 = torch.ops.aten.cat.default([getitem_318, getitem_319, getitem_320, getitem_321, getitem_322, getitem_323, getitem_324, getitem_325], 1); getitem_318 = getitem_319 = getitem_320 = getitem_321 = getitem_322 = getitem_323 = getitem_324 = getitem_325 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16) + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 32, '0'); convert_element_type_202 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_82, [1, 0]); wait_tensor_82 = None + view_447 = torch.ops.aten.view.default(cat_25, [16384, 4096]); cat_25 = None + mm_42 = torch.ops.aten.mm.default(view_447, permute_66); permute_66 = None + view_448 = torch.ops.aten.view.default(mm_42, [2, 8192, 512]) + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16) + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 32, '0'); convert_element_type_205 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + mm_43 = torch.ops.aten.mm.default(view_447, permute_67); permute_67 = None + view_455 = torch.ops.aten.view.default(mm_43, [2, 8192, 128]); mm_43 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16) + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 32, '0'); convert_element_type_208 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_44 = torch.ops.aten.mm.default(view_447, permute_68); view_447 = permute_68 = None + view_462 = torch.ops.aten.view.default(mm_44, [2, 8192, 128]) + view_464 = torch.ops.aten.view.default(view_448, [2, 8192, -1, 128]); view_448 = None + view_465 = torch.ops.aten.view.default(view_455, [2, 8192, -1, 128]); view_455 = None + view_466 = torch.ops.aten.view.default(view_462, [2, 8192, -1, 128]); view_462 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_464, torch.float32); view_464 = None + view_467 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 4, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_467); view_467 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_465, torch.float32); view_465 = None + view_468 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 1, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_468); view_468 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_37); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_470 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 4, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_37); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_471 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 1, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_470, torch.bfloat16); view_470 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_471, torch.bfloat16); view_471 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 1, 4, 128]); unsqueeze_12 = None + view_472 = torch.ops.aten.view.default(expand_12, [2, 8192, 4, 128]); expand_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_466, 3); view_466 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 1, 4, 128]); unsqueeze_13 = None + view_473 = torch.ops.aten.view.default(expand_13, [2, 8192, 4, 128]); expand_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_472, [0, 2, 1, 3]); view_472 = None + permute_71 = torch.ops.aten.permute.default(view_473, [0, 2, 1, 3]); view_473 = None + _scaled_dot_product_cudnn_attention_6 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_69, permute_70, permute_71, None, True, 0.0, True); permute_69 = permute_70 = permute_71 = None + getitem_326 = _scaled_dot_product_cudnn_attention_6[0] + getitem_327 = _scaled_dot_product_cudnn_attention_6[1] + getitem_332 = _scaled_dot_product_cudnn_attention_6[6] + getitem_333 = _scaled_dot_product_cudnn_attention_6[7]; _scaled_dot_product_cudnn_attention_6 = None + permute_72 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]) + view_474 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16) + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 32, '0'); convert_element_type_215 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + view_480 = torch.ops.aten.view.default(view_474, [16384, 512]); view_474 = None + mm_45 = torch.ops.aten.mm.default(view_480, permute_73); view_480 = permute_73 = None + view_481 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + split_34 = torch.ops.aten.split.Tensor(view_481, 1024, 1); view_481 = None + getitem_335 = split_34[0] + getitem_336 = split_34[1] + getitem_337 = split_34[2] + getitem_338 = split_34[3] + getitem_339 = split_34[4] + getitem_340 = split_34[5] + getitem_341 = split_34[6] + getitem_342 = split_34[7]; split_34 = None + cat_26 = torch.ops.aten.cat.default([getitem_335, getitem_336, getitem_337, getitem_338, getitem_339, getitem_340, getitem_341, getitem_342]); getitem_335 = getitem_336 = getitem_337 = getitem_338 = getitem_339 = getitem_340 = getitem_341 = getitem_342 = None + reduce_scatter_tensor_13 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_26, 'sum', 8, '1'); cat_26 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13) + add_25 = torch.ops.aten.add.Tensor(add_23, wait_tensor_86); wait_tensor_86 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16) + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 32, '0'); convert_element_type_218 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32) + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = rsqrt_13 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_87); mul_52 = wait_tensor_87 = None + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_220, 8, '1'); convert_element_type_220 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + split_35 = torch.ops.aten.split.Tensor(wait_tensor_88, 2); wait_tensor_88 = None + getitem_343 = split_35[0] + getitem_344 = split_35[1] + getitem_345 = split_35[2] + getitem_346 = split_35[3] + getitem_347 = split_35[4] + getitem_348 = split_35[5] + getitem_349 = split_35[6] + getitem_350 = split_35[7]; split_35 = None + cat_27 = torch.ops.aten.cat.default([getitem_343, getitem_344, getitem_345, getitem_346, getitem_347, getitem_348, getitem_349, getitem_350], 1); getitem_343 = getitem_344 = getitem_345 = getitem_346 = getitem_347 = getitem_348 = getitem_349 = getitem_350 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16) + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 32, '0'); convert_element_type_221 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + view_492 = torch.ops.aten.view.default(cat_27, [16384, 4096]); cat_27 = None + mm_46 = torch.ops.aten.mm.default(view_492, permute_74); permute_74 = None + view_493 = torch.ops.aten.view.default(mm_46, [2, 8192, 1792]) + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_493, torch.float32); view_493 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); convert_element_type_224 = sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16) + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 32, '0'); convert_element_type_226 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + mm_47 = torch.ops.aten.mm.default(view_492, permute_75); view_492 = permute_75 = None + view_500 = torch.ops.aten.view.default(mm_47, [2, 8192, 1792]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_500); convert_element_type_225 = view_500 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16) + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 32, '0'); convert_element_type_229 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_91, [1, 0]); wait_tensor_91 = None + view_507 = torch.ops.aten.view.default(mul_55, [16384, 1792]); mul_55 = None + mm_48 = torch.ops.aten.mm.default(view_507, permute_76); view_507 = permute_76 = None + view_508 = torch.ops.aten.view.default(mm_48, [2, 8192, 4096]); mm_48 = None + split_36 = torch.ops.aten.split.Tensor(view_508, 1024, 1); view_508 = None + getitem_351 = split_36[0] + getitem_352 = split_36[1] + getitem_353 = split_36[2] + getitem_354 = split_36[3] + getitem_355 = split_36[4] + getitem_356 = split_36[5] + getitem_357 = split_36[6] + getitem_358 = split_36[7]; split_36 = None + cat_28 = torch.ops.aten.cat.default([getitem_351, getitem_352, getitem_353, getitem_354, getitem_355, getitem_356, getitem_357, getitem_358]); getitem_351 = getitem_352 = getitem_353 = getitem_354 = getitem_355 = getitem_356 = getitem_357 = getitem_358 = None + reduce_scatter_tensor_14 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_28, 'sum', 8, '1'); cat_28 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_14); reduce_scatter_tensor_14 = None + add_27 = torch.ops.aten.add.Tensor(add_25, wait_tensor_92); add_25 = wait_tensor_92 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16) + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 32, '0'); convert_element_type_232 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32) + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = rsqrt_14 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_93); mul_56 = wait_tensor_93 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 8, '1'); convert_element_type_234 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + split_37 = torch.ops.aten.split.Tensor(wait_tensor_94, 2); wait_tensor_94 = None + getitem_359 = split_37[0] + getitem_360 = split_37[1] + getitem_361 = split_37[2] + getitem_362 = split_37[3] + getitem_363 = split_37[4] + getitem_364 = split_37[5] + getitem_365 = split_37[6] + getitem_366 = split_37[7]; split_37 = None + cat_29 = torch.ops.aten.cat.default([getitem_359, getitem_360, getitem_361, getitem_362, getitem_363, getitem_364, getitem_365, getitem_366], 1); getitem_359 = getitem_360 = getitem_361 = getitem_362 = getitem_363 = getitem_364 = getitem_365 = getitem_366 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16) + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 32, '0'); convert_element_type_235 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_519 = torch.ops.aten.view.default(cat_29, [16384, 4096]); cat_29 = None + mm_49 = torch.ops.aten.mm.default(view_519, permute_77); permute_77 = None + view_520 = torch.ops.aten.view.default(mm_49, [2, 8192, 512]) + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16) + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 32, '0'); convert_element_type_238 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + mm_50 = torch.ops.aten.mm.default(view_519, permute_78); permute_78 = None + view_527 = torch.ops.aten.view.default(mm_50, [2, 8192, 128]); mm_50 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16) + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 32, '0'); convert_element_type_241 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + mm_51 = torch.ops.aten.mm.default(view_519, permute_79); view_519 = permute_79 = None + view_534 = torch.ops.aten.view.default(mm_51, [2, 8192, 128]) + view_536 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + view_537 = torch.ops.aten.view.default(view_527, [2, 8192, -1, 128]); view_527 = None + view_538 = torch.ops.aten.view.default(view_534, [2, 8192, -1, 128]); view_534 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_536, torch.float32); view_536 = None + view_539 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 4, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_539); view_539 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_537, torch.float32); view_537 = None + view_540 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 1, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_540); view_540 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_37); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_542 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 4, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_37); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_543 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 1, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_542, torch.bfloat16); view_542 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_543, torch.bfloat16); view_543 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 1, 4, 128]); unsqueeze_14 = None + view_544 = torch.ops.aten.view.default(expand_14, [2, 8192, 4, 128]); expand_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_538, 3); view_538 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 1, 4, 128]); unsqueeze_15 = None + view_545 = torch.ops.aten.view.default(expand_15, [2, 8192, 4, 128]); expand_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_544, [0, 2, 1, 3]); view_544 = None + permute_82 = torch.ops.aten.permute.default(view_545, [0, 2, 1, 3]); view_545 = None + _scaled_dot_product_cudnn_attention_7 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_80, permute_81, permute_82, None, True, 0.0, True); permute_80 = permute_81 = permute_82 = None + getitem_367 = _scaled_dot_product_cudnn_attention_7[0] + getitem_368 = _scaled_dot_product_cudnn_attention_7[1] + getitem_373 = _scaled_dot_product_cudnn_attention_7[6] + getitem_374 = _scaled_dot_product_cudnn_attention_7[7]; _scaled_dot_product_cudnn_attention_7 = None + permute_83 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]) + view_546 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16) + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 32, '0'); convert_element_type_248 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + view_552 = torch.ops.aten.view.default(view_546, [16384, 512]); view_546 = None + mm_52 = torch.ops.aten.mm.default(view_552, permute_84); view_552 = permute_84 = None + view_553 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + split_38 = torch.ops.aten.split.Tensor(view_553, 1024, 1); view_553 = None + getitem_376 = split_38[0] + getitem_377 = split_38[1] + getitem_378 = split_38[2] + getitem_379 = split_38[3] + getitem_380 = split_38[4] + getitem_381 = split_38[5] + getitem_382 = split_38[6] + getitem_383 = split_38[7]; split_38 = None + cat_30 = torch.ops.aten.cat.default([getitem_376, getitem_377, getitem_378, getitem_379, getitem_380, getitem_381, getitem_382, getitem_383]); getitem_376 = getitem_377 = getitem_378 = getitem_379 = getitem_380 = getitem_381 = getitem_382 = getitem_383 = None + reduce_scatter_tensor_15 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_30, 'sum', 8, '1'); cat_30 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15) + add_29 = torch.ops.aten.add.Tensor(add_27, wait_tensor_99); wait_tensor_99 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16) + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 32, '0'); convert_element_type_251 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32) + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = rsqrt_15 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_100); mul_60 = wait_tensor_100 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 8, '1'); convert_element_type_253 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + split_39 = torch.ops.aten.split.Tensor(wait_tensor_101, 2); wait_tensor_101 = None + getitem_384 = split_39[0] + getitem_385 = split_39[1] + getitem_386 = split_39[2] + getitem_387 = split_39[3] + getitem_388 = split_39[4] + getitem_389 = split_39[5] + getitem_390 = split_39[6] + getitem_391 = split_39[7]; split_39 = None + cat_31 = torch.ops.aten.cat.default([getitem_384, getitem_385, getitem_386, getitem_387, getitem_388, getitem_389, getitem_390, getitem_391], 1); getitem_384 = getitem_385 = getitem_386 = getitem_387 = getitem_388 = getitem_389 = getitem_390 = getitem_391 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16) + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 32, '0'); convert_element_type_254 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + view_564 = torch.ops.aten.view.default(cat_31, [16384, 4096]); cat_31 = None + mm_53 = torch.ops.aten.mm.default(view_564, permute_85); permute_85 = None + view_565 = torch.ops.aten.view.default(mm_53, [2, 8192, 1792]) + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_565, torch.float32); view_565 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); convert_element_type_257 = sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16) + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 32, '0'); convert_element_type_259 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_54 = torch.ops.aten.mm.default(view_564, permute_86); view_564 = permute_86 = None + view_572 = torch.ops.aten.view.default(mm_54, [2, 8192, 1792]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_572); convert_element_type_258 = view_572 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16) + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 32, '0'); convert_element_type_262 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_579 = torch.ops.aten.view.default(mul_63, [16384, 1792]); mul_63 = None + mm_55 = torch.ops.aten.mm.default(view_579, permute_87); view_579 = permute_87 = None + view_580 = torch.ops.aten.view.default(mm_55, [2, 8192, 4096]); mm_55 = None + split_40 = torch.ops.aten.split.Tensor(view_580, 1024, 1); view_580 = None + getitem_392 = split_40[0] + getitem_393 = split_40[1] + getitem_394 = split_40[2] + getitem_395 = split_40[3] + getitem_396 = split_40[4] + getitem_397 = split_40[5] + getitem_398 = split_40[6] + getitem_399 = split_40[7]; split_40 = None + cat_32 = torch.ops.aten.cat.default([getitem_392, getitem_393, getitem_394, getitem_395, getitem_396, getitem_397, getitem_398, getitem_399]); getitem_392 = getitem_393 = getitem_394 = getitem_395 = getitem_396 = getitem_397 = getitem_398 = getitem_399 = None + reduce_scatter_tensor_16 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_32, 'sum', 8, '1'); cat_32 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_16); reduce_scatter_tensor_16 = None + add_31 = torch.ops.aten.add.Tensor(add_29, wait_tensor_105); add_29 = wait_tensor_105 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16) + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 32, '0'); convert_element_type_265 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32) + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = rsqrt_16 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_106); mul_64 = wait_tensor_106 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_267, 8, '1'); convert_element_type_267 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + split_41 = torch.ops.aten.split.Tensor(wait_tensor_107, 2); wait_tensor_107 = None + getitem_400 = split_41[0] + getitem_401 = split_41[1] + getitem_402 = split_41[2] + getitem_403 = split_41[3] + getitem_404 = split_41[4] + getitem_405 = split_41[5] + getitem_406 = split_41[6] + getitem_407 = split_41[7]; split_41 = None + cat_33 = torch.ops.aten.cat.default([getitem_400, getitem_401, getitem_402, getitem_403, getitem_404, getitem_405, getitem_406, getitem_407], 1); getitem_400 = getitem_401 = getitem_402 = getitem_403 = getitem_404 = getitem_405 = getitem_406 = getitem_407 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16) + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 32, '0'); convert_element_type_268 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + view_591 = torch.ops.aten.view.default(cat_33, [16384, 4096]); cat_33 = None + mm_56 = torch.ops.aten.mm.default(view_591, permute_88); permute_88 = None + view_592 = torch.ops.aten.view.default(mm_56, [2, 8192, 512]) + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16) + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 32, '0'); convert_element_type_271 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_109, [1, 0]); wait_tensor_109 = None + mm_57 = torch.ops.aten.mm.default(view_591, permute_89); permute_89 = None + view_599 = torch.ops.aten.view.default(mm_57, [2, 8192, 128]); mm_57 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16) + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 32, '0'); convert_element_type_274 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + mm_58 = torch.ops.aten.mm.default(view_591, permute_90); view_591 = permute_90 = None + view_606 = torch.ops.aten.view.default(mm_58, [2, 8192, 128]) + view_608 = torch.ops.aten.view.default(view_592, [2, 8192, -1, 128]); view_592 = None + view_609 = torch.ops.aten.view.default(view_599, [2, 8192, -1, 128]); view_599 = None + view_610 = torch.ops.aten.view.default(view_606, [2, 8192, -1, 128]); view_606 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_608, torch.float32); view_608 = None + view_611 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 4, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_611); view_611 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_609, torch.float32); view_609 = None + view_612 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 1, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_612); view_612 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_37); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_614 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 4, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_37); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_615 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 1, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_614, torch.bfloat16); view_614 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_615, torch.bfloat16); view_615 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 1, 4, 128]); unsqueeze_16 = None + view_616 = torch.ops.aten.view.default(expand_16, [2, 8192, 4, 128]); expand_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_610, 3); view_610 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 1, 4, 128]); unsqueeze_17 = None + view_617 = torch.ops.aten.view.default(expand_17, [2, 8192, 4, 128]); expand_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_616, [0, 2, 1, 3]); view_616 = None + permute_93 = torch.ops.aten.permute.default(view_617, [0, 2, 1, 3]); view_617 = None + _scaled_dot_product_cudnn_attention_8 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_91, permute_92, permute_93, None, True, 0.0, True); permute_91 = permute_92 = permute_93 = None + getitem_408 = _scaled_dot_product_cudnn_attention_8[0] + getitem_409 = _scaled_dot_product_cudnn_attention_8[1] + getitem_414 = _scaled_dot_product_cudnn_attention_8[6] + getitem_415 = _scaled_dot_product_cudnn_attention_8[7]; _scaled_dot_product_cudnn_attention_8 = None + permute_94 = torch.ops.aten.permute.default(getitem_408, [0, 2, 1, 3]) + view_618 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16) + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 32, '0'); convert_element_type_281 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + view_624 = torch.ops.aten.view.default(view_618, [16384, 512]); view_618 = None + mm_59 = torch.ops.aten.mm.default(view_624, permute_95); view_624 = permute_95 = None + view_625 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + split_42 = torch.ops.aten.split.Tensor(view_625, 1024, 1); view_625 = None + getitem_417 = split_42[0] + getitem_418 = split_42[1] + getitem_419 = split_42[2] + getitem_420 = split_42[3] + getitem_421 = split_42[4] + getitem_422 = split_42[5] + getitem_423 = split_42[6] + getitem_424 = split_42[7]; split_42 = None + cat_34 = torch.ops.aten.cat.default([getitem_417, getitem_418, getitem_419, getitem_420, getitem_421, getitem_422, getitem_423, getitem_424]); getitem_417 = getitem_418 = getitem_419 = getitem_420 = getitem_421 = getitem_422 = getitem_423 = getitem_424 = None + reduce_scatter_tensor_17 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_34, 'sum', 8, '1'); cat_34 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17) + add_33 = torch.ops.aten.add.Tensor(add_31, wait_tensor_112); wait_tensor_112 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16) + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 32, '0'); convert_element_type_284 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32) + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = rsqrt_17 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_113); mul_68 = wait_tensor_113 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 8, '1'); convert_element_type_286 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + split_43 = torch.ops.aten.split.Tensor(wait_tensor_114, 2); wait_tensor_114 = None + getitem_425 = split_43[0] + getitem_426 = split_43[1] + getitem_427 = split_43[2] + getitem_428 = split_43[3] + getitem_429 = split_43[4] + getitem_430 = split_43[5] + getitem_431 = split_43[6] + getitem_432 = split_43[7]; split_43 = None + cat_35 = torch.ops.aten.cat.default([getitem_425, getitem_426, getitem_427, getitem_428, getitem_429, getitem_430, getitem_431, getitem_432], 1); getitem_425 = getitem_426 = getitem_427 = getitem_428 = getitem_429 = getitem_430 = getitem_431 = getitem_432 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16) + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 32, '0'); convert_element_type_287 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + view_636 = torch.ops.aten.view.default(cat_35, [16384, 4096]); cat_35 = None + mm_60 = torch.ops.aten.mm.default(view_636, permute_96); permute_96 = None + view_637 = torch.ops.aten.view.default(mm_60, [2, 8192, 1792]) + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_637, torch.float32); view_637 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); convert_element_type_290 = sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16) + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 32, '0'); convert_element_type_292 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_61 = torch.ops.aten.mm.default(view_636, permute_97); view_636 = permute_97 = None + view_644 = torch.ops.aten.view.default(mm_61, [2, 8192, 1792]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_644); convert_element_type_291 = view_644 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16) + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 32, '0'); convert_element_type_295 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + view_651 = torch.ops.aten.view.default(mul_71, [16384, 1792]); mul_71 = None + mm_62 = torch.ops.aten.mm.default(view_651, permute_98); view_651 = permute_98 = None + view_652 = torch.ops.aten.view.default(mm_62, [2, 8192, 4096]); mm_62 = None + split_44 = torch.ops.aten.split.Tensor(view_652, 1024, 1); view_652 = None + getitem_433 = split_44[0] + getitem_434 = split_44[1] + getitem_435 = split_44[2] + getitem_436 = split_44[3] + getitem_437 = split_44[4] + getitem_438 = split_44[5] + getitem_439 = split_44[6] + getitem_440 = split_44[7]; split_44 = None + cat_36 = torch.ops.aten.cat.default([getitem_433, getitem_434, getitem_435, getitem_436, getitem_437, getitem_438, getitem_439, getitem_440]); getitem_433 = getitem_434 = getitem_435 = getitem_436 = getitem_437 = getitem_438 = getitem_439 = getitem_440 = None + reduce_scatter_tensor_18 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_36, 'sum', 8, '1'); cat_36 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_18); reduce_scatter_tensor_18 = None + add_35 = torch.ops.aten.add.Tensor(add_33, wait_tensor_118); add_33 = wait_tensor_118 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16) + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 32, '0'); convert_element_type_298 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32) + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = rsqrt_18 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_119); mul_72 = wait_tensor_119 = None + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_300, 8, '1'); convert_element_type_300 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + split_45 = torch.ops.aten.split.Tensor(wait_tensor_120, 2); wait_tensor_120 = None + getitem_441 = split_45[0] + getitem_442 = split_45[1] + getitem_443 = split_45[2] + getitem_444 = split_45[3] + getitem_445 = split_45[4] + getitem_446 = split_45[5] + getitem_447 = split_45[6] + getitem_448 = split_45[7]; split_45 = None + cat_37 = torch.ops.aten.cat.default([getitem_441, getitem_442, getitem_443, getitem_444, getitem_445, getitem_446, getitem_447, getitem_448], 1); getitem_441 = getitem_442 = getitem_443 = getitem_444 = getitem_445 = getitem_446 = getitem_447 = getitem_448 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16) + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 32, '0'); convert_element_type_301 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + view_663 = torch.ops.aten.view.default(cat_37, [16384, 4096]); cat_37 = None + mm_63 = torch.ops.aten.mm.default(view_663, permute_99); permute_99 = None + view_664 = torch.ops.aten.view.default(mm_63, [2, 8192, 512]) + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16) + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 32, '0'); convert_element_type_304 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + mm_64 = torch.ops.aten.mm.default(view_663, permute_100); permute_100 = None + view_671 = torch.ops.aten.view.default(mm_64, [2, 8192, 128]); mm_64 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16) + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 32, '0'); convert_element_type_307 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_123, [1, 0]); wait_tensor_123 = None + mm_65 = torch.ops.aten.mm.default(view_663, permute_101); view_663 = permute_101 = None + view_678 = torch.ops.aten.view.default(mm_65, [2, 8192, 128]) + view_680 = torch.ops.aten.view.default(view_664, [2, 8192, -1, 128]); view_664 = None + view_681 = torch.ops.aten.view.default(view_671, [2, 8192, -1, 128]); view_671 = None + view_682 = torch.ops.aten.view.default(view_678, [2, 8192, -1, 128]); view_678 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_680, torch.float32); view_680 = None + view_683 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 4, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_683); view_683 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_681, torch.float32); view_681 = None + view_684 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 1, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_684); view_684 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_37); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_686 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 4, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_37); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_687 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 1, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_686, torch.bfloat16); view_686 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_687, torch.bfloat16); view_687 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 1, 4, 128]); unsqueeze_18 = None + view_688 = torch.ops.aten.view.default(expand_18, [2, 8192, 4, 128]); expand_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_682, 3); view_682 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 1, 4, 128]); unsqueeze_19 = None + view_689 = torch.ops.aten.view.default(expand_19, [2, 8192, 4, 128]); expand_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_688, [0, 2, 1, 3]); view_688 = None + permute_104 = torch.ops.aten.permute.default(view_689, [0, 2, 1, 3]); view_689 = None + _scaled_dot_product_cudnn_attention_9 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_102, permute_103, permute_104, None, True, 0.0, True); permute_102 = permute_103 = permute_104 = None + getitem_449 = _scaled_dot_product_cudnn_attention_9[0] + getitem_450 = _scaled_dot_product_cudnn_attention_9[1] + getitem_455 = _scaled_dot_product_cudnn_attention_9[6] + getitem_456 = _scaled_dot_product_cudnn_attention_9[7]; _scaled_dot_product_cudnn_attention_9 = None + permute_105 = torch.ops.aten.permute.default(getitem_449, [0, 2, 1, 3]) + view_690 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16) + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 32, '0'); convert_element_type_314 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + view_696 = torch.ops.aten.view.default(view_690, [16384, 512]); view_690 = None + mm_66 = torch.ops.aten.mm.default(view_696, permute_106); view_696 = permute_106 = None + view_697 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + split_46 = torch.ops.aten.split.Tensor(view_697, 1024, 1); view_697 = None + getitem_458 = split_46[0] + getitem_459 = split_46[1] + getitem_460 = split_46[2] + getitem_461 = split_46[3] + getitem_462 = split_46[4] + getitem_463 = split_46[5] + getitem_464 = split_46[6] + getitem_465 = split_46[7]; split_46 = None + cat_38 = torch.ops.aten.cat.default([getitem_458, getitem_459, getitem_460, getitem_461, getitem_462, getitem_463, getitem_464, getitem_465]); getitem_458 = getitem_459 = getitem_460 = getitem_461 = getitem_462 = getitem_463 = getitem_464 = getitem_465 = None + reduce_scatter_tensor_19 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_38, 'sum', 8, '1'); cat_38 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19) + add_37 = torch.ops.aten.add.Tensor(add_35, wait_tensor_125); wait_tensor_125 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16) + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 32, '0'); convert_element_type_317 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32) + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = rsqrt_19 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_126); mul_76 = wait_tensor_126 = None + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_319, 8, '1'); convert_element_type_319 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + split_47 = torch.ops.aten.split.Tensor(wait_tensor_127, 2); wait_tensor_127 = None + getitem_466 = split_47[0] + getitem_467 = split_47[1] + getitem_468 = split_47[2] + getitem_469 = split_47[3] + getitem_470 = split_47[4] + getitem_471 = split_47[5] + getitem_472 = split_47[6] + getitem_473 = split_47[7]; split_47 = None + cat_39 = torch.ops.aten.cat.default([getitem_466, getitem_467, getitem_468, getitem_469, getitem_470, getitem_471, getitem_472, getitem_473], 1); getitem_466 = getitem_467 = getitem_468 = getitem_469 = getitem_470 = getitem_471 = getitem_472 = getitem_473 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16) + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 32, '0'); convert_element_type_320 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + view_708 = torch.ops.aten.view.default(cat_39, [16384, 4096]); cat_39 = None + mm_67 = torch.ops.aten.mm.default(view_708, permute_107); permute_107 = None + view_709 = torch.ops.aten.view.default(mm_67, [2, 8192, 1792]) + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_709, torch.float32); view_709 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); convert_element_type_323 = sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16) + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 32, '0'); convert_element_type_325 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_68 = torch.ops.aten.mm.default(view_708, permute_108); view_708 = permute_108 = None + view_716 = torch.ops.aten.view.default(mm_68, [2, 8192, 1792]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_716); convert_element_type_324 = view_716 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16) + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 32, '0'); convert_element_type_328 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + view_723 = torch.ops.aten.view.default(mul_79, [16384, 1792]); mul_79 = None + mm_69 = torch.ops.aten.mm.default(view_723, permute_109); view_723 = permute_109 = None + view_724 = torch.ops.aten.view.default(mm_69, [2, 8192, 4096]); mm_69 = None + split_48 = torch.ops.aten.split.Tensor(view_724, 1024, 1); view_724 = None + getitem_474 = split_48[0] + getitem_475 = split_48[1] + getitem_476 = split_48[2] + getitem_477 = split_48[3] + getitem_478 = split_48[4] + getitem_479 = split_48[5] + getitem_480 = split_48[6] + getitem_481 = split_48[7]; split_48 = None + cat_40 = torch.ops.aten.cat.default([getitem_474, getitem_475, getitem_476, getitem_477, getitem_478, getitem_479, getitem_480, getitem_481]); getitem_474 = getitem_475 = getitem_476 = getitem_477 = getitem_478 = getitem_479 = getitem_480 = getitem_481 = None + reduce_scatter_tensor_20 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_40, 'sum', 8, '1'); cat_40 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_20); reduce_scatter_tensor_20 = None + add_39 = torch.ops.aten.add.Tensor(add_37, wait_tensor_131); add_37 = wait_tensor_131 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16) + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 32, '0'); convert_element_type_331 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32) + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = rsqrt_20 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_132); mul_80 = wait_tensor_132 = None + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_333, 8, '1'); convert_element_type_333 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + split_49 = torch.ops.aten.split.Tensor(wait_tensor_133, 2); wait_tensor_133 = None + getitem_482 = split_49[0] + getitem_483 = split_49[1] + getitem_484 = split_49[2] + getitem_485 = split_49[3] + getitem_486 = split_49[4] + getitem_487 = split_49[5] + getitem_488 = split_49[6] + getitem_489 = split_49[7]; split_49 = None + cat_41 = torch.ops.aten.cat.default([getitem_482, getitem_483, getitem_484, getitem_485, getitem_486, getitem_487, getitem_488, getitem_489], 1); getitem_482 = getitem_483 = getitem_484 = getitem_485 = getitem_486 = getitem_487 = getitem_488 = getitem_489 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16) + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 32, '0'); convert_element_type_334 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + view_735 = torch.ops.aten.view.default(cat_41, [16384, 4096]); cat_41 = None + mm_70 = torch.ops.aten.mm.default(view_735, permute_110); permute_110 = None + view_736 = torch.ops.aten.view.default(mm_70, [2, 8192, 512]) + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16) + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 32, '0'); convert_element_type_337 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + mm_71 = torch.ops.aten.mm.default(view_735, permute_111); permute_111 = None + view_743 = torch.ops.aten.view.default(mm_71, [2, 8192, 128]); mm_71 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16) + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 32, '0'); convert_element_type_340 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + mm_72 = torch.ops.aten.mm.default(view_735, permute_112); view_735 = permute_112 = None + view_750 = torch.ops.aten.view.default(mm_72, [2, 8192, 128]) + view_752 = torch.ops.aten.view.default(view_736, [2, 8192, -1, 128]); view_736 = None + view_753 = torch.ops.aten.view.default(view_743, [2, 8192, -1, 128]); view_743 = None + view_754 = torch.ops.aten.view.default(view_750, [2, 8192, -1, 128]); view_750 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_752, torch.float32); view_752 = None + view_755 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 4, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_755); view_755 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_753, torch.float32); view_753 = None + view_756 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 1, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_756); view_756 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_37); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_758 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 4, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_37); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_759 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 1, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_758, torch.bfloat16); view_758 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_759, torch.bfloat16); view_759 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 1, 4, 128]); unsqueeze_20 = None + view_760 = torch.ops.aten.view.default(expand_20, [2, 8192, 4, 128]); expand_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_754, 3); view_754 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 1, 4, 128]); unsqueeze_21 = None + view_761 = torch.ops.aten.view.default(expand_21, [2, 8192, 4, 128]); expand_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_760, [0, 2, 1, 3]); view_760 = None + permute_115 = torch.ops.aten.permute.default(view_761, [0, 2, 1, 3]); view_761 = None + _scaled_dot_product_cudnn_attention_10 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_113, permute_114, permute_115, None, True, 0.0, True); permute_113 = permute_114 = permute_115 = None + getitem_490 = _scaled_dot_product_cudnn_attention_10[0] + getitem_491 = _scaled_dot_product_cudnn_attention_10[1] + getitem_496 = _scaled_dot_product_cudnn_attention_10[6] + getitem_497 = _scaled_dot_product_cudnn_attention_10[7]; _scaled_dot_product_cudnn_attention_10 = None + permute_116 = torch.ops.aten.permute.default(getitem_490, [0, 2, 1, 3]) + view_762 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16) + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 32, '0'); convert_element_type_347 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + view_768 = torch.ops.aten.view.default(view_762, [16384, 512]); view_762 = None + mm_73 = torch.ops.aten.mm.default(view_768, permute_117); view_768 = permute_117 = None + view_769 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + split_50 = torch.ops.aten.split.Tensor(view_769, 1024, 1); view_769 = None + getitem_499 = split_50[0] + getitem_500 = split_50[1] + getitem_501 = split_50[2] + getitem_502 = split_50[3] + getitem_503 = split_50[4] + getitem_504 = split_50[5] + getitem_505 = split_50[6] + getitem_506 = split_50[7]; split_50 = None + cat_42 = torch.ops.aten.cat.default([getitem_499, getitem_500, getitem_501, getitem_502, getitem_503, getitem_504, getitem_505, getitem_506]); getitem_499 = getitem_500 = getitem_501 = getitem_502 = getitem_503 = getitem_504 = getitem_505 = getitem_506 = None + reduce_scatter_tensor_21 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_42, 'sum', 8, '1'); cat_42 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21) + add_41 = torch.ops.aten.add.Tensor(add_39, wait_tensor_138); wait_tensor_138 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16) + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 32, '0'); convert_element_type_350 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32) + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = rsqrt_21 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_139); mul_84 = wait_tensor_139 = None + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_352, 8, '1'); convert_element_type_352 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + split_51 = torch.ops.aten.split.Tensor(wait_tensor_140, 2); wait_tensor_140 = None + getitem_507 = split_51[0] + getitem_508 = split_51[1] + getitem_509 = split_51[2] + getitem_510 = split_51[3] + getitem_511 = split_51[4] + getitem_512 = split_51[5] + getitem_513 = split_51[6] + getitem_514 = split_51[7]; split_51 = None + cat_43 = torch.ops.aten.cat.default([getitem_507, getitem_508, getitem_509, getitem_510, getitem_511, getitem_512, getitem_513, getitem_514], 1); getitem_507 = getitem_508 = getitem_509 = getitem_510 = getitem_511 = getitem_512 = getitem_513 = getitem_514 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16) + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 32, '0'); convert_element_type_353 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + view_780 = torch.ops.aten.view.default(cat_43, [16384, 4096]); cat_43 = None + mm_74 = torch.ops.aten.mm.default(view_780, permute_118); permute_118 = None + view_781 = torch.ops.aten.view.default(mm_74, [2, 8192, 1792]) + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_781, torch.float32); view_781 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); convert_element_type_356 = sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16) + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 32, '0'); convert_element_type_358 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + mm_75 = torch.ops.aten.mm.default(view_780, permute_119); view_780 = permute_119 = None + view_788 = torch.ops.aten.view.default(mm_75, [2, 8192, 1792]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_788); convert_element_type_357 = view_788 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16) + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 32, '0'); convert_element_type_361 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + view_795 = torch.ops.aten.view.default(mul_87, [16384, 1792]); mul_87 = None + mm_76 = torch.ops.aten.mm.default(view_795, permute_120); view_795 = permute_120 = None + view_796 = torch.ops.aten.view.default(mm_76, [2, 8192, 4096]); mm_76 = None + split_52 = torch.ops.aten.split.Tensor(view_796, 1024, 1); view_796 = None + getitem_515 = split_52[0] + getitem_516 = split_52[1] + getitem_517 = split_52[2] + getitem_518 = split_52[3] + getitem_519 = split_52[4] + getitem_520 = split_52[5] + getitem_521 = split_52[6] + getitem_522 = split_52[7]; split_52 = None + cat_44 = torch.ops.aten.cat.default([getitem_515, getitem_516, getitem_517, getitem_518, getitem_519, getitem_520, getitem_521, getitem_522]); getitem_515 = getitem_516 = getitem_517 = getitem_518 = getitem_519 = getitem_520 = getitem_521 = getitem_522 = None + reduce_scatter_tensor_22 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_44, 'sum', 8, '1'); cat_44 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_22); reduce_scatter_tensor_22 = None + add_43 = torch.ops.aten.add.Tensor(add_41, wait_tensor_144); add_41 = wait_tensor_144 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16) + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 32, '0'); convert_element_type_364 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32) + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = rsqrt_22 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_145); mul_88 = wait_tensor_145 = None + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_366, 8, '1'); convert_element_type_366 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + split_53 = torch.ops.aten.split.Tensor(wait_tensor_146, 2); wait_tensor_146 = None + getitem_523 = split_53[0] + getitem_524 = split_53[1] + getitem_525 = split_53[2] + getitem_526 = split_53[3] + getitem_527 = split_53[4] + getitem_528 = split_53[5] + getitem_529 = split_53[6] + getitem_530 = split_53[7]; split_53 = None + cat_45 = torch.ops.aten.cat.default([getitem_523, getitem_524, getitem_525, getitem_526, getitem_527, getitem_528, getitem_529, getitem_530], 1); getitem_523 = getitem_524 = getitem_525 = getitem_526 = getitem_527 = getitem_528 = getitem_529 = getitem_530 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16) + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 32, '0'); convert_element_type_367 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + view_807 = torch.ops.aten.view.default(cat_45, [16384, 4096]); cat_45 = None + mm_77 = torch.ops.aten.mm.default(view_807, permute_121); permute_121 = None + view_808 = torch.ops.aten.view.default(mm_77, [2, 8192, 512]) + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16) + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 32, '0'); convert_element_type_370 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_78 = torch.ops.aten.mm.default(view_807, permute_122); permute_122 = None + view_815 = torch.ops.aten.view.default(mm_78, [2, 8192, 128]); mm_78 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16) + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 32, '0'); convert_element_type_373 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + mm_79 = torch.ops.aten.mm.default(view_807, permute_123); view_807 = permute_123 = None + view_822 = torch.ops.aten.view.default(mm_79, [2, 8192, 128]) + view_824 = torch.ops.aten.view.default(view_808, [2, 8192, -1, 128]); view_808 = None + view_825 = torch.ops.aten.view.default(view_815, [2, 8192, -1, 128]); view_815 = None + view_826 = torch.ops.aten.view.default(view_822, [2, 8192, -1, 128]); view_822 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_824, torch.float32); view_824 = None + view_827 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 4, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_827); view_827 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_825, torch.float32); view_825 = None + view_828 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 1, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_828); view_828 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_37); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_830 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 4, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_37); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_831 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 1, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_830, torch.bfloat16); view_830 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_831, torch.bfloat16); view_831 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 1, 4, 128]); unsqueeze_22 = None + view_832 = torch.ops.aten.view.default(expand_22, [2, 8192, 4, 128]); expand_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_826, 3); view_826 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 1, 4, 128]); unsqueeze_23 = None + view_833 = torch.ops.aten.view.default(expand_23, [2, 8192, 4, 128]); expand_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_832, [0, 2, 1, 3]); view_832 = None + permute_126 = torch.ops.aten.permute.default(view_833, [0, 2, 1, 3]); view_833 = None + _scaled_dot_product_cudnn_attention_11 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_124, permute_125, permute_126, None, True, 0.0, True); permute_124 = permute_125 = permute_126 = None + getitem_531 = _scaled_dot_product_cudnn_attention_11[0] + getitem_532 = _scaled_dot_product_cudnn_attention_11[1] + getitem_537 = _scaled_dot_product_cudnn_attention_11[6] + getitem_538 = _scaled_dot_product_cudnn_attention_11[7]; _scaled_dot_product_cudnn_attention_11 = None + permute_127 = torch.ops.aten.permute.default(getitem_531, [0, 2, 1, 3]) + view_834 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16) + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 32, '0'); convert_element_type_380 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_150, [1, 0]); wait_tensor_150 = None + view_840 = torch.ops.aten.view.default(view_834, [16384, 512]); view_834 = None + mm_80 = torch.ops.aten.mm.default(view_840, permute_128); view_840 = permute_128 = None + view_841 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + split_54 = torch.ops.aten.split.Tensor(view_841, 1024, 1); view_841 = None + getitem_540 = split_54[0] + getitem_541 = split_54[1] + getitem_542 = split_54[2] + getitem_543 = split_54[3] + getitem_544 = split_54[4] + getitem_545 = split_54[5] + getitem_546 = split_54[6] + getitem_547 = split_54[7]; split_54 = None + cat_46 = torch.ops.aten.cat.default([getitem_540, getitem_541, getitem_542, getitem_543, getitem_544, getitem_545, getitem_546, getitem_547]); getitem_540 = getitem_541 = getitem_542 = getitem_543 = getitem_544 = getitem_545 = getitem_546 = getitem_547 = None + reduce_scatter_tensor_23 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_46, 'sum', 8, '1'); cat_46 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23) + add_45 = torch.ops.aten.add.Tensor(add_43, wait_tensor_151); wait_tensor_151 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16) + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 32, '0'); convert_element_type_383 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32) + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = rsqrt_23 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_152); mul_92 = wait_tensor_152 = None + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_385, 8, '1'); convert_element_type_385 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + split_55 = torch.ops.aten.split.Tensor(wait_tensor_153, 2); wait_tensor_153 = None + getitem_548 = split_55[0] + getitem_549 = split_55[1] + getitem_550 = split_55[2] + getitem_551 = split_55[3] + getitem_552 = split_55[4] + getitem_553 = split_55[5] + getitem_554 = split_55[6] + getitem_555 = split_55[7]; split_55 = None + cat_47 = torch.ops.aten.cat.default([getitem_548, getitem_549, getitem_550, getitem_551, getitem_552, getitem_553, getitem_554, getitem_555], 1); getitem_548 = getitem_549 = getitem_550 = getitem_551 = getitem_552 = getitem_553 = getitem_554 = getitem_555 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16) + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 32, '0'); convert_element_type_386 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_154, [1, 0]); wait_tensor_154 = None + view_852 = torch.ops.aten.view.default(cat_47, [16384, 4096]); cat_47 = None + mm_81 = torch.ops.aten.mm.default(view_852, permute_129); permute_129 = None + view_853 = torch.ops.aten.view.default(mm_81, [2, 8192, 1792]) + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_853, torch.float32); view_853 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); convert_element_type_389 = sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16) + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 32, '0'); convert_element_type_391 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + mm_82 = torch.ops.aten.mm.default(view_852, permute_130); view_852 = permute_130 = None + view_860 = torch.ops.aten.view.default(mm_82, [2, 8192, 1792]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_860); convert_element_type_390 = view_860 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16) + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 32, '0'); convert_element_type_394 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + view_867 = torch.ops.aten.view.default(mul_95, [16384, 1792]); mul_95 = None + mm_83 = torch.ops.aten.mm.default(view_867, permute_131); view_867 = permute_131 = None + view_868 = torch.ops.aten.view.default(mm_83, [2, 8192, 4096]); mm_83 = None + split_56 = torch.ops.aten.split.Tensor(view_868, 1024, 1); view_868 = None + getitem_556 = split_56[0] + getitem_557 = split_56[1] + getitem_558 = split_56[2] + getitem_559 = split_56[3] + getitem_560 = split_56[4] + getitem_561 = split_56[5] + getitem_562 = split_56[6] + getitem_563 = split_56[7]; split_56 = None + cat_48 = torch.ops.aten.cat.default([getitem_556, getitem_557, getitem_558, getitem_559, getitem_560, getitem_561, getitem_562, getitem_563]); getitem_556 = getitem_557 = getitem_558 = getitem_559 = getitem_560 = getitem_561 = getitem_562 = getitem_563 = None + reduce_scatter_tensor_24 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_48, 'sum', 8, '1'); cat_48 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_24); reduce_scatter_tensor_24 = None + add_47 = torch.ops.aten.add.Tensor(add_45, wait_tensor_157); add_45 = wait_tensor_157 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16) + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 32, '0'); convert_element_type_397 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32) + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = rsqrt_24 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_158); mul_96 = wait_tensor_158 = None + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_399, 8, '1'); convert_element_type_399 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + split_57 = torch.ops.aten.split.Tensor(wait_tensor_159, 2); wait_tensor_159 = None + getitem_564 = split_57[0] + getitem_565 = split_57[1] + getitem_566 = split_57[2] + getitem_567 = split_57[3] + getitem_568 = split_57[4] + getitem_569 = split_57[5] + getitem_570 = split_57[6] + getitem_571 = split_57[7]; split_57 = None + cat_49 = torch.ops.aten.cat.default([getitem_564, getitem_565, getitem_566, getitem_567, getitem_568, getitem_569, getitem_570, getitem_571], 1); getitem_564 = getitem_565 = getitem_566 = getitem_567 = getitem_568 = getitem_569 = getitem_570 = getitem_571 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16) + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 32, '0'); convert_element_type_400 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + view_879 = torch.ops.aten.view.default(cat_49, [16384, 4096]); cat_49 = None + mm_84 = torch.ops.aten.mm.default(view_879, permute_132); permute_132 = None + view_880 = torch.ops.aten.view.default(mm_84, [2, 8192, 512]) + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16) + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 32, '0'); convert_element_type_403 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_85 = torch.ops.aten.mm.default(view_879, permute_133); permute_133 = None + view_887 = torch.ops.aten.view.default(mm_85, [2, 8192, 128]); mm_85 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16) + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 32, '0'); convert_element_type_406 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + mm_86 = torch.ops.aten.mm.default(view_879, permute_134); view_879 = permute_134 = None + view_894 = torch.ops.aten.view.default(mm_86, [2, 8192, 128]) + view_896 = torch.ops.aten.view.default(view_880, [2, 8192, -1, 128]); view_880 = None + view_897 = torch.ops.aten.view.default(view_887, [2, 8192, -1, 128]); view_887 = None + view_898 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 4, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_897, torch.float32); view_897 = None + view_900 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 1, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_900); view_900 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_37); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_902 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 4, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_37); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_903 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 1, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_903, torch.bfloat16); view_903 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 1, 4, 128]); unsqueeze_24 = None + view_904 = torch.ops.aten.view.default(expand_24, [2, 8192, 4, 128]); expand_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_898, 3); view_898 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 1, 4, 128]); unsqueeze_25 = None + view_905 = torch.ops.aten.view.default(expand_25, [2, 8192, 4, 128]); expand_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + permute_137 = torch.ops.aten.permute.default(view_905, [0, 2, 1, 3]); view_905 = None + _scaled_dot_product_cudnn_attention_12 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_135, permute_136, permute_137, None, True, 0.0, True); permute_135 = permute_136 = permute_137 = None + getitem_572 = _scaled_dot_product_cudnn_attention_12[0] + getitem_573 = _scaled_dot_product_cudnn_attention_12[1] + getitem_578 = _scaled_dot_product_cudnn_attention_12[6] + getitem_579 = _scaled_dot_product_cudnn_attention_12[7]; _scaled_dot_product_cudnn_attention_12 = None + permute_138 = torch.ops.aten.permute.default(getitem_572, [0, 2, 1, 3]) + view_906 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16) + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 32, '0'); convert_element_type_413 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + view_912 = torch.ops.aten.view.default(view_906, [16384, 512]); view_906 = None + mm_87 = torch.ops.aten.mm.default(view_912, permute_139); view_912 = permute_139 = None + view_913 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + split_58 = torch.ops.aten.split.Tensor(view_913, 1024, 1); view_913 = None + getitem_581 = split_58[0] + getitem_582 = split_58[1] + getitem_583 = split_58[2] + getitem_584 = split_58[3] + getitem_585 = split_58[4] + getitem_586 = split_58[5] + getitem_587 = split_58[6] + getitem_588 = split_58[7]; split_58 = None + cat_50 = torch.ops.aten.cat.default([getitem_581, getitem_582, getitem_583, getitem_584, getitem_585, getitem_586, getitem_587, getitem_588]); getitem_581 = getitem_582 = getitem_583 = getitem_584 = getitem_585 = getitem_586 = getitem_587 = getitem_588 = None + reduce_scatter_tensor_25 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_50, 'sum', 8, '1'); cat_50 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25) + add_49 = torch.ops.aten.add.Tensor(add_47, wait_tensor_164); wait_tensor_164 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16) + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 32, '0'); convert_element_type_416 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32) + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = rsqrt_25 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_165); mul_100 = wait_tensor_165 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 8, '1'); convert_element_type_418 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + split_59 = torch.ops.aten.split.Tensor(wait_tensor_166, 2); wait_tensor_166 = None + getitem_589 = split_59[0] + getitem_590 = split_59[1] + getitem_591 = split_59[2] + getitem_592 = split_59[3] + getitem_593 = split_59[4] + getitem_594 = split_59[5] + getitem_595 = split_59[6] + getitem_596 = split_59[7]; split_59 = None + cat_51 = torch.ops.aten.cat.default([getitem_589, getitem_590, getitem_591, getitem_592, getitem_593, getitem_594, getitem_595, getitem_596], 1); getitem_589 = getitem_590 = getitem_591 = getitem_592 = getitem_593 = getitem_594 = getitem_595 = getitem_596 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16) + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 32, '0'); convert_element_type_419 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_924 = torch.ops.aten.view.default(cat_51, [16384, 4096]); cat_51 = None + mm_88 = torch.ops.aten.mm.default(view_924, permute_140); permute_140 = None + view_925 = torch.ops.aten.view.default(mm_88, [2, 8192, 1792]) + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_925, torch.float32); view_925 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); convert_element_type_422 = sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16) + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 32, '0'); convert_element_type_424 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_168, [1, 0]); wait_tensor_168 = None + mm_89 = torch.ops.aten.mm.default(view_924, permute_141); view_924 = permute_141 = None + view_932 = torch.ops.aten.view.default(mm_89, [2, 8192, 1792]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_932); convert_element_type_423 = view_932 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16) + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 32, '0'); convert_element_type_427 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + view_939 = torch.ops.aten.view.default(mul_103, [16384, 1792]); mul_103 = None + mm_90 = torch.ops.aten.mm.default(view_939, permute_142); view_939 = permute_142 = None + view_940 = torch.ops.aten.view.default(mm_90, [2, 8192, 4096]); mm_90 = None + split_60 = torch.ops.aten.split.Tensor(view_940, 1024, 1); view_940 = None + getitem_597 = split_60[0] + getitem_598 = split_60[1] + getitem_599 = split_60[2] + getitem_600 = split_60[3] + getitem_601 = split_60[4] + getitem_602 = split_60[5] + getitem_603 = split_60[6] + getitem_604 = split_60[7]; split_60 = None + cat_52 = torch.ops.aten.cat.default([getitem_597, getitem_598, getitem_599, getitem_600, getitem_601, getitem_602, getitem_603, getitem_604]); getitem_597 = getitem_598 = getitem_599 = getitem_600 = getitem_601 = getitem_602 = getitem_603 = getitem_604 = None + reduce_scatter_tensor_26 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_52, 'sum', 8, '1'); cat_52 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_26); reduce_scatter_tensor_26 = None + add_51 = torch.ops.aten.add.Tensor(add_49, wait_tensor_170); add_49 = wait_tensor_170 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16) + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 32, '0'); convert_element_type_430 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32) + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = rsqrt_26 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_171); mul_104 = wait_tensor_171 = None + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_432, 8, '1'); convert_element_type_432 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + split_61 = torch.ops.aten.split.Tensor(wait_tensor_172, 2); wait_tensor_172 = None + getitem_605 = split_61[0] + getitem_606 = split_61[1] + getitem_607 = split_61[2] + getitem_608 = split_61[3] + getitem_609 = split_61[4] + getitem_610 = split_61[5] + getitem_611 = split_61[6] + getitem_612 = split_61[7]; split_61 = None + cat_53 = torch.ops.aten.cat.default([getitem_605, getitem_606, getitem_607, getitem_608, getitem_609, getitem_610, getitem_611, getitem_612], 1); getitem_605 = getitem_606 = getitem_607 = getitem_608 = getitem_609 = getitem_610 = getitem_611 = getitem_612 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16) + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 32, '0'); convert_element_type_433 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + view_951 = torch.ops.aten.view.default(cat_53, [16384, 4096]); cat_53 = None + mm_91 = torch.ops.aten.mm.default(view_951, permute_143); permute_143 = None + view_952 = torch.ops.aten.view.default(mm_91, [2, 8192, 512]) + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16) + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 32, '0'); convert_element_type_436 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_92 = torch.ops.aten.mm.default(view_951, permute_144); permute_144 = None + view_959 = torch.ops.aten.view.default(mm_92, [2, 8192, 128]); mm_92 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16) + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 32, '0'); convert_element_type_439 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + mm_93 = torch.ops.aten.mm.default(view_951, permute_145); view_951 = permute_145 = None + view_966 = torch.ops.aten.view.default(mm_93, [2, 8192, 128]) + view_968 = torch.ops.aten.view.default(view_952, [2, 8192, -1, 128]); view_952 = None + view_969 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_970 = torch.ops.aten.view.default(view_966, [2, 8192, -1, 128]); view_966 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_968, torch.float32); view_968 = None + view_971 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 4, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_971); view_971 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_969, torch.float32); view_969 = None + view_972 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 1, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_972); view_972 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_37); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_974 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 4, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_37); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_975 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 1, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_974, torch.bfloat16); view_974 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_975, torch.bfloat16); view_975 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 1, 4, 128]); unsqueeze_26 = None + view_976 = torch.ops.aten.view.default(expand_26, [2, 8192, 4, 128]); expand_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_970, 3); view_970 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 1, 4, 128]); unsqueeze_27 = None + view_977 = torch.ops.aten.view.default(expand_27, [2, 8192, 4, 128]); expand_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_976, [0, 2, 1, 3]); view_976 = None + permute_148 = torch.ops.aten.permute.default(view_977, [0, 2, 1, 3]); view_977 = None + _scaled_dot_product_cudnn_attention_13 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_146, permute_147, permute_148, None, True, 0.0, True); permute_146 = permute_147 = permute_148 = None + getitem_613 = _scaled_dot_product_cudnn_attention_13[0] + getitem_614 = _scaled_dot_product_cudnn_attention_13[1] + getitem_619 = _scaled_dot_product_cudnn_attention_13[6] + getitem_620 = _scaled_dot_product_cudnn_attention_13[7]; _scaled_dot_product_cudnn_attention_13 = None + permute_149 = torch.ops.aten.permute.default(getitem_613, [0, 2, 1, 3]) + view_978 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16) + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 32, '0'); convert_element_type_446 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_984 = torch.ops.aten.view.default(view_978, [16384, 512]); view_978 = None + mm_94 = torch.ops.aten.mm.default(view_984, permute_150); view_984 = permute_150 = None + view_985 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + split_62 = torch.ops.aten.split.Tensor(view_985, 1024, 1); view_985 = None + getitem_622 = split_62[0] + getitem_623 = split_62[1] + getitem_624 = split_62[2] + getitem_625 = split_62[3] + getitem_626 = split_62[4] + getitem_627 = split_62[5] + getitem_628 = split_62[6] + getitem_629 = split_62[7]; split_62 = None + cat_54 = torch.ops.aten.cat.default([getitem_622, getitem_623, getitem_624, getitem_625, getitem_626, getitem_627, getitem_628, getitem_629]); getitem_622 = getitem_623 = getitem_624 = getitem_625 = getitem_626 = getitem_627 = getitem_628 = getitem_629 = None + reduce_scatter_tensor_27 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_54, 'sum', 8, '1'); cat_54 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27) + add_53 = torch.ops.aten.add.Tensor(add_51, wait_tensor_177); wait_tensor_177 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16) + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 32, '0'); convert_element_type_449 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32) + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = rsqrt_27 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_178); mul_108 = wait_tensor_178 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 8, '1'); convert_element_type_451 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + split_63 = torch.ops.aten.split.Tensor(wait_tensor_179, 2); wait_tensor_179 = None + getitem_630 = split_63[0] + getitem_631 = split_63[1] + getitem_632 = split_63[2] + getitem_633 = split_63[3] + getitem_634 = split_63[4] + getitem_635 = split_63[5] + getitem_636 = split_63[6] + getitem_637 = split_63[7]; split_63 = None + cat_55 = torch.ops.aten.cat.default([getitem_630, getitem_631, getitem_632, getitem_633, getitem_634, getitem_635, getitem_636, getitem_637], 1); getitem_630 = getitem_631 = getitem_632 = getitem_633 = getitem_634 = getitem_635 = getitem_636 = getitem_637 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16) + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 32, '0'); convert_element_type_452 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + view_996 = torch.ops.aten.view.default(cat_55, [16384, 4096]); cat_55 = None + mm_95 = torch.ops.aten.mm.default(view_996, permute_151); permute_151 = None + view_997 = torch.ops.aten.view.default(mm_95, [2, 8192, 1792]) + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); convert_element_type_455 = sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16) + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 32, '0'); convert_element_type_457 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_181, [1, 0]); wait_tensor_181 = None + mm_96 = torch.ops.aten.mm.default(view_996, permute_152); view_996 = permute_152 = None + view_1004 = torch.ops.aten.view.default(mm_96, [2, 8192, 1792]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_1004); convert_element_type_456 = view_1004 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16) + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 32, '0'); convert_element_type_460 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + view_1011 = torch.ops.aten.view.default(mul_111, [16384, 1792]); mul_111 = None + mm_97 = torch.ops.aten.mm.default(view_1011, permute_153); view_1011 = permute_153 = None + view_1012 = torch.ops.aten.view.default(mm_97, [2, 8192, 4096]); mm_97 = None + split_64 = torch.ops.aten.split.Tensor(view_1012, 1024, 1); view_1012 = None + getitem_638 = split_64[0] + getitem_639 = split_64[1] + getitem_640 = split_64[2] + getitem_641 = split_64[3] + getitem_642 = split_64[4] + getitem_643 = split_64[5] + getitem_644 = split_64[6] + getitem_645 = split_64[7]; split_64 = None + cat_56 = torch.ops.aten.cat.default([getitem_638, getitem_639, getitem_640, getitem_641, getitem_642, getitem_643, getitem_644, getitem_645]); getitem_638 = getitem_639 = getitem_640 = getitem_641 = getitem_642 = getitem_643 = getitem_644 = getitem_645 = None + reduce_scatter_tensor_28 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_56, 'sum', 8, '1'); cat_56 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_28); reduce_scatter_tensor_28 = None + add_55 = torch.ops.aten.add.Tensor(add_53, wait_tensor_183); add_53 = wait_tensor_183 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16) + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 32, '0'); convert_element_type_463 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32) + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = rsqrt_28 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_184); mul_112 = wait_tensor_184 = None + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_465, 8, '1'); convert_element_type_465 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + split_65 = torch.ops.aten.split.Tensor(wait_tensor_185, 2); wait_tensor_185 = None + getitem_646 = split_65[0] + getitem_647 = split_65[1] + getitem_648 = split_65[2] + getitem_649 = split_65[3] + getitem_650 = split_65[4] + getitem_651 = split_65[5] + getitem_652 = split_65[6] + getitem_653 = split_65[7]; split_65 = None + cat_57 = torch.ops.aten.cat.default([getitem_646, getitem_647, getitem_648, getitem_649, getitem_650, getitem_651, getitem_652, getitem_653], 1); getitem_646 = getitem_647 = getitem_648 = getitem_649 = getitem_650 = getitem_651 = getitem_652 = getitem_653 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16) + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 32, '0'); convert_element_type_466 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_186, [1, 0]); wait_tensor_186 = None + view_1023 = torch.ops.aten.view.default(cat_57, [16384, 4096]); cat_57 = None + mm_98 = torch.ops.aten.mm.default(view_1023, permute_154); permute_154 = None + view_1024 = torch.ops.aten.view.default(mm_98, [2, 8192, 512]) + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16) + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 32, '0'); convert_element_type_469 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + mm_99 = torch.ops.aten.mm.default(view_1023, permute_155); permute_155 = None + view_1031 = torch.ops.aten.view.default(mm_99, [2, 8192, 128]); mm_99 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16) + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 32, '0'); convert_element_type_472 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_100 = torch.ops.aten.mm.default(view_1023, permute_156); view_1023 = permute_156 = None + view_1038 = torch.ops.aten.view.default(mm_100, [2, 8192, 128]) + view_1040 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1041 = torch.ops.aten.view.default(view_1031, [2, 8192, -1, 128]); view_1031 = None + view_1042 = torch.ops.aten.view.default(view_1038, [2, 8192, -1, 128]); view_1038 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_1040, torch.float32); view_1040 = None + view_1043 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 4, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_1043); view_1043 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_1041, torch.float32); view_1041 = None + view_1044 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 1, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_1044); view_1044 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_37); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_1046 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 4, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_37); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_1047 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 1, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_1046, torch.bfloat16); view_1046 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_1047, torch.bfloat16); view_1047 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 1, 4, 128]); unsqueeze_28 = None + view_1048 = torch.ops.aten.view.default(expand_28, [2, 8192, 4, 128]); expand_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_1042, 3); view_1042 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 1, 4, 128]); unsqueeze_29 = None + view_1049 = torch.ops.aten.view.default(expand_29, [2, 8192, 4, 128]); expand_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_1048, [0, 2, 1, 3]); view_1048 = None + permute_159 = torch.ops.aten.permute.default(view_1049, [0, 2, 1, 3]); view_1049 = None + _scaled_dot_product_cudnn_attention_14 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_157, permute_158, permute_159, None, True, 0.0, True); permute_157 = permute_158 = permute_159 = None + getitem_654 = _scaled_dot_product_cudnn_attention_14[0] + getitem_655 = _scaled_dot_product_cudnn_attention_14[1] + getitem_660 = _scaled_dot_product_cudnn_attention_14[6] + getitem_661 = _scaled_dot_product_cudnn_attention_14[7]; _scaled_dot_product_cudnn_attention_14 = None + permute_160 = torch.ops.aten.permute.default(getitem_654, [0, 2, 1, 3]) + view_1050 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16) + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 32, '0'); convert_element_type_479 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + view_1056 = torch.ops.aten.view.default(view_1050, [16384, 512]); view_1050 = None + mm_101 = torch.ops.aten.mm.default(view_1056, permute_161); view_1056 = permute_161 = None + view_1057 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + split_66 = torch.ops.aten.split.Tensor(view_1057, 1024, 1); view_1057 = None + getitem_663 = split_66[0] + getitem_664 = split_66[1] + getitem_665 = split_66[2] + getitem_666 = split_66[3] + getitem_667 = split_66[4] + getitem_668 = split_66[5] + getitem_669 = split_66[6] + getitem_670 = split_66[7]; split_66 = None + cat_58 = torch.ops.aten.cat.default([getitem_663, getitem_664, getitem_665, getitem_666, getitem_667, getitem_668, getitem_669, getitem_670]); getitem_663 = getitem_664 = getitem_665 = getitem_666 = getitem_667 = getitem_668 = getitem_669 = getitem_670 = None + reduce_scatter_tensor_29 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_58, 'sum', 8, '1'); cat_58 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29) + add_57 = torch.ops.aten.add.Tensor(add_55, wait_tensor_190); wait_tensor_190 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16) + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 32, '0'); convert_element_type_482 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32) + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = rsqrt_29 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_191); mul_116 = wait_tensor_191 = None + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_484, 8, '1'); convert_element_type_484 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + split_67 = torch.ops.aten.split.Tensor(wait_tensor_192, 2); wait_tensor_192 = None + getitem_671 = split_67[0] + getitem_672 = split_67[1] + getitem_673 = split_67[2] + getitem_674 = split_67[3] + getitem_675 = split_67[4] + getitem_676 = split_67[5] + getitem_677 = split_67[6] + getitem_678 = split_67[7]; split_67 = None + cat_59 = torch.ops.aten.cat.default([getitem_671, getitem_672, getitem_673, getitem_674, getitem_675, getitem_676, getitem_677, getitem_678], 1); getitem_671 = getitem_672 = getitem_673 = getitem_674 = getitem_675 = getitem_676 = getitem_677 = getitem_678 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16) + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 32, '0'); convert_element_type_485 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + view_1068 = torch.ops.aten.view.default(cat_59, [16384, 4096]); cat_59 = None + mm_102 = torch.ops.aten.mm.default(view_1068, permute_162); permute_162 = None + view_1069 = torch.ops.aten.view.default(mm_102, [2, 8192, 1792]) + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_1069, torch.float32); view_1069 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); convert_element_type_488 = sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16) + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 32, '0'); convert_element_type_490 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + mm_103 = torch.ops.aten.mm.default(view_1068, permute_163); view_1068 = permute_163 = None + view_1076 = torch.ops.aten.view.default(mm_103, [2, 8192, 1792]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_1076); convert_element_type_489 = view_1076 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16) + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 32, '0'); convert_element_type_493 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_195, [1, 0]); wait_tensor_195 = None + view_1083 = torch.ops.aten.view.default(mul_119, [16384, 1792]); mul_119 = None + mm_104 = torch.ops.aten.mm.default(view_1083, permute_164); view_1083 = permute_164 = None + view_1084 = torch.ops.aten.view.default(mm_104, [2, 8192, 4096]); mm_104 = None + split_68 = torch.ops.aten.split.Tensor(view_1084, 1024, 1); view_1084 = None + getitem_679 = split_68[0] + getitem_680 = split_68[1] + getitem_681 = split_68[2] + getitem_682 = split_68[3] + getitem_683 = split_68[4] + getitem_684 = split_68[5] + getitem_685 = split_68[6] + getitem_686 = split_68[7]; split_68 = None + cat_60 = torch.ops.aten.cat.default([getitem_679, getitem_680, getitem_681, getitem_682, getitem_683, getitem_684, getitem_685, getitem_686]); getitem_679 = getitem_680 = getitem_681 = getitem_682 = getitem_683 = getitem_684 = getitem_685 = getitem_686 = None + reduce_scatter_tensor_30 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_60, 'sum', 8, '1'); cat_60 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_30); reduce_scatter_tensor_30 = None + add_59 = torch.ops.aten.add.Tensor(add_57, wait_tensor_196); add_57 = wait_tensor_196 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16) + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 32, '0'); convert_element_type_496 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32) + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = rsqrt_30 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_197); mul_120 = wait_tensor_197 = None + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_498, 8, '1'); convert_element_type_498 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + split_69 = torch.ops.aten.split.Tensor(wait_tensor_198, 2); wait_tensor_198 = None + getitem_687 = split_69[0] + getitem_688 = split_69[1] + getitem_689 = split_69[2] + getitem_690 = split_69[3] + getitem_691 = split_69[4] + getitem_692 = split_69[5] + getitem_693 = split_69[6] + getitem_694 = split_69[7]; split_69 = None + cat_61 = torch.ops.aten.cat.default([getitem_687, getitem_688, getitem_689, getitem_690, getitem_691, getitem_692, getitem_693, getitem_694], 1); getitem_687 = getitem_688 = getitem_689 = getitem_690 = getitem_691 = getitem_692 = getitem_693 = getitem_694 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16) + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 32, '0'); convert_element_type_499 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + view_1095 = torch.ops.aten.view.default(cat_61, [16384, 4096]); cat_61 = None + mm_105 = torch.ops.aten.mm.default(view_1095, permute_165); permute_165 = None + view_1096 = torch.ops.aten.view.default(mm_105, [2, 8192, 512]) + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16) + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 32, '0'); convert_element_type_502 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + mm_106 = torch.ops.aten.mm.default(view_1095, permute_166); permute_166 = None + view_1103 = torch.ops.aten.view.default(mm_106, [2, 8192, 128]); mm_106 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16) + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 32, '0'); convert_element_type_505 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_107 = torch.ops.aten.mm.default(view_1095, permute_167); view_1095 = permute_167 = None + view_1110 = torch.ops.aten.view.default(mm_107, [2, 8192, 128]) + view_1112 = torch.ops.aten.view.default(view_1096, [2, 8192, -1, 128]); view_1096 = None + view_1113 = torch.ops.aten.view.default(view_1103, [2, 8192, -1, 128]); view_1103 = None + view_1114 = torch.ops.aten.view.default(view_1110, [2, 8192, -1, 128]); view_1110 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_1112, torch.float32); view_1112 = None + view_1115 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 4, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_1115); view_1115 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_1113, torch.float32); view_1113 = None + view_1116 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 1, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_1116); view_1116 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_37); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_1118 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 4, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_37); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_1119 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 1, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_1118, torch.bfloat16); view_1118 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_1119, torch.bfloat16); view_1119 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 1, 4, 128]); unsqueeze_30 = None + view_1120 = torch.ops.aten.view.default(expand_30, [2, 8192, 4, 128]); expand_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_1114, 3); view_1114 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 1, 4, 128]); unsqueeze_31 = None + view_1121 = torch.ops.aten.view.default(expand_31, [2, 8192, 4, 128]); expand_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_1120, [0, 2, 1, 3]); view_1120 = None + permute_170 = torch.ops.aten.permute.default(view_1121, [0, 2, 1, 3]); view_1121 = None + _scaled_dot_product_cudnn_attention_15 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_168, permute_169, permute_170, None, True, 0.0, True); permute_168 = permute_169 = permute_170 = None + getitem_695 = _scaled_dot_product_cudnn_attention_15[0] + getitem_696 = _scaled_dot_product_cudnn_attention_15[1] + getitem_701 = _scaled_dot_product_cudnn_attention_15[6] + getitem_702 = _scaled_dot_product_cudnn_attention_15[7]; _scaled_dot_product_cudnn_attention_15 = None + permute_171 = torch.ops.aten.permute.default(getitem_695, [0, 2, 1, 3]) + view_1122 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16) + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 32, '0'); convert_element_type_512 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + view_1128 = torch.ops.aten.view.default(view_1122, [16384, 512]); view_1122 = None + mm_108 = torch.ops.aten.mm.default(view_1128, permute_172); view_1128 = permute_172 = None + view_1129 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + split_70 = torch.ops.aten.split.Tensor(view_1129, 1024, 1); view_1129 = None + getitem_704 = split_70[0] + getitem_705 = split_70[1] + getitem_706 = split_70[2] + getitem_707 = split_70[3] + getitem_708 = split_70[4] + getitem_709 = split_70[5] + getitem_710 = split_70[6] + getitem_711 = split_70[7]; split_70 = None + cat_62 = torch.ops.aten.cat.default([getitem_704, getitem_705, getitem_706, getitem_707, getitem_708, getitem_709, getitem_710, getitem_711]); getitem_704 = getitem_705 = getitem_706 = getitem_707 = getitem_708 = getitem_709 = getitem_710 = getitem_711 = None + reduce_scatter_tensor_31 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_62, 'sum', 8, '1'); cat_62 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31) + add_61 = torch.ops.aten.add.Tensor(add_59, wait_tensor_203); wait_tensor_203 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16) + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 32, '0'); convert_element_type_515 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32) + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = rsqrt_31 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_204); mul_124 = wait_tensor_204 = None + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_517, 8, '1'); convert_element_type_517 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + split_71 = torch.ops.aten.split.Tensor(wait_tensor_205, 2); wait_tensor_205 = None + getitem_712 = split_71[0] + getitem_713 = split_71[1] + getitem_714 = split_71[2] + getitem_715 = split_71[3] + getitem_716 = split_71[4] + getitem_717 = split_71[5] + getitem_718 = split_71[6] + getitem_719 = split_71[7]; split_71 = None + cat_63 = torch.ops.aten.cat.default([getitem_712, getitem_713, getitem_714, getitem_715, getitem_716, getitem_717, getitem_718, getitem_719], 1); getitem_712 = getitem_713 = getitem_714 = getitem_715 = getitem_716 = getitem_717 = getitem_718 = getitem_719 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16) + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 32, '0'); convert_element_type_518 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + view_1140 = torch.ops.aten.view.default(cat_63, [16384, 4096]); cat_63 = None + mm_109 = torch.ops.aten.mm.default(view_1140, permute_173); permute_173 = None + view_1141 = torch.ops.aten.view.default(mm_109, [2, 8192, 1792]) + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_1141, torch.float32); view_1141 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); convert_element_type_521 = sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16) + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 32, '0'); convert_element_type_523 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + mm_110 = torch.ops.aten.mm.default(view_1140, permute_174); view_1140 = permute_174 = None + view_1148 = torch.ops.aten.view.default(mm_110, [2, 8192, 1792]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_1148); convert_element_type_522 = view_1148 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16) + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 32, '0'); convert_element_type_526 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_208, [1, 0]); wait_tensor_208 = None + view_1155 = torch.ops.aten.view.default(mul_127, [16384, 1792]); mul_127 = None + mm_111 = torch.ops.aten.mm.default(view_1155, permute_175); view_1155 = permute_175 = None + view_1156 = torch.ops.aten.view.default(mm_111, [2, 8192, 4096]); mm_111 = None + split_72 = torch.ops.aten.split.Tensor(view_1156, 1024, 1); view_1156 = None + getitem_720 = split_72[0] + getitem_721 = split_72[1] + getitem_722 = split_72[2] + getitem_723 = split_72[3] + getitem_724 = split_72[4] + getitem_725 = split_72[5] + getitem_726 = split_72[6] + getitem_727 = split_72[7]; split_72 = None + cat_64 = torch.ops.aten.cat.default([getitem_720, getitem_721, getitem_722, getitem_723, getitem_724, getitem_725, getitem_726, getitem_727]); getitem_720 = getitem_721 = getitem_722 = getitem_723 = getitem_724 = getitem_725 = getitem_726 = getitem_727 = None + reduce_scatter_tensor_32 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_64, 'sum', 8, '1'); cat_64 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + add_63 = torch.ops.aten.add.Tensor(add_61, wait_tensor_209); add_61 = wait_tensor_209 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16) + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 32, '0'); convert_element_type_529 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32) + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = rsqrt_32 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_210); mul_128 = wait_tensor_210 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_531, 8, '1'); convert_element_type_531 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + split_73 = torch.ops.aten.split.Tensor(wait_tensor_211, 2); wait_tensor_211 = None + getitem_728 = split_73[0] + getitem_729 = split_73[1] + getitem_730 = split_73[2] + getitem_731 = split_73[3] + getitem_732 = split_73[4] + getitem_733 = split_73[5] + getitem_734 = split_73[6] + getitem_735 = split_73[7]; split_73 = None + cat_65 = torch.ops.aten.cat.default([getitem_728, getitem_729, getitem_730, getitem_731, getitem_732, getitem_733, getitem_734, getitem_735], 1); getitem_728 = getitem_729 = getitem_730 = getitem_731 = getitem_732 = getitem_733 = getitem_734 = getitem_735 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16) + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 32, '0'); convert_element_type_532 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_1167 = torch.ops.aten.view.default(cat_65, [16384, 4096]); cat_65 = None + mm_112 = torch.ops.aten.mm.default(view_1167, permute_176); permute_176 = None + view_1168 = torch.ops.aten.view.default(mm_112, [2, 8192, 512]) + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16) + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 32, '0'); convert_element_type_535 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_213, [1, 0]); wait_tensor_213 = None + mm_113 = torch.ops.aten.mm.default(view_1167, permute_177); permute_177 = None + view_1175 = torch.ops.aten.view.default(mm_113, [2, 8192, 128]); mm_113 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16) + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 32, '0'); convert_element_type_538 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + mm_114 = torch.ops.aten.mm.default(view_1167, permute_178); view_1167 = permute_178 = None + view_1182 = torch.ops.aten.view.default(mm_114, [2, 8192, 128]) + view_1184 = torch.ops.aten.view.default(view_1168, [2, 8192, -1, 128]); view_1168 = None + view_1185 = torch.ops.aten.view.default(view_1175, [2, 8192, -1, 128]); view_1175 = None + view_1186 = torch.ops.aten.view.default(view_1182, [2, 8192, -1, 128]); view_1182 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_1184, torch.float32); view_1184 = None + view_1187 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 4, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_1187); view_1187 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_1185, torch.float32); view_1185 = None + view_1188 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 1, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_1188); view_1188 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_37); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_1190 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 4, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_37); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_1191 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 1, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_1190, torch.bfloat16); view_1190 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_1191, torch.bfloat16); view_1191 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 1, 4, 128]); unsqueeze_32 = None + view_1192 = torch.ops.aten.view.default(expand_32, [2, 8192, 4, 128]); expand_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_1186, 3); view_1186 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 1, 4, 128]); unsqueeze_33 = None + view_1193 = torch.ops.aten.view.default(expand_33, [2, 8192, 4, 128]); expand_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_1192, [0, 2, 1, 3]); view_1192 = None + permute_181 = torch.ops.aten.permute.default(view_1193, [0, 2, 1, 3]); view_1193 = None + _scaled_dot_product_cudnn_attention_16 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_179, permute_180, permute_181, None, True, 0.0, True); permute_179 = permute_180 = permute_181 = None + getitem_736 = _scaled_dot_product_cudnn_attention_16[0] + getitem_737 = _scaled_dot_product_cudnn_attention_16[1] + getitem_742 = _scaled_dot_product_cudnn_attention_16[6] + getitem_743 = _scaled_dot_product_cudnn_attention_16[7]; _scaled_dot_product_cudnn_attention_16 = None + permute_182 = torch.ops.aten.permute.default(getitem_736, [0, 2, 1, 3]) + view_1194 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16) + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 32, '0'); convert_element_type_545 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + view_1200 = torch.ops.aten.view.default(view_1194, [16384, 512]); view_1194 = None + mm_115 = torch.ops.aten.mm.default(view_1200, permute_183); view_1200 = permute_183 = None + view_1201 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + split_74 = torch.ops.aten.split.Tensor(view_1201, 1024, 1); view_1201 = None + getitem_745 = split_74[0] + getitem_746 = split_74[1] + getitem_747 = split_74[2] + getitem_748 = split_74[3] + getitem_749 = split_74[4] + getitem_750 = split_74[5] + getitem_751 = split_74[6] + getitem_752 = split_74[7]; split_74 = None + cat_66 = torch.ops.aten.cat.default([getitem_745, getitem_746, getitem_747, getitem_748, getitem_749, getitem_750, getitem_751, getitem_752]); getitem_745 = getitem_746 = getitem_747 = getitem_748 = getitem_749 = getitem_750 = getitem_751 = getitem_752 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_66, 'sum', 8, '1'); cat_66 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33) + add_65 = torch.ops.aten.add.Tensor(add_63, wait_tensor_216); wait_tensor_216 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16) + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 32, '0'); convert_element_type_548 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32) + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = rsqrt_33 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_217); mul_132 = wait_tensor_217 = None + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_550, 8, '1'); convert_element_type_550 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + split_75 = torch.ops.aten.split.Tensor(wait_tensor_218, 2); wait_tensor_218 = None + getitem_753 = split_75[0] + getitem_754 = split_75[1] + getitem_755 = split_75[2] + getitem_756 = split_75[3] + getitem_757 = split_75[4] + getitem_758 = split_75[5] + getitem_759 = split_75[6] + getitem_760 = split_75[7]; split_75 = None + cat_67 = torch.ops.aten.cat.default([getitem_753, getitem_754, getitem_755, getitem_756, getitem_757, getitem_758, getitem_759, getitem_760], 1); getitem_753 = getitem_754 = getitem_755 = getitem_756 = getitem_757 = getitem_758 = getitem_759 = getitem_760 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16) + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 32, '0'); convert_element_type_551 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + view_1212 = torch.ops.aten.view.default(cat_67, [16384, 4096]); cat_67 = None + mm_116 = torch.ops.aten.mm.default(view_1212, permute_184); permute_184 = None + view_1213 = torch.ops.aten.view.default(mm_116, [2, 8192, 1792]) + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_1213, torch.float32); view_1213 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); convert_element_type_554 = sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16) + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 32, '0'); convert_element_type_556 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_117 = torch.ops.aten.mm.default(view_1212, permute_185); view_1212 = permute_185 = None + view_1220 = torch.ops.aten.view.default(mm_117, [2, 8192, 1792]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_1220); convert_element_type_555 = view_1220 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16) + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 32, '0'); convert_element_type_559 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_1227 = torch.ops.aten.view.default(mul_135, [16384, 1792]); mul_135 = None + mm_118 = torch.ops.aten.mm.default(view_1227, permute_186); view_1227 = permute_186 = None + view_1228 = torch.ops.aten.view.default(mm_118, [2, 8192, 4096]); mm_118 = None + split_76 = torch.ops.aten.split.Tensor(view_1228, 1024, 1); view_1228 = None + getitem_761 = split_76[0] + getitem_762 = split_76[1] + getitem_763 = split_76[2] + getitem_764 = split_76[3] + getitem_765 = split_76[4] + getitem_766 = split_76[5] + getitem_767 = split_76[6] + getitem_768 = split_76[7]; split_76 = None + cat_68 = torch.ops.aten.cat.default([getitem_761, getitem_762, getitem_763, getitem_764, getitem_765, getitem_766, getitem_767, getitem_768]); getitem_761 = getitem_762 = getitem_763 = getitem_764 = getitem_765 = getitem_766 = getitem_767 = getitem_768 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_68, 'sum', 8, '1'); cat_68 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + add_67 = torch.ops.aten.add.Tensor(add_65, wait_tensor_222); add_65 = wait_tensor_222 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16) + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 32, '0'); convert_element_type_562 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32) + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = rsqrt_34 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_223); mul_136 = wait_tensor_223 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_564, 8, '1'); convert_element_type_564 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + split_77 = torch.ops.aten.split.Tensor(wait_tensor_224, 2); wait_tensor_224 = None + getitem_769 = split_77[0] + getitem_770 = split_77[1] + getitem_771 = split_77[2] + getitem_772 = split_77[3] + getitem_773 = split_77[4] + getitem_774 = split_77[5] + getitem_775 = split_77[6] + getitem_776 = split_77[7]; split_77 = None + cat_69 = torch.ops.aten.cat.default([getitem_769, getitem_770, getitem_771, getitem_772, getitem_773, getitem_774, getitem_775, getitem_776], 1); getitem_769 = getitem_770 = getitem_771 = getitem_772 = getitem_773 = getitem_774 = getitem_775 = getitem_776 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16) + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 32, '0'); convert_element_type_565 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + view_1239 = torch.ops.aten.view.default(cat_69, [16384, 4096]); cat_69 = None + mm_119 = torch.ops.aten.mm.default(view_1239, permute_187); permute_187 = None + view_1240 = torch.ops.aten.view.default(mm_119, [2, 8192, 512]) + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16) + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 32, '0'); convert_element_type_568 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_226, [1, 0]); wait_tensor_226 = None + mm_120 = torch.ops.aten.mm.default(view_1239, permute_188); permute_188 = None + view_1247 = torch.ops.aten.view.default(mm_120, [2, 8192, 128]); mm_120 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16) + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 32, '0'); convert_element_type_571 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + mm_121 = torch.ops.aten.mm.default(view_1239, permute_189); view_1239 = permute_189 = None + view_1254 = torch.ops.aten.view.default(mm_121, [2, 8192, 128]) + view_1256 = torch.ops.aten.view.default(view_1240, [2, 8192, -1, 128]); view_1240 = None + view_1257 = torch.ops.aten.view.default(view_1247, [2, 8192, -1, 128]); view_1247 = None + view_1258 = torch.ops.aten.view.default(view_1254, [2, 8192, -1, 128]); view_1254 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_1256, torch.float32); view_1256 = None + view_1259 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 4, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_1259); view_1259 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_1257, torch.float32); view_1257 = None + view_1260 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 1, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_1260); view_1260 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_37); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_1262 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 4, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_37); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_1263 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 1, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_1262, torch.bfloat16); view_1262 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_1263, torch.bfloat16); view_1263 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 1, 4, 128]); unsqueeze_34 = None + view_1264 = torch.ops.aten.view.default(expand_34, [2, 8192, 4, 128]); expand_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_1258, 3); view_1258 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 1, 4, 128]); unsqueeze_35 = None + view_1265 = torch.ops.aten.view.default(expand_35, [2, 8192, 4, 128]); expand_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_1264, [0, 2, 1, 3]); view_1264 = None + permute_192 = torch.ops.aten.permute.default(view_1265, [0, 2, 1, 3]); view_1265 = None + _scaled_dot_product_cudnn_attention_17 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_190, permute_191, permute_192, None, True, 0.0, True); permute_190 = permute_191 = permute_192 = None + getitem_777 = _scaled_dot_product_cudnn_attention_17[0] + getitem_778 = _scaled_dot_product_cudnn_attention_17[1] + getitem_783 = _scaled_dot_product_cudnn_attention_17[6] + getitem_784 = _scaled_dot_product_cudnn_attention_17[7]; _scaled_dot_product_cudnn_attention_17 = None + permute_193 = torch.ops.aten.permute.default(getitem_777, [0, 2, 1, 3]) + view_1266 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16) + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 32, '0'); convert_element_type_578 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + view_1272 = torch.ops.aten.view.default(view_1266, [16384, 512]); view_1266 = None + mm_122 = torch.ops.aten.mm.default(view_1272, permute_194); view_1272 = permute_194 = None + view_1273 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + split_78 = torch.ops.aten.split.Tensor(view_1273, 1024, 1); view_1273 = None + getitem_786 = split_78[0] + getitem_787 = split_78[1] + getitem_788 = split_78[2] + getitem_789 = split_78[3] + getitem_790 = split_78[4] + getitem_791 = split_78[5] + getitem_792 = split_78[6] + getitem_793 = split_78[7]; split_78 = None + cat_70 = torch.ops.aten.cat.default([getitem_786, getitem_787, getitem_788, getitem_789, getitem_790, getitem_791, getitem_792, getitem_793]); getitem_786 = getitem_787 = getitem_788 = getitem_789 = getitem_790 = getitem_791 = getitem_792 = getitem_793 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_70, 'sum', 8, '1'); cat_70 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35) + add_69 = torch.ops.aten.add.Tensor(add_67, wait_tensor_229); wait_tensor_229 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16) + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 32, '0'); convert_element_type_581 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32) + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = rsqrt_35 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_230); mul_140 = wait_tensor_230 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_583, 8, '1'); convert_element_type_583 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + split_79 = torch.ops.aten.split.Tensor(wait_tensor_231, 2); wait_tensor_231 = None + getitem_794 = split_79[0] + getitem_795 = split_79[1] + getitem_796 = split_79[2] + getitem_797 = split_79[3] + getitem_798 = split_79[4] + getitem_799 = split_79[5] + getitem_800 = split_79[6] + getitem_801 = split_79[7]; split_79 = None + cat_71 = torch.ops.aten.cat.default([getitem_794, getitem_795, getitem_796, getitem_797, getitem_798, getitem_799, getitem_800, getitem_801], 1); getitem_794 = getitem_795 = getitem_796 = getitem_797 = getitem_798 = getitem_799 = getitem_800 = getitem_801 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16) + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 32, '0'); convert_element_type_584 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + view_1284 = torch.ops.aten.view.default(cat_71, [16384, 4096]); cat_71 = None + mm_123 = torch.ops.aten.mm.default(view_1284, permute_195); permute_195 = None + view_1285 = torch.ops.aten.view.default(mm_123, [2, 8192, 1792]) + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_1285, torch.float32); view_1285 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); convert_element_type_587 = sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16) + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 32, '0'); convert_element_type_589 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_124 = torch.ops.aten.mm.default(view_1284, permute_196); view_1284 = permute_196 = None + view_1292 = torch.ops.aten.view.default(mm_124, [2, 8192, 1792]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_1292); convert_element_type_588 = view_1292 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16) + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 32, '0'); convert_element_type_592 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + view_1299 = torch.ops.aten.view.default(mul_143, [16384, 1792]); mul_143 = None + mm_125 = torch.ops.aten.mm.default(view_1299, permute_197); view_1299 = permute_197 = None + view_1300 = torch.ops.aten.view.default(mm_125, [2, 8192, 4096]); mm_125 = None + split_80 = torch.ops.aten.split.Tensor(view_1300, 1024, 1); view_1300 = None + getitem_802 = split_80[0] + getitem_803 = split_80[1] + getitem_804 = split_80[2] + getitem_805 = split_80[3] + getitem_806 = split_80[4] + getitem_807 = split_80[5] + getitem_808 = split_80[6] + getitem_809 = split_80[7]; split_80 = None + cat_72 = torch.ops.aten.cat.default([getitem_802, getitem_803, getitem_804, getitem_805, getitem_806, getitem_807, getitem_808, getitem_809]); getitem_802 = getitem_803 = getitem_804 = getitem_805 = getitem_806 = getitem_807 = getitem_808 = getitem_809 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_72, 'sum', 8, '1'); cat_72 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + add_71 = torch.ops.aten.add.Tensor(add_69, wait_tensor_235); add_69 = wait_tensor_235 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16) + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 32, '0'); convert_element_type_595 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32) + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = rsqrt_36 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_236); mul_144 = wait_tensor_236 = None + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_597, 8, '1'); convert_element_type_597 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + split_81 = torch.ops.aten.split.Tensor(wait_tensor_237, 2); wait_tensor_237 = None + getitem_810 = split_81[0] + getitem_811 = split_81[1] + getitem_812 = split_81[2] + getitem_813 = split_81[3] + getitem_814 = split_81[4] + getitem_815 = split_81[5] + getitem_816 = split_81[6] + getitem_817 = split_81[7]; split_81 = None + cat_73 = torch.ops.aten.cat.default([getitem_810, getitem_811, getitem_812, getitem_813, getitem_814, getitem_815, getitem_816, getitem_817], 1); getitem_810 = getitem_811 = getitem_812 = getitem_813 = getitem_814 = getitem_815 = getitem_816 = getitem_817 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16) + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 32, '0'); convert_element_type_598 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + view_1311 = torch.ops.aten.view.default(cat_73, [16384, 4096]); cat_73 = None + mm_126 = torch.ops.aten.mm.default(view_1311, permute_198); permute_198 = None + view_1312 = torch.ops.aten.view.default(mm_126, [2, 8192, 512]) + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16) + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 32, '0'); convert_element_type_601 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + mm_127 = torch.ops.aten.mm.default(view_1311, permute_199); permute_199 = None + view_1319 = torch.ops.aten.view.default(mm_127, [2, 8192, 128]); mm_127 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16) + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 32, '0'); convert_element_type_604 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_240, [1, 0]); wait_tensor_240 = None + mm_128 = torch.ops.aten.mm.default(view_1311, permute_200); view_1311 = permute_200 = None + view_1326 = torch.ops.aten.view.default(mm_128, [2, 8192, 128]) + view_1328 = torch.ops.aten.view.default(view_1312, [2, 8192, -1, 128]); view_1312 = None + view_1329 = torch.ops.aten.view.default(view_1319, [2, 8192, -1, 128]); view_1319 = None + view_1330 = torch.ops.aten.view.default(view_1326, [2, 8192, -1, 128]); view_1326 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_1328, torch.float32); view_1328 = None + view_1331 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 4, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_1331); view_1331 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_1329, torch.float32); view_1329 = None + view_1332 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 1, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_1332); view_1332 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_37); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_1334 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 4, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_37); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_1335 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 1, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_1334, torch.bfloat16); view_1334 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_1335, torch.bfloat16); view_1335 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 1, 4, 128]); unsqueeze_36 = None + view_1336 = torch.ops.aten.view.default(expand_36, [2, 8192, 4, 128]); expand_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_1330, 3); view_1330 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 1, 4, 128]); unsqueeze_37 = None + view_1337 = torch.ops.aten.view.default(expand_37, [2, 8192, 4, 128]); expand_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_1336, [0, 2, 1, 3]); view_1336 = None + permute_203 = torch.ops.aten.permute.default(view_1337, [0, 2, 1, 3]); view_1337 = None + _scaled_dot_product_cudnn_attention_18 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_201, permute_202, permute_203, None, True, 0.0, True); permute_201 = permute_202 = permute_203 = None + getitem_818 = _scaled_dot_product_cudnn_attention_18[0] + getitem_819 = _scaled_dot_product_cudnn_attention_18[1] + getitem_824 = _scaled_dot_product_cudnn_attention_18[6] + getitem_825 = _scaled_dot_product_cudnn_attention_18[7]; _scaled_dot_product_cudnn_attention_18 = None + permute_204 = torch.ops.aten.permute.default(getitem_818, [0, 2, 1, 3]) + view_1338 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16) + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 32, '0'); convert_element_type_611 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + view_1344 = torch.ops.aten.view.default(view_1338, [16384, 512]); view_1338 = None + mm_129 = torch.ops.aten.mm.default(view_1344, permute_205); view_1344 = permute_205 = None + view_1345 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + split_82 = torch.ops.aten.split.Tensor(view_1345, 1024, 1); view_1345 = None + getitem_827 = split_82[0] + getitem_828 = split_82[1] + getitem_829 = split_82[2] + getitem_830 = split_82[3] + getitem_831 = split_82[4] + getitem_832 = split_82[5] + getitem_833 = split_82[6] + getitem_834 = split_82[7]; split_82 = None + cat_74 = torch.ops.aten.cat.default([getitem_827, getitem_828, getitem_829, getitem_830, getitem_831, getitem_832, getitem_833, getitem_834]); getitem_827 = getitem_828 = getitem_829 = getitem_830 = getitem_831 = getitem_832 = getitem_833 = getitem_834 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_74, 'sum', 8, '1'); cat_74 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37) + add_73 = torch.ops.aten.add.Tensor(add_71, wait_tensor_242); wait_tensor_242 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16) + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 32, '0'); convert_element_type_614 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32) + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = rsqrt_37 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_243); mul_148 = wait_tensor_243 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_616, 8, '1'); convert_element_type_616 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + split_83 = torch.ops.aten.split.Tensor(wait_tensor_244, 2); wait_tensor_244 = None + getitem_835 = split_83[0] + getitem_836 = split_83[1] + getitem_837 = split_83[2] + getitem_838 = split_83[3] + getitem_839 = split_83[4] + getitem_840 = split_83[5] + getitem_841 = split_83[6] + getitem_842 = split_83[7]; split_83 = None + cat_75 = torch.ops.aten.cat.default([getitem_835, getitem_836, getitem_837, getitem_838, getitem_839, getitem_840, getitem_841, getitem_842], 1); getitem_835 = getitem_836 = getitem_837 = getitem_838 = getitem_839 = getitem_840 = getitem_841 = getitem_842 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16) + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 32, '0'); convert_element_type_617 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + view_1356 = torch.ops.aten.view.default(cat_75, [16384, 4096]); cat_75 = None + mm_130 = torch.ops.aten.mm.default(view_1356, permute_206); permute_206 = None + view_1357 = torch.ops.aten.view.default(mm_130, [2, 8192, 1792]) + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_1357, torch.float32); view_1357 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); convert_element_type_620 = sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16) + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 32, '0'); convert_element_type_622 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_131 = torch.ops.aten.mm.default(view_1356, permute_207); view_1356 = permute_207 = None + view_1364 = torch.ops.aten.view.default(mm_131, [2, 8192, 1792]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_1364); convert_element_type_621 = view_1364 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16) + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 32, '0'); convert_element_type_625 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + view_1371 = torch.ops.aten.view.default(mul_151, [16384, 1792]); mul_151 = None + mm_132 = torch.ops.aten.mm.default(view_1371, permute_208); view_1371 = permute_208 = None + view_1372 = torch.ops.aten.view.default(mm_132, [2, 8192, 4096]); mm_132 = None + split_84 = torch.ops.aten.split.Tensor(view_1372, 1024, 1); view_1372 = None + getitem_843 = split_84[0] + getitem_844 = split_84[1] + getitem_845 = split_84[2] + getitem_846 = split_84[3] + getitem_847 = split_84[4] + getitem_848 = split_84[5] + getitem_849 = split_84[6] + getitem_850 = split_84[7]; split_84 = None + cat_76 = torch.ops.aten.cat.default([getitem_843, getitem_844, getitem_845, getitem_846, getitem_847, getitem_848, getitem_849, getitem_850]); getitem_843 = getitem_844 = getitem_845 = getitem_846 = getitem_847 = getitem_848 = getitem_849 = getitem_850 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_76, 'sum', 8, '1'); cat_76 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + add_75 = torch.ops.aten.add.Tensor(add_73, wait_tensor_248); add_73 = wait_tensor_248 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16) + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 32, '0'); convert_element_type_628 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32) + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = rsqrt_38 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_249); mul_152 = wait_tensor_249 = None + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_630, 8, '1'); convert_element_type_630 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + split_85 = torch.ops.aten.split.Tensor(wait_tensor_250, 2); wait_tensor_250 = None + getitem_851 = split_85[0] + getitem_852 = split_85[1] + getitem_853 = split_85[2] + getitem_854 = split_85[3] + getitem_855 = split_85[4] + getitem_856 = split_85[5] + getitem_857 = split_85[6] + getitem_858 = split_85[7]; split_85 = None + cat_77 = torch.ops.aten.cat.default([getitem_851, getitem_852, getitem_853, getitem_854, getitem_855, getitem_856, getitem_857, getitem_858], 1); getitem_851 = getitem_852 = getitem_853 = getitem_854 = getitem_855 = getitem_856 = getitem_857 = getitem_858 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16) + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 32, '0'); convert_element_type_631 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + view_1383 = torch.ops.aten.view.default(cat_77, [16384, 4096]); cat_77 = None + mm_133 = torch.ops.aten.mm.default(view_1383, permute_209); permute_209 = None + view_1384 = torch.ops.aten.view.default(mm_133, [2, 8192, 512]) + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16) + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 32, '0'); convert_element_type_634 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + mm_134 = torch.ops.aten.mm.default(view_1383, permute_210); permute_210 = None + view_1391 = torch.ops.aten.view.default(mm_134, [2, 8192, 128]); mm_134 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16) + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 32, '0'); convert_element_type_637 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_253, [1, 0]); wait_tensor_253 = None + mm_135 = torch.ops.aten.mm.default(view_1383, permute_211); view_1383 = permute_211 = None + view_1398 = torch.ops.aten.view.default(mm_135, [2, 8192, 128]) + view_1400 = torch.ops.aten.view.default(view_1384, [2, 8192, -1, 128]); view_1384 = None + view_1401 = torch.ops.aten.view.default(view_1391, [2, 8192, -1, 128]); view_1391 = None + view_1402 = torch.ops.aten.view.default(view_1398, [2, 8192, -1, 128]); view_1398 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_1400, torch.float32); view_1400 = None + view_1403 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 4, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_1403); view_1403 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_1401, torch.float32); view_1401 = None + view_1404 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 1, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_1404); view_1404 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_37); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_1406 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 4, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_37); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_1407 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 1, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_1406, torch.bfloat16); view_1406 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_1407, torch.bfloat16); view_1407 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 1, 4, 128]); unsqueeze_38 = None + view_1408 = torch.ops.aten.view.default(expand_38, [2, 8192, 4, 128]); expand_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_1402, 3); view_1402 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 1, 4, 128]); unsqueeze_39 = None + view_1409 = torch.ops.aten.view.default(expand_39, [2, 8192, 4, 128]); expand_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_1408, [0, 2, 1, 3]); view_1408 = None + permute_214 = torch.ops.aten.permute.default(view_1409, [0, 2, 1, 3]); view_1409 = None + _scaled_dot_product_cudnn_attention_19 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_212, permute_213, permute_214, None, True, 0.0, True); permute_212 = permute_213 = permute_214 = None + getitem_859 = _scaled_dot_product_cudnn_attention_19[0] + getitem_860 = _scaled_dot_product_cudnn_attention_19[1] + getitem_865 = _scaled_dot_product_cudnn_attention_19[6] + getitem_866 = _scaled_dot_product_cudnn_attention_19[7]; _scaled_dot_product_cudnn_attention_19 = None + permute_215 = torch.ops.aten.permute.default(getitem_859, [0, 2, 1, 3]) + view_1410 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16) + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 32, '0'); convert_element_type_644 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + view_1416 = torch.ops.aten.view.default(view_1410, [16384, 512]); view_1410 = None + mm_136 = torch.ops.aten.mm.default(view_1416, permute_216); view_1416 = permute_216 = None + view_1417 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + split_86 = torch.ops.aten.split.Tensor(view_1417, 1024, 1); view_1417 = None + getitem_868 = split_86[0] + getitem_869 = split_86[1] + getitem_870 = split_86[2] + getitem_871 = split_86[3] + getitem_872 = split_86[4] + getitem_873 = split_86[5] + getitem_874 = split_86[6] + getitem_875 = split_86[7]; split_86 = None + cat_78 = torch.ops.aten.cat.default([getitem_868, getitem_869, getitem_870, getitem_871, getitem_872, getitem_873, getitem_874, getitem_875]); getitem_868 = getitem_869 = getitem_870 = getitem_871 = getitem_872 = getitem_873 = getitem_874 = getitem_875 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_78, 'sum', 8, '1'); cat_78 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39) + add_77 = torch.ops.aten.add.Tensor(add_75, wait_tensor_255); wait_tensor_255 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16) + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 32, '0'); convert_element_type_647 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32) + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = rsqrt_39 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_256); mul_156 = wait_tensor_256 = None + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_649, 8, '1'); convert_element_type_649 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + split_87 = torch.ops.aten.split.Tensor(wait_tensor_257, 2); wait_tensor_257 = None + getitem_876 = split_87[0] + getitem_877 = split_87[1] + getitem_878 = split_87[2] + getitem_879 = split_87[3] + getitem_880 = split_87[4] + getitem_881 = split_87[5] + getitem_882 = split_87[6] + getitem_883 = split_87[7]; split_87 = None + cat_79 = torch.ops.aten.cat.default([getitem_876, getitem_877, getitem_878, getitem_879, getitem_880, getitem_881, getitem_882, getitem_883], 1); getitem_876 = getitem_877 = getitem_878 = getitem_879 = getitem_880 = getitem_881 = getitem_882 = getitem_883 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16) + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 32, '0'); convert_element_type_650 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_258, [1, 0]); wait_tensor_258 = None + view_1428 = torch.ops.aten.view.default(cat_79, [16384, 4096]); cat_79 = None + mm_137 = torch.ops.aten.mm.default(view_1428, permute_217); permute_217 = None + view_1429 = torch.ops.aten.view.default(mm_137, [2, 8192, 1792]) + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_1429, torch.float32); view_1429 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); convert_element_type_653 = sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16) + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 32, '0'); convert_element_type_655 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + mm_138 = torch.ops.aten.mm.default(view_1428, permute_218); view_1428 = permute_218 = None + view_1436 = torch.ops.aten.view.default(mm_138, [2, 8192, 1792]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_1436); convert_element_type_654 = view_1436 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16) + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 32, '0'); convert_element_type_658 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + view_1443 = torch.ops.aten.view.default(mul_159, [16384, 1792]); mul_159 = None + mm_139 = torch.ops.aten.mm.default(view_1443, permute_219); view_1443 = permute_219 = None + view_1444 = torch.ops.aten.view.default(mm_139, [2, 8192, 4096]); mm_139 = None + split_88 = torch.ops.aten.split.Tensor(view_1444, 1024, 1); view_1444 = None + getitem_884 = split_88[0] + getitem_885 = split_88[1] + getitem_886 = split_88[2] + getitem_887 = split_88[3] + getitem_888 = split_88[4] + getitem_889 = split_88[5] + getitem_890 = split_88[6] + getitem_891 = split_88[7]; split_88 = None + cat_80 = torch.ops.aten.cat.default([getitem_884, getitem_885, getitem_886, getitem_887, getitem_888, getitem_889, getitem_890, getitem_891]); getitem_884 = getitem_885 = getitem_886 = getitem_887 = getitem_888 = getitem_889 = getitem_890 = getitem_891 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_80, 'sum', 8, '1'); cat_80 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + add_79 = torch.ops.aten.add.Tensor(add_77, wait_tensor_261); add_77 = wait_tensor_261 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16) + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 32, '0'); convert_element_type_661 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32) + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = rsqrt_40 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_262); mul_160 = wait_tensor_262 = None + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_663, 8, '1'); convert_element_type_663 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + split_89 = torch.ops.aten.split.Tensor(wait_tensor_263, 2); wait_tensor_263 = None + getitem_892 = split_89[0] + getitem_893 = split_89[1] + getitem_894 = split_89[2] + getitem_895 = split_89[3] + getitem_896 = split_89[4] + getitem_897 = split_89[5] + getitem_898 = split_89[6] + getitem_899 = split_89[7]; split_89 = None + cat_81 = torch.ops.aten.cat.default([getitem_892, getitem_893, getitem_894, getitem_895, getitem_896, getitem_897, getitem_898, getitem_899], 1); getitem_892 = getitem_893 = getitem_894 = getitem_895 = getitem_896 = getitem_897 = getitem_898 = getitem_899 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16) + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 32, '0'); convert_element_type_664 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + view_1455 = torch.ops.aten.view.default(cat_81, [16384, 4096]); cat_81 = None + mm_140 = torch.ops.aten.mm.default(view_1455, permute_220); permute_220 = None + view_1456 = torch.ops.aten.view.default(mm_140, [2, 8192, 512]) + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16) + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 32, '0'); convert_element_type_667 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_141 = torch.ops.aten.mm.default(view_1455, permute_221); permute_221 = None + view_1463 = torch.ops.aten.view.default(mm_141, [2, 8192, 128]); mm_141 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16) + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 32, '0'); convert_element_type_670 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + mm_142 = torch.ops.aten.mm.default(view_1455, permute_222); view_1455 = permute_222 = None + view_1470 = torch.ops.aten.view.default(mm_142, [2, 8192, 128]) + view_1472 = torch.ops.aten.view.default(view_1456, [2, 8192, -1, 128]); view_1456 = None + view_1473 = torch.ops.aten.view.default(view_1463, [2, 8192, -1, 128]); view_1463 = None + view_1474 = torch.ops.aten.view.default(view_1470, [2, 8192, -1, 128]); view_1470 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_1472, torch.float32); view_1472 = None + view_1475 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 4, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_1475); view_1475 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_1473, torch.float32); view_1473 = None + view_1476 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 1, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_1476); view_1476 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_37); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_1478 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 4, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_37); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_1479 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 1, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_1478, torch.bfloat16); view_1478 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_1479, torch.bfloat16); view_1479 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 1, 4, 128]); unsqueeze_40 = None + view_1480 = torch.ops.aten.view.default(expand_40, [2, 8192, 4, 128]); expand_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_1474, 3); view_1474 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 1, 4, 128]); unsqueeze_41 = None + view_1481 = torch.ops.aten.view.default(expand_41, [2, 8192, 4, 128]); expand_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_1480, [0, 2, 1, 3]); view_1480 = None + permute_225 = torch.ops.aten.permute.default(view_1481, [0, 2, 1, 3]); view_1481 = None + _scaled_dot_product_cudnn_attention_20 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_223, permute_224, permute_225, None, True, 0.0, True); permute_223 = permute_224 = permute_225 = None + getitem_900 = _scaled_dot_product_cudnn_attention_20[0] + getitem_901 = _scaled_dot_product_cudnn_attention_20[1] + getitem_906 = _scaled_dot_product_cudnn_attention_20[6] + getitem_907 = _scaled_dot_product_cudnn_attention_20[7]; _scaled_dot_product_cudnn_attention_20 = None + permute_226 = torch.ops.aten.permute.default(getitem_900, [0, 2, 1, 3]) + view_1482 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16) + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 32, '0'); convert_element_type_677 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_267, [1, 0]); wait_tensor_267 = None + view_1488 = torch.ops.aten.view.default(view_1482, [16384, 512]); view_1482 = None + mm_143 = torch.ops.aten.mm.default(view_1488, permute_227); view_1488 = permute_227 = None + view_1489 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + split_90 = torch.ops.aten.split.Tensor(view_1489, 1024, 1); view_1489 = None + getitem_909 = split_90[0] + getitem_910 = split_90[1] + getitem_911 = split_90[2] + getitem_912 = split_90[3] + getitem_913 = split_90[4] + getitem_914 = split_90[5] + getitem_915 = split_90[6] + getitem_916 = split_90[7]; split_90 = None + cat_82 = torch.ops.aten.cat.default([getitem_909, getitem_910, getitem_911, getitem_912, getitem_913, getitem_914, getitem_915, getitem_916]); getitem_909 = getitem_910 = getitem_911 = getitem_912 = getitem_913 = getitem_914 = getitem_915 = getitem_916 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_82, 'sum', 8, '1'); cat_82 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41) + add_81 = torch.ops.aten.add.Tensor(add_79, wait_tensor_268); wait_tensor_268 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16) + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 32, '0'); convert_element_type_680 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32) + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = rsqrt_41 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_269); mul_164 = wait_tensor_269 = None + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_682, 8, '1'); convert_element_type_682 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + split_91 = torch.ops.aten.split.Tensor(wait_tensor_270, 2); wait_tensor_270 = None + getitem_917 = split_91[0] + getitem_918 = split_91[1] + getitem_919 = split_91[2] + getitem_920 = split_91[3] + getitem_921 = split_91[4] + getitem_922 = split_91[5] + getitem_923 = split_91[6] + getitem_924 = split_91[7]; split_91 = None + cat_83 = torch.ops.aten.cat.default([getitem_917, getitem_918, getitem_919, getitem_920, getitem_921, getitem_922, getitem_923, getitem_924], 1); getitem_917 = getitem_918 = getitem_919 = getitem_920 = getitem_921 = getitem_922 = getitem_923 = getitem_924 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16) + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 32, '0'); convert_element_type_683 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_271, [1, 0]); wait_tensor_271 = None + view_1500 = torch.ops.aten.view.default(cat_83, [16384, 4096]); cat_83 = None + mm_144 = torch.ops.aten.mm.default(view_1500, permute_228); permute_228 = None + view_1501 = torch.ops.aten.view.default(mm_144, [2, 8192, 1792]) + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_1501, torch.float32); view_1501 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); convert_element_type_686 = sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16) + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 32, '0'); convert_element_type_688 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + mm_145 = torch.ops.aten.mm.default(view_1500, permute_229); view_1500 = permute_229 = None + view_1508 = torch.ops.aten.view.default(mm_145, [2, 8192, 1792]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_1508); convert_element_type_687 = view_1508 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16) + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 32, '0'); convert_element_type_691 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + view_1515 = torch.ops.aten.view.default(mul_167, [16384, 1792]); mul_167 = None + mm_146 = torch.ops.aten.mm.default(view_1515, permute_230); view_1515 = permute_230 = None + view_1516 = torch.ops.aten.view.default(mm_146, [2, 8192, 4096]); mm_146 = None + split_92 = torch.ops.aten.split.Tensor(view_1516, 1024, 1); view_1516 = None + getitem_925 = split_92[0] + getitem_926 = split_92[1] + getitem_927 = split_92[2] + getitem_928 = split_92[3] + getitem_929 = split_92[4] + getitem_930 = split_92[5] + getitem_931 = split_92[6] + getitem_932 = split_92[7]; split_92 = None + cat_84 = torch.ops.aten.cat.default([getitem_925, getitem_926, getitem_927, getitem_928, getitem_929, getitem_930, getitem_931, getitem_932]); getitem_925 = getitem_926 = getitem_927 = getitem_928 = getitem_929 = getitem_930 = getitem_931 = getitem_932 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_84, 'sum', 8, '1'); cat_84 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + add_83 = torch.ops.aten.add.Tensor(add_81, wait_tensor_274); add_81 = wait_tensor_274 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16) + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 32, '0'); convert_element_type_694 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32) + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = rsqrt_42 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_275); mul_168 = wait_tensor_275 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_696, 8, '1'); convert_element_type_696 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + split_93 = torch.ops.aten.split.Tensor(wait_tensor_276, 2); wait_tensor_276 = None + getitem_933 = split_93[0] + getitem_934 = split_93[1] + getitem_935 = split_93[2] + getitem_936 = split_93[3] + getitem_937 = split_93[4] + getitem_938 = split_93[5] + getitem_939 = split_93[6] + getitem_940 = split_93[7]; split_93 = None + cat_85 = torch.ops.aten.cat.default([getitem_933, getitem_934, getitem_935, getitem_936, getitem_937, getitem_938, getitem_939, getitem_940], 1); getitem_933 = getitem_934 = getitem_935 = getitem_936 = getitem_937 = getitem_938 = getitem_939 = getitem_940 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16) + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 32, '0'); convert_element_type_697 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + view_1527 = torch.ops.aten.view.default(cat_85, [16384, 4096]); cat_85 = None + mm_147 = torch.ops.aten.mm.default(view_1527, permute_231); permute_231 = None + view_1528 = torch.ops.aten.view.default(mm_147, [2, 8192, 512]) + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16) + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 32, '0'); convert_element_type_700 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_148 = torch.ops.aten.mm.default(view_1527, permute_232); permute_232 = None + view_1535 = torch.ops.aten.view.default(mm_148, [2, 8192, 128]); mm_148 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16) + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 32, '0'); convert_element_type_703 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + mm_149 = torch.ops.aten.mm.default(view_1527, permute_233); view_1527 = permute_233 = None + view_1542 = torch.ops.aten.view.default(mm_149, [2, 8192, 128]) + view_1544 = torch.ops.aten.view.default(view_1528, [2, 8192, -1, 128]); view_1528 = None + view_1545 = torch.ops.aten.view.default(view_1535, [2, 8192, -1, 128]); view_1535 = None + view_1546 = torch.ops.aten.view.default(view_1542, [2, 8192, -1, 128]); view_1542 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_1544, torch.float32); view_1544 = None + view_1547 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 4, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_1547); view_1547 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_1545, torch.float32); view_1545 = None + view_1548 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 1, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_1548); view_1548 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_37); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_1550 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 4, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_37); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_1551 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 1, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_1550, torch.bfloat16); view_1550 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_1551, torch.bfloat16); view_1551 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 1, 4, 128]); unsqueeze_42 = None + view_1552 = torch.ops.aten.view.default(expand_42, [2, 8192, 4, 128]); expand_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_1546, 3); view_1546 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 1, 4, 128]); unsqueeze_43 = None + view_1553 = torch.ops.aten.view.default(expand_43, [2, 8192, 4, 128]); expand_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_1552, [0, 2, 1, 3]); view_1552 = None + permute_236 = torch.ops.aten.permute.default(view_1553, [0, 2, 1, 3]); view_1553 = None + _scaled_dot_product_cudnn_attention_21 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_234, permute_235, permute_236, None, True, 0.0, True); permute_234 = permute_235 = permute_236 = None + getitem_941 = _scaled_dot_product_cudnn_attention_21[0] + getitem_942 = _scaled_dot_product_cudnn_attention_21[1] + getitem_947 = _scaled_dot_product_cudnn_attention_21[6] + getitem_948 = _scaled_dot_product_cudnn_attention_21[7]; _scaled_dot_product_cudnn_attention_21 = None + permute_237 = torch.ops.aten.permute.default(getitem_941, [0, 2, 1, 3]) + view_1554 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16) + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 32, '0'); convert_element_type_710 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_280, [1, 0]); wait_tensor_280 = None + view_1560 = torch.ops.aten.view.default(view_1554, [16384, 512]); view_1554 = None + mm_150 = torch.ops.aten.mm.default(view_1560, permute_238); view_1560 = permute_238 = None + view_1561 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + split_94 = torch.ops.aten.split.Tensor(view_1561, 1024, 1); view_1561 = None + getitem_950 = split_94[0] + getitem_951 = split_94[1] + getitem_952 = split_94[2] + getitem_953 = split_94[3] + getitem_954 = split_94[4] + getitem_955 = split_94[5] + getitem_956 = split_94[6] + getitem_957 = split_94[7]; split_94 = None + cat_86 = torch.ops.aten.cat.default([getitem_950, getitem_951, getitem_952, getitem_953, getitem_954, getitem_955, getitem_956, getitem_957]); getitem_950 = getitem_951 = getitem_952 = getitem_953 = getitem_954 = getitem_955 = getitem_956 = getitem_957 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_86, 'sum', 8, '1'); cat_86 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43) + add_85 = torch.ops.aten.add.Tensor(add_83, wait_tensor_281); wait_tensor_281 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16) + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 32, '0'); convert_element_type_713 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32) + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = rsqrt_43 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_282); mul_172 = wait_tensor_282 = None + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_715, 8, '1'); convert_element_type_715 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + split_95 = torch.ops.aten.split.Tensor(wait_tensor_283, 2); wait_tensor_283 = None + getitem_958 = split_95[0] + getitem_959 = split_95[1] + getitem_960 = split_95[2] + getitem_961 = split_95[3] + getitem_962 = split_95[4] + getitem_963 = split_95[5] + getitem_964 = split_95[6] + getitem_965 = split_95[7]; split_95 = None + cat_87 = torch.ops.aten.cat.default([getitem_958, getitem_959, getitem_960, getitem_961, getitem_962, getitem_963, getitem_964, getitem_965], 1); getitem_958 = getitem_959 = getitem_960 = getitem_961 = getitem_962 = getitem_963 = getitem_964 = getitem_965 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16) + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 32, '0'); convert_element_type_716 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1572 = torch.ops.aten.view.default(cat_87, [16384, 4096]); cat_87 = None + mm_151 = torch.ops.aten.mm.default(view_1572, permute_239); permute_239 = None + view_1573 = torch.ops.aten.view.default(mm_151, [2, 8192, 1792]) + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_1573, torch.float32); view_1573 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); convert_element_type_719 = sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16) + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 32, '0'); convert_element_type_721 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_285, [1, 0]); wait_tensor_285 = None + mm_152 = torch.ops.aten.mm.default(view_1572, permute_240); view_1572 = permute_240 = None + view_1580 = torch.ops.aten.view.default(mm_152, [2, 8192, 1792]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_1580); convert_element_type_720 = view_1580 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16) + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 32, '0'); convert_element_type_724 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + view_1587 = torch.ops.aten.view.default(mul_175, [16384, 1792]); mul_175 = None + mm_153 = torch.ops.aten.mm.default(view_1587, permute_241); view_1587 = permute_241 = None + view_1588 = torch.ops.aten.view.default(mm_153, [2, 8192, 4096]); mm_153 = None + split_96 = torch.ops.aten.split.Tensor(view_1588, 1024, 1); view_1588 = None + getitem_966 = split_96[0] + getitem_967 = split_96[1] + getitem_968 = split_96[2] + getitem_969 = split_96[3] + getitem_970 = split_96[4] + getitem_971 = split_96[5] + getitem_972 = split_96[6] + getitem_973 = split_96[7]; split_96 = None + cat_88 = torch.ops.aten.cat.default([getitem_966, getitem_967, getitem_968, getitem_969, getitem_970, getitem_971, getitem_972, getitem_973]); getitem_966 = getitem_967 = getitem_968 = getitem_969 = getitem_970 = getitem_971 = getitem_972 = getitem_973 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_88, 'sum', 8, '1'); cat_88 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + add_87 = torch.ops.aten.add.Tensor(add_85, wait_tensor_287); add_85 = wait_tensor_287 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16) + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 32, '0'); convert_element_type_727 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32) + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = rsqrt_44 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_288); mul_176 = wait_tensor_288 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_729, 8, '1'); convert_element_type_729 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + split_97 = torch.ops.aten.split.Tensor(wait_tensor_289, 2); wait_tensor_289 = None + getitem_974 = split_97[0] + getitem_975 = split_97[1] + getitem_976 = split_97[2] + getitem_977 = split_97[3] + getitem_978 = split_97[4] + getitem_979 = split_97[5] + getitem_980 = split_97[6] + getitem_981 = split_97[7]; split_97 = None + cat_89 = torch.ops.aten.cat.default([getitem_974, getitem_975, getitem_976, getitem_977, getitem_978, getitem_979, getitem_980, getitem_981], 1); getitem_974 = getitem_975 = getitem_976 = getitem_977 = getitem_978 = getitem_979 = getitem_980 = getitem_981 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16) + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 32, '0'); convert_element_type_730 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + view_1599 = torch.ops.aten.view.default(cat_89, [16384, 4096]); cat_89 = None + mm_154 = torch.ops.aten.mm.default(view_1599, permute_242); permute_242 = None + view_1600 = torch.ops.aten.view.default(mm_154, [2, 8192, 512]) + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16) + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 32, '0'); convert_element_type_733 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_291, [1, 0]); wait_tensor_291 = None + mm_155 = torch.ops.aten.mm.default(view_1599, permute_243); permute_243 = None + view_1607 = torch.ops.aten.view.default(mm_155, [2, 8192, 128]); mm_155 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16) + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 32, '0'); convert_element_type_736 = None + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_292, [1, 0]); wait_tensor_292 = None + mm_156 = torch.ops.aten.mm.default(view_1599, permute_244); view_1599 = permute_244 = None + view_1614 = torch.ops.aten.view.default(mm_156, [2, 8192, 128]) + view_1616 = torch.ops.aten.view.default(view_1600, [2, 8192, -1, 128]); view_1600 = None + view_1617 = torch.ops.aten.view.default(view_1607, [2, 8192, -1, 128]); view_1607 = None + view_1618 = torch.ops.aten.view.default(view_1614, [2, 8192, -1, 128]); view_1614 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_1616, torch.float32); view_1616 = None + view_1619 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 4, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_1619); view_1619 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_1617, torch.float32); view_1617 = None + view_1620 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 1, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_1620); view_1620 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_37); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_1622 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 4, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_37); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_1623 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 1, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_1622, torch.bfloat16); view_1622 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_1623, torch.bfloat16); view_1623 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 1, 4, 128]); unsqueeze_44 = None + view_1624 = torch.ops.aten.view.default(expand_44, [2, 8192, 4, 128]); expand_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_1618, 3); view_1618 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 1, 4, 128]); unsqueeze_45 = None + view_1625 = torch.ops.aten.view.default(expand_45, [2, 8192, 4, 128]); expand_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_1624, [0, 2, 1, 3]); view_1624 = None + permute_247 = torch.ops.aten.permute.default(view_1625, [0, 2, 1, 3]); view_1625 = None + _scaled_dot_product_cudnn_attention_22 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_245, permute_246, permute_247, None, True, 0.0, True); permute_245 = permute_246 = permute_247 = None + getitem_982 = _scaled_dot_product_cudnn_attention_22[0] + getitem_983 = _scaled_dot_product_cudnn_attention_22[1] + getitem_988 = _scaled_dot_product_cudnn_attention_22[6] + getitem_989 = _scaled_dot_product_cudnn_attention_22[7]; _scaled_dot_product_cudnn_attention_22 = None + permute_248 = torch.ops.aten.permute.default(getitem_982, [0, 2, 1, 3]) + view_1626 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16) + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 32, '0'); convert_element_type_743 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_293, [1, 0]); wait_tensor_293 = None + view_1632 = torch.ops.aten.view.default(view_1626, [16384, 512]); view_1626 = None + mm_157 = torch.ops.aten.mm.default(view_1632, permute_249); view_1632 = permute_249 = None + view_1633 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + split_98 = torch.ops.aten.split.Tensor(view_1633, 1024, 1); view_1633 = None + getitem_991 = split_98[0] + getitem_992 = split_98[1] + getitem_993 = split_98[2] + getitem_994 = split_98[3] + getitem_995 = split_98[4] + getitem_996 = split_98[5] + getitem_997 = split_98[6] + getitem_998 = split_98[7]; split_98 = None + cat_90 = torch.ops.aten.cat.default([getitem_991, getitem_992, getitem_993, getitem_994, getitem_995, getitem_996, getitem_997, getitem_998]); getitem_991 = getitem_992 = getitem_993 = getitem_994 = getitem_995 = getitem_996 = getitem_997 = getitem_998 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_90, 'sum', 8, '1'); cat_90 = None + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45) + add_89 = torch.ops.aten.add.Tensor(add_87, wait_tensor_294); wait_tensor_294 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16) + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 32, '0'); convert_element_type_746 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32) + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = rsqrt_45 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_295); mul_180 = wait_tensor_295 = None + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_748, 8, '1'); convert_element_type_748 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + split_99 = torch.ops.aten.split.Tensor(wait_tensor_296, 2); wait_tensor_296 = None + getitem_999 = split_99[0] + getitem_1000 = split_99[1] + getitem_1001 = split_99[2] + getitem_1002 = split_99[3] + getitem_1003 = split_99[4] + getitem_1004 = split_99[5] + getitem_1005 = split_99[6] + getitem_1006 = split_99[7]; split_99 = None + cat_91 = torch.ops.aten.cat.default([getitem_999, getitem_1000, getitem_1001, getitem_1002, getitem_1003, getitem_1004, getitem_1005, getitem_1006], 1); getitem_999 = getitem_1000 = getitem_1001 = getitem_1002 = getitem_1003 = getitem_1004 = getitem_1005 = getitem_1006 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16) + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 32, '0'); convert_element_type_749 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_297, [1, 0]); wait_tensor_297 = None + view_1644 = torch.ops.aten.view.default(cat_91, [16384, 4096]); cat_91 = None + mm_158 = torch.ops.aten.mm.default(view_1644, permute_250); permute_250 = None + view_1645 = torch.ops.aten.view.default(mm_158, [2, 8192, 1792]) + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_1645, torch.float32); view_1645 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); convert_element_type_752 = sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16) + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 32, '0'); convert_element_type_754 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_298, [1, 0]); wait_tensor_298 = None + mm_159 = torch.ops.aten.mm.default(view_1644, permute_251); view_1644 = permute_251 = None + view_1652 = torch.ops.aten.view.default(mm_159, [2, 8192, 1792]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_1652); convert_element_type_753 = view_1652 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16) + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 32, '0'); convert_element_type_757 = None + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_299, [1, 0]); wait_tensor_299 = None + view_1659 = torch.ops.aten.view.default(mul_183, [16384, 1792]); mul_183 = None + mm_160 = torch.ops.aten.mm.default(view_1659, permute_252); view_1659 = permute_252 = None + view_1660 = torch.ops.aten.view.default(mm_160, [2, 8192, 4096]); mm_160 = None + split_100 = torch.ops.aten.split.Tensor(view_1660, 1024, 1); view_1660 = None + getitem_1007 = split_100[0] + getitem_1008 = split_100[1] + getitem_1009 = split_100[2] + getitem_1010 = split_100[3] + getitem_1011 = split_100[4] + getitem_1012 = split_100[5] + getitem_1013 = split_100[6] + getitem_1014 = split_100[7]; split_100 = None + cat_92 = torch.ops.aten.cat.default([getitem_1007, getitem_1008, getitem_1009, getitem_1010, getitem_1011, getitem_1012, getitem_1013, getitem_1014]); getitem_1007 = getitem_1008 = getitem_1009 = getitem_1010 = getitem_1011 = getitem_1012 = getitem_1013 = getitem_1014 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_92, 'sum', 8, '1'); cat_92 = None + wait_tensor_300 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + add_91 = torch.ops.aten.add.Tensor(add_89, wait_tensor_300); add_89 = wait_tensor_300 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16) + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 32, '0'); convert_element_type_760 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32) + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = rsqrt_46 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_301); mul_184 = wait_tensor_301 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_762, 8, '1'); convert_element_type_762 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + split_101 = torch.ops.aten.split.Tensor(wait_tensor_302, 2); wait_tensor_302 = None + getitem_1015 = split_101[0] + getitem_1016 = split_101[1] + getitem_1017 = split_101[2] + getitem_1018 = split_101[3] + getitem_1019 = split_101[4] + getitem_1020 = split_101[5] + getitem_1021 = split_101[6] + getitem_1022 = split_101[7]; split_101 = None + cat_93 = torch.ops.aten.cat.default([getitem_1015, getitem_1016, getitem_1017, getitem_1018, getitem_1019, getitem_1020, getitem_1021, getitem_1022], 1); getitem_1015 = getitem_1016 = getitem_1017 = getitem_1018 = getitem_1019 = getitem_1020 = getitem_1021 = getitem_1022 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16) + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 32, '0'); convert_element_type_763 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_303, [1, 0]); wait_tensor_303 = None + view_1671 = torch.ops.aten.view.default(cat_93, [16384, 4096]); cat_93 = None + mm_161 = torch.ops.aten.mm.default(view_1671, permute_253); permute_253 = None + view_1672 = torch.ops.aten.view.default(mm_161, [2, 8192, 512]) + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16) + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 32, '0'); convert_element_type_766 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_304, [1, 0]); wait_tensor_304 = None + mm_162 = torch.ops.aten.mm.default(view_1671, permute_254); permute_254 = None + view_1679 = torch.ops.aten.view.default(mm_162, [2, 8192, 128]); mm_162 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16) + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 32, '0'); convert_element_type_769 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_305, [1, 0]); wait_tensor_305 = None + mm_163 = torch.ops.aten.mm.default(view_1671, permute_255); view_1671 = permute_255 = None + view_1686 = torch.ops.aten.view.default(mm_163, [2, 8192, 128]) + view_1688 = torch.ops.aten.view.default(view_1672, [2, 8192, -1, 128]); view_1672 = None + view_1689 = torch.ops.aten.view.default(view_1679, [2, 8192, -1, 128]); view_1679 = None + view_1690 = torch.ops.aten.view.default(view_1686, [2, 8192, -1, 128]); view_1686 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_1688, torch.float32); view_1688 = None + view_1691 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 4, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_1691); view_1691 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_1689, torch.float32); view_1689 = None + view_1692 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 1, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_1692); view_1692 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_37); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_1694 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 4, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_37); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_1695 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 1, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_1694, torch.bfloat16); view_1694 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_1695, torch.bfloat16); view_1695 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 1, 4, 128]); unsqueeze_46 = None + view_1696 = torch.ops.aten.view.default(expand_46, [2, 8192, 4, 128]); expand_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_1690, 3); view_1690 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 1, 4, 128]); unsqueeze_47 = None + view_1697 = torch.ops.aten.view.default(expand_47, [2, 8192, 4, 128]); expand_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_1696, [0, 2, 1, 3]); view_1696 = None + permute_258 = torch.ops.aten.permute.default(view_1697, [0, 2, 1, 3]); view_1697 = None + _scaled_dot_product_cudnn_attention_23 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_256, permute_257, permute_258, None, True, 0.0, True); permute_256 = permute_257 = permute_258 = None + getitem_1023 = _scaled_dot_product_cudnn_attention_23[0] + getitem_1024 = _scaled_dot_product_cudnn_attention_23[1] + getitem_1029 = _scaled_dot_product_cudnn_attention_23[6] + getitem_1030 = _scaled_dot_product_cudnn_attention_23[7]; _scaled_dot_product_cudnn_attention_23 = None + permute_259 = torch.ops.aten.permute.default(getitem_1023, [0, 2, 1, 3]) + view_1698 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16) + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 32, '0'); convert_element_type_776 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_306, [1, 0]); wait_tensor_306 = None + view_1704 = torch.ops.aten.view.default(view_1698, [16384, 512]); view_1698 = None + mm_164 = torch.ops.aten.mm.default(view_1704, permute_260); view_1704 = permute_260 = None + view_1705 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + split_102 = torch.ops.aten.split.Tensor(view_1705, 1024, 1); view_1705 = None + getitem_1032 = split_102[0] + getitem_1033 = split_102[1] + getitem_1034 = split_102[2] + getitem_1035 = split_102[3] + getitem_1036 = split_102[4] + getitem_1037 = split_102[5] + getitem_1038 = split_102[6] + getitem_1039 = split_102[7]; split_102 = None + cat_94 = torch.ops.aten.cat.default([getitem_1032, getitem_1033, getitem_1034, getitem_1035, getitem_1036, getitem_1037, getitem_1038, getitem_1039]); getitem_1032 = getitem_1033 = getitem_1034 = getitem_1035 = getitem_1036 = getitem_1037 = getitem_1038 = getitem_1039 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_94, 'sum', 8, '1'); cat_94 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47) + add_93 = torch.ops.aten.add.Tensor(add_91, wait_tensor_307); wait_tensor_307 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16) + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 32, '0'); convert_element_type_779 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32) + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = rsqrt_47 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_308); mul_188 = wait_tensor_308 = None + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_781, 8, '1'); convert_element_type_781 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + split_103 = torch.ops.aten.split.Tensor(wait_tensor_309, 2); wait_tensor_309 = None + getitem_1040 = split_103[0] + getitem_1041 = split_103[1] + getitem_1042 = split_103[2] + getitem_1043 = split_103[3] + getitem_1044 = split_103[4] + getitem_1045 = split_103[5] + getitem_1046 = split_103[6] + getitem_1047 = split_103[7]; split_103 = None + cat_95 = torch.ops.aten.cat.default([getitem_1040, getitem_1041, getitem_1042, getitem_1043, getitem_1044, getitem_1045, getitem_1046, getitem_1047], 1); getitem_1040 = getitem_1041 = getitem_1042 = getitem_1043 = getitem_1044 = getitem_1045 = getitem_1046 = getitem_1047 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16) + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 32, '0'); convert_element_type_782 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_310, [1, 0]); wait_tensor_310 = None + view_1716 = torch.ops.aten.view.default(cat_95, [16384, 4096]); cat_95 = None + mm_165 = torch.ops.aten.mm.default(view_1716, permute_261); permute_261 = None + view_1717 = torch.ops.aten.view.default(mm_165, [2, 8192, 1792]) + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_1717, torch.float32); view_1717 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); convert_element_type_785 = sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16) + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 32, '0'); convert_element_type_787 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_311, [1, 0]); wait_tensor_311 = None + mm_166 = torch.ops.aten.mm.default(view_1716, permute_262); view_1716 = permute_262 = None + view_1724 = torch.ops.aten.view.default(mm_166, [2, 8192, 1792]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_1724); convert_element_type_786 = view_1724 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16) + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 32, '0'); convert_element_type_790 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_312, [1, 0]); wait_tensor_312 = None + view_1731 = torch.ops.aten.view.default(mul_191, [16384, 1792]); mul_191 = None + mm_167 = torch.ops.aten.mm.default(view_1731, permute_263); view_1731 = permute_263 = None + view_1732 = torch.ops.aten.view.default(mm_167, [2, 8192, 4096]); mm_167 = None + split_104 = torch.ops.aten.split.Tensor(view_1732, 1024, 1); view_1732 = None + getitem_1048 = split_104[0] + getitem_1049 = split_104[1] + getitem_1050 = split_104[2] + getitem_1051 = split_104[3] + getitem_1052 = split_104[4] + getitem_1053 = split_104[5] + getitem_1054 = split_104[6] + getitem_1055 = split_104[7]; split_104 = None + cat_96 = torch.ops.aten.cat.default([getitem_1048, getitem_1049, getitem_1050, getitem_1051, getitem_1052, getitem_1053, getitem_1054, getitem_1055]); getitem_1048 = getitem_1049 = getitem_1050 = getitem_1051 = getitem_1052 = getitem_1053 = getitem_1054 = getitem_1055 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_96, 'sum', 8, '1'); cat_96 = None + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + add_95 = torch.ops.aten.add.Tensor(add_93, wait_tensor_313); add_93 = wait_tensor_313 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16) + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 32, '0'); convert_element_type_793 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32) + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = rsqrt_48 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_314); mul_192 = wait_tensor_314 = None + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_795, 8, '1'); convert_element_type_795 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + split_105 = torch.ops.aten.split.Tensor(wait_tensor_315, 2); wait_tensor_315 = None + getitem_1056 = split_105[0] + getitem_1057 = split_105[1] + getitem_1058 = split_105[2] + getitem_1059 = split_105[3] + getitem_1060 = split_105[4] + getitem_1061 = split_105[5] + getitem_1062 = split_105[6] + getitem_1063 = split_105[7]; split_105 = None + cat_97 = torch.ops.aten.cat.default([getitem_1056, getitem_1057, getitem_1058, getitem_1059, getitem_1060, getitem_1061, getitem_1062, getitem_1063], 1); getitem_1056 = getitem_1057 = getitem_1058 = getitem_1059 = getitem_1060 = getitem_1061 = getitem_1062 = getitem_1063 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16) + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 32, '0'); convert_element_type_796 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_316, [1, 0]); wait_tensor_316 = None + view_1743 = torch.ops.aten.view.default(cat_97, [16384, 4096]); cat_97 = None + mm_168 = torch.ops.aten.mm.default(view_1743, permute_264); permute_264 = None + view_1744 = torch.ops.aten.view.default(mm_168, [2, 8192, 512]) + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16) + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 32, '0'); convert_element_type_799 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_317, [1, 0]); wait_tensor_317 = None + mm_169 = torch.ops.aten.mm.default(view_1743, permute_265); permute_265 = None + view_1751 = torch.ops.aten.view.default(mm_169, [2, 8192, 128]); mm_169 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16) + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 32, '0'); convert_element_type_802 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_318, [1, 0]); wait_tensor_318 = None + mm_170 = torch.ops.aten.mm.default(view_1743, permute_266); view_1743 = permute_266 = None + view_1758 = torch.ops.aten.view.default(mm_170, [2, 8192, 128]) + view_1760 = torch.ops.aten.view.default(view_1744, [2, 8192, -1, 128]); view_1744 = None + view_1761 = torch.ops.aten.view.default(view_1751, [2, 8192, -1, 128]); view_1751 = None + view_1762 = torch.ops.aten.view.default(view_1758, [2, 8192, -1, 128]); view_1758 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_1760, torch.float32); view_1760 = None + view_1763 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 4, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_1763); view_1763 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_1761, torch.float32); view_1761 = None + view_1764 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 1, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_1764); view_1764 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_37); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_1766 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 4, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_37); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_1767 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 1, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_1766, torch.bfloat16); view_1766 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_1767, torch.bfloat16); view_1767 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 1, 4, 128]); unsqueeze_48 = None + view_1768 = torch.ops.aten.view.default(expand_48, [2, 8192, 4, 128]); expand_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_1762, 3); view_1762 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 1, 4, 128]); unsqueeze_49 = None + view_1769 = torch.ops.aten.view.default(expand_49, [2, 8192, 4, 128]); expand_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_1768, [0, 2, 1, 3]); view_1768 = None + permute_269 = torch.ops.aten.permute.default(view_1769, [0, 2, 1, 3]); view_1769 = None + _scaled_dot_product_cudnn_attention_24 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_267, permute_268, permute_269, None, True, 0.0, True); permute_267 = permute_268 = permute_269 = None + getitem_1064 = _scaled_dot_product_cudnn_attention_24[0] + getitem_1065 = _scaled_dot_product_cudnn_attention_24[1] + getitem_1070 = _scaled_dot_product_cudnn_attention_24[6] + getitem_1071 = _scaled_dot_product_cudnn_attention_24[7]; _scaled_dot_product_cudnn_attention_24 = None + permute_270 = torch.ops.aten.permute.default(getitem_1064, [0, 2, 1, 3]) + view_1770 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16) + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 32, '0'); convert_element_type_809 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_319, [1, 0]); wait_tensor_319 = None + view_1776 = torch.ops.aten.view.default(view_1770, [16384, 512]); view_1770 = None + mm_171 = torch.ops.aten.mm.default(view_1776, permute_271); view_1776 = permute_271 = None + view_1777 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + split_106 = torch.ops.aten.split.Tensor(view_1777, 1024, 1); view_1777 = None + getitem_1073 = split_106[0] + getitem_1074 = split_106[1] + getitem_1075 = split_106[2] + getitem_1076 = split_106[3] + getitem_1077 = split_106[4] + getitem_1078 = split_106[5] + getitem_1079 = split_106[6] + getitem_1080 = split_106[7]; split_106 = None + cat_98 = torch.ops.aten.cat.default([getitem_1073, getitem_1074, getitem_1075, getitem_1076, getitem_1077, getitem_1078, getitem_1079, getitem_1080]); getitem_1073 = getitem_1074 = getitem_1075 = getitem_1076 = getitem_1077 = getitem_1078 = getitem_1079 = getitem_1080 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_98, 'sum', 8, '1'); cat_98 = None + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49) + add_97 = torch.ops.aten.add.Tensor(add_95, wait_tensor_320); wait_tensor_320 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16) + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 32, '0'); convert_element_type_812 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32) + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = rsqrt_49 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_321); mul_196 = wait_tensor_321 = None + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_814, 8, '1'); convert_element_type_814 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + split_107 = torch.ops.aten.split.Tensor(wait_tensor_322, 2); wait_tensor_322 = None + getitem_1081 = split_107[0] + getitem_1082 = split_107[1] + getitem_1083 = split_107[2] + getitem_1084 = split_107[3] + getitem_1085 = split_107[4] + getitem_1086 = split_107[5] + getitem_1087 = split_107[6] + getitem_1088 = split_107[7]; split_107 = None + cat_99 = torch.ops.aten.cat.default([getitem_1081, getitem_1082, getitem_1083, getitem_1084, getitem_1085, getitem_1086, getitem_1087, getitem_1088], 1); getitem_1081 = getitem_1082 = getitem_1083 = getitem_1084 = getitem_1085 = getitem_1086 = getitem_1087 = getitem_1088 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16) + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 32, '0'); convert_element_type_815 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_323, [1, 0]); wait_tensor_323 = None + view_1788 = torch.ops.aten.view.default(cat_99, [16384, 4096]); cat_99 = None + mm_172 = torch.ops.aten.mm.default(view_1788, permute_272); permute_272 = None + view_1789 = torch.ops.aten.view.default(mm_172, [2, 8192, 1792]) + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_1789, torch.float32); view_1789 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); convert_element_type_818 = sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16) + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 32, '0'); convert_element_type_820 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_324, [1, 0]); wait_tensor_324 = None + mm_173 = torch.ops.aten.mm.default(view_1788, permute_273); view_1788 = permute_273 = None + view_1796 = torch.ops.aten.view.default(mm_173, [2, 8192, 1792]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_1796); convert_element_type_819 = view_1796 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16) + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 32, '0'); convert_element_type_823 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_325, [1, 0]); wait_tensor_325 = None + view_1803 = torch.ops.aten.view.default(mul_199, [16384, 1792]); mul_199 = None + mm_174 = torch.ops.aten.mm.default(view_1803, permute_274); view_1803 = permute_274 = None + view_1804 = torch.ops.aten.view.default(mm_174, [2, 8192, 4096]); mm_174 = None + split_108 = torch.ops.aten.split.Tensor(view_1804, 1024, 1); view_1804 = None + getitem_1089 = split_108[0] + getitem_1090 = split_108[1] + getitem_1091 = split_108[2] + getitem_1092 = split_108[3] + getitem_1093 = split_108[4] + getitem_1094 = split_108[5] + getitem_1095 = split_108[6] + getitem_1096 = split_108[7]; split_108 = None + cat_100 = torch.ops.aten.cat.default([getitem_1089, getitem_1090, getitem_1091, getitem_1092, getitem_1093, getitem_1094, getitem_1095, getitem_1096]); getitem_1089 = getitem_1090 = getitem_1091 = getitem_1092 = getitem_1093 = getitem_1094 = getitem_1095 = getitem_1096 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_100, 'sum', 8, '1'); cat_100 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + add_99 = torch.ops.aten.add.Tensor(add_97, wait_tensor_326); add_97 = wait_tensor_326 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16) + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 32, '0'); convert_element_type_826 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32) + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = rsqrt_50 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_327); mul_200 = wait_tensor_327 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_828, 8, '1'); convert_element_type_828 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + split_109 = torch.ops.aten.split.Tensor(wait_tensor_328, 2); wait_tensor_328 = None + getitem_1097 = split_109[0] + getitem_1098 = split_109[1] + getitem_1099 = split_109[2] + getitem_1100 = split_109[3] + getitem_1101 = split_109[4] + getitem_1102 = split_109[5] + getitem_1103 = split_109[6] + getitem_1104 = split_109[7]; split_109 = None + cat_101 = torch.ops.aten.cat.default([getitem_1097, getitem_1098, getitem_1099, getitem_1100, getitem_1101, getitem_1102, getitem_1103, getitem_1104], 1); getitem_1097 = getitem_1098 = getitem_1099 = getitem_1100 = getitem_1101 = getitem_1102 = getitem_1103 = getitem_1104 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16) + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 32, '0'); convert_element_type_829 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_329, [1, 0]); wait_tensor_329 = None + view_1815 = torch.ops.aten.view.default(cat_101, [16384, 4096]); cat_101 = None + mm_175 = torch.ops.aten.mm.default(view_1815, permute_275); permute_275 = None + view_1816 = torch.ops.aten.view.default(mm_175, [2, 8192, 512]) + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16) + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 32, '0'); convert_element_type_832 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_330, [1, 0]); wait_tensor_330 = None + mm_176 = torch.ops.aten.mm.default(view_1815, permute_276); permute_276 = None + view_1823 = torch.ops.aten.view.default(mm_176, [2, 8192, 128]); mm_176 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16) + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 32, '0'); convert_element_type_835 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_331, [1, 0]); wait_tensor_331 = None + mm_177 = torch.ops.aten.mm.default(view_1815, permute_277); view_1815 = permute_277 = None + view_1830 = torch.ops.aten.view.default(mm_177, [2, 8192, 128]) + view_1832 = torch.ops.aten.view.default(view_1816, [2, 8192, -1, 128]); view_1816 = None + view_1833 = torch.ops.aten.view.default(view_1823, [2, 8192, -1, 128]); view_1823 = None + view_1834 = torch.ops.aten.view.default(view_1830, [2, 8192, -1, 128]); view_1830 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_1832, torch.float32); view_1832 = None + view_1835 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 4, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_1835); view_1835 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_1833, torch.float32); view_1833 = None + view_1836 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 1, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_1836); view_1836 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_37); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_1838 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 4, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_37); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_1839 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 1, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_1838, torch.bfloat16); view_1838 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_1839, torch.bfloat16); view_1839 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 1, 4, 128]); unsqueeze_50 = None + view_1840 = torch.ops.aten.view.default(expand_50, [2, 8192, 4, 128]); expand_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_1834, 3); view_1834 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 1, 4, 128]); unsqueeze_51 = None + view_1841 = torch.ops.aten.view.default(expand_51, [2, 8192, 4, 128]); expand_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_1840, [0, 2, 1, 3]); view_1840 = None + permute_280 = torch.ops.aten.permute.default(view_1841, [0, 2, 1, 3]); view_1841 = None + _scaled_dot_product_cudnn_attention_25 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_278, permute_279, permute_280, None, True, 0.0, True); permute_278 = permute_279 = permute_280 = None + getitem_1105 = _scaled_dot_product_cudnn_attention_25[0] + getitem_1106 = _scaled_dot_product_cudnn_attention_25[1] + getitem_1111 = _scaled_dot_product_cudnn_attention_25[6] + getitem_1112 = _scaled_dot_product_cudnn_attention_25[7]; _scaled_dot_product_cudnn_attention_25 = None + permute_281 = torch.ops.aten.permute.default(getitem_1105, [0, 2, 1, 3]) + view_1842 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16) + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 32, '0'); convert_element_type_842 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_332, [1, 0]); wait_tensor_332 = None + view_1848 = torch.ops.aten.view.default(view_1842, [16384, 512]); view_1842 = None + mm_178 = torch.ops.aten.mm.default(view_1848, permute_282); view_1848 = permute_282 = None + view_1849 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + split_110 = torch.ops.aten.split.Tensor(view_1849, 1024, 1); view_1849 = None + getitem_1114 = split_110[0] + getitem_1115 = split_110[1] + getitem_1116 = split_110[2] + getitem_1117 = split_110[3] + getitem_1118 = split_110[4] + getitem_1119 = split_110[5] + getitem_1120 = split_110[6] + getitem_1121 = split_110[7]; split_110 = None + cat_102 = torch.ops.aten.cat.default([getitem_1114, getitem_1115, getitem_1116, getitem_1117, getitem_1118, getitem_1119, getitem_1120, getitem_1121]); getitem_1114 = getitem_1115 = getitem_1116 = getitem_1117 = getitem_1118 = getitem_1119 = getitem_1120 = getitem_1121 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_102, 'sum', 8, '1'); cat_102 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51) + add_101 = torch.ops.aten.add.Tensor(add_99, wait_tensor_333); wait_tensor_333 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16) + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 32, '0'); convert_element_type_845 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32) + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = rsqrt_51 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_334); mul_204 = wait_tensor_334 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_847, 8, '1'); convert_element_type_847 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + split_111 = torch.ops.aten.split.Tensor(wait_tensor_335, 2); wait_tensor_335 = None + getitem_1122 = split_111[0] + getitem_1123 = split_111[1] + getitem_1124 = split_111[2] + getitem_1125 = split_111[3] + getitem_1126 = split_111[4] + getitem_1127 = split_111[5] + getitem_1128 = split_111[6] + getitem_1129 = split_111[7]; split_111 = None + cat_103 = torch.ops.aten.cat.default([getitem_1122, getitem_1123, getitem_1124, getitem_1125, getitem_1126, getitem_1127, getitem_1128, getitem_1129], 1); getitem_1122 = getitem_1123 = getitem_1124 = getitem_1125 = getitem_1126 = getitem_1127 = getitem_1128 = getitem_1129 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16) + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 32, '0'); convert_element_type_848 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_336, [1, 0]); wait_tensor_336 = None + view_1860 = torch.ops.aten.view.default(cat_103, [16384, 4096]); cat_103 = None + mm_179 = torch.ops.aten.mm.default(view_1860, permute_283); permute_283 = None + view_1861 = torch.ops.aten.view.default(mm_179, [2, 8192, 1792]) + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_1861, torch.float32); view_1861 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); convert_element_type_851 = sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16) + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 32, '0'); convert_element_type_853 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_337, [1, 0]); wait_tensor_337 = None + mm_180 = torch.ops.aten.mm.default(view_1860, permute_284); view_1860 = permute_284 = None + view_1868 = torch.ops.aten.view.default(mm_180, [2, 8192, 1792]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_1868); convert_element_type_852 = view_1868 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16) + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 32, '0'); convert_element_type_856 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_338, [1, 0]); wait_tensor_338 = None + view_1875 = torch.ops.aten.view.default(mul_207, [16384, 1792]); mul_207 = None + mm_181 = torch.ops.aten.mm.default(view_1875, permute_285); view_1875 = permute_285 = None + view_1876 = torch.ops.aten.view.default(mm_181, [2, 8192, 4096]); mm_181 = None + split_112 = torch.ops.aten.split.Tensor(view_1876, 1024, 1); view_1876 = None + getitem_1130 = split_112[0] + getitem_1131 = split_112[1] + getitem_1132 = split_112[2] + getitem_1133 = split_112[3] + getitem_1134 = split_112[4] + getitem_1135 = split_112[5] + getitem_1136 = split_112[6] + getitem_1137 = split_112[7]; split_112 = None + cat_104 = torch.ops.aten.cat.default([getitem_1130, getitem_1131, getitem_1132, getitem_1133, getitem_1134, getitem_1135, getitem_1136, getitem_1137]); getitem_1130 = getitem_1131 = getitem_1132 = getitem_1133 = getitem_1134 = getitem_1135 = getitem_1136 = getitem_1137 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_104, 'sum', 8, '1'); cat_104 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + add_103 = torch.ops.aten.add.Tensor(add_101, wait_tensor_339); add_101 = wait_tensor_339 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16) + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 32, '0'); convert_element_type_859 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32) + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = rsqrt_52 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_340); mul_208 = wait_tensor_340 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_861, 8, '1'); convert_element_type_861 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + split_113 = torch.ops.aten.split.Tensor(wait_tensor_341, 2); wait_tensor_341 = None + getitem_1138 = split_113[0] + getitem_1139 = split_113[1] + getitem_1140 = split_113[2] + getitem_1141 = split_113[3] + getitem_1142 = split_113[4] + getitem_1143 = split_113[5] + getitem_1144 = split_113[6] + getitem_1145 = split_113[7]; split_113 = None + cat_105 = torch.ops.aten.cat.default([getitem_1138, getitem_1139, getitem_1140, getitem_1141, getitem_1142, getitem_1143, getitem_1144, getitem_1145], 1); getitem_1138 = getitem_1139 = getitem_1140 = getitem_1141 = getitem_1142 = getitem_1143 = getitem_1144 = getitem_1145 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16) + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 32, '0'); convert_element_type_862 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_342, [1, 0]); wait_tensor_342 = None + view_1887 = torch.ops.aten.view.default(cat_105, [16384, 4096]); cat_105 = None + mm_182 = torch.ops.aten.mm.default(view_1887, permute_286); permute_286 = None + view_1888 = torch.ops.aten.view.default(mm_182, [2, 8192, 512]) + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16) + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 32, '0'); convert_element_type_865 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_343, [1, 0]); wait_tensor_343 = None + mm_183 = torch.ops.aten.mm.default(view_1887, permute_287); permute_287 = None + view_1895 = torch.ops.aten.view.default(mm_183, [2, 8192, 128]); mm_183 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16) + all_gather_into_tensor_291 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 32, '0'); convert_element_type_868 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_291); all_gather_into_tensor_291 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_344, [1, 0]); wait_tensor_344 = None + mm_184 = torch.ops.aten.mm.default(view_1887, permute_288); view_1887 = permute_288 = None + view_1902 = torch.ops.aten.view.default(mm_184, [2, 8192, 128]) + view_1904 = torch.ops.aten.view.default(view_1888, [2, 8192, -1, 128]); view_1888 = None + view_1905 = torch.ops.aten.view.default(view_1895, [2, 8192, -1, 128]); view_1895 = None + view_1906 = torch.ops.aten.view.default(view_1902, [2, 8192, -1, 128]); view_1902 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_1904, torch.float32); view_1904 = None + view_1907 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 4, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_1907); view_1907 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_1905, torch.float32); view_1905 = None + view_1908 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 1, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_1908); view_1908 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_37); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_1910 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 4, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_37); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_1911 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 1, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_1910, torch.bfloat16); view_1910 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_1911, torch.bfloat16); view_1911 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 1, 4, 128]); unsqueeze_52 = None + view_1912 = torch.ops.aten.view.default(expand_52, [2, 8192, 4, 128]); expand_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_1906, 3); view_1906 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 1, 4, 128]); unsqueeze_53 = None + view_1913 = torch.ops.aten.view.default(expand_53, [2, 8192, 4, 128]); expand_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_1912, [0, 2, 1, 3]); view_1912 = None + permute_291 = torch.ops.aten.permute.default(view_1913, [0, 2, 1, 3]); view_1913 = None + _scaled_dot_product_cudnn_attention_26 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_289, permute_290, permute_291, None, True, 0.0, True); permute_289 = permute_290 = permute_291 = None + getitem_1146 = _scaled_dot_product_cudnn_attention_26[0] + getitem_1147 = _scaled_dot_product_cudnn_attention_26[1] + getitem_1152 = _scaled_dot_product_cudnn_attention_26[6] + getitem_1153 = _scaled_dot_product_cudnn_attention_26[7]; _scaled_dot_product_cudnn_attention_26 = None + permute_292 = torch.ops.aten.permute.default(getitem_1146, [0, 2, 1, 3]) + view_1914 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16) + all_gather_into_tensor_292 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 32, '0'); convert_element_type_875 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_292); all_gather_into_tensor_292 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_345, [1, 0]); wait_tensor_345 = None + view_1920 = torch.ops.aten.view.default(view_1914, [16384, 512]); view_1914 = None + mm_185 = torch.ops.aten.mm.default(view_1920, permute_293); view_1920 = permute_293 = None + view_1921 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + split_114 = torch.ops.aten.split.Tensor(view_1921, 1024, 1); view_1921 = None + getitem_1155 = split_114[0] + getitem_1156 = split_114[1] + getitem_1157 = split_114[2] + getitem_1158 = split_114[3] + getitem_1159 = split_114[4] + getitem_1160 = split_114[5] + getitem_1161 = split_114[6] + getitem_1162 = split_114[7]; split_114 = None + cat_106 = torch.ops.aten.cat.default([getitem_1155, getitem_1156, getitem_1157, getitem_1158, getitem_1159, getitem_1160, getitem_1161, getitem_1162]); getitem_1155 = getitem_1156 = getitem_1157 = getitem_1158 = getitem_1159 = getitem_1160 = getitem_1161 = getitem_1162 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_106, 'sum', 8, '1'); cat_106 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53) + add_105 = torch.ops.aten.add.Tensor(add_103, wait_tensor_346); wait_tensor_346 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16) + all_gather_into_tensor_293 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 32, '0'); convert_element_type_878 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_293); all_gather_into_tensor_293 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32) + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = rsqrt_53 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_347); mul_212 = wait_tensor_347 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + all_gather_into_tensor_294 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_880, 8, '1'); convert_element_type_880 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_294); all_gather_into_tensor_294 = None + split_115 = torch.ops.aten.split.Tensor(wait_tensor_348, 2); wait_tensor_348 = None + getitem_1163 = split_115[0] + getitem_1164 = split_115[1] + getitem_1165 = split_115[2] + getitem_1166 = split_115[3] + getitem_1167 = split_115[4] + getitem_1168 = split_115[5] + getitem_1169 = split_115[6] + getitem_1170 = split_115[7]; split_115 = None + cat_107 = torch.ops.aten.cat.default([getitem_1163, getitem_1164, getitem_1165, getitem_1166, getitem_1167, getitem_1168, getitem_1169, getitem_1170], 1); getitem_1163 = getitem_1164 = getitem_1165 = getitem_1166 = getitem_1167 = getitem_1168 = getitem_1169 = getitem_1170 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16) + all_gather_into_tensor_295 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 32, '0'); convert_element_type_881 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_295); all_gather_into_tensor_295 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_349, [1, 0]); wait_tensor_349 = None + view_1932 = torch.ops.aten.view.default(cat_107, [16384, 4096]); cat_107 = None + mm_186 = torch.ops.aten.mm.default(view_1932, permute_294); permute_294 = None + view_1933 = torch.ops.aten.view.default(mm_186, [2, 8192, 1792]) + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_1933, torch.float32); view_1933 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); convert_element_type_884 = sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16) + all_gather_into_tensor_296 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 32, '0'); convert_element_type_886 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_296); all_gather_into_tensor_296 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_350, [1, 0]); wait_tensor_350 = None + mm_187 = torch.ops.aten.mm.default(view_1932, permute_295); view_1932 = permute_295 = None + view_1940 = torch.ops.aten.view.default(mm_187, [2, 8192, 1792]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_1940); convert_element_type_885 = view_1940 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16) + all_gather_into_tensor_297 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 32, '0'); convert_element_type_889 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_297); all_gather_into_tensor_297 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_351, [1, 0]); wait_tensor_351 = None + view_1947 = torch.ops.aten.view.default(mul_215, [16384, 1792]); mul_215 = None + mm_188 = torch.ops.aten.mm.default(view_1947, permute_296); view_1947 = permute_296 = None + view_1948 = torch.ops.aten.view.default(mm_188, [2, 8192, 4096]); mm_188 = None + split_116 = torch.ops.aten.split.Tensor(view_1948, 1024, 1); view_1948 = None + getitem_1171 = split_116[0] + getitem_1172 = split_116[1] + getitem_1173 = split_116[2] + getitem_1174 = split_116[3] + getitem_1175 = split_116[4] + getitem_1176 = split_116[5] + getitem_1177 = split_116[6] + getitem_1178 = split_116[7]; split_116 = None + cat_108 = torch.ops.aten.cat.default([getitem_1171, getitem_1172, getitem_1173, getitem_1174, getitem_1175, getitem_1176, getitem_1177, getitem_1178]); getitem_1171 = getitem_1172 = getitem_1173 = getitem_1174 = getitem_1175 = getitem_1176 = getitem_1177 = getitem_1178 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_108, 'sum', 8, '1'); cat_108 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + add_107 = torch.ops.aten.add.Tensor(add_105, wait_tensor_352); add_105 = wait_tensor_352 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16) + all_gather_into_tensor_298 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 32, '0'); convert_element_type_892 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_298); all_gather_into_tensor_298 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32) + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = rsqrt_54 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_353); mul_216 = wait_tensor_353 = None + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + all_gather_into_tensor_299 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_894, 8, '1'); convert_element_type_894 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_299); all_gather_into_tensor_299 = None + split_117 = torch.ops.aten.split.Tensor(wait_tensor_354, 2); wait_tensor_354 = None + getitem_1179 = split_117[0] + getitem_1180 = split_117[1] + getitem_1181 = split_117[2] + getitem_1182 = split_117[3] + getitem_1183 = split_117[4] + getitem_1184 = split_117[5] + getitem_1185 = split_117[6] + getitem_1186 = split_117[7]; split_117 = None + cat_109 = torch.ops.aten.cat.default([getitem_1179, getitem_1180, getitem_1181, getitem_1182, getitem_1183, getitem_1184, getitem_1185, getitem_1186], 1); getitem_1179 = getitem_1180 = getitem_1181 = getitem_1182 = getitem_1183 = getitem_1184 = getitem_1185 = getitem_1186 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16) + all_gather_into_tensor_300 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 32, '0'); convert_element_type_895 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_300); all_gather_into_tensor_300 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_355, [1, 0]); wait_tensor_355 = None + view_1959 = torch.ops.aten.view.default(cat_109, [16384, 4096]); cat_109 = None + mm_189 = torch.ops.aten.mm.default(view_1959, permute_297); permute_297 = None + view_1960 = torch.ops.aten.view.default(mm_189, [2, 8192, 512]) + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16) + all_gather_into_tensor_301 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 32, '0'); convert_element_type_898 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_301); all_gather_into_tensor_301 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_356, [1, 0]); wait_tensor_356 = None + mm_190 = torch.ops.aten.mm.default(view_1959, permute_298); permute_298 = None + view_1967 = torch.ops.aten.view.default(mm_190, [2, 8192, 128]); mm_190 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16) + all_gather_into_tensor_302 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 32, '0'); convert_element_type_901 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_302); all_gather_into_tensor_302 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_357, [1, 0]); wait_tensor_357 = None + mm_191 = torch.ops.aten.mm.default(view_1959, permute_299); view_1959 = permute_299 = None + view_1974 = torch.ops.aten.view.default(mm_191, [2, 8192, 128]) + view_1976 = torch.ops.aten.view.default(view_1960, [2, 8192, -1, 128]); view_1960 = None + view_1977 = torch.ops.aten.view.default(view_1967, [2, 8192, -1, 128]); view_1967 = None + view_1978 = torch.ops.aten.view.default(view_1974, [2, 8192, -1, 128]); view_1974 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_1976, torch.float32); view_1976 = None + view_1979 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 4, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_1979); view_1979 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_1977, torch.float32); view_1977 = None + view_1980 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 1, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_1980); view_1980 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_37); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_1982 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 4, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_37); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_1983 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 1, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_1982, torch.bfloat16); view_1982 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_1983, torch.bfloat16); view_1983 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 1, 4, 128]); unsqueeze_54 = None + view_1984 = torch.ops.aten.view.default(expand_54, [2, 8192, 4, 128]); expand_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_1978, 3); view_1978 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 1, 4, 128]); unsqueeze_55 = None + view_1985 = torch.ops.aten.view.default(expand_55, [2, 8192, 4, 128]); expand_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_1984, [0, 2, 1, 3]); view_1984 = None + permute_302 = torch.ops.aten.permute.default(view_1985, [0, 2, 1, 3]); view_1985 = None + _scaled_dot_product_cudnn_attention_27 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_300, permute_301, permute_302, None, True, 0.0, True); permute_300 = permute_301 = permute_302 = None + getitem_1187 = _scaled_dot_product_cudnn_attention_27[0] + getitem_1188 = _scaled_dot_product_cudnn_attention_27[1] + getitem_1193 = _scaled_dot_product_cudnn_attention_27[6] + getitem_1194 = _scaled_dot_product_cudnn_attention_27[7]; _scaled_dot_product_cudnn_attention_27 = None + permute_303 = torch.ops.aten.permute.default(getitem_1187, [0, 2, 1, 3]) + view_1986 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16) + all_gather_into_tensor_303 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 32, '0'); convert_element_type_908 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_303); all_gather_into_tensor_303 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_358, [1, 0]); wait_tensor_358 = None + view_1992 = torch.ops.aten.view.default(view_1986, [16384, 512]); view_1986 = None + mm_192 = torch.ops.aten.mm.default(view_1992, permute_304); view_1992 = permute_304 = None + view_1993 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + split_118 = torch.ops.aten.split.Tensor(view_1993, 1024, 1); view_1993 = None + getitem_1196 = split_118[0] + getitem_1197 = split_118[1] + getitem_1198 = split_118[2] + getitem_1199 = split_118[3] + getitem_1200 = split_118[4] + getitem_1201 = split_118[5] + getitem_1202 = split_118[6] + getitem_1203 = split_118[7]; split_118 = None + cat_110 = torch.ops.aten.cat.default([getitem_1196, getitem_1197, getitem_1198, getitem_1199, getitem_1200, getitem_1201, getitem_1202, getitem_1203]); getitem_1196 = getitem_1197 = getitem_1198 = getitem_1199 = getitem_1200 = getitem_1201 = getitem_1202 = getitem_1203 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_110, 'sum', 8, '1'); cat_110 = None + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55) + add_109 = torch.ops.aten.add.Tensor(add_107, wait_tensor_359); wait_tensor_359 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16) + all_gather_into_tensor_304 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 32, '0'); convert_element_type_911 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_304); all_gather_into_tensor_304 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32) + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = rsqrt_55 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_360); mul_220 = wait_tensor_360 = None + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + all_gather_into_tensor_305 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_913, 8, '1'); convert_element_type_913 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_305); all_gather_into_tensor_305 = None + split_119 = torch.ops.aten.split.Tensor(wait_tensor_361, 2); wait_tensor_361 = None + getitem_1204 = split_119[0] + getitem_1205 = split_119[1] + getitem_1206 = split_119[2] + getitem_1207 = split_119[3] + getitem_1208 = split_119[4] + getitem_1209 = split_119[5] + getitem_1210 = split_119[6] + getitem_1211 = split_119[7]; split_119 = None + cat_111 = torch.ops.aten.cat.default([getitem_1204, getitem_1205, getitem_1206, getitem_1207, getitem_1208, getitem_1209, getitem_1210, getitem_1211], 1); getitem_1204 = getitem_1205 = getitem_1206 = getitem_1207 = getitem_1208 = getitem_1209 = getitem_1210 = getitem_1211 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16) + all_gather_into_tensor_306 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 32, '0'); convert_element_type_914 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_306); all_gather_into_tensor_306 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_362, [1, 0]); wait_tensor_362 = None + view_2004 = torch.ops.aten.view.default(cat_111, [16384, 4096]); cat_111 = None + mm_193 = torch.ops.aten.mm.default(view_2004, permute_305); permute_305 = None + view_2005 = torch.ops.aten.view.default(mm_193, [2, 8192, 1792]) + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_2005, torch.float32); view_2005 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); convert_element_type_917 = sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16) + all_gather_into_tensor_307 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 32, '0'); convert_element_type_919 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_307); all_gather_into_tensor_307 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_363, [1, 0]); wait_tensor_363 = None + mm_194 = torch.ops.aten.mm.default(view_2004, permute_306); view_2004 = permute_306 = None + view_2012 = torch.ops.aten.view.default(mm_194, [2, 8192, 1792]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_2012); convert_element_type_918 = view_2012 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16) + all_gather_into_tensor_308 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 32, '0'); convert_element_type_922 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_308); all_gather_into_tensor_308 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_364, [1, 0]); wait_tensor_364 = None + view_2019 = torch.ops.aten.view.default(mul_223, [16384, 1792]); mul_223 = None + mm_195 = torch.ops.aten.mm.default(view_2019, permute_307); view_2019 = permute_307 = None + view_2020 = torch.ops.aten.view.default(mm_195, [2, 8192, 4096]); mm_195 = None + split_120 = torch.ops.aten.split.Tensor(view_2020, 1024, 1); view_2020 = None + getitem_1212 = split_120[0] + getitem_1213 = split_120[1] + getitem_1214 = split_120[2] + getitem_1215 = split_120[3] + getitem_1216 = split_120[4] + getitem_1217 = split_120[5] + getitem_1218 = split_120[6] + getitem_1219 = split_120[7]; split_120 = None + cat_112 = torch.ops.aten.cat.default([getitem_1212, getitem_1213, getitem_1214, getitem_1215, getitem_1216, getitem_1217, getitem_1218, getitem_1219]); getitem_1212 = getitem_1213 = getitem_1214 = getitem_1215 = getitem_1216 = getitem_1217 = getitem_1218 = getitem_1219 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_112, 'sum', 8, '1'); cat_112 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + add_111 = torch.ops.aten.add.Tensor(add_109, wait_tensor_365); add_109 = wait_tensor_365 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16) + all_gather_into_tensor_309 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 32, '0'); convert_element_type_925 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_309); all_gather_into_tensor_309 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32) + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = rsqrt_56 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_366); mul_224 = wait_tensor_366 = None + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + all_gather_into_tensor_310 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_927, 8, '1'); convert_element_type_927 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_310); all_gather_into_tensor_310 = None + split_121 = torch.ops.aten.split.Tensor(wait_tensor_367, 2); wait_tensor_367 = None + getitem_1220 = split_121[0] + getitem_1221 = split_121[1] + getitem_1222 = split_121[2] + getitem_1223 = split_121[3] + getitem_1224 = split_121[4] + getitem_1225 = split_121[5] + getitem_1226 = split_121[6] + getitem_1227 = split_121[7]; split_121 = None + cat_113 = torch.ops.aten.cat.default([getitem_1220, getitem_1221, getitem_1222, getitem_1223, getitem_1224, getitem_1225, getitem_1226, getitem_1227], 1); getitem_1220 = getitem_1221 = getitem_1222 = getitem_1223 = getitem_1224 = getitem_1225 = getitem_1226 = getitem_1227 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16) + all_gather_into_tensor_311 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 32, '0'); convert_element_type_928 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_311); all_gather_into_tensor_311 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_368, [1, 0]); wait_tensor_368 = None + view_2031 = torch.ops.aten.view.default(cat_113, [16384, 4096]); cat_113 = None + mm_196 = torch.ops.aten.mm.default(view_2031, permute_308); permute_308 = None + view_2032 = torch.ops.aten.view.default(mm_196, [2, 8192, 512]) + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16) + all_gather_into_tensor_312 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 32, '0'); convert_element_type_931 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_312); all_gather_into_tensor_312 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_369, [1, 0]); wait_tensor_369 = None + mm_197 = torch.ops.aten.mm.default(view_2031, permute_309); permute_309 = None + view_2039 = torch.ops.aten.view.default(mm_197, [2, 8192, 128]); mm_197 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16) + all_gather_into_tensor_313 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 32, '0'); convert_element_type_934 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_313); all_gather_into_tensor_313 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_370, [1, 0]); wait_tensor_370 = None + mm_198 = torch.ops.aten.mm.default(view_2031, permute_310); view_2031 = permute_310 = None + view_2046 = torch.ops.aten.view.default(mm_198, [2, 8192, 128]) + view_2048 = torch.ops.aten.view.default(view_2032, [2, 8192, -1, 128]); view_2032 = None + view_2049 = torch.ops.aten.view.default(view_2039, [2, 8192, -1, 128]); view_2039 = None + view_2050 = torch.ops.aten.view.default(view_2046, [2, 8192, -1, 128]); view_2046 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_2048, torch.float32); view_2048 = None + view_2051 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 4, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_2051); view_2051 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_2049, torch.float32); view_2049 = None + view_2052 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 1, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_2052); view_2052 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_37); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_2054 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 4, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_37); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_2055 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 1, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_2054, torch.bfloat16); view_2054 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_2055, torch.bfloat16); view_2055 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 1, 4, 128]); unsqueeze_56 = None + view_2056 = torch.ops.aten.view.default(expand_56, [2, 8192, 4, 128]); expand_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_2050, 3); view_2050 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 1, 4, 128]); unsqueeze_57 = None + view_2057 = torch.ops.aten.view.default(expand_57, [2, 8192, 4, 128]); expand_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_2056, [0, 2, 1, 3]); view_2056 = None + permute_313 = torch.ops.aten.permute.default(view_2057, [0, 2, 1, 3]); view_2057 = None + _scaled_dot_product_cudnn_attention_28 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_311, permute_312, permute_313, None, True, 0.0, True); permute_311 = permute_312 = permute_313 = None + getitem_1228 = _scaled_dot_product_cudnn_attention_28[0] + getitem_1229 = _scaled_dot_product_cudnn_attention_28[1] + getitem_1234 = _scaled_dot_product_cudnn_attention_28[6] + getitem_1235 = _scaled_dot_product_cudnn_attention_28[7]; _scaled_dot_product_cudnn_attention_28 = None + permute_314 = torch.ops.aten.permute.default(getitem_1228, [0, 2, 1, 3]) + view_2058 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16) + all_gather_into_tensor_314 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 32, '0'); convert_element_type_941 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_314); all_gather_into_tensor_314 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_371, [1, 0]); wait_tensor_371 = None + view_2064 = torch.ops.aten.view.default(view_2058, [16384, 512]); view_2058 = None + mm_199 = torch.ops.aten.mm.default(view_2064, permute_315); view_2064 = permute_315 = None + view_2065 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + split_122 = torch.ops.aten.split.Tensor(view_2065, 1024, 1); view_2065 = None + getitem_1237 = split_122[0] + getitem_1238 = split_122[1] + getitem_1239 = split_122[2] + getitem_1240 = split_122[3] + getitem_1241 = split_122[4] + getitem_1242 = split_122[5] + getitem_1243 = split_122[6] + getitem_1244 = split_122[7]; split_122 = None + cat_114 = torch.ops.aten.cat.default([getitem_1237, getitem_1238, getitem_1239, getitem_1240, getitem_1241, getitem_1242, getitem_1243, getitem_1244]); getitem_1237 = getitem_1238 = getitem_1239 = getitem_1240 = getitem_1241 = getitem_1242 = getitem_1243 = getitem_1244 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_114, 'sum', 8, '1'); cat_114 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57) + add_113 = torch.ops.aten.add.Tensor(add_111, wait_tensor_372); wait_tensor_372 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16) + all_gather_into_tensor_315 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 32, '0'); convert_element_type_944 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_315); all_gather_into_tensor_315 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32) + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = rsqrt_57 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_373); mul_228 = wait_tensor_373 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + all_gather_into_tensor_316 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_946, 8, '1'); convert_element_type_946 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_316); all_gather_into_tensor_316 = None + split_123 = torch.ops.aten.split.Tensor(wait_tensor_374, 2); wait_tensor_374 = None + getitem_1245 = split_123[0] + getitem_1246 = split_123[1] + getitem_1247 = split_123[2] + getitem_1248 = split_123[3] + getitem_1249 = split_123[4] + getitem_1250 = split_123[5] + getitem_1251 = split_123[6] + getitem_1252 = split_123[7]; split_123 = None + cat_115 = torch.ops.aten.cat.default([getitem_1245, getitem_1246, getitem_1247, getitem_1248, getitem_1249, getitem_1250, getitem_1251, getitem_1252], 1); getitem_1245 = getitem_1246 = getitem_1247 = getitem_1248 = getitem_1249 = getitem_1250 = getitem_1251 = getitem_1252 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16) + all_gather_into_tensor_317 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 32, '0'); convert_element_type_947 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_317); all_gather_into_tensor_317 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_375, [1, 0]); wait_tensor_375 = None + view_2076 = torch.ops.aten.view.default(cat_115, [16384, 4096]); cat_115 = None + mm_200 = torch.ops.aten.mm.default(view_2076, permute_316); permute_316 = None + view_2077 = torch.ops.aten.view.default(mm_200, [2, 8192, 1792]) + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_2077, torch.float32); view_2077 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); convert_element_type_950 = sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16) + all_gather_into_tensor_318 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 32, '0'); convert_element_type_952 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_318); all_gather_into_tensor_318 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_376, [1, 0]); wait_tensor_376 = None + mm_201 = torch.ops.aten.mm.default(view_2076, permute_317); view_2076 = permute_317 = None + view_2084 = torch.ops.aten.view.default(mm_201, [2, 8192, 1792]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_2084); convert_element_type_951 = view_2084 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16) + all_gather_into_tensor_319 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 32, '0'); convert_element_type_955 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_319); all_gather_into_tensor_319 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_377, [1, 0]); wait_tensor_377 = None + view_2091 = torch.ops.aten.view.default(mul_231, [16384, 1792]); mul_231 = None + mm_202 = torch.ops.aten.mm.default(view_2091, permute_318); view_2091 = permute_318 = None + view_2092 = torch.ops.aten.view.default(mm_202, [2, 8192, 4096]); mm_202 = None + split_124 = torch.ops.aten.split.Tensor(view_2092, 1024, 1); view_2092 = None + getitem_1253 = split_124[0] + getitem_1254 = split_124[1] + getitem_1255 = split_124[2] + getitem_1256 = split_124[3] + getitem_1257 = split_124[4] + getitem_1258 = split_124[5] + getitem_1259 = split_124[6] + getitem_1260 = split_124[7]; split_124 = None + cat_116 = torch.ops.aten.cat.default([getitem_1253, getitem_1254, getitem_1255, getitem_1256, getitem_1257, getitem_1258, getitem_1259, getitem_1260]); getitem_1253 = getitem_1254 = getitem_1255 = getitem_1256 = getitem_1257 = getitem_1258 = getitem_1259 = getitem_1260 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_116, 'sum', 8, '1'); cat_116 = None + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + add_115 = torch.ops.aten.add.Tensor(add_113, wait_tensor_378); add_113 = wait_tensor_378 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16) + all_gather_into_tensor_320 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 32, '0'); convert_element_type_958 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_320); all_gather_into_tensor_320 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32) + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = rsqrt_58 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_379); mul_232 = wait_tensor_379 = None + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + all_gather_into_tensor_321 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_960, 8, '1'); convert_element_type_960 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_321); all_gather_into_tensor_321 = None + split_125 = torch.ops.aten.split.Tensor(wait_tensor_380, 2); wait_tensor_380 = None + getitem_1261 = split_125[0] + getitem_1262 = split_125[1] + getitem_1263 = split_125[2] + getitem_1264 = split_125[3] + getitem_1265 = split_125[4] + getitem_1266 = split_125[5] + getitem_1267 = split_125[6] + getitem_1268 = split_125[7]; split_125 = None + cat_117 = torch.ops.aten.cat.default([getitem_1261, getitem_1262, getitem_1263, getitem_1264, getitem_1265, getitem_1266, getitem_1267, getitem_1268], 1); getitem_1261 = getitem_1262 = getitem_1263 = getitem_1264 = getitem_1265 = getitem_1266 = getitem_1267 = getitem_1268 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16) + all_gather_into_tensor_322 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 32, '0'); convert_element_type_961 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_322); all_gather_into_tensor_322 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_381, [1, 0]); wait_tensor_381 = None + view_2103 = torch.ops.aten.view.default(cat_117, [16384, 4096]); cat_117 = None + mm_203 = torch.ops.aten.mm.default(view_2103, permute_319); permute_319 = None + view_2104 = torch.ops.aten.view.default(mm_203, [2, 8192, 512]) + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16) + all_gather_into_tensor_323 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 32, '0'); convert_element_type_964 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_323); all_gather_into_tensor_323 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_382, [1, 0]); wait_tensor_382 = None + mm_204 = torch.ops.aten.mm.default(view_2103, permute_320); permute_320 = None + view_2111 = torch.ops.aten.view.default(mm_204, [2, 8192, 128]); mm_204 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16) + all_gather_into_tensor_324 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 32, '0'); convert_element_type_967 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_324); all_gather_into_tensor_324 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_383, [1, 0]); wait_tensor_383 = None + mm_205 = torch.ops.aten.mm.default(view_2103, permute_321); view_2103 = permute_321 = None + view_2118 = torch.ops.aten.view.default(mm_205, [2, 8192, 128]) + view_2120 = torch.ops.aten.view.default(view_2104, [2, 8192, -1, 128]); view_2104 = None + view_2121 = torch.ops.aten.view.default(view_2111, [2, 8192, -1, 128]); view_2111 = None + view_2122 = torch.ops.aten.view.default(view_2118, [2, 8192, -1, 128]); view_2118 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_2120, torch.float32); view_2120 = None + view_2123 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 4, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_2123); view_2123 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_2121, torch.float32); view_2121 = None + view_2124 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 1, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_2124); view_2124 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_37); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_2126 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 4, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_37); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_2127 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 1, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_2126, torch.bfloat16); view_2126 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_2127, torch.bfloat16); view_2127 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 1, 4, 128]); unsqueeze_58 = None + view_2128 = torch.ops.aten.view.default(expand_58, [2, 8192, 4, 128]); expand_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_2122, 3); view_2122 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 1, 4, 128]); unsqueeze_59 = None + view_2129 = torch.ops.aten.view.default(expand_59, [2, 8192, 4, 128]); expand_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_2128, [0, 2, 1, 3]); view_2128 = None + permute_324 = torch.ops.aten.permute.default(view_2129, [0, 2, 1, 3]); view_2129 = None + _scaled_dot_product_cudnn_attention_29 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_322, permute_323, permute_324, None, True, 0.0, True); permute_322 = permute_323 = permute_324 = None + getitem_1269 = _scaled_dot_product_cudnn_attention_29[0] + getitem_1270 = _scaled_dot_product_cudnn_attention_29[1] + getitem_1275 = _scaled_dot_product_cudnn_attention_29[6] + getitem_1276 = _scaled_dot_product_cudnn_attention_29[7]; _scaled_dot_product_cudnn_attention_29 = None + permute_325 = torch.ops.aten.permute.default(getitem_1269, [0, 2, 1, 3]) + view_2130 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16) + all_gather_into_tensor_325 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 32, '0'); convert_element_type_974 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_325); all_gather_into_tensor_325 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_384, [1, 0]); wait_tensor_384 = None + view_2136 = torch.ops.aten.view.default(view_2130, [16384, 512]); view_2130 = None + mm_206 = torch.ops.aten.mm.default(view_2136, permute_326); view_2136 = permute_326 = None + view_2137 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + split_126 = torch.ops.aten.split.Tensor(view_2137, 1024, 1); view_2137 = None + getitem_1278 = split_126[0] + getitem_1279 = split_126[1] + getitem_1280 = split_126[2] + getitem_1281 = split_126[3] + getitem_1282 = split_126[4] + getitem_1283 = split_126[5] + getitem_1284 = split_126[6] + getitem_1285 = split_126[7]; split_126 = None + cat_118 = torch.ops.aten.cat.default([getitem_1278, getitem_1279, getitem_1280, getitem_1281, getitem_1282, getitem_1283, getitem_1284, getitem_1285]); getitem_1278 = getitem_1279 = getitem_1280 = getitem_1281 = getitem_1282 = getitem_1283 = getitem_1284 = getitem_1285 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_118, 'sum', 8, '1'); cat_118 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59) + add_117 = torch.ops.aten.add.Tensor(add_115, wait_tensor_385); wait_tensor_385 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16) + all_gather_into_tensor_326 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 32, '0'); convert_element_type_977 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_326); all_gather_into_tensor_326 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32) + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = rsqrt_59 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_386); mul_236 = wait_tensor_386 = None + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + all_gather_into_tensor_327 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_979, 8, '1'); convert_element_type_979 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_327); all_gather_into_tensor_327 = None + split_127 = torch.ops.aten.split.Tensor(wait_tensor_387, 2); wait_tensor_387 = None + getitem_1286 = split_127[0] + getitem_1287 = split_127[1] + getitem_1288 = split_127[2] + getitem_1289 = split_127[3] + getitem_1290 = split_127[4] + getitem_1291 = split_127[5] + getitem_1292 = split_127[6] + getitem_1293 = split_127[7]; split_127 = None + cat_119 = torch.ops.aten.cat.default([getitem_1286, getitem_1287, getitem_1288, getitem_1289, getitem_1290, getitem_1291, getitem_1292, getitem_1293], 1); getitem_1286 = getitem_1287 = getitem_1288 = getitem_1289 = getitem_1290 = getitem_1291 = getitem_1292 = getitem_1293 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16) + all_gather_into_tensor_328 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 32, '0'); convert_element_type_980 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_328); all_gather_into_tensor_328 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_388, [1, 0]); wait_tensor_388 = None + view_2148 = torch.ops.aten.view.default(cat_119, [16384, 4096]); cat_119 = None + mm_207 = torch.ops.aten.mm.default(view_2148, permute_327); permute_327 = None + view_2149 = torch.ops.aten.view.default(mm_207, [2, 8192, 1792]) + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_2149, torch.float32); view_2149 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); convert_element_type_983 = sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16) + all_gather_into_tensor_329 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 32, '0'); convert_element_type_985 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_329); all_gather_into_tensor_329 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_389, [1, 0]); wait_tensor_389 = None + mm_208 = torch.ops.aten.mm.default(view_2148, permute_328); view_2148 = permute_328 = None + view_2156 = torch.ops.aten.view.default(mm_208, [2, 8192, 1792]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_2156); convert_element_type_984 = view_2156 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16) + all_gather_into_tensor_330 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 32, '0'); convert_element_type_988 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_330); all_gather_into_tensor_330 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_390, [1, 0]); wait_tensor_390 = None + view_2163 = torch.ops.aten.view.default(mul_239, [16384, 1792]); mul_239 = None + mm_209 = torch.ops.aten.mm.default(view_2163, permute_329); view_2163 = permute_329 = None + view_2164 = torch.ops.aten.view.default(mm_209, [2, 8192, 4096]); mm_209 = None + split_128 = torch.ops.aten.split.Tensor(view_2164, 1024, 1); view_2164 = None + getitem_1294 = split_128[0] + getitem_1295 = split_128[1] + getitem_1296 = split_128[2] + getitem_1297 = split_128[3] + getitem_1298 = split_128[4] + getitem_1299 = split_128[5] + getitem_1300 = split_128[6] + getitem_1301 = split_128[7]; split_128 = None + cat_120 = torch.ops.aten.cat.default([getitem_1294, getitem_1295, getitem_1296, getitem_1297, getitem_1298, getitem_1299, getitem_1300, getitem_1301]); getitem_1294 = getitem_1295 = getitem_1296 = getitem_1297 = getitem_1298 = getitem_1299 = getitem_1300 = getitem_1301 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_120, 'sum', 8, '1'); cat_120 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + add_119 = torch.ops.aten.add.Tensor(add_117, wait_tensor_391); add_117 = wait_tensor_391 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16) + all_gather_into_tensor_331 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 32, '0'); convert_element_type_991 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_331); all_gather_into_tensor_331 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32) + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = rsqrt_60 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_392); mul_240 = wait_tensor_392 = None + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + all_gather_into_tensor_332 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_993, 8, '1'); convert_element_type_993 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_332); all_gather_into_tensor_332 = None + split_129 = torch.ops.aten.split.Tensor(wait_tensor_393, 2); wait_tensor_393 = None + getitem_1302 = split_129[0] + getitem_1303 = split_129[1] + getitem_1304 = split_129[2] + getitem_1305 = split_129[3] + getitem_1306 = split_129[4] + getitem_1307 = split_129[5] + getitem_1308 = split_129[6] + getitem_1309 = split_129[7]; split_129 = None + cat_121 = torch.ops.aten.cat.default([getitem_1302, getitem_1303, getitem_1304, getitem_1305, getitem_1306, getitem_1307, getitem_1308, getitem_1309], 1); getitem_1302 = getitem_1303 = getitem_1304 = getitem_1305 = getitem_1306 = getitem_1307 = getitem_1308 = getitem_1309 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16) + all_gather_into_tensor_333 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 32, '0'); convert_element_type_994 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_333); all_gather_into_tensor_333 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_394, [1, 0]); wait_tensor_394 = None + view_2175 = torch.ops.aten.view.default(cat_121, [16384, 4096]); cat_121 = None + mm_210 = torch.ops.aten.mm.default(view_2175, permute_330); permute_330 = None + view_2176 = torch.ops.aten.view.default(mm_210, [2, 8192, 512]) + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16) + all_gather_into_tensor_334 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 32, '0'); convert_element_type_997 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_334); all_gather_into_tensor_334 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_395, [1, 0]); wait_tensor_395 = None + mm_211 = torch.ops.aten.mm.default(view_2175, permute_331); permute_331 = None + view_2183 = torch.ops.aten.view.default(mm_211, [2, 8192, 128]); mm_211 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16) + all_gather_into_tensor_335 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 32, '0'); convert_element_type_1000 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_335); all_gather_into_tensor_335 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_396, [1, 0]); wait_tensor_396 = None + mm_212 = torch.ops.aten.mm.default(view_2175, permute_332); view_2175 = permute_332 = None + view_2190 = torch.ops.aten.view.default(mm_212, [2, 8192, 128]) + view_2192 = torch.ops.aten.view.default(view_2176, [2, 8192, -1, 128]); view_2176 = None + view_2193 = torch.ops.aten.view.default(view_2183, [2, 8192, -1, 128]); view_2183 = None + view_2194 = torch.ops.aten.view.default(view_2190, [2, 8192, -1, 128]); view_2190 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_2192, torch.float32); view_2192 = None + view_2195 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 4, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_2195); view_2195 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_2193, torch.float32); view_2193 = None + view_2196 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 1, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_2196); view_2196 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_37); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_2198 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 4, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_37); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_2199 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 1, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_2198, torch.bfloat16); view_2198 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_2199, torch.bfloat16); view_2199 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 1, 4, 128]); unsqueeze_60 = None + view_2200 = torch.ops.aten.view.default(expand_60, [2, 8192, 4, 128]); expand_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_2194, 3); view_2194 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 1, 4, 128]); unsqueeze_61 = None + view_2201 = torch.ops.aten.view.default(expand_61, [2, 8192, 4, 128]); expand_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_2200, [0, 2, 1, 3]); view_2200 = None + permute_335 = torch.ops.aten.permute.default(view_2201, [0, 2, 1, 3]); view_2201 = None + _scaled_dot_product_cudnn_attention_30 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_333, permute_334, permute_335, None, True, 0.0, True); permute_333 = permute_334 = permute_335 = None + getitem_1310 = _scaled_dot_product_cudnn_attention_30[0] + getitem_1311 = _scaled_dot_product_cudnn_attention_30[1] + getitem_1316 = _scaled_dot_product_cudnn_attention_30[6] + getitem_1317 = _scaled_dot_product_cudnn_attention_30[7]; _scaled_dot_product_cudnn_attention_30 = None + permute_336 = torch.ops.aten.permute.default(getitem_1310, [0, 2, 1, 3]) + view_2202 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16) + all_gather_into_tensor_336 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 32, '0'); convert_element_type_1007 = None + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_336); all_gather_into_tensor_336 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_397, [1, 0]); wait_tensor_397 = None + view_2208 = torch.ops.aten.view.default(view_2202, [16384, 512]); view_2202 = None + mm_213 = torch.ops.aten.mm.default(view_2208, permute_337); view_2208 = permute_337 = None + view_2209 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + split_130 = torch.ops.aten.split.Tensor(view_2209, 1024, 1); view_2209 = None + getitem_1319 = split_130[0] + getitem_1320 = split_130[1] + getitem_1321 = split_130[2] + getitem_1322 = split_130[3] + getitem_1323 = split_130[4] + getitem_1324 = split_130[5] + getitem_1325 = split_130[6] + getitem_1326 = split_130[7]; split_130 = None + cat_122 = torch.ops.aten.cat.default([getitem_1319, getitem_1320, getitem_1321, getitem_1322, getitem_1323, getitem_1324, getitem_1325, getitem_1326]); getitem_1319 = getitem_1320 = getitem_1321 = getitem_1322 = getitem_1323 = getitem_1324 = getitem_1325 = getitem_1326 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_122, 'sum', 8, '1'); cat_122 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61) + add_121 = torch.ops.aten.add.Tensor(add_119, wait_tensor_398); wait_tensor_398 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16) + all_gather_into_tensor_337 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 32, '0'); convert_element_type_1010 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_337); all_gather_into_tensor_337 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32) + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = rsqrt_61 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_399); mul_244 = wait_tensor_399 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + all_gather_into_tensor_338 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1012, 8, '1'); convert_element_type_1012 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_338); all_gather_into_tensor_338 = None + split_131 = torch.ops.aten.split.Tensor(wait_tensor_400, 2); wait_tensor_400 = None + getitem_1327 = split_131[0] + getitem_1328 = split_131[1] + getitem_1329 = split_131[2] + getitem_1330 = split_131[3] + getitem_1331 = split_131[4] + getitem_1332 = split_131[5] + getitem_1333 = split_131[6] + getitem_1334 = split_131[7]; split_131 = None + cat_123 = torch.ops.aten.cat.default([getitem_1327, getitem_1328, getitem_1329, getitem_1330, getitem_1331, getitem_1332, getitem_1333, getitem_1334], 1); getitem_1327 = getitem_1328 = getitem_1329 = getitem_1330 = getitem_1331 = getitem_1332 = getitem_1333 = getitem_1334 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16) + all_gather_into_tensor_339 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 32, '0'); convert_element_type_1013 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_339); all_gather_into_tensor_339 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_401, [1, 0]); wait_tensor_401 = None + view_2220 = torch.ops.aten.view.default(cat_123, [16384, 4096]); cat_123 = None + mm_214 = torch.ops.aten.mm.default(view_2220, permute_338); permute_338 = None + view_2221 = torch.ops.aten.view.default(mm_214, [2, 8192, 1792]) + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_2221, torch.float32); view_2221 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); convert_element_type_1016 = sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16) + all_gather_into_tensor_340 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 32, '0'); convert_element_type_1018 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_340); all_gather_into_tensor_340 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_402, [1, 0]); wait_tensor_402 = None + mm_215 = torch.ops.aten.mm.default(view_2220, permute_339); view_2220 = permute_339 = None + view_2228 = torch.ops.aten.view.default(mm_215, [2, 8192, 1792]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_2228); convert_element_type_1017 = view_2228 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16) + all_gather_into_tensor_341 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 32, '0'); convert_element_type_1021 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_341); all_gather_into_tensor_341 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_403, [1, 0]); wait_tensor_403 = None + view_2235 = torch.ops.aten.view.default(mul_247, [16384, 1792]); mul_247 = None + mm_216 = torch.ops.aten.mm.default(view_2235, permute_340); view_2235 = permute_340 = None + view_2236 = torch.ops.aten.view.default(mm_216, [2, 8192, 4096]); mm_216 = None + split_132 = torch.ops.aten.split.Tensor(view_2236, 1024, 1); view_2236 = None + getitem_1335 = split_132[0] + getitem_1336 = split_132[1] + getitem_1337 = split_132[2] + getitem_1338 = split_132[3] + getitem_1339 = split_132[4] + getitem_1340 = split_132[5] + getitem_1341 = split_132[6] + getitem_1342 = split_132[7]; split_132 = None + cat_124 = torch.ops.aten.cat.default([getitem_1335, getitem_1336, getitem_1337, getitem_1338, getitem_1339, getitem_1340, getitem_1341, getitem_1342]); getitem_1335 = getitem_1336 = getitem_1337 = getitem_1338 = getitem_1339 = getitem_1340 = getitem_1341 = getitem_1342 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_124, 'sum', 8, '1'); cat_124 = None + wait_tensor_404 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + add_123 = torch.ops.aten.add.Tensor(add_121, wait_tensor_404); add_121 = wait_tensor_404 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16) + all_gather_into_tensor_342 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 32, '0'); convert_element_type_1024 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_342); all_gather_into_tensor_342 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32) + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = rsqrt_62 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_405); mul_248 = wait_tensor_405 = None + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + all_gather_into_tensor_343 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1026, 8, '1'); convert_element_type_1026 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_343); all_gather_into_tensor_343 = None + split_133 = torch.ops.aten.split.Tensor(wait_tensor_406, 2); wait_tensor_406 = None + getitem_1343 = split_133[0] + getitem_1344 = split_133[1] + getitem_1345 = split_133[2] + getitem_1346 = split_133[3] + getitem_1347 = split_133[4] + getitem_1348 = split_133[5] + getitem_1349 = split_133[6] + getitem_1350 = split_133[7]; split_133 = None + cat_125 = torch.ops.aten.cat.default([getitem_1343, getitem_1344, getitem_1345, getitem_1346, getitem_1347, getitem_1348, getitem_1349, getitem_1350], 1); getitem_1343 = getitem_1344 = getitem_1345 = getitem_1346 = getitem_1347 = getitem_1348 = getitem_1349 = getitem_1350 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16) + all_gather_into_tensor_344 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 32, '0'); convert_element_type_1027 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_344); all_gather_into_tensor_344 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_407, [1, 0]); wait_tensor_407 = None + view_2247 = torch.ops.aten.view.default(cat_125, [16384, 4096]); cat_125 = None + mm_217 = torch.ops.aten.mm.default(view_2247, permute_341); permute_341 = None + view_2248 = torch.ops.aten.view.default(mm_217, [2, 8192, 512]) + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16) + all_gather_into_tensor_345 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 32, '0'); convert_element_type_1030 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_345); all_gather_into_tensor_345 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_408, [1, 0]); wait_tensor_408 = None + mm_218 = torch.ops.aten.mm.default(view_2247, permute_342); permute_342 = None + view_2255 = torch.ops.aten.view.default(mm_218, [2, 8192, 128]); mm_218 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16) + all_gather_into_tensor_346 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 32, '0'); convert_element_type_1033 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_346); all_gather_into_tensor_346 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_409, [1, 0]); wait_tensor_409 = None + mm_219 = torch.ops.aten.mm.default(view_2247, permute_343); view_2247 = permute_343 = None + view_2262 = torch.ops.aten.view.default(mm_219, [2, 8192, 128]) + view_2264 = torch.ops.aten.view.default(view_2248, [2, 8192, -1, 128]); view_2248 = None + view_2265 = torch.ops.aten.view.default(view_2255, [2, 8192, -1, 128]); view_2255 = None + view_2266 = torch.ops.aten.view.default(view_2262, [2, 8192, -1, 128]); view_2262 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_2264, torch.float32); view_2264 = None + view_2267 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 4, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_2267); view_2267 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_2265, torch.float32); view_2265 = None + view_2268 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 1, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_2268); view_2268 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_37); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_2270 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 4, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_37); view_as_complex_63 = view_37 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_2271 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 1, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_2270, torch.bfloat16); view_2270 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_2271, torch.bfloat16); view_2271 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 1, 4, 128]); unsqueeze_62 = None + view_2272 = torch.ops.aten.view.default(expand_62, [2, 8192, 4, 128]); expand_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_2266, 3); view_2266 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 1, 4, 128]); unsqueeze_63 = None + view_2273 = torch.ops.aten.view.default(expand_63, [2, 8192, 4, 128]); expand_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_2272, [0, 2, 1, 3]); view_2272 = None + permute_346 = torch.ops.aten.permute.default(view_2273, [0, 2, 1, 3]); view_2273 = None + _scaled_dot_product_cudnn_attention_31 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_344, permute_345, permute_346, None, True, 0.0, True); permute_344 = permute_345 = permute_346 = None + getitem_1351 = _scaled_dot_product_cudnn_attention_31[0] + getitem_1352 = _scaled_dot_product_cudnn_attention_31[1] + getitem_1357 = _scaled_dot_product_cudnn_attention_31[6] + getitem_1358 = _scaled_dot_product_cudnn_attention_31[7]; _scaled_dot_product_cudnn_attention_31 = None + permute_347 = torch.ops.aten.permute.default(getitem_1351, [0, 2, 1, 3]) + view_2274 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16) + all_gather_into_tensor_347 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 32, '0'); convert_element_type_1040 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_347); all_gather_into_tensor_347 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_410, [1, 0]); wait_tensor_410 = None + view_2280 = torch.ops.aten.view.default(view_2274, [16384, 512]); view_2274 = None + mm_220 = torch.ops.aten.mm.default(view_2280, permute_348); view_2280 = permute_348 = None + view_2281 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + split_134 = torch.ops.aten.split.Tensor(view_2281, 1024, 1); view_2281 = None + getitem_1360 = split_134[0] + getitem_1361 = split_134[1] + getitem_1362 = split_134[2] + getitem_1363 = split_134[3] + getitem_1364 = split_134[4] + getitem_1365 = split_134[5] + getitem_1366 = split_134[6] + getitem_1367 = split_134[7]; split_134 = None + cat_126 = torch.ops.aten.cat.default([getitem_1360, getitem_1361, getitem_1362, getitem_1363, getitem_1364, getitem_1365, getitem_1366, getitem_1367]); getitem_1360 = getitem_1361 = getitem_1362 = getitem_1363 = getitem_1364 = getitem_1365 = getitem_1366 = getitem_1367 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_126, 'sum', 8, '1'); cat_126 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63) + add_125 = torch.ops.aten.add.Tensor(add_123, wait_tensor_411); wait_tensor_411 = None + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16) + all_gather_into_tensor_348 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 32, '0'); convert_element_type_1043 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_348); all_gather_into_tensor_348 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32) + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = rsqrt_63 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_412); mul_252 = wait_tensor_412 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + all_gather_into_tensor_349 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1045, 8, '1'); convert_element_type_1045 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_349); all_gather_into_tensor_349 = None + split_135 = torch.ops.aten.split.Tensor(wait_tensor_413, 2); wait_tensor_413 = None + getitem_1368 = split_135[0] + getitem_1369 = split_135[1] + getitem_1370 = split_135[2] + getitem_1371 = split_135[3] + getitem_1372 = split_135[4] + getitem_1373 = split_135[5] + getitem_1374 = split_135[6] + getitem_1375 = split_135[7]; split_135 = None + cat_127 = torch.ops.aten.cat.default([getitem_1368, getitem_1369, getitem_1370, getitem_1371, getitem_1372, getitem_1373, getitem_1374, getitem_1375], 1); getitem_1368 = getitem_1369 = getitem_1370 = getitem_1371 = getitem_1372 = getitem_1373 = getitem_1374 = getitem_1375 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16) + all_gather_into_tensor_350 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 32, '0'); convert_element_type_1046 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_350); all_gather_into_tensor_350 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_414, [1, 0]); wait_tensor_414 = None + view_2292 = torch.ops.aten.view.default(cat_127, [16384, 4096]); cat_127 = None + mm_221 = torch.ops.aten.mm.default(view_2292, permute_349); permute_349 = None + view_2293 = torch.ops.aten.view.default(mm_221, [2, 8192, 1792]) + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_2293, torch.float32); view_2293 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); convert_element_type_1049 = sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16) + all_gather_into_tensor_351 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 32, '0'); convert_element_type_1051 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_351); all_gather_into_tensor_351 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_415, [1, 0]); wait_tensor_415 = None + mm_222 = torch.ops.aten.mm.default(view_2292, permute_350); view_2292 = permute_350 = None + view_2300 = torch.ops.aten.view.default(mm_222, [2, 8192, 1792]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_2300); convert_element_type_1050 = view_2300 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16) + all_gather_into_tensor_352 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 32, '0'); convert_element_type_1054 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_352); all_gather_into_tensor_352 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_416, [1, 0]); wait_tensor_416 = None + view_2307 = torch.ops.aten.view.default(mul_255, [16384, 1792]); mul_255 = None + mm_223 = torch.ops.aten.mm.default(view_2307, permute_351); view_2307 = permute_351 = None + view_2308 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]); mm_223 = None + split_136 = torch.ops.aten.split.Tensor(view_2308, 1024, 1); view_2308 = None + getitem_1376 = split_136[0] + getitem_1377 = split_136[1] + getitem_1378 = split_136[2] + getitem_1379 = split_136[3] + getitem_1380 = split_136[4] + getitem_1381 = split_136[5] + getitem_1382 = split_136[6] + getitem_1383 = split_136[7]; split_136 = None + cat_128 = torch.ops.aten.cat.default([getitem_1376, getitem_1377, getitem_1378, getitem_1379, getitem_1380, getitem_1381, getitem_1382, getitem_1383]); getitem_1376 = getitem_1377 = getitem_1378 = getitem_1379 = getitem_1380 = getitem_1381 = getitem_1382 = getitem_1383 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_128, 'sum', 8, '1'); cat_128 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64) + add_127 = torch.ops.aten.add.Tensor(add_125, wait_tensor_417); add_125 = wait_tensor_417 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16) + all_gather_into_tensor_353 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 32, '0'); convert_element_type_1057 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_353); all_gather_into_tensor_353 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + pow_65 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1058, 2) + mean_64 = torch.ops.aten.mean.dim(pow_65, [2], True); pow_65 = None + add_128 = torch.ops.aten.add.Scalar(mean_64, 1e-05); mean_64 = None + rsqrt_64 = torch.ops.aten.rsqrt.default(add_128); add_128 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_257 = torch.ops.aten.mul.Tensor(mul_256, wait_tensor_418); mul_256 = wait_tensor_418 = None + convert_element_type_1059 = torch.ops.prims.convert_element_type.default(mul_257, torch.bfloat16); mul_257 = None + all_gather_into_tensor_354 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1059, 8, '1'); convert_element_type_1059 = None + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_354); all_gather_into_tensor_354 = None + split_137 = torch.ops.aten.split.Tensor(wait_tensor_419, 2); wait_tensor_419 = None + getitem_1384 = split_137[0] + getitem_1385 = split_137[1] + getitem_1386 = split_137[2] + getitem_1387 = split_137[3] + getitem_1388 = split_137[4] + getitem_1389 = split_137[5] + getitem_1390 = split_137[6] + getitem_1391 = split_137[7]; split_137 = None + cat_129 = torch.ops.aten.cat.default([getitem_1384, getitem_1385, getitem_1386, getitem_1387, getitem_1388, getitem_1389, getitem_1390, getitem_1391], 1); getitem_1384 = getitem_1385 = getitem_1386 = getitem_1387 = getitem_1388 = getitem_1389 = getitem_1390 = getitem_1391 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16) + all_gather_into_tensor_355 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 32, '0'); convert_element_type_1060 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_355); all_gather_into_tensor_355 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_420, [1, 0]); wait_tensor_420 = None + view_2319 = torch.ops.aten.view.default(cat_129, [16384, 4096]); cat_129 = None + mm_224 = torch.ops.aten.mm.default(view_2319, permute_352); permute_352 = None + view_2320 = torch.ops.aten.view.default(mm_224, [2, 8192, 16032]); mm_224 = None + return (view_2320, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, wait_tensor_1, mm, mm_2, getitem_80, getitem_81, getitem_86, getitem_87, reduce_scatter_tensor_1, mm_4, add_3, mm_7, mm_9, getitem_121, getitem_122, getitem_127, getitem_128, reduce_scatter_tensor_3, mm_11, add_7, mm_14, mm_16, getitem_162, getitem_163, getitem_168, getitem_169, reduce_scatter_tensor_5, mm_18, add_11, mm_21, mm_23, getitem_203, getitem_204, getitem_209, getitem_210, reduce_scatter_tensor_7, mm_25, add_15, mm_28, mm_30, getitem_244, getitem_245, getitem_250, getitem_251, reduce_scatter_tensor_9, mm_32, add_19, mm_35, mm_37, getitem_285, getitem_286, getitem_291, getitem_292, reduce_scatter_tensor_11, mm_39, add_23, mm_42, mm_44, getitem_326, getitem_327, getitem_332, getitem_333, reduce_scatter_tensor_13, mm_46, add_27, mm_49, mm_51, getitem_367, getitem_368, getitem_373, getitem_374, reduce_scatter_tensor_15, mm_53, add_31, mm_56, mm_58, getitem_408, getitem_409, getitem_414, getitem_415, reduce_scatter_tensor_17, mm_60, add_35, mm_63, mm_65, getitem_449, getitem_450, getitem_455, getitem_456, reduce_scatter_tensor_19, mm_67, add_39, mm_70, mm_72, getitem_490, getitem_491, getitem_496, getitem_497, reduce_scatter_tensor_21, mm_74, add_43, mm_77, mm_79, getitem_531, getitem_532, getitem_537, getitem_538, reduce_scatter_tensor_23, mm_81, add_47, mm_84, mm_86, getitem_572, getitem_573, getitem_578, getitem_579, reduce_scatter_tensor_25, mm_88, add_51, mm_91, mm_93, getitem_613, getitem_614, getitem_619, getitem_620, reduce_scatter_tensor_27, mm_95, add_55, mm_98, mm_100, getitem_654, getitem_655, getitem_660, getitem_661, reduce_scatter_tensor_29, mm_102, add_59, mm_105, mm_107, getitem_695, getitem_696, getitem_701, getitem_702, reduce_scatter_tensor_31, mm_109, add_63, mm_112, mm_114, getitem_736, getitem_737, getitem_742, getitem_743, reduce_scatter_tensor_33, mm_116, add_67, mm_119, mm_121, getitem_777, getitem_778, getitem_783, getitem_784, reduce_scatter_tensor_35, mm_123, add_71, mm_126, mm_128, getitem_818, getitem_819, getitem_824, getitem_825, reduce_scatter_tensor_37, mm_130, add_75, mm_133, mm_135, getitem_859, getitem_860, getitem_865, getitem_866, reduce_scatter_tensor_39, mm_137, add_79, mm_140, mm_142, getitem_900, getitem_901, getitem_906, getitem_907, reduce_scatter_tensor_41, mm_144, add_83, mm_147, mm_149, getitem_941, getitem_942, getitem_947, getitem_948, reduce_scatter_tensor_43, mm_151, add_87, mm_154, mm_156, getitem_982, getitem_983, getitem_988, getitem_989, reduce_scatter_tensor_45, mm_158, add_91, mm_161, mm_163, getitem_1023, getitem_1024, getitem_1029, getitem_1030, reduce_scatter_tensor_47, mm_165, add_95, mm_168, mm_170, getitem_1064, getitem_1065, getitem_1070, getitem_1071, reduce_scatter_tensor_49, mm_172, add_99, mm_175, mm_177, getitem_1105, getitem_1106, getitem_1111, getitem_1112, reduce_scatter_tensor_51, mm_179, add_103, mm_182, mm_184, getitem_1146, getitem_1147, getitem_1152, getitem_1153, reduce_scatter_tensor_53, mm_186, add_107, mm_189, mm_191, getitem_1187, getitem_1188, getitem_1193, getitem_1194, reduce_scatter_tensor_55, mm_193, add_111, mm_196, mm_198, getitem_1228, getitem_1229, getitem_1234, getitem_1235, reduce_scatter_tensor_57, mm_200, add_115, mm_203, mm_205, getitem_1269, getitem_1270, getitem_1275, getitem_1276, reduce_scatter_tensor_59, mm_207, add_119, mm_210, mm_212, getitem_1310, getitem_1311, getitem_1316, getitem_1317, reduce_scatter_tensor_61, mm_214, add_123, mm_217, mm_219, getitem_1351, getitem_1352, getitem_1357, getitem_1358, reduce_scatter_tensor_63, mm_221, reduce_scatter_tensor_64, rsqrt_64, view_2319) + +def load_args(reader): + buf0 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf0, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_1 + buf1 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf1, (501, 4096), is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf3, (128,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf4, (16, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf5, (4, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf6, (4, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf7, (128, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf8, (128,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf9, (56, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf10, (56, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf11, (128, 1792), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf12, (128,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf13, (16, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf14, (4, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf15, (4, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf16, (128, 512), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf17, (128,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf18, (56, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf19, (56, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf20, (128, 1792), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf21, (128,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf22, (16, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf23, (4, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf24, (4, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf25, (128, 512), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf26, (128,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf27, (56, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf28, (56, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf29, (128, 1792), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf30, (128,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf31, (16, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf32, (4, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf33, (4, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf34, (128, 512), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf35, (128,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf36, (56, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf37, (56, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf38, (128, 1792), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf39, (128,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf40, (16, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf41, (4, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf42, (4, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf43, (128, 512), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf44, (128,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf45, (56, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf46, (56, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf47, (128, 1792), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf48, (128,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf49, (16, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf50, (4, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf51, (4, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf52, (128, 512), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf53, (128,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf54, (56, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf55, (56, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf56, (128, 1792), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf57, (128,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf58, (16, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf59, (4, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf60, (4, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf61, (128, 512), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf62, (128,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf63, (56, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf64, (56, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf65, (128, 1792), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf66, (128,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf67, (16, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf68, (4, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf69, (4, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf70, (128, 512), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf71, (128,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf72, (56, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf73, (56, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf74, (128, 1792), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf75, (128,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf76, (16, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf77, (4, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf78, (4, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf79, (128, 512), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf80, (128,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf81, (56, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf82, (56, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf83, (128, 1792), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf84, (128,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf85, (16, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf86, (4, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf87, (4, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf88, (128, 512), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf89, (128,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf90, (56, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf91, (56, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf92, (128, 1792), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf93, (128,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf94, (16, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf95, (4, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf96, (4, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf97, (128, 512), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf98, (128,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf99, (56, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf100, (56, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf101, (128, 1792), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf102, (128,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf103, (16, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf104, (4, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf105, (4, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf106, (128, 512), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf107, (128,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf108, (56, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf109, (56, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf110, (128, 1792), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf111, (128,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf112, (16, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf113, (4, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf114, (4, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf115, (128, 512), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf116, (128,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf117, (56, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf118, (56, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf119, (128, 1792), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf120, (128,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf121, (16, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf122, (4, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf123, (4, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf124, (128, 512), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf125, (128,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf126, (56, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf127, (56, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf128, (128, 1792), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf129, (128,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf130, (16, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf131, (4, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf132, (4, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf133, (128, 512), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf134, (128,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf135, (56, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf136, (56, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf137, (128, 1792), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf138, (128,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf139, (16, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf140, (4, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf141, (4, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf142, (128, 512), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf143, (128,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf144, (56, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf145, (56, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf146, (128, 1792), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf147, (128,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf148, (16, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf149, (4, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf150, (4, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf151, (128, 512), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf152, (128,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf153, (56, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf154, (56, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf155, (128, 1792), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf156, (128,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf157, (16, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf158, (4, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf159, (4, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf160, (128, 512), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf161, (128,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf162, (56, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf163, (56, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf164, (128, 1792), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf165, (128,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf166, (16, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf167, (4, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf168, (4, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf169, (128, 512), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf170, (128,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf171, (56, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf172, (56, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf173, (128, 1792), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf174, (128,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf175, (16, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf176, (4, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf177, (4, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf178, (128, 512), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf179, (128,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf180, (56, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf181, (56, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf182, (128, 1792), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf183, (128,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf184, (16, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf185, (4, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf186, (4, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf187, (128, 512), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf188, (128,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf189, (56, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf190, (56, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf191, (128, 1792), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf192, (128,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf193, (16, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf194, (4, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf195, (4, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf196, (128, 512), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf197, (128,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf198, (56, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf199, (56, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf200, (128, 1792), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf201, (128,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf202, (16, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf203, (4, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf204, (4, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf205, (128, 512), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf206, (128,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf207, (56, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf208, (56, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf209, (128, 1792), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf210, (128,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf211, (16, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf212, (4, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf213, (4, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf214, (128, 512), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf215, (128,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf216, (56, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf217, (56, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf218, (128, 1792), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf219, (128,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf220, (16, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf221, (4, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf222, (4, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf223, (128, 512), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf224, (128,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf225, (56, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf226, (56, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf227, (128, 1792), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf228, (128,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf229, (16, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf230, (4, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf231, (4, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf232, (128, 512), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf233, (128,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf234, (56, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf235, (56, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf236, (128, 1792), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf237, (128,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf238, (16, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf239, (4, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf240, (4, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf241, (128, 512), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf242, (128,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf243, (56, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf244, (56, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf245, (128, 1792), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf246, (128,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf247, (16, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf248, (4, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf249, (4, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf250, (128, 512), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf251, (128,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf252, (56, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf253, (56, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf254, (128, 1792), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf255, (128,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf256, (16, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf257, (4, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf258, (4, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf259, (128, 512), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf260, (128,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf261, (56, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf262, (56, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf263, (128, 1792), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf264, (128,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf265, (16, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf266, (4, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf267, (4, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf268, (128, 512), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf269, (128,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf270, (56, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf271, (56, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf272, (128, 1792), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf273, (128,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf274, (16, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf275, (4, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf276, (4, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf277, (128, 512), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf278, (128,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf279, (56, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf280, (56, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf281, (128, 1792), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf282, (128,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf283, (16, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf284, (4, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 65536, device=device(type='cuda', index=0)) + reader.tensor(buf285, (4, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf286, (128, 512), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf287, (128,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf288, (56, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf289, (56, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 917504, device=device(type='cuda', index=0)) + reader.tensor(buf290, (128, 1792), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 512, device=device(type='cuda', index=0)) + reader.tensor(buf291, (128,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 8208384, device=device(type='cuda', index=0)) + reader.tensor(buf292, (501, 4096), is_leaf=True) # primals_293 + +load_args._version = 0 + +def get_pg_config(): + return {'0': {'size': 32, 'rank': 0}, '1': {'size': 8, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls32_8.table" + diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_1d.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_1d.py new file mode 100644 index 00000000..2482cafe --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_1d.py @@ -0,0 +1,4153 @@ +# fmt: off +# flake8: noqa +# isort: skip_file +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293): + convert_element_type = torch.ops.prims.convert_element_type.default(primals_1, torch.bfloat16) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type, 64, '0'); convert_element_type = None + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None + embedding = torch.ops.aten.embedding.default(wait_tensor, primals_2); wait_tensor = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16) + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 64, '0'); convert_element_type_1 = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(embedding, torch.float32) + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = rsqrt = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_1); mul = wait_tensor_1 = None + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16) + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 64, '0'); convert_element_type_4 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + permute = torch.ops.aten.permute.default(wait_tensor_2, [1, 0]); wait_tensor_2 = None + view_3 = torch.ops.aten.view.default(convert_element_type_3, [16384, 4096]); convert_element_type_3 = None + mm = torch.ops.aten.mm.default(view_3, permute); permute = None + view_4 = torch.ops.aten.view.default(mm, [2, 8192, 4096]) + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16) + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 64, '0'); convert_element_type_7 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_3, [1, 0]); wait_tensor_3 = None + mm_1 = torch.ops.aten.mm.default(view_3, permute_1); permute_1 = None + view_7 = torch.ops.aten.view.default(mm_1, [2, 8192, 1024]); mm_1 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16) + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 64, '0'); convert_element_type_10 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + mm_2 = torch.ops.aten.mm.default(view_3, permute_2); view_3 = permute_2 = None + view_10 = torch.ops.aten.view.default(mm_2, [2, 8192, 1024]) + view_11 = torch.ops.aten.view.default(view_4, [2, 8192, -1, 128]); view_4 = None + view_12 = torch.ops.aten.view.default(view_7, [2, 8192, -1, 128]); view_7 = None + view_13 = torch.ops.aten.view.default(view_10, [2, 8192, -1, 128]); view_10 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None + view_14 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 32, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_14); view_14 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_12, torch.float32); view_12 = None + view_15 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 8, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_15); view_15 = None + view_16 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]) + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_16); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_17 = torch.ops.aten.view.default(view_as_real, [2, 8192, 32, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_16); view_as_complex_1 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_18 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 8, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_17, torch.bfloat16); view_17 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_18, torch.bfloat16); view_18 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 8, 4, 128]); unsqueeze = None + clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None + view_19 = torch.ops.aten.view.default(clone, [2, 8192, 32, 128]); clone = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_13, 3); view_13 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 8, 4, 128]); unsqueeze_1 = None + clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None + view_20 = torch.ops.aten.view.default(clone_1, [2, 8192, 32, 128]); clone_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]); view_19 = None + permute_5 = torch.ops.aten.permute.default(view_20, [0, 2, 1, 3]); view_20 = None + _scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_3, permute_4, permute_5, None, True, 0.0, True); permute_3 = permute_4 = permute_5 = None + getitem = _scaled_dot_product_cudnn_attention[0] + getitem_1 = _scaled_dot_product_cudnn_attention[1] + getitem_6 = _scaled_dot_product_cudnn_attention[6] + getitem_7 = _scaled_dot_product_cudnn_attention[7]; _scaled_dot_product_cudnn_attention = None + permute_6 = torch.ops.aten.permute.default(getitem, [0, 2, 1, 3]) + view_21 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16) + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 64, '0'); convert_element_type_17 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + view_23 = torch.ops.aten.view.default(view_21, [16384, 4096]); view_21 = None + mm_3 = torch.ops.aten.mm.default(view_23, permute_7); view_23 = permute_7 = None + view_24 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + add_1 = torch.ops.aten.add.Tensor(embedding, view_24); view_24 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16) + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 64, '0'); convert_element_type_20 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32) + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = rsqrt_1 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_6); mul_4 = wait_tensor_6 = None + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16) + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 64, '0'); convert_element_type_23 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + view_27 = torch.ops.aten.view.default(convert_element_type_22, [16384, 4096]); convert_element_type_22 = None + mm_4 = torch.ops.aten.mm.default(view_27, permute_8); permute_8 = None + view_28 = torch.ops.aten.view.default(mm_4, [2, 8192, 14336]) + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_28, torch.float32); view_28 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); convert_element_type_26 = sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16) + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 64, '0'); convert_element_type_28 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_8, [1, 0]); wait_tensor_8 = None + mm_5 = torch.ops.aten.mm.default(view_27, permute_9); view_27 = permute_9 = None + view_31 = torch.ops.aten.view.default(mm_5, [2, 8192, 14336]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_31); convert_element_type_27 = view_31 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16) + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 64, '0'); convert_element_type_31 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_9, [1, 0]); wait_tensor_9 = None + view_33 = torch.ops.aten.view.default(mul_7, [16384, 14336]); mul_7 = None + mm_6 = torch.ops.aten.mm.default(view_33, permute_10); view_33 = permute_10 = None + view_34 = torch.ops.aten.view.default(mm_6, [2, 8192, 4096]); mm_6 = None + add_3 = torch.ops.aten.add.Tensor(add_1, view_34); add_1 = view_34 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16) + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 64, '0'); convert_element_type_34 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32) + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = rsqrt_2 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_10); mul_8 = wait_tensor_10 = None + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16) + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 64, '0'); convert_element_type_37 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + view_37 = torch.ops.aten.view.default(convert_element_type_36, [16384, 4096]); convert_element_type_36 = None + mm_7 = torch.ops.aten.mm.default(view_37, permute_11); permute_11 = None + view_38 = torch.ops.aten.view.default(mm_7, [2, 8192, 4096]) + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16) + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 64, '0'); convert_element_type_40 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_8 = torch.ops.aten.mm.default(view_37, permute_12); permute_12 = None + view_41 = torch.ops.aten.view.default(mm_8, [2, 8192, 1024]); mm_8 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16) + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 64, '0'); convert_element_type_43 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + mm_9 = torch.ops.aten.mm.default(view_37, permute_13); view_37 = permute_13 = None + view_44 = torch.ops.aten.view.default(mm_9, [2, 8192, 1024]) + view_45 = torch.ops.aten.view.default(view_38, [2, 8192, -1, 128]); view_38 = None + view_46 = torch.ops.aten.view.default(view_41, [2, 8192, -1, 128]); view_41 = None + view_47 = torch.ops.aten.view.default(view_44, [2, 8192, -1, 128]); view_44 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_45, torch.float32); view_45 = None + view_48 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 32, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_48); view_48 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_46, torch.float32); view_46 = None + view_49 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 8, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_49); view_49 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_16); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_51 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 32, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_16); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_52 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 8, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_51, torch.bfloat16); view_51 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_52, torch.bfloat16); view_52 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 8, 4, 128]); unsqueeze_2 = None + clone_2 = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format); expand_2 = None + view_53 = torch.ops.aten.view.default(clone_2, [2, 8192, 32, 128]); clone_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_47, 3); view_47 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 8, 4, 128]); unsqueeze_3 = None + clone_3 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None + view_54 = torch.ops.aten.view.default(clone_3, [2, 8192, 32, 128]); clone_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_53, [0, 2, 1, 3]); view_53 = None + permute_16 = torch.ops.aten.permute.default(view_54, [0, 2, 1, 3]); view_54 = None + _scaled_dot_product_cudnn_attention_1 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_14, permute_15, permute_16, None, True, 0.0, True); permute_14 = permute_15 = permute_16 = None + getitem_9 = _scaled_dot_product_cudnn_attention_1[0] + getitem_10 = _scaled_dot_product_cudnn_attention_1[1] + getitem_15 = _scaled_dot_product_cudnn_attention_1[6] + getitem_16 = _scaled_dot_product_cudnn_attention_1[7]; _scaled_dot_product_cudnn_attention_1 = None + permute_17 = torch.ops.aten.permute.default(getitem_9, [0, 2, 1, 3]) + view_55 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16) + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 64, '0'); convert_element_type_50 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_14, [1, 0]); wait_tensor_14 = None + view_57 = torch.ops.aten.view.default(view_55, [16384, 4096]); view_55 = None + mm_10 = torch.ops.aten.mm.default(view_57, permute_18); view_57 = permute_18 = None + view_58 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + add_5 = torch.ops.aten.add.Tensor(add_3, view_58); view_58 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16) + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 64, '0'); convert_element_type_53 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32) + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = rsqrt_3 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_15); mul_12 = wait_tensor_15 = None + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16) + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 64, '0'); convert_element_type_56 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_16, [1, 0]); wait_tensor_16 = None + view_61 = torch.ops.aten.view.default(convert_element_type_55, [16384, 4096]); convert_element_type_55 = None + mm_11 = torch.ops.aten.mm.default(view_61, permute_19); permute_19 = None + view_62 = torch.ops.aten.view.default(mm_11, [2, 8192, 14336]) + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_62, torch.float32); view_62 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); convert_element_type_59 = sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16) + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 64, '0'); convert_element_type_61 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + mm_12 = torch.ops.aten.mm.default(view_61, permute_20); view_61 = permute_20 = None + view_65 = torch.ops.aten.view.default(mm_12, [2, 8192, 14336]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_65); convert_element_type_60 = view_65 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16) + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 64, '0'); convert_element_type_64 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + view_67 = torch.ops.aten.view.default(mul_15, [16384, 14336]); mul_15 = None + mm_13 = torch.ops.aten.mm.default(view_67, permute_21); view_67 = permute_21 = None + view_68 = torch.ops.aten.view.default(mm_13, [2, 8192, 4096]); mm_13 = None + add_7 = torch.ops.aten.add.Tensor(add_5, view_68); add_5 = view_68 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16) + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 64, '0'); convert_element_type_67 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32) + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = rsqrt_4 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_19); mul_16 = wait_tensor_19 = None + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16) + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 64, '0'); convert_element_type_70 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + view_71 = torch.ops.aten.view.default(convert_element_type_69, [16384, 4096]); convert_element_type_69 = None + mm_14 = torch.ops.aten.mm.default(view_71, permute_22); permute_22 = None + view_72 = torch.ops.aten.view.default(mm_14, [2, 8192, 4096]) + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16) + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 64, '0'); convert_element_type_73 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_21, [1, 0]); wait_tensor_21 = None + mm_15 = torch.ops.aten.mm.default(view_71, permute_23); permute_23 = None + view_75 = torch.ops.aten.view.default(mm_15, [2, 8192, 1024]); mm_15 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16) + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 64, '0'); convert_element_type_76 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_22, [1, 0]); wait_tensor_22 = None + mm_16 = torch.ops.aten.mm.default(view_71, permute_24); view_71 = permute_24 = None + view_78 = torch.ops.aten.view.default(mm_16, [2, 8192, 1024]) + view_79 = torch.ops.aten.view.default(view_72, [2, 8192, -1, 128]); view_72 = None + view_80 = torch.ops.aten.view.default(view_75, [2, 8192, -1, 128]); view_75 = None + view_81 = torch.ops.aten.view.default(view_78, [2, 8192, -1, 128]); view_78 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_79, torch.float32); view_79 = None + view_82 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 32, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_82); view_82 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_80, torch.float32); view_80 = None + view_83 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 8, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_83); view_83 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_16); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_85 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 32, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_16); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_86 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 8, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_85, torch.bfloat16); view_85 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_86, torch.bfloat16); view_86 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 8, 4, 128]); unsqueeze_4 = None + clone_4 = torch.ops.aten.clone.default(expand_4, memory_format = torch.contiguous_format); expand_4 = None + view_87 = torch.ops.aten.view.default(clone_4, [2, 8192, 32, 128]); clone_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_81, 3); view_81 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 8, 4, 128]); unsqueeze_5 = None + clone_5 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format); expand_5 = None + view_88 = torch.ops.aten.view.default(clone_5, [2, 8192, 32, 128]); clone_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_87, [0, 2, 1, 3]); view_87 = None + permute_27 = torch.ops.aten.permute.default(view_88, [0, 2, 1, 3]); view_88 = None + _scaled_dot_product_cudnn_attention_2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_25, permute_26, permute_27, None, True, 0.0, True); permute_25 = permute_26 = permute_27 = None + getitem_18 = _scaled_dot_product_cudnn_attention_2[0] + getitem_19 = _scaled_dot_product_cudnn_attention_2[1] + getitem_24 = _scaled_dot_product_cudnn_attention_2[6] + getitem_25 = _scaled_dot_product_cudnn_attention_2[7]; _scaled_dot_product_cudnn_attention_2 = None + permute_28 = torch.ops.aten.permute.default(getitem_18, [0, 2, 1, 3]) + view_89 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16) + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 64, '0'); convert_element_type_83 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_23, [1, 0]); wait_tensor_23 = None + view_91 = torch.ops.aten.view.default(view_89, [16384, 4096]); view_89 = None + mm_17 = torch.ops.aten.mm.default(view_91, permute_29); view_91 = permute_29 = None + view_92 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + add_9 = torch.ops.aten.add.Tensor(add_7, view_92); view_92 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16) + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 64, '0'); convert_element_type_86 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32) + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = rsqrt_5 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_24); mul_20 = wait_tensor_24 = None + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16) + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 64, '0'); convert_element_type_89 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + view_95 = torch.ops.aten.view.default(convert_element_type_88, [16384, 4096]); convert_element_type_88 = None + mm_18 = torch.ops.aten.mm.default(view_95, permute_30); permute_30 = None + view_96 = torch.ops.aten.view.default(mm_18, [2, 8192, 14336]) + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_96, torch.float32); view_96 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); convert_element_type_92 = sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16) + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 64, '0'); convert_element_type_94 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + mm_19 = torch.ops.aten.mm.default(view_95, permute_31); view_95 = permute_31 = None + view_99 = torch.ops.aten.view.default(mm_19, [2, 8192, 14336]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_99); convert_element_type_93 = view_99 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16) + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 64, '0'); convert_element_type_97 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_27, [1, 0]); wait_tensor_27 = None + view_101 = torch.ops.aten.view.default(mul_23, [16384, 14336]); mul_23 = None + mm_20 = torch.ops.aten.mm.default(view_101, permute_32); view_101 = permute_32 = None + view_102 = torch.ops.aten.view.default(mm_20, [2, 8192, 4096]); mm_20 = None + add_11 = torch.ops.aten.add.Tensor(add_9, view_102); add_9 = view_102 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16) + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 64, '0'); convert_element_type_100 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32) + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = rsqrt_6 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_28); mul_24 = wait_tensor_28 = None + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16) + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 64, '0'); convert_element_type_103 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_29, [1, 0]); wait_tensor_29 = None + view_105 = torch.ops.aten.view.default(convert_element_type_102, [16384, 4096]); convert_element_type_102 = None + mm_21 = torch.ops.aten.mm.default(view_105, permute_33); permute_33 = None + view_106 = torch.ops.aten.view.default(mm_21, [2, 8192, 4096]) + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16) + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 64, '0'); convert_element_type_106 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + mm_22 = torch.ops.aten.mm.default(view_105, permute_34); permute_34 = None + view_109 = torch.ops.aten.view.default(mm_22, [2, 8192, 1024]); mm_22 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16) + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 64, '0'); convert_element_type_109 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_23 = torch.ops.aten.mm.default(view_105, permute_35); view_105 = permute_35 = None + view_112 = torch.ops.aten.view.default(mm_23, [2, 8192, 1024]) + view_113 = torch.ops.aten.view.default(view_106, [2, 8192, -1, 128]); view_106 = None + view_114 = torch.ops.aten.view.default(view_109, [2, 8192, -1, 128]); view_109 = None + view_115 = torch.ops.aten.view.default(view_112, [2, 8192, -1, 128]); view_112 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_113, torch.float32); view_113 = None + view_116 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 32, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_116); view_116 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_114, torch.float32); view_114 = None + view_117 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 8, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_117); view_117 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_16); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_119 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 32, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_16); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_120 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 8, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_119, torch.bfloat16); view_119 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_120, torch.bfloat16); view_120 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 8, 4, 128]); unsqueeze_6 = None + clone_6 = torch.ops.aten.clone.default(expand_6, memory_format = torch.contiguous_format); expand_6 = None + view_121 = torch.ops.aten.view.default(clone_6, [2, 8192, 32, 128]); clone_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_115, 3); view_115 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 8, 4, 128]); unsqueeze_7 = None + clone_7 = torch.ops.aten.clone.default(expand_7, memory_format = torch.contiguous_format); expand_7 = None + view_122 = torch.ops.aten.view.default(clone_7, [2, 8192, 32, 128]); clone_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_121, [0, 2, 1, 3]); view_121 = None + permute_38 = torch.ops.aten.permute.default(view_122, [0, 2, 1, 3]); view_122 = None + _scaled_dot_product_cudnn_attention_3 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_36, permute_37, permute_38, None, True, 0.0, True); permute_36 = permute_37 = permute_38 = None + getitem_27 = _scaled_dot_product_cudnn_attention_3[0] + getitem_28 = _scaled_dot_product_cudnn_attention_3[1] + getitem_33 = _scaled_dot_product_cudnn_attention_3[6] + getitem_34 = _scaled_dot_product_cudnn_attention_3[7]; _scaled_dot_product_cudnn_attention_3 = None + permute_39 = torch.ops.aten.permute.default(getitem_27, [0, 2, 1, 3]) + view_123 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16) + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 64, '0'); convert_element_type_116 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + view_125 = torch.ops.aten.view.default(view_123, [16384, 4096]); view_123 = None + mm_24 = torch.ops.aten.mm.default(view_125, permute_40); view_125 = permute_40 = None + view_126 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + add_13 = torch.ops.aten.add.Tensor(add_11, view_126); view_126 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16) + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 64, '0'); convert_element_type_119 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32) + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = rsqrt_7 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_33); mul_28 = wait_tensor_33 = None + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16) + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 64, '0'); convert_element_type_122 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_34, [1, 0]); wait_tensor_34 = None + view_129 = torch.ops.aten.view.default(convert_element_type_121, [16384, 4096]); convert_element_type_121 = None + mm_25 = torch.ops.aten.mm.default(view_129, permute_41); permute_41 = None + view_130 = torch.ops.aten.view.default(mm_25, [2, 8192, 14336]) + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_130, torch.float32); view_130 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); convert_element_type_125 = sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16) + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 64, '0'); convert_element_type_127 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_35, [1, 0]); wait_tensor_35 = None + mm_26 = torch.ops.aten.mm.default(view_129, permute_42); view_129 = permute_42 = None + view_133 = torch.ops.aten.view.default(mm_26, [2, 8192, 14336]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_133); convert_element_type_126 = view_133 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16) + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 64, '0'); convert_element_type_130 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_36, [1, 0]); wait_tensor_36 = None + view_135 = torch.ops.aten.view.default(mul_31, [16384, 14336]); mul_31 = None + mm_27 = torch.ops.aten.mm.default(view_135, permute_43); view_135 = permute_43 = None + view_136 = torch.ops.aten.view.default(mm_27, [2, 8192, 4096]); mm_27 = None + add_15 = torch.ops.aten.add.Tensor(add_13, view_136); add_13 = view_136 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16) + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 64, '0'); convert_element_type_133 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32) + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = rsqrt_8 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_37); mul_32 = wait_tensor_37 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16) + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 64, '0'); convert_element_type_136 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + view_139 = torch.ops.aten.view.default(convert_element_type_135, [16384, 4096]); convert_element_type_135 = None + mm_28 = torch.ops.aten.mm.default(view_139, permute_44); permute_44 = None + view_140 = torch.ops.aten.view.default(mm_28, [2, 8192, 4096]) + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16) + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 64, '0'); convert_element_type_139 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + mm_29 = torch.ops.aten.mm.default(view_139, permute_45); permute_45 = None + view_143 = torch.ops.aten.view.default(mm_29, [2, 8192, 1024]); mm_29 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16) + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 64, '0'); convert_element_type_142 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_40, [1, 0]); wait_tensor_40 = None + mm_30 = torch.ops.aten.mm.default(view_139, permute_46); view_139 = permute_46 = None + view_146 = torch.ops.aten.view.default(mm_30, [2, 8192, 1024]) + view_147 = torch.ops.aten.view.default(view_140, [2, 8192, -1, 128]); view_140 = None + view_148 = torch.ops.aten.view.default(view_143, [2, 8192, -1, 128]); view_143 = None + view_149 = torch.ops.aten.view.default(view_146, [2, 8192, -1, 128]); view_146 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_147, torch.float32); view_147 = None + view_150 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 32, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_150); view_150 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_148, torch.float32); view_148 = None + view_151 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 8, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_151); view_151 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_16); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_153 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 32, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_16); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_154 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 8, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_153, torch.bfloat16); view_153 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_154, torch.bfloat16); view_154 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 8, 4, 128]); unsqueeze_8 = None + clone_8 = torch.ops.aten.clone.default(expand_8, memory_format = torch.contiguous_format); expand_8 = None + view_155 = torch.ops.aten.view.default(clone_8, [2, 8192, 32, 128]); clone_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_149, 3); view_149 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 8, 4, 128]); unsqueeze_9 = None + clone_9 = torch.ops.aten.clone.default(expand_9, memory_format = torch.contiguous_format); expand_9 = None + view_156 = torch.ops.aten.view.default(clone_9, [2, 8192, 32, 128]); clone_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_155, [0, 2, 1, 3]); view_155 = None + permute_49 = torch.ops.aten.permute.default(view_156, [0, 2, 1, 3]); view_156 = None + _scaled_dot_product_cudnn_attention_4 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_47, permute_48, permute_49, None, True, 0.0, True); permute_47 = permute_48 = permute_49 = None + getitem_36 = _scaled_dot_product_cudnn_attention_4[0] + getitem_37 = _scaled_dot_product_cudnn_attention_4[1] + getitem_42 = _scaled_dot_product_cudnn_attention_4[6] + getitem_43 = _scaled_dot_product_cudnn_attention_4[7]; _scaled_dot_product_cudnn_attention_4 = None + permute_50 = torch.ops.aten.permute.default(getitem_36, [0, 2, 1, 3]) + view_157 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16) + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 64, '0'); convert_element_type_149 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_41, [1, 0]); wait_tensor_41 = None + view_159 = torch.ops.aten.view.default(view_157, [16384, 4096]); view_157 = None + mm_31 = torch.ops.aten.mm.default(view_159, permute_51); view_159 = permute_51 = None + view_160 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + add_17 = torch.ops.aten.add.Tensor(add_15, view_160); view_160 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16) + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 64, '0'); convert_element_type_152 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32) + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = rsqrt_9 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_42); mul_36 = wait_tensor_42 = None + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16) + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 64, '0'); convert_element_type_155 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + view_163 = torch.ops.aten.view.default(convert_element_type_154, [16384, 4096]); convert_element_type_154 = None + mm_32 = torch.ops.aten.mm.default(view_163, permute_52); permute_52 = None + view_164 = torch.ops.aten.view.default(mm_32, [2, 8192, 14336]) + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_164, torch.float32); view_164 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); convert_element_type_158 = sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16) + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 64, '0'); convert_element_type_160 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_33 = torch.ops.aten.mm.default(view_163, permute_53); view_163 = permute_53 = None + view_167 = torch.ops.aten.view.default(mm_33, [2, 8192, 14336]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_167); convert_element_type_159 = view_167 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16) + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 64, '0'); convert_element_type_163 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + view_169 = torch.ops.aten.view.default(mul_39, [16384, 14336]); mul_39 = None + mm_34 = torch.ops.aten.mm.default(view_169, permute_54); view_169 = permute_54 = None + view_170 = torch.ops.aten.view.default(mm_34, [2, 8192, 4096]); mm_34 = None + add_19 = torch.ops.aten.add.Tensor(add_17, view_170); add_17 = view_170 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16) + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 64, '0'); convert_element_type_166 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32) + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = rsqrt_10 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_46); mul_40 = wait_tensor_46 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16) + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 64, '0'); convert_element_type_169 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_47, [1, 0]); wait_tensor_47 = None + view_173 = torch.ops.aten.view.default(convert_element_type_168, [16384, 4096]); convert_element_type_168 = None + mm_35 = torch.ops.aten.mm.default(view_173, permute_55); permute_55 = None + view_174 = torch.ops.aten.view.default(mm_35, [2, 8192, 4096]) + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16) + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 64, '0'); convert_element_type_172 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_48, [1, 0]); wait_tensor_48 = None + mm_36 = torch.ops.aten.mm.default(view_173, permute_56); permute_56 = None + view_177 = torch.ops.aten.view.default(mm_36, [2, 8192, 1024]); mm_36 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16) + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 64, '0'); convert_element_type_175 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_49, [1, 0]); wait_tensor_49 = None + mm_37 = torch.ops.aten.mm.default(view_173, permute_57); view_173 = permute_57 = None + view_180 = torch.ops.aten.view.default(mm_37, [2, 8192, 1024]) + view_181 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + view_182 = torch.ops.aten.view.default(view_177, [2, 8192, -1, 128]); view_177 = None + view_183 = torch.ops.aten.view.default(view_180, [2, 8192, -1, 128]); view_180 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_181, torch.float32); view_181 = None + view_184 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 32, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_184); view_184 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_182, torch.float32); view_182 = None + view_185 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 8, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_185); view_185 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_16); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_187 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 32, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_16); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_188 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 8, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_187, torch.bfloat16); view_187 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_188, torch.bfloat16); view_188 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 8, 4, 128]); unsqueeze_10 = None + clone_10 = torch.ops.aten.clone.default(expand_10, memory_format = torch.contiguous_format); expand_10 = None + view_189 = torch.ops.aten.view.default(clone_10, [2, 8192, 32, 128]); clone_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_183, 3); view_183 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 8, 4, 128]); unsqueeze_11 = None + clone_11 = torch.ops.aten.clone.default(expand_11, memory_format = torch.contiguous_format); expand_11 = None + view_190 = torch.ops.aten.view.default(clone_11, [2, 8192, 32, 128]); clone_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_189, [0, 2, 1, 3]); view_189 = None + permute_60 = torch.ops.aten.permute.default(view_190, [0, 2, 1, 3]); view_190 = None + _scaled_dot_product_cudnn_attention_5 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_58, permute_59, permute_60, None, True, 0.0, True); permute_58 = permute_59 = permute_60 = None + getitem_45 = _scaled_dot_product_cudnn_attention_5[0] + getitem_46 = _scaled_dot_product_cudnn_attention_5[1] + getitem_51 = _scaled_dot_product_cudnn_attention_5[6] + getitem_52 = _scaled_dot_product_cudnn_attention_5[7]; _scaled_dot_product_cudnn_attention_5 = None + permute_61 = torch.ops.aten.permute.default(getitem_45, [0, 2, 1, 3]) + view_191 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16) + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 64, '0'); convert_element_type_182 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_193 = torch.ops.aten.view.default(view_191, [16384, 4096]); view_191 = None + mm_38 = torch.ops.aten.mm.default(view_193, permute_62); view_193 = permute_62 = None + view_194 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + add_21 = torch.ops.aten.add.Tensor(add_19, view_194); view_194 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16) + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 64, '0'); convert_element_type_185 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32) + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = rsqrt_11 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_51); mul_44 = wait_tensor_51 = None + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16) + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 64, '0'); convert_element_type_188 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + view_197 = torch.ops.aten.view.default(convert_element_type_187, [16384, 4096]); convert_element_type_187 = None + mm_39 = torch.ops.aten.mm.default(view_197, permute_63); permute_63 = None + view_198 = torch.ops.aten.view.default(mm_39, [2, 8192, 14336]) + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_198, torch.float32); view_198 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); convert_element_type_191 = sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16) + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 64, '0'); convert_element_type_193 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_53, [1, 0]); wait_tensor_53 = None + mm_40 = torch.ops.aten.mm.default(view_197, permute_64); view_197 = permute_64 = None + view_201 = torch.ops.aten.view.default(mm_40, [2, 8192, 14336]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_201); convert_element_type_192 = view_201 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16) + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 64, '0'); convert_element_type_196 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_54, [1, 0]); wait_tensor_54 = None + view_203 = torch.ops.aten.view.default(mul_47, [16384, 14336]); mul_47 = None + mm_41 = torch.ops.aten.mm.default(view_203, permute_65); view_203 = permute_65 = None + view_204 = torch.ops.aten.view.default(mm_41, [2, 8192, 4096]); mm_41 = None + add_23 = torch.ops.aten.add.Tensor(add_21, view_204); add_21 = view_204 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16) + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 64, '0'); convert_element_type_199 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32) + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = rsqrt_12 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_55); mul_48 = wait_tensor_55 = None + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16) + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 64, '0'); convert_element_type_202 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + view_207 = torch.ops.aten.view.default(convert_element_type_201, [16384, 4096]); convert_element_type_201 = None + mm_42 = torch.ops.aten.mm.default(view_207, permute_66); permute_66 = None + view_208 = torch.ops.aten.view.default(mm_42, [2, 8192, 4096]) + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16) + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 64, '0'); convert_element_type_205 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_43 = torch.ops.aten.mm.default(view_207, permute_67); permute_67 = None + view_211 = torch.ops.aten.view.default(mm_43, [2, 8192, 1024]); mm_43 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16) + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 64, '0'); convert_element_type_208 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + mm_44 = torch.ops.aten.mm.default(view_207, permute_68); view_207 = permute_68 = None + view_214 = torch.ops.aten.view.default(mm_44, [2, 8192, 1024]) + view_215 = torch.ops.aten.view.default(view_208, [2, 8192, -1, 128]); view_208 = None + view_216 = torch.ops.aten.view.default(view_211, [2, 8192, -1, 128]); view_211 = None + view_217 = torch.ops.aten.view.default(view_214, [2, 8192, -1, 128]); view_214 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_215, torch.float32); view_215 = None + view_218 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 32, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_218); view_218 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_216, torch.float32); view_216 = None + view_219 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 8, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_219); view_219 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_16); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_221 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 32, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_16); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_222 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 8, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_221, torch.bfloat16); view_221 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_222, torch.bfloat16); view_222 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 8, 4, 128]); unsqueeze_12 = None + clone_12 = torch.ops.aten.clone.default(expand_12, memory_format = torch.contiguous_format); expand_12 = None + view_223 = torch.ops.aten.view.default(clone_12, [2, 8192, 32, 128]); clone_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_217, 3); view_217 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 8, 4, 128]); unsqueeze_13 = None + clone_13 = torch.ops.aten.clone.default(expand_13, memory_format = torch.contiguous_format); expand_13 = None + view_224 = torch.ops.aten.view.default(clone_13, [2, 8192, 32, 128]); clone_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_223, [0, 2, 1, 3]); view_223 = None + permute_71 = torch.ops.aten.permute.default(view_224, [0, 2, 1, 3]); view_224 = None + _scaled_dot_product_cudnn_attention_6 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_69, permute_70, permute_71, None, True, 0.0, True); permute_69 = permute_70 = permute_71 = None + getitem_54 = _scaled_dot_product_cudnn_attention_6[0] + getitem_55 = _scaled_dot_product_cudnn_attention_6[1] + getitem_60 = _scaled_dot_product_cudnn_attention_6[6] + getitem_61 = _scaled_dot_product_cudnn_attention_6[7]; _scaled_dot_product_cudnn_attention_6 = None + permute_72 = torch.ops.aten.permute.default(getitem_54, [0, 2, 1, 3]) + view_225 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16) + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 64, '0'); convert_element_type_215 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_227 = torch.ops.aten.view.default(view_225, [16384, 4096]); view_225 = None + mm_45 = torch.ops.aten.mm.default(view_227, permute_73); view_227 = permute_73 = None + view_228 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + add_25 = torch.ops.aten.add.Tensor(add_23, view_228); view_228 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16) + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 64, '0'); convert_element_type_218 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32) + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = rsqrt_13 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_60); mul_52 = wait_tensor_60 = None + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16) + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 64, '0'); convert_element_type_221 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_61, [1, 0]); wait_tensor_61 = None + view_231 = torch.ops.aten.view.default(convert_element_type_220, [16384, 4096]); convert_element_type_220 = None + mm_46 = torch.ops.aten.mm.default(view_231, permute_74); permute_74 = None + view_232 = torch.ops.aten.view.default(mm_46, [2, 8192, 14336]) + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_232, torch.float32); view_232 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); convert_element_type_224 = sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16) + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 64, '0'); convert_element_type_226 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_62, [1, 0]); wait_tensor_62 = None + mm_47 = torch.ops.aten.mm.default(view_231, permute_75); view_231 = permute_75 = None + view_235 = torch.ops.aten.view.default(mm_47, [2, 8192, 14336]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_235); convert_element_type_225 = view_235 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16) + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 64, '0'); convert_element_type_229 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + view_237 = torch.ops.aten.view.default(mul_55, [16384, 14336]); mul_55 = None + mm_48 = torch.ops.aten.mm.default(view_237, permute_76); view_237 = permute_76 = None + view_238 = torch.ops.aten.view.default(mm_48, [2, 8192, 4096]); mm_48 = None + add_27 = torch.ops.aten.add.Tensor(add_25, view_238); add_25 = view_238 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16) + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 64, '0'); convert_element_type_232 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32) + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = rsqrt_14 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_64); mul_56 = wait_tensor_64 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16) + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 64, '0'); convert_element_type_235 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + view_241 = torch.ops.aten.view.default(convert_element_type_234, [16384, 4096]); convert_element_type_234 = None + mm_49 = torch.ops.aten.mm.default(view_241, permute_77); permute_77 = None + view_242 = torch.ops.aten.view.default(mm_49, [2, 8192, 4096]) + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16) + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 64, '0'); convert_element_type_238 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_66, [1, 0]); wait_tensor_66 = None + mm_50 = torch.ops.aten.mm.default(view_241, permute_78); permute_78 = None + view_245 = torch.ops.aten.view.default(mm_50, [2, 8192, 1024]); mm_50 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16) + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 64, '0'); convert_element_type_241 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_67, [1, 0]); wait_tensor_67 = None + mm_51 = torch.ops.aten.mm.default(view_241, permute_79); view_241 = permute_79 = None + view_248 = torch.ops.aten.view.default(mm_51, [2, 8192, 1024]) + view_249 = torch.ops.aten.view.default(view_242, [2, 8192, -1, 128]); view_242 = None + view_250 = torch.ops.aten.view.default(view_245, [2, 8192, -1, 128]); view_245 = None + view_251 = torch.ops.aten.view.default(view_248, [2, 8192, -1, 128]); view_248 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 32, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_250, torch.float32); view_250 = None + view_253 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 8, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_253); view_253 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_16); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_255 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 32, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_16); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_256 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 8, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_256, torch.bfloat16); view_256 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 8, 4, 128]); unsqueeze_14 = None + clone_14 = torch.ops.aten.clone.default(expand_14, memory_format = torch.contiguous_format); expand_14 = None + view_257 = torch.ops.aten.view.default(clone_14, [2, 8192, 32, 128]); clone_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_251, 3); view_251 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 8, 4, 128]); unsqueeze_15 = None + clone_15 = torch.ops.aten.clone.default(expand_15, memory_format = torch.contiguous_format); expand_15 = None + view_258 = torch.ops.aten.view.default(clone_15, [2, 8192, 32, 128]); clone_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + permute_82 = torch.ops.aten.permute.default(view_258, [0, 2, 1, 3]); view_258 = None + _scaled_dot_product_cudnn_attention_7 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_80, permute_81, permute_82, None, True, 0.0, True); permute_80 = permute_81 = permute_82 = None + getitem_63 = _scaled_dot_product_cudnn_attention_7[0] + getitem_64 = _scaled_dot_product_cudnn_attention_7[1] + getitem_69 = _scaled_dot_product_cudnn_attention_7[6] + getitem_70 = _scaled_dot_product_cudnn_attention_7[7]; _scaled_dot_product_cudnn_attention_7 = None + permute_83 = torch.ops.aten.permute.default(getitem_63, [0, 2, 1, 3]) + view_259 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16) + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 64, '0'); convert_element_type_248 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_68, [1, 0]); wait_tensor_68 = None + view_261 = torch.ops.aten.view.default(view_259, [16384, 4096]); view_259 = None + mm_52 = torch.ops.aten.mm.default(view_261, permute_84); view_261 = permute_84 = None + view_262 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + add_29 = torch.ops.aten.add.Tensor(add_27, view_262); view_262 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16) + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 64, '0'); convert_element_type_251 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32) + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = rsqrt_15 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_69); mul_60 = wait_tensor_69 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16) + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 64, '0'); convert_element_type_254 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + view_265 = torch.ops.aten.view.default(convert_element_type_253, [16384, 4096]); convert_element_type_253 = None + mm_53 = torch.ops.aten.mm.default(view_265, permute_85); permute_85 = None + view_266 = torch.ops.aten.view.default(mm_53, [2, 8192, 14336]) + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_266, torch.float32); view_266 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); convert_element_type_257 = sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16) + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 64, '0'); convert_element_type_259 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_54 = torch.ops.aten.mm.default(view_265, permute_86); view_265 = permute_86 = None + view_269 = torch.ops.aten.view.default(mm_54, [2, 8192, 14336]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_269); convert_element_type_258 = view_269 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16) + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 64, '0'); convert_element_type_262 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + view_271 = torch.ops.aten.view.default(mul_63, [16384, 14336]); mul_63 = None + mm_55 = torch.ops.aten.mm.default(view_271, permute_87); view_271 = permute_87 = None + view_272 = torch.ops.aten.view.default(mm_55, [2, 8192, 4096]); mm_55 = None + add_31 = torch.ops.aten.add.Tensor(add_29, view_272); add_29 = view_272 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16) + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 64, '0'); convert_element_type_265 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32) + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = rsqrt_16 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_73); mul_64 = wait_tensor_73 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16) + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 64, '0'); convert_element_type_268 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_74, [1, 0]); wait_tensor_74 = None + view_275 = torch.ops.aten.view.default(convert_element_type_267, [16384, 4096]); convert_element_type_267 = None + mm_56 = torch.ops.aten.mm.default(view_275, permute_88); permute_88 = None + view_276 = torch.ops.aten.view.default(mm_56, [2, 8192, 4096]) + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16) + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 64, '0'); convert_element_type_271 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_75, [1, 0]); wait_tensor_75 = None + mm_57 = torch.ops.aten.mm.default(view_275, permute_89); permute_89 = None + view_279 = torch.ops.aten.view.default(mm_57, [2, 8192, 1024]); mm_57 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16) + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 64, '0'); convert_element_type_274 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + mm_58 = torch.ops.aten.mm.default(view_275, permute_90); view_275 = permute_90 = None + view_282 = torch.ops.aten.view.default(mm_58, [2, 8192, 1024]) + view_283 = torch.ops.aten.view.default(view_276, [2, 8192, -1, 128]); view_276 = None + view_284 = torch.ops.aten.view.default(view_279, [2, 8192, -1, 128]); view_279 = None + view_285 = torch.ops.aten.view.default(view_282, [2, 8192, -1, 128]); view_282 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_283, torch.float32); view_283 = None + view_286 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 32, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_286); view_286 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_284, torch.float32); view_284 = None + view_287 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 8, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_287); view_287 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_16); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_289 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 32, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_16); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_290 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 8, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_289, torch.bfloat16); view_289 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_290, torch.bfloat16); view_290 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 8, 4, 128]); unsqueeze_16 = None + clone_16 = torch.ops.aten.clone.default(expand_16, memory_format = torch.contiguous_format); expand_16 = None + view_291 = torch.ops.aten.view.default(clone_16, [2, 8192, 32, 128]); clone_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_285, 3); view_285 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 8, 4, 128]); unsqueeze_17 = None + clone_17 = torch.ops.aten.clone.default(expand_17, memory_format = torch.contiguous_format); expand_17 = None + view_292 = torch.ops.aten.view.default(clone_17, [2, 8192, 32, 128]); clone_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_291, [0, 2, 1, 3]); view_291 = None + permute_93 = torch.ops.aten.permute.default(view_292, [0, 2, 1, 3]); view_292 = None + _scaled_dot_product_cudnn_attention_8 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_91, permute_92, permute_93, None, True, 0.0, True); permute_91 = permute_92 = permute_93 = None + getitem_72 = _scaled_dot_product_cudnn_attention_8[0] + getitem_73 = _scaled_dot_product_cudnn_attention_8[1] + getitem_78 = _scaled_dot_product_cudnn_attention_8[6] + getitem_79 = _scaled_dot_product_cudnn_attention_8[7]; _scaled_dot_product_cudnn_attention_8 = None + permute_94 = torch.ops.aten.permute.default(getitem_72, [0, 2, 1, 3]) + view_293 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16) + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 64, '0'); convert_element_type_281 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + view_295 = torch.ops.aten.view.default(view_293, [16384, 4096]); view_293 = None + mm_59 = torch.ops.aten.mm.default(view_295, permute_95); view_295 = permute_95 = None + view_296 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + add_33 = torch.ops.aten.add.Tensor(add_31, view_296); view_296 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16) + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 64, '0'); convert_element_type_284 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32) + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = rsqrt_17 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_78); mul_68 = wait_tensor_78 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16) + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 64, '0'); convert_element_type_287 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_79, [1, 0]); wait_tensor_79 = None + view_299 = torch.ops.aten.view.default(convert_element_type_286, [16384, 4096]); convert_element_type_286 = None + mm_60 = torch.ops.aten.mm.default(view_299, permute_96); permute_96 = None + view_300 = torch.ops.aten.view.default(mm_60, [2, 8192, 14336]) + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); convert_element_type_290 = sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16) + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 64, '0'); convert_element_type_292 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_80, [1, 0]); wait_tensor_80 = None + mm_61 = torch.ops.aten.mm.default(view_299, permute_97); view_299 = permute_97 = None + view_303 = torch.ops.aten.view.default(mm_61, [2, 8192, 14336]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_303); convert_element_type_291 = view_303 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16) + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 64, '0'); convert_element_type_295 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_81, [1, 0]); wait_tensor_81 = None + view_305 = torch.ops.aten.view.default(mul_71, [16384, 14336]); mul_71 = None + mm_62 = torch.ops.aten.mm.default(view_305, permute_98); view_305 = permute_98 = None + view_306 = torch.ops.aten.view.default(mm_62, [2, 8192, 4096]); mm_62 = None + add_35 = torch.ops.aten.add.Tensor(add_33, view_306); add_33 = view_306 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16) + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 64, '0'); convert_element_type_298 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32) + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = rsqrt_18 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_82); mul_72 = wait_tensor_82 = None + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16) + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 64, '0'); convert_element_type_301 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + view_309 = torch.ops.aten.view.default(convert_element_type_300, [16384, 4096]); convert_element_type_300 = None + mm_63 = torch.ops.aten.mm.default(view_309, permute_99); permute_99 = None + view_310 = torch.ops.aten.view.default(mm_63, [2, 8192, 4096]) + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16) + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 64, '0'); convert_element_type_304 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_64 = torch.ops.aten.mm.default(view_309, permute_100); permute_100 = None + view_313 = torch.ops.aten.view.default(mm_64, [2, 8192, 1024]); mm_64 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16) + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 64, '0'); convert_element_type_307 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + mm_65 = torch.ops.aten.mm.default(view_309, permute_101); view_309 = permute_101 = None + view_316 = torch.ops.aten.view.default(mm_65, [2, 8192, 1024]) + view_317 = torch.ops.aten.view.default(view_310, [2, 8192, -1, 128]); view_310 = None + view_318 = torch.ops.aten.view.default(view_313, [2, 8192, -1, 128]); view_313 = None + view_319 = torch.ops.aten.view.default(view_316, [2, 8192, -1, 128]); view_316 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_317, torch.float32); view_317 = None + view_320 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 32, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_320); view_320 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_318, torch.float32); view_318 = None + view_321 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 8, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_321); view_321 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_16); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_323 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 32, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_16); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_324 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 8, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_323, torch.bfloat16); view_323 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_324, torch.bfloat16); view_324 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 8, 4, 128]); unsqueeze_18 = None + clone_18 = torch.ops.aten.clone.default(expand_18, memory_format = torch.contiguous_format); expand_18 = None + view_325 = torch.ops.aten.view.default(clone_18, [2, 8192, 32, 128]); clone_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_319, 3); view_319 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 8, 4, 128]); unsqueeze_19 = None + clone_19 = torch.ops.aten.clone.default(expand_19, memory_format = torch.contiguous_format); expand_19 = None + view_326 = torch.ops.aten.view.default(clone_19, [2, 8192, 32, 128]); clone_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_325, [0, 2, 1, 3]); view_325 = None + permute_104 = torch.ops.aten.permute.default(view_326, [0, 2, 1, 3]); view_326 = None + _scaled_dot_product_cudnn_attention_9 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_102, permute_103, permute_104, None, True, 0.0, True); permute_102 = permute_103 = permute_104 = None + getitem_81 = _scaled_dot_product_cudnn_attention_9[0] + getitem_82 = _scaled_dot_product_cudnn_attention_9[1] + getitem_87 = _scaled_dot_product_cudnn_attention_9[6] + getitem_88 = _scaled_dot_product_cudnn_attention_9[7]; _scaled_dot_product_cudnn_attention_9 = None + permute_105 = torch.ops.aten.permute.default(getitem_81, [0, 2, 1, 3]) + view_327 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16) + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 64, '0'); convert_element_type_314 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_86, [1, 0]); wait_tensor_86 = None + view_329 = torch.ops.aten.view.default(view_327, [16384, 4096]); view_327 = None + mm_66 = torch.ops.aten.mm.default(view_329, permute_106); view_329 = permute_106 = None + view_330 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + add_37 = torch.ops.aten.add.Tensor(add_35, view_330); view_330 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16) + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 64, '0'); convert_element_type_317 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32) + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = rsqrt_19 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_87); mul_76 = wait_tensor_87 = None + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16) + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 64, '0'); convert_element_type_320 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_88, [1, 0]); wait_tensor_88 = None + view_333 = torch.ops.aten.view.default(convert_element_type_319, [16384, 4096]); convert_element_type_319 = None + mm_67 = torch.ops.aten.mm.default(view_333, permute_107); permute_107 = None + view_334 = torch.ops.aten.view.default(mm_67, [2, 8192, 14336]) + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_334, torch.float32); view_334 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); convert_element_type_323 = sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16) + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 64, '0'); convert_element_type_325 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + mm_68 = torch.ops.aten.mm.default(view_333, permute_108); view_333 = permute_108 = None + view_337 = torch.ops.aten.view.default(mm_68, [2, 8192, 14336]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_337); convert_element_type_324 = view_337 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16) + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 64, '0'); convert_element_type_328 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + view_339 = torch.ops.aten.view.default(mul_79, [16384, 14336]); mul_79 = None + mm_69 = torch.ops.aten.mm.default(view_339, permute_109); view_339 = permute_109 = None + view_340 = torch.ops.aten.view.default(mm_69, [2, 8192, 4096]); mm_69 = None + add_39 = torch.ops.aten.add.Tensor(add_37, view_340); add_37 = view_340 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16) + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 64, '0'); convert_element_type_331 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32) + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = rsqrt_20 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_91); mul_80 = wait_tensor_91 = None + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16) + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 64, '0'); convert_element_type_334 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_92, [1, 0]); wait_tensor_92 = None + view_343 = torch.ops.aten.view.default(convert_element_type_333, [16384, 4096]); convert_element_type_333 = None + mm_70 = torch.ops.aten.mm.default(view_343, permute_110); permute_110 = None + view_344 = torch.ops.aten.view.default(mm_70, [2, 8192, 4096]) + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16) + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 64, '0'); convert_element_type_337 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_93, [1, 0]); wait_tensor_93 = None + mm_71 = torch.ops.aten.mm.default(view_343, permute_111); permute_111 = None + view_347 = torch.ops.aten.view.default(mm_71, [2, 8192, 1024]); mm_71 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16) + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 64, '0'); convert_element_type_340 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_94, [1, 0]); wait_tensor_94 = None + mm_72 = torch.ops.aten.mm.default(view_343, permute_112); view_343 = permute_112 = None + view_350 = torch.ops.aten.view.default(mm_72, [2, 8192, 1024]) + view_351 = torch.ops.aten.view.default(view_344, [2, 8192, -1, 128]); view_344 = None + view_352 = torch.ops.aten.view.default(view_347, [2, 8192, -1, 128]); view_347 = None + view_353 = torch.ops.aten.view.default(view_350, [2, 8192, -1, 128]); view_350 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_351, torch.float32); view_351 = None + view_354 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 32, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_354); view_354 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_352, torch.float32); view_352 = None + view_355 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 8, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_355); view_355 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_16); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_357 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 32, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_16); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_358 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 8, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_357, torch.bfloat16); view_357 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_358, torch.bfloat16); view_358 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 8, 4, 128]); unsqueeze_20 = None + clone_20 = torch.ops.aten.clone.default(expand_20, memory_format = torch.contiguous_format); expand_20 = None + view_359 = torch.ops.aten.view.default(clone_20, [2, 8192, 32, 128]); clone_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_353, 3); view_353 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 8, 4, 128]); unsqueeze_21 = None + clone_21 = torch.ops.aten.clone.default(expand_21, memory_format = torch.contiguous_format); expand_21 = None + view_360 = torch.ops.aten.view.default(clone_21, [2, 8192, 32, 128]); clone_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_359, [0, 2, 1, 3]); view_359 = None + permute_115 = torch.ops.aten.permute.default(view_360, [0, 2, 1, 3]); view_360 = None + _scaled_dot_product_cudnn_attention_10 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_113, permute_114, permute_115, None, True, 0.0, True); permute_113 = permute_114 = permute_115 = None + getitem_90 = _scaled_dot_product_cudnn_attention_10[0] + getitem_91 = _scaled_dot_product_cudnn_attention_10[1] + getitem_96 = _scaled_dot_product_cudnn_attention_10[6] + getitem_97 = _scaled_dot_product_cudnn_attention_10[7]; _scaled_dot_product_cudnn_attention_10 = None + permute_116 = torch.ops.aten.permute.default(getitem_90, [0, 2, 1, 3]) + view_361 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16) + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 64, '0'); convert_element_type_347 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_363 = torch.ops.aten.view.default(view_361, [16384, 4096]); view_361 = None + mm_73 = torch.ops.aten.mm.default(view_363, permute_117); view_363 = permute_117 = None + view_364 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + add_41 = torch.ops.aten.add.Tensor(add_39, view_364); view_364 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16) + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 64, '0'); convert_element_type_350 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32) + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = rsqrt_21 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_96); mul_84 = wait_tensor_96 = None + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16) + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 64, '0'); convert_element_type_353 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + view_367 = torch.ops.aten.view.default(convert_element_type_352, [16384, 4096]); convert_element_type_352 = None + mm_74 = torch.ops.aten.mm.default(view_367, permute_118); permute_118 = None + view_368 = torch.ops.aten.view.default(mm_74, [2, 8192, 14336]) + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_368, torch.float32); view_368 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); convert_element_type_356 = sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16) + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 64, '0'); convert_element_type_358 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + mm_75 = torch.ops.aten.mm.default(view_367, permute_119); view_367 = permute_119 = None + view_371 = torch.ops.aten.view.default(mm_75, [2, 8192, 14336]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_371); convert_element_type_357 = view_371 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16) + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 64, '0'); convert_element_type_361 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_99, [1, 0]); wait_tensor_99 = None + view_373 = torch.ops.aten.view.default(mul_87, [16384, 14336]); mul_87 = None + mm_76 = torch.ops.aten.mm.default(view_373, permute_120); view_373 = permute_120 = None + view_374 = torch.ops.aten.view.default(mm_76, [2, 8192, 4096]); mm_76 = None + add_43 = torch.ops.aten.add.Tensor(add_41, view_374); add_41 = view_374 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16) + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 64, '0'); convert_element_type_364 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32) + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = rsqrt_22 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_100); mul_88 = wait_tensor_100 = None + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16) + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 64, '0'); convert_element_type_367 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_101, [1, 0]); wait_tensor_101 = None + view_377 = torch.ops.aten.view.default(convert_element_type_366, [16384, 4096]); convert_element_type_366 = None + mm_77 = torch.ops.aten.mm.default(view_377, permute_121); permute_121 = None + view_378 = torch.ops.aten.view.default(mm_77, [2, 8192, 4096]) + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16) + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 64, '0'); convert_element_type_370 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + mm_78 = torch.ops.aten.mm.default(view_377, permute_122); permute_122 = None + view_381 = torch.ops.aten.view.default(mm_78, [2, 8192, 1024]); mm_78 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16) + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 64, '0'); convert_element_type_373 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_79 = torch.ops.aten.mm.default(view_377, permute_123); view_377 = permute_123 = None + view_384 = torch.ops.aten.view.default(mm_79, [2, 8192, 1024]) + view_385 = torch.ops.aten.view.default(view_378, [2, 8192, -1, 128]); view_378 = None + view_386 = torch.ops.aten.view.default(view_381, [2, 8192, -1, 128]); view_381 = None + view_387 = torch.ops.aten.view.default(view_384, [2, 8192, -1, 128]); view_384 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_385, torch.float32); view_385 = None + view_388 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 32, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_388); view_388 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_386, torch.float32); view_386 = None + view_389 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 8, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_389); view_389 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_16); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_391 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 32, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_16); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_392 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 8, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_391, torch.bfloat16); view_391 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_392, torch.bfloat16); view_392 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 8, 4, 128]); unsqueeze_22 = None + clone_22 = torch.ops.aten.clone.default(expand_22, memory_format = torch.contiguous_format); expand_22 = None + view_393 = torch.ops.aten.view.default(clone_22, [2, 8192, 32, 128]); clone_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_387, 3); view_387 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 8, 4, 128]); unsqueeze_23 = None + clone_23 = torch.ops.aten.clone.default(expand_23, memory_format = torch.contiguous_format); expand_23 = None + view_394 = torch.ops.aten.view.default(clone_23, [2, 8192, 32, 128]); clone_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_393, [0, 2, 1, 3]); view_393 = None + permute_126 = torch.ops.aten.permute.default(view_394, [0, 2, 1, 3]); view_394 = None + _scaled_dot_product_cudnn_attention_11 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_124, permute_125, permute_126, None, True, 0.0, True); permute_124 = permute_125 = permute_126 = None + getitem_99 = _scaled_dot_product_cudnn_attention_11[0] + getitem_100 = _scaled_dot_product_cudnn_attention_11[1] + getitem_105 = _scaled_dot_product_cudnn_attention_11[6] + getitem_106 = _scaled_dot_product_cudnn_attention_11[7]; _scaled_dot_product_cudnn_attention_11 = None + permute_127 = torch.ops.aten.permute.default(getitem_99, [0, 2, 1, 3]) + view_395 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16) + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 64, '0'); convert_element_type_380 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_397 = torch.ops.aten.view.default(view_395, [16384, 4096]); view_395 = None + mm_80 = torch.ops.aten.mm.default(view_397, permute_128); view_397 = permute_128 = None + view_398 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + add_45 = torch.ops.aten.add.Tensor(add_43, view_398); view_398 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16) + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 64, '0'); convert_element_type_383 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32) + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = rsqrt_23 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_105); mul_92 = wait_tensor_105 = None + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16) + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 64, '0'); convert_element_type_386 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_106, [1, 0]); wait_tensor_106 = None + view_401 = torch.ops.aten.view.default(convert_element_type_385, [16384, 4096]); convert_element_type_385 = None + mm_81 = torch.ops.aten.mm.default(view_401, permute_129); permute_129 = None + view_402 = torch.ops.aten.view.default(mm_81, [2, 8192, 14336]) + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_402, torch.float32); view_402 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); convert_element_type_389 = sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16) + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 64, '0'); convert_element_type_391 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_107, [1, 0]); wait_tensor_107 = None + mm_82 = torch.ops.aten.mm.default(view_401, permute_130); view_401 = permute_130 = None + view_405 = torch.ops.aten.view.default(mm_82, [2, 8192, 14336]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_405); convert_element_type_390 = view_405 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16) + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 64, '0'); convert_element_type_394 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + view_407 = torch.ops.aten.view.default(mul_95, [16384, 14336]); mul_95 = None + mm_83 = torch.ops.aten.mm.default(view_407, permute_131); view_407 = permute_131 = None + view_408 = torch.ops.aten.view.default(mm_83, [2, 8192, 4096]); mm_83 = None + add_47 = torch.ops.aten.add.Tensor(add_45, view_408); add_45 = view_408 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16) + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 64, '0'); convert_element_type_397 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32) + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = rsqrt_24 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_109); mul_96 = wait_tensor_109 = None + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16) + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 64, '0'); convert_element_type_400 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + view_411 = torch.ops.aten.view.default(convert_element_type_399, [16384, 4096]); convert_element_type_399 = None + mm_84 = torch.ops.aten.mm.default(view_411, permute_132); permute_132 = None + view_412 = torch.ops.aten.view.default(mm_84, [2, 8192, 4096]) + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16) + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 64, '0'); convert_element_type_403 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + mm_85 = torch.ops.aten.mm.default(view_411, permute_133); permute_133 = None + view_415 = torch.ops.aten.view.default(mm_85, [2, 8192, 1024]); mm_85 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16) + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 64, '0'); convert_element_type_406 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_112, [1, 0]); wait_tensor_112 = None + mm_86 = torch.ops.aten.mm.default(view_411, permute_134); view_411 = permute_134 = None + view_418 = torch.ops.aten.view.default(mm_86, [2, 8192, 1024]) + view_419 = torch.ops.aten.view.default(view_412, [2, 8192, -1, 128]); view_412 = None + view_420 = torch.ops.aten.view.default(view_415, [2, 8192, -1, 128]); view_415 = None + view_421 = torch.ops.aten.view.default(view_418, [2, 8192, -1, 128]); view_418 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_419, torch.float32); view_419 = None + view_422 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 32, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_422); view_422 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_420, torch.float32); view_420 = None + view_423 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 8, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_423); view_423 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_16); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_425 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 32, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_16); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_426 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 8, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_425, torch.bfloat16); view_425 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_426, torch.bfloat16); view_426 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 8, 4, 128]); unsqueeze_24 = None + clone_24 = torch.ops.aten.clone.default(expand_24, memory_format = torch.contiguous_format); expand_24 = None + view_427 = torch.ops.aten.view.default(clone_24, [2, 8192, 32, 128]); clone_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_421, 3); view_421 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 8, 4, 128]); unsqueeze_25 = None + clone_25 = torch.ops.aten.clone.default(expand_25, memory_format = torch.contiguous_format); expand_25 = None + view_428 = torch.ops.aten.view.default(clone_25, [2, 8192, 32, 128]); clone_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_427, [0, 2, 1, 3]); view_427 = None + permute_137 = torch.ops.aten.permute.default(view_428, [0, 2, 1, 3]); view_428 = None + _scaled_dot_product_cudnn_attention_12 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_135, permute_136, permute_137, None, True, 0.0, True); permute_135 = permute_136 = permute_137 = None + getitem_108 = _scaled_dot_product_cudnn_attention_12[0] + getitem_109 = _scaled_dot_product_cudnn_attention_12[1] + getitem_114 = _scaled_dot_product_cudnn_attention_12[6] + getitem_115 = _scaled_dot_product_cudnn_attention_12[7]; _scaled_dot_product_cudnn_attention_12 = None + permute_138 = torch.ops.aten.permute.default(getitem_108, [0, 2, 1, 3]) + view_429 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16) + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 64, '0'); convert_element_type_413 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_113, [1, 0]); wait_tensor_113 = None + view_431 = torch.ops.aten.view.default(view_429, [16384, 4096]); view_429 = None + mm_87 = torch.ops.aten.mm.default(view_431, permute_139); view_431 = permute_139 = None + view_432 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + add_49 = torch.ops.aten.add.Tensor(add_47, view_432); view_432 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16) + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 64, '0'); convert_element_type_416 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32) + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = rsqrt_25 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_114); mul_100 = wait_tensor_114 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16) + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 64, '0'); convert_element_type_419 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + view_435 = torch.ops.aten.view.default(convert_element_type_418, [16384, 4096]); convert_element_type_418 = None + mm_88 = torch.ops.aten.mm.default(view_435, permute_140); permute_140 = None + view_436 = torch.ops.aten.view.default(mm_88, [2, 8192, 14336]) + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_436, torch.float32); view_436 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); convert_element_type_422 = sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16) + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 64, '0'); convert_element_type_424 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_89 = torch.ops.aten.mm.default(view_435, permute_141); view_435 = permute_141 = None + view_439 = torch.ops.aten.view.default(mm_89, [2, 8192, 14336]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_439); convert_element_type_423 = view_439 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16) + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 64, '0'); convert_element_type_427 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + view_441 = torch.ops.aten.view.default(mul_103, [16384, 14336]); mul_103 = None + mm_90 = torch.ops.aten.mm.default(view_441, permute_142); view_441 = permute_142 = None + view_442 = torch.ops.aten.view.default(mm_90, [2, 8192, 4096]); mm_90 = None + add_51 = torch.ops.aten.add.Tensor(add_49, view_442); add_49 = view_442 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16) + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 64, '0'); convert_element_type_430 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32) + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = rsqrt_26 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_118); mul_104 = wait_tensor_118 = None + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16) + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 64, '0'); convert_element_type_433 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_119, [1, 0]); wait_tensor_119 = None + view_445 = torch.ops.aten.view.default(convert_element_type_432, [16384, 4096]); convert_element_type_432 = None + mm_91 = torch.ops.aten.mm.default(view_445, permute_143); permute_143 = None + view_446 = torch.ops.aten.view.default(mm_91, [2, 8192, 4096]) + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16) + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 64, '0'); convert_element_type_436 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_120, [1, 0]); wait_tensor_120 = None + mm_92 = torch.ops.aten.mm.default(view_445, permute_144); permute_144 = None + view_449 = torch.ops.aten.view.default(mm_92, [2, 8192, 1024]); mm_92 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16) + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 64, '0'); convert_element_type_439 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + mm_93 = torch.ops.aten.mm.default(view_445, permute_145); view_445 = permute_145 = None + view_452 = torch.ops.aten.view.default(mm_93, [2, 8192, 1024]) + view_453 = torch.ops.aten.view.default(view_446, [2, 8192, -1, 128]); view_446 = None + view_454 = torch.ops.aten.view.default(view_449, [2, 8192, -1, 128]); view_449 = None + view_455 = torch.ops.aten.view.default(view_452, [2, 8192, -1, 128]); view_452 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_453, torch.float32); view_453 = None + view_456 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 32, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_456); view_456 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_454, torch.float32); view_454 = None + view_457 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 8, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_457); view_457 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_16); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_459 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 32, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_16); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_460 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 8, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_459, torch.bfloat16); view_459 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_460, torch.bfloat16); view_460 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 8, 4, 128]); unsqueeze_26 = None + clone_26 = torch.ops.aten.clone.default(expand_26, memory_format = torch.contiguous_format); expand_26 = None + view_461 = torch.ops.aten.view.default(clone_26, [2, 8192, 32, 128]); clone_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_455, 3); view_455 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 8, 4, 128]); unsqueeze_27 = None + clone_27 = torch.ops.aten.clone.default(expand_27, memory_format = torch.contiguous_format); expand_27 = None + view_462 = torch.ops.aten.view.default(clone_27, [2, 8192, 32, 128]); clone_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_461, [0, 2, 1, 3]); view_461 = None + permute_148 = torch.ops.aten.permute.default(view_462, [0, 2, 1, 3]); view_462 = None + _scaled_dot_product_cudnn_attention_13 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_146, permute_147, permute_148, None, True, 0.0, True); permute_146 = permute_147 = permute_148 = None + getitem_117 = _scaled_dot_product_cudnn_attention_13[0] + getitem_118 = _scaled_dot_product_cudnn_attention_13[1] + getitem_123 = _scaled_dot_product_cudnn_attention_13[6] + getitem_124 = _scaled_dot_product_cudnn_attention_13[7]; _scaled_dot_product_cudnn_attention_13 = None + permute_149 = torch.ops.aten.permute.default(getitem_117, [0, 2, 1, 3]) + view_463 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16) + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 64, '0'); convert_element_type_446 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + view_465 = torch.ops.aten.view.default(view_463, [16384, 4096]); view_463 = None + mm_94 = torch.ops.aten.mm.default(view_465, permute_150); view_465 = permute_150 = None + view_466 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + add_53 = torch.ops.aten.add.Tensor(add_51, view_466); view_466 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16) + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 64, '0'); convert_element_type_449 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32) + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = rsqrt_27 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_123); mul_108 = wait_tensor_123 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16) + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 64, '0'); convert_element_type_452 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + view_469 = torch.ops.aten.view.default(convert_element_type_451, [16384, 4096]); convert_element_type_451 = None + mm_95 = torch.ops.aten.mm.default(view_469, permute_151); permute_151 = None + view_470 = torch.ops.aten.view.default(mm_95, [2, 8192, 14336]) + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_470, torch.float32); view_470 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); convert_element_type_455 = sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16) + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 64, '0'); convert_element_type_457 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_125, [1, 0]); wait_tensor_125 = None + mm_96 = torch.ops.aten.mm.default(view_469, permute_152); view_469 = permute_152 = None + view_473 = torch.ops.aten.view.default(mm_96, [2, 8192, 14336]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_473); convert_element_type_456 = view_473 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16) + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 64, '0'); convert_element_type_460 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_126, [1, 0]); wait_tensor_126 = None + view_475 = torch.ops.aten.view.default(mul_111, [16384, 14336]); mul_111 = None + mm_97 = torch.ops.aten.mm.default(view_475, permute_153); view_475 = permute_153 = None + view_476 = torch.ops.aten.view.default(mm_97, [2, 8192, 4096]); mm_97 = None + add_55 = torch.ops.aten.add.Tensor(add_53, view_476); add_53 = view_476 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16) + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 64, '0'); convert_element_type_463 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32) + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = rsqrt_28 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_127); mul_112 = wait_tensor_127 = None + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16) + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 64, '0'); convert_element_type_466 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + view_479 = torch.ops.aten.view.default(convert_element_type_465, [16384, 4096]); convert_element_type_465 = None + mm_98 = torch.ops.aten.mm.default(view_479, permute_154); permute_154 = None + view_480 = torch.ops.aten.view.default(mm_98, [2, 8192, 4096]) + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16) + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 64, '0'); convert_element_type_469 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_99 = torch.ops.aten.mm.default(view_479, permute_155); permute_155 = None + view_483 = torch.ops.aten.view.default(mm_99, [2, 8192, 1024]); mm_99 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16) + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 64, '0'); convert_element_type_472 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + mm_100 = torch.ops.aten.mm.default(view_479, permute_156); view_479 = permute_156 = None + view_486 = torch.ops.aten.view.default(mm_100, [2, 8192, 1024]) + view_487 = torch.ops.aten.view.default(view_480, [2, 8192, -1, 128]); view_480 = None + view_488 = torch.ops.aten.view.default(view_483, [2, 8192, -1, 128]); view_483 = None + view_489 = torch.ops.aten.view.default(view_486, [2, 8192, -1, 128]); view_486 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_487, torch.float32); view_487 = None + view_490 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 32, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_490); view_490 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_488, torch.float32); view_488 = None + view_491 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 8, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_491); view_491 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_16); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_493 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 32, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_16); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_494 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 8, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_493, torch.bfloat16); view_493 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_494, torch.bfloat16); view_494 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 8, 4, 128]); unsqueeze_28 = None + clone_28 = torch.ops.aten.clone.default(expand_28, memory_format = torch.contiguous_format); expand_28 = None + view_495 = torch.ops.aten.view.default(clone_28, [2, 8192, 32, 128]); clone_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_489, 3); view_489 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 8, 4, 128]); unsqueeze_29 = None + clone_29 = torch.ops.aten.clone.default(expand_29, memory_format = torch.contiguous_format); expand_29 = None + view_496 = torch.ops.aten.view.default(clone_29, [2, 8192, 32, 128]); clone_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_495, [0, 2, 1, 3]); view_495 = None + permute_159 = torch.ops.aten.permute.default(view_496, [0, 2, 1, 3]); view_496 = None + _scaled_dot_product_cudnn_attention_14 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_157, permute_158, permute_159, None, True, 0.0, True); permute_157 = permute_158 = permute_159 = None + getitem_126 = _scaled_dot_product_cudnn_attention_14[0] + getitem_127 = _scaled_dot_product_cudnn_attention_14[1] + getitem_132 = _scaled_dot_product_cudnn_attention_14[6] + getitem_133 = _scaled_dot_product_cudnn_attention_14[7]; _scaled_dot_product_cudnn_attention_14 = None + permute_160 = torch.ops.aten.permute.default(getitem_126, [0, 2, 1, 3]) + view_497 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16) + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 64, '0'); convert_element_type_479 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_131, [1, 0]); wait_tensor_131 = None + view_499 = torch.ops.aten.view.default(view_497, [16384, 4096]); view_497 = None + mm_101 = torch.ops.aten.mm.default(view_499, permute_161); view_499 = permute_161 = None + view_500 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + add_57 = torch.ops.aten.add.Tensor(add_55, view_500); view_500 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16) + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 64, '0'); convert_element_type_482 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32) + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = rsqrt_29 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_132); mul_116 = wait_tensor_132 = None + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16) + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 64, '0'); convert_element_type_485 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_133, [1, 0]); wait_tensor_133 = None + view_503 = torch.ops.aten.view.default(convert_element_type_484, [16384, 4096]); convert_element_type_484 = None + mm_102 = torch.ops.aten.mm.default(view_503, permute_162); permute_162 = None + view_504 = torch.ops.aten.view.default(mm_102, [2, 8192, 14336]) + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_504, torch.float32); view_504 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); convert_element_type_488 = sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16) + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 64, '0'); convert_element_type_490 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + mm_103 = torch.ops.aten.mm.default(view_503, permute_163); view_503 = permute_163 = None + view_507 = torch.ops.aten.view.default(mm_103, [2, 8192, 14336]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_507); convert_element_type_489 = view_507 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16) + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 64, '0'); convert_element_type_493 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + view_509 = torch.ops.aten.view.default(mul_119, [16384, 14336]); mul_119 = None + mm_104 = torch.ops.aten.mm.default(view_509, permute_164); view_509 = permute_164 = None + view_510 = torch.ops.aten.view.default(mm_104, [2, 8192, 4096]); mm_104 = None + add_59 = torch.ops.aten.add.Tensor(add_57, view_510); add_57 = view_510 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16) + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 64, '0'); convert_element_type_496 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32) + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = rsqrt_30 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_136); mul_120 = wait_tensor_136 = None + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16) + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 64, '0'); convert_element_type_499 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + view_513 = torch.ops.aten.view.default(convert_element_type_498, [16384, 4096]); convert_element_type_498 = None + mm_105 = torch.ops.aten.mm.default(view_513, permute_165); permute_165 = None + view_514 = torch.ops.aten.view.default(mm_105, [2, 8192, 4096]) + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16) + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 64, '0'); convert_element_type_502 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_138, [1, 0]); wait_tensor_138 = None + mm_106 = torch.ops.aten.mm.default(view_513, permute_166); permute_166 = None + view_517 = torch.ops.aten.view.default(mm_106, [2, 8192, 1024]); mm_106 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16) + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 64, '0'); convert_element_type_505 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_139, [1, 0]); wait_tensor_139 = None + mm_107 = torch.ops.aten.mm.default(view_513, permute_167); view_513 = permute_167 = None + view_520 = torch.ops.aten.view.default(mm_107, [2, 8192, 1024]) + view_521 = torch.ops.aten.view.default(view_514, [2, 8192, -1, 128]); view_514 = None + view_522 = torch.ops.aten.view.default(view_517, [2, 8192, -1, 128]); view_517 = None + view_523 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_521, torch.float32); view_521 = None + view_524 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 32, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_524); view_524 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_522, torch.float32); view_522 = None + view_525 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 8, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_525); view_525 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_16); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_527 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 32, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_16); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_528 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 8, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_527, torch.bfloat16); view_527 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_528, torch.bfloat16); view_528 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 8, 4, 128]); unsqueeze_30 = None + clone_30 = torch.ops.aten.clone.default(expand_30, memory_format = torch.contiguous_format); expand_30 = None + view_529 = torch.ops.aten.view.default(clone_30, [2, 8192, 32, 128]); clone_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_523, 3); view_523 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 8, 4, 128]); unsqueeze_31 = None + clone_31 = torch.ops.aten.clone.default(expand_31, memory_format = torch.contiguous_format); expand_31 = None + view_530 = torch.ops.aten.view.default(clone_31, [2, 8192, 32, 128]); clone_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_529, [0, 2, 1, 3]); view_529 = None + permute_170 = torch.ops.aten.permute.default(view_530, [0, 2, 1, 3]); view_530 = None + _scaled_dot_product_cudnn_attention_15 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_168, permute_169, permute_170, None, True, 0.0, True); permute_168 = permute_169 = permute_170 = None + getitem_135 = _scaled_dot_product_cudnn_attention_15[0] + getitem_136 = _scaled_dot_product_cudnn_attention_15[1] + getitem_141 = _scaled_dot_product_cudnn_attention_15[6] + getitem_142 = _scaled_dot_product_cudnn_attention_15[7]; _scaled_dot_product_cudnn_attention_15 = None + permute_171 = torch.ops.aten.permute.default(getitem_135, [0, 2, 1, 3]) + view_531 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16) + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 64, '0'); convert_element_type_512 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_140, [1, 0]); wait_tensor_140 = None + view_533 = torch.ops.aten.view.default(view_531, [16384, 4096]); view_531 = None + mm_108 = torch.ops.aten.mm.default(view_533, permute_172); view_533 = permute_172 = None + view_534 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + add_61 = torch.ops.aten.add.Tensor(add_59, view_534); view_534 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16) + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 64, '0'); convert_element_type_515 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32) + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = rsqrt_31 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_141); mul_124 = wait_tensor_141 = None + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16) + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 64, '0'); convert_element_type_518 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + view_537 = torch.ops.aten.view.default(convert_element_type_517, [16384, 4096]); convert_element_type_517 = None + mm_109 = torch.ops.aten.mm.default(view_537, permute_173); permute_173 = None + view_538 = torch.ops.aten.view.default(mm_109, [2, 8192, 14336]) + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_538, torch.float32); view_538 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); convert_element_type_521 = sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16) + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 64, '0'); convert_element_type_523 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + mm_110 = torch.ops.aten.mm.default(view_537, permute_174); view_537 = permute_174 = None + view_541 = torch.ops.aten.view.default(mm_110, [2, 8192, 14336]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_541); convert_element_type_522 = view_541 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16) + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 64, '0'); convert_element_type_526 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_144, [1, 0]); wait_tensor_144 = None + view_543 = torch.ops.aten.view.default(mul_127, [16384, 14336]); mul_127 = None + mm_111 = torch.ops.aten.mm.default(view_543, permute_175); view_543 = permute_175 = None + view_544 = torch.ops.aten.view.default(mm_111, [2, 8192, 4096]); mm_111 = None + add_63 = torch.ops.aten.add.Tensor(add_61, view_544); add_61 = view_544 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16) + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 64, '0'); convert_element_type_529 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32) + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = rsqrt_32 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_145); mul_128 = wait_tensor_145 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16) + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 64, '0'); convert_element_type_532 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_146, [1, 0]); wait_tensor_146 = None + view_547 = torch.ops.aten.view.default(convert_element_type_531, [16384, 4096]); convert_element_type_531 = None + mm_112 = torch.ops.aten.mm.default(view_547, permute_176); permute_176 = None + view_548 = torch.ops.aten.view.default(mm_112, [2, 8192, 4096]) + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16) + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 64, '0'); convert_element_type_535 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + mm_113 = torch.ops.aten.mm.default(view_547, permute_177); permute_177 = None + view_551 = torch.ops.aten.view.default(mm_113, [2, 8192, 1024]); mm_113 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16) + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 64, '0'); convert_element_type_538 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_114 = torch.ops.aten.mm.default(view_547, permute_178); view_547 = permute_178 = None + view_554 = torch.ops.aten.view.default(mm_114, [2, 8192, 1024]) + view_555 = torch.ops.aten.view.default(view_548, [2, 8192, -1, 128]); view_548 = None + view_556 = torch.ops.aten.view.default(view_551, [2, 8192, -1, 128]); view_551 = None + view_557 = torch.ops.aten.view.default(view_554, [2, 8192, -1, 128]); view_554 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_555, torch.float32); view_555 = None + view_558 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 32, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_558); view_558 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_556, torch.float32); view_556 = None + view_559 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 8, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_559); view_559 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_16); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_561 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 32, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_16); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_562 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 8, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_561, torch.bfloat16); view_561 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_562, torch.bfloat16); view_562 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 8, 4, 128]); unsqueeze_32 = None + clone_32 = torch.ops.aten.clone.default(expand_32, memory_format = torch.contiguous_format); expand_32 = None + view_563 = torch.ops.aten.view.default(clone_32, [2, 8192, 32, 128]); clone_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_557, 3); view_557 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 8, 4, 128]); unsqueeze_33 = None + clone_33 = torch.ops.aten.clone.default(expand_33, memory_format = torch.contiguous_format); expand_33 = None + view_564 = torch.ops.aten.view.default(clone_33, [2, 8192, 32, 128]); clone_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_563, [0, 2, 1, 3]); view_563 = None + permute_181 = torch.ops.aten.permute.default(view_564, [0, 2, 1, 3]); view_564 = None + _scaled_dot_product_cudnn_attention_16 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_179, permute_180, permute_181, None, True, 0.0, True); permute_179 = permute_180 = permute_181 = None + getitem_144 = _scaled_dot_product_cudnn_attention_16[0] + getitem_145 = _scaled_dot_product_cudnn_attention_16[1] + getitem_150 = _scaled_dot_product_cudnn_attention_16[6] + getitem_151 = _scaled_dot_product_cudnn_attention_16[7]; _scaled_dot_product_cudnn_attention_16 = None + permute_182 = torch.ops.aten.permute.default(getitem_144, [0, 2, 1, 3]) + view_565 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16) + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 64, '0'); convert_element_type_545 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + view_567 = torch.ops.aten.view.default(view_565, [16384, 4096]); view_565 = None + mm_115 = torch.ops.aten.mm.default(view_567, permute_183); view_567 = permute_183 = None + view_568 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + add_65 = torch.ops.aten.add.Tensor(add_63, view_568); view_568 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16) + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 64, '0'); convert_element_type_548 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32) + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = rsqrt_33 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_150); mul_132 = wait_tensor_150 = None + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16) + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 64, '0'); convert_element_type_551 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_151, [1, 0]); wait_tensor_151 = None + view_571 = torch.ops.aten.view.default(convert_element_type_550, [16384, 4096]); convert_element_type_550 = None + mm_116 = torch.ops.aten.mm.default(view_571, permute_184); permute_184 = None + view_572 = torch.ops.aten.view.default(mm_116, [2, 8192, 14336]) + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_572, torch.float32); view_572 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); convert_element_type_554 = sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16) + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 64, '0'); convert_element_type_556 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_152, [1, 0]); wait_tensor_152 = None + mm_117 = torch.ops.aten.mm.default(view_571, permute_185); view_571 = permute_185 = None + view_575 = torch.ops.aten.view.default(mm_117, [2, 8192, 14336]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_575); convert_element_type_555 = view_575 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16) + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 64, '0'); convert_element_type_559 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_153, [1, 0]); wait_tensor_153 = None + view_577 = torch.ops.aten.view.default(mul_135, [16384, 14336]); mul_135 = None + mm_118 = torch.ops.aten.mm.default(view_577, permute_186); view_577 = permute_186 = None + view_578 = torch.ops.aten.view.default(mm_118, [2, 8192, 4096]); mm_118 = None + add_67 = torch.ops.aten.add.Tensor(add_65, view_578); add_65 = view_578 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16) + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 64, '0'); convert_element_type_562 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32) + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = rsqrt_34 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_154); mul_136 = wait_tensor_154 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16) + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 64, '0'); convert_element_type_565 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + view_581 = torch.ops.aten.view.default(convert_element_type_564, [16384, 4096]); convert_element_type_564 = None + mm_119 = torch.ops.aten.mm.default(view_581, permute_187); permute_187 = None + view_582 = torch.ops.aten.view.default(mm_119, [2, 8192, 4096]) + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16) + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 64, '0'); convert_element_type_568 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + mm_120 = torch.ops.aten.mm.default(view_581, permute_188); permute_188 = None + view_585 = torch.ops.aten.view.default(mm_120, [2, 8192, 1024]); mm_120 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16) + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 64, '0'); convert_element_type_571 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_157, [1, 0]); wait_tensor_157 = None + mm_121 = torch.ops.aten.mm.default(view_581, permute_189); view_581 = permute_189 = None + view_588 = torch.ops.aten.view.default(mm_121, [2, 8192, 1024]) + view_589 = torch.ops.aten.view.default(view_582, [2, 8192, -1, 128]); view_582 = None + view_590 = torch.ops.aten.view.default(view_585, [2, 8192, -1, 128]); view_585 = None + view_591 = torch.ops.aten.view.default(view_588, [2, 8192, -1, 128]); view_588 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_589, torch.float32); view_589 = None + view_592 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 32, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_592); view_592 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_590, torch.float32); view_590 = None + view_593 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 8, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_593); view_593 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_16); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_595 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 32, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_16); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_596 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 8, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_595, torch.bfloat16); view_595 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_596, torch.bfloat16); view_596 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 8, 4, 128]); unsqueeze_34 = None + clone_34 = torch.ops.aten.clone.default(expand_34, memory_format = torch.contiguous_format); expand_34 = None + view_597 = torch.ops.aten.view.default(clone_34, [2, 8192, 32, 128]); clone_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_591, 3); view_591 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 8, 4, 128]); unsqueeze_35 = None + clone_35 = torch.ops.aten.clone.default(expand_35, memory_format = torch.contiguous_format); expand_35 = None + view_598 = torch.ops.aten.view.default(clone_35, [2, 8192, 32, 128]); clone_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_597, [0, 2, 1, 3]); view_597 = None + permute_192 = torch.ops.aten.permute.default(view_598, [0, 2, 1, 3]); view_598 = None + _scaled_dot_product_cudnn_attention_17 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_190, permute_191, permute_192, None, True, 0.0, True); permute_190 = permute_191 = permute_192 = None + getitem_153 = _scaled_dot_product_cudnn_attention_17[0] + getitem_154 = _scaled_dot_product_cudnn_attention_17[1] + getitem_159 = _scaled_dot_product_cudnn_attention_17[6] + getitem_160 = _scaled_dot_product_cudnn_attention_17[7]; _scaled_dot_product_cudnn_attention_17 = None + permute_193 = torch.ops.aten.permute.default(getitem_153, [0, 2, 1, 3]) + view_599 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16) + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 64, '0'); convert_element_type_578 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_158, [1, 0]); wait_tensor_158 = None + view_601 = torch.ops.aten.view.default(view_599, [16384, 4096]); view_599 = None + mm_122 = torch.ops.aten.mm.default(view_601, permute_194); view_601 = permute_194 = None + view_602 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + add_69 = torch.ops.aten.add.Tensor(add_67, view_602); view_602 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16) + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 64, '0'); convert_element_type_581 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32) + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = rsqrt_35 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_159); mul_140 = wait_tensor_159 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16) + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 64, '0'); convert_element_type_584 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + view_605 = torch.ops.aten.view.default(convert_element_type_583, [16384, 4096]); convert_element_type_583 = None + mm_123 = torch.ops.aten.mm.default(view_605, permute_195); permute_195 = None + view_606 = torch.ops.aten.view.default(mm_123, [2, 8192, 14336]) + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_606, torch.float32); view_606 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); convert_element_type_587 = sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16) + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 64, '0'); convert_element_type_589 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_124 = torch.ops.aten.mm.default(view_605, permute_196); view_605 = permute_196 = None + view_609 = torch.ops.aten.view.default(mm_124, [2, 8192, 14336]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_609); convert_element_type_588 = view_609 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16) + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 64, '0'); convert_element_type_592 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + view_611 = torch.ops.aten.view.default(mul_143, [16384, 14336]); mul_143 = None + mm_125 = torch.ops.aten.mm.default(view_611, permute_197); view_611 = permute_197 = None + view_612 = torch.ops.aten.view.default(mm_125, [2, 8192, 4096]); mm_125 = None + add_71 = torch.ops.aten.add.Tensor(add_69, view_612); add_69 = view_612 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16) + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 64, '0'); convert_element_type_595 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32) + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = rsqrt_36 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_163); mul_144 = wait_tensor_163 = None + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16) + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 64, '0'); convert_element_type_598 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_164, [1, 0]); wait_tensor_164 = None + view_615 = torch.ops.aten.view.default(convert_element_type_597, [16384, 4096]); convert_element_type_597 = None + mm_126 = torch.ops.aten.mm.default(view_615, permute_198); permute_198 = None + view_616 = torch.ops.aten.view.default(mm_126, [2, 8192, 4096]) + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16) + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 64, '0'); convert_element_type_601 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_165, [1, 0]); wait_tensor_165 = None + mm_127 = torch.ops.aten.mm.default(view_615, permute_199); permute_199 = None + view_619 = torch.ops.aten.view.default(mm_127, [2, 8192, 1024]); mm_127 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16) + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 64, '0'); convert_element_type_604 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_166, [1, 0]); wait_tensor_166 = None + mm_128 = torch.ops.aten.mm.default(view_615, permute_200); view_615 = permute_200 = None + view_622 = torch.ops.aten.view.default(mm_128, [2, 8192, 1024]) + view_623 = torch.ops.aten.view.default(view_616, [2, 8192, -1, 128]); view_616 = None + view_624 = torch.ops.aten.view.default(view_619, [2, 8192, -1, 128]); view_619 = None + view_625 = torch.ops.aten.view.default(view_622, [2, 8192, -1, 128]); view_622 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_623, torch.float32); view_623 = None + view_626 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 32, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_626); view_626 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_624, torch.float32); view_624 = None + view_627 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 8, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_627); view_627 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_16); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_629 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 32, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_16); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_630 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 8, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_629, torch.bfloat16); view_629 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_630, torch.bfloat16); view_630 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 8, 4, 128]); unsqueeze_36 = None + clone_36 = torch.ops.aten.clone.default(expand_36, memory_format = torch.contiguous_format); expand_36 = None + view_631 = torch.ops.aten.view.default(clone_36, [2, 8192, 32, 128]); clone_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_625, 3); view_625 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 8, 4, 128]); unsqueeze_37 = None + clone_37 = torch.ops.aten.clone.default(expand_37, memory_format = torch.contiguous_format); expand_37 = None + view_632 = torch.ops.aten.view.default(clone_37, [2, 8192, 32, 128]); clone_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_631, [0, 2, 1, 3]); view_631 = None + permute_203 = torch.ops.aten.permute.default(view_632, [0, 2, 1, 3]); view_632 = None + _scaled_dot_product_cudnn_attention_18 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_201, permute_202, permute_203, None, True, 0.0, True); permute_201 = permute_202 = permute_203 = None + getitem_162 = _scaled_dot_product_cudnn_attention_18[0] + getitem_163 = _scaled_dot_product_cudnn_attention_18[1] + getitem_168 = _scaled_dot_product_cudnn_attention_18[6] + getitem_169 = _scaled_dot_product_cudnn_attention_18[7]; _scaled_dot_product_cudnn_attention_18 = None + permute_204 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_633 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16) + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 64, '0'); convert_element_type_611 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_635 = torch.ops.aten.view.default(view_633, [16384, 4096]); view_633 = None + mm_129 = torch.ops.aten.mm.default(view_635, permute_205); view_635 = permute_205 = None + view_636 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + add_73 = torch.ops.aten.add.Tensor(add_71, view_636); view_636 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16) + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 64, '0'); convert_element_type_614 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32) + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = rsqrt_37 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_168); mul_148 = wait_tensor_168 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16) + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 64, '0'); convert_element_type_617 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + view_639 = torch.ops.aten.view.default(convert_element_type_616, [16384, 4096]); convert_element_type_616 = None + mm_130 = torch.ops.aten.mm.default(view_639, permute_206); permute_206 = None + view_640 = torch.ops.aten.view.default(mm_130, [2, 8192, 14336]) + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_640, torch.float32); view_640 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); convert_element_type_620 = sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16) + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 64, '0'); convert_element_type_622 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_170, [1, 0]); wait_tensor_170 = None + mm_131 = torch.ops.aten.mm.default(view_639, permute_207); view_639 = permute_207 = None + view_643 = torch.ops.aten.view.default(mm_131, [2, 8192, 14336]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_643); convert_element_type_621 = view_643 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16) + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 64, '0'); convert_element_type_625 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_171, [1, 0]); wait_tensor_171 = None + view_645 = torch.ops.aten.view.default(mul_151, [16384, 14336]); mul_151 = None + mm_132 = torch.ops.aten.mm.default(view_645, permute_208); view_645 = permute_208 = None + view_646 = torch.ops.aten.view.default(mm_132, [2, 8192, 4096]); mm_132 = None + add_75 = torch.ops.aten.add.Tensor(add_73, view_646); add_73 = view_646 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16) + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 64, '0'); convert_element_type_628 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32) + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = rsqrt_38 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_172); mul_152 = wait_tensor_172 = None + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16) + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 64, '0'); convert_element_type_631 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + view_649 = torch.ops.aten.view.default(convert_element_type_630, [16384, 4096]); convert_element_type_630 = None + mm_133 = torch.ops.aten.mm.default(view_649, permute_209); permute_209 = None + view_650 = torch.ops.aten.view.default(mm_133, [2, 8192, 4096]) + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16) + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 64, '0'); convert_element_type_634 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_134 = torch.ops.aten.mm.default(view_649, permute_210); permute_210 = None + view_653 = torch.ops.aten.view.default(mm_134, [2, 8192, 1024]); mm_134 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16) + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 64, '0'); convert_element_type_637 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + mm_135 = torch.ops.aten.mm.default(view_649, permute_211); view_649 = permute_211 = None + view_656 = torch.ops.aten.view.default(mm_135, [2, 8192, 1024]) + view_657 = torch.ops.aten.view.default(view_650, [2, 8192, -1, 128]); view_650 = None + view_658 = torch.ops.aten.view.default(view_653, [2, 8192, -1, 128]); view_653 = None + view_659 = torch.ops.aten.view.default(view_656, [2, 8192, -1, 128]); view_656 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_657, torch.float32); view_657 = None + view_660 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 32, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_660); view_660 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_658, torch.float32); view_658 = None + view_661 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 8, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_661); view_661 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_16); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_663 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 32, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_16); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_664 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 8, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_663, torch.bfloat16); view_663 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_664, torch.bfloat16); view_664 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 8, 4, 128]); unsqueeze_38 = None + clone_38 = torch.ops.aten.clone.default(expand_38, memory_format = torch.contiguous_format); expand_38 = None + view_665 = torch.ops.aten.view.default(clone_38, [2, 8192, 32, 128]); clone_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_659, 3); view_659 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 8, 4, 128]); unsqueeze_39 = None + clone_39 = torch.ops.aten.clone.default(expand_39, memory_format = torch.contiguous_format); expand_39 = None + view_666 = torch.ops.aten.view.default(clone_39, [2, 8192, 32, 128]); clone_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_665, [0, 2, 1, 3]); view_665 = None + permute_214 = torch.ops.aten.permute.default(view_666, [0, 2, 1, 3]); view_666 = None + _scaled_dot_product_cudnn_attention_19 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_212, permute_213, permute_214, None, True, 0.0, True); permute_212 = permute_213 = permute_214 = None + getitem_171 = _scaled_dot_product_cudnn_attention_19[0] + getitem_172 = _scaled_dot_product_cudnn_attention_19[1] + getitem_177 = _scaled_dot_product_cudnn_attention_19[6] + getitem_178 = _scaled_dot_product_cudnn_attention_19[7]; _scaled_dot_product_cudnn_attention_19 = None + permute_215 = torch.ops.aten.permute.default(getitem_171, [0, 2, 1, 3]) + view_667 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16) + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 64, '0'); convert_element_type_644 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_669 = torch.ops.aten.view.default(view_667, [16384, 4096]); view_667 = None + mm_136 = torch.ops.aten.mm.default(view_669, permute_216); view_669 = permute_216 = None + view_670 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + add_77 = torch.ops.aten.add.Tensor(add_75, view_670); view_670 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16) + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 64, '0'); convert_element_type_647 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32) + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = rsqrt_39 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_177); mul_156 = wait_tensor_177 = None + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16) + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 64, '0'); convert_element_type_650 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_178, [1, 0]); wait_tensor_178 = None + view_673 = torch.ops.aten.view.default(convert_element_type_649, [16384, 4096]); convert_element_type_649 = None + mm_137 = torch.ops.aten.mm.default(view_673, permute_217); permute_217 = None + view_674 = torch.ops.aten.view.default(mm_137, [2, 8192, 14336]) + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_674, torch.float32); view_674 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); convert_element_type_653 = sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16) + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 64, '0'); convert_element_type_655 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_179, [1, 0]); wait_tensor_179 = None + mm_138 = torch.ops.aten.mm.default(view_673, permute_218); view_673 = permute_218 = None + view_677 = torch.ops.aten.view.default(mm_138, [2, 8192, 14336]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_677); convert_element_type_654 = view_677 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16) + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 64, '0'); convert_element_type_658 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + view_679 = torch.ops.aten.view.default(mul_159, [16384, 14336]); mul_159 = None + mm_139 = torch.ops.aten.mm.default(view_679, permute_219); view_679 = permute_219 = None + view_680 = torch.ops.aten.view.default(mm_139, [2, 8192, 4096]); mm_139 = None + add_79 = torch.ops.aten.add.Tensor(add_77, view_680); add_77 = view_680 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16) + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 64, '0'); convert_element_type_661 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32) + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = rsqrt_40 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_181); mul_160 = wait_tensor_181 = None + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16) + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 64, '0'); convert_element_type_664 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + view_683 = torch.ops.aten.view.default(convert_element_type_663, [16384, 4096]); convert_element_type_663 = None + mm_140 = torch.ops.aten.mm.default(view_683, permute_220); permute_220 = None + view_684 = torch.ops.aten.view.default(mm_140, [2, 8192, 4096]) + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16) + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 64, '0'); convert_element_type_667 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_183, [1, 0]); wait_tensor_183 = None + mm_141 = torch.ops.aten.mm.default(view_683, permute_221); permute_221 = None + view_687 = torch.ops.aten.view.default(mm_141, [2, 8192, 1024]); mm_141 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16) + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 64, '0'); convert_element_type_670 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_184, [1, 0]); wait_tensor_184 = None + mm_142 = torch.ops.aten.mm.default(view_683, permute_222); view_683 = permute_222 = None + view_690 = torch.ops.aten.view.default(mm_142, [2, 8192, 1024]) + view_691 = torch.ops.aten.view.default(view_684, [2, 8192, -1, 128]); view_684 = None + view_692 = torch.ops.aten.view.default(view_687, [2, 8192, -1, 128]); view_687 = None + view_693 = torch.ops.aten.view.default(view_690, [2, 8192, -1, 128]); view_690 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_691, torch.float32); view_691 = None + view_694 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 32, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_694); view_694 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_692, torch.float32); view_692 = None + view_695 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 8, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_695); view_695 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_16); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_697 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 32, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_16); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_698 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 8, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_697, torch.bfloat16); view_697 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_698, torch.bfloat16); view_698 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 8, 4, 128]); unsqueeze_40 = None + clone_40 = torch.ops.aten.clone.default(expand_40, memory_format = torch.contiguous_format); expand_40 = None + view_699 = torch.ops.aten.view.default(clone_40, [2, 8192, 32, 128]); clone_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_693, 3); view_693 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 8, 4, 128]); unsqueeze_41 = None + clone_41 = torch.ops.aten.clone.default(expand_41, memory_format = torch.contiguous_format); expand_41 = None + view_700 = torch.ops.aten.view.default(clone_41, [2, 8192, 32, 128]); clone_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_699, [0, 2, 1, 3]); view_699 = None + permute_225 = torch.ops.aten.permute.default(view_700, [0, 2, 1, 3]); view_700 = None + _scaled_dot_product_cudnn_attention_20 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_223, permute_224, permute_225, None, True, 0.0, True); permute_223 = permute_224 = permute_225 = None + getitem_180 = _scaled_dot_product_cudnn_attention_20[0] + getitem_181 = _scaled_dot_product_cudnn_attention_20[1] + getitem_186 = _scaled_dot_product_cudnn_attention_20[6] + getitem_187 = _scaled_dot_product_cudnn_attention_20[7]; _scaled_dot_product_cudnn_attention_20 = None + permute_226 = torch.ops.aten.permute.default(getitem_180, [0, 2, 1, 3]) + view_701 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16) + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 64, '0'); convert_element_type_677 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_185, [1, 0]); wait_tensor_185 = None + view_703 = torch.ops.aten.view.default(view_701, [16384, 4096]); view_701 = None + mm_143 = torch.ops.aten.mm.default(view_703, permute_227); view_703 = permute_227 = None + view_704 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + add_81 = torch.ops.aten.add.Tensor(add_79, view_704); view_704 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16) + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 64, '0'); convert_element_type_680 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32) + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = rsqrt_41 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_186); mul_164 = wait_tensor_186 = None + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16) + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 64, '0'); convert_element_type_683 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + view_707 = torch.ops.aten.view.default(convert_element_type_682, [16384, 4096]); convert_element_type_682 = None + mm_144 = torch.ops.aten.mm.default(view_707, permute_228); permute_228 = None + view_708 = torch.ops.aten.view.default(mm_144, [2, 8192, 14336]) + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_708, torch.float32); view_708 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); convert_element_type_686 = sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16) + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 64, '0'); convert_element_type_688 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_145 = torch.ops.aten.mm.default(view_707, permute_229); view_707 = permute_229 = None + view_711 = torch.ops.aten.view.default(mm_145, [2, 8192, 14336]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_711); convert_element_type_687 = view_711 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16) + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 64, '0'); convert_element_type_691 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + view_713 = torch.ops.aten.view.default(mul_167, [16384, 14336]); mul_167 = None + mm_146 = torch.ops.aten.mm.default(view_713, permute_230); view_713 = permute_230 = None + view_714 = torch.ops.aten.view.default(mm_146, [2, 8192, 4096]); mm_146 = None + add_83 = torch.ops.aten.add.Tensor(add_81, view_714); add_81 = view_714 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16) + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 64, '0'); convert_element_type_694 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32) + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = rsqrt_42 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_190); mul_168 = wait_tensor_190 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16) + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 64, '0'); convert_element_type_697 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_191, [1, 0]); wait_tensor_191 = None + view_717 = torch.ops.aten.view.default(convert_element_type_696, [16384, 4096]); convert_element_type_696 = None + mm_147 = torch.ops.aten.mm.default(view_717, permute_231); permute_231 = None + view_718 = torch.ops.aten.view.default(mm_147, [2, 8192, 4096]) + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16) + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 64, '0'); convert_element_type_700 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_192, [1, 0]); wait_tensor_192 = None + mm_148 = torch.ops.aten.mm.default(view_717, permute_232); permute_232 = None + view_721 = torch.ops.aten.view.default(mm_148, [2, 8192, 1024]); mm_148 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16) + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 64, '0'); convert_element_type_703 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + mm_149 = torch.ops.aten.mm.default(view_717, permute_233); view_717 = permute_233 = None + view_724 = torch.ops.aten.view.default(mm_149, [2, 8192, 1024]) + view_725 = torch.ops.aten.view.default(view_718, [2, 8192, -1, 128]); view_718 = None + view_726 = torch.ops.aten.view.default(view_721, [2, 8192, -1, 128]); view_721 = None + view_727 = torch.ops.aten.view.default(view_724, [2, 8192, -1, 128]); view_724 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_725, torch.float32); view_725 = None + view_728 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 32, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_728); view_728 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_726, torch.float32); view_726 = None + view_729 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 8, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_729); view_729 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_16); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_731 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 32, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_16); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_732 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 8, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_731, torch.bfloat16); view_731 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_732, torch.bfloat16); view_732 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 8, 4, 128]); unsqueeze_42 = None + clone_42 = torch.ops.aten.clone.default(expand_42, memory_format = torch.contiguous_format); expand_42 = None + view_733 = torch.ops.aten.view.default(clone_42, [2, 8192, 32, 128]); clone_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_727, 3); view_727 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 8, 4, 128]); unsqueeze_43 = None + clone_43 = torch.ops.aten.clone.default(expand_43, memory_format = torch.contiguous_format); expand_43 = None + view_734 = torch.ops.aten.view.default(clone_43, [2, 8192, 32, 128]); clone_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_733, [0, 2, 1, 3]); view_733 = None + permute_236 = torch.ops.aten.permute.default(view_734, [0, 2, 1, 3]); view_734 = None + _scaled_dot_product_cudnn_attention_21 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_234, permute_235, permute_236, None, True, 0.0, True); permute_234 = permute_235 = permute_236 = None + getitem_189 = _scaled_dot_product_cudnn_attention_21[0] + getitem_190 = _scaled_dot_product_cudnn_attention_21[1] + getitem_195 = _scaled_dot_product_cudnn_attention_21[6] + getitem_196 = _scaled_dot_product_cudnn_attention_21[7]; _scaled_dot_product_cudnn_attention_21 = None + permute_237 = torch.ops.aten.permute.default(getitem_189, [0, 2, 1, 3]) + view_735 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16) + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 64, '0'); convert_element_type_710 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + view_737 = torch.ops.aten.view.default(view_735, [16384, 4096]); view_735 = None + mm_150 = torch.ops.aten.mm.default(view_737, permute_238); view_737 = permute_238 = None + view_738 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + add_85 = torch.ops.aten.add.Tensor(add_83, view_738); view_738 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16) + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 64, '0'); convert_element_type_713 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32) + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = rsqrt_43 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_195); mul_172 = wait_tensor_195 = None + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16) + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 64, '0'); convert_element_type_716 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_196, [1, 0]); wait_tensor_196 = None + view_741 = torch.ops.aten.view.default(convert_element_type_715, [16384, 4096]); convert_element_type_715 = None + mm_151 = torch.ops.aten.mm.default(view_741, permute_239); permute_239 = None + view_742 = torch.ops.aten.view.default(mm_151, [2, 8192, 14336]) + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_742, torch.float32); view_742 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); convert_element_type_719 = sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16) + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 64, '0'); convert_element_type_721 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_197, [1, 0]); wait_tensor_197 = None + mm_152 = torch.ops.aten.mm.default(view_741, permute_240); view_741 = permute_240 = None + view_745 = torch.ops.aten.view.default(mm_152, [2, 8192, 14336]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_745); convert_element_type_720 = view_745 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16) + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 64, '0'); convert_element_type_724 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_198, [1, 0]); wait_tensor_198 = None + view_747 = torch.ops.aten.view.default(mul_175, [16384, 14336]); mul_175 = None + mm_153 = torch.ops.aten.mm.default(view_747, permute_241); view_747 = permute_241 = None + view_748 = torch.ops.aten.view.default(mm_153, [2, 8192, 4096]); mm_153 = None + add_87 = torch.ops.aten.add.Tensor(add_85, view_748); add_85 = view_748 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16) + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 64, '0'); convert_element_type_727 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32) + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = rsqrt_44 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_199); mul_176 = wait_tensor_199 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16) + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 64, '0'); convert_element_type_730 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + view_751 = torch.ops.aten.view.default(convert_element_type_729, [16384, 4096]); convert_element_type_729 = None + mm_154 = torch.ops.aten.mm.default(view_751, permute_242); permute_242 = None + view_752 = torch.ops.aten.view.default(mm_154, [2, 8192, 4096]) + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16) + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 64, '0'); convert_element_type_733 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_155 = torch.ops.aten.mm.default(view_751, permute_243); permute_243 = None + view_755 = torch.ops.aten.view.default(mm_155, [2, 8192, 1024]); mm_155 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16) + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 64, '0'); convert_element_type_736 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + mm_156 = torch.ops.aten.mm.default(view_751, permute_244); view_751 = permute_244 = None + view_758 = torch.ops.aten.view.default(mm_156, [2, 8192, 1024]) + view_759 = torch.ops.aten.view.default(view_752, [2, 8192, -1, 128]); view_752 = None + view_760 = torch.ops.aten.view.default(view_755, [2, 8192, -1, 128]); view_755 = None + view_761 = torch.ops.aten.view.default(view_758, [2, 8192, -1, 128]); view_758 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_759, torch.float32); view_759 = None + view_762 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 32, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_762); view_762 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_760, torch.float32); view_760 = None + view_763 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 8, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_763); view_763 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_16); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_765 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 32, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_16); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_766 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 8, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_765, torch.bfloat16); view_765 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_766, torch.bfloat16); view_766 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 8, 4, 128]); unsqueeze_44 = None + clone_44 = torch.ops.aten.clone.default(expand_44, memory_format = torch.contiguous_format); expand_44 = None + view_767 = torch.ops.aten.view.default(clone_44, [2, 8192, 32, 128]); clone_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_761, 3); view_761 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 8, 4, 128]); unsqueeze_45 = None + clone_45 = torch.ops.aten.clone.default(expand_45, memory_format = torch.contiguous_format); expand_45 = None + view_768 = torch.ops.aten.view.default(clone_45, [2, 8192, 32, 128]); clone_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_767, [0, 2, 1, 3]); view_767 = None + permute_247 = torch.ops.aten.permute.default(view_768, [0, 2, 1, 3]); view_768 = None + _scaled_dot_product_cudnn_attention_22 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_245, permute_246, permute_247, None, True, 0.0, True); permute_245 = permute_246 = permute_247 = None + getitem_198 = _scaled_dot_product_cudnn_attention_22[0] + getitem_199 = _scaled_dot_product_cudnn_attention_22[1] + getitem_204 = _scaled_dot_product_cudnn_attention_22[6] + getitem_205 = _scaled_dot_product_cudnn_attention_22[7]; _scaled_dot_product_cudnn_attention_22 = None + permute_248 = torch.ops.aten.permute.default(getitem_198, [0, 2, 1, 3]) + view_769 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16) + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 64, '0'); convert_element_type_743 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_203, [1, 0]); wait_tensor_203 = None + view_771 = torch.ops.aten.view.default(view_769, [16384, 4096]); view_769 = None + mm_157 = torch.ops.aten.mm.default(view_771, permute_249); view_771 = permute_249 = None + view_772 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + add_89 = torch.ops.aten.add.Tensor(add_87, view_772); view_772 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16) + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 64, '0'); convert_element_type_746 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32) + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = rsqrt_45 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_204); mul_180 = wait_tensor_204 = None + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16) + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 64, '0'); convert_element_type_749 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_205, [1, 0]); wait_tensor_205 = None + view_775 = torch.ops.aten.view.default(convert_element_type_748, [16384, 4096]); convert_element_type_748 = None + mm_158 = torch.ops.aten.mm.default(view_775, permute_250); permute_250 = None + view_776 = torch.ops.aten.view.default(mm_158, [2, 8192, 14336]) + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_776, torch.float32); view_776 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); convert_element_type_752 = sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16) + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 64, '0'); convert_element_type_754 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + mm_159 = torch.ops.aten.mm.default(view_775, permute_251); view_775 = permute_251 = None + view_779 = torch.ops.aten.view.default(mm_159, [2, 8192, 14336]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_779); convert_element_type_753 = view_779 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16) + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 64, '0'); convert_element_type_757 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + view_781 = torch.ops.aten.view.default(mul_183, [16384, 14336]); mul_183 = None + mm_160 = torch.ops.aten.mm.default(view_781, permute_252); view_781 = permute_252 = None + view_782 = torch.ops.aten.view.default(mm_160, [2, 8192, 4096]); mm_160 = None + add_91 = torch.ops.aten.add.Tensor(add_89, view_782); add_89 = view_782 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16) + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 64, '0'); convert_element_type_760 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32) + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = rsqrt_46 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_208); mul_184 = wait_tensor_208 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16) + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 64, '0'); convert_element_type_763 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_209, [1, 0]); wait_tensor_209 = None + view_785 = torch.ops.aten.view.default(convert_element_type_762, [16384, 4096]); convert_element_type_762 = None + mm_161 = torch.ops.aten.mm.default(view_785, permute_253); permute_253 = None + view_786 = torch.ops.aten.view.default(mm_161, [2, 8192, 4096]) + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16) + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 64, '0'); convert_element_type_766 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_210, [1, 0]); wait_tensor_210 = None + mm_162 = torch.ops.aten.mm.default(view_785, permute_254); permute_254 = None + view_789 = torch.ops.aten.view.default(mm_162, [2, 8192, 1024]); mm_162 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16) + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 64, '0'); convert_element_type_769 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_211, [1, 0]); wait_tensor_211 = None + mm_163 = torch.ops.aten.mm.default(view_785, permute_255); view_785 = permute_255 = None + view_792 = torch.ops.aten.view.default(mm_163, [2, 8192, 1024]) + view_793 = torch.ops.aten.view.default(view_786, [2, 8192, -1, 128]); view_786 = None + view_794 = torch.ops.aten.view.default(view_789, [2, 8192, -1, 128]); view_789 = None + view_795 = torch.ops.aten.view.default(view_792, [2, 8192, -1, 128]); view_792 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_793, torch.float32); view_793 = None + view_796 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 32, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_796); view_796 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_794, torch.float32); view_794 = None + view_797 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 8, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_797); view_797 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_16); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_799 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 32, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_16); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_800 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 8, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_799, torch.bfloat16); view_799 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_800, torch.bfloat16); view_800 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 8, 4, 128]); unsqueeze_46 = None + clone_46 = torch.ops.aten.clone.default(expand_46, memory_format = torch.contiguous_format); expand_46 = None + view_801 = torch.ops.aten.view.default(clone_46, [2, 8192, 32, 128]); clone_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_795, 3); view_795 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 8, 4, 128]); unsqueeze_47 = None + clone_47 = torch.ops.aten.clone.default(expand_47, memory_format = torch.contiguous_format); expand_47 = None + view_802 = torch.ops.aten.view.default(clone_47, [2, 8192, 32, 128]); clone_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_801, [0, 2, 1, 3]); view_801 = None + permute_258 = torch.ops.aten.permute.default(view_802, [0, 2, 1, 3]); view_802 = None + _scaled_dot_product_cudnn_attention_23 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_256, permute_257, permute_258, None, True, 0.0, True); permute_256 = permute_257 = permute_258 = None + getitem_207 = _scaled_dot_product_cudnn_attention_23[0] + getitem_208 = _scaled_dot_product_cudnn_attention_23[1] + getitem_213 = _scaled_dot_product_cudnn_attention_23[6] + getitem_214 = _scaled_dot_product_cudnn_attention_23[7]; _scaled_dot_product_cudnn_attention_23 = None + permute_259 = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3]) + view_803 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16) + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 64, '0'); convert_element_type_776 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_805 = torch.ops.aten.view.default(view_803, [16384, 4096]); view_803 = None + mm_164 = torch.ops.aten.mm.default(view_805, permute_260); view_805 = permute_260 = None + view_806 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + add_93 = torch.ops.aten.add.Tensor(add_91, view_806); view_806 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16) + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 64, '0'); convert_element_type_779 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32) + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = rsqrt_47 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_213); mul_188 = wait_tensor_213 = None + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16) + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 64, '0'); convert_element_type_782 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + view_809 = torch.ops.aten.view.default(convert_element_type_781, [16384, 4096]); convert_element_type_781 = None + mm_165 = torch.ops.aten.mm.default(view_809, permute_261); permute_261 = None + view_810 = torch.ops.aten.view.default(mm_165, [2, 8192, 14336]) + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_810, torch.float32); view_810 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); convert_element_type_785 = sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16) + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 64, '0'); convert_element_type_787 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + mm_166 = torch.ops.aten.mm.default(view_809, permute_262); view_809 = permute_262 = None + view_813 = torch.ops.aten.view.default(mm_166, [2, 8192, 14336]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_813); convert_element_type_786 = view_813 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16) + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 64, '0'); convert_element_type_790 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_216, [1, 0]); wait_tensor_216 = None + view_815 = torch.ops.aten.view.default(mul_191, [16384, 14336]); mul_191 = None + mm_167 = torch.ops.aten.mm.default(view_815, permute_263); view_815 = permute_263 = None + view_816 = torch.ops.aten.view.default(mm_167, [2, 8192, 4096]); mm_167 = None + add_95 = torch.ops.aten.add.Tensor(add_93, view_816); add_93 = view_816 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16) + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 64, '0'); convert_element_type_793 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32) + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = rsqrt_48 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_217); mul_192 = wait_tensor_217 = None + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16) + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 64, '0'); convert_element_type_796 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_218, [1, 0]); wait_tensor_218 = None + view_819 = torch.ops.aten.view.default(convert_element_type_795, [16384, 4096]); convert_element_type_795 = None + mm_168 = torch.ops.aten.mm.default(view_819, permute_264); permute_264 = None + view_820 = torch.ops.aten.view.default(mm_168, [2, 8192, 4096]) + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16) + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 64, '0'); convert_element_type_799 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + mm_169 = torch.ops.aten.mm.default(view_819, permute_265); permute_265 = None + view_823 = torch.ops.aten.view.default(mm_169, [2, 8192, 1024]); mm_169 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16) + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 64, '0'); convert_element_type_802 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_170 = torch.ops.aten.mm.default(view_819, permute_266); view_819 = permute_266 = None + view_826 = torch.ops.aten.view.default(mm_170, [2, 8192, 1024]) + view_827 = torch.ops.aten.view.default(view_820, [2, 8192, -1, 128]); view_820 = None + view_828 = torch.ops.aten.view.default(view_823, [2, 8192, -1, 128]); view_823 = None + view_829 = torch.ops.aten.view.default(view_826, [2, 8192, -1, 128]); view_826 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_827, torch.float32); view_827 = None + view_830 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 32, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_830); view_830 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_828, torch.float32); view_828 = None + view_831 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 8, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_831); view_831 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_16); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_833 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 32, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_16); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_834 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 8, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_833, torch.bfloat16); view_833 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_834, torch.bfloat16); view_834 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 8, 4, 128]); unsqueeze_48 = None + clone_48 = torch.ops.aten.clone.default(expand_48, memory_format = torch.contiguous_format); expand_48 = None + view_835 = torch.ops.aten.view.default(clone_48, [2, 8192, 32, 128]); clone_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_829, 3); view_829 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 8, 4, 128]); unsqueeze_49 = None + clone_49 = torch.ops.aten.clone.default(expand_49, memory_format = torch.contiguous_format); expand_49 = None + view_836 = torch.ops.aten.view.default(clone_49, [2, 8192, 32, 128]); clone_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_835, [0, 2, 1, 3]); view_835 = None + permute_269 = torch.ops.aten.permute.default(view_836, [0, 2, 1, 3]); view_836 = None + _scaled_dot_product_cudnn_attention_24 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_267, permute_268, permute_269, None, True, 0.0, True); permute_267 = permute_268 = permute_269 = None + getitem_216 = _scaled_dot_product_cudnn_attention_24[0] + getitem_217 = _scaled_dot_product_cudnn_attention_24[1] + getitem_222 = _scaled_dot_product_cudnn_attention_24[6] + getitem_223 = _scaled_dot_product_cudnn_attention_24[7]; _scaled_dot_product_cudnn_attention_24 = None + permute_270 = torch.ops.aten.permute.default(getitem_216, [0, 2, 1, 3]) + view_837 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16) + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 64, '0'); convert_element_type_809 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_839 = torch.ops.aten.view.default(view_837, [16384, 4096]); view_837 = None + mm_171 = torch.ops.aten.mm.default(view_839, permute_271); view_839 = permute_271 = None + view_840 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + add_97 = torch.ops.aten.add.Tensor(add_95, view_840); view_840 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16) + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 64, '0'); convert_element_type_812 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32) + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = rsqrt_49 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_222); mul_196 = wait_tensor_222 = None + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16) + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 64, '0'); convert_element_type_815 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_223, [1, 0]); wait_tensor_223 = None + view_843 = torch.ops.aten.view.default(convert_element_type_814, [16384, 4096]); convert_element_type_814 = None + mm_172 = torch.ops.aten.mm.default(view_843, permute_272); permute_272 = None + view_844 = torch.ops.aten.view.default(mm_172, [2, 8192, 14336]) + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_844, torch.float32); view_844 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); convert_element_type_818 = sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16) + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 64, '0'); convert_element_type_820 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_224, [1, 0]); wait_tensor_224 = None + mm_173 = torch.ops.aten.mm.default(view_843, permute_273); view_843 = permute_273 = None + view_847 = torch.ops.aten.view.default(mm_173, [2, 8192, 14336]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_847); convert_element_type_819 = view_847 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16) + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 64, '0'); convert_element_type_823 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + view_849 = torch.ops.aten.view.default(mul_199, [16384, 14336]); mul_199 = None + mm_174 = torch.ops.aten.mm.default(view_849, permute_274); view_849 = permute_274 = None + view_850 = torch.ops.aten.view.default(mm_174, [2, 8192, 4096]); mm_174 = None + add_99 = torch.ops.aten.add.Tensor(add_97, view_850); add_97 = view_850 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16) + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 64, '0'); convert_element_type_826 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32) + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = rsqrt_50 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_226); mul_200 = wait_tensor_226 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16) + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 64, '0'); convert_element_type_829 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + view_853 = torch.ops.aten.view.default(convert_element_type_828, [16384, 4096]); convert_element_type_828 = None + mm_175 = torch.ops.aten.mm.default(view_853, permute_275); permute_275 = None + view_854 = torch.ops.aten.view.default(mm_175, [2, 8192, 4096]) + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16) + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 64, '0'); convert_element_type_832 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + mm_176 = torch.ops.aten.mm.default(view_853, permute_276); permute_276 = None + view_857 = torch.ops.aten.view.default(mm_176, [2, 8192, 1024]); mm_176 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16) + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 64, '0'); convert_element_type_835 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_229, [1, 0]); wait_tensor_229 = None + mm_177 = torch.ops.aten.mm.default(view_853, permute_277); view_853 = permute_277 = None + view_860 = torch.ops.aten.view.default(mm_177, [2, 8192, 1024]) + view_861 = torch.ops.aten.view.default(view_854, [2, 8192, -1, 128]); view_854 = None + view_862 = torch.ops.aten.view.default(view_857, [2, 8192, -1, 128]); view_857 = None + view_863 = torch.ops.aten.view.default(view_860, [2, 8192, -1, 128]); view_860 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_861, torch.float32); view_861 = None + view_864 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 32, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_864); view_864 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_862, torch.float32); view_862 = None + view_865 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 8, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_865); view_865 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_16); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_867 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 32, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_16); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_868 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 8, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_867, torch.bfloat16); view_867 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_868, torch.bfloat16); view_868 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 8, 4, 128]); unsqueeze_50 = None + clone_50 = torch.ops.aten.clone.default(expand_50, memory_format = torch.contiguous_format); expand_50 = None + view_869 = torch.ops.aten.view.default(clone_50, [2, 8192, 32, 128]); clone_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_863, 3); view_863 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 8, 4, 128]); unsqueeze_51 = None + clone_51 = torch.ops.aten.clone.default(expand_51, memory_format = torch.contiguous_format); expand_51 = None + view_870 = torch.ops.aten.view.default(clone_51, [2, 8192, 32, 128]); clone_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_869, [0, 2, 1, 3]); view_869 = None + permute_280 = torch.ops.aten.permute.default(view_870, [0, 2, 1, 3]); view_870 = None + _scaled_dot_product_cudnn_attention_25 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_278, permute_279, permute_280, None, True, 0.0, True); permute_278 = permute_279 = permute_280 = None + getitem_225 = _scaled_dot_product_cudnn_attention_25[0] + getitem_226 = _scaled_dot_product_cudnn_attention_25[1] + getitem_231 = _scaled_dot_product_cudnn_attention_25[6] + getitem_232 = _scaled_dot_product_cudnn_attention_25[7]; _scaled_dot_product_cudnn_attention_25 = None + permute_281 = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]) + view_871 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16) + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 64, '0'); convert_element_type_842 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_230, [1, 0]); wait_tensor_230 = None + view_873 = torch.ops.aten.view.default(view_871, [16384, 4096]); view_871 = None + mm_178 = torch.ops.aten.mm.default(view_873, permute_282); view_873 = permute_282 = None + view_874 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + add_101 = torch.ops.aten.add.Tensor(add_99, view_874); view_874 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16) + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 64, '0'); convert_element_type_845 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32) + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = rsqrt_51 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_231); mul_204 = wait_tensor_231 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16) + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 64, '0'); convert_element_type_848 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + view_877 = torch.ops.aten.view.default(convert_element_type_847, [16384, 4096]); convert_element_type_847 = None + mm_179 = torch.ops.aten.mm.default(view_877, permute_283); permute_283 = None + view_878 = torch.ops.aten.view.default(mm_179, [2, 8192, 14336]) + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_878, torch.float32); view_878 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); convert_element_type_851 = sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16) + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 64, '0'); convert_element_type_853 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_180 = torch.ops.aten.mm.default(view_877, permute_284); view_877 = permute_284 = None + view_881 = torch.ops.aten.view.default(mm_180, [2, 8192, 14336]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_881); convert_element_type_852 = view_881 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16) + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 64, '0'); convert_element_type_856 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + view_883 = torch.ops.aten.view.default(mul_207, [16384, 14336]); mul_207 = None + mm_181 = torch.ops.aten.mm.default(view_883, permute_285); view_883 = permute_285 = None + view_884 = torch.ops.aten.view.default(mm_181, [2, 8192, 4096]); mm_181 = None + add_103 = torch.ops.aten.add.Tensor(add_101, view_884); add_101 = view_884 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16) + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 64, '0'); convert_element_type_859 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32) + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = rsqrt_52 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_235); mul_208 = wait_tensor_235 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16) + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 64, '0'); convert_element_type_862 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_236, [1, 0]); wait_tensor_236 = None + view_887 = torch.ops.aten.view.default(convert_element_type_861, [16384, 4096]); convert_element_type_861 = None + mm_182 = torch.ops.aten.mm.default(view_887, permute_286); permute_286 = None + view_888 = torch.ops.aten.view.default(mm_182, [2, 8192, 4096]) + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16) + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 64, '0'); convert_element_type_865 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_237, [1, 0]); wait_tensor_237 = None + mm_183 = torch.ops.aten.mm.default(view_887, permute_287); permute_287 = None + view_891 = torch.ops.aten.view.default(mm_183, [2, 8192, 1024]); mm_183 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16) + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 64, '0'); convert_element_type_868 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + mm_184 = torch.ops.aten.mm.default(view_887, permute_288); view_887 = permute_288 = None + view_894 = torch.ops.aten.view.default(mm_184, [2, 8192, 1024]) + view_895 = torch.ops.aten.view.default(view_888, [2, 8192, -1, 128]); view_888 = None + view_896 = torch.ops.aten.view.default(view_891, [2, 8192, -1, 128]); view_891 = None + view_897 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_895, torch.float32); view_895 = None + view_898 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 32, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_898); view_898 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 8, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_16); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_901 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 32, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_16); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_902 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 8, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_901, torch.bfloat16); view_901 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 8, 4, 128]); unsqueeze_52 = None + clone_52 = torch.ops.aten.clone.default(expand_52, memory_format = torch.contiguous_format); expand_52 = None + view_903 = torch.ops.aten.view.default(clone_52, [2, 8192, 32, 128]); clone_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_897, 3); view_897 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 8, 4, 128]); unsqueeze_53 = None + clone_53 = torch.ops.aten.clone.default(expand_53, memory_format = torch.contiguous_format); expand_53 = None + view_904 = torch.ops.aten.view.default(clone_53, [2, 8192, 32, 128]); clone_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_903, [0, 2, 1, 3]); view_903 = None + permute_291 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + _scaled_dot_product_cudnn_attention_26 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_289, permute_290, permute_291, None, True, 0.0, True); permute_289 = permute_290 = permute_291 = None + getitem_234 = _scaled_dot_product_cudnn_attention_26[0] + getitem_235 = _scaled_dot_product_cudnn_attention_26[1] + getitem_240 = _scaled_dot_product_cudnn_attention_26[6] + getitem_241 = _scaled_dot_product_cudnn_attention_26[7]; _scaled_dot_product_cudnn_attention_26 = None + permute_292 = torch.ops.aten.permute.default(getitem_234, [0, 2, 1, 3]) + view_905 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16) + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 64, '0'); convert_element_type_875 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + view_907 = torch.ops.aten.view.default(view_905, [16384, 4096]); view_905 = None + mm_185 = torch.ops.aten.mm.default(view_907, permute_293); view_907 = permute_293 = None + view_908 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + add_105 = torch.ops.aten.add.Tensor(add_103, view_908); view_908 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16) + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 64, '0'); convert_element_type_878 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32) + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = rsqrt_53 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_240); mul_212 = wait_tensor_240 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16) + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 64, '0'); convert_element_type_881 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + view_911 = torch.ops.aten.view.default(convert_element_type_880, [16384, 4096]); convert_element_type_880 = None + mm_186 = torch.ops.aten.mm.default(view_911, permute_294); permute_294 = None + view_912 = torch.ops.aten.view.default(mm_186, [2, 8192, 14336]) + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_912, torch.float32); view_912 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); convert_element_type_884 = sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16) + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 64, '0'); convert_element_type_886 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_242, [1, 0]); wait_tensor_242 = None + mm_187 = torch.ops.aten.mm.default(view_911, permute_295); view_911 = permute_295 = None + view_915 = torch.ops.aten.view.default(mm_187, [2, 8192, 14336]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_915); convert_element_type_885 = view_915 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16) + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 64, '0'); convert_element_type_889 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_243, [1, 0]); wait_tensor_243 = None + view_917 = torch.ops.aten.view.default(mul_215, [16384, 14336]); mul_215 = None + mm_188 = torch.ops.aten.mm.default(view_917, permute_296); view_917 = permute_296 = None + view_918 = torch.ops.aten.view.default(mm_188, [2, 8192, 4096]); mm_188 = None + add_107 = torch.ops.aten.add.Tensor(add_105, view_918); add_105 = view_918 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16) + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 64, '0'); convert_element_type_892 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32) + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = rsqrt_54 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_244); mul_216 = wait_tensor_244 = None + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16) + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 64, '0'); convert_element_type_895 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + view_921 = torch.ops.aten.view.default(convert_element_type_894, [16384, 4096]); convert_element_type_894 = None + mm_189 = torch.ops.aten.mm.default(view_921, permute_297); permute_297 = None + view_922 = torch.ops.aten.view.default(mm_189, [2, 8192, 4096]) + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16) + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 64, '0'); convert_element_type_898 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_190 = torch.ops.aten.mm.default(view_921, permute_298); permute_298 = None + view_925 = torch.ops.aten.view.default(mm_190, [2, 8192, 1024]); mm_190 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16) + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 64, '0'); convert_element_type_901 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + mm_191 = torch.ops.aten.mm.default(view_921, permute_299); view_921 = permute_299 = None + view_928 = torch.ops.aten.view.default(mm_191, [2, 8192, 1024]) + view_929 = torch.ops.aten.view.default(view_922, [2, 8192, -1, 128]); view_922 = None + view_930 = torch.ops.aten.view.default(view_925, [2, 8192, -1, 128]); view_925 = None + view_931 = torch.ops.aten.view.default(view_928, [2, 8192, -1, 128]); view_928 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_929, torch.float32); view_929 = None + view_932 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 32, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_932); view_932 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_930, torch.float32); view_930 = None + view_933 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 8, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_933); view_933 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_16); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_935 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 32, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_16); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_936 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 8, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_935, torch.bfloat16); view_935 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_936, torch.bfloat16); view_936 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 8, 4, 128]); unsqueeze_54 = None + clone_54 = torch.ops.aten.clone.default(expand_54, memory_format = torch.contiguous_format); expand_54 = None + view_937 = torch.ops.aten.view.default(clone_54, [2, 8192, 32, 128]); clone_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_931, 3); view_931 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 8, 4, 128]); unsqueeze_55 = None + clone_55 = torch.ops.aten.clone.default(expand_55, memory_format = torch.contiguous_format); expand_55 = None + view_938 = torch.ops.aten.view.default(clone_55, [2, 8192, 32, 128]); clone_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_937, [0, 2, 1, 3]); view_937 = None + permute_302 = torch.ops.aten.permute.default(view_938, [0, 2, 1, 3]); view_938 = None + _scaled_dot_product_cudnn_attention_27 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_300, permute_301, permute_302, None, True, 0.0, True); permute_300 = permute_301 = permute_302 = None + getitem_243 = _scaled_dot_product_cudnn_attention_27[0] + getitem_244 = _scaled_dot_product_cudnn_attention_27[1] + getitem_249 = _scaled_dot_product_cudnn_attention_27[6] + getitem_250 = _scaled_dot_product_cudnn_attention_27[7]; _scaled_dot_product_cudnn_attention_27 = None + permute_303 = torch.ops.aten.permute.default(getitem_243, [0, 2, 1, 3]) + view_939 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16) + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 64, '0'); convert_element_type_908 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_248, [1, 0]); wait_tensor_248 = None + view_941 = torch.ops.aten.view.default(view_939, [16384, 4096]); view_939 = None + mm_192 = torch.ops.aten.mm.default(view_941, permute_304); view_941 = permute_304 = None + view_942 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + add_109 = torch.ops.aten.add.Tensor(add_107, view_942); view_942 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16) + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 64, '0'); convert_element_type_911 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32) + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = rsqrt_55 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_249); mul_220 = wait_tensor_249 = None + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16) + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 64, '0'); convert_element_type_914 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_250, [1, 0]); wait_tensor_250 = None + view_945 = torch.ops.aten.view.default(convert_element_type_913, [16384, 4096]); convert_element_type_913 = None + mm_193 = torch.ops.aten.mm.default(view_945, permute_305); permute_305 = None + view_946 = torch.ops.aten.view.default(mm_193, [2, 8192, 14336]) + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_946, torch.float32); view_946 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); convert_element_type_917 = sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16) + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 64, '0'); convert_element_type_919 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + mm_194 = torch.ops.aten.mm.default(view_945, permute_306); view_945 = permute_306 = None + view_949 = torch.ops.aten.view.default(mm_194, [2, 8192, 14336]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_949); convert_element_type_918 = view_949 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16) + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 64, '0'); convert_element_type_922 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + view_951 = torch.ops.aten.view.default(mul_223, [16384, 14336]); mul_223 = None + mm_195 = torch.ops.aten.mm.default(view_951, permute_307); view_951 = permute_307 = None + view_952 = torch.ops.aten.view.default(mm_195, [2, 8192, 4096]); mm_195 = None + add_111 = torch.ops.aten.add.Tensor(add_109, view_952); add_109 = view_952 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16) + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 64, '0'); convert_element_type_925 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32) + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = rsqrt_56 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_253); mul_224 = wait_tensor_253 = None + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16) + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 64, '0'); convert_element_type_928 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + view_955 = torch.ops.aten.view.default(convert_element_type_927, [16384, 4096]); convert_element_type_927 = None + mm_196 = torch.ops.aten.mm.default(view_955, permute_308); permute_308 = None + view_956 = torch.ops.aten.view.default(mm_196, [2, 8192, 4096]) + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16) + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 64, '0'); convert_element_type_931 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_255, [1, 0]); wait_tensor_255 = None + mm_197 = torch.ops.aten.mm.default(view_955, permute_309); permute_309 = None + view_959 = torch.ops.aten.view.default(mm_197, [2, 8192, 1024]); mm_197 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16) + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 64, '0'); convert_element_type_934 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_256, [1, 0]); wait_tensor_256 = None + mm_198 = torch.ops.aten.mm.default(view_955, permute_310); view_955 = permute_310 = None + view_962 = torch.ops.aten.view.default(mm_198, [2, 8192, 1024]) + view_963 = torch.ops.aten.view.default(view_956, [2, 8192, -1, 128]); view_956 = None + view_964 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_965 = torch.ops.aten.view.default(view_962, [2, 8192, -1, 128]); view_962 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_963, torch.float32); view_963 = None + view_966 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 32, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_966); view_966 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_964, torch.float32); view_964 = None + view_967 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 8, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_967); view_967 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_16); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_969 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 32, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_16); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_970 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 8, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_969, torch.bfloat16); view_969 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_970, torch.bfloat16); view_970 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 8, 4, 128]); unsqueeze_56 = None + clone_56 = torch.ops.aten.clone.default(expand_56, memory_format = torch.contiguous_format); expand_56 = None + view_971 = torch.ops.aten.view.default(clone_56, [2, 8192, 32, 128]); clone_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_965, 3); view_965 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 8, 4, 128]); unsqueeze_57 = None + clone_57 = torch.ops.aten.clone.default(expand_57, memory_format = torch.contiguous_format); expand_57 = None + view_972 = torch.ops.aten.view.default(clone_57, [2, 8192, 32, 128]); clone_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_971, [0, 2, 1, 3]); view_971 = None + permute_313 = torch.ops.aten.permute.default(view_972, [0, 2, 1, 3]); view_972 = None + _scaled_dot_product_cudnn_attention_28 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_311, permute_312, permute_313, None, True, 0.0, True); permute_311 = permute_312 = permute_313 = None + getitem_252 = _scaled_dot_product_cudnn_attention_28[0] + getitem_253 = _scaled_dot_product_cudnn_attention_28[1] + getitem_258 = _scaled_dot_product_cudnn_attention_28[6] + getitem_259 = _scaled_dot_product_cudnn_attention_28[7]; _scaled_dot_product_cudnn_attention_28 = None + permute_314 = torch.ops.aten.permute.default(getitem_252, [0, 2, 1, 3]) + view_973 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16) + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 64, '0'); convert_element_type_941 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_257, [1, 0]); wait_tensor_257 = None + view_975 = torch.ops.aten.view.default(view_973, [16384, 4096]); view_973 = None + mm_199 = torch.ops.aten.mm.default(view_975, permute_315); view_975 = permute_315 = None + view_976 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + add_113 = torch.ops.aten.add.Tensor(add_111, view_976); view_976 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16) + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 64, '0'); convert_element_type_944 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32) + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = rsqrt_57 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_258); mul_228 = wait_tensor_258 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16) + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 64, '0'); convert_element_type_947 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + view_979 = torch.ops.aten.view.default(convert_element_type_946, [16384, 4096]); convert_element_type_946 = None + mm_200 = torch.ops.aten.mm.default(view_979, permute_316); permute_316 = None + view_980 = torch.ops.aten.view.default(mm_200, [2, 8192, 14336]) + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_980, torch.float32); view_980 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); convert_element_type_950 = sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16) + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 64, '0'); convert_element_type_952 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + mm_201 = torch.ops.aten.mm.default(view_979, permute_317); view_979 = permute_317 = None + view_983 = torch.ops.aten.view.default(mm_201, [2, 8192, 14336]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_983); convert_element_type_951 = view_983 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16) + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 64, '0'); convert_element_type_955 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_261, [1, 0]); wait_tensor_261 = None + view_985 = torch.ops.aten.view.default(mul_231, [16384, 14336]); mul_231 = None + mm_202 = torch.ops.aten.mm.default(view_985, permute_318); view_985 = permute_318 = None + view_986 = torch.ops.aten.view.default(mm_202, [2, 8192, 4096]); mm_202 = None + add_115 = torch.ops.aten.add.Tensor(add_113, view_986); add_113 = view_986 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16) + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 64, '0'); convert_element_type_958 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32) + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = rsqrt_58 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_262); mul_232 = wait_tensor_262 = None + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16) + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 64, '0'); convert_element_type_961 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_263, [1, 0]); wait_tensor_263 = None + view_989 = torch.ops.aten.view.default(convert_element_type_960, [16384, 4096]); convert_element_type_960 = None + mm_203 = torch.ops.aten.mm.default(view_989, permute_319); permute_319 = None + view_990 = torch.ops.aten.view.default(mm_203, [2, 8192, 4096]) + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16) + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 64, '0'); convert_element_type_964 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + mm_204 = torch.ops.aten.mm.default(view_989, permute_320); permute_320 = None + view_993 = torch.ops.aten.view.default(mm_204, [2, 8192, 1024]); mm_204 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16) + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 64, '0'); convert_element_type_967 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_205 = torch.ops.aten.mm.default(view_989, permute_321); view_989 = permute_321 = None + view_996 = torch.ops.aten.view.default(mm_205, [2, 8192, 1024]) + view_997 = torch.ops.aten.view.default(view_990, [2, 8192, -1, 128]); view_990 = None + view_998 = torch.ops.aten.view.default(view_993, [2, 8192, -1, 128]); view_993 = None + view_999 = torch.ops.aten.view.default(view_996, [2, 8192, -1, 128]); view_996 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + view_1000 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 32, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_1000); view_1000 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_998, torch.float32); view_998 = None + view_1001 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 8, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_1001); view_1001 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_16); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_1003 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 32, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_16); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_1004 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 8, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_1003, torch.bfloat16); view_1003 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_1004, torch.bfloat16); view_1004 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 8, 4, 128]); unsqueeze_58 = None + clone_58 = torch.ops.aten.clone.default(expand_58, memory_format = torch.contiguous_format); expand_58 = None + view_1005 = torch.ops.aten.view.default(clone_58, [2, 8192, 32, 128]); clone_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_999, 3); view_999 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 8, 4, 128]); unsqueeze_59 = None + clone_59 = torch.ops.aten.clone.default(expand_59, memory_format = torch.contiguous_format); expand_59 = None + view_1006 = torch.ops.aten.view.default(clone_59, [2, 8192, 32, 128]); clone_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_1005, [0, 2, 1, 3]); view_1005 = None + permute_324 = torch.ops.aten.permute.default(view_1006, [0, 2, 1, 3]); view_1006 = None + _scaled_dot_product_cudnn_attention_29 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_322, permute_323, permute_324, None, True, 0.0, True); permute_322 = permute_323 = permute_324 = None + getitem_261 = _scaled_dot_product_cudnn_attention_29[0] + getitem_262 = _scaled_dot_product_cudnn_attention_29[1] + getitem_267 = _scaled_dot_product_cudnn_attention_29[6] + getitem_268 = _scaled_dot_product_cudnn_attention_29[7]; _scaled_dot_product_cudnn_attention_29 = None + permute_325 = torch.ops.aten.permute.default(getitem_261, [0, 2, 1, 3]) + view_1007 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16) + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 64, '0'); convert_element_type_974 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + view_1009 = torch.ops.aten.view.default(view_1007, [16384, 4096]); view_1007 = None + mm_206 = torch.ops.aten.mm.default(view_1009, permute_326); view_1009 = permute_326 = None + view_1010 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + add_117 = torch.ops.aten.add.Tensor(add_115, view_1010); view_1010 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16) + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 64, '0'); convert_element_type_977 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32) + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = rsqrt_59 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_267); mul_236 = wait_tensor_267 = None + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16) + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 64, '0'); convert_element_type_980 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_268, [1, 0]); wait_tensor_268 = None + view_1013 = torch.ops.aten.view.default(convert_element_type_979, [16384, 4096]); convert_element_type_979 = None + mm_207 = torch.ops.aten.mm.default(view_1013, permute_327); permute_327 = None + view_1014 = torch.ops.aten.view.default(mm_207, [2, 8192, 14336]) + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_1014, torch.float32); view_1014 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); convert_element_type_983 = sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16) + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 64, '0'); convert_element_type_985 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_269, [1, 0]); wait_tensor_269 = None + mm_208 = torch.ops.aten.mm.default(view_1013, permute_328); view_1013 = permute_328 = None + view_1017 = torch.ops.aten.view.default(mm_208, [2, 8192, 14336]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_1017); convert_element_type_984 = view_1017 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16) + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 64, '0'); convert_element_type_988 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_270, [1, 0]); wait_tensor_270 = None + view_1019 = torch.ops.aten.view.default(mul_239, [16384, 14336]); mul_239 = None + mm_209 = torch.ops.aten.mm.default(view_1019, permute_329); view_1019 = permute_329 = None + view_1020 = torch.ops.aten.view.default(mm_209, [2, 8192, 4096]); mm_209 = None + add_119 = torch.ops.aten.add.Tensor(add_117, view_1020); add_117 = view_1020 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16) + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 64, '0'); convert_element_type_991 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32) + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = rsqrt_60 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_271); mul_240 = wait_tensor_271 = None + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16) + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 64, '0'); convert_element_type_994 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + view_1023 = torch.ops.aten.view.default(convert_element_type_993, [16384, 4096]); convert_element_type_993 = None + mm_210 = torch.ops.aten.mm.default(view_1023, permute_330); permute_330 = None + view_1024 = torch.ops.aten.view.default(mm_210, [2, 8192, 4096]) + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16) + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 64, '0'); convert_element_type_997 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + mm_211 = torch.ops.aten.mm.default(view_1023, permute_331); permute_331 = None + view_1027 = torch.ops.aten.view.default(mm_211, [2, 8192, 1024]); mm_211 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16) + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 64, '0'); convert_element_type_1000 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_274, [1, 0]); wait_tensor_274 = None + mm_212 = torch.ops.aten.mm.default(view_1023, permute_332); view_1023 = permute_332 = None + view_1030 = torch.ops.aten.view.default(mm_212, [2, 8192, 1024]) + view_1031 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1032 = torch.ops.aten.view.default(view_1027, [2, 8192, -1, 128]); view_1027 = None + view_1033 = torch.ops.aten.view.default(view_1030, [2, 8192, -1, 128]); view_1030 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_1031, torch.float32); view_1031 = None + view_1034 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 32, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_1034); view_1034 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_1032, torch.float32); view_1032 = None + view_1035 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 8, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_1035); view_1035 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_16); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_1037 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 32, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_16); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_1038 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 8, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_1037, torch.bfloat16); view_1037 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_1038, torch.bfloat16); view_1038 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 8, 4, 128]); unsqueeze_60 = None + clone_60 = torch.ops.aten.clone.default(expand_60, memory_format = torch.contiguous_format); expand_60 = None + view_1039 = torch.ops.aten.view.default(clone_60, [2, 8192, 32, 128]); clone_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_1033, 3); view_1033 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 8, 4, 128]); unsqueeze_61 = None + clone_61 = torch.ops.aten.clone.default(expand_61, memory_format = torch.contiguous_format); expand_61 = None + view_1040 = torch.ops.aten.view.default(clone_61, [2, 8192, 32, 128]); clone_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_1039, [0, 2, 1, 3]); view_1039 = None + permute_335 = torch.ops.aten.permute.default(view_1040, [0, 2, 1, 3]); view_1040 = None + _scaled_dot_product_cudnn_attention_30 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_333, permute_334, permute_335, None, True, 0.0, True); permute_333 = permute_334 = permute_335 = None + getitem_270 = _scaled_dot_product_cudnn_attention_30[0] + getitem_271 = _scaled_dot_product_cudnn_attention_30[1] + getitem_276 = _scaled_dot_product_cudnn_attention_30[6] + getitem_277 = _scaled_dot_product_cudnn_attention_30[7]; _scaled_dot_product_cudnn_attention_30 = None + permute_336 = torch.ops.aten.permute.default(getitem_270, [0, 2, 1, 3]) + view_1041 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16) + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 64, '0'); convert_element_type_1007 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_275, [1, 0]); wait_tensor_275 = None + view_1043 = torch.ops.aten.view.default(view_1041, [16384, 4096]); view_1041 = None + mm_213 = torch.ops.aten.mm.default(view_1043, permute_337); view_1043 = permute_337 = None + view_1044 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + add_121 = torch.ops.aten.add.Tensor(add_119, view_1044); view_1044 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16) + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 64, '0'); convert_element_type_1010 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32) + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = rsqrt_61 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_276); mul_244 = wait_tensor_276 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16) + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 64, '0'); convert_element_type_1013 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + view_1047 = torch.ops.aten.view.default(convert_element_type_1012, [16384, 4096]); convert_element_type_1012 = None + mm_214 = torch.ops.aten.mm.default(view_1047, permute_338); permute_338 = None + view_1048 = torch.ops.aten.view.default(mm_214, [2, 8192, 14336]) + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_1048, torch.float32); view_1048 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); convert_element_type_1016 = sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16) + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 64, '0'); convert_element_type_1018 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_215 = torch.ops.aten.mm.default(view_1047, permute_339); view_1047 = permute_339 = None + view_1051 = torch.ops.aten.view.default(mm_215, [2, 8192, 14336]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_1051); convert_element_type_1017 = view_1051 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16) + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 64, '0'); convert_element_type_1021 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + view_1053 = torch.ops.aten.view.default(mul_247, [16384, 14336]); mul_247 = None + mm_216 = torch.ops.aten.mm.default(view_1053, permute_340); view_1053 = permute_340 = None + view_1054 = torch.ops.aten.view.default(mm_216, [2, 8192, 4096]); mm_216 = None + add_123 = torch.ops.aten.add.Tensor(add_121, view_1054); add_121 = view_1054 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16) + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 64, '0'); convert_element_type_1024 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32) + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = rsqrt_62 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_280); mul_248 = wait_tensor_280 = None + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16) + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 64, '0'); convert_element_type_1027 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_281, [1, 0]); wait_tensor_281 = None + view_1057 = torch.ops.aten.view.default(convert_element_type_1026, [16384, 4096]); convert_element_type_1026 = None + mm_217 = torch.ops.aten.mm.default(view_1057, permute_341); permute_341 = None + view_1058 = torch.ops.aten.view.default(mm_217, [2, 8192, 4096]) + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16) + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 64, '0'); convert_element_type_1030 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_282, [1, 0]); wait_tensor_282 = None + mm_218 = torch.ops.aten.mm.default(view_1057, permute_342); permute_342 = None + view_1061 = torch.ops.aten.view.default(mm_218, [2, 8192, 1024]); mm_218 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16) + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 64, '0'); convert_element_type_1033 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_283, [1, 0]); wait_tensor_283 = None + mm_219 = torch.ops.aten.mm.default(view_1057, permute_343); view_1057 = permute_343 = None + view_1064 = torch.ops.aten.view.default(mm_219, [2, 8192, 1024]) + view_1065 = torch.ops.aten.view.default(view_1058, [2, 8192, -1, 128]); view_1058 = None + view_1066 = torch.ops.aten.view.default(view_1061, [2, 8192, -1, 128]); view_1061 = None + view_1067 = torch.ops.aten.view.default(view_1064, [2, 8192, -1, 128]); view_1064 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_1065, torch.float32); view_1065 = None + view_1068 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 32, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_1068); view_1068 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_1066, torch.float32); view_1066 = None + view_1069 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 8, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_1069); view_1069 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_16); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_1071 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 32, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_16); view_as_complex_63 = view_16 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_1072 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 8, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_1071, torch.bfloat16); view_1071 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_1072, torch.bfloat16); view_1072 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 8, 4, 128]); unsqueeze_62 = None + clone_62 = torch.ops.aten.clone.default(expand_62, memory_format = torch.contiguous_format); expand_62 = None + view_1073 = torch.ops.aten.view.default(clone_62, [2, 8192, 32, 128]); clone_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_1067, 3); view_1067 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 8, 4, 128]); unsqueeze_63 = None + clone_63 = torch.ops.aten.clone.default(expand_63, memory_format = torch.contiguous_format); expand_63 = None + view_1074 = torch.ops.aten.view.default(clone_63, [2, 8192, 32, 128]); clone_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_1073, [0, 2, 1, 3]); view_1073 = None + permute_346 = torch.ops.aten.permute.default(view_1074, [0, 2, 1, 3]); view_1074 = None + _scaled_dot_product_cudnn_attention_31 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_344, permute_345, permute_346, None, True, 0.0, True); permute_344 = permute_345 = permute_346 = None + getitem_279 = _scaled_dot_product_cudnn_attention_31[0] + getitem_280 = _scaled_dot_product_cudnn_attention_31[1] + getitem_285 = _scaled_dot_product_cudnn_attention_31[6] + getitem_286 = _scaled_dot_product_cudnn_attention_31[7]; _scaled_dot_product_cudnn_attention_31 = None + permute_347 = torch.ops.aten.permute.default(getitem_279, [0, 2, 1, 3]) + view_1075 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16) + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 64, '0'); convert_element_type_1040 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1077 = torch.ops.aten.view.default(view_1075, [16384, 4096]); view_1075 = None + mm_220 = torch.ops.aten.mm.default(view_1077, permute_348); view_1077 = permute_348 = None + view_1078 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + add_125 = torch.ops.aten.add.Tensor(add_123, view_1078); view_1078 = None + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16) + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 64, '0'); convert_element_type_1043 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32) + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = rsqrt_63 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_285); mul_252 = wait_tensor_285 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16) + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 64, '0'); convert_element_type_1046 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + view_1081 = torch.ops.aten.view.default(convert_element_type_1045, [16384, 4096]); convert_element_type_1045 = None + mm_221 = torch.ops.aten.mm.default(view_1081, permute_349); permute_349 = None + view_1082 = torch.ops.aten.view.default(mm_221, [2, 8192, 14336]) + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_1082, torch.float32); view_1082 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); convert_element_type_1049 = sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16) + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 64, '0'); convert_element_type_1051 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_287, [1, 0]); wait_tensor_287 = None + mm_222 = torch.ops.aten.mm.default(view_1081, permute_350); view_1081 = permute_350 = None + view_1085 = torch.ops.aten.view.default(mm_222, [2, 8192, 14336]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_1085); convert_element_type_1050 = view_1085 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16) + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 64, '0'); convert_element_type_1054 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_288, [1, 0]); wait_tensor_288 = None + view_1087 = torch.ops.aten.view.default(mul_255, [16384, 14336]); mul_255 = None + mm_223 = torch.ops.aten.mm.default(view_1087, permute_351); view_1087 = permute_351 = None + view_1088 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]) + add_127 = torch.ops.aten.add.Tensor(add_125, view_1088); add_125 = view_1088 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16) + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 64, '0'); convert_element_type_1057 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + pow_65 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1058, 2) + mean_64 = torch.ops.aten.mean.dim(pow_65, [2], True); pow_65 = None + add_128 = torch.ops.aten.add.Scalar(mean_64, 1e-05); mean_64 = None + rsqrt_64 = torch.ops.aten.rsqrt.default(add_128); add_128 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_257 = torch.ops.aten.mul.Tensor(mul_256, wait_tensor_289); mul_256 = wait_tensor_289 = None + convert_element_type_1059 = torch.ops.prims.convert_element_type.default(mul_257, torch.bfloat16); mul_257 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16) + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 64, '0'); convert_element_type_1060 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + view_1091 = torch.ops.aten.view.default(convert_element_type_1059, [16384, 4096]); convert_element_type_1059 = None + mm_224 = torch.ops.aten.mm.default(view_1091, permute_352); permute_352 = None + view_1092 = torch.ops.aten.view.default(mm_224, [2, 8192, 128256]); mm_224 = None + return (view_1092, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, embedding, mm, mm_2, getitem, getitem_1, getitem_6, getitem_7, mm_4, add_3, mm_7, mm_9, getitem_9, getitem_10, getitem_15, getitem_16, mm_11, add_7, mm_14, mm_16, getitem_18, getitem_19, getitem_24, getitem_25, mm_18, add_11, mm_21, mm_23, getitem_27, getitem_28, getitem_33, getitem_34, mm_25, add_15, mm_28, mm_30, getitem_36, getitem_37, getitem_42, getitem_43, mm_32, add_19, mm_35, mm_37, getitem_45, getitem_46, getitem_51, getitem_52, mm_39, add_23, mm_42, mm_44, getitem_54, getitem_55, getitem_60, getitem_61, mm_46, add_27, mm_49, mm_51, getitem_63, getitem_64, getitem_69, getitem_70, mm_53, add_31, mm_56, mm_58, getitem_72, getitem_73, getitem_78, getitem_79, mm_60, add_35, mm_63, mm_65, getitem_81, getitem_82, getitem_87, getitem_88, mm_67, add_39, mm_70, mm_72, getitem_90, getitem_91, getitem_96, getitem_97, mm_74, add_43, mm_77, mm_79, getitem_99, getitem_100, getitem_105, getitem_106, mm_81, add_47, mm_84, mm_86, getitem_108, getitem_109, getitem_114, getitem_115, mm_88, add_51, mm_91, mm_93, getitem_117, getitem_118, getitem_123, getitem_124, mm_95, add_55, mm_98, mm_100, getitem_126, getitem_127, getitem_132, getitem_133, mm_102, add_59, mm_105, mm_107, getitem_135, getitem_136, getitem_141, getitem_142, mm_109, add_63, mm_112, mm_114, getitem_144, getitem_145, getitem_150, getitem_151, mm_116, add_67, mm_119, mm_121, getitem_153, getitem_154, getitem_159, getitem_160, mm_123, add_71, mm_126, mm_128, getitem_162, getitem_163, getitem_168, getitem_169, mm_130, add_75, mm_133, mm_135, getitem_171, getitem_172, getitem_177, getitem_178, mm_137, add_79, mm_140, mm_142, getitem_180, getitem_181, getitem_186, getitem_187, mm_144, add_83, mm_147, mm_149, getitem_189, getitem_190, getitem_195, getitem_196, mm_151, add_87, mm_154, mm_156, getitem_198, getitem_199, getitem_204, getitem_205, mm_158, add_91, mm_161, mm_163, getitem_207, getitem_208, getitem_213, getitem_214, mm_165, add_95, mm_168, mm_170, getitem_216, getitem_217, getitem_222, getitem_223, mm_172, add_99, mm_175, mm_177, getitem_225, getitem_226, getitem_231, getitem_232, mm_179, add_103, mm_182, mm_184, getitem_234, getitem_235, getitem_240, getitem_241, mm_186, add_107, mm_189, mm_191, getitem_243, getitem_244, getitem_249, getitem_250, mm_193, add_111, mm_196, mm_198, getitem_252, getitem_253, getitem_258, getitem_259, mm_200, add_115, mm_203, mm_205, getitem_261, getitem_262, getitem_267, getitem_268, mm_207, add_119, mm_210, mm_212, getitem_270, getitem_271, getitem_276, getitem_277, mm_214, add_123, mm_217, mm_219, getitem_279, getitem_280, getitem_285, getitem_286, mm_221, mm_223, rsqrt_64, view_1091) + +def load_args(reader): + buf0 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf0, (2004, 4096), is_leaf=True) # primals_1 + buf1 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf1, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf3, (64,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf4, (64, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf5, (16, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf6, (16, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf7, (64, 4096), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf8, (64,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf9, (224, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf10, (224, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf11, (64, 14336), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf12, (64,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf13, (64, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf14, (16, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf15, (16, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf16, (64, 4096), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf17, (64,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf18, (224, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf19, (224, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf20, (64, 14336), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf21, (64,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf22, (64, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf23, (16, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf24, (16, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf25, (64, 4096), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf26, (64,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf27, (224, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf28, (224, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf29, (64, 14336), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf30, (64,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf31, (64, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf32, (16, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf33, (16, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf34, (64, 4096), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf35, (64,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf36, (224, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf37, (224, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf38, (64, 14336), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf39, (64,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf40, (64, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf41, (16, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf42, (16, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf43, (64, 4096), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf44, (64,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf45, (224, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf46, (224, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf47, (64, 14336), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf48, (64,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf49, (64, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf50, (16, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf51, (16, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf52, (64, 4096), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf53, (64,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf54, (224, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf55, (224, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf56, (64, 14336), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf57, (64,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf58, (64, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf59, (16, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf60, (16, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf61, (64, 4096), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf62, (64,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf63, (224, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf64, (224, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf65, (64, 14336), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf66, (64,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf67, (64, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf68, (16, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf69, (16, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf70, (64, 4096), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf71, (64,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf72, (224, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf73, (224, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf74, (64, 14336), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf75, (64,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf76, (64, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf77, (16, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf78, (16, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf79, (64, 4096), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf80, (64,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf81, (224, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf82, (224, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf83, (64, 14336), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf84, (64,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf85, (64, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf86, (16, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf87, (16, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf88, (64, 4096), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf89, (64,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf90, (224, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf91, (224, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf92, (64, 14336), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf93, (64,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf94, (64, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf95, (16, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf96, (16, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf97, (64, 4096), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf98, (64,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf99, (224, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf100, (224, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf101, (64, 14336), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf102, (64,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf103, (64, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf104, (16, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf105, (16, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf106, (64, 4096), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf107, (64,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf108, (224, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf109, (224, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf110, (64, 14336), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf111, (64,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf112, (64, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf113, (16, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf114, (16, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf115, (64, 4096), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf116, (64,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf117, (224, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf118, (224, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf119, (64, 14336), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf120, (64,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf121, (64, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf122, (16, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf123, (16, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf124, (64, 4096), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf125, (64,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf126, (224, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf127, (224, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf128, (64, 14336), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf129, (64,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf130, (64, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf131, (16, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf132, (16, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf133, (64, 4096), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf134, (64,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf135, (224, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf136, (224, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf137, (64, 14336), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf138, (64,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf139, (64, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf140, (16, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf141, (16, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf142, (64, 4096), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf143, (64,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf144, (224, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf145, (224, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf146, (64, 14336), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf147, (64,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf148, (64, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf149, (16, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf150, (16, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf151, (64, 4096), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf152, (64,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf153, (224, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf154, (224, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf155, (64, 14336), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf156, (64,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf157, (64, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf158, (16, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf159, (16, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf160, (64, 4096), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf161, (64,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf162, (224, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf163, (224, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf164, (64, 14336), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf165, (64,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf166, (64, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf167, (16, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf168, (16, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf169, (64, 4096), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf170, (64,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf171, (224, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf172, (224, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf173, (64, 14336), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf174, (64,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf175, (64, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf176, (16, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf177, (16, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf178, (64, 4096), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf179, (64,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf180, (224, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf181, (224, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf182, (64, 14336), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf183, (64,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf184, (64, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf185, (16, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf186, (16, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf187, (64, 4096), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf188, (64,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf189, (224, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf190, (224, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf191, (64, 14336), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf192, (64,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf193, (64, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf194, (16, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf195, (16, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf196, (64, 4096), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf197, (64,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf198, (224, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf199, (224, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf200, (64, 14336), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf201, (64,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf202, (64, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf203, (16, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf204, (16, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf205, (64, 4096), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf206, (64,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf207, (224, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf208, (224, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf209, (64, 14336), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf210, (64,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf211, (64, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf212, (16, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf213, (16, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf214, (64, 4096), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf215, (64,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf216, (224, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf217, (224, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf218, (64, 14336), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf219, (64,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf220, (64, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf221, (16, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf222, (16, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf223, (64, 4096), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf224, (64,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf225, (224, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf226, (224, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf227, (64, 14336), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf228, (64,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf229, (64, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf230, (16, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf231, (16, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf232, (64, 4096), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf233, (64,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf234, (224, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf235, (224, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf236, (64, 14336), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf237, (64,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf238, (64, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf239, (16, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf240, (16, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf241, (64, 4096), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf242, (64,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf243, (224, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf244, (224, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf245, (64, 14336), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf246, (64,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf247, (64, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf248, (16, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf249, (16, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf250, (64, 4096), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf251, (64,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf252, (224, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf253, (224, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf254, (64, 14336), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf255, (64,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf256, (64, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf257, (16, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf258, (16, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf259, (64, 4096), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf260, (64,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf261, (224, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf262, (224, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf263, (64, 14336), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf264, (64,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf265, (64, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf266, (16, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf267, (16, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf268, (64, 4096), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf269, (64,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf270, (224, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf271, (224, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf272, (64, 14336), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf273, (64,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf274, (64, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf275, (16, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf276, (16, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf277, (64, 4096), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf278, (64,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf279, (224, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf280, (224, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf281, (64, 14336), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf282, (64,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf283, (64, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf284, (16, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf285, (16, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf286, (64, 4096), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf287, (64,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf288, (224, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf289, (224, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf290, (64, 14336), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 256, device=device(type='cuda', index=0)) + reader.tensor(buf291, (64,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf292, (2004, 4096), is_leaf=True) # primals_293 + +load_args._version = 0 + +def get_pg_config(): + return {'0': {'size': 64, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls8_8.table" diff --git a/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_2d.py b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_2d.py new file mode 100644 index 00000000..d82a452f --- /dev/null +++ b/autoparallel/tools/overlap_simulator/repro_llama3_8b_fw_64_2d.py @@ -0,0 +1,5657 @@ +# fmt: off +# flake8: noqa +# isort: skip_file +import torch +from torch.nn import * +from torch import tensor, device + + +class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293): + convert_element_type = torch.ops.prims.convert_element_type.default(primals_2, torch.bfloat16) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type, 8, '0'); convert_element_type = None + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None + lt = torch.ops.aten.lt.Scalar(primals_1, 0) + ge = torch.ops.aten.ge.Scalar(primals_1, 16032) + bitwise_or = torch.ops.aten.bitwise_or.Tensor(lt, ge); lt = ge = None + sub = torch.ops.aten.sub.Tensor(primals_1, 0) + full_default = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + index_put = torch.ops.aten.index_put.default(sub, [bitwise_or], full_default); sub = full_default = None + embedding = torch.ops.aten.embedding.default(wait_tensor, index_put); wait_tensor = index_put = None + full_default_1 = torch.ops.aten.full.default([], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + index_put_1 = torch.ops.aten.index_put.default(embedding, [bitwise_or], full_default_1); embedding = bitwise_or = full_default_1 = None + split_1 = torch.ops.aten.split.Tensor(index_put_1, 1024, 1); index_put_1 = None + getitem_8 = split_1[0] + getitem_17 = split_1[1] + getitem_26 = split_1[2] + getitem_35 = split_1[3] + getitem_44 = split_1[4] + getitem_53 = split_1[5] + getitem_62 = split_1[6] + getitem_71 = split_1[7]; split_1 = None + cat = torch.ops.aten.cat.default([getitem_8, getitem_17, getitem_26, getitem_35, getitem_44, getitem_53, getitem_62, getitem_71]); getitem_8 = getitem_17 = getitem_26 = getitem_35 = getitem_44 = getitem_53 = getitem_62 = getitem_71 = None + reduce_scatter_tensor = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat, 'sum', 8, '1'); cat = None + wait_tensor_1 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None + convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16) + all_gather_into_tensor_1 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1, 8, '0'); convert_element_type_1 = None + wait_tensor_2 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None + convert_element_type_2 = torch.ops.prims.convert_element_type.default(wait_tensor_1, torch.float32) + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_2, 2) + mean = torch.ops.aten.mean.dim(pow_1, [2], True); pow_1 = None + add = torch.ops.aten.add.Scalar(mean, 1e-05); mean = None + rsqrt = torch.ops.aten.rsqrt.default(add); add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_2, rsqrt); convert_element_type_2 = rsqrt = None + mul_1 = torch.ops.aten.mul.Tensor(mul, wait_tensor_2); mul = wait_tensor_2 = None + convert_element_type_3 = torch.ops.prims.convert_element_type.default(mul_1, torch.bfloat16); mul_1 = None + all_gather_into_tensor_2 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_3, 8, '1'); convert_element_type_3 = None + wait_tensor_3 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None + split_9 = torch.ops.aten.split.Tensor(wait_tensor_3, 2); wait_tensor_3 = None + getitem_72 = split_9[0] + getitem_73 = split_9[1] + getitem_74 = split_9[2] + getitem_75 = split_9[3] + getitem_76 = split_9[4] + getitem_77 = split_9[5] + getitem_78 = split_9[6] + getitem_79 = split_9[7]; split_9 = None + cat_1 = torch.ops.aten.cat.default([getitem_72, getitem_73, getitem_74, getitem_75, getitem_76, getitem_77, getitem_78, getitem_79], 1); getitem_72 = getitem_73 = getitem_74 = getitem_75 = getitem_76 = getitem_77 = getitem_78 = getitem_79 = None + convert_element_type_4 = torch.ops.prims.convert_element_type.default(primals_5, torch.bfloat16) + all_gather_into_tensor_3 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_4, 8, '0'); convert_element_type_4 = None + wait_tensor_4 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None + permute = torch.ops.aten.permute.default(wait_tensor_4, [1, 0]); wait_tensor_4 = None + view_15 = torch.ops.aten.view.default(cat_1, [16384, 4096]); cat_1 = None + mm = torch.ops.aten.mm.default(view_15, permute); permute = None + view_16 = torch.ops.aten.view.default(mm, [2, 8192, 512]) + convert_element_type_7 = torch.ops.prims.convert_element_type.default(primals_6, torch.bfloat16) + all_gather_into_tensor_4 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_7, 8, '0'); convert_element_type_7 = None + wait_tensor_5 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None + permute_1 = torch.ops.aten.permute.default(wait_tensor_5, [1, 0]); wait_tensor_5 = None + mm_1 = torch.ops.aten.mm.default(view_15, permute_1); permute_1 = None + view_23 = torch.ops.aten.view.default(mm_1, [2, 8192, 128]); mm_1 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default(primals_7, torch.bfloat16) + all_gather_into_tensor_5 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_10, 8, '0'); convert_element_type_10 = None + wait_tensor_6 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None + permute_2 = torch.ops.aten.permute.default(wait_tensor_6, [1, 0]); wait_tensor_6 = None + mm_2 = torch.ops.aten.mm.default(view_15, permute_2); view_15 = permute_2 = None + view_30 = torch.ops.aten.view.default(mm_2, [2, 8192, 128]) + view_32 = torch.ops.aten.view.default(view_16, [2, 8192, -1, 128]); view_16 = None + view_33 = torch.ops.aten.view.default(view_23, [2, 8192, -1, 128]); view_23 = None + view_34 = torch.ops.aten.view.default(view_30, [2, 8192, -1, 128]); view_30 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default(view_32, torch.float32); view_32 = None + view_35 = torch.ops.aten.view.default(convert_element_type_13, [2, 8192, 4, -1, 2]); convert_element_type_13 = None + view_as_complex = torch.ops.aten.view_as_complex.default(view_35); view_35 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default(view_33, torch.float32); view_33 = None + view_36 = torch.ops.aten.view.default(convert_element_type_14, [2, 8192, 1, -1, 2]); convert_element_type_14 = None + view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_36); view_36 = None + view_37 = torch.ops.aten.view.default(primals_3, [1, 8192, 1, 64]) + mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_37); view_as_complex = None + view_as_real = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None + view_38 = torch.ops.aten.view.default(view_as_real, [2, 8192, 4, 128]); view_as_real = None + mul_3 = torch.ops.aten.mul.Tensor(view_as_complex_1, view_37); view_as_complex_1 = None + view_as_real_1 = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None + view_39 = torch.ops.aten.view.default(view_as_real_1, [2, 8192, 1, 128]); view_as_real_1 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default(view_38, torch.bfloat16); view_38 = None + convert_element_type_16 = torch.ops.prims.convert_element_type.default(view_39, torch.bfloat16); view_39 = None + unsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type_16, 3); convert_element_type_16 = None + expand = torch.ops.aten.expand.default(unsqueeze, [2, 8192, 1, 4, 128]); unsqueeze = None + view_40 = torch.ops.aten.view.default(expand, [2, 8192, 4, 128]); expand = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(view_34, 3); view_34 = None + expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [2, 8192, 1, 4, 128]); unsqueeze_1 = None + view_41 = torch.ops.aten.view.default(expand_1, [2, 8192, 4, 128]); expand_1 = None + permute_3 = torch.ops.aten.permute.default(convert_element_type_15, [0, 2, 1, 3]); convert_element_type_15 = None + permute_4 = torch.ops.aten.permute.default(view_40, [0, 2, 1, 3]); view_40 = None + permute_5 = torch.ops.aten.permute.default(view_41, [0, 2, 1, 3]); view_41 = None + _scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_3, permute_4, permute_5, None, True, 0.0, True); permute_3 = permute_4 = permute_5 = None + getitem_80 = _scaled_dot_product_cudnn_attention[0] + getitem_81 = _scaled_dot_product_cudnn_attention[1] + getitem_86 = _scaled_dot_product_cudnn_attention[6] + getitem_87 = _scaled_dot_product_cudnn_attention[7]; _scaled_dot_product_cudnn_attention = None + permute_6 = torch.ops.aten.permute.default(getitem_80, [0, 2, 1, 3]) + view_42 = torch.ops.aten.view.default(permute_6, [2, 8192, -1]); permute_6 = None + convert_element_type_17 = torch.ops.prims.convert_element_type.default(primals_8, torch.bfloat16) + all_gather_into_tensor_6 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_17, 8, '0'); convert_element_type_17 = None + wait_tensor_7 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None + permute_7 = torch.ops.aten.permute.default(wait_tensor_7, [1, 0]); wait_tensor_7 = None + view_48 = torch.ops.aten.view.default(view_42, [16384, 512]); view_42 = None + mm_3 = torch.ops.aten.mm.default(view_48, permute_7); view_48 = permute_7 = None + view_49 = torch.ops.aten.view.default(mm_3, [2, 8192, 4096]); mm_3 = None + split_10 = torch.ops.aten.split.Tensor(view_49, 1024, 1); view_49 = None + getitem_89 = split_10[0] + getitem_90 = split_10[1] + getitem_91 = split_10[2] + getitem_92 = split_10[3] + getitem_93 = split_10[4] + getitem_94 = split_10[5] + getitem_95 = split_10[6] + getitem_96 = split_10[7]; split_10 = None + cat_2 = torch.ops.aten.cat.default([getitem_89, getitem_90, getitem_91, getitem_92, getitem_93, getitem_94, getitem_95, getitem_96]); getitem_89 = getitem_90 = getitem_91 = getitem_92 = getitem_93 = getitem_94 = getitem_95 = getitem_96 = None + reduce_scatter_tensor_1 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_2, 'sum', 8, '1'); cat_2 = None + wait_tensor_8 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1) + add_1 = torch.ops.aten.add.Tensor(wait_tensor_1, wait_tensor_8); wait_tensor_8 = None + convert_element_type_20 = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16) + all_gather_into_tensor_7 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_20, 8, '0'); convert_element_type_20 = None + wait_tensor_9 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None + convert_element_type_21 = torch.ops.prims.convert_element_type.default(add_1, torch.float32) + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_21, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [2], True); pow_2 = None + add_2 = torch.ops.aten.add.Scalar(mean_1, 1e-05); mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_2); add_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_21, rsqrt_1); convert_element_type_21 = rsqrt_1 = None + mul_5 = torch.ops.aten.mul.Tensor(mul_4, wait_tensor_9); mul_4 = wait_tensor_9 = None + convert_element_type_22 = torch.ops.prims.convert_element_type.default(mul_5, torch.bfloat16); mul_5 = None + all_gather_into_tensor_8 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_22, 8, '1'); convert_element_type_22 = None + wait_tensor_10 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None + split_11 = torch.ops.aten.split.Tensor(wait_tensor_10, 2); wait_tensor_10 = None + getitem_97 = split_11[0] + getitem_98 = split_11[1] + getitem_99 = split_11[2] + getitem_100 = split_11[3] + getitem_101 = split_11[4] + getitem_102 = split_11[5] + getitem_103 = split_11[6] + getitem_104 = split_11[7]; split_11 = None + cat_3 = torch.ops.aten.cat.default([getitem_97, getitem_98, getitem_99, getitem_100, getitem_101, getitem_102, getitem_103, getitem_104], 1); getitem_97 = getitem_98 = getitem_99 = getitem_100 = getitem_101 = getitem_102 = getitem_103 = getitem_104 = None + convert_element_type_23 = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16) + all_gather_into_tensor_9 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_23, 8, '0'); convert_element_type_23 = None + wait_tensor_11 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None + permute_8 = torch.ops.aten.permute.default(wait_tensor_11, [1, 0]); wait_tensor_11 = None + view_60 = torch.ops.aten.view.default(cat_3, [16384, 4096]); cat_3 = None + mm_4 = torch.ops.aten.mm.default(view_60, permute_8); permute_8 = None + view_61 = torch.ops.aten.view.default(mm_4, [2, 8192, 1792]) + convert_element_type_26 = torch.ops.prims.convert_element_type.default(view_61, torch.float32); view_61 = None + sigmoid = torch.ops.aten.sigmoid.default(convert_element_type_26) + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_26, sigmoid); convert_element_type_26 = sigmoid = None + convert_element_type_27 = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None + convert_element_type_28 = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16) + all_gather_into_tensor_10 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_28, 8, '0'); convert_element_type_28 = None + wait_tensor_12 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None + permute_9 = torch.ops.aten.permute.default(wait_tensor_12, [1, 0]); wait_tensor_12 = None + mm_5 = torch.ops.aten.mm.default(view_60, permute_9); view_60 = permute_9 = None + view_68 = torch.ops.aten.view.default(mm_5, [2, 8192, 1792]); mm_5 = None + mul_7 = torch.ops.aten.mul.Tensor(convert_element_type_27, view_68); convert_element_type_27 = view_68 = None + convert_element_type_31 = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16) + all_gather_into_tensor_11 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_31, 8, '0'); convert_element_type_31 = None + wait_tensor_13 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None + permute_10 = torch.ops.aten.permute.default(wait_tensor_13, [1, 0]); wait_tensor_13 = None + view_75 = torch.ops.aten.view.default(mul_7, [16384, 1792]); mul_7 = None + mm_6 = torch.ops.aten.mm.default(view_75, permute_10); view_75 = permute_10 = None + view_76 = torch.ops.aten.view.default(mm_6, [2, 8192, 4096]); mm_6 = None + split_12 = torch.ops.aten.split.Tensor(view_76, 1024, 1); view_76 = None + getitem_105 = split_12[0] + getitem_106 = split_12[1] + getitem_107 = split_12[2] + getitem_108 = split_12[3] + getitem_109 = split_12[4] + getitem_110 = split_12[5] + getitem_111 = split_12[6] + getitem_112 = split_12[7]; split_12 = None + cat_4 = torch.ops.aten.cat.default([getitem_105, getitem_106, getitem_107, getitem_108, getitem_109, getitem_110, getitem_111, getitem_112]); getitem_105 = getitem_106 = getitem_107 = getitem_108 = getitem_109 = getitem_110 = getitem_111 = getitem_112 = None + reduce_scatter_tensor_2 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_4, 'sum', 8, '1'); cat_4 = None + wait_tensor_14 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None + add_3 = torch.ops.aten.add.Tensor(add_1, wait_tensor_14); add_1 = wait_tensor_14 = None + convert_element_type_34 = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16) + all_gather_into_tensor_12 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_34, 8, '0'); convert_element_type_34 = None + wait_tensor_15 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None + convert_element_type_35 = torch.ops.prims.convert_element_type.default(add_3, torch.float32) + pow_3 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_35, 2) + mean_2 = torch.ops.aten.mean.dim(pow_3, [2], True); pow_3 = None + add_4 = torch.ops.aten.add.Scalar(mean_2, 1e-05); mean_2 = None + rsqrt_2 = torch.ops.aten.rsqrt.default(add_4); add_4 = None + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_35, rsqrt_2); convert_element_type_35 = rsqrt_2 = None + mul_9 = torch.ops.aten.mul.Tensor(mul_8, wait_tensor_15); mul_8 = wait_tensor_15 = None + convert_element_type_36 = torch.ops.prims.convert_element_type.default(mul_9, torch.bfloat16); mul_9 = None + all_gather_into_tensor_13 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_36, 8, '1'); convert_element_type_36 = None + wait_tensor_16 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None + split_13 = torch.ops.aten.split.Tensor(wait_tensor_16, 2); wait_tensor_16 = None + getitem_113 = split_13[0] + getitem_114 = split_13[1] + getitem_115 = split_13[2] + getitem_116 = split_13[3] + getitem_117 = split_13[4] + getitem_118 = split_13[5] + getitem_119 = split_13[6] + getitem_120 = split_13[7]; split_13 = None + cat_5 = torch.ops.aten.cat.default([getitem_113, getitem_114, getitem_115, getitem_116, getitem_117, getitem_118, getitem_119, getitem_120], 1); getitem_113 = getitem_114 = getitem_115 = getitem_116 = getitem_117 = getitem_118 = getitem_119 = getitem_120 = None + convert_element_type_37 = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16) + all_gather_into_tensor_14 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_37, 8, '0'); convert_element_type_37 = None + wait_tensor_17 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None + permute_11 = torch.ops.aten.permute.default(wait_tensor_17, [1, 0]); wait_tensor_17 = None + view_87 = torch.ops.aten.view.default(cat_5, [16384, 4096]); cat_5 = None + mm_7 = torch.ops.aten.mm.default(view_87, permute_11); permute_11 = None + view_88 = torch.ops.aten.view.default(mm_7, [2, 8192, 512]) + convert_element_type_40 = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16) + all_gather_into_tensor_15 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_40, 8, '0'); convert_element_type_40 = None + wait_tensor_18 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None + permute_12 = torch.ops.aten.permute.default(wait_tensor_18, [1, 0]); wait_tensor_18 = None + mm_8 = torch.ops.aten.mm.default(view_87, permute_12); permute_12 = None + view_95 = torch.ops.aten.view.default(mm_8, [2, 8192, 128]); mm_8 = None + convert_element_type_43 = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16) + all_gather_into_tensor_16 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_43, 8, '0'); convert_element_type_43 = None + wait_tensor_19 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None + permute_13 = torch.ops.aten.permute.default(wait_tensor_19, [1, 0]); wait_tensor_19 = None + mm_9 = torch.ops.aten.mm.default(view_87, permute_13); view_87 = permute_13 = None + view_102 = torch.ops.aten.view.default(mm_9, [2, 8192, 128]) + view_104 = torch.ops.aten.view.default(view_88, [2, 8192, -1, 128]); view_88 = None + view_105 = torch.ops.aten.view.default(view_95, [2, 8192, -1, 128]); view_95 = None + view_106 = torch.ops.aten.view.default(view_102, [2, 8192, -1, 128]); view_102 = None + convert_element_type_46 = torch.ops.prims.convert_element_type.default(view_104, torch.float32); view_104 = None + view_107 = torch.ops.aten.view.default(convert_element_type_46, [2, 8192, 4, -1, 2]); convert_element_type_46 = None + view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_107); view_107 = None + convert_element_type_47 = torch.ops.prims.convert_element_type.default(view_105, torch.float32); view_105 = None + view_108 = torch.ops.aten.view.default(convert_element_type_47, [2, 8192, 1, -1, 2]); convert_element_type_47 = None + view_as_complex_3 = torch.ops.aten.view_as_complex.default(view_108); view_108 = None + mul_10 = torch.ops.aten.mul.Tensor(view_as_complex_2, view_37); view_as_complex_2 = None + view_as_real_2 = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None + view_110 = torch.ops.aten.view.default(view_as_real_2, [2, 8192, 4, 128]); view_as_real_2 = None + mul_11 = torch.ops.aten.mul.Tensor(view_as_complex_3, view_37); view_as_complex_3 = None + view_as_real_3 = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None + view_111 = torch.ops.aten.view.default(view_as_real_3, [2, 8192, 1, 128]); view_as_real_3 = None + convert_element_type_48 = torch.ops.prims.convert_element_type.default(view_110, torch.bfloat16); view_110 = None + convert_element_type_49 = torch.ops.prims.convert_element_type.default(view_111, torch.bfloat16); view_111 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(convert_element_type_49, 3); convert_element_type_49 = None + expand_2 = torch.ops.aten.expand.default(unsqueeze_2, [2, 8192, 1, 4, 128]); unsqueeze_2 = None + view_112 = torch.ops.aten.view.default(expand_2, [2, 8192, 4, 128]); expand_2 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view_106, 3); view_106 = None + expand_3 = torch.ops.aten.expand.default(unsqueeze_3, [2, 8192, 1, 4, 128]); unsqueeze_3 = None + view_113 = torch.ops.aten.view.default(expand_3, [2, 8192, 4, 128]); expand_3 = None + permute_14 = torch.ops.aten.permute.default(convert_element_type_48, [0, 2, 1, 3]); convert_element_type_48 = None + permute_15 = torch.ops.aten.permute.default(view_112, [0, 2, 1, 3]); view_112 = None + permute_16 = torch.ops.aten.permute.default(view_113, [0, 2, 1, 3]); view_113 = None + _scaled_dot_product_cudnn_attention_1 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_14, permute_15, permute_16, None, True, 0.0, True); permute_14 = permute_15 = permute_16 = None + getitem_121 = _scaled_dot_product_cudnn_attention_1[0] + getitem_122 = _scaled_dot_product_cudnn_attention_1[1] + getitem_127 = _scaled_dot_product_cudnn_attention_1[6] + getitem_128 = _scaled_dot_product_cudnn_attention_1[7]; _scaled_dot_product_cudnn_attention_1 = None + permute_17 = torch.ops.aten.permute.default(getitem_121, [0, 2, 1, 3]) + view_114 = torch.ops.aten.view.default(permute_17, [2, 8192, -1]); permute_17 = None + convert_element_type_50 = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16) + all_gather_into_tensor_17 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_50, 8, '0'); convert_element_type_50 = None + wait_tensor_20 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None + permute_18 = torch.ops.aten.permute.default(wait_tensor_20, [1, 0]); wait_tensor_20 = None + view_120 = torch.ops.aten.view.default(view_114, [16384, 512]); view_114 = None + mm_10 = torch.ops.aten.mm.default(view_120, permute_18); view_120 = permute_18 = None + view_121 = torch.ops.aten.view.default(mm_10, [2, 8192, 4096]); mm_10 = None + split_14 = torch.ops.aten.split.Tensor(view_121, 1024, 1); view_121 = None + getitem_130 = split_14[0] + getitem_131 = split_14[1] + getitem_132 = split_14[2] + getitem_133 = split_14[3] + getitem_134 = split_14[4] + getitem_135 = split_14[5] + getitem_136 = split_14[6] + getitem_137 = split_14[7]; split_14 = None + cat_6 = torch.ops.aten.cat.default([getitem_130, getitem_131, getitem_132, getitem_133, getitem_134, getitem_135, getitem_136, getitem_137]); getitem_130 = getitem_131 = getitem_132 = getitem_133 = getitem_134 = getitem_135 = getitem_136 = getitem_137 = None + reduce_scatter_tensor_3 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_6, 'sum', 8, '1'); cat_6 = None + wait_tensor_21 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_3) + add_5 = torch.ops.aten.add.Tensor(add_3, wait_tensor_21); wait_tensor_21 = None + convert_element_type_53 = torch.ops.prims.convert_element_type.default(primals_18, torch.bfloat16) + all_gather_into_tensor_18 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_53, 8, '0'); convert_element_type_53 = None + wait_tensor_22 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None + convert_element_type_54 = torch.ops.prims.convert_element_type.default(add_5, torch.float32) + pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_54, 2) + mean_3 = torch.ops.aten.mean.dim(pow_4, [2], True); pow_4 = None + add_6 = torch.ops.aten.add.Scalar(mean_3, 1e-05); mean_3 = None + rsqrt_3 = torch.ops.aten.rsqrt.default(add_6); add_6 = None + mul_12 = torch.ops.aten.mul.Tensor(convert_element_type_54, rsqrt_3); convert_element_type_54 = rsqrt_3 = None + mul_13 = torch.ops.aten.mul.Tensor(mul_12, wait_tensor_22); mul_12 = wait_tensor_22 = None + convert_element_type_55 = torch.ops.prims.convert_element_type.default(mul_13, torch.bfloat16); mul_13 = None + all_gather_into_tensor_19 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_55, 8, '1'); convert_element_type_55 = None + wait_tensor_23 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None + split_15 = torch.ops.aten.split.Tensor(wait_tensor_23, 2); wait_tensor_23 = None + getitem_138 = split_15[0] + getitem_139 = split_15[1] + getitem_140 = split_15[2] + getitem_141 = split_15[3] + getitem_142 = split_15[4] + getitem_143 = split_15[5] + getitem_144 = split_15[6] + getitem_145 = split_15[7]; split_15 = None + cat_7 = torch.ops.aten.cat.default([getitem_138, getitem_139, getitem_140, getitem_141, getitem_142, getitem_143, getitem_144, getitem_145], 1); getitem_138 = getitem_139 = getitem_140 = getitem_141 = getitem_142 = getitem_143 = getitem_144 = getitem_145 = None + convert_element_type_56 = torch.ops.prims.convert_element_type.default(primals_19, torch.bfloat16) + all_gather_into_tensor_20 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_56, 8, '0'); convert_element_type_56 = None + wait_tensor_24 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_20); all_gather_into_tensor_20 = None + permute_19 = torch.ops.aten.permute.default(wait_tensor_24, [1, 0]); wait_tensor_24 = None + view_132 = torch.ops.aten.view.default(cat_7, [16384, 4096]); cat_7 = None + mm_11 = torch.ops.aten.mm.default(view_132, permute_19); permute_19 = None + view_133 = torch.ops.aten.view.default(mm_11, [2, 8192, 1792]) + convert_element_type_59 = torch.ops.prims.convert_element_type.default(view_133, torch.float32); view_133 = None + sigmoid_1 = torch.ops.aten.sigmoid.default(convert_element_type_59) + mul_14 = torch.ops.aten.mul.Tensor(convert_element_type_59, sigmoid_1); convert_element_type_59 = sigmoid_1 = None + convert_element_type_60 = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None + convert_element_type_61 = torch.ops.prims.convert_element_type.default(primals_20, torch.bfloat16) + all_gather_into_tensor_21 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_61, 8, '0'); convert_element_type_61 = None + wait_tensor_25 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_21); all_gather_into_tensor_21 = None + permute_20 = torch.ops.aten.permute.default(wait_tensor_25, [1, 0]); wait_tensor_25 = None + mm_12 = torch.ops.aten.mm.default(view_132, permute_20); view_132 = permute_20 = None + view_140 = torch.ops.aten.view.default(mm_12, [2, 8192, 1792]); mm_12 = None + mul_15 = torch.ops.aten.mul.Tensor(convert_element_type_60, view_140); convert_element_type_60 = view_140 = None + convert_element_type_64 = torch.ops.prims.convert_element_type.default(primals_21, torch.bfloat16) + all_gather_into_tensor_22 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_64, 8, '0'); convert_element_type_64 = None + wait_tensor_26 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_22); all_gather_into_tensor_22 = None + permute_21 = torch.ops.aten.permute.default(wait_tensor_26, [1, 0]); wait_tensor_26 = None + view_147 = torch.ops.aten.view.default(mul_15, [16384, 1792]); mul_15 = None + mm_13 = torch.ops.aten.mm.default(view_147, permute_21); view_147 = permute_21 = None + view_148 = torch.ops.aten.view.default(mm_13, [2, 8192, 4096]); mm_13 = None + split_16 = torch.ops.aten.split.Tensor(view_148, 1024, 1); view_148 = None + getitem_146 = split_16[0] + getitem_147 = split_16[1] + getitem_148 = split_16[2] + getitem_149 = split_16[3] + getitem_150 = split_16[4] + getitem_151 = split_16[5] + getitem_152 = split_16[6] + getitem_153 = split_16[7]; split_16 = None + cat_8 = torch.ops.aten.cat.default([getitem_146, getitem_147, getitem_148, getitem_149, getitem_150, getitem_151, getitem_152, getitem_153]); getitem_146 = getitem_147 = getitem_148 = getitem_149 = getitem_150 = getitem_151 = getitem_152 = getitem_153 = None + reduce_scatter_tensor_4 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_8, 'sum', 8, '1'); cat_8 = None + wait_tensor_27 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_4); reduce_scatter_tensor_4 = None + add_7 = torch.ops.aten.add.Tensor(add_5, wait_tensor_27); add_5 = wait_tensor_27 = None + convert_element_type_67 = torch.ops.prims.convert_element_type.default(primals_22, torch.bfloat16) + all_gather_into_tensor_23 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_67, 8, '0'); convert_element_type_67 = None + wait_tensor_28 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_23); all_gather_into_tensor_23 = None + convert_element_type_68 = torch.ops.prims.convert_element_type.default(add_7, torch.float32) + pow_5 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_68, 2) + mean_4 = torch.ops.aten.mean.dim(pow_5, [2], True); pow_5 = None + add_8 = torch.ops.aten.add.Scalar(mean_4, 1e-05); mean_4 = None + rsqrt_4 = torch.ops.aten.rsqrt.default(add_8); add_8 = None + mul_16 = torch.ops.aten.mul.Tensor(convert_element_type_68, rsqrt_4); convert_element_type_68 = rsqrt_4 = None + mul_17 = torch.ops.aten.mul.Tensor(mul_16, wait_tensor_28); mul_16 = wait_tensor_28 = None + convert_element_type_69 = torch.ops.prims.convert_element_type.default(mul_17, torch.bfloat16); mul_17 = None + all_gather_into_tensor_24 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_69, 8, '1'); convert_element_type_69 = None + wait_tensor_29 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_24); all_gather_into_tensor_24 = None + split_17 = torch.ops.aten.split.Tensor(wait_tensor_29, 2); wait_tensor_29 = None + getitem_154 = split_17[0] + getitem_155 = split_17[1] + getitem_156 = split_17[2] + getitem_157 = split_17[3] + getitem_158 = split_17[4] + getitem_159 = split_17[5] + getitem_160 = split_17[6] + getitem_161 = split_17[7]; split_17 = None + cat_9 = torch.ops.aten.cat.default([getitem_154, getitem_155, getitem_156, getitem_157, getitem_158, getitem_159, getitem_160, getitem_161], 1); getitem_154 = getitem_155 = getitem_156 = getitem_157 = getitem_158 = getitem_159 = getitem_160 = getitem_161 = None + convert_element_type_70 = torch.ops.prims.convert_element_type.default(primals_23, torch.bfloat16) + all_gather_into_tensor_25 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_70, 8, '0'); convert_element_type_70 = None + wait_tensor_30 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_25); all_gather_into_tensor_25 = None + permute_22 = torch.ops.aten.permute.default(wait_tensor_30, [1, 0]); wait_tensor_30 = None + view_159 = torch.ops.aten.view.default(cat_9, [16384, 4096]); cat_9 = None + mm_14 = torch.ops.aten.mm.default(view_159, permute_22); permute_22 = None + view_160 = torch.ops.aten.view.default(mm_14, [2, 8192, 512]) + convert_element_type_73 = torch.ops.prims.convert_element_type.default(primals_24, torch.bfloat16) + all_gather_into_tensor_26 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_73, 8, '0'); convert_element_type_73 = None + wait_tensor_31 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_26); all_gather_into_tensor_26 = None + permute_23 = torch.ops.aten.permute.default(wait_tensor_31, [1, 0]); wait_tensor_31 = None + mm_15 = torch.ops.aten.mm.default(view_159, permute_23); permute_23 = None + view_167 = torch.ops.aten.view.default(mm_15, [2, 8192, 128]); mm_15 = None + convert_element_type_76 = torch.ops.prims.convert_element_type.default(primals_25, torch.bfloat16) + all_gather_into_tensor_27 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_76, 8, '0'); convert_element_type_76 = None + wait_tensor_32 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_27); all_gather_into_tensor_27 = None + permute_24 = torch.ops.aten.permute.default(wait_tensor_32, [1, 0]); wait_tensor_32 = None + mm_16 = torch.ops.aten.mm.default(view_159, permute_24); view_159 = permute_24 = None + view_174 = torch.ops.aten.view.default(mm_16, [2, 8192, 128]) + view_176 = torch.ops.aten.view.default(view_160, [2, 8192, -1, 128]); view_160 = None + view_177 = torch.ops.aten.view.default(view_167, [2, 8192, -1, 128]); view_167 = None + view_178 = torch.ops.aten.view.default(view_174, [2, 8192, -1, 128]); view_174 = None + convert_element_type_79 = torch.ops.prims.convert_element_type.default(view_176, torch.float32); view_176 = None + view_179 = torch.ops.aten.view.default(convert_element_type_79, [2, 8192, 4, -1, 2]); convert_element_type_79 = None + view_as_complex_4 = torch.ops.aten.view_as_complex.default(view_179); view_179 = None + convert_element_type_80 = torch.ops.prims.convert_element_type.default(view_177, torch.float32); view_177 = None + view_180 = torch.ops.aten.view.default(convert_element_type_80, [2, 8192, 1, -1, 2]); convert_element_type_80 = None + view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_180); view_180 = None + mul_18 = torch.ops.aten.mul.Tensor(view_as_complex_4, view_37); view_as_complex_4 = None + view_as_real_4 = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None + view_182 = torch.ops.aten.view.default(view_as_real_4, [2, 8192, 4, 128]); view_as_real_4 = None + mul_19 = torch.ops.aten.mul.Tensor(view_as_complex_5, view_37); view_as_complex_5 = None + view_as_real_5 = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None + view_183 = torch.ops.aten.view.default(view_as_real_5, [2, 8192, 1, 128]); view_as_real_5 = None + convert_element_type_81 = torch.ops.prims.convert_element_type.default(view_182, torch.bfloat16); view_182 = None + convert_element_type_82 = torch.ops.prims.convert_element_type.default(view_183, torch.bfloat16); view_183 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(convert_element_type_82, 3); convert_element_type_82 = None + expand_4 = torch.ops.aten.expand.default(unsqueeze_4, [2, 8192, 1, 4, 128]); unsqueeze_4 = None + view_184 = torch.ops.aten.view.default(expand_4, [2, 8192, 4, 128]); expand_4 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(view_178, 3); view_178 = None + expand_5 = torch.ops.aten.expand.default(unsqueeze_5, [2, 8192, 1, 4, 128]); unsqueeze_5 = None + view_185 = torch.ops.aten.view.default(expand_5, [2, 8192, 4, 128]); expand_5 = None + permute_25 = torch.ops.aten.permute.default(convert_element_type_81, [0, 2, 1, 3]); convert_element_type_81 = None + permute_26 = torch.ops.aten.permute.default(view_184, [0, 2, 1, 3]); view_184 = None + permute_27 = torch.ops.aten.permute.default(view_185, [0, 2, 1, 3]); view_185 = None + _scaled_dot_product_cudnn_attention_2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_25, permute_26, permute_27, None, True, 0.0, True); permute_25 = permute_26 = permute_27 = None + getitem_162 = _scaled_dot_product_cudnn_attention_2[0] + getitem_163 = _scaled_dot_product_cudnn_attention_2[1] + getitem_168 = _scaled_dot_product_cudnn_attention_2[6] + getitem_169 = _scaled_dot_product_cudnn_attention_2[7]; _scaled_dot_product_cudnn_attention_2 = None + permute_28 = torch.ops.aten.permute.default(getitem_162, [0, 2, 1, 3]) + view_186 = torch.ops.aten.view.default(permute_28, [2, 8192, -1]); permute_28 = None + convert_element_type_83 = torch.ops.prims.convert_element_type.default(primals_26, torch.bfloat16) + all_gather_into_tensor_28 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_83, 8, '0'); convert_element_type_83 = None + wait_tensor_33 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_28); all_gather_into_tensor_28 = None + permute_29 = torch.ops.aten.permute.default(wait_tensor_33, [1, 0]); wait_tensor_33 = None + view_192 = torch.ops.aten.view.default(view_186, [16384, 512]); view_186 = None + mm_17 = torch.ops.aten.mm.default(view_192, permute_29); view_192 = permute_29 = None + view_193 = torch.ops.aten.view.default(mm_17, [2, 8192, 4096]); mm_17 = None + split_18 = torch.ops.aten.split.Tensor(view_193, 1024, 1); view_193 = None + getitem_171 = split_18[0] + getitem_172 = split_18[1] + getitem_173 = split_18[2] + getitem_174 = split_18[3] + getitem_175 = split_18[4] + getitem_176 = split_18[5] + getitem_177 = split_18[6] + getitem_178 = split_18[7]; split_18 = None + cat_10 = torch.ops.aten.cat.default([getitem_171, getitem_172, getitem_173, getitem_174, getitem_175, getitem_176, getitem_177, getitem_178]); getitem_171 = getitem_172 = getitem_173 = getitem_174 = getitem_175 = getitem_176 = getitem_177 = getitem_178 = None + reduce_scatter_tensor_5 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_10, 'sum', 8, '1'); cat_10 = None + wait_tensor_34 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_5) + add_9 = torch.ops.aten.add.Tensor(add_7, wait_tensor_34); wait_tensor_34 = None + convert_element_type_86 = torch.ops.prims.convert_element_type.default(primals_27, torch.bfloat16) + all_gather_into_tensor_29 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_86, 8, '0'); convert_element_type_86 = None + wait_tensor_35 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_29); all_gather_into_tensor_29 = None + convert_element_type_87 = torch.ops.prims.convert_element_type.default(add_9, torch.float32) + pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_87, 2) + mean_5 = torch.ops.aten.mean.dim(pow_6, [2], True); pow_6 = None + add_10 = torch.ops.aten.add.Scalar(mean_5, 1e-05); mean_5 = None + rsqrt_5 = torch.ops.aten.rsqrt.default(add_10); add_10 = None + mul_20 = torch.ops.aten.mul.Tensor(convert_element_type_87, rsqrt_5); convert_element_type_87 = rsqrt_5 = None + mul_21 = torch.ops.aten.mul.Tensor(mul_20, wait_tensor_35); mul_20 = wait_tensor_35 = None + convert_element_type_88 = torch.ops.prims.convert_element_type.default(mul_21, torch.bfloat16); mul_21 = None + all_gather_into_tensor_30 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_88, 8, '1'); convert_element_type_88 = None + wait_tensor_36 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_30); all_gather_into_tensor_30 = None + split_19 = torch.ops.aten.split.Tensor(wait_tensor_36, 2); wait_tensor_36 = None + getitem_179 = split_19[0] + getitem_180 = split_19[1] + getitem_181 = split_19[2] + getitem_182 = split_19[3] + getitem_183 = split_19[4] + getitem_184 = split_19[5] + getitem_185 = split_19[6] + getitem_186 = split_19[7]; split_19 = None + cat_11 = torch.ops.aten.cat.default([getitem_179, getitem_180, getitem_181, getitem_182, getitem_183, getitem_184, getitem_185, getitem_186], 1); getitem_179 = getitem_180 = getitem_181 = getitem_182 = getitem_183 = getitem_184 = getitem_185 = getitem_186 = None + convert_element_type_89 = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16) + all_gather_into_tensor_31 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_89, 8, '0'); convert_element_type_89 = None + wait_tensor_37 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_31); all_gather_into_tensor_31 = None + permute_30 = torch.ops.aten.permute.default(wait_tensor_37, [1, 0]); wait_tensor_37 = None + view_204 = torch.ops.aten.view.default(cat_11, [16384, 4096]); cat_11 = None + mm_18 = torch.ops.aten.mm.default(view_204, permute_30); permute_30 = None + view_205 = torch.ops.aten.view.default(mm_18, [2, 8192, 1792]) + convert_element_type_92 = torch.ops.prims.convert_element_type.default(view_205, torch.float32); view_205 = None + sigmoid_2 = torch.ops.aten.sigmoid.default(convert_element_type_92) + mul_22 = torch.ops.aten.mul.Tensor(convert_element_type_92, sigmoid_2); convert_element_type_92 = sigmoid_2 = None + convert_element_type_93 = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None + convert_element_type_94 = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16) + all_gather_into_tensor_32 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_94, 8, '0'); convert_element_type_94 = None + wait_tensor_38 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_32); all_gather_into_tensor_32 = None + permute_31 = torch.ops.aten.permute.default(wait_tensor_38, [1, 0]); wait_tensor_38 = None + mm_19 = torch.ops.aten.mm.default(view_204, permute_31); view_204 = permute_31 = None + view_212 = torch.ops.aten.view.default(mm_19, [2, 8192, 1792]); mm_19 = None + mul_23 = torch.ops.aten.mul.Tensor(convert_element_type_93, view_212); convert_element_type_93 = view_212 = None + convert_element_type_97 = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16) + all_gather_into_tensor_33 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_97, 8, '0'); convert_element_type_97 = None + wait_tensor_39 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_33); all_gather_into_tensor_33 = None + permute_32 = torch.ops.aten.permute.default(wait_tensor_39, [1, 0]); wait_tensor_39 = None + view_219 = torch.ops.aten.view.default(mul_23, [16384, 1792]); mul_23 = None + mm_20 = torch.ops.aten.mm.default(view_219, permute_32); view_219 = permute_32 = None + view_220 = torch.ops.aten.view.default(mm_20, [2, 8192, 4096]); mm_20 = None + split_20 = torch.ops.aten.split.Tensor(view_220, 1024, 1); view_220 = None + getitem_187 = split_20[0] + getitem_188 = split_20[1] + getitem_189 = split_20[2] + getitem_190 = split_20[3] + getitem_191 = split_20[4] + getitem_192 = split_20[5] + getitem_193 = split_20[6] + getitem_194 = split_20[7]; split_20 = None + cat_12 = torch.ops.aten.cat.default([getitem_187, getitem_188, getitem_189, getitem_190, getitem_191, getitem_192, getitem_193, getitem_194]); getitem_187 = getitem_188 = getitem_189 = getitem_190 = getitem_191 = getitem_192 = getitem_193 = getitem_194 = None + reduce_scatter_tensor_6 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_12, 'sum', 8, '1'); cat_12 = None + wait_tensor_40 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_6); reduce_scatter_tensor_6 = None + add_11 = torch.ops.aten.add.Tensor(add_9, wait_tensor_40); add_9 = wait_tensor_40 = None + convert_element_type_100 = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16) + all_gather_into_tensor_34 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_100, 8, '0'); convert_element_type_100 = None + wait_tensor_41 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_34); all_gather_into_tensor_34 = None + convert_element_type_101 = torch.ops.prims.convert_element_type.default(add_11, torch.float32) + pow_7 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_101, 2) + mean_6 = torch.ops.aten.mean.dim(pow_7, [2], True); pow_7 = None + add_12 = torch.ops.aten.add.Scalar(mean_6, 1e-05); mean_6 = None + rsqrt_6 = torch.ops.aten.rsqrt.default(add_12); add_12 = None + mul_24 = torch.ops.aten.mul.Tensor(convert_element_type_101, rsqrt_6); convert_element_type_101 = rsqrt_6 = None + mul_25 = torch.ops.aten.mul.Tensor(mul_24, wait_tensor_41); mul_24 = wait_tensor_41 = None + convert_element_type_102 = torch.ops.prims.convert_element_type.default(mul_25, torch.bfloat16); mul_25 = None + all_gather_into_tensor_35 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_102, 8, '1'); convert_element_type_102 = None + wait_tensor_42 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_35); all_gather_into_tensor_35 = None + split_21 = torch.ops.aten.split.Tensor(wait_tensor_42, 2); wait_tensor_42 = None + getitem_195 = split_21[0] + getitem_196 = split_21[1] + getitem_197 = split_21[2] + getitem_198 = split_21[3] + getitem_199 = split_21[4] + getitem_200 = split_21[5] + getitem_201 = split_21[6] + getitem_202 = split_21[7]; split_21 = None + cat_13 = torch.ops.aten.cat.default([getitem_195, getitem_196, getitem_197, getitem_198, getitem_199, getitem_200, getitem_201, getitem_202], 1); getitem_195 = getitem_196 = getitem_197 = getitem_198 = getitem_199 = getitem_200 = getitem_201 = getitem_202 = None + convert_element_type_103 = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16) + all_gather_into_tensor_36 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_103, 8, '0'); convert_element_type_103 = None + wait_tensor_43 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_36); all_gather_into_tensor_36 = None + permute_33 = torch.ops.aten.permute.default(wait_tensor_43, [1, 0]); wait_tensor_43 = None + view_231 = torch.ops.aten.view.default(cat_13, [16384, 4096]); cat_13 = None + mm_21 = torch.ops.aten.mm.default(view_231, permute_33); permute_33 = None + view_232 = torch.ops.aten.view.default(mm_21, [2, 8192, 512]) + convert_element_type_106 = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16) + all_gather_into_tensor_37 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_106, 8, '0'); convert_element_type_106 = None + wait_tensor_44 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_37); all_gather_into_tensor_37 = None + permute_34 = torch.ops.aten.permute.default(wait_tensor_44, [1, 0]); wait_tensor_44 = None + mm_22 = torch.ops.aten.mm.default(view_231, permute_34); permute_34 = None + view_239 = torch.ops.aten.view.default(mm_22, [2, 8192, 128]); mm_22 = None + convert_element_type_109 = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16) + all_gather_into_tensor_38 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_109, 8, '0'); convert_element_type_109 = None + wait_tensor_45 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_38); all_gather_into_tensor_38 = None + permute_35 = torch.ops.aten.permute.default(wait_tensor_45, [1, 0]); wait_tensor_45 = None + mm_23 = torch.ops.aten.mm.default(view_231, permute_35); view_231 = permute_35 = None + view_246 = torch.ops.aten.view.default(mm_23, [2, 8192, 128]) + view_248 = torch.ops.aten.view.default(view_232, [2, 8192, -1, 128]); view_232 = None + view_249 = torch.ops.aten.view.default(view_239, [2, 8192, -1, 128]); view_239 = None + view_250 = torch.ops.aten.view.default(view_246, [2, 8192, -1, 128]); view_246 = None + convert_element_type_112 = torch.ops.prims.convert_element_type.default(view_248, torch.float32); view_248 = None + view_251 = torch.ops.aten.view.default(convert_element_type_112, [2, 8192, 4, -1, 2]); convert_element_type_112 = None + view_as_complex_6 = torch.ops.aten.view_as_complex.default(view_251); view_251 = None + convert_element_type_113 = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None + view_252 = torch.ops.aten.view.default(convert_element_type_113, [2, 8192, 1, -1, 2]); convert_element_type_113 = None + view_as_complex_7 = torch.ops.aten.view_as_complex.default(view_252); view_252 = None + mul_26 = torch.ops.aten.mul.Tensor(view_as_complex_6, view_37); view_as_complex_6 = None + view_as_real_6 = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None + view_254 = torch.ops.aten.view.default(view_as_real_6, [2, 8192, 4, 128]); view_as_real_6 = None + mul_27 = torch.ops.aten.mul.Tensor(view_as_complex_7, view_37); view_as_complex_7 = None + view_as_real_7 = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None + view_255 = torch.ops.aten.view.default(view_as_real_7, [2, 8192, 1, 128]); view_as_real_7 = None + convert_element_type_114 = torch.ops.prims.convert_element_type.default(view_254, torch.bfloat16); view_254 = None + convert_element_type_115 = torch.ops.prims.convert_element_type.default(view_255, torch.bfloat16); view_255 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(convert_element_type_115, 3); convert_element_type_115 = None + expand_6 = torch.ops.aten.expand.default(unsqueeze_6, [2, 8192, 1, 4, 128]); unsqueeze_6 = None + view_256 = torch.ops.aten.view.default(expand_6, [2, 8192, 4, 128]); expand_6 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(view_250, 3); view_250 = None + expand_7 = torch.ops.aten.expand.default(unsqueeze_7, [2, 8192, 1, 4, 128]); unsqueeze_7 = None + view_257 = torch.ops.aten.view.default(expand_7, [2, 8192, 4, 128]); expand_7 = None + permute_36 = torch.ops.aten.permute.default(convert_element_type_114, [0, 2, 1, 3]); convert_element_type_114 = None + permute_37 = torch.ops.aten.permute.default(view_256, [0, 2, 1, 3]); view_256 = None + permute_38 = torch.ops.aten.permute.default(view_257, [0, 2, 1, 3]); view_257 = None + _scaled_dot_product_cudnn_attention_3 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_36, permute_37, permute_38, None, True, 0.0, True); permute_36 = permute_37 = permute_38 = None + getitem_203 = _scaled_dot_product_cudnn_attention_3[0] + getitem_204 = _scaled_dot_product_cudnn_attention_3[1] + getitem_209 = _scaled_dot_product_cudnn_attention_3[6] + getitem_210 = _scaled_dot_product_cudnn_attention_3[7]; _scaled_dot_product_cudnn_attention_3 = None + permute_39 = torch.ops.aten.permute.default(getitem_203, [0, 2, 1, 3]) + view_258 = torch.ops.aten.view.default(permute_39, [2, 8192, -1]); permute_39 = None + convert_element_type_116 = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16) + all_gather_into_tensor_39 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_116, 8, '0'); convert_element_type_116 = None + wait_tensor_46 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_39); all_gather_into_tensor_39 = None + permute_40 = torch.ops.aten.permute.default(wait_tensor_46, [1, 0]); wait_tensor_46 = None + view_264 = torch.ops.aten.view.default(view_258, [16384, 512]); view_258 = None + mm_24 = torch.ops.aten.mm.default(view_264, permute_40); view_264 = permute_40 = None + view_265 = torch.ops.aten.view.default(mm_24, [2, 8192, 4096]); mm_24 = None + split_22 = torch.ops.aten.split.Tensor(view_265, 1024, 1); view_265 = None + getitem_212 = split_22[0] + getitem_213 = split_22[1] + getitem_214 = split_22[2] + getitem_215 = split_22[3] + getitem_216 = split_22[4] + getitem_217 = split_22[5] + getitem_218 = split_22[6] + getitem_219 = split_22[7]; split_22 = None + cat_14 = torch.ops.aten.cat.default([getitem_212, getitem_213, getitem_214, getitem_215, getitem_216, getitem_217, getitem_218, getitem_219]); getitem_212 = getitem_213 = getitem_214 = getitem_215 = getitem_216 = getitem_217 = getitem_218 = getitem_219 = None + reduce_scatter_tensor_7 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_14, 'sum', 8, '1'); cat_14 = None + wait_tensor_47 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_7) + add_13 = torch.ops.aten.add.Tensor(add_11, wait_tensor_47); wait_tensor_47 = None + convert_element_type_119 = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16) + all_gather_into_tensor_40 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_119, 8, '0'); convert_element_type_119 = None + wait_tensor_48 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_40); all_gather_into_tensor_40 = None + convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_13, torch.float32) + pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_120, 2) + mean_7 = torch.ops.aten.mean.dim(pow_8, [2], True); pow_8 = None + add_14 = torch.ops.aten.add.Scalar(mean_7, 1e-05); mean_7 = None + rsqrt_7 = torch.ops.aten.rsqrt.default(add_14); add_14 = None + mul_28 = torch.ops.aten.mul.Tensor(convert_element_type_120, rsqrt_7); convert_element_type_120 = rsqrt_7 = None + mul_29 = torch.ops.aten.mul.Tensor(mul_28, wait_tensor_48); mul_28 = wait_tensor_48 = None + convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_29, torch.bfloat16); mul_29 = None + all_gather_into_tensor_41 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_121, 8, '1'); convert_element_type_121 = None + wait_tensor_49 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_41); all_gather_into_tensor_41 = None + split_23 = torch.ops.aten.split.Tensor(wait_tensor_49, 2); wait_tensor_49 = None + getitem_220 = split_23[0] + getitem_221 = split_23[1] + getitem_222 = split_23[2] + getitem_223 = split_23[3] + getitem_224 = split_23[4] + getitem_225 = split_23[5] + getitem_226 = split_23[6] + getitem_227 = split_23[7]; split_23 = None + cat_15 = torch.ops.aten.cat.default([getitem_220, getitem_221, getitem_222, getitem_223, getitem_224, getitem_225, getitem_226, getitem_227], 1); getitem_220 = getitem_221 = getitem_222 = getitem_223 = getitem_224 = getitem_225 = getitem_226 = getitem_227 = None + convert_element_type_122 = torch.ops.prims.convert_element_type.default(primals_37, torch.bfloat16) + all_gather_into_tensor_42 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_122, 8, '0'); convert_element_type_122 = None + wait_tensor_50 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_42); all_gather_into_tensor_42 = None + permute_41 = torch.ops.aten.permute.default(wait_tensor_50, [1, 0]); wait_tensor_50 = None + view_276 = torch.ops.aten.view.default(cat_15, [16384, 4096]); cat_15 = None + mm_25 = torch.ops.aten.mm.default(view_276, permute_41); permute_41 = None + view_277 = torch.ops.aten.view.default(mm_25, [2, 8192, 1792]) + convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_277, torch.float32); view_277 = None + sigmoid_3 = torch.ops.aten.sigmoid.default(convert_element_type_125) + mul_30 = torch.ops.aten.mul.Tensor(convert_element_type_125, sigmoid_3); convert_element_type_125 = sigmoid_3 = None + convert_element_type_126 = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None + convert_element_type_127 = torch.ops.prims.convert_element_type.default(primals_38, torch.bfloat16) + all_gather_into_tensor_43 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_127, 8, '0'); convert_element_type_127 = None + wait_tensor_51 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_43); all_gather_into_tensor_43 = None + permute_42 = torch.ops.aten.permute.default(wait_tensor_51, [1, 0]); wait_tensor_51 = None + mm_26 = torch.ops.aten.mm.default(view_276, permute_42); view_276 = permute_42 = None + view_284 = torch.ops.aten.view.default(mm_26, [2, 8192, 1792]); mm_26 = None + mul_31 = torch.ops.aten.mul.Tensor(convert_element_type_126, view_284); convert_element_type_126 = view_284 = None + convert_element_type_130 = torch.ops.prims.convert_element_type.default(primals_39, torch.bfloat16) + all_gather_into_tensor_44 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_130, 8, '0'); convert_element_type_130 = None + wait_tensor_52 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_44); all_gather_into_tensor_44 = None + permute_43 = torch.ops.aten.permute.default(wait_tensor_52, [1, 0]); wait_tensor_52 = None + view_291 = torch.ops.aten.view.default(mul_31, [16384, 1792]); mul_31 = None + mm_27 = torch.ops.aten.mm.default(view_291, permute_43); view_291 = permute_43 = None + view_292 = torch.ops.aten.view.default(mm_27, [2, 8192, 4096]); mm_27 = None + split_24 = torch.ops.aten.split.Tensor(view_292, 1024, 1); view_292 = None + getitem_228 = split_24[0] + getitem_229 = split_24[1] + getitem_230 = split_24[2] + getitem_231 = split_24[3] + getitem_232 = split_24[4] + getitem_233 = split_24[5] + getitem_234 = split_24[6] + getitem_235 = split_24[7]; split_24 = None + cat_16 = torch.ops.aten.cat.default([getitem_228, getitem_229, getitem_230, getitem_231, getitem_232, getitem_233, getitem_234, getitem_235]); getitem_228 = getitem_229 = getitem_230 = getitem_231 = getitem_232 = getitem_233 = getitem_234 = getitem_235 = None + reduce_scatter_tensor_8 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_16, 'sum', 8, '1'); cat_16 = None + wait_tensor_53 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_8); reduce_scatter_tensor_8 = None + add_15 = torch.ops.aten.add.Tensor(add_13, wait_tensor_53); add_13 = wait_tensor_53 = None + convert_element_type_133 = torch.ops.prims.convert_element_type.default(primals_40, torch.bfloat16) + all_gather_into_tensor_45 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_133, 8, '0'); convert_element_type_133 = None + wait_tensor_54 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_45); all_gather_into_tensor_45 = None + convert_element_type_134 = torch.ops.prims.convert_element_type.default(add_15, torch.float32) + pow_9 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_134, 2) + mean_8 = torch.ops.aten.mean.dim(pow_9, [2], True); pow_9 = None + add_16 = torch.ops.aten.add.Scalar(mean_8, 1e-05); mean_8 = None + rsqrt_8 = torch.ops.aten.rsqrt.default(add_16); add_16 = None + mul_32 = torch.ops.aten.mul.Tensor(convert_element_type_134, rsqrt_8); convert_element_type_134 = rsqrt_8 = None + mul_33 = torch.ops.aten.mul.Tensor(mul_32, wait_tensor_54); mul_32 = wait_tensor_54 = None + convert_element_type_135 = torch.ops.prims.convert_element_type.default(mul_33, torch.bfloat16); mul_33 = None + all_gather_into_tensor_46 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_135, 8, '1'); convert_element_type_135 = None + wait_tensor_55 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_46); all_gather_into_tensor_46 = None + split_25 = torch.ops.aten.split.Tensor(wait_tensor_55, 2); wait_tensor_55 = None + getitem_236 = split_25[0] + getitem_237 = split_25[1] + getitem_238 = split_25[2] + getitem_239 = split_25[3] + getitem_240 = split_25[4] + getitem_241 = split_25[5] + getitem_242 = split_25[6] + getitem_243 = split_25[7]; split_25 = None + cat_17 = torch.ops.aten.cat.default([getitem_236, getitem_237, getitem_238, getitem_239, getitem_240, getitem_241, getitem_242, getitem_243], 1); getitem_236 = getitem_237 = getitem_238 = getitem_239 = getitem_240 = getitem_241 = getitem_242 = getitem_243 = None + convert_element_type_136 = torch.ops.prims.convert_element_type.default(primals_41, torch.bfloat16) + all_gather_into_tensor_47 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_136, 8, '0'); convert_element_type_136 = None + wait_tensor_56 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_47); all_gather_into_tensor_47 = None + permute_44 = torch.ops.aten.permute.default(wait_tensor_56, [1, 0]); wait_tensor_56 = None + view_303 = torch.ops.aten.view.default(cat_17, [16384, 4096]); cat_17 = None + mm_28 = torch.ops.aten.mm.default(view_303, permute_44); permute_44 = None + view_304 = torch.ops.aten.view.default(mm_28, [2, 8192, 512]) + convert_element_type_139 = torch.ops.prims.convert_element_type.default(primals_42, torch.bfloat16) + all_gather_into_tensor_48 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_139, 8, '0'); convert_element_type_139 = None + wait_tensor_57 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_48); all_gather_into_tensor_48 = None + permute_45 = torch.ops.aten.permute.default(wait_tensor_57, [1, 0]); wait_tensor_57 = None + mm_29 = torch.ops.aten.mm.default(view_303, permute_45); permute_45 = None + view_311 = torch.ops.aten.view.default(mm_29, [2, 8192, 128]); mm_29 = None + convert_element_type_142 = torch.ops.prims.convert_element_type.default(primals_43, torch.bfloat16) + all_gather_into_tensor_49 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_142, 8, '0'); convert_element_type_142 = None + wait_tensor_58 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_49); all_gather_into_tensor_49 = None + permute_46 = torch.ops.aten.permute.default(wait_tensor_58, [1, 0]); wait_tensor_58 = None + mm_30 = torch.ops.aten.mm.default(view_303, permute_46); view_303 = permute_46 = None + view_318 = torch.ops.aten.view.default(mm_30, [2, 8192, 128]) + view_320 = torch.ops.aten.view.default(view_304, [2, 8192, -1, 128]); view_304 = None + view_321 = torch.ops.aten.view.default(view_311, [2, 8192, -1, 128]); view_311 = None + view_322 = torch.ops.aten.view.default(view_318, [2, 8192, -1, 128]); view_318 = None + convert_element_type_145 = torch.ops.prims.convert_element_type.default(view_320, torch.float32); view_320 = None + view_323 = torch.ops.aten.view.default(convert_element_type_145, [2, 8192, 4, -1, 2]); convert_element_type_145 = None + view_as_complex_8 = torch.ops.aten.view_as_complex.default(view_323); view_323 = None + convert_element_type_146 = torch.ops.prims.convert_element_type.default(view_321, torch.float32); view_321 = None + view_324 = torch.ops.aten.view.default(convert_element_type_146, [2, 8192, 1, -1, 2]); convert_element_type_146 = None + view_as_complex_9 = torch.ops.aten.view_as_complex.default(view_324); view_324 = None + mul_34 = torch.ops.aten.mul.Tensor(view_as_complex_8, view_37); view_as_complex_8 = None + view_as_real_8 = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None + view_326 = torch.ops.aten.view.default(view_as_real_8, [2, 8192, 4, 128]); view_as_real_8 = None + mul_35 = torch.ops.aten.mul.Tensor(view_as_complex_9, view_37); view_as_complex_9 = None + view_as_real_9 = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None + view_327 = torch.ops.aten.view.default(view_as_real_9, [2, 8192, 1, 128]); view_as_real_9 = None + convert_element_type_147 = torch.ops.prims.convert_element_type.default(view_326, torch.bfloat16); view_326 = None + convert_element_type_148 = torch.ops.prims.convert_element_type.default(view_327, torch.bfloat16); view_327 = None + unsqueeze_8 = torch.ops.aten.unsqueeze.default(convert_element_type_148, 3); convert_element_type_148 = None + expand_8 = torch.ops.aten.expand.default(unsqueeze_8, [2, 8192, 1, 4, 128]); unsqueeze_8 = None + view_328 = torch.ops.aten.view.default(expand_8, [2, 8192, 4, 128]); expand_8 = None + unsqueeze_9 = torch.ops.aten.unsqueeze.default(view_322, 3); view_322 = None + expand_9 = torch.ops.aten.expand.default(unsqueeze_9, [2, 8192, 1, 4, 128]); unsqueeze_9 = None + view_329 = torch.ops.aten.view.default(expand_9, [2, 8192, 4, 128]); expand_9 = None + permute_47 = torch.ops.aten.permute.default(convert_element_type_147, [0, 2, 1, 3]); convert_element_type_147 = None + permute_48 = torch.ops.aten.permute.default(view_328, [0, 2, 1, 3]); view_328 = None + permute_49 = torch.ops.aten.permute.default(view_329, [0, 2, 1, 3]); view_329 = None + _scaled_dot_product_cudnn_attention_4 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_47, permute_48, permute_49, None, True, 0.0, True); permute_47 = permute_48 = permute_49 = None + getitem_244 = _scaled_dot_product_cudnn_attention_4[0] + getitem_245 = _scaled_dot_product_cudnn_attention_4[1] + getitem_250 = _scaled_dot_product_cudnn_attention_4[6] + getitem_251 = _scaled_dot_product_cudnn_attention_4[7]; _scaled_dot_product_cudnn_attention_4 = None + permute_50 = torch.ops.aten.permute.default(getitem_244, [0, 2, 1, 3]) + view_330 = torch.ops.aten.view.default(permute_50, [2, 8192, -1]); permute_50 = None + convert_element_type_149 = torch.ops.prims.convert_element_type.default(primals_44, torch.bfloat16) + all_gather_into_tensor_50 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_149, 8, '0'); convert_element_type_149 = None + wait_tensor_59 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_50); all_gather_into_tensor_50 = None + permute_51 = torch.ops.aten.permute.default(wait_tensor_59, [1, 0]); wait_tensor_59 = None + view_336 = torch.ops.aten.view.default(view_330, [16384, 512]); view_330 = None + mm_31 = torch.ops.aten.mm.default(view_336, permute_51); view_336 = permute_51 = None + view_337 = torch.ops.aten.view.default(mm_31, [2, 8192, 4096]); mm_31 = None + split_26 = torch.ops.aten.split.Tensor(view_337, 1024, 1); view_337 = None + getitem_253 = split_26[0] + getitem_254 = split_26[1] + getitem_255 = split_26[2] + getitem_256 = split_26[3] + getitem_257 = split_26[4] + getitem_258 = split_26[5] + getitem_259 = split_26[6] + getitem_260 = split_26[7]; split_26 = None + cat_18 = torch.ops.aten.cat.default([getitem_253, getitem_254, getitem_255, getitem_256, getitem_257, getitem_258, getitem_259, getitem_260]); getitem_253 = getitem_254 = getitem_255 = getitem_256 = getitem_257 = getitem_258 = getitem_259 = getitem_260 = None + reduce_scatter_tensor_9 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_18, 'sum', 8, '1'); cat_18 = None + wait_tensor_60 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_9) + add_17 = torch.ops.aten.add.Tensor(add_15, wait_tensor_60); wait_tensor_60 = None + convert_element_type_152 = torch.ops.prims.convert_element_type.default(primals_45, torch.bfloat16) + all_gather_into_tensor_51 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_152, 8, '0'); convert_element_type_152 = None + wait_tensor_61 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_51); all_gather_into_tensor_51 = None + convert_element_type_153 = torch.ops.prims.convert_element_type.default(add_17, torch.float32) + pow_10 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_153, 2) + mean_9 = torch.ops.aten.mean.dim(pow_10, [2], True); pow_10 = None + add_18 = torch.ops.aten.add.Scalar(mean_9, 1e-05); mean_9 = None + rsqrt_9 = torch.ops.aten.rsqrt.default(add_18); add_18 = None + mul_36 = torch.ops.aten.mul.Tensor(convert_element_type_153, rsqrt_9); convert_element_type_153 = rsqrt_9 = None + mul_37 = torch.ops.aten.mul.Tensor(mul_36, wait_tensor_61); mul_36 = wait_tensor_61 = None + convert_element_type_154 = torch.ops.prims.convert_element_type.default(mul_37, torch.bfloat16); mul_37 = None + all_gather_into_tensor_52 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_154, 8, '1'); convert_element_type_154 = None + wait_tensor_62 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_52); all_gather_into_tensor_52 = None + split_27 = torch.ops.aten.split.Tensor(wait_tensor_62, 2); wait_tensor_62 = None + getitem_261 = split_27[0] + getitem_262 = split_27[1] + getitem_263 = split_27[2] + getitem_264 = split_27[3] + getitem_265 = split_27[4] + getitem_266 = split_27[5] + getitem_267 = split_27[6] + getitem_268 = split_27[7]; split_27 = None + cat_19 = torch.ops.aten.cat.default([getitem_261, getitem_262, getitem_263, getitem_264, getitem_265, getitem_266, getitem_267, getitem_268], 1); getitem_261 = getitem_262 = getitem_263 = getitem_264 = getitem_265 = getitem_266 = getitem_267 = getitem_268 = None + convert_element_type_155 = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16) + all_gather_into_tensor_53 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_155, 8, '0'); convert_element_type_155 = None + wait_tensor_63 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_53); all_gather_into_tensor_53 = None + permute_52 = torch.ops.aten.permute.default(wait_tensor_63, [1, 0]); wait_tensor_63 = None + view_348 = torch.ops.aten.view.default(cat_19, [16384, 4096]); cat_19 = None + mm_32 = torch.ops.aten.mm.default(view_348, permute_52); permute_52 = None + view_349 = torch.ops.aten.view.default(mm_32, [2, 8192, 1792]) + convert_element_type_158 = torch.ops.prims.convert_element_type.default(view_349, torch.float32); view_349 = None + sigmoid_4 = torch.ops.aten.sigmoid.default(convert_element_type_158) + mul_38 = torch.ops.aten.mul.Tensor(convert_element_type_158, sigmoid_4); convert_element_type_158 = sigmoid_4 = None + convert_element_type_159 = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None + convert_element_type_160 = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16) + all_gather_into_tensor_54 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_160, 8, '0'); convert_element_type_160 = None + wait_tensor_64 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_54); all_gather_into_tensor_54 = None + permute_53 = torch.ops.aten.permute.default(wait_tensor_64, [1, 0]); wait_tensor_64 = None + mm_33 = torch.ops.aten.mm.default(view_348, permute_53); view_348 = permute_53 = None + view_356 = torch.ops.aten.view.default(mm_33, [2, 8192, 1792]); mm_33 = None + mul_39 = torch.ops.aten.mul.Tensor(convert_element_type_159, view_356); convert_element_type_159 = view_356 = None + convert_element_type_163 = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16) + all_gather_into_tensor_55 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_163, 8, '0'); convert_element_type_163 = None + wait_tensor_65 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_55); all_gather_into_tensor_55 = None + permute_54 = torch.ops.aten.permute.default(wait_tensor_65, [1, 0]); wait_tensor_65 = None + view_363 = torch.ops.aten.view.default(mul_39, [16384, 1792]); mul_39 = None + mm_34 = torch.ops.aten.mm.default(view_363, permute_54); view_363 = permute_54 = None + view_364 = torch.ops.aten.view.default(mm_34, [2, 8192, 4096]); mm_34 = None + split_28 = torch.ops.aten.split.Tensor(view_364, 1024, 1); view_364 = None + getitem_269 = split_28[0] + getitem_270 = split_28[1] + getitem_271 = split_28[2] + getitem_272 = split_28[3] + getitem_273 = split_28[4] + getitem_274 = split_28[5] + getitem_275 = split_28[6] + getitem_276 = split_28[7]; split_28 = None + cat_20 = torch.ops.aten.cat.default([getitem_269, getitem_270, getitem_271, getitem_272, getitem_273, getitem_274, getitem_275, getitem_276]); getitem_269 = getitem_270 = getitem_271 = getitem_272 = getitem_273 = getitem_274 = getitem_275 = getitem_276 = None + reduce_scatter_tensor_10 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_20, 'sum', 8, '1'); cat_20 = None + wait_tensor_66 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_10); reduce_scatter_tensor_10 = None + add_19 = torch.ops.aten.add.Tensor(add_17, wait_tensor_66); add_17 = wait_tensor_66 = None + convert_element_type_166 = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16) + all_gather_into_tensor_56 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_166, 8, '0'); convert_element_type_166 = None + wait_tensor_67 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_56); all_gather_into_tensor_56 = None + convert_element_type_167 = torch.ops.prims.convert_element_type.default(add_19, torch.float32) + pow_11 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_167, 2) + mean_10 = torch.ops.aten.mean.dim(pow_11, [2], True); pow_11 = None + add_20 = torch.ops.aten.add.Scalar(mean_10, 1e-05); mean_10 = None + rsqrt_10 = torch.ops.aten.rsqrt.default(add_20); add_20 = None + mul_40 = torch.ops.aten.mul.Tensor(convert_element_type_167, rsqrt_10); convert_element_type_167 = rsqrt_10 = None + mul_41 = torch.ops.aten.mul.Tensor(mul_40, wait_tensor_67); mul_40 = wait_tensor_67 = None + convert_element_type_168 = torch.ops.prims.convert_element_type.default(mul_41, torch.bfloat16); mul_41 = None + all_gather_into_tensor_57 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_168, 8, '1'); convert_element_type_168 = None + wait_tensor_68 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_57); all_gather_into_tensor_57 = None + split_29 = torch.ops.aten.split.Tensor(wait_tensor_68, 2); wait_tensor_68 = None + getitem_277 = split_29[0] + getitem_278 = split_29[1] + getitem_279 = split_29[2] + getitem_280 = split_29[3] + getitem_281 = split_29[4] + getitem_282 = split_29[5] + getitem_283 = split_29[6] + getitem_284 = split_29[7]; split_29 = None + cat_21 = torch.ops.aten.cat.default([getitem_277, getitem_278, getitem_279, getitem_280, getitem_281, getitem_282, getitem_283, getitem_284], 1); getitem_277 = getitem_278 = getitem_279 = getitem_280 = getitem_281 = getitem_282 = getitem_283 = getitem_284 = None + convert_element_type_169 = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16) + all_gather_into_tensor_58 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_169, 8, '0'); convert_element_type_169 = None + wait_tensor_69 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_58); all_gather_into_tensor_58 = None + permute_55 = torch.ops.aten.permute.default(wait_tensor_69, [1, 0]); wait_tensor_69 = None + view_375 = torch.ops.aten.view.default(cat_21, [16384, 4096]); cat_21 = None + mm_35 = torch.ops.aten.mm.default(view_375, permute_55); permute_55 = None + view_376 = torch.ops.aten.view.default(mm_35, [2, 8192, 512]) + convert_element_type_172 = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16) + all_gather_into_tensor_59 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_172, 8, '0'); convert_element_type_172 = None + wait_tensor_70 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_59); all_gather_into_tensor_59 = None + permute_56 = torch.ops.aten.permute.default(wait_tensor_70, [1, 0]); wait_tensor_70 = None + mm_36 = torch.ops.aten.mm.default(view_375, permute_56); permute_56 = None + view_383 = torch.ops.aten.view.default(mm_36, [2, 8192, 128]); mm_36 = None + convert_element_type_175 = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16) + all_gather_into_tensor_60 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_175, 8, '0'); convert_element_type_175 = None + wait_tensor_71 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_60); all_gather_into_tensor_60 = None + permute_57 = torch.ops.aten.permute.default(wait_tensor_71, [1, 0]); wait_tensor_71 = None + mm_37 = torch.ops.aten.mm.default(view_375, permute_57); view_375 = permute_57 = None + view_390 = torch.ops.aten.view.default(mm_37, [2, 8192, 128]) + view_392 = torch.ops.aten.view.default(view_376, [2, 8192, -1, 128]); view_376 = None + view_393 = torch.ops.aten.view.default(view_383, [2, 8192, -1, 128]); view_383 = None + view_394 = torch.ops.aten.view.default(view_390, [2, 8192, -1, 128]); view_390 = None + convert_element_type_178 = torch.ops.prims.convert_element_type.default(view_392, torch.float32); view_392 = None + view_395 = torch.ops.aten.view.default(convert_element_type_178, [2, 8192, 4, -1, 2]); convert_element_type_178 = None + view_as_complex_10 = torch.ops.aten.view_as_complex.default(view_395); view_395 = None + convert_element_type_179 = torch.ops.prims.convert_element_type.default(view_393, torch.float32); view_393 = None + view_396 = torch.ops.aten.view.default(convert_element_type_179, [2, 8192, 1, -1, 2]); convert_element_type_179 = None + view_as_complex_11 = torch.ops.aten.view_as_complex.default(view_396); view_396 = None + mul_42 = torch.ops.aten.mul.Tensor(view_as_complex_10, view_37); view_as_complex_10 = None + view_as_real_10 = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None + view_398 = torch.ops.aten.view.default(view_as_real_10, [2, 8192, 4, 128]); view_as_real_10 = None + mul_43 = torch.ops.aten.mul.Tensor(view_as_complex_11, view_37); view_as_complex_11 = None + view_as_real_11 = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None + view_399 = torch.ops.aten.view.default(view_as_real_11, [2, 8192, 1, 128]); view_as_real_11 = None + convert_element_type_180 = torch.ops.prims.convert_element_type.default(view_398, torch.bfloat16); view_398 = None + convert_element_type_181 = torch.ops.prims.convert_element_type.default(view_399, torch.bfloat16); view_399 = None + unsqueeze_10 = torch.ops.aten.unsqueeze.default(convert_element_type_181, 3); convert_element_type_181 = None + expand_10 = torch.ops.aten.expand.default(unsqueeze_10, [2, 8192, 1, 4, 128]); unsqueeze_10 = None + view_400 = torch.ops.aten.view.default(expand_10, [2, 8192, 4, 128]); expand_10 = None + unsqueeze_11 = torch.ops.aten.unsqueeze.default(view_394, 3); view_394 = None + expand_11 = torch.ops.aten.expand.default(unsqueeze_11, [2, 8192, 1, 4, 128]); unsqueeze_11 = None + view_401 = torch.ops.aten.view.default(expand_11, [2, 8192, 4, 128]); expand_11 = None + permute_58 = torch.ops.aten.permute.default(convert_element_type_180, [0, 2, 1, 3]); convert_element_type_180 = None + permute_59 = torch.ops.aten.permute.default(view_400, [0, 2, 1, 3]); view_400 = None + permute_60 = torch.ops.aten.permute.default(view_401, [0, 2, 1, 3]); view_401 = None + _scaled_dot_product_cudnn_attention_5 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_58, permute_59, permute_60, None, True, 0.0, True); permute_58 = permute_59 = permute_60 = None + getitem_285 = _scaled_dot_product_cudnn_attention_5[0] + getitem_286 = _scaled_dot_product_cudnn_attention_5[1] + getitem_291 = _scaled_dot_product_cudnn_attention_5[6] + getitem_292 = _scaled_dot_product_cudnn_attention_5[7]; _scaled_dot_product_cudnn_attention_5 = None + permute_61 = torch.ops.aten.permute.default(getitem_285, [0, 2, 1, 3]) + view_402 = torch.ops.aten.view.default(permute_61, [2, 8192, -1]); permute_61 = None + convert_element_type_182 = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16) + all_gather_into_tensor_61 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_182, 8, '0'); convert_element_type_182 = None + wait_tensor_72 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_61); all_gather_into_tensor_61 = None + permute_62 = torch.ops.aten.permute.default(wait_tensor_72, [1, 0]); wait_tensor_72 = None + view_408 = torch.ops.aten.view.default(view_402, [16384, 512]); view_402 = None + mm_38 = torch.ops.aten.mm.default(view_408, permute_62); view_408 = permute_62 = None + view_409 = torch.ops.aten.view.default(mm_38, [2, 8192, 4096]); mm_38 = None + split_30 = torch.ops.aten.split.Tensor(view_409, 1024, 1); view_409 = None + getitem_294 = split_30[0] + getitem_295 = split_30[1] + getitem_296 = split_30[2] + getitem_297 = split_30[3] + getitem_298 = split_30[4] + getitem_299 = split_30[5] + getitem_300 = split_30[6] + getitem_301 = split_30[7]; split_30 = None + cat_22 = torch.ops.aten.cat.default([getitem_294, getitem_295, getitem_296, getitem_297, getitem_298, getitem_299, getitem_300, getitem_301]); getitem_294 = getitem_295 = getitem_296 = getitem_297 = getitem_298 = getitem_299 = getitem_300 = getitem_301 = None + reduce_scatter_tensor_11 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_22, 'sum', 8, '1'); cat_22 = None + wait_tensor_73 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_11) + add_21 = torch.ops.aten.add.Tensor(add_19, wait_tensor_73); wait_tensor_73 = None + convert_element_type_185 = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16) + all_gather_into_tensor_62 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_185, 8, '0'); convert_element_type_185 = None + wait_tensor_74 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_62); all_gather_into_tensor_62 = None + convert_element_type_186 = torch.ops.prims.convert_element_type.default(add_21, torch.float32) + pow_12 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_186, 2) + mean_11 = torch.ops.aten.mean.dim(pow_12, [2], True); pow_12 = None + add_22 = torch.ops.aten.add.Scalar(mean_11, 1e-05); mean_11 = None + rsqrt_11 = torch.ops.aten.rsqrt.default(add_22); add_22 = None + mul_44 = torch.ops.aten.mul.Tensor(convert_element_type_186, rsqrt_11); convert_element_type_186 = rsqrt_11 = None + mul_45 = torch.ops.aten.mul.Tensor(mul_44, wait_tensor_74); mul_44 = wait_tensor_74 = None + convert_element_type_187 = torch.ops.prims.convert_element_type.default(mul_45, torch.bfloat16); mul_45 = None + all_gather_into_tensor_63 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_187, 8, '1'); convert_element_type_187 = None + wait_tensor_75 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_63); all_gather_into_tensor_63 = None + split_31 = torch.ops.aten.split.Tensor(wait_tensor_75, 2); wait_tensor_75 = None + getitem_302 = split_31[0] + getitem_303 = split_31[1] + getitem_304 = split_31[2] + getitem_305 = split_31[3] + getitem_306 = split_31[4] + getitem_307 = split_31[5] + getitem_308 = split_31[6] + getitem_309 = split_31[7]; split_31 = None + cat_23 = torch.ops.aten.cat.default([getitem_302, getitem_303, getitem_304, getitem_305, getitem_306, getitem_307, getitem_308, getitem_309], 1); getitem_302 = getitem_303 = getitem_304 = getitem_305 = getitem_306 = getitem_307 = getitem_308 = getitem_309 = None + convert_element_type_188 = torch.ops.prims.convert_element_type.default(primals_55, torch.bfloat16) + all_gather_into_tensor_64 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_188, 8, '0'); convert_element_type_188 = None + wait_tensor_76 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_64); all_gather_into_tensor_64 = None + permute_63 = torch.ops.aten.permute.default(wait_tensor_76, [1, 0]); wait_tensor_76 = None + view_420 = torch.ops.aten.view.default(cat_23, [16384, 4096]); cat_23 = None + mm_39 = torch.ops.aten.mm.default(view_420, permute_63); permute_63 = None + view_421 = torch.ops.aten.view.default(mm_39, [2, 8192, 1792]) + convert_element_type_191 = torch.ops.prims.convert_element_type.default(view_421, torch.float32); view_421 = None + sigmoid_5 = torch.ops.aten.sigmoid.default(convert_element_type_191) + mul_46 = torch.ops.aten.mul.Tensor(convert_element_type_191, sigmoid_5); convert_element_type_191 = sigmoid_5 = None + convert_element_type_192 = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None + convert_element_type_193 = torch.ops.prims.convert_element_type.default(primals_56, torch.bfloat16) + all_gather_into_tensor_65 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_193, 8, '0'); convert_element_type_193 = None + wait_tensor_77 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_65); all_gather_into_tensor_65 = None + permute_64 = torch.ops.aten.permute.default(wait_tensor_77, [1, 0]); wait_tensor_77 = None + mm_40 = torch.ops.aten.mm.default(view_420, permute_64); view_420 = permute_64 = None + view_428 = torch.ops.aten.view.default(mm_40, [2, 8192, 1792]); mm_40 = None + mul_47 = torch.ops.aten.mul.Tensor(convert_element_type_192, view_428); convert_element_type_192 = view_428 = None + convert_element_type_196 = torch.ops.prims.convert_element_type.default(primals_57, torch.bfloat16) + all_gather_into_tensor_66 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_196, 8, '0'); convert_element_type_196 = None + wait_tensor_78 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_66); all_gather_into_tensor_66 = None + permute_65 = torch.ops.aten.permute.default(wait_tensor_78, [1, 0]); wait_tensor_78 = None + view_435 = torch.ops.aten.view.default(mul_47, [16384, 1792]); mul_47 = None + mm_41 = torch.ops.aten.mm.default(view_435, permute_65); view_435 = permute_65 = None + view_436 = torch.ops.aten.view.default(mm_41, [2, 8192, 4096]); mm_41 = None + split_32 = torch.ops.aten.split.Tensor(view_436, 1024, 1); view_436 = None + getitem_310 = split_32[0] + getitem_311 = split_32[1] + getitem_312 = split_32[2] + getitem_313 = split_32[3] + getitem_314 = split_32[4] + getitem_315 = split_32[5] + getitem_316 = split_32[6] + getitem_317 = split_32[7]; split_32 = None + cat_24 = torch.ops.aten.cat.default([getitem_310, getitem_311, getitem_312, getitem_313, getitem_314, getitem_315, getitem_316, getitem_317]); getitem_310 = getitem_311 = getitem_312 = getitem_313 = getitem_314 = getitem_315 = getitem_316 = getitem_317 = None + reduce_scatter_tensor_12 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_24, 'sum', 8, '1'); cat_24 = None + wait_tensor_79 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_12); reduce_scatter_tensor_12 = None + add_23 = torch.ops.aten.add.Tensor(add_21, wait_tensor_79); add_21 = wait_tensor_79 = None + convert_element_type_199 = torch.ops.prims.convert_element_type.default(primals_58, torch.bfloat16) + all_gather_into_tensor_67 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_199, 8, '0'); convert_element_type_199 = None + wait_tensor_80 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_67); all_gather_into_tensor_67 = None + convert_element_type_200 = torch.ops.prims.convert_element_type.default(add_23, torch.float32) + pow_13 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_200, 2) + mean_12 = torch.ops.aten.mean.dim(pow_13, [2], True); pow_13 = None + add_24 = torch.ops.aten.add.Scalar(mean_12, 1e-05); mean_12 = None + rsqrt_12 = torch.ops.aten.rsqrt.default(add_24); add_24 = None + mul_48 = torch.ops.aten.mul.Tensor(convert_element_type_200, rsqrt_12); convert_element_type_200 = rsqrt_12 = None + mul_49 = torch.ops.aten.mul.Tensor(mul_48, wait_tensor_80); mul_48 = wait_tensor_80 = None + convert_element_type_201 = torch.ops.prims.convert_element_type.default(mul_49, torch.bfloat16); mul_49 = None + all_gather_into_tensor_68 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_201, 8, '1'); convert_element_type_201 = None + wait_tensor_81 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_68); all_gather_into_tensor_68 = None + split_33 = torch.ops.aten.split.Tensor(wait_tensor_81, 2); wait_tensor_81 = None + getitem_318 = split_33[0] + getitem_319 = split_33[1] + getitem_320 = split_33[2] + getitem_321 = split_33[3] + getitem_322 = split_33[4] + getitem_323 = split_33[5] + getitem_324 = split_33[6] + getitem_325 = split_33[7]; split_33 = None + cat_25 = torch.ops.aten.cat.default([getitem_318, getitem_319, getitem_320, getitem_321, getitem_322, getitem_323, getitem_324, getitem_325], 1); getitem_318 = getitem_319 = getitem_320 = getitem_321 = getitem_322 = getitem_323 = getitem_324 = getitem_325 = None + convert_element_type_202 = torch.ops.prims.convert_element_type.default(primals_59, torch.bfloat16) + all_gather_into_tensor_69 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_202, 8, '0'); convert_element_type_202 = None + wait_tensor_82 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_69); all_gather_into_tensor_69 = None + permute_66 = torch.ops.aten.permute.default(wait_tensor_82, [1, 0]); wait_tensor_82 = None + view_447 = torch.ops.aten.view.default(cat_25, [16384, 4096]); cat_25 = None + mm_42 = torch.ops.aten.mm.default(view_447, permute_66); permute_66 = None + view_448 = torch.ops.aten.view.default(mm_42, [2, 8192, 512]) + convert_element_type_205 = torch.ops.prims.convert_element_type.default(primals_60, torch.bfloat16) + all_gather_into_tensor_70 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_205, 8, '0'); convert_element_type_205 = None + wait_tensor_83 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_70); all_gather_into_tensor_70 = None + permute_67 = torch.ops.aten.permute.default(wait_tensor_83, [1, 0]); wait_tensor_83 = None + mm_43 = torch.ops.aten.mm.default(view_447, permute_67); permute_67 = None + view_455 = torch.ops.aten.view.default(mm_43, [2, 8192, 128]); mm_43 = None + convert_element_type_208 = torch.ops.prims.convert_element_type.default(primals_61, torch.bfloat16) + all_gather_into_tensor_71 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_208, 8, '0'); convert_element_type_208 = None + wait_tensor_84 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_71); all_gather_into_tensor_71 = None + permute_68 = torch.ops.aten.permute.default(wait_tensor_84, [1, 0]); wait_tensor_84 = None + mm_44 = torch.ops.aten.mm.default(view_447, permute_68); view_447 = permute_68 = None + view_462 = torch.ops.aten.view.default(mm_44, [2, 8192, 128]) + view_464 = torch.ops.aten.view.default(view_448, [2, 8192, -1, 128]); view_448 = None + view_465 = torch.ops.aten.view.default(view_455, [2, 8192, -1, 128]); view_455 = None + view_466 = torch.ops.aten.view.default(view_462, [2, 8192, -1, 128]); view_462 = None + convert_element_type_211 = torch.ops.prims.convert_element_type.default(view_464, torch.float32); view_464 = None + view_467 = torch.ops.aten.view.default(convert_element_type_211, [2, 8192, 4, -1, 2]); convert_element_type_211 = None + view_as_complex_12 = torch.ops.aten.view_as_complex.default(view_467); view_467 = None + convert_element_type_212 = torch.ops.prims.convert_element_type.default(view_465, torch.float32); view_465 = None + view_468 = torch.ops.aten.view.default(convert_element_type_212, [2, 8192, 1, -1, 2]); convert_element_type_212 = None + view_as_complex_13 = torch.ops.aten.view_as_complex.default(view_468); view_468 = None + mul_50 = torch.ops.aten.mul.Tensor(view_as_complex_12, view_37); view_as_complex_12 = None + view_as_real_12 = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None + view_470 = torch.ops.aten.view.default(view_as_real_12, [2, 8192, 4, 128]); view_as_real_12 = None + mul_51 = torch.ops.aten.mul.Tensor(view_as_complex_13, view_37); view_as_complex_13 = None + view_as_real_13 = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None + view_471 = torch.ops.aten.view.default(view_as_real_13, [2, 8192, 1, 128]); view_as_real_13 = None + convert_element_type_213 = torch.ops.prims.convert_element_type.default(view_470, torch.bfloat16); view_470 = None + convert_element_type_214 = torch.ops.prims.convert_element_type.default(view_471, torch.bfloat16); view_471 = None + unsqueeze_12 = torch.ops.aten.unsqueeze.default(convert_element_type_214, 3); convert_element_type_214 = None + expand_12 = torch.ops.aten.expand.default(unsqueeze_12, [2, 8192, 1, 4, 128]); unsqueeze_12 = None + view_472 = torch.ops.aten.view.default(expand_12, [2, 8192, 4, 128]); expand_12 = None + unsqueeze_13 = torch.ops.aten.unsqueeze.default(view_466, 3); view_466 = None + expand_13 = torch.ops.aten.expand.default(unsqueeze_13, [2, 8192, 1, 4, 128]); unsqueeze_13 = None + view_473 = torch.ops.aten.view.default(expand_13, [2, 8192, 4, 128]); expand_13 = None + permute_69 = torch.ops.aten.permute.default(convert_element_type_213, [0, 2, 1, 3]); convert_element_type_213 = None + permute_70 = torch.ops.aten.permute.default(view_472, [0, 2, 1, 3]); view_472 = None + permute_71 = torch.ops.aten.permute.default(view_473, [0, 2, 1, 3]); view_473 = None + _scaled_dot_product_cudnn_attention_6 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_69, permute_70, permute_71, None, True, 0.0, True); permute_69 = permute_70 = permute_71 = None + getitem_326 = _scaled_dot_product_cudnn_attention_6[0] + getitem_327 = _scaled_dot_product_cudnn_attention_6[1] + getitem_332 = _scaled_dot_product_cudnn_attention_6[6] + getitem_333 = _scaled_dot_product_cudnn_attention_6[7]; _scaled_dot_product_cudnn_attention_6 = None + permute_72 = torch.ops.aten.permute.default(getitem_326, [0, 2, 1, 3]) + view_474 = torch.ops.aten.view.default(permute_72, [2, 8192, -1]); permute_72 = None + convert_element_type_215 = torch.ops.prims.convert_element_type.default(primals_62, torch.bfloat16) + all_gather_into_tensor_72 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_215, 8, '0'); convert_element_type_215 = None + wait_tensor_85 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_72); all_gather_into_tensor_72 = None + permute_73 = torch.ops.aten.permute.default(wait_tensor_85, [1, 0]); wait_tensor_85 = None + view_480 = torch.ops.aten.view.default(view_474, [16384, 512]); view_474 = None + mm_45 = torch.ops.aten.mm.default(view_480, permute_73); view_480 = permute_73 = None + view_481 = torch.ops.aten.view.default(mm_45, [2, 8192, 4096]); mm_45 = None + split_34 = torch.ops.aten.split.Tensor(view_481, 1024, 1); view_481 = None + getitem_335 = split_34[0] + getitem_336 = split_34[1] + getitem_337 = split_34[2] + getitem_338 = split_34[3] + getitem_339 = split_34[4] + getitem_340 = split_34[5] + getitem_341 = split_34[6] + getitem_342 = split_34[7]; split_34 = None + cat_26 = torch.ops.aten.cat.default([getitem_335, getitem_336, getitem_337, getitem_338, getitem_339, getitem_340, getitem_341, getitem_342]); getitem_335 = getitem_336 = getitem_337 = getitem_338 = getitem_339 = getitem_340 = getitem_341 = getitem_342 = None + reduce_scatter_tensor_13 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_26, 'sum', 8, '1'); cat_26 = None + wait_tensor_86 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_13) + add_25 = torch.ops.aten.add.Tensor(add_23, wait_tensor_86); wait_tensor_86 = None + convert_element_type_218 = torch.ops.prims.convert_element_type.default(primals_63, torch.bfloat16) + all_gather_into_tensor_73 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_218, 8, '0'); convert_element_type_218 = None + wait_tensor_87 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_73); all_gather_into_tensor_73 = None + convert_element_type_219 = torch.ops.prims.convert_element_type.default(add_25, torch.float32) + pow_14 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_219, 2) + mean_13 = torch.ops.aten.mean.dim(pow_14, [2], True); pow_14 = None + add_26 = torch.ops.aten.add.Scalar(mean_13, 1e-05); mean_13 = None + rsqrt_13 = torch.ops.aten.rsqrt.default(add_26); add_26 = None + mul_52 = torch.ops.aten.mul.Tensor(convert_element_type_219, rsqrt_13); convert_element_type_219 = rsqrt_13 = None + mul_53 = torch.ops.aten.mul.Tensor(mul_52, wait_tensor_87); mul_52 = wait_tensor_87 = None + convert_element_type_220 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None + all_gather_into_tensor_74 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_220, 8, '1'); convert_element_type_220 = None + wait_tensor_88 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_74); all_gather_into_tensor_74 = None + split_35 = torch.ops.aten.split.Tensor(wait_tensor_88, 2); wait_tensor_88 = None + getitem_343 = split_35[0] + getitem_344 = split_35[1] + getitem_345 = split_35[2] + getitem_346 = split_35[3] + getitem_347 = split_35[4] + getitem_348 = split_35[5] + getitem_349 = split_35[6] + getitem_350 = split_35[7]; split_35 = None + cat_27 = torch.ops.aten.cat.default([getitem_343, getitem_344, getitem_345, getitem_346, getitem_347, getitem_348, getitem_349, getitem_350], 1); getitem_343 = getitem_344 = getitem_345 = getitem_346 = getitem_347 = getitem_348 = getitem_349 = getitem_350 = None + convert_element_type_221 = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16) + all_gather_into_tensor_75 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_221, 8, '0'); convert_element_type_221 = None + wait_tensor_89 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_75); all_gather_into_tensor_75 = None + permute_74 = torch.ops.aten.permute.default(wait_tensor_89, [1, 0]); wait_tensor_89 = None + view_492 = torch.ops.aten.view.default(cat_27, [16384, 4096]); cat_27 = None + mm_46 = torch.ops.aten.mm.default(view_492, permute_74); permute_74 = None + view_493 = torch.ops.aten.view.default(mm_46, [2, 8192, 1792]) + convert_element_type_224 = torch.ops.prims.convert_element_type.default(view_493, torch.float32); view_493 = None + sigmoid_6 = torch.ops.aten.sigmoid.default(convert_element_type_224) + mul_54 = torch.ops.aten.mul.Tensor(convert_element_type_224, sigmoid_6); convert_element_type_224 = sigmoid_6 = None + convert_element_type_225 = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None + convert_element_type_226 = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16) + all_gather_into_tensor_76 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_226, 8, '0'); convert_element_type_226 = None + wait_tensor_90 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_76); all_gather_into_tensor_76 = None + permute_75 = torch.ops.aten.permute.default(wait_tensor_90, [1, 0]); wait_tensor_90 = None + mm_47 = torch.ops.aten.mm.default(view_492, permute_75); view_492 = permute_75 = None + view_500 = torch.ops.aten.view.default(mm_47, [2, 8192, 1792]); mm_47 = None + mul_55 = torch.ops.aten.mul.Tensor(convert_element_type_225, view_500); convert_element_type_225 = view_500 = None + convert_element_type_229 = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16) + all_gather_into_tensor_77 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_229, 8, '0'); convert_element_type_229 = None + wait_tensor_91 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_77); all_gather_into_tensor_77 = None + permute_76 = torch.ops.aten.permute.default(wait_tensor_91, [1, 0]); wait_tensor_91 = None + view_507 = torch.ops.aten.view.default(mul_55, [16384, 1792]); mul_55 = None + mm_48 = torch.ops.aten.mm.default(view_507, permute_76); view_507 = permute_76 = None + view_508 = torch.ops.aten.view.default(mm_48, [2, 8192, 4096]); mm_48 = None + split_36 = torch.ops.aten.split.Tensor(view_508, 1024, 1); view_508 = None + getitem_351 = split_36[0] + getitem_352 = split_36[1] + getitem_353 = split_36[2] + getitem_354 = split_36[3] + getitem_355 = split_36[4] + getitem_356 = split_36[5] + getitem_357 = split_36[6] + getitem_358 = split_36[7]; split_36 = None + cat_28 = torch.ops.aten.cat.default([getitem_351, getitem_352, getitem_353, getitem_354, getitem_355, getitem_356, getitem_357, getitem_358]); getitem_351 = getitem_352 = getitem_353 = getitem_354 = getitem_355 = getitem_356 = getitem_357 = getitem_358 = None + reduce_scatter_tensor_14 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_28, 'sum', 8, '1'); cat_28 = None + wait_tensor_92 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_14); reduce_scatter_tensor_14 = None + add_27 = torch.ops.aten.add.Tensor(add_25, wait_tensor_92); add_25 = wait_tensor_92 = None + convert_element_type_232 = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16) + all_gather_into_tensor_78 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_232, 8, '0'); convert_element_type_232 = None + wait_tensor_93 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_78); all_gather_into_tensor_78 = None + convert_element_type_233 = torch.ops.prims.convert_element_type.default(add_27, torch.float32) + pow_15 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2) + mean_14 = torch.ops.aten.mean.dim(pow_15, [2], True); pow_15 = None + add_28 = torch.ops.aten.add.Scalar(mean_14, 1e-05); mean_14 = None + rsqrt_14 = torch.ops.aten.rsqrt.default(add_28); add_28 = None + mul_56 = torch.ops.aten.mul.Tensor(convert_element_type_233, rsqrt_14); convert_element_type_233 = rsqrt_14 = None + mul_57 = torch.ops.aten.mul.Tensor(mul_56, wait_tensor_93); mul_56 = wait_tensor_93 = None + convert_element_type_234 = torch.ops.prims.convert_element_type.default(mul_57, torch.bfloat16); mul_57 = None + all_gather_into_tensor_79 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_234, 8, '1'); convert_element_type_234 = None + wait_tensor_94 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_79); all_gather_into_tensor_79 = None + split_37 = torch.ops.aten.split.Tensor(wait_tensor_94, 2); wait_tensor_94 = None + getitem_359 = split_37[0] + getitem_360 = split_37[1] + getitem_361 = split_37[2] + getitem_362 = split_37[3] + getitem_363 = split_37[4] + getitem_364 = split_37[5] + getitem_365 = split_37[6] + getitem_366 = split_37[7]; split_37 = None + cat_29 = torch.ops.aten.cat.default([getitem_359, getitem_360, getitem_361, getitem_362, getitem_363, getitem_364, getitem_365, getitem_366], 1); getitem_359 = getitem_360 = getitem_361 = getitem_362 = getitem_363 = getitem_364 = getitem_365 = getitem_366 = None + convert_element_type_235 = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16) + all_gather_into_tensor_80 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_235, 8, '0'); convert_element_type_235 = None + wait_tensor_95 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_80); all_gather_into_tensor_80 = None + permute_77 = torch.ops.aten.permute.default(wait_tensor_95, [1, 0]); wait_tensor_95 = None + view_519 = torch.ops.aten.view.default(cat_29, [16384, 4096]); cat_29 = None + mm_49 = torch.ops.aten.mm.default(view_519, permute_77); permute_77 = None + view_520 = torch.ops.aten.view.default(mm_49, [2, 8192, 512]) + convert_element_type_238 = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16) + all_gather_into_tensor_81 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_238, 8, '0'); convert_element_type_238 = None + wait_tensor_96 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_81); all_gather_into_tensor_81 = None + permute_78 = torch.ops.aten.permute.default(wait_tensor_96, [1, 0]); wait_tensor_96 = None + mm_50 = torch.ops.aten.mm.default(view_519, permute_78); permute_78 = None + view_527 = torch.ops.aten.view.default(mm_50, [2, 8192, 128]); mm_50 = None + convert_element_type_241 = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16) + all_gather_into_tensor_82 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_241, 8, '0'); convert_element_type_241 = None + wait_tensor_97 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_82); all_gather_into_tensor_82 = None + permute_79 = torch.ops.aten.permute.default(wait_tensor_97, [1, 0]); wait_tensor_97 = None + mm_51 = torch.ops.aten.mm.default(view_519, permute_79); view_519 = permute_79 = None + view_534 = torch.ops.aten.view.default(mm_51, [2, 8192, 128]) + view_536 = torch.ops.aten.view.default(view_520, [2, 8192, -1, 128]); view_520 = None + view_537 = torch.ops.aten.view.default(view_527, [2, 8192, -1, 128]); view_527 = None + view_538 = torch.ops.aten.view.default(view_534, [2, 8192, -1, 128]); view_534 = None + convert_element_type_244 = torch.ops.prims.convert_element_type.default(view_536, torch.float32); view_536 = None + view_539 = torch.ops.aten.view.default(convert_element_type_244, [2, 8192, 4, -1, 2]); convert_element_type_244 = None + view_as_complex_14 = torch.ops.aten.view_as_complex.default(view_539); view_539 = None + convert_element_type_245 = torch.ops.prims.convert_element_type.default(view_537, torch.float32); view_537 = None + view_540 = torch.ops.aten.view.default(convert_element_type_245, [2, 8192, 1, -1, 2]); convert_element_type_245 = None + view_as_complex_15 = torch.ops.aten.view_as_complex.default(view_540); view_540 = None + mul_58 = torch.ops.aten.mul.Tensor(view_as_complex_14, view_37); view_as_complex_14 = None + view_as_real_14 = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None + view_542 = torch.ops.aten.view.default(view_as_real_14, [2, 8192, 4, 128]); view_as_real_14 = None + mul_59 = torch.ops.aten.mul.Tensor(view_as_complex_15, view_37); view_as_complex_15 = None + view_as_real_15 = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None + view_543 = torch.ops.aten.view.default(view_as_real_15, [2, 8192, 1, 128]); view_as_real_15 = None + convert_element_type_246 = torch.ops.prims.convert_element_type.default(view_542, torch.bfloat16); view_542 = None + convert_element_type_247 = torch.ops.prims.convert_element_type.default(view_543, torch.bfloat16); view_543 = None + unsqueeze_14 = torch.ops.aten.unsqueeze.default(convert_element_type_247, 3); convert_element_type_247 = None + expand_14 = torch.ops.aten.expand.default(unsqueeze_14, [2, 8192, 1, 4, 128]); unsqueeze_14 = None + view_544 = torch.ops.aten.view.default(expand_14, [2, 8192, 4, 128]); expand_14 = None + unsqueeze_15 = torch.ops.aten.unsqueeze.default(view_538, 3); view_538 = None + expand_15 = torch.ops.aten.expand.default(unsqueeze_15, [2, 8192, 1, 4, 128]); unsqueeze_15 = None + view_545 = torch.ops.aten.view.default(expand_15, [2, 8192, 4, 128]); expand_15 = None + permute_80 = torch.ops.aten.permute.default(convert_element_type_246, [0, 2, 1, 3]); convert_element_type_246 = None + permute_81 = torch.ops.aten.permute.default(view_544, [0, 2, 1, 3]); view_544 = None + permute_82 = torch.ops.aten.permute.default(view_545, [0, 2, 1, 3]); view_545 = None + _scaled_dot_product_cudnn_attention_7 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_80, permute_81, permute_82, None, True, 0.0, True); permute_80 = permute_81 = permute_82 = None + getitem_367 = _scaled_dot_product_cudnn_attention_7[0] + getitem_368 = _scaled_dot_product_cudnn_attention_7[1] + getitem_373 = _scaled_dot_product_cudnn_attention_7[6] + getitem_374 = _scaled_dot_product_cudnn_attention_7[7]; _scaled_dot_product_cudnn_attention_7 = None + permute_83 = torch.ops.aten.permute.default(getitem_367, [0, 2, 1, 3]) + view_546 = torch.ops.aten.view.default(permute_83, [2, 8192, -1]); permute_83 = None + convert_element_type_248 = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16) + all_gather_into_tensor_83 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_248, 8, '0'); convert_element_type_248 = None + wait_tensor_98 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_83); all_gather_into_tensor_83 = None + permute_84 = torch.ops.aten.permute.default(wait_tensor_98, [1, 0]); wait_tensor_98 = None + view_552 = torch.ops.aten.view.default(view_546, [16384, 512]); view_546 = None + mm_52 = torch.ops.aten.mm.default(view_552, permute_84); view_552 = permute_84 = None + view_553 = torch.ops.aten.view.default(mm_52, [2, 8192, 4096]); mm_52 = None + split_38 = torch.ops.aten.split.Tensor(view_553, 1024, 1); view_553 = None + getitem_376 = split_38[0] + getitem_377 = split_38[1] + getitem_378 = split_38[2] + getitem_379 = split_38[3] + getitem_380 = split_38[4] + getitem_381 = split_38[5] + getitem_382 = split_38[6] + getitem_383 = split_38[7]; split_38 = None + cat_30 = torch.ops.aten.cat.default([getitem_376, getitem_377, getitem_378, getitem_379, getitem_380, getitem_381, getitem_382, getitem_383]); getitem_376 = getitem_377 = getitem_378 = getitem_379 = getitem_380 = getitem_381 = getitem_382 = getitem_383 = None + reduce_scatter_tensor_15 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_30, 'sum', 8, '1'); cat_30 = None + wait_tensor_99 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_15) + add_29 = torch.ops.aten.add.Tensor(add_27, wait_tensor_99); wait_tensor_99 = None + convert_element_type_251 = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16) + all_gather_into_tensor_84 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_251, 8, '0'); convert_element_type_251 = None + wait_tensor_100 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_84); all_gather_into_tensor_84 = None + convert_element_type_252 = torch.ops.prims.convert_element_type.default(add_29, torch.float32) + pow_16 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_252, 2) + mean_15 = torch.ops.aten.mean.dim(pow_16, [2], True); pow_16 = None + add_30 = torch.ops.aten.add.Scalar(mean_15, 1e-05); mean_15 = None + rsqrt_15 = torch.ops.aten.rsqrt.default(add_30); add_30 = None + mul_60 = torch.ops.aten.mul.Tensor(convert_element_type_252, rsqrt_15); convert_element_type_252 = rsqrt_15 = None + mul_61 = torch.ops.aten.mul.Tensor(mul_60, wait_tensor_100); mul_60 = wait_tensor_100 = None + convert_element_type_253 = torch.ops.prims.convert_element_type.default(mul_61, torch.bfloat16); mul_61 = None + all_gather_into_tensor_85 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_253, 8, '1'); convert_element_type_253 = None + wait_tensor_101 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_85); all_gather_into_tensor_85 = None + split_39 = torch.ops.aten.split.Tensor(wait_tensor_101, 2); wait_tensor_101 = None + getitem_384 = split_39[0] + getitem_385 = split_39[1] + getitem_386 = split_39[2] + getitem_387 = split_39[3] + getitem_388 = split_39[4] + getitem_389 = split_39[5] + getitem_390 = split_39[6] + getitem_391 = split_39[7]; split_39 = None + cat_31 = torch.ops.aten.cat.default([getitem_384, getitem_385, getitem_386, getitem_387, getitem_388, getitem_389, getitem_390, getitem_391], 1); getitem_384 = getitem_385 = getitem_386 = getitem_387 = getitem_388 = getitem_389 = getitem_390 = getitem_391 = None + convert_element_type_254 = torch.ops.prims.convert_element_type.default(primals_73, torch.bfloat16) + all_gather_into_tensor_86 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_254, 8, '0'); convert_element_type_254 = None + wait_tensor_102 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_86); all_gather_into_tensor_86 = None + permute_85 = torch.ops.aten.permute.default(wait_tensor_102, [1, 0]); wait_tensor_102 = None + view_564 = torch.ops.aten.view.default(cat_31, [16384, 4096]); cat_31 = None + mm_53 = torch.ops.aten.mm.default(view_564, permute_85); permute_85 = None + view_565 = torch.ops.aten.view.default(mm_53, [2, 8192, 1792]) + convert_element_type_257 = torch.ops.prims.convert_element_type.default(view_565, torch.float32); view_565 = None + sigmoid_7 = torch.ops.aten.sigmoid.default(convert_element_type_257) + mul_62 = torch.ops.aten.mul.Tensor(convert_element_type_257, sigmoid_7); convert_element_type_257 = sigmoid_7 = None + convert_element_type_258 = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None + convert_element_type_259 = torch.ops.prims.convert_element_type.default(primals_74, torch.bfloat16) + all_gather_into_tensor_87 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_259, 8, '0'); convert_element_type_259 = None + wait_tensor_103 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_87); all_gather_into_tensor_87 = None + permute_86 = torch.ops.aten.permute.default(wait_tensor_103, [1, 0]); wait_tensor_103 = None + mm_54 = torch.ops.aten.mm.default(view_564, permute_86); view_564 = permute_86 = None + view_572 = torch.ops.aten.view.default(mm_54, [2, 8192, 1792]); mm_54 = None + mul_63 = torch.ops.aten.mul.Tensor(convert_element_type_258, view_572); convert_element_type_258 = view_572 = None + convert_element_type_262 = torch.ops.prims.convert_element_type.default(primals_75, torch.bfloat16) + all_gather_into_tensor_88 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_262, 8, '0'); convert_element_type_262 = None + wait_tensor_104 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_88); all_gather_into_tensor_88 = None + permute_87 = torch.ops.aten.permute.default(wait_tensor_104, [1, 0]); wait_tensor_104 = None + view_579 = torch.ops.aten.view.default(mul_63, [16384, 1792]); mul_63 = None + mm_55 = torch.ops.aten.mm.default(view_579, permute_87); view_579 = permute_87 = None + view_580 = torch.ops.aten.view.default(mm_55, [2, 8192, 4096]); mm_55 = None + split_40 = torch.ops.aten.split.Tensor(view_580, 1024, 1); view_580 = None + getitem_392 = split_40[0] + getitem_393 = split_40[1] + getitem_394 = split_40[2] + getitem_395 = split_40[3] + getitem_396 = split_40[4] + getitem_397 = split_40[5] + getitem_398 = split_40[6] + getitem_399 = split_40[7]; split_40 = None + cat_32 = torch.ops.aten.cat.default([getitem_392, getitem_393, getitem_394, getitem_395, getitem_396, getitem_397, getitem_398, getitem_399]); getitem_392 = getitem_393 = getitem_394 = getitem_395 = getitem_396 = getitem_397 = getitem_398 = getitem_399 = None + reduce_scatter_tensor_16 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_32, 'sum', 8, '1'); cat_32 = None + wait_tensor_105 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_16); reduce_scatter_tensor_16 = None + add_31 = torch.ops.aten.add.Tensor(add_29, wait_tensor_105); add_29 = wait_tensor_105 = None + convert_element_type_265 = torch.ops.prims.convert_element_type.default(primals_76, torch.bfloat16) + all_gather_into_tensor_89 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_265, 8, '0'); convert_element_type_265 = None + wait_tensor_106 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_89); all_gather_into_tensor_89 = None + convert_element_type_266 = torch.ops.prims.convert_element_type.default(add_31, torch.float32) + pow_17 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_266, 2) + mean_16 = torch.ops.aten.mean.dim(pow_17, [2], True); pow_17 = None + add_32 = torch.ops.aten.add.Scalar(mean_16, 1e-05); mean_16 = None + rsqrt_16 = torch.ops.aten.rsqrt.default(add_32); add_32 = None + mul_64 = torch.ops.aten.mul.Tensor(convert_element_type_266, rsqrt_16); convert_element_type_266 = rsqrt_16 = None + mul_65 = torch.ops.aten.mul.Tensor(mul_64, wait_tensor_106); mul_64 = wait_tensor_106 = None + convert_element_type_267 = torch.ops.prims.convert_element_type.default(mul_65, torch.bfloat16); mul_65 = None + all_gather_into_tensor_90 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_267, 8, '1'); convert_element_type_267 = None + wait_tensor_107 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_90); all_gather_into_tensor_90 = None + split_41 = torch.ops.aten.split.Tensor(wait_tensor_107, 2); wait_tensor_107 = None + getitem_400 = split_41[0] + getitem_401 = split_41[1] + getitem_402 = split_41[2] + getitem_403 = split_41[3] + getitem_404 = split_41[4] + getitem_405 = split_41[5] + getitem_406 = split_41[6] + getitem_407 = split_41[7]; split_41 = None + cat_33 = torch.ops.aten.cat.default([getitem_400, getitem_401, getitem_402, getitem_403, getitem_404, getitem_405, getitem_406, getitem_407], 1); getitem_400 = getitem_401 = getitem_402 = getitem_403 = getitem_404 = getitem_405 = getitem_406 = getitem_407 = None + convert_element_type_268 = torch.ops.prims.convert_element_type.default(primals_77, torch.bfloat16) + all_gather_into_tensor_91 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_268, 8, '0'); convert_element_type_268 = None + wait_tensor_108 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_91); all_gather_into_tensor_91 = None + permute_88 = torch.ops.aten.permute.default(wait_tensor_108, [1, 0]); wait_tensor_108 = None + view_591 = torch.ops.aten.view.default(cat_33, [16384, 4096]); cat_33 = None + mm_56 = torch.ops.aten.mm.default(view_591, permute_88); permute_88 = None + view_592 = torch.ops.aten.view.default(mm_56, [2, 8192, 512]) + convert_element_type_271 = torch.ops.prims.convert_element_type.default(primals_78, torch.bfloat16) + all_gather_into_tensor_92 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_271, 8, '0'); convert_element_type_271 = None + wait_tensor_109 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_92); all_gather_into_tensor_92 = None + permute_89 = torch.ops.aten.permute.default(wait_tensor_109, [1, 0]); wait_tensor_109 = None + mm_57 = torch.ops.aten.mm.default(view_591, permute_89); permute_89 = None + view_599 = torch.ops.aten.view.default(mm_57, [2, 8192, 128]); mm_57 = None + convert_element_type_274 = torch.ops.prims.convert_element_type.default(primals_79, torch.bfloat16) + all_gather_into_tensor_93 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_274, 8, '0'); convert_element_type_274 = None + wait_tensor_110 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_93); all_gather_into_tensor_93 = None + permute_90 = torch.ops.aten.permute.default(wait_tensor_110, [1, 0]); wait_tensor_110 = None + mm_58 = torch.ops.aten.mm.default(view_591, permute_90); view_591 = permute_90 = None + view_606 = torch.ops.aten.view.default(mm_58, [2, 8192, 128]) + view_608 = torch.ops.aten.view.default(view_592, [2, 8192, -1, 128]); view_592 = None + view_609 = torch.ops.aten.view.default(view_599, [2, 8192, -1, 128]); view_599 = None + view_610 = torch.ops.aten.view.default(view_606, [2, 8192, -1, 128]); view_606 = None + convert_element_type_277 = torch.ops.prims.convert_element_type.default(view_608, torch.float32); view_608 = None + view_611 = torch.ops.aten.view.default(convert_element_type_277, [2, 8192, 4, -1, 2]); convert_element_type_277 = None + view_as_complex_16 = torch.ops.aten.view_as_complex.default(view_611); view_611 = None + convert_element_type_278 = torch.ops.prims.convert_element_type.default(view_609, torch.float32); view_609 = None + view_612 = torch.ops.aten.view.default(convert_element_type_278, [2, 8192, 1, -1, 2]); convert_element_type_278 = None + view_as_complex_17 = torch.ops.aten.view_as_complex.default(view_612); view_612 = None + mul_66 = torch.ops.aten.mul.Tensor(view_as_complex_16, view_37); view_as_complex_16 = None + view_as_real_16 = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None + view_614 = torch.ops.aten.view.default(view_as_real_16, [2, 8192, 4, 128]); view_as_real_16 = None + mul_67 = torch.ops.aten.mul.Tensor(view_as_complex_17, view_37); view_as_complex_17 = None + view_as_real_17 = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None + view_615 = torch.ops.aten.view.default(view_as_real_17, [2, 8192, 1, 128]); view_as_real_17 = None + convert_element_type_279 = torch.ops.prims.convert_element_type.default(view_614, torch.bfloat16); view_614 = None + convert_element_type_280 = torch.ops.prims.convert_element_type.default(view_615, torch.bfloat16); view_615 = None + unsqueeze_16 = torch.ops.aten.unsqueeze.default(convert_element_type_280, 3); convert_element_type_280 = None + expand_16 = torch.ops.aten.expand.default(unsqueeze_16, [2, 8192, 1, 4, 128]); unsqueeze_16 = None + view_616 = torch.ops.aten.view.default(expand_16, [2, 8192, 4, 128]); expand_16 = None + unsqueeze_17 = torch.ops.aten.unsqueeze.default(view_610, 3); view_610 = None + expand_17 = torch.ops.aten.expand.default(unsqueeze_17, [2, 8192, 1, 4, 128]); unsqueeze_17 = None + view_617 = torch.ops.aten.view.default(expand_17, [2, 8192, 4, 128]); expand_17 = None + permute_91 = torch.ops.aten.permute.default(convert_element_type_279, [0, 2, 1, 3]); convert_element_type_279 = None + permute_92 = torch.ops.aten.permute.default(view_616, [0, 2, 1, 3]); view_616 = None + permute_93 = torch.ops.aten.permute.default(view_617, [0, 2, 1, 3]); view_617 = None + _scaled_dot_product_cudnn_attention_8 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_91, permute_92, permute_93, None, True, 0.0, True); permute_91 = permute_92 = permute_93 = None + getitem_408 = _scaled_dot_product_cudnn_attention_8[0] + getitem_409 = _scaled_dot_product_cudnn_attention_8[1] + getitem_414 = _scaled_dot_product_cudnn_attention_8[6] + getitem_415 = _scaled_dot_product_cudnn_attention_8[7]; _scaled_dot_product_cudnn_attention_8 = None + permute_94 = torch.ops.aten.permute.default(getitem_408, [0, 2, 1, 3]) + view_618 = torch.ops.aten.view.default(permute_94, [2, 8192, -1]); permute_94 = None + convert_element_type_281 = torch.ops.prims.convert_element_type.default(primals_80, torch.bfloat16) + all_gather_into_tensor_94 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_281, 8, '0'); convert_element_type_281 = None + wait_tensor_111 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_94); all_gather_into_tensor_94 = None + permute_95 = torch.ops.aten.permute.default(wait_tensor_111, [1, 0]); wait_tensor_111 = None + view_624 = torch.ops.aten.view.default(view_618, [16384, 512]); view_618 = None + mm_59 = torch.ops.aten.mm.default(view_624, permute_95); view_624 = permute_95 = None + view_625 = torch.ops.aten.view.default(mm_59, [2, 8192, 4096]); mm_59 = None + split_42 = torch.ops.aten.split.Tensor(view_625, 1024, 1); view_625 = None + getitem_417 = split_42[0] + getitem_418 = split_42[1] + getitem_419 = split_42[2] + getitem_420 = split_42[3] + getitem_421 = split_42[4] + getitem_422 = split_42[5] + getitem_423 = split_42[6] + getitem_424 = split_42[7]; split_42 = None + cat_34 = torch.ops.aten.cat.default([getitem_417, getitem_418, getitem_419, getitem_420, getitem_421, getitem_422, getitem_423, getitem_424]); getitem_417 = getitem_418 = getitem_419 = getitem_420 = getitem_421 = getitem_422 = getitem_423 = getitem_424 = None + reduce_scatter_tensor_17 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_34, 'sum', 8, '1'); cat_34 = None + wait_tensor_112 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_17) + add_33 = torch.ops.aten.add.Tensor(add_31, wait_tensor_112); wait_tensor_112 = None + convert_element_type_284 = torch.ops.prims.convert_element_type.default(primals_81, torch.bfloat16) + all_gather_into_tensor_95 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_284, 8, '0'); convert_element_type_284 = None + wait_tensor_113 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_95); all_gather_into_tensor_95 = None + convert_element_type_285 = torch.ops.prims.convert_element_type.default(add_33, torch.float32) + pow_18 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_285, 2) + mean_17 = torch.ops.aten.mean.dim(pow_18, [2], True); pow_18 = None + add_34 = torch.ops.aten.add.Scalar(mean_17, 1e-05); mean_17 = None + rsqrt_17 = torch.ops.aten.rsqrt.default(add_34); add_34 = None + mul_68 = torch.ops.aten.mul.Tensor(convert_element_type_285, rsqrt_17); convert_element_type_285 = rsqrt_17 = None + mul_69 = torch.ops.aten.mul.Tensor(mul_68, wait_tensor_113); mul_68 = wait_tensor_113 = None + convert_element_type_286 = torch.ops.prims.convert_element_type.default(mul_69, torch.bfloat16); mul_69 = None + all_gather_into_tensor_96 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_286, 8, '1'); convert_element_type_286 = None + wait_tensor_114 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_96); all_gather_into_tensor_96 = None + split_43 = torch.ops.aten.split.Tensor(wait_tensor_114, 2); wait_tensor_114 = None + getitem_425 = split_43[0] + getitem_426 = split_43[1] + getitem_427 = split_43[2] + getitem_428 = split_43[3] + getitem_429 = split_43[4] + getitem_430 = split_43[5] + getitem_431 = split_43[6] + getitem_432 = split_43[7]; split_43 = None + cat_35 = torch.ops.aten.cat.default([getitem_425, getitem_426, getitem_427, getitem_428, getitem_429, getitem_430, getitem_431, getitem_432], 1); getitem_425 = getitem_426 = getitem_427 = getitem_428 = getitem_429 = getitem_430 = getitem_431 = getitem_432 = None + convert_element_type_287 = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16) + all_gather_into_tensor_97 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_287, 8, '0'); convert_element_type_287 = None + wait_tensor_115 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_97); all_gather_into_tensor_97 = None + permute_96 = torch.ops.aten.permute.default(wait_tensor_115, [1, 0]); wait_tensor_115 = None + view_636 = torch.ops.aten.view.default(cat_35, [16384, 4096]); cat_35 = None + mm_60 = torch.ops.aten.mm.default(view_636, permute_96); permute_96 = None + view_637 = torch.ops.aten.view.default(mm_60, [2, 8192, 1792]) + convert_element_type_290 = torch.ops.prims.convert_element_type.default(view_637, torch.float32); view_637 = None + sigmoid_8 = torch.ops.aten.sigmoid.default(convert_element_type_290) + mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_290, sigmoid_8); convert_element_type_290 = sigmoid_8 = None + convert_element_type_291 = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None + convert_element_type_292 = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16) + all_gather_into_tensor_98 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_292, 8, '0'); convert_element_type_292 = None + wait_tensor_116 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_98); all_gather_into_tensor_98 = None + permute_97 = torch.ops.aten.permute.default(wait_tensor_116, [1, 0]); wait_tensor_116 = None + mm_61 = torch.ops.aten.mm.default(view_636, permute_97); view_636 = permute_97 = None + view_644 = torch.ops.aten.view.default(mm_61, [2, 8192, 1792]); mm_61 = None + mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_291, view_644); convert_element_type_291 = view_644 = None + convert_element_type_295 = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16) + all_gather_into_tensor_99 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_295, 8, '0'); convert_element_type_295 = None + wait_tensor_117 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_99); all_gather_into_tensor_99 = None + permute_98 = torch.ops.aten.permute.default(wait_tensor_117, [1, 0]); wait_tensor_117 = None + view_651 = torch.ops.aten.view.default(mul_71, [16384, 1792]); mul_71 = None + mm_62 = torch.ops.aten.mm.default(view_651, permute_98); view_651 = permute_98 = None + view_652 = torch.ops.aten.view.default(mm_62, [2, 8192, 4096]); mm_62 = None + split_44 = torch.ops.aten.split.Tensor(view_652, 1024, 1); view_652 = None + getitem_433 = split_44[0] + getitem_434 = split_44[1] + getitem_435 = split_44[2] + getitem_436 = split_44[3] + getitem_437 = split_44[4] + getitem_438 = split_44[5] + getitem_439 = split_44[6] + getitem_440 = split_44[7]; split_44 = None + cat_36 = torch.ops.aten.cat.default([getitem_433, getitem_434, getitem_435, getitem_436, getitem_437, getitem_438, getitem_439, getitem_440]); getitem_433 = getitem_434 = getitem_435 = getitem_436 = getitem_437 = getitem_438 = getitem_439 = getitem_440 = None + reduce_scatter_tensor_18 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_36, 'sum', 8, '1'); cat_36 = None + wait_tensor_118 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_18); reduce_scatter_tensor_18 = None + add_35 = torch.ops.aten.add.Tensor(add_33, wait_tensor_118); add_33 = wait_tensor_118 = None + convert_element_type_298 = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16) + all_gather_into_tensor_100 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_298, 8, '0'); convert_element_type_298 = None + wait_tensor_119 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_100); all_gather_into_tensor_100 = None + convert_element_type_299 = torch.ops.prims.convert_element_type.default(add_35, torch.float32) + pow_19 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_299, 2) + mean_18 = torch.ops.aten.mean.dim(pow_19, [2], True); pow_19 = None + add_36 = torch.ops.aten.add.Scalar(mean_18, 1e-05); mean_18 = None + rsqrt_18 = torch.ops.aten.rsqrt.default(add_36); add_36 = None + mul_72 = torch.ops.aten.mul.Tensor(convert_element_type_299, rsqrt_18); convert_element_type_299 = rsqrt_18 = None + mul_73 = torch.ops.aten.mul.Tensor(mul_72, wait_tensor_119); mul_72 = wait_tensor_119 = None + convert_element_type_300 = torch.ops.prims.convert_element_type.default(mul_73, torch.bfloat16); mul_73 = None + all_gather_into_tensor_101 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_300, 8, '1'); convert_element_type_300 = None + wait_tensor_120 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_101); all_gather_into_tensor_101 = None + split_45 = torch.ops.aten.split.Tensor(wait_tensor_120, 2); wait_tensor_120 = None + getitem_441 = split_45[0] + getitem_442 = split_45[1] + getitem_443 = split_45[2] + getitem_444 = split_45[3] + getitem_445 = split_45[4] + getitem_446 = split_45[5] + getitem_447 = split_45[6] + getitem_448 = split_45[7]; split_45 = None + cat_37 = torch.ops.aten.cat.default([getitem_441, getitem_442, getitem_443, getitem_444, getitem_445, getitem_446, getitem_447, getitem_448], 1); getitem_441 = getitem_442 = getitem_443 = getitem_444 = getitem_445 = getitem_446 = getitem_447 = getitem_448 = None + convert_element_type_301 = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16) + all_gather_into_tensor_102 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_301, 8, '0'); convert_element_type_301 = None + wait_tensor_121 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_102); all_gather_into_tensor_102 = None + permute_99 = torch.ops.aten.permute.default(wait_tensor_121, [1, 0]); wait_tensor_121 = None + view_663 = torch.ops.aten.view.default(cat_37, [16384, 4096]); cat_37 = None + mm_63 = torch.ops.aten.mm.default(view_663, permute_99); permute_99 = None + view_664 = torch.ops.aten.view.default(mm_63, [2, 8192, 512]) + convert_element_type_304 = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16) + all_gather_into_tensor_103 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_304, 8, '0'); convert_element_type_304 = None + wait_tensor_122 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_103); all_gather_into_tensor_103 = None + permute_100 = torch.ops.aten.permute.default(wait_tensor_122, [1, 0]); wait_tensor_122 = None + mm_64 = torch.ops.aten.mm.default(view_663, permute_100); permute_100 = None + view_671 = torch.ops.aten.view.default(mm_64, [2, 8192, 128]); mm_64 = None + convert_element_type_307 = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16) + all_gather_into_tensor_104 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_307, 8, '0'); convert_element_type_307 = None + wait_tensor_123 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_104); all_gather_into_tensor_104 = None + permute_101 = torch.ops.aten.permute.default(wait_tensor_123, [1, 0]); wait_tensor_123 = None + mm_65 = torch.ops.aten.mm.default(view_663, permute_101); view_663 = permute_101 = None + view_678 = torch.ops.aten.view.default(mm_65, [2, 8192, 128]) + view_680 = torch.ops.aten.view.default(view_664, [2, 8192, -1, 128]); view_664 = None + view_681 = torch.ops.aten.view.default(view_671, [2, 8192, -1, 128]); view_671 = None + view_682 = torch.ops.aten.view.default(view_678, [2, 8192, -1, 128]); view_678 = None + convert_element_type_310 = torch.ops.prims.convert_element_type.default(view_680, torch.float32); view_680 = None + view_683 = torch.ops.aten.view.default(convert_element_type_310, [2, 8192, 4, -1, 2]); convert_element_type_310 = None + view_as_complex_18 = torch.ops.aten.view_as_complex.default(view_683); view_683 = None + convert_element_type_311 = torch.ops.prims.convert_element_type.default(view_681, torch.float32); view_681 = None + view_684 = torch.ops.aten.view.default(convert_element_type_311, [2, 8192, 1, -1, 2]); convert_element_type_311 = None + view_as_complex_19 = torch.ops.aten.view_as_complex.default(view_684); view_684 = None + mul_74 = torch.ops.aten.mul.Tensor(view_as_complex_18, view_37); view_as_complex_18 = None + view_as_real_18 = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None + view_686 = torch.ops.aten.view.default(view_as_real_18, [2, 8192, 4, 128]); view_as_real_18 = None + mul_75 = torch.ops.aten.mul.Tensor(view_as_complex_19, view_37); view_as_complex_19 = None + view_as_real_19 = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None + view_687 = torch.ops.aten.view.default(view_as_real_19, [2, 8192, 1, 128]); view_as_real_19 = None + convert_element_type_312 = torch.ops.prims.convert_element_type.default(view_686, torch.bfloat16); view_686 = None + convert_element_type_313 = torch.ops.prims.convert_element_type.default(view_687, torch.bfloat16); view_687 = None + unsqueeze_18 = torch.ops.aten.unsqueeze.default(convert_element_type_313, 3); convert_element_type_313 = None + expand_18 = torch.ops.aten.expand.default(unsqueeze_18, [2, 8192, 1, 4, 128]); unsqueeze_18 = None + view_688 = torch.ops.aten.view.default(expand_18, [2, 8192, 4, 128]); expand_18 = None + unsqueeze_19 = torch.ops.aten.unsqueeze.default(view_682, 3); view_682 = None + expand_19 = torch.ops.aten.expand.default(unsqueeze_19, [2, 8192, 1, 4, 128]); unsqueeze_19 = None + view_689 = torch.ops.aten.view.default(expand_19, [2, 8192, 4, 128]); expand_19 = None + permute_102 = torch.ops.aten.permute.default(convert_element_type_312, [0, 2, 1, 3]); convert_element_type_312 = None + permute_103 = torch.ops.aten.permute.default(view_688, [0, 2, 1, 3]); view_688 = None + permute_104 = torch.ops.aten.permute.default(view_689, [0, 2, 1, 3]); view_689 = None + _scaled_dot_product_cudnn_attention_9 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_102, permute_103, permute_104, None, True, 0.0, True); permute_102 = permute_103 = permute_104 = None + getitem_449 = _scaled_dot_product_cudnn_attention_9[0] + getitem_450 = _scaled_dot_product_cudnn_attention_9[1] + getitem_455 = _scaled_dot_product_cudnn_attention_9[6] + getitem_456 = _scaled_dot_product_cudnn_attention_9[7]; _scaled_dot_product_cudnn_attention_9 = None + permute_105 = torch.ops.aten.permute.default(getitem_449, [0, 2, 1, 3]) + view_690 = torch.ops.aten.view.default(permute_105, [2, 8192, -1]); permute_105 = None + convert_element_type_314 = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16) + all_gather_into_tensor_105 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_314, 8, '0'); convert_element_type_314 = None + wait_tensor_124 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_105); all_gather_into_tensor_105 = None + permute_106 = torch.ops.aten.permute.default(wait_tensor_124, [1, 0]); wait_tensor_124 = None + view_696 = torch.ops.aten.view.default(view_690, [16384, 512]); view_690 = None + mm_66 = torch.ops.aten.mm.default(view_696, permute_106); view_696 = permute_106 = None + view_697 = torch.ops.aten.view.default(mm_66, [2, 8192, 4096]); mm_66 = None + split_46 = torch.ops.aten.split.Tensor(view_697, 1024, 1); view_697 = None + getitem_458 = split_46[0] + getitem_459 = split_46[1] + getitem_460 = split_46[2] + getitem_461 = split_46[3] + getitem_462 = split_46[4] + getitem_463 = split_46[5] + getitem_464 = split_46[6] + getitem_465 = split_46[7]; split_46 = None + cat_38 = torch.ops.aten.cat.default([getitem_458, getitem_459, getitem_460, getitem_461, getitem_462, getitem_463, getitem_464, getitem_465]); getitem_458 = getitem_459 = getitem_460 = getitem_461 = getitem_462 = getitem_463 = getitem_464 = getitem_465 = None + reduce_scatter_tensor_19 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_38, 'sum', 8, '1'); cat_38 = None + wait_tensor_125 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_19) + add_37 = torch.ops.aten.add.Tensor(add_35, wait_tensor_125); wait_tensor_125 = None + convert_element_type_317 = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16) + all_gather_into_tensor_106 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_317, 8, '0'); convert_element_type_317 = None + wait_tensor_126 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_106); all_gather_into_tensor_106 = None + convert_element_type_318 = torch.ops.prims.convert_element_type.default(add_37, torch.float32) + pow_20 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_318, 2) + mean_19 = torch.ops.aten.mean.dim(pow_20, [2], True); pow_20 = None + add_38 = torch.ops.aten.add.Scalar(mean_19, 1e-05); mean_19 = None + rsqrt_19 = torch.ops.aten.rsqrt.default(add_38); add_38 = None + mul_76 = torch.ops.aten.mul.Tensor(convert_element_type_318, rsqrt_19); convert_element_type_318 = rsqrt_19 = None + mul_77 = torch.ops.aten.mul.Tensor(mul_76, wait_tensor_126); mul_76 = wait_tensor_126 = None + convert_element_type_319 = torch.ops.prims.convert_element_type.default(mul_77, torch.bfloat16); mul_77 = None + all_gather_into_tensor_107 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_319, 8, '1'); convert_element_type_319 = None + wait_tensor_127 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_107); all_gather_into_tensor_107 = None + split_47 = torch.ops.aten.split.Tensor(wait_tensor_127, 2); wait_tensor_127 = None + getitem_466 = split_47[0] + getitem_467 = split_47[1] + getitem_468 = split_47[2] + getitem_469 = split_47[3] + getitem_470 = split_47[4] + getitem_471 = split_47[5] + getitem_472 = split_47[6] + getitem_473 = split_47[7]; split_47 = None + cat_39 = torch.ops.aten.cat.default([getitem_466, getitem_467, getitem_468, getitem_469, getitem_470, getitem_471, getitem_472, getitem_473], 1); getitem_466 = getitem_467 = getitem_468 = getitem_469 = getitem_470 = getitem_471 = getitem_472 = getitem_473 = None + convert_element_type_320 = torch.ops.prims.convert_element_type.default(primals_91, torch.bfloat16) + all_gather_into_tensor_108 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_320, 8, '0'); convert_element_type_320 = None + wait_tensor_128 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_108); all_gather_into_tensor_108 = None + permute_107 = torch.ops.aten.permute.default(wait_tensor_128, [1, 0]); wait_tensor_128 = None + view_708 = torch.ops.aten.view.default(cat_39, [16384, 4096]); cat_39 = None + mm_67 = torch.ops.aten.mm.default(view_708, permute_107); permute_107 = None + view_709 = torch.ops.aten.view.default(mm_67, [2, 8192, 1792]) + convert_element_type_323 = torch.ops.prims.convert_element_type.default(view_709, torch.float32); view_709 = None + sigmoid_9 = torch.ops.aten.sigmoid.default(convert_element_type_323) + mul_78 = torch.ops.aten.mul.Tensor(convert_element_type_323, sigmoid_9); convert_element_type_323 = sigmoid_9 = None + convert_element_type_324 = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None + convert_element_type_325 = torch.ops.prims.convert_element_type.default(primals_92, torch.bfloat16) + all_gather_into_tensor_109 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_325, 8, '0'); convert_element_type_325 = None + wait_tensor_129 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_109); all_gather_into_tensor_109 = None + permute_108 = torch.ops.aten.permute.default(wait_tensor_129, [1, 0]); wait_tensor_129 = None + mm_68 = torch.ops.aten.mm.default(view_708, permute_108); view_708 = permute_108 = None + view_716 = torch.ops.aten.view.default(mm_68, [2, 8192, 1792]); mm_68 = None + mul_79 = torch.ops.aten.mul.Tensor(convert_element_type_324, view_716); convert_element_type_324 = view_716 = None + convert_element_type_328 = torch.ops.prims.convert_element_type.default(primals_93, torch.bfloat16) + all_gather_into_tensor_110 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_328, 8, '0'); convert_element_type_328 = None + wait_tensor_130 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_110); all_gather_into_tensor_110 = None + permute_109 = torch.ops.aten.permute.default(wait_tensor_130, [1, 0]); wait_tensor_130 = None + view_723 = torch.ops.aten.view.default(mul_79, [16384, 1792]); mul_79 = None + mm_69 = torch.ops.aten.mm.default(view_723, permute_109); view_723 = permute_109 = None + view_724 = torch.ops.aten.view.default(mm_69, [2, 8192, 4096]); mm_69 = None + split_48 = torch.ops.aten.split.Tensor(view_724, 1024, 1); view_724 = None + getitem_474 = split_48[0] + getitem_475 = split_48[1] + getitem_476 = split_48[2] + getitem_477 = split_48[3] + getitem_478 = split_48[4] + getitem_479 = split_48[5] + getitem_480 = split_48[6] + getitem_481 = split_48[7]; split_48 = None + cat_40 = torch.ops.aten.cat.default([getitem_474, getitem_475, getitem_476, getitem_477, getitem_478, getitem_479, getitem_480, getitem_481]); getitem_474 = getitem_475 = getitem_476 = getitem_477 = getitem_478 = getitem_479 = getitem_480 = getitem_481 = None + reduce_scatter_tensor_20 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_40, 'sum', 8, '1'); cat_40 = None + wait_tensor_131 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_20); reduce_scatter_tensor_20 = None + add_39 = torch.ops.aten.add.Tensor(add_37, wait_tensor_131); add_37 = wait_tensor_131 = None + convert_element_type_331 = torch.ops.prims.convert_element_type.default(primals_94, torch.bfloat16) + all_gather_into_tensor_111 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_331, 8, '0'); convert_element_type_331 = None + wait_tensor_132 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_111); all_gather_into_tensor_111 = None + convert_element_type_332 = torch.ops.prims.convert_element_type.default(add_39, torch.float32) + pow_21 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_332, 2) + mean_20 = torch.ops.aten.mean.dim(pow_21, [2], True); pow_21 = None + add_40 = torch.ops.aten.add.Scalar(mean_20, 1e-05); mean_20 = None + rsqrt_20 = torch.ops.aten.rsqrt.default(add_40); add_40 = None + mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_332, rsqrt_20); convert_element_type_332 = rsqrt_20 = None + mul_81 = torch.ops.aten.mul.Tensor(mul_80, wait_tensor_132); mul_80 = wait_tensor_132 = None + convert_element_type_333 = torch.ops.prims.convert_element_type.default(mul_81, torch.bfloat16); mul_81 = None + all_gather_into_tensor_112 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_333, 8, '1'); convert_element_type_333 = None + wait_tensor_133 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_112); all_gather_into_tensor_112 = None + split_49 = torch.ops.aten.split.Tensor(wait_tensor_133, 2); wait_tensor_133 = None + getitem_482 = split_49[0] + getitem_483 = split_49[1] + getitem_484 = split_49[2] + getitem_485 = split_49[3] + getitem_486 = split_49[4] + getitem_487 = split_49[5] + getitem_488 = split_49[6] + getitem_489 = split_49[7]; split_49 = None + cat_41 = torch.ops.aten.cat.default([getitem_482, getitem_483, getitem_484, getitem_485, getitem_486, getitem_487, getitem_488, getitem_489], 1); getitem_482 = getitem_483 = getitem_484 = getitem_485 = getitem_486 = getitem_487 = getitem_488 = getitem_489 = None + convert_element_type_334 = torch.ops.prims.convert_element_type.default(primals_95, torch.bfloat16) + all_gather_into_tensor_113 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_334, 8, '0'); convert_element_type_334 = None + wait_tensor_134 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_113); all_gather_into_tensor_113 = None + permute_110 = torch.ops.aten.permute.default(wait_tensor_134, [1, 0]); wait_tensor_134 = None + view_735 = torch.ops.aten.view.default(cat_41, [16384, 4096]); cat_41 = None + mm_70 = torch.ops.aten.mm.default(view_735, permute_110); permute_110 = None + view_736 = torch.ops.aten.view.default(mm_70, [2, 8192, 512]) + convert_element_type_337 = torch.ops.prims.convert_element_type.default(primals_96, torch.bfloat16) + all_gather_into_tensor_114 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_337, 8, '0'); convert_element_type_337 = None + wait_tensor_135 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_114); all_gather_into_tensor_114 = None + permute_111 = torch.ops.aten.permute.default(wait_tensor_135, [1, 0]); wait_tensor_135 = None + mm_71 = torch.ops.aten.mm.default(view_735, permute_111); permute_111 = None + view_743 = torch.ops.aten.view.default(mm_71, [2, 8192, 128]); mm_71 = None + convert_element_type_340 = torch.ops.prims.convert_element_type.default(primals_97, torch.bfloat16) + all_gather_into_tensor_115 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_340, 8, '0'); convert_element_type_340 = None + wait_tensor_136 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_115); all_gather_into_tensor_115 = None + permute_112 = torch.ops.aten.permute.default(wait_tensor_136, [1, 0]); wait_tensor_136 = None + mm_72 = torch.ops.aten.mm.default(view_735, permute_112); view_735 = permute_112 = None + view_750 = torch.ops.aten.view.default(mm_72, [2, 8192, 128]) + view_752 = torch.ops.aten.view.default(view_736, [2, 8192, -1, 128]); view_736 = None + view_753 = torch.ops.aten.view.default(view_743, [2, 8192, -1, 128]); view_743 = None + view_754 = torch.ops.aten.view.default(view_750, [2, 8192, -1, 128]); view_750 = None + convert_element_type_343 = torch.ops.prims.convert_element_type.default(view_752, torch.float32); view_752 = None + view_755 = torch.ops.aten.view.default(convert_element_type_343, [2, 8192, 4, -1, 2]); convert_element_type_343 = None + view_as_complex_20 = torch.ops.aten.view_as_complex.default(view_755); view_755 = None + convert_element_type_344 = torch.ops.prims.convert_element_type.default(view_753, torch.float32); view_753 = None + view_756 = torch.ops.aten.view.default(convert_element_type_344, [2, 8192, 1, -1, 2]); convert_element_type_344 = None + view_as_complex_21 = torch.ops.aten.view_as_complex.default(view_756); view_756 = None + mul_82 = torch.ops.aten.mul.Tensor(view_as_complex_20, view_37); view_as_complex_20 = None + view_as_real_20 = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None + view_758 = torch.ops.aten.view.default(view_as_real_20, [2, 8192, 4, 128]); view_as_real_20 = None + mul_83 = torch.ops.aten.mul.Tensor(view_as_complex_21, view_37); view_as_complex_21 = None + view_as_real_21 = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None + view_759 = torch.ops.aten.view.default(view_as_real_21, [2, 8192, 1, 128]); view_as_real_21 = None + convert_element_type_345 = torch.ops.prims.convert_element_type.default(view_758, torch.bfloat16); view_758 = None + convert_element_type_346 = torch.ops.prims.convert_element_type.default(view_759, torch.bfloat16); view_759 = None + unsqueeze_20 = torch.ops.aten.unsqueeze.default(convert_element_type_346, 3); convert_element_type_346 = None + expand_20 = torch.ops.aten.expand.default(unsqueeze_20, [2, 8192, 1, 4, 128]); unsqueeze_20 = None + view_760 = torch.ops.aten.view.default(expand_20, [2, 8192, 4, 128]); expand_20 = None + unsqueeze_21 = torch.ops.aten.unsqueeze.default(view_754, 3); view_754 = None + expand_21 = torch.ops.aten.expand.default(unsqueeze_21, [2, 8192, 1, 4, 128]); unsqueeze_21 = None + view_761 = torch.ops.aten.view.default(expand_21, [2, 8192, 4, 128]); expand_21 = None + permute_113 = torch.ops.aten.permute.default(convert_element_type_345, [0, 2, 1, 3]); convert_element_type_345 = None + permute_114 = torch.ops.aten.permute.default(view_760, [0, 2, 1, 3]); view_760 = None + permute_115 = torch.ops.aten.permute.default(view_761, [0, 2, 1, 3]); view_761 = None + _scaled_dot_product_cudnn_attention_10 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_113, permute_114, permute_115, None, True, 0.0, True); permute_113 = permute_114 = permute_115 = None + getitem_490 = _scaled_dot_product_cudnn_attention_10[0] + getitem_491 = _scaled_dot_product_cudnn_attention_10[1] + getitem_496 = _scaled_dot_product_cudnn_attention_10[6] + getitem_497 = _scaled_dot_product_cudnn_attention_10[7]; _scaled_dot_product_cudnn_attention_10 = None + permute_116 = torch.ops.aten.permute.default(getitem_490, [0, 2, 1, 3]) + view_762 = torch.ops.aten.view.default(permute_116, [2, 8192, -1]); permute_116 = None + convert_element_type_347 = torch.ops.prims.convert_element_type.default(primals_98, torch.bfloat16) + all_gather_into_tensor_116 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_347, 8, '0'); convert_element_type_347 = None + wait_tensor_137 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_116); all_gather_into_tensor_116 = None + permute_117 = torch.ops.aten.permute.default(wait_tensor_137, [1, 0]); wait_tensor_137 = None + view_768 = torch.ops.aten.view.default(view_762, [16384, 512]); view_762 = None + mm_73 = torch.ops.aten.mm.default(view_768, permute_117); view_768 = permute_117 = None + view_769 = torch.ops.aten.view.default(mm_73, [2, 8192, 4096]); mm_73 = None + split_50 = torch.ops.aten.split.Tensor(view_769, 1024, 1); view_769 = None + getitem_499 = split_50[0] + getitem_500 = split_50[1] + getitem_501 = split_50[2] + getitem_502 = split_50[3] + getitem_503 = split_50[4] + getitem_504 = split_50[5] + getitem_505 = split_50[6] + getitem_506 = split_50[7]; split_50 = None + cat_42 = torch.ops.aten.cat.default([getitem_499, getitem_500, getitem_501, getitem_502, getitem_503, getitem_504, getitem_505, getitem_506]); getitem_499 = getitem_500 = getitem_501 = getitem_502 = getitem_503 = getitem_504 = getitem_505 = getitem_506 = None + reduce_scatter_tensor_21 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_42, 'sum', 8, '1'); cat_42 = None + wait_tensor_138 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_21) + add_41 = torch.ops.aten.add.Tensor(add_39, wait_tensor_138); wait_tensor_138 = None + convert_element_type_350 = torch.ops.prims.convert_element_type.default(primals_99, torch.bfloat16) + all_gather_into_tensor_117 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_350, 8, '0'); convert_element_type_350 = None + wait_tensor_139 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_117); all_gather_into_tensor_117 = None + convert_element_type_351 = torch.ops.prims.convert_element_type.default(add_41, torch.float32) + pow_22 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_351, 2) + mean_21 = torch.ops.aten.mean.dim(pow_22, [2], True); pow_22 = None + add_42 = torch.ops.aten.add.Scalar(mean_21, 1e-05); mean_21 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_42); add_42 = None + mul_84 = torch.ops.aten.mul.Tensor(convert_element_type_351, rsqrt_21); convert_element_type_351 = rsqrt_21 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, wait_tensor_139); mul_84 = wait_tensor_139 = None + convert_element_type_352 = torch.ops.prims.convert_element_type.default(mul_85, torch.bfloat16); mul_85 = None + all_gather_into_tensor_118 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_352, 8, '1'); convert_element_type_352 = None + wait_tensor_140 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_118); all_gather_into_tensor_118 = None + split_51 = torch.ops.aten.split.Tensor(wait_tensor_140, 2); wait_tensor_140 = None + getitem_507 = split_51[0] + getitem_508 = split_51[1] + getitem_509 = split_51[2] + getitem_510 = split_51[3] + getitem_511 = split_51[4] + getitem_512 = split_51[5] + getitem_513 = split_51[6] + getitem_514 = split_51[7]; split_51 = None + cat_43 = torch.ops.aten.cat.default([getitem_507, getitem_508, getitem_509, getitem_510, getitem_511, getitem_512, getitem_513, getitem_514], 1); getitem_507 = getitem_508 = getitem_509 = getitem_510 = getitem_511 = getitem_512 = getitem_513 = getitem_514 = None + convert_element_type_353 = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16) + all_gather_into_tensor_119 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_353, 8, '0'); convert_element_type_353 = None + wait_tensor_141 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_119); all_gather_into_tensor_119 = None + permute_118 = torch.ops.aten.permute.default(wait_tensor_141, [1, 0]); wait_tensor_141 = None + view_780 = torch.ops.aten.view.default(cat_43, [16384, 4096]); cat_43 = None + mm_74 = torch.ops.aten.mm.default(view_780, permute_118); permute_118 = None + view_781 = torch.ops.aten.view.default(mm_74, [2, 8192, 1792]) + convert_element_type_356 = torch.ops.prims.convert_element_type.default(view_781, torch.float32); view_781 = None + sigmoid_10 = torch.ops.aten.sigmoid.default(convert_element_type_356) + mul_86 = torch.ops.aten.mul.Tensor(convert_element_type_356, sigmoid_10); convert_element_type_356 = sigmoid_10 = None + convert_element_type_357 = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None + convert_element_type_358 = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16) + all_gather_into_tensor_120 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_358, 8, '0'); convert_element_type_358 = None + wait_tensor_142 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_120); all_gather_into_tensor_120 = None + permute_119 = torch.ops.aten.permute.default(wait_tensor_142, [1, 0]); wait_tensor_142 = None + mm_75 = torch.ops.aten.mm.default(view_780, permute_119); view_780 = permute_119 = None + view_788 = torch.ops.aten.view.default(mm_75, [2, 8192, 1792]); mm_75 = None + mul_87 = torch.ops.aten.mul.Tensor(convert_element_type_357, view_788); convert_element_type_357 = view_788 = None + convert_element_type_361 = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16) + all_gather_into_tensor_121 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_361, 8, '0'); convert_element_type_361 = None + wait_tensor_143 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_121); all_gather_into_tensor_121 = None + permute_120 = torch.ops.aten.permute.default(wait_tensor_143, [1, 0]); wait_tensor_143 = None + view_795 = torch.ops.aten.view.default(mul_87, [16384, 1792]); mul_87 = None + mm_76 = torch.ops.aten.mm.default(view_795, permute_120); view_795 = permute_120 = None + view_796 = torch.ops.aten.view.default(mm_76, [2, 8192, 4096]); mm_76 = None + split_52 = torch.ops.aten.split.Tensor(view_796, 1024, 1); view_796 = None + getitem_515 = split_52[0] + getitem_516 = split_52[1] + getitem_517 = split_52[2] + getitem_518 = split_52[3] + getitem_519 = split_52[4] + getitem_520 = split_52[5] + getitem_521 = split_52[6] + getitem_522 = split_52[7]; split_52 = None + cat_44 = torch.ops.aten.cat.default([getitem_515, getitem_516, getitem_517, getitem_518, getitem_519, getitem_520, getitem_521, getitem_522]); getitem_515 = getitem_516 = getitem_517 = getitem_518 = getitem_519 = getitem_520 = getitem_521 = getitem_522 = None + reduce_scatter_tensor_22 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_44, 'sum', 8, '1'); cat_44 = None + wait_tensor_144 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_22); reduce_scatter_tensor_22 = None + add_43 = torch.ops.aten.add.Tensor(add_41, wait_tensor_144); add_41 = wait_tensor_144 = None + convert_element_type_364 = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16) + all_gather_into_tensor_122 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_364, 8, '0'); convert_element_type_364 = None + wait_tensor_145 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_122); all_gather_into_tensor_122 = None + convert_element_type_365 = torch.ops.prims.convert_element_type.default(add_43, torch.float32) + pow_23 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_365, 2) + mean_22 = torch.ops.aten.mean.dim(pow_23, [2], True); pow_23 = None + add_44 = torch.ops.aten.add.Scalar(mean_22, 1e-05); mean_22 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_44); add_44 = None + mul_88 = torch.ops.aten.mul.Tensor(convert_element_type_365, rsqrt_22); convert_element_type_365 = rsqrt_22 = None + mul_89 = torch.ops.aten.mul.Tensor(mul_88, wait_tensor_145); mul_88 = wait_tensor_145 = None + convert_element_type_366 = torch.ops.prims.convert_element_type.default(mul_89, torch.bfloat16); mul_89 = None + all_gather_into_tensor_123 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_366, 8, '1'); convert_element_type_366 = None + wait_tensor_146 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_123); all_gather_into_tensor_123 = None + split_53 = torch.ops.aten.split.Tensor(wait_tensor_146, 2); wait_tensor_146 = None + getitem_523 = split_53[0] + getitem_524 = split_53[1] + getitem_525 = split_53[2] + getitem_526 = split_53[3] + getitem_527 = split_53[4] + getitem_528 = split_53[5] + getitem_529 = split_53[6] + getitem_530 = split_53[7]; split_53 = None + cat_45 = torch.ops.aten.cat.default([getitem_523, getitem_524, getitem_525, getitem_526, getitem_527, getitem_528, getitem_529, getitem_530], 1); getitem_523 = getitem_524 = getitem_525 = getitem_526 = getitem_527 = getitem_528 = getitem_529 = getitem_530 = None + convert_element_type_367 = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16) + all_gather_into_tensor_124 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_367, 8, '0'); convert_element_type_367 = None + wait_tensor_147 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_124); all_gather_into_tensor_124 = None + permute_121 = torch.ops.aten.permute.default(wait_tensor_147, [1, 0]); wait_tensor_147 = None + view_807 = torch.ops.aten.view.default(cat_45, [16384, 4096]); cat_45 = None + mm_77 = torch.ops.aten.mm.default(view_807, permute_121); permute_121 = None + view_808 = torch.ops.aten.view.default(mm_77, [2, 8192, 512]) + convert_element_type_370 = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16) + all_gather_into_tensor_125 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_370, 8, '0'); convert_element_type_370 = None + wait_tensor_148 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_125); all_gather_into_tensor_125 = None + permute_122 = torch.ops.aten.permute.default(wait_tensor_148, [1, 0]); wait_tensor_148 = None + mm_78 = torch.ops.aten.mm.default(view_807, permute_122); permute_122 = None + view_815 = torch.ops.aten.view.default(mm_78, [2, 8192, 128]); mm_78 = None + convert_element_type_373 = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16) + all_gather_into_tensor_126 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_373, 8, '0'); convert_element_type_373 = None + wait_tensor_149 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_126); all_gather_into_tensor_126 = None + permute_123 = torch.ops.aten.permute.default(wait_tensor_149, [1, 0]); wait_tensor_149 = None + mm_79 = torch.ops.aten.mm.default(view_807, permute_123); view_807 = permute_123 = None + view_822 = torch.ops.aten.view.default(mm_79, [2, 8192, 128]) + view_824 = torch.ops.aten.view.default(view_808, [2, 8192, -1, 128]); view_808 = None + view_825 = torch.ops.aten.view.default(view_815, [2, 8192, -1, 128]); view_815 = None + view_826 = torch.ops.aten.view.default(view_822, [2, 8192, -1, 128]); view_822 = None + convert_element_type_376 = torch.ops.prims.convert_element_type.default(view_824, torch.float32); view_824 = None + view_827 = torch.ops.aten.view.default(convert_element_type_376, [2, 8192, 4, -1, 2]); convert_element_type_376 = None + view_as_complex_22 = torch.ops.aten.view_as_complex.default(view_827); view_827 = None + convert_element_type_377 = torch.ops.prims.convert_element_type.default(view_825, torch.float32); view_825 = None + view_828 = torch.ops.aten.view.default(convert_element_type_377, [2, 8192, 1, -1, 2]); convert_element_type_377 = None + view_as_complex_23 = torch.ops.aten.view_as_complex.default(view_828); view_828 = None + mul_90 = torch.ops.aten.mul.Tensor(view_as_complex_22, view_37); view_as_complex_22 = None + view_as_real_22 = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None + view_830 = torch.ops.aten.view.default(view_as_real_22, [2, 8192, 4, 128]); view_as_real_22 = None + mul_91 = torch.ops.aten.mul.Tensor(view_as_complex_23, view_37); view_as_complex_23 = None + view_as_real_23 = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None + view_831 = torch.ops.aten.view.default(view_as_real_23, [2, 8192, 1, 128]); view_as_real_23 = None + convert_element_type_378 = torch.ops.prims.convert_element_type.default(view_830, torch.bfloat16); view_830 = None + convert_element_type_379 = torch.ops.prims.convert_element_type.default(view_831, torch.bfloat16); view_831 = None + unsqueeze_22 = torch.ops.aten.unsqueeze.default(convert_element_type_379, 3); convert_element_type_379 = None + expand_22 = torch.ops.aten.expand.default(unsqueeze_22, [2, 8192, 1, 4, 128]); unsqueeze_22 = None + view_832 = torch.ops.aten.view.default(expand_22, [2, 8192, 4, 128]); expand_22 = None + unsqueeze_23 = torch.ops.aten.unsqueeze.default(view_826, 3); view_826 = None + expand_23 = torch.ops.aten.expand.default(unsqueeze_23, [2, 8192, 1, 4, 128]); unsqueeze_23 = None + view_833 = torch.ops.aten.view.default(expand_23, [2, 8192, 4, 128]); expand_23 = None + permute_124 = torch.ops.aten.permute.default(convert_element_type_378, [0, 2, 1, 3]); convert_element_type_378 = None + permute_125 = torch.ops.aten.permute.default(view_832, [0, 2, 1, 3]); view_832 = None + permute_126 = torch.ops.aten.permute.default(view_833, [0, 2, 1, 3]); view_833 = None + _scaled_dot_product_cudnn_attention_11 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_124, permute_125, permute_126, None, True, 0.0, True); permute_124 = permute_125 = permute_126 = None + getitem_531 = _scaled_dot_product_cudnn_attention_11[0] + getitem_532 = _scaled_dot_product_cudnn_attention_11[1] + getitem_537 = _scaled_dot_product_cudnn_attention_11[6] + getitem_538 = _scaled_dot_product_cudnn_attention_11[7]; _scaled_dot_product_cudnn_attention_11 = None + permute_127 = torch.ops.aten.permute.default(getitem_531, [0, 2, 1, 3]) + view_834 = torch.ops.aten.view.default(permute_127, [2, 8192, -1]); permute_127 = None + convert_element_type_380 = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16) + all_gather_into_tensor_127 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_380, 8, '0'); convert_element_type_380 = None + wait_tensor_150 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_127); all_gather_into_tensor_127 = None + permute_128 = torch.ops.aten.permute.default(wait_tensor_150, [1, 0]); wait_tensor_150 = None + view_840 = torch.ops.aten.view.default(view_834, [16384, 512]); view_834 = None + mm_80 = torch.ops.aten.mm.default(view_840, permute_128); view_840 = permute_128 = None + view_841 = torch.ops.aten.view.default(mm_80, [2, 8192, 4096]); mm_80 = None + split_54 = torch.ops.aten.split.Tensor(view_841, 1024, 1); view_841 = None + getitem_540 = split_54[0] + getitem_541 = split_54[1] + getitem_542 = split_54[2] + getitem_543 = split_54[3] + getitem_544 = split_54[4] + getitem_545 = split_54[5] + getitem_546 = split_54[6] + getitem_547 = split_54[7]; split_54 = None + cat_46 = torch.ops.aten.cat.default([getitem_540, getitem_541, getitem_542, getitem_543, getitem_544, getitem_545, getitem_546, getitem_547]); getitem_540 = getitem_541 = getitem_542 = getitem_543 = getitem_544 = getitem_545 = getitem_546 = getitem_547 = None + reduce_scatter_tensor_23 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_46, 'sum', 8, '1'); cat_46 = None + wait_tensor_151 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_23) + add_45 = torch.ops.aten.add.Tensor(add_43, wait_tensor_151); wait_tensor_151 = None + convert_element_type_383 = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16) + all_gather_into_tensor_128 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_383, 8, '0'); convert_element_type_383 = None + wait_tensor_152 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_128); all_gather_into_tensor_128 = None + convert_element_type_384 = torch.ops.prims.convert_element_type.default(add_45, torch.float32) + pow_24 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_384, 2) + mean_23 = torch.ops.aten.mean.dim(pow_24, [2], True); pow_24 = None + add_46 = torch.ops.aten.add.Scalar(mean_23, 1e-05); mean_23 = None + rsqrt_23 = torch.ops.aten.rsqrt.default(add_46); add_46 = None + mul_92 = torch.ops.aten.mul.Tensor(convert_element_type_384, rsqrt_23); convert_element_type_384 = rsqrt_23 = None + mul_93 = torch.ops.aten.mul.Tensor(mul_92, wait_tensor_152); mul_92 = wait_tensor_152 = None + convert_element_type_385 = torch.ops.prims.convert_element_type.default(mul_93, torch.bfloat16); mul_93 = None + all_gather_into_tensor_129 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_385, 8, '1'); convert_element_type_385 = None + wait_tensor_153 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_129); all_gather_into_tensor_129 = None + split_55 = torch.ops.aten.split.Tensor(wait_tensor_153, 2); wait_tensor_153 = None + getitem_548 = split_55[0] + getitem_549 = split_55[1] + getitem_550 = split_55[2] + getitem_551 = split_55[3] + getitem_552 = split_55[4] + getitem_553 = split_55[5] + getitem_554 = split_55[6] + getitem_555 = split_55[7]; split_55 = None + cat_47 = torch.ops.aten.cat.default([getitem_548, getitem_549, getitem_550, getitem_551, getitem_552, getitem_553, getitem_554, getitem_555], 1); getitem_548 = getitem_549 = getitem_550 = getitem_551 = getitem_552 = getitem_553 = getitem_554 = getitem_555 = None + convert_element_type_386 = torch.ops.prims.convert_element_type.default(primals_109, torch.bfloat16) + all_gather_into_tensor_130 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_386, 8, '0'); convert_element_type_386 = None + wait_tensor_154 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_130); all_gather_into_tensor_130 = None + permute_129 = torch.ops.aten.permute.default(wait_tensor_154, [1, 0]); wait_tensor_154 = None + view_852 = torch.ops.aten.view.default(cat_47, [16384, 4096]); cat_47 = None + mm_81 = torch.ops.aten.mm.default(view_852, permute_129); permute_129 = None + view_853 = torch.ops.aten.view.default(mm_81, [2, 8192, 1792]) + convert_element_type_389 = torch.ops.prims.convert_element_type.default(view_853, torch.float32); view_853 = None + sigmoid_11 = torch.ops.aten.sigmoid.default(convert_element_type_389) + mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_389, sigmoid_11); convert_element_type_389 = sigmoid_11 = None + convert_element_type_390 = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None + convert_element_type_391 = torch.ops.prims.convert_element_type.default(primals_110, torch.bfloat16) + all_gather_into_tensor_131 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_391, 8, '0'); convert_element_type_391 = None + wait_tensor_155 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_131); all_gather_into_tensor_131 = None + permute_130 = torch.ops.aten.permute.default(wait_tensor_155, [1, 0]); wait_tensor_155 = None + mm_82 = torch.ops.aten.mm.default(view_852, permute_130); view_852 = permute_130 = None + view_860 = torch.ops.aten.view.default(mm_82, [2, 8192, 1792]); mm_82 = None + mul_95 = torch.ops.aten.mul.Tensor(convert_element_type_390, view_860); convert_element_type_390 = view_860 = None + convert_element_type_394 = torch.ops.prims.convert_element_type.default(primals_111, torch.bfloat16) + all_gather_into_tensor_132 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_394, 8, '0'); convert_element_type_394 = None + wait_tensor_156 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_132); all_gather_into_tensor_132 = None + permute_131 = torch.ops.aten.permute.default(wait_tensor_156, [1, 0]); wait_tensor_156 = None + view_867 = torch.ops.aten.view.default(mul_95, [16384, 1792]); mul_95 = None + mm_83 = torch.ops.aten.mm.default(view_867, permute_131); view_867 = permute_131 = None + view_868 = torch.ops.aten.view.default(mm_83, [2, 8192, 4096]); mm_83 = None + split_56 = torch.ops.aten.split.Tensor(view_868, 1024, 1); view_868 = None + getitem_556 = split_56[0] + getitem_557 = split_56[1] + getitem_558 = split_56[2] + getitem_559 = split_56[3] + getitem_560 = split_56[4] + getitem_561 = split_56[5] + getitem_562 = split_56[6] + getitem_563 = split_56[7]; split_56 = None + cat_48 = torch.ops.aten.cat.default([getitem_556, getitem_557, getitem_558, getitem_559, getitem_560, getitem_561, getitem_562, getitem_563]); getitem_556 = getitem_557 = getitem_558 = getitem_559 = getitem_560 = getitem_561 = getitem_562 = getitem_563 = None + reduce_scatter_tensor_24 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_48, 'sum', 8, '1'); cat_48 = None + wait_tensor_157 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_24); reduce_scatter_tensor_24 = None + add_47 = torch.ops.aten.add.Tensor(add_45, wait_tensor_157); add_45 = wait_tensor_157 = None + convert_element_type_397 = torch.ops.prims.convert_element_type.default(primals_112, torch.bfloat16) + all_gather_into_tensor_133 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_397, 8, '0'); convert_element_type_397 = None + wait_tensor_158 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_133); all_gather_into_tensor_133 = None + convert_element_type_398 = torch.ops.prims.convert_element_type.default(add_47, torch.float32) + pow_25 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_398, 2) + mean_24 = torch.ops.aten.mean.dim(pow_25, [2], True); pow_25 = None + add_48 = torch.ops.aten.add.Scalar(mean_24, 1e-05); mean_24 = None + rsqrt_24 = torch.ops.aten.rsqrt.default(add_48); add_48 = None + mul_96 = torch.ops.aten.mul.Tensor(convert_element_type_398, rsqrt_24); convert_element_type_398 = rsqrt_24 = None + mul_97 = torch.ops.aten.mul.Tensor(mul_96, wait_tensor_158); mul_96 = wait_tensor_158 = None + convert_element_type_399 = torch.ops.prims.convert_element_type.default(mul_97, torch.bfloat16); mul_97 = None + all_gather_into_tensor_134 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_399, 8, '1'); convert_element_type_399 = None + wait_tensor_159 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_134); all_gather_into_tensor_134 = None + split_57 = torch.ops.aten.split.Tensor(wait_tensor_159, 2); wait_tensor_159 = None + getitem_564 = split_57[0] + getitem_565 = split_57[1] + getitem_566 = split_57[2] + getitem_567 = split_57[3] + getitem_568 = split_57[4] + getitem_569 = split_57[5] + getitem_570 = split_57[6] + getitem_571 = split_57[7]; split_57 = None + cat_49 = torch.ops.aten.cat.default([getitem_564, getitem_565, getitem_566, getitem_567, getitem_568, getitem_569, getitem_570, getitem_571], 1); getitem_564 = getitem_565 = getitem_566 = getitem_567 = getitem_568 = getitem_569 = getitem_570 = getitem_571 = None + convert_element_type_400 = torch.ops.prims.convert_element_type.default(primals_113, torch.bfloat16) + all_gather_into_tensor_135 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_400, 8, '0'); convert_element_type_400 = None + wait_tensor_160 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_135); all_gather_into_tensor_135 = None + permute_132 = torch.ops.aten.permute.default(wait_tensor_160, [1, 0]); wait_tensor_160 = None + view_879 = torch.ops.aten.view.default(cat_49, [16384, 4096]); cat_49 = None + mm_84 = torch.ops.aten.mm.default(view_879, permute_132); permute_132 = None + view_880 = torch.ops.aten.view.default(mm_84, [2, 8192, 512]) + convert_element_type_403 = torch.ops.prims.convert_element_type.default(primals_114, torch.bfloat16) + all_gather_into_tensor_136 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_403, 8, '0'); convert_element_type_403 = None + wait_tensor_161 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_136); all_gather_into_tensor_136 = None + permute_133 = torch.ops.aten.permute.default(wait_tensor_161, [1, 0]); wait_tensor_161 = None + mm_85 = torch.ops.aten.mm.default(view_879, permute_133); permute_133 = None + view_887 = torch.ops.aten.view.default(mm_85, [2, 8192, 128]); mm_85 = None + convert_element_type_406 = torch.ops.prims.convert_element_type.default(primals_115, torch.bfloat16) + all_gather_into_tensor_137 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_406, 8, '0'); convert_element_type_406 = None + wait_tensor_162 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_137); all_gather_into_tensor_137 = None + permute_134 = torch.ops.aten.permute.default(wait_tensor_162, [1, 0]); wait_tensor_162 = None + mm_86 = torch.ops.aten.mm.default(view_879, permute_134); view_879 = permute_134 = None + view_894 = torch.ops.aten.view.default(mm_86, [2, 8192, 128]) + view_896 = torch.ops.aten.view.default(view_880, [2, 8192, -1, 128]); view_880 = None + view_897 = torch.ops.aten.view.default(view_887, [2, 8192, -1, 128]); view_887 = None + view_898 = torch.ops.aten.view.default(view_894, [2, 8192, -1, 128]); view_894 = None + convert_element_type_409 = torch.ops.prims.convert_element_type.default(view_896, torch.float32); view_896 = None + view_899 = torch.ops.aten.view.default(convert_element_type_409, [2, 8192, 4, -1, 2]); convert_element_type_409 = None + view_as_complex_24 = torch.ops.aten.view_as_complex.default(view_899); view_899 = None + convert_element_type_410 = torch.ops.prims.convert_element_type.default(view_897, torch.float32); view_897 = None + view_900 = torch.ops.aten.view.default(convert_element_type_410, [2, 8192, 1, -1, 2]); convert_element_type_410 = None + view_as_complex_25 = torch.ops.aten.view_as_complex.default(view_900); view_900 = None + mul_98 = torch.ops.aten.mul.Tensor(view_as_complex_24, view_37); view_as_complex_24 = None + view_as_real_24 = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None + view_902 = torch.ops.aten.view.default(view_as_real_24, [2, 8192, 4, 128]); view_as_real_24 = None + mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_25, view_37); view_as_complex_25 = None + view_as_real_25 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None + view_903 = torch.ops.aten.view.default(view_as_real_25, [2, 8192, 1, 128]); view_as_real_25 = None + convert_element_type_411 = torch.ops.prims.convert_element_type.default(view_902, torch.bfloat16); view_902 = None + convert_element_type_412 = torch.ops.prims.convert_element_type.default(view_903, torch.bfloat16); view_903 = None + unsqueeze_24 = torch.ops.aten.unsqueeze.default(convert_element_type_412, 3); convert_element_type_412 = None + expand_24 = torch.ops.aten.expand.default(unsqueeze_24, [2, 8192, 1, 4, 128]); unsqueeze_24 = None + view_904 = torch.ops.aten.view.default(expand_24, [2, 8192, 4, 128]); expand_24 = None + unsqueeze_25 = torch.ops.aten.unsqueeze.default(view_898, 3); view_898 = None + expand_25 = torch.ops.aten.expand.default(unsqueeze_25, [2, 8192, 1, 4, 128]); unsqueeze_25 = None + view_905 = torch.ops.aten.view.default(expand_25, [2, 8192, 4, 128]); expand_25 = None + permute_135 = torch.ops.aten.permute.default(convert_element_type_411, [0, 2, 1, 3]); convert_element_type_411 = None + permute_136 = torch.ops.aten.permute.default(view_904, [0, 2, 1, 3]); view_904 = None + permute_137 = torch.ops.aten.permute.default(view_905, [0, 2, 1, 3]); view_905 = None + _scaled_dot_product_cudnn_attention_12 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_135, permute_136, permute_137, None, True, 0.0, True); permute_135 = permute_136 = permute_137 = None + getitem_572 = _scaled_dot_product_cudnn_attention_12[0] + getitem_573 = _scaled_dot_product_cudnn_attention_12[1] + getitem_578 = _scaled_dot_product_cudnn_attention_12[6] + getitem_579 = _scaled_dot_product_cudnn_attention_12[7]; _scaled_dot_product_cudnn_attention_12 = None + permute_138 = torch.ops.aten.permute.default(getitem_572, [0, 2, 1, 3]) + view_906 = torch.ops.aten.view.default(permute_138, [2, 8192, -1]); permute_138 = None + convert_element_type_413 = torch.ops.prims.convert_element_type.default(primals_116, torch.bfloat16) + all_gather_into_tensor_138 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_413, 8, '0'); convert_element_type_413 = None + wait_tensor_163 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_138); all_gather_into_tensor_138 = None + permute_139 = torch.ops.aten.permute.default(wait_tensor_163, [1, 0]); wait_tensor_163 = None + view_912 = torch.ops.aten.view.default(view_906, [16384, 512]); view_906 = None + mm_87 = torch.ops.aten.mm.default(view_912, permute_139); view_912 = permute_139 = None + view_913 = torch.ops.aten.view.default(mm_87, [2, 8192, 4096]); mm_87 = None + split_58 = torch.ops.aten.split.Tensor(view_913, 1024, 1); view_913 = None + getitem_581 = split_58[0] + getitem_582 = split_58[1] + getitem_583 = split_58[2] + getitem_584 = split_58[3] + getitem_585 = split_58[4] + getitem_586 = split_58[5] + getitem_587 = split_58[6] + getitem_588 = split_58[7]; split_58 = None + cat_50 = torch.ops.aten.cat.default([getitem_581, getitem_582, getitem_583, getitem_584, getitem_585, getitem_586, getitem_587, getitem_588]); getitem_581 = getitem_582 = getitem_583 = getitem_584 = getitem_585 = getitem_586 = getitem_587 = getitem_588 = None + reduce_scatter_tensor_25 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_50, 'sum', 8, '1'); cat_50 = None + wait_tensor_164 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_25) + add_49 = torch.ops.aten.add.Tensor(add_47, wait_tensor_164); wait_tensor_164 = None + convert_element_type_416 = torch.ops.prims.convert_element_type.default(primals_117, torch.bfloat16) + all_gather_into_tensor_139 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_416, 8, '0'); convert_element_type_416 = None + wait_tensor_165 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_139); all_gather_into_tensor_139 = None + convert_element_type_417 = torch.ops.prims.convert_element_type.default(add_49, torch.float32) + pow_26 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_417, 2) + mean_25 = torch.ops.aten.mean.dim(pow_26, [2], True); pow_26 = None + add_50 = torch.ops.aten.add.Scalar(mean_25, 1e-05); mean_25 = None + rsqrt_25 = torch.ops.aten.rsqrt.default(add_50); add_50 = None + mul_100 = torch.ops.aten.mul.Tensor(convert_element_type_417, rsqrt_25); convert_element_type_417 = rsqrt_25 = None + mul_101 = torch.ops.aten.mul.Tensor(mul_100, wait_tensor_165); mul_100 = wait_tensor_165 = None + convert_element_type_418 = torch.ops.prims.convert_element_type.default(mul_101, torch.bfloat16); mul_101 = None + all_gather_into_tensor_140 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_418, 8, '1'); convert_element_type_418 = None + wait_tensor_166 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_140); all_gather_into_tensor_140 = None + split_59 = torch.ops.aten.split.Tensor(wait_tensor_166, 2); wait_tensor_166 = None + getitem_589 = split_59[0] + getitem_590 = split_59[1] + getitem_591 = split_59[2] + getitem_592 = split_59[3] + getitem_593 = split_59[4] + getitem_594 = split_59[5] + getitem_595 = split_59[6] + getitem_596 = split_59[7]; split_59 = None + cat_51 = torch.ops.aten.cat.default([getitem_589, getitem_590, getitem_591, getitem_592, getitem_593, getitem_594, getitem_595, getitem_596], 1); getitem_589 = getitem_590 = getitem_591 = getitem_592 = getitem_593 = getitem_594 = getitem_595 = getitem_596 = None + convert_element_type_419 = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16) + all_gather_into_tensor_141 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_419, 8, '0'); convert_element_type_419 = None + wait_tensor_167 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_141); all_gather_into_tensor_141 = None + permute_140 = torch.ops.aten.permute.default(wait_tensor_167, [1, 0]); wait_tensor_167 = None + view_924 = torch.ops.aten.view.default(cat_51, [16384, 4096]); cat_51 = None + mm_88 = torch.ops.aten.mm.default(view_924, permute_140); permute_140 = None + view_925 = torch.ops.aten.view.default(mm_88, [2, 8192, 1792]) + convert_element_type_422 = torch.ops.prims.convert_element_type.default(view_925, torch.float32); view_925 = None + sigmoid_12 = torch.ops.aten.sigmoid.default(convert_element_type_422) + mul_102 = torch.ops.aten.mul.Tensor(convert_element_type_422, sigmoid_12); convert_element_type_422 = sigmoid_12 = None + convert_element_type_423 = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None + convert_element_type_424 = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16) + all_gather_into_tensor_142 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_424, 8, '0'); convert_element_type_424 = None + wait_tensor_168 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_142); all_gather_into_tensor_142 = None + permute_141 = torch.ops.aten.permute.default(wait_tensor_168, [1, 0]); wait_tensor_168 = None + mm_89 = torch.ops.aten.mm.default(view_924, permute_141); view_924 = permute_141 = None + view_932 = torch.ops.aten.view.default(mm_89, [2, 8192, 1792]); mm_89 = None + mul_103 = torch.ops.aten.mul.Tensor(convert_element_type_423, view_932); convert_element_type_423 = view_932 = None + convert_element_type_427 = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16) + all_gather_into_tensor_143 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_427, 8, '0'); convert_element_type_427 = None + wait_tensor_169 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_143); all_gather_into_tensor_143 = None + permute_142 = torch.ops.aten.permute.default(wait_tensor_169, [1, 0]); wait_tensor_169 = None + view_939 = torch.ops.aten.view.default(mul_103, [16384, 1792]); mul_103 = None + mm_90 = torch.ops.aten.mm.default(view_939, permute_142); view_939 = permute_142 = None + view_940 = torch.ops.aten.view.default(mm_90, [2, 8192, 4096]); mm_90 = None + split_60 = torch.ops.aten.split.Tensor(view_940, 1024, 1); view_940 = None + getitem_597 = split_60[0] + getitem_598 = split_60[1] + getitem_599 = split_60[2] + getitem_600 = split_60[3] + getitem_601 = split_60[4] + getitem_602 = split_60[5] + getitem_603 = split_60[6] + getitem_604 = split_60[7]; split_60 = None + cat_52 = torch.ops.aten.cat.default([getitem_597, getitem_598, getitem_599, getitem_600, getitem_601, getitem_602, getitem_603, getitem_604]); getitem_597 = getitem_598 = getitem_599 = getitem_600 = getitem_601 = getitem_602 = getitem_603 = getitem_604 = None + reduce_scatter_tensor_26 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_52, 'sum', 8, '1'); cat_52 = None + wait_tensor_170 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_26); reduce_scatter_tensor_26 = None + add_51 = torch.ops.aten.add.Tensor(add_49, wait_tensor_170); add_49 = wait_tensor_170 = None + convert_element_type_430 = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16) + all_gather_into_tensor_144 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_430, 8, '0'); convert_element_type_430 = None + wait_tensor_171 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_144); all_gather_into_tensor_144 = None + convert_element_type_431 = torch.ops.prims.convert_element_type.default(add_51, torch.float32) + pow_27 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_431, 2) + mean_26 = torch.ops.aten.mean.dim(pow_27, [2], True); pow_27 = None + add_52 = torch.ops.aten.add.Scalar(mean_26, 1e-05); mean_26 = None + rsqrt_26 = torch.ops.aten.rsqrt.default(add_52); add_52 = None + mul_104 = torch.ops.aten.mul.Tensor(convert_element_type_431, rsqrt_26); convert_element_type_431 = rsqrt_26 = None + mul_105 = torch.ops.aten.mul.Tensor(mul_104, wait_tensor_171); mul_104 = wait_tensor_171 = None + convert_element_type_432 = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16); mul_105 = None + all_gather_into_tensor_145 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_432, 8, '1'); convert_element_type_432 = None + wait_tensor_172 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_145); all_gather_into_tensor_145 = None + split_61 = torch.ops.aten.split.Tensor(wait_tensor_172, 2); wait_tensor_172 = None + getitem_605 = split_61[0] + getitem_606 = split_61[1] + getitem_607 = split_61[2] + getitem_608 = split_61[3] + getitem_609 = split_61[4] + getitem_610 = split_61[5] + getitem_611 = split_61[6] + getitem_612 = split_61[7]; split_61 = None + cat_53 = torch.ops.aten.cat.default([getitem_605, getitem_606, getitem_607, getitem_608, getitem_609, getitem_610, getitem_611, getitem_612], 1); getitem_605 = getitem_606 = getitem_607 = getitem_608 = getitem_609 = getitem_610 = getitem_611 = getitem_612 = None + convert_element_type_433 = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16) + all_gather_into_tensor_146 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_433, 8, '0'); convert_element_type_433 = None + wait_tensor_173 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_146); all_gather_into_tensor_146 = None + permute_143 = torch.ops.aten.permute.default(wait_tensor_173, [1, 0]); wait_tensor_173 = None + view_951 = torch.ops.aten.view.default(cat_53, [16384, 4096]); cat_53 = None + mm_91 = torch.ops.aten.mm.default(view_951, permute_143); permute_143 = None + view_952 = torch.ops.aten.view.default(mm_91, [2, 8192, 512]) + convert_element_type_436 = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16) + all_gather_into_tensor_147 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_436, 8, '0'); convert_element_type_436 = None + wait_tensor_174 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_147); all_gather_into_tensor_147 = None + permute_144 = torch.ops.aten.permute.default(wait_tensor_174, [1, 0]); wait_tensor_174 = None + mm_92 = torch.ops.aten.mm.default(view_951, permute_144); permute_144 = None + view_959 = torch.ops.aten.view.default(mm_92, [2, 8192, 128]); mm_92 = None + convert_element_type_439 = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16) + all_gather_into_tensor_148 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_439, 8, '0'); convert_element_type_439 = None + wait_tensor_175 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_148); all_gather_into_tensor_148 = None + permute_145 = torch.ops.aten.permute.default(wait_tensor_175, [1, 0]); wait_tensor_175 = None + mm_93 = torch.ops.aten.mm.default(view_951, permute_145); view_951 = permute_145 = None + view_966 = torch.ops.aten.view.default(mm_93, [2, 8192, 128]) + view_968 = torch.ops.aten.view.default(view_952, [2, 8192, -1, 128]); view_952 = None + view_969 = torch.ops.aten.view.default(view_959, [2, 8192, -1, 128]); view_959 = None + view_970 = torch.ops.aten.view.default(view_966, [2, 8192, -1, 128]); view_966 = None + convert_element_type_442 = torch.ops.prims.convert_element_type.default(view_968, torch.float32); view_968 = None + view_971 = torch.ops.aten.view.default(convert_element_type_442, [2, 8192, 4, -1, 2]); convert_element_type_442 = None + view_as_complex_26 = torch.ops.aten.view_as_complex.default(view_971); view_971 = None + convert_element_type_443 = torch.ops.prims.convert_element_type.default(view_969, torch.float32); view_969 = None + view_972 = torch.ops.aten.view.default(convert_element_type_443, [2, 8192, 1, -1, 2]); convert_element_type_443 = None + view_as_complex_27 = torch.ops.aten.view_as_complex.default(view_972); view_972 = None + mul_106 = torch.ops.aten.mul.Tensor(view_as_complex_26, view_37); view_as_complex_26 = None + view_as_real_26 = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None + view_974 = torch.ops.aten.view.default(view_as_real_26, [2, 8192, 4, 128]); view_as_real_26 = None + mul_107 = torch.ops.aten.mul.Tensor(view_as_complex_27, view_37); view_as_complex_27 = None + view_as_real_27 = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None + view_975 = torch.ops.aten.view.default(view_as_real_27, [2, 8192, 1, 128]); view_as_real_27 = None + convert_element_type_444 = torch.ops.prims.convert_element_type.default(view_974, torch.bfloat16); view_974 = None + convert_element_type_445 = torch.ops.prims.convert_element_type.default(view_975, torch.bfloat16); view_975 = None + unsqueeze_26 = torch.ops.aten.unsqueeze.default(convert_element_type_445, 3); convert_element_type_445 = None + expand_26 = torch.ops.aten.expand.default(unsqueeze_26, [2, 8192, 1, 4, 128]); unsqueeze_26 = None + view_976 = torch.ops.aten.view.default(expand_26, [2, 8192, 4, 128]); expand_26 = None + unsqueeze_27 = torch.ops.aten.unsqueeze.default(view_970, 3); view_970 = None + expand_27 = torch.ops.aten.expand.default(unsqueeze_27, [2, 8192, 1, 4, 128]); unsqueeze_27 = None + view_977 = torch.ops.aten.view.default(expand_27, [2, 8192, 4, 128]); expand_27 = None + permute_146 = torch.ops.aten.permute.default(convert_element_type_444, [0, 2, 1, 3]); convert_element_type_444 = None + permute_147 = torch.ops.aten.permute.default(view_976, [0, 2, 1, 3]); view_976 = None + permute_148 = torch.ops.aten.permute.default(view_977, [0, 2, 1, 3]); view_977 = None + _scaled_dot_product_cudnn_attention_13 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_146, permute_147, permute_148, None, True, 0.0, True); permute_146 = permute_147 = permute_148 = None + getitem_613 = _scaled_dot_product_cudnn_attention_13[0] + getitem_614 = _scaled_dot_product_cudnn_attention_13[1] + getitem_619 = _scaled_dot_product_cudnn_attention_13[6] + getitem_620 = _scaled_dot_product_cudnn_attention_13[7]; _scaled_dot_product_cudnn_attention_13 = None + permute_149 = torch.ops.aten.permute.default(getitem_613, [0, 2, 1, 3]) + view_978 = torch.ops.aten.view.default(permute_149, [2, 8192, -1]); permute_149 = None + convert_element_type_446 = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16) + all_gather_into_tensor_149 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_446, 8, '0'); convert_element_type_446 = None + wait_tensor_176 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_149); all_gather_into_tensor_149 = None + permute_150 = torch.ops.aten.permute.default(wait_tensor_176, [1, 0]); wait_tensor_176 = None + view_984 = torch.ops.aten.view.default(view_978, [16384, 512]); view_978 = None + mm_94 = torch.ops.aten.mm.default(view_984, permute_150); view_984 = permute_150 = None + view_985 = torch.ops.aten.view.default(mm_94, [2, 8192, 4096]); mm_94 = None + split_62 = torch.ops.aten.split.Tensor(view_985, 1024, 1); view_985 = None + getitem_622 = split_62[0] + getitem_623 = split_62[1] + getitem_624 = split_62[2] + getitem_625 = split_62[3] + getitem_626 = split_62[4] + getitem_627 = split_62[5] + getitem_628 = split_62[6] + getitem_629 = split_62[7]; split_62 = None + cat_54 = torch.ops.aten.cat.default([getitem_622, getitem_623, getitem_624, getitem_625, getitem_626, getitem_627, getitem_628, getitem_629]); getitem_622 = getitem_623 = getitem_624 = getitem_625 = getitem_626 = getitem_627 = getitem_628 = getitem_629 = None + reduce_scatter_tensor_27 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_54, 'sum', 8, '1'); cat_54 = None + wait_tensor_177 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_27) + add_53 = torch.ops.aten.add.Tensor(add_51, wait_tensor_177); wait_tensor_177 = None + convert_element_type_449 = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16) + all_gather_into_tensor_150 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_449, 8, '0'); convert_element_type_449 = None + wait_tensor_178 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_150); all_gather_into_tensor_150 = None + convert_element_type_450 = torch.ops.prims.convert_element_type.default(add_53, torch.float32) + pow_28 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_450, 2) + mean_27 = torch.ops.aten.mean.dim(pow_28, [2], True); pow_28 = None + add_54 = torch.ops.aten.add.Scalar(mean_27, 1e-05); mean_27 = None + rsqrt_27 = torch.ops.aten.rsqrt.default(add_54); add_54 = None + mul_108 = torch.ops.aten.mul.Tensor(convert_element_type_450, rsqrt_27); convert_element_type_450 = rsqrt_27 = None + mul_109 = torch.ops.aten.mul.Tensor(mul_108, wait_tensor_178); mul_108 = wait_tensor_178 = None + convert_element_type_451 = torch.ops.prims.convert_element_type.default(mul_109, torch.bfloat16); mul_109 = None + all_gather_into_tensor_151 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_451, 8, '1'); convert_element_type_451 = None + wait_tensor_179 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_151); all_gather_into_tensor_151 = None + split_63 = torch.ops.aten.split.Tensor(wait_tensor_179, 2); wait_tensor_179 = None + getitem_630 = split_63[0] + getitem_631 = split_63[1] + getitem_632 = split_63[2] + getitem_633 = split_63[3] + getitem_634 = split_63[4] + getitem_635 = split_63[5] + getitem_636 = split_63[6] + getitem_637 = split_63[7]; split_63 = None + cat_55 = torch.ops.aten.cat.default([getitem_630, getitem_631, getitem_632, getitem_633, getitem_634, getitem_635, getitem_636, getitem_637], 1); getitem_630 = getitem_631 = getitem_632 = getitem_633 = getitem_634 = getitem_635 = getitem_636 = getitem_637 = None + convert_element_type_452 = torch.ops.prims.convert_element_type.default(primals_127, torch.bfloat16) + all_gather_into_tensor_152 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_452, 8, '0'); convert_element_type_452 = None + wait_tensor_180 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_152); all_gather_into_tensor_152 = None + permute_151 = torch.ops.aten.permute.default(wait_tensor_180, [1, 0]); wait_tensor_180 = None + view_996 = torch.ops.aten.view.default(cat_55, [16384, 4096]); cat_55 = None + mm_95 = torch.ops.aten.mm.default(view_996, permute_151); permute_151 = None + view_997 = torch.ops.aten.view.default(mm_95, [2, 8192, 1792]) + convert_element_type_455 = torch.ops.prims.convert_element_type.default(view_997, torch.float32); view_997 = None + sigmoid_13 = torch.ops.aten.sigmoid.default(convert_element_type_455) + mul_110 = torch.ops.aten.mul.Tensor(convert_element_type_455, sigmoid_13); convert_element_type_455 = sigmoid_13 = None + convert_element_type_456 = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None + convert_element_type_457 = torch.ops.prims.convert_element_type.default(primals_128, torch.bfloat16) + all_gather_into_tensor_153 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_457, 8, '0'); convert_element_type_457 = None + wait_tensor_181 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_153); all_gather_into_tensor_153 = None + permute_152 = torch.ops.aten.permute.default(wait_tensor_181, [1, 0]); wait_tensor_181 = None + mm_96 = torch.ops.aten.mm.default(view_996, permute_152); view_996 = permute_152 = None + view_1004 = torch.ops.aten.view.default(mm_96, [2, 8192, 1792]); mm_96 = None + mul_111 = torch.ops.aten.mul.Tensor(convert_element_type_456, view_1004); convert_element_type_456 = view_1004 = None + convert_element_type_460 = torch.ops.prims.convert_element_type.default(primals_129, torch.bfloat16) + all_gather_into_tensor_154 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_460, 8, '0'); convert_element_type_460 = None + wait_tensor_182 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_154); all_gather_into_tensor_154 = None + permute_153 = torch.ops.aten.permute.default(wait_tensor_182, [1, 0]); wait_tensor_182 = None + view_1011 = torch.ops.aten.view.default(mul_111, [16384, 1792]); mul_111 = None + mm_97 = torch.ops.aten.mm.default(view_1011, permute_153); view_1011 = permute_153 = None + view_1012 = torch.ops.aten.view.default(mm_97, [2, 8192, 4096]); mm_97 = None + split_64 = torch.ops.aten.split.Tensor(view_1012, 1024, 1); view_1012 = None + getitem_638 = split_64[0] + getitem_639 = split_64[1] + getitem_640 = split_64[2] + getitem_641 = split_64[3] + getitem_642 = split_64[4] + getitem_643 = split_64[5] + getitem_644 = split_64[6] + getitem_645 = split_64[7]; split_64 = None + cat_56 = torch.ops.aten.cat.default([getitem_638, getitem_639, getitem_640, getitem_641, getitem_642, getitem_643, getitem_644, getitem_645]); getitem_638 = getitem_639 = getitem_640 = getitem_641 = getitem_642 = getitem_643 = getitem_644 = getitem_645 = None + reduce_scatter_tensor_28 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_56, 'sum', 8, '1'); cat_56 = None + wait_tensor_183 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_28); reduce_scatter_tensor_28 = None + add_55 = torch.ops.aten.add.Tensor(add_53, wait_tensor_183); add_53 = wait_tensor_183 = None + convert_element_type_463 = torch.ops.prims.convert_element_type.default(primals_130, torch.bfloat16) + all_gather_into_tensor_155 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_463, 8, '0'); convert_element_type_463 = None + wait_tensor_184 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_155); all_gather_into_tensor_155 = None + convert_element_type_464 = torch.ops.prims.convert_element_type.default(add_55, torch.float32) + pow_29 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_464, 2) + mean_28 = torch.ops.aten.mean.dim(pow_29, [2], True); pow_29 = None + add_56 = torch.ops.aten.add.Scalar(mean_28, 1e-05); mean_28 = None + rsqrt_28 = torch.ops.aten.rsqrt.default(add_56); add_56 = None + mul_112 = torch.ops.aten.mul.Tensor(convert_element_type_464, rsqrt_28); convert_element_type_464 = rsqrt_28 = None + mul_113 = torch.ops.aten.mul.Tensor(mul_112, wait_tensor_184); mul_112 = wait_tensor_184 = None + convert_element_type_465 = torch.ops.prims.convert_element_type.default(mul_113, torch.bfloat16); mul_113 = None + all_gather_into_tensor_156 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_465, 8, '1'); convert_element_type_465 = None + wait_tensor_185 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_156); all_gather_into_tensor_156 = None + split_65 = torch.ops.aten.split.Tensor(wait_tensor_185, 2); wait_tensor_185 = None + getitem_646 = split_65[0] + getitem_647 = split_65[1] + getitem_648 = split_65[2] + getitem_649 = split_65[3] + getitem_650 = split_65[4] + getitem_651 = split_65[5] + getitem_652 = split_65[6] + getitem_653 = split_65[7]; split_65 = None + cat_57 = torch.ops.aten.cat.default([getitem_646, getitem_647, getitem_648, getitem_649, getitem_650, getitem_651, getitem_652, getitem_653], 1); getitem_646 = getitem_647 = getitem_648 = getitem_649 = getitem_650 = getitem_651 = getitem_652 = getitem_653 = None + convert_element_type_466 = torch.ops.prims.convert_element_type.default(primals_131, torch.bfloat16) + all_gather_into_tensor_157 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_466, 8, '0'); convert_element_type_466 = None + wait_tensor_186 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_157); all_gather_into_tensor_157 = None + permute_154 = torch.ops.aten.permute.default(wait_tensor_186, [1, 0]); wait_tensor_186 = None + view_1023 = torch.ops.aten.view.default(cat_57, [16384, 4096]); cat_57 = None + mm_98 = torch.ops.aten.mm.default(view_1023, permute_154); permute_154 = None + view_1024 = torch.ops.aten.view.default(mm_98, [2, 8192, 512]) + convert_element_type_469 = torch.ops.prims.convert_element_type.default(primals_132, torch.bfloat16) + all_gather_into_tensor_158 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_469, 8, '0'); convert_element_type_469 = None + wait_tensor_187 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_158); all_gather_into_tensor_158 = None + permute_155 = torch.ops.aten.permute.default(wait_tensor_187, [1, 0]); wait_tensor_187 = None + mm_99 = torch.ops.aten.mm.default(view_1023, permute_155); permute_155 = None + view_1031 = torch.ops.aten.view.default(mm_99, [2, 8192, 128]); mm_99 = None + convert_element_type_472 = torch.ops.prims.convert_element_type.default(primals_133, torch.bfloat16) + all_gather_into_tensor_159 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_472, 8, '0'); convert_element_type_472 = None + wait_tensor_188 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_159); all_gather_into_tensor_159 = None + permute_156 = torch.ops.aten.permute.default(wait_tensor_188, [1, 0]); wait_tensor_188 = None + mm_100 = torch.ops.aten.mm.default(view_1023, permute_156); view_1023 = permute_156 = None + view_1038 = torch.ops.aten.view.default(mm_100, [2, 8192, 128]) + view_1040 = torch.ops.aten.view.default(view_1024, [2, 8192, -1, 128]); view_1024 = None + view_1041 = torch.ops.aten.view.default(view_1031, [2, 8192, -1, 128]); view_1031 = None + view_1042 = torch.ops.aten.view.default(view_1038, [2, 8192, -1, 128]); view_1038 = None + convert_element_type_475 = torch.ops.prims.convert_element_type.default(view_1040, torch.float32); view_1040 = None + view_1043 = torch.ops.aten.view.default(convert_element_type_475, [2, 8192, 4, -1, 2]); convert_element_type_475 = None + view_as_complex_28 = torch.ops.aten.view_as_complex.default(view_1043); view_1043 = None + convert_element_type_476 = torch.ops.prims.convert_element_type.default(view_1041, torch.float32); view_1041 = None + view_1044 = torch.ops.aten.view.default(convert_element_type_476, [2, 8192, 1, -1, 2]); convert_element_type_476 = None + view_as_complex_29 = torch.ops.aten.view_as_complex.default(view_1044); view_1044 = None + mul_114 = torch.ops.aten.mul.Tensor(view_as_complex_28, view_37); view_as_complex_28 = None + view_as_real_28 = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None + view_1046 = torch.ops.aten.view.default(view_as_real_28, [2, 8192, 4, 128]); view_as_real_28 = None + mul_115 = torch.ops.aten.mul.Tensor(view_as_complex_29, view_37); view_as_complex_29 = None + view_as_real_29 = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None + view_1047 = torch.ops.aten.view.default(view_as_real_29, [2, 8192, 1, 128]); view_as_real_29 = None + convert_element_type_477 = torch.ops.prims.convert_element_type.default(view_1046, torch.bfloat16); view_1046 = None + convert_element_type_478 = torch.ops.prims.convert_element_type.default(view_1047, torch.bfloat16); view_1047 = None + unsqueeze_28 = torch.ops.aten.unsqueeze.default(convert_element_type_478, 3); convert_element_type_478 = None + expand_28 = torch.ops.aten.expand.default(unsqueeze_28, [2, 8192, 1, 4, 128]); unsqueeze_28 = None + view_1048 = torch.ops.aten.view.default(expand_28, [2, 8192, 4, 128]); expand_28 = None + unsqueeze_29 = torch.ops.aten.unsqueeze.default(view_1042, 3); view_1042 = None + expand_29 = torch.ops.aten.expand.default(unsqueeze_29, [2, 8192, 1, 4, 128]); unsqueeze_29 = None + view_1049 = torch.ops.aten.view.default(expand_29, [2, 8192, 4, 128]); expand_29 = None + permute_157 = torch.ops.aten.permute.default(convert_element_type_477, [0, 2, 1, 3]); convert_element_type_477 = None + permute_158 = torch.ops.aten.permute.default(view_1048, [0, 2, 1, 3]); view_1048 = None + permute_159 = torch.ops.aten.permute.default(view_1049, [0, 2, 1, 3]); view_1049 = None + _scaled_dot_product_cudnn_attention_14 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_157, permute_158, permute_159, None, True, 0.0, True); permute_157 = permute_158 = permute_159 = None + getitem_654 = _scaled_dot_product_cudnn_attention_14[0] + getitem_655 = _scaled_dot_product_cudnn_attention_14[1] + getitem_660 = _scaled_dot_product_cudnn_attention_14[6] + getitem_661 = _scaled_dot_product_cudnn_attention_14[7]; _scaled_dot_product_cudnn_attention_14 = None + permute_160 = torch.ops.aten.permute.default(getitem_654, [0, 2, 1, 3]) + view_1050 = torch.ops.aten.view.default(permute_160, [2, 8192, -1]); permute_160 = None + convert_element_type_479 = torch.ops.prims.convert_element_type.default(primals_134, torch.bfloat16) + all_gather_into_tensor_160 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_479, 8, '0'); convert_element_type_479 = None + wait_tensor_189 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_160); all_gather_into_tensor_160 = None + permute_161 = torch.ops.aten.permute.default(wait_tensor_189, [1, 0]); wait_tensor_189 = None + view_1056 = torch.ops.aten.view.default(view_1050, [16384, 512]); view_1050 = None + mm_101 = torch.ops.aten.mm.default(view_1056, permute_161); view_1056 = permute_161 = None + view_1057 = torch.ops.aten.view.default(mm_101, [2, 8192, 4096]); mm_101 = None + split_66 = torch.ops.aten.split.Tensor(view_1057, 1024, 1); view_1057 = None + getitem_663 = split_66[0] + getitem_664 = split_66[1] + getitem_665 = split_66[2] + getitem_666 = split_66[3] + getitem_667 = split_66[4] + getitem_668 = split_66[5] + getitem_669 = split_66[6] + getitem_670 = split_66[7]; split_66 = None + cat_58 = torch.ops.aten.cat.default([getitem_663, getitem_664, getitem_665, getitem_666, getitem_667, getitem_668, getitem_669, getitem_670]); getitem_663 = getitem_664 = getitem_665 = getitem_666 = getitem_667 = getitem_668 = getitem_669 = getitem_670 = None + reduce_scatter_tensor_29 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_58, 'sum', 8, '1'); cat_58 = None + wait_tensor_190 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_29) + add_57 = torch.ops.aten.add.Tensor(add_55, wait_tensor_190); wait_tensor_190 = None + convert_element_type_482 = torch.ops.prims.convert_element_type.default(primals_135, torch.bfloat16) + all_gather_into_tensor_161 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_482, 8, '0'); convert_element_type_482 = None + wait_tensor_191 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_161); all_gather_into_tensor_161 = None + convert_element_type_483 = torch.ops.prims.convert_element_type.default(add_57, torch.float32) + pow_30 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_483, 2) + mean_29 = torch.ops.aten.mean.dim(pow_30, [2], True); pow_30 = None + add_58 = torch.ops.aten.add.Scalar(mean_29, 1e-05); mean_29 = None + rsqrt_29 = torch.ops.aten.rsqrt.default(add_58); add_58 = None + mul_116 = torch.ops.aten.mul.Tensor(convert_element_type_483, rsqrt_29); convert_element_type_483 = rsqrt_29 = None + mul_117 = torch.ops.aten.mul.Tensor(mul_116, wait_tensor_191); mul_116 = wait_tensor_191 = None + convert_element_type_484 = torch.ops.prims.convert_element_type.default(mul_117, torch.bfloat16); mul_117 = None + all_gather_into_tensor_162 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_484, 8, '1'); convert_element_type_484 = None + wait_tensor_192 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_162); all_gather_into_tensor_162 = None + split_67 = torch.ops.aten.split.Tensor(wait_tensor_192, 2); wait_tensor_192 = None + getitem_671 = split_67[0] + getitem_672 = split_67[1] + getitem_673 = split_67[2] + getitem_674 = split_67[3] + getitem_675 = split_67[4] + getitem_676 = split_67[5] + getitem_677 = split_67[6] + getitem_678 = split_67[7]; split_67 = None + cat_59 = torch.ops.aten.cat.default([getitem_671, getitem_672, getitem_673, getitem_674, getitem_675, getitem_676, getitem_677, getitem_678], 1); getitem_671 = getitem_672 = getitem_673 = getitem_674 = getitem_675 = getitem_676 = getitem_677 = getitem_678 = None + convert_element_type_485 = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16) + all_gather_into_tensor_163 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_485, 8, '0'); convert_element_type_485 = None + wait_tensor_193 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_163); all_gather_into_tensor_163 = None + permute_162 = torch.ops.aten.permute.default(wait_tensor_193, [1, 0]); wait_tensor_193 = None + view_1068 = torch.ops.aten.view.default(cat_59, [16384, 4096]); cat_59 = None + mm_102 = torch.ops.aten.mm.default(view_1068, permute_162); permute_162 = None + view_1069 = torch.ops.aten.view.default(mm_102, [2, 8192, 1792]) + convert_element_type_488 = torch.ops.prims.convert_element_type.default(view_1069, torch.float32); view_1069 = None + sigmoid_14 = torch.ops.aten.sigmoid.default(convert_element_type_488) + mul_118 = torch.ops.aten.mul.Tensor(convert_element_type_488, sigmoid_14); convert_element_type_488 = sigmoid_14 = None + convert_element_type_489 = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None + convert_element_type_490 = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16) + all_gather_into_tensor_164 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_490, 8, '0'); convert_element_type_490 = None + wait_tensor_194 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_164); all_gather_into_tensor_164 = None + permute_163 = torch.ops.aten.permute.default(wait_tensor_194, [1, 0]); wait_tensor_194 = None + mm_103 = torch.ops.aten.mm.default(view_1068, permute_163); view_1068 = permute_163 = None + view_1076 = torch.ops.aten.view.default(mm_103, [2, 8192, 1792]); mm_103 = None + mul_119 = torch.ops.aten.mul.Tensor(convert_element_type_489, view_1076); convert_element_type_489 = view_1076 = None + convert_element_type_493 = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16) + all_gather_into_tensor_165 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_493, 8, '0'); convert_element_type_493 = None + wait_tensor_195 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_165); all_gather_into_tensor_165 = None + permute_164 = torch.ops.aten.permute.default(wait_tensor_195, [1, 0]); wait_tensor_195 = None + view_1083 = torch.ops.aten.view.default(mul_119, [16384, 1792]); mul_119 = None + mm_104 = torch.ops.aten.mm.default(view_1083, permute_164); view_1083 = permute_164 = None + view_1084 = torch.ops.aten.view.default(mm_104, [2, 8192, 4096]); mm_104 = None + split_68 = torch.ops.aten.split.Tensor(view_1084, 1024, 1); view_1084 = None + getitem_679 = split_68[0] + getitem_680 = split_68[1] + getitem_681 = split_68[2] + getitem_682 = split_68[3] + getitem_683 = split_68[4] + getitem_684 = split_68[5] + getitem_685 = split_68[6] + getitem_686 = split_68[7]; split_68 = None + cat_60 = torch.ops.aten.cat.default([getitem_679, getitem_680, getitem_681, getitem_682, getitem_683, getitem_684, getitem_685, getitem_686]); getitem_679 = getitem_680 = getitem_681 = getitem_682 = getitem_683 = getitem_684 = getitem_685 = getitem_686 = None + reduce_scatter_tensor_30 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_60, 'sum', 8, '1'); cat_60 = None + wait_tensor_196 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_30); reduce_scatter_tensor_30 = None + add_59 = torch.ops.aten.add.Tensor(add_57, wait_tensor_196); add_57 = wait_tensor_196 = None + convert_element_type_496 = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16) + all_gather_into_tensor_166 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_496, 8, '0'); convert_element_type_496 = None + wait_tensor_197 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_166); all_gather_into_tensor_166 = None + convert_element_type_497 = torch.ops.prims.convert_element_type.default(add_59, torch.float32) + pow_31 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_497, 2) + mean_30 = torch.ops.aten.mean.dim(pow_31, [2], True); pow_31 = None + add_60 = torch.ops.aten.add.Scalar(mean_30, 1e-05); mean_30 = None + rsqrt_30 = torch.ops.aten.rsqrt.default(add_60); add_60 = None + mul_120 = torch.ops.aten.mul.Tensor(convert_element_type_497, rsqrt_30); convert_element_type_497 = rsqrt_30 = None + mul_121 = torch.ops.aten.mul.Tensor(mul_120, wait_tensor_197); mul_120 = wait_tensor_197 = None + convert_element_type_498 = torch.ops.prims.convert_element_type.default(mul_121, torch.bfloat16); mul_121 = None + all_gather_into_tensor_167 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_498, 8, '1'); convert_element_type_498 = None + wait_tensor_198 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_167); all_gather_into_tensor_167 = None + split_69 = torch.ops.aten.split.Tensor(wait_tensor_198, 2); wait_tensor_198 = None + getitem_687 = split_69[0] + getitem_688 = split_69[1] + getitem_689 = split_69[2] + getitem_690 = split_69[3] + getitem_691 = split_69[4] + getitem_692 = split_69[5] + getitem_693 = split_69[6] + getitem_694 = split_69[7]; split_69 = None + cat_61 = torch.ops.aten.cat.default([getitem_687, getitem_688, getitem_689, getitem_690, getitem_691, getitem_692, getitem_693, getitem_694], 1); getitem_687 = getitem_688 = getitem_689 = getitem_690 = getitem_691 = getitem_692 = getitem_693 = getitem_694 = None + convert_element_type_499 = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16) + all_gather_into_tensor_168 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_499, 8, '0'); convert_element_type_499 = None + wait_tensor_199 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_168); all_gather_into_tensor_168 = None + permute_165 = torch.ops.aten.permute.default(wait_tensor_199, [1, 0]); wait_tensor_199 = None + view_1095 = torch.ops.aten.view.default(cat_61, [16384, 4096]); cat_61 = None + mm_105 = torch.ops.aten.mm.default(view_1095, permute_165); permute_165 = None + view_1096 = torch.ops.aten.view.default(mm_105, [2, 8192, 512]) + convert_element_type_502 = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16) + all_gather_into_tensor_169 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_502, 8, '0'); convert_element_type_502 = None + wait_tensor_200 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_169); all_gather_into_tensor_169 = None + permute_166 = torch.ops.aten.permute.default(wait_tensor_200, [1, 0]); wait_tensor_200 = None + mm_106 = torch.ops.aten.mm.default(view_1095, permute_166); permute_166 = None + view_1103 = torch.ops.aten.view.default(mm_106, [2, 8192, 128]); mm_106 = None + convert_element_type_505 = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16) + all_gather_into_tensor_170 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_505, 8, '0'); convert_element_type_505 = None + wait_tensor_201 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_170); all_gather_into_tensor_170 = None + permute_167 = torch.ops.aten.permute.default(wait_tensor_201, [1, 0]); wait_tensor_201 = None + mm_107 = torch.ops.aten.mm.default(view_1095, permute_167); view_1095 = permute_167 = None + view_1110 = torch.ops.aten.view.default(mm_107, [2, 8192, 128]) + view_1112 = torch.ops.aten.view.default(view_1096, [2, 8192, -1, 128]); view_1096 = None + view_1113 = torch.ops.aten.view.default(view_1103, [2, 8192, -1, 128]); view_1103 = None + view_1114 = torch.ops.aten.view.default(view_1110, [2, 8192, -1, 128]); view_1110 = None + convert_element_type_508 = torch.ops.prims.convert_element_type.default(view_1112, torch.float32); view_1112 = None + view_1115 = torch.ops.aten.view.default(convert_element_type_508, [2, 8192, 4, -1, 2]); convert_element_type_508 = None + view_as_complex_30 = torch.ops.aten.view_as_complex.default(view_1115); view_1115 = None + convert_element_type_509 = torch.ops.prims.convert_element_type.default(view_1113, torch.float32); view_1113 = None + view_1116 = torch.ops.aten.view.default(convert_element_type_509, [2, 8192, 1, -1, 2]); convert_element_type_509 = None + view_as_complex_31 = torch.ops.aten.view_as_complex.default(view_1116); view_1116 = None + mul_122 = torch.ops.aten.mul.Tensor(view_as_complex_30, view_37); view_as_complex_30 = None + view_as_real_30 = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None + view_1118 = torch.ops.aten.view.default(view_as_real_30, [2, 8192, 4, 128]); view_as_real_30 = None + mul_123 = torch.ops.aten.mul.Tensor(view_as_complex_31, view_37); view_as_complex_31 = None + view_as_real_31 = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None + view_1119 = torch.ops.aten.view.default(view_as_real_31, [2, 8192, 1, 128]); view_as_real_31 = None + convert_element_type_510 = torch.ops.prims.convert_element_type.default(view_1118, torch.bfloat16); view_1118 = None + convert_element_type_511 = torch.ops.prims.convert_element_type.default(view_1119, torch.bfloat16); view_1119 = None + unsqueeze_30 = torch.ops.aten.unsqueeze.default(convert_element_type_511, 3); convert_element_type_511 = None + expand_30 = torch.ops.aten.expand.default(unsqueeze_30, [2, 8192, 1, 4, 128]); unsqueeze_30 = None + view_1120 = torch.ops.aten.view.default(expand_30, [2, 8192, 4, 128]); expand_30 = None + unsqueeze_31 = torch.ops.aten.unsqueeze.default(view_1114, 3); view_1114 = None + expand_31 = torch.ops.aten.expand.default(unsqueeze_31, [2, 8192, 1, 4, 128]); unsqueeze_31 = None + view_1121 = torch.ops.aten.view.default(expand_31, [2, 8192, 4, 128]); expand_31 = None + permute_168 = torch.ops.aten.permute.default(convert_element_type_510, [0, 2, 1, 3]); convert_element_type_510 = None + permute_169 = torch.ops.aten.permute.default(view_1120, [0, 2, 1, 3]); view_1120 = None + permute_170 = torch.ops.aten.permute.default(view_1121, [0, 2, 1, 3]); view_1121 = None + _scaled_dot_product_cudnn_attention_15 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_168, permute_169, permute_170, None, True, 0.0, True); permute_168 = permute_169 = permute_170 = None + getitem_695 = _scaled_dot_product_cudnn_attention_15[0] + getitem_696 = _scaled_dot_product_cudnn_attention_15[1] + getitem_701 = _scaled_dot_product_cudnn_attention_15[6] + getitem_702 = _scaled_dot_product_cudnn_attention_15[7]; _scaled_dot_product_cudnn_attention_15 = None + permute_171 = torch.ops.aten.permute.default(getitem_695, [0, 2, 1, 3]) + view_1122 = torch.ops.aten.view.default(permute_171, [2, 8192, -1]); permute_171 = None + convert_element_type_512 = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16) + all_gather_into_tensor_171 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_512, 8, '0'); convert_element_type_512 = None + wait_tensor_202 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_171); all_gather_into_tensor_171 = None + permute_172 = torch.ops.aten.permute.default(wait_tensor_202, [1, 0]); wait_tensor_202 = None + view_1128 = torch.ops.aten.view.default(view_1122, [16384, 512]); view_1122 = None + mm_108 = torch.ops.aten.mm.default(view_1128, permute_172); view_1128 = permute_172 = None + view_1129 = torch.ops.aten.view.default(mm_108, [2, 8192, 4096]); mm_108 = None + split_70 = torch.ops.aten.split.Tensor(view_1129, 1024, 1); view_1129 = None + getitem_704 = split_70[0] + getitem_705 = split_70[1] + getitem_706 = split_70[2] + getitem_707 = split_70[3] + getitem_708 = split_70[4] + getitem_709 = split_70[5] + getitem_710 = split_70[6] + getitem_711 = split_70[7]; split_70 = None + cat_62 = torch.ops.aten.cat.default([getitem_704, getitem_705, getitem_706, getitem_707, getitem_708, getitem_709, getitem_710, getitem_711]); getitem_704 = getitem_705 = getitem_706 = getitem_707 = getitem_708 = getitem_709 = getitem_710 = getitem_711 = None + reduce_scatter_tensor_31 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_62, 'sum', 8, '1'); cat_62 = None + wait_tensor_203 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_31) + add_61 = torch.ops.aten.add.Tensor(add_59, wait_tensor_203); wait_tensor_203 = None + convert_element_type_515 = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16) + all_gather_into_tensor_172 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_515, 8, '0'); convert_element_type_515 = None + wait_tensor_204 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_172); all_gather_into_tensor_172 = None + convert_element_type_516 = torch.ops.prims.convert_element_type.default(add_61, torch.float32) + pow_32 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_516, 2) + mean_31 = torch.ops.aten.mean.dim(pow_32, [2], True); pow_32 = None + add_62 = torch.ops.aten.add.Scalar(mean_31, 1e-05); mean_31 = None + rsqrt_31 = torch.ops.aten.rsqrt.default(add_62); add_62 = None + mul_124 = torch.ops.aten.mul.Tensor(convert_element_type_516, rsqrt_31); convert_element_type_516 = rsqrt_31 = None + mul_125 = torch.ops.aten.mul.Tensor(mul_124, wait_tensor_204); mul_124 = wait_tensor_204 = None + convert_element_type_517 = torch.ops.prims.convert_element_type.default(mul_125, torch.bfloat16); mul_125 = None + all_gather_into_tensor_173 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_517, 8, '1'); convert_element_type_517 = None + wait_tensor_205 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_173); all_gather_into_tensor_173 = None + split_71 = torch.ops.aten.split.Tensor(wait_tensor_205, 2); wait_tensor_205 = None + getitem_712 = split_71[0] + getitem_713 = split_71[1] + getitem_714 = split_71[2] + getitem_715 = split_71[3] + getitem_716 = split_71[4] + getitem_717 = split_71[5] + getitem_718 = split_71[6] + getitem_719 = split_71[7]; split_71 = None + cat_63 = torch.ops.aten.cat.default([getitem_712, getitem_713, getitem_714, getitem_715, getitem_716, getitem_717, getitem_718, getitem_719], 1); getitem_712 = getitem_713 = getitem_714 = getitem_715 = getitem_716 = getitem_717 = getitem_718 = getitem_719 = None + convert_element_type_518 = torch.ops.prims.convert_element_type.default(primals_145, torch.bfloat16) + all_gather_into_tensor_174 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_518, 8, '0'); convert_element_type_518 = None + wait_tensor_206 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_174); all_gather_into_tensor_174 = None + permute_173 = torch.ops.aten.permute.default(wait_tensor_206, [1, 0]); wait_tensor_206 = None + view_1140 = torch.ops.aten.view.default(cat_63, [16384, 4096]); cat_63 = None + mm_109 = torch.ops.aten.mm.default(view_1140, permute_173); permute_173 = None + view_1141 = torch.ops.aten.view.default(mm_109, [2, 8192, 1792]) + convert_element_type_521 = torch.ops.prims.convert_element_type.default(view_1141, torch.float32); view_1141 = None + sigmoid_15 = torch.ops.aten.sigmoid.default(convert_element_type_521) + mul_126 = torch.ops.aten.mul.Tensor(convert_element_type_521, sigmoid_15); convert_element_type_521 = sigmoid_15 = None + convert_element_type_522 = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None + convert_element_type_523 = torch.ops.prims.convert_element_type.default(primals_146, torch.bfloat16) + all_gather_into_tensor_175 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_523, 8, '0'); convert_element_type_523 = None + wait_tensor_207 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_175); all_gather_into_tensor_175 = None + permute_174 = torch.ops.aten.permute.default(wait_tensor_207, [1, 0]); wait_tensor_207 = None + mm_110 = torch.ops.aten.mm.default(view_1140, permute_174); view_1140 = permute_174 = None + view_1148 = torch.ops.aten.view.default(mm_110, [2, 8192, 1792]); mm_110 = None + mul_127 = torch.ops.aten.mul.Tensor(convert_element_type_522, view_1148); convert_element_type_522 = view_1148 = None + convert_element_type_526 = torch.ops.prims.convert_element_type.default(primals_147, torch.bfloat16) + all_gather_into_tensor_176 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_526, 8, '0'); convert_element_type_526 = None + wait_tensor_208 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_176); all_gather_into_tensor_176 = None + permute_175 = torch.ops.aten.permute.default(wait_tensor_208, [1, 0]); wait_tensor_208 = None + view_1155 = torch.ops.aten.view.default(mul_127, [16384, 1792]); mul_127 = None + mm_111 = torch.ops.aten.mm.default(view_1155, permute_175); view_1155 = permute_175 = None + view_1156 = torch.ops.aten.view.default(mm_111, [2, 8192, 4096]); mm_111 = None + split_72 = torch.ops.aten.split.Tensor(view_1156, 1024, 1); view_1156 = None + getitem_720 = split_72[0] + getitem_721 = split_72[1] + getitem_722 = split_72[2] + getitem_723 = split_72[3] + getitem_724 = split_72[4] + getitem_725 = split_72[5] + getitem_726 = split_72[6] + getitem_727 = split_72[7]; split_72 = None + cat_64 = torch.ops.aten.cat.default([getitem_720, getitem_721, getitem_722, getitem_723, getitem_724, getitem_725, getitem_726, getitem_727]); getitem_720 = getitem_721 = getitem_722 = getitem_723 = getitem_724 = getitem_725 = getitem_726 = getitem_727 = None + reduce_scatter_tensor_32 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_64, 'sum', 8, '1'); cat_64 = None + wait_tensor_209 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_32); reduce_scatter_tensor_32 = None + add_63 = torch.ops.aten.add.Tensor(add_61, wait_tensor_209); add_61 = wait_tensor_209 = None + convert_element_type_529 = torch.ops.prims.convert_element_type.default(primals_148, torch.bfloat16) + all_gather_into_tensor_177 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_529, 8, '0'); convert_element_type_529 = None + wait_tensor_210 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_177); all_gather_into_tensor_177 = None + convert_element_type_530 = torch.ops.prims.convert_element_type.default(add_63, torch.float32) + pow_33 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_530, 2) + mean_32 = torch.ops.aten.mean.dim(pow_33, [2], True); pow_33 = None + add_64 = torch.ops.aten.add.Scalar(mean_32, 1e-05); mean_32 = None + rsqrt_32 = torch.ops.aten.rsqrt.default(add_64); add_64 = None + mul_128 = torch.ops.aten.mul.Tensor(convert_element_type_530, rsqrt_32); convert_element_type_530 = rsqrt_32 = None + mul_129 = torch.ops.aten.mul.Tensor(mul_128, wait_tensor_210); mul_128 = wait_tensor_210 = None + convert_element_type_531 = torch.ops.prims.convert_element_type.default(mul_129, torch.bfloat16); mul_129 = None + all_gather_into_tensor_178 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_531, 8, '1'); convert_element_type_531 = None + wait_tensor_211 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_178); all_gather_into_tensor_178 = None + split_73 = torch.ops.aten.split.Tensor(wait_tensor_211, 2); wait_tensor_211 = None + getitem_728 = split_73[0] + getitem_729 = split_73[1] + getitem_730 = split_73[2] + getitem_731 = split_73[3] + getitem_732 = split_73[4] + getitem_733 = split_73[5] + getitem_734 = split_73[6] + getitem_735 = split_73[7]; split_73 = None + cat_65 = torch.ops.aten.cat.default([getitem_728, getitem_729, getitem_730, getitem_731, getitem_732, getitem_733, getitem_734, getitem_735], 1); getitem_728 = getitem_729 = getitem_730 = getitem_731 = getitem_732 = getitem_733 = getitem_734 = getitem_735 = None + convert_element_type_532 = torch.ops.prims.convert_element_type.default(primals_149, torch.bfloat16) + all_gather_into_tensor_179 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_532, 8, '0'); convert_element_type_532 = None + wait_tensor_212 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_179); all_gather_into_tensor_179 = None + permute_176 = torch.ops.aten.permute.default(wait_tensor_212, [1, 0]); wait_tensor_212 = None + view_1167 = torch.ops.aten.view.default(cat_65, [16384, 4096]); cat_65 = None + mm_112 = torch.ops.aten.mm.default(view_1167, permute_176); permute_176 = None + view_1168 = torch.ops.aten.view.default(mm_112, [2, 8192, 512]) + convert_element_type_535 = torch.ops.prims.convert_element_type.default(primals_150, torch.bfloat16) + all_gather_into_tensor_180 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_535, 8, '0'); convert_element_type_535 = None + wait_tensor_213 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_180); all_gather_into_tensor_180 = None + permute_177 = torch.ops.aten.permute.default(wait_tensor_213, [1, 0]); wait_tensor_213 = None + mm_113 = torch.ops.aten.mm.default(view_1167, permute_177); permute_177 = None + view_1175 = torch.ops.aten.view.default(mm_113, [2, 8192, 128]); mm_113 = None + convert_element_type_538 = torch.ops.prims.convert_element_type.default(primals_151, torch.bfloat16) + all_gather_into_tensor_181 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_538, 8, '0'); convert_element_type_538 = None + wait_tensor_214 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_181); all_gather_into_tensor_181 = None + permute_178 = torch.ops.aten.permute.default(wait_tensor_214, [1, 0]); wait_tensor_214 = None + mm_114 = torch.ops.aten.mm.default(view_1167, permute_178); view_1167 = permute_178 = None + view_1182 = torch.ops.aten.view.default(mm_114, [2, 8192, 128]) + view_1184 = torch.ops.aten.view.default(view_1168, [2, 8192, -1, 128]); view_1168 = None + view_1185 = torch.ops.aten.view.default(view_1175, [2, 8192, -1, 128]); view_1175 = None + view_1186 = torch.ops.aten.view.default(view_1182, [2, 8192, -1, 128]); view_1182 = None + convert_element_type_541 = torch.ops.prims.convert_element_type.default(view_1184, torch.float32); view_1184 = None + view_1187 = torch.ops.aten.view.default(convert_element_type_541, [2, 8192, 4, -1, 2]); convert_element_type_541 = None + view_as_complex_32 = torch.ops.aten.view_as_complex.default(view_1187); view_1187 = None + convert_element_type_542 = torch.ops.prims.convert_element_type.default(view_1185, torch.float32); view_1185 = None + view_1188 = torch.ops.aten.view.default(convert_element_type_542, [2, 8192, 1, -1, 2]); convert_element_type_542 = None + view_as_complex_33 = torch.ops.aten.view_as_complex.default(view_1188); view_1188 = None + mul_130 = torch.ops.aten.mul.Tensor(view_as_complex_32, view_37); view_as_complex_32 = None + view_as_real_32 = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None + view_1190 = torch.ops.aten.view.default(view_as_real_32, [2, 8192, 4, 128]); view_as_real_32 = None + mul_131 = torch.ops.aten.mul.Tensor(view_as_complex_33, view_37); view_as_complex_33 = None + view_as_real_33 = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None + view_1191 = torch.ops.aten.view.default(view_as_real_33, [2, 8192, 1, 128]); view_as_real_33 = None + convert_element_type_543 = torch.ops.prims.convert_element_type.default(view_1190, torch.bfloat16); view_1190 = None + convert_element_type_544 = torch.ops.prims.convert_element_type.default(view_1191, torch.bfloat16); view_1191 = None + unsqueeze_32 = torch.ops.aten.unsqueeze.default(convert_element_type_544, 3); convert_element_type_544 = None + expand_32 = torch.ops.aten.expand.default(unsqueeze_32, [2, 8192, 1, 4, 128]); unsqueeze_32 = None + view_1192 = torch.ops.aten.view.default(expand_32, [2, 8192, 4, 128]); expand_32 = None + unsqueeze_33 = torch.ops.aten.unsqueeze.default(view_1186, 3); view_1186 = None + expand_33 = torch.ops.aten.expand.default(unsqueeze_33, [2, 8192, 1, 4, 128]); unsqueeze_33 = None + view_1193 = torch.ops.aten.view.default(expand_33, [2, 8192, 4, 128]); expand_33 = None + permute_179 = torch.ops.aten.permute.default(convert_element_type_543, [0, 2, 1, 3]); convert_element_type_543 = None + permute_180 = torch.ops.aten.permute.default(view_1192, [0, 2, 1, 3]); view_1192 = None + permute_181 = torch.ops.aten.permute.default(view_1193, [0, 2, 1, 3]); view_1193 = None + _scaled_dot_product_cudnn_attention_16 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_179, permute_180, permute_181, None, True, 0.0, True); permute_179 = permute_180 = permute_181 = None + getitem_736 = _scaled_dot_product_cudnn_attention_16[0] + getitem_737 = _scaled_dot_product_cudnn_attention_16[1] + getitem_742 = _scaled_dot_product_cudnn_attention_16[6] + getitem_743 = _scaled_dot_product_cudnn_attention_16[7]; _scaled_dot_product_cudnn_attention_16 = None + permute_182 = torch.ops.aten.permute.default(getitem_736, [0, 2, 1, 3]) + view_1194 = torch.ops.aten.view.default(permute_182, [2, 8192, -1]); permute_182 = None + convert_element_type_545 = torch.ops.prims.convert_element_type.default(primals_152, torch.bfloat16) + all_gather_into_tensor_182 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_545, 8, '0'); convert_element_type_545 = None + wait_tensor_215 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_182); all_gather_into_tensor_182 = None + permute_183 = torch.ops.aten.permute.default(wait_tensor_215, [1, 0]); wait_tensor_215 = None + view_1200 = torch.ops.aten.view.default(view_1194, [16384, 512]); view_1194 = None + mm_115 = torch.ops.aten.mm.default(view_1200, permute_183); view_1200 = permute_183 = None + view_1201 = torch.ops.aten.view.default(mm_115, [2, 8192, 4096]); mm_115 = None + split_74 = torch.ops.aten.split.Tensor(view_1201, 1024, 1); view_1201 = None + getitem_745 = split_74[0] + getitem_746 = split_74[1] + getitem_747 = split_74[2] + getitem_748 = split_74[3] + getitem_749 = split_74[4] + getitem_750 = split_74[5] + getitem_751 = split_74[6] + getitem_752 = split_74[7]; split_74 = None + cat_66 = torch.ops.aten.cat.default([getitem_745, getitem_746, getitem_747, getitem_748, getitem_749, getitem_750, getitem_751, getitem_752]); getitem_745 = getitem_746 = getitem_747 = getitem_748 = getitem_749 = getitem_750 = getitem_751 = getitem_752 = None + reduce_scatter_tensor_33 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_66, 'sum', 8, '1'); cat_66 = None + wait_tensor_216 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_33) + add_65 = torch.ops.aten.add.Tensor(add_63, wait_tensor_216); wait_tensor_216 = None + convert_element_type_548 = torch.ops.prims.convert_element_type.default(primals_153, torch.bfloat16) + all_gather_into_tensor_183 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_548, 8, '0'); convert_element_type_548 = None + wait_tensor_217 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_183); all_gather_into_tensor_183 = None + convert_element_type_549 = torch.ops.prims.convert_element_type.default(add_65, torch.float32) + pow_34 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_549, 2) + mean_33 = torch.ops.aten.mean.dim(pow_34, [2], True); pow_34 = None + add_66 = torch.ops.aten.add.Scalar(mean_33, 1e-05); mean_33 = None + rsqrt_33 = torch.ops.aten.rsqrt.default(add_66); add_66 = None + mul_132 = torch.ops.aten.mul.Tensor(convert_element_type_549, rsqrt_33); convert_element_type_549 = rsqrt_33 = None + mul_133 = torch.ops.aten.mul.Tensor(mul_132, wait_tensor_217); mul_132 = wait_tensor_217 = None + convert_element_type_550 = torch.ops.prims.convert_element_type.default(mul_133, torch.bfloat16); mul_133 = None + all_gather_into_tensor_184 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_550, 8, '1'); convert_element_type_550 = None + wait_tensor_218 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_184); all_gather_into_tensor_184 = None + split_75 = torch.ops.aten.split.Tensor(wait_tensor_218, 2); wait_tensor_218 = None + getitem_753 = split_75[0] + getitem_754 = split_75[1] + getitem_755 = split_75[2] + getitem_756 = split_75[3] + getitem_757 = split_75[4] + getitem_758 = split_75[5] + getitem_759 = split_75[6] + getitem_760 = split_75[7]; split_75 = None + cat_67 = torch.ops.aten.cat.default([getitem_753, getitem_754, getitem_755, getitem_756, getitem_757, getitem_758, getitem_759, getitem_760], 1); getitem_753 = getitem_754 = getitem_755 = getitem_756 = getitem_757 = getitem_758 = getitem_759 = getitem_760 = None + convert_element_type_551 = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16) + all_gather_into_tensor_185 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_551, 8, '0'); convert_element_type_551 = None + wait_tensor_219 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_185); all_gather_into_tensor_185 = None + permute_184 = torch.ops.aten.permute.default(wait_tensor_219, [1, 0]); wait_tensor_219 = None + view_1212 = torch.ops.aten.view.default(cat_67, [16384, 4096]); cat_67 = None + mm_116 = torch.ops.aten.mm.default(view_1212, permute_184); permute_184 = None + view_1213 = torch.ops.aten.view.default(mm_116, [2, 8192, 1792]) + convert_element_type_554 = torch.ops.prims.convert_element_type.default(view_1213, torch.float32); view_1213 = None + sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_554) + mul_134 = torch.ops.aten.mul.Tensor(convert_element_type_554, sigmoid_16); convert_element_type_554 = sigmoid_16 = None + convert_element_type_555 = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None + convert_element_type_556 = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16) + all_gather_into_tensor_186 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_556, 8, '0'); convert_element_type_556 = None + wait_tensor_220 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_186); all_gather_into_tensor_186 = None + permute_185 = torch.ops.aten.permute.default(wait_tensor_220, [1, 0]); wait_tensor_220 = None + mm_117 = torch.ops.aten.mm.default(view_1212, permute_185); view_1212 = permute_185 = None + view_1220 = torch.ops.aten.view.default(mm_117, [2, 8192, 1792]); mm_117 = None + mul_135 = torch.ops.aten.mul.Tensor(convert_element_type_555, view_1220); convert_element_type_555 = view_1220 = None + convert_element_type_559 = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16) + all_gather_into_tensor_187 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_559, 8, '0'); convert_element_type_559 = None + wait_tensor_221 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_187); all_gather_into_tensor_187 = None + permute_186 = torch.ops.aten.permute.default(wait_tensor_221, [1, 0]); wait_tensor_221 = None + view_1227 = torch.ops.aten.view.default(mul_135, [16384, 1792]); mul_135 = None + mm_118 = torch.ops.aten.mm.default(view_1227, permute_186); view_1227 = permute_186 = None + view_1228 = torch.ops.aten.view.default(mm_118, [2, 8192, 4096]); mm_118 = None + split_76 = torch.ops.aten.split.Tensor(view_1228, 1024, 1); view_1228 = None + getitem_761 = split_76[0] + getitem_762 = split_76[1] + getitem_763 = split_76[2] + getitem_764 = split_76[3] + getitem_765 = split_76[4] + getitem_766 = split_76[5] + getitem_767 = split_76[6] + getitem_768 = split_76[7]; split_76 = None + cat_68 = torch.ops.aten.cat.default([getitem_761, getitem_762, getitem_763, getitem_764, getitem_765, getitem_766, getitem_767, getitem_768]); getitem_761 = getitem_762 = getitem_763 = getitem_764 = getitem_765 = getitem_766 = getitem_767 = getitem_768 = None + reduce_scatter_tensor_34 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_68, 'sum', 8, '1'); cat_68 = None + wait_tensor_222 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_34); reduce_scatter_tensor_34 = None + add_67 = torch.ops.aten.add.Tensor(add_65, wait_tensor_222); add_65 = wait_tensor_222 = None + convert_element_type_562 = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16) + all_gather_into_tensor_188 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_562, 8, '0'); convert_element_type_562 = None + wait_tensor_223 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_188); all_gather_into_tensor_188 = None + convert_element_type_563 = torch.ops.prims.convert_element_type.default(add_67, torch.float32) + pow_35 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_563, 2) + mean_34 = torch.ops.aten.mean.dim(pow_35, [2], True); pow_35 = None + add_68 = torch.ops.aten.add.Scalar(mean_34, 1e-05); mean_34 = None + rsqrt_34 = torch.ops.aten.rsqrt.default(add_68); add_68 = None + mul_136 = torch.ops.aten.mul.Tensor(convert_element_type_563, rsqrt_34); convert_element_type_563 = rsqrt_34 = None + mul_137 = torch.ops.aten.mul.Tensor(mul_136, wait_tensor_223); mul_136 = wait_tensor_223 = None + convert_element_type_564 = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None + all_gather_into_tensor_189 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_564, 8, '1'); convert_element_type_564 = None + wait_tensor_224 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_189); all_gather_into_tensor_189 = None + split_77 = torch.ops.aten.split.Tensor(wait_tensor_224, 2); wait_tensor_224 = None + getitem_769 = split_77[0] + getitem_770 = split_77[1] + getitem_771 = split_77[2] + getitem_772 = split_77[3] + getitem_773 = split_77[4] + getitem_774 = split_77[5] + getitem_775 = split_77[6] + getitem_776 = split_77[7]; split_77 = None + cat_69 = torch.ops.aten.cat.default([getitem_769, getitem_770, getitem_771, getitem_772, getitem_773, getitem_774, getitem_775, getitem_776], 1); getitem_769 = getitem_770 = getitem_771 = getitem_772 = getitem_773 = getitem_774 = getitem_775 = getitem_776 = None + convert_element_type_565 = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16) + all_gather_into_tensor_190 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_565, 8, '0'); convert_element_type_565 = None + wait_tensor_225 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_190); all_gather_into_tensor_190 = None + permute_187 = torch.ops.aten.permute.default(wait_tensor_225, [1, 0]); wait_tensor_225 = None + view_1239 = torch.ops.aten.view.default(cat_69, [16384, 4096]); cat_69 = None + mm_119 = torch.ops.aten.mm.default(view_1239, permute_187); permute_187 = None + view_1240 = torch.ops.aten.view.default(mm_119, [2, 8192, 512]) + convert_element_type_568 = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16) + all_gather_into_tensor_191 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_568, 8, '0'); convert_element_type_568 = None + wait_tensor_226 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_191); all_gather_into_tensor_191 = None + permute_188 = torch.ops.aten.permute.default(wait_tensor_226, [1, 0]); wait_tensor_226 = None + mm_120 = torch.ops.aten.mm.default(view_1239, permute_188); permute_188 = None + view_1247 = torch.ops.aten.view.default(mm_120, [2, 8192, 128]); mm_120 = None + convert_element_type_571 = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16) + all_gather_into_tensor_192 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_571, 8, '0'); convert_element_type_571 = None + wait_tensor_227 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_192); all_gather_into_tensor_192 = None + permute_189 = torch.ops.aten.permute.default(wait_tensor_227, [1, 0]); wait_tensor_227 = None + mm_121 = torch.ops.aten.mm.default(view_1239, permute_189); view_1239 = permute_189 = None + view_1254 = torch.ops.aten.view.default(mm_121, [2, 8192, 128]) + view_1256 = torch.ops.aten.view.default(view_1240, [2, 8192, -1, 128]); view_1240 = None + view_1257 = torch.ops.aten.view.default(view_1247, [2, 8192, -1, 128]); view_1247 = None + view_1258 = torch.ops.aten.view.default(view_1254, [2, 8192, -1, 128]); view_1254 = None + convert_element_type_574 = torch.ops.prims.convert_element_type.default(view_1256, torch.float32); view_1256 = None + view_1259 = torch.ops.aten.view.default(convert_element_type_574, [2, 8192, 4, -1, 2]); convert_element_type_574 = None + view_as_complex_34 = torch.ops.aten.view_as_complex.default(view_1259); view_1259 = None + convert_element_type_575 = torch.ops.prims.convert_element_type.default(view_1257, torch.float32); view_1257 = None + view_1260 = torch.ops.aten.view.default(convert_element_type_575, [2, 8192, 1, -1, 2]); convert_element_type_575 = None + view_as_complex_35 = torch.ops.aten.view_as_complex.default(view_1260); view_1260 = None + mul_138 = torch.ops.aten.mul.Tensor(view_as_complex_34, view_37); view_as_complex_34 = None + view_as_real_34 = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None + view_1262 = torch.ops.aten.view.default(view_as_real_34, [2, 8192, 4, 128]); view_as_real_34 = None + mul_139 = torch.ops.aten.mul.Tensor(view_as_complex_35, view_37); view_as_complex_35 = None + view_as_real_35 = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None + view_1263 = torch.ops.aten.view.default(view_as_real_35, [2, 8192, 1, 128]); view_as_real_35 = None + convert_element_type_576 = torch.ops.prims.convert_element_type.default(view_1262, torch.bfloat16); view_1262 = None + convert_element_type_577 = torch.ops.prims.convert_element_type.default(view_1263, torch.bfloat16); view_1263 = None + unsqueeze_34 = torch.ops.aten.unsqueeze.default(convert_element_type_577, 3); convert_element_type_577 = None + expand_34 = torch.ops.aten.expand.default(unsqueeze_34, [2, 8192, 1, 4, 128]); unsqueeze_34 = None + view_1264 = torch.ops.aten.view.default(expand_34, [2, 8192, 4, 128]); expand_34 = None + unsqueeze_35 = torch.ops.aten.unsqueeze.default(view_1258, 3); view_1258 = None + expand_35 = torch.ops.aten.expand.default(unsqueeze_35, [2, 8192, 1, 4, 128]); unsqueeze_35 = None + view_1265 = torch.ops.aten.view.default(expand_35, [2, 8192, 4, 128]); expand_35 = None + permute_190 = torch.ops.aten.permute.default(convert_element_type_576, [0, 2, 1, 3]); convert_element_type_576 = None + permute_191 = torch.ops.aten.permute.default(view_1264, [0, 2, 1, 3]); view_1264 = None + permute_192 = torch.ops.aten.permute.default(view_1265, [0, 2, 1, 3]); view_1265 = None + _scaled_dot_product_cudnn_attention_17 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_190, permute_191, permute_192, None, True, 0.0, True); permute_190 = permute_191 = permute_192 = None + getitem_777 = _scaled_dot_product_cudnn_attention_17[0] + getitem_778 = _scaled_dot_product_cudnn_attention_17[1] + getitem_783 = _scaled_dot_product_cudnn_attention_17[6] + getitem_784 = _scaled_dot_product_cudnn_attention_17[7]; _scaled_dot_product_cudnn_attention_17 = None + permute_193 = torch.ops.aten.permute.default(getitem_777, [0, 2, 1, 3]) + view_1266 = torch.ops.aten.view.default(permute_193, [2, 8192, -1]); permute_193 = None + convert_element_type_578 = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16) + all_gather_into_tensor_193 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_578, 8, '0'); convert_element_type_578 = None + wait_tensor_228 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_193); all_gather_into_tensor_193 = None + permute_194 = torch.ops.aten.permute.default(wait_tensor_228, [1, 0]); wait_tensor_228 = None + view_1272 = torch.ops.aten.view.default(view_1266, [16384, 512]); view_1266 = None + mm_122 = torch.ops.aten.mm.default(view_1272, permute_194); view_1272 = permute_194 = None + view_1273 = torch.ops.aten.view.default(mm_122, [2, 8192, 4096]); mm_122 = None + split_78 = torch.ops.aten.split.Tensor(view_1273, 1024, 1); view_1273 = None + getitem_786 = split_78[0] + getitem_787 = split_78[1] + getitem_788 = split_78[2] + getitem_789 = split_78[3] + getitem_790 = split_78[4] + getitem_791 = split_78[5] + getitem_792 = split_78[6] + getitem_793 = split_78[7]; split_78 = None + cat_70 = torch.ops.aten.cat.default([getitem_786, getitem_787, getitem_788, getitem_789, getitem_790, getitem_791, getitem_792, getitem_793]); getitem_786 = getitem_787 = getitem_788 = getitem_789 = getitem_790 = getitem_791 = getitem_792 = getitem_793 = None + reduce_scatter_tensor_35 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_70, 'sum', 8, '1'); cat_70 = None + wait_tensor_229 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_35) + add_69 = torch.ops.aten.add.Tensor(add_67, wait_tensor_229); wait_tensor_229 = None + convert_element_type_581 = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16) + all_gather_into_tensor_194 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_581, 8, '0'); convert_element_type_581 = None + wait_tensor_230 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_194); all_gather_into_tensor_194 = None + convert_element_type_582 = torch.ops.prims.convert_element_type.default(add_69, torch.float32) + pow_36 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_582, 2) + mean_35 = torch.ops.aten.mean.dim(pow_36, [2], True); pow_36 = None + add_70 = torch.ops.aten.add.Scalar(mean_35, 1e-05); mean_35 = None + rsqrt_35 = torch.ops.aten.rsqrt.default(add_70); add_70 = None + mul_140 = torch.ops.aten.mul.Tensor(convert_element_type_582, rsqrt_35); convert_element_type_582 = rsqrt_35 = None + mul_141 = torch.ops.aten.mul.Tensor(mul_140, wait_tensor_230); mul_140 = wait_tensor_230 = None + convert_element_type_583 = torch.ops.prims.convert_element_type.default(mul_141, torch.bfloat16); mul_141 = None + all_gather_into_tensor_195 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_583, 8, '1'); convert_element_type_583 = None + wait_tensor_231 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_195); all_gather_into_tensor_195 = None + split_79 = torch.ops.aten.split.Tensor(wait_tensor_231, 2); wait_tensor_231 = None + getitem_794 = split_79[0] + getitem_795 = split_79[1] + getitem_796 = split_79[2] + getitem_797 = split_79[3] + getitem_798 = split_79[4] + getitem_799 = split_79[5] + getitem_800 = split_79[6] + getitem_801 = split_79[7]; split_79 = None + cat_71 = torch.ops.aten.cat.default([getitem_794, getitem_795, getitem_796, getitem_797, getitem_798, getitem_799, getitem_800, getitem_801], 1); getitem_794 = getitem_795 = getitem_796 = getitem_797 = getitem_798 = getitem_799 = getitem_800 = getitem_801 = None + convert_element_type_584 = torch.ops.prims.convert_element_type.default(primals_163, torch.bfloat16) + all_gather_into_tensor_196 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_584, 8, '0'); convert_element_type_584 = None + wait_tensor_232 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_196); all_gather_into_tensor_196 = None + permute_195 = torch.ops.aten.permute.default(wait_tensor_232, [1, 0]); wait_tensor_232 = None + view_1284 = torch.ops.aten.view.default(cat_71, [16384, 4096]); cat_71 = None + mm_123 = torch.ops.aten.mm.default(view_1284, permute_195); permute_195 = None + view_1285 = torch.ops.aten.view.default(mm_123, [2, 8192, 1792]) + convert_element_type_587 = torch.ops.prims.convert_element_type.default(view_1285, torch.float32); view_1285 = None + sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_587) + mul_142 = torch.ops.aten.mul.Tensor(convert_element_type_587, sigmoid_17); convert_element_type_587 = sigmoid_17 = None + convert_element_type_588 = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None + convert_element_type_589 = torch.ops.prims.convert_element_type.default(primals_164, torch.bfloat16) + all_gather_into_tensor_197 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_589, 8, '0'); convert_element_type_589 = None + wait_tensor_233 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_197); all_gather_into_tensor_197 = None + permute_196 = torch.ops.aten.permute.default(wait_tensor_233, [1, 0]); wait_tensor_233 = None + mm_124 = torch.ops.aten.mm.default(view_1284, permute_196); view_1284 = permute_196 = None + view_1292 = torch.ops.aten.view.default(mm_124, [2, 8192, 1792]); mm_124 = None + mul_143 = torch.ops.aten.mul.Tensor(convert_element_type_588, view_1292); convert_element_type_588 = view_1292 = None + convert_element_type_592 = torch.ops.prims.convert_element_type.default(primals_165, torch.bfloat16) + all_gather_into_tensor_198 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_592, 8, '0'); convert_element_type_592 = None + wait_tensor_234 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_198); all_gather_into_tensor_198 = None + permute_197 = torch.ops.aten.permute.default(wait_tensor_234, [1, 0]); wait_tensor_234 = None + view_1299 = torch.ops.aten.view.default(mul_143, [16384, 1792]); mul_143 = None + mm_125 = torch.ops.aten.mm.default(view_1299, permute_197); view_1299 = permute_197 = None + view_1300 = torch.ops.aten.view.default(mm_125, [2, 8192, 4096]); mm_125 = None + split_80 = torch.ops.aten.split.Tensor(view_1300, 1024, 1); view_1300 = None + getitem_802 = split_80[0] + getitem_803 = split_80[1] + getitem_804 = split_80[2] + getitem_805 = split_80[3] + getitem_806 = split_80[4] + getitem_807 = split_80[5] + getitem_808 = split_80[6] + getitem_809 = split_80[7]; split_80 = None + cat_72 = torch.ops.aten.cat.default([getitem_802, getitem_803, getitem_804, getitem_805, getitem_806, getitem_807, getitem_808, getitem_809]); getitem_802 = getitem_803 = getitem_804 = getitem_805 = getitem_806 = getitem_807 = getitem_808 = getitem_809 = None + reduce_scatter_tensor_36 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_72, 'sum', 8, '1'); cat_72 = None + wait_tensor_235 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_36); reduce_scatter_tensor_36 = None + add_71 = torch.ops.aten.add.Tensor(add_69, wait_tensor_235); add_69 = wait_tensor_235 = None + convert_element_type_595 = torch.ops.prims.convert_element_type.default(primals_166, torch.bfloat16) + all_gather_into_tensor_199 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_595, 8, '0'); convert_element_type_595 = None + wait_tensor_236 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_199); all_gather_into_tensor_199 = None + convert_element_type_596 = torch.ops.prims.convert_element_type.default(add_71, torch.float32) + pow_37 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_596, 2) + mean_36 = torch.ops.aten.mean.dim(pow_37, [2], True); pow_37 = None + add_72 = torch.ops.aten.add.Scalar(mean_36, 1e-05); mean_36 = None + rsqrt_36 = torch.ops.aten.rsqrt.default(add_72); add_72 = None + mul_144 = torch.ops.aten.mul.Tensor(convert_element_type_596, rsqrt_36); convert_element_type_596 = rsqrt_36 = None + mul_145 = torch.ops.aten.mul.Tensor(mul_144, wait_tensor_236); mul_144 = wait_tensor_236 = None + convert_element_type_597 = torch.ops.prims.convert_element_type.default(mul_145, torch.bfloat16); mul_145 = None + all_gather_into_tensor_200 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_597, 8, '1'); convert_element_type_597 = None + wait_tensor_237 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_200); all_gather_into_tensor_200 = None + split_81 = torch.ops.aten.split.Tensor(wait_tensor_237, 2); wait_tensor_237 = None + getitem_810 = split_81[0] + getitem_811 = split_81[1] + getitem_812 = split_81[2] + getitem_813 = split_81[3] + getitem_814 = split_81[4] + getitem_815 = split_81[5] + getitem_816 = split_81[6] + getitem_817 = split_81[7]; split_81 = None + cat_73 = torch.ops.aten.cat.default([getitem_810, getitem_811, getitem_812, getitem_813, getitem_814, getitem_815, getitem_816, getitem_817], 1); getitem_810 = getitem_811 = getitem_812 = getitem_813 = getitem_814 = getitem_815 = getitem_816 = getitem_817 = None + convert_element_type_598 = torch.ops.prims.convert_element_type.default(primals_167, torch.bfloat16) + all_gather_into_tensor_201 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_598, 8, '0'); convert_element_type_598 = None + wait_tensor_238 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_201); all_gather_into_tensor_201 = None + permute_198 = torch.ops.aten.permute.default(wait_tensor_238, [1, 0]); wait_tensor_238 = None + view_1311 = torch.ops.aten.view.default(cat_73, [16384, 4096]); cat_73 = None + mm_126 = torch.ops.aten.mm.default(view_1311, permute_198); permute_198 = None + view_1312 = torch.ops.aten.view.default(mm_126, [2, 8192, 512]) + convert_element_type_601 = torch.ops.prims.convert_element_type.default(primals_168, torch.bfloat16) + all_gather_into_tensor_202 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_601, 8, '0'); convert_element_type_601 = None + wait_tensor_239 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_202); all_gather_into_tensor_202 = None + permute_199 = torch.ops.aten.permute.default(wait_tensor_239, [1, 0]); wait_tensor_239 = None + mm_127 = torch.ops.aten.mm.default(view_1311, permute_199); permute_199 = None + view_1319 = torch.ops.aten.view.default(mm_127, [2, 8192, 128]); mm_127 = None + convert_element_type_604 = torch.ops.prims.convert_element_type.default(primals_169, torch.bfloat16) + all_gather_into_tensor_203 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_604, 8, '0'); convert_element_type_604 = None + wait_tensor_240 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_203); all_gather_into_tensor_203 = None + permute_200 = torch.ops.aten.permute.default(wait_tensor_240, [1, 0]); wait_tensor_240 = None + mm_128 = torch.ops.aten.mm.default(view_1311, permute_200); view_1311 = permute_200 = None + view_1326 = torch.ops.aten.view.default(mm_128, [2, 8192, 128]) + view_1328 = torch.ops.aten.view.default(view_1312, [2, 8192, -1, 128]); view_1312 = None + view_1329 = torch.ops.aten.view.default(view_1319, [2, 8192, -1, 128]); view_1319 = None + view_1330 = torch.ops.aten.view.default(view_1326, [2, 8192, -1, 128]); view_1326 = None + convert_element_type_607 = torch.ops.prims.convert_element_type.default(view_1328, torch.float32); view_1328 = None + view_1331 = torch.ops.aten.view.default(convert_element_type_607, [2, 8192, 4, -1, 2]); convert_element_type_607 = None + view_as_complex_36 = torch.ops.aten.view_as_complex.default(view_1331); view_1331 = None + convert_element_type_608 = torch.ops.prims.convert_element_type.default(view_1329, torch.float32); view_1329 = None + view_1332 = torch.ops.aten.view.default(convert_element_type_608, [2, 8192, 1, -1, 2]); convert_element_type_608 = None + view_as_complex_37 = torch.ops.aten.view_as_complex.default(view_1332); view_1332 = None + mul_146 = torch.ops.aten.mul.Tensor(view_as_complex_36, view_37); view_as_complex_36 = None + view_as_real_36 = torch.ops.aten.view_as_real.default(mul_146); mul_146 = None + view_1334 = torch.ops.aten.view.default(view_as_real_36, [2, 8192, 4, 128]); view_as_real_36 = None + mul_147 = torch.ops.aten.mul.Tensor(view_as_complex_37, view_37); view_as_complex_37 = None + view_as_real_37 = torch.ops.aten.view_as_real.default(mul_147); mul_147 = None + view_1335 = torch.ops.aten.view.default(view_as_real_37, [2, 8192, 1, 128]); view_as_real_37 = None + convert_element_type_609 = torch.ops.prims.convert_element_type.default(view_1334, torch.bfloat16); view_1334 = None + convert_element_type_610 = torch.ops.prims.convert_element_type.default(view_1335, torch.bfloat16); view_1335 = None + unsqueeze_36 = torch.ops.aten.unsqueeze.default(convert_element_type_610, 3); convert_element_type_610 = None + expand_36 = torch.ops.aten.expand.default(unsqueeze_36, [2, 8192, 1, 4, 128]); unsqueeze_36 = None + view_1336 = torch.ops.aten.view.default(expand_36, [2, 8192, 4, 128]); expand_36 = None + unsqueeze_37 = torch.ops.aten.unsqueeze.default(view_1330, 3); view_1330 = None + expand_37 = torch.ops.aten.expand.default(unsqueeze_37, [2, 8192, 1, 4, 128]); unsqueeze_37 = None + view_1337 = torch.ops.aten.view.default(expand_37, [2, 8192, 4, 128]); expand_37 = None + permute_201 = torch.ops.aten.permute.default(convert_element_type_609, [0, 2, 1, 3]); convert_element_type_609 = None + permute_202 = torch.ops.aten.permute.default(view_1336, [0, 2, 1, 3]); view_1336 = None + permute_203 = torch.ops.aten.permute.default(view_1337, [0, 2, 1, 3]); view_1337 = None + _scaled_dot_product_cudnn_attention_18 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_201, permute_202, permute_203, None, True, 0.0, True); permute_201 = permute_202 = permute_203 = None + getitem_818 = _scaled_dot_product_cudnn_attention_18[0] + getitem_819 = _scaled_dot_product_cudnn_attention_18[1] + getitem_824 = _scaled_dot_product_cudnn_attention_18[6] + getitem_825 = _scaled_dot_product_cudnn_attention_18[7]; _scaled_dot_product_cudnn_attention_18 = None + permute_204 = torch.ops.aten.permute.default(getitem_818, [0, 2, 1, 3]) + view_1338 = torch.ops.aten.view.default(permute_204, [2, 8192, -1]); permute_204 = None + convert_element_type_611 = torch.ops.prims.convert_element_type.default(primals_170, torch.bfloat16) + all_gather_into_tensor_204 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_611, 8, '0'); convert_element_type_611 = None + wait_tensor_241 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_204); all_gather_into_tensor_204 = None + permute_205 = torch.ops.aten.permute.default(wait_tensor_241, [1, 0]); wait_tensor_241 = None + view_1344 = torch.ops.aten.view.default(view_1338, [16384, 512]); view_1338 = None + mm_129 = torch.ops.aten.mm.default(view_1344, permute_205); view_1344 = permute_205 = None + view_1345 = torch.ops.aten.view.default(mm_129, [2, 8192, 4096]); mm_129 = None + split_82 = torch.ops.aten.split.Tensor(view_1345, 1024, 1); view_1345 = None + getitem_827 = split_82[0] + getitem_828 = split_82[1] + getitem_829 = split_82[2] + getitem_830 = split_82[3] + getitem_831 = split_82[4] + getitem_832 = split_82[5] + getitem_833 = split_82[6] + getitem_834 = split_82[7]; split_82 = None + cat_74 = torch.ops.aten.cat.default([getitem_827, getitem_828, getitem_829, getitem_830, getitem_831, getitem_832, getitem_833, getitem_834]); getitem_827 = getitem_828 = getitem_829 = getitem_830 = getitem_831 = getitem_832 = getitem_833 = getitem_834 = None + reduce_scatter_tensor_37 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_74, 'sum', 8, '1'); cat_74 = None + wait_tensor_242 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_37) + add_73 = torch.ops.aten.add.Tensor(add_71, wait_tensor_242); wait_tensor_242 = None + convert_element_type_614 = torch.ops.prims.convert_element_type.default(primals_171, torch.bfloat16) + all_gather_into_tensor_205 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_614, 8, '0'); convert_element_type_614 = None + wait_tensor_243 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_205); all_gather_into_tensor_205 = None + convert_element_type_615 = torch.ops.prims.convert_element_type.default(add_73, torch.float32) + pow_38 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_615, 2) + mean_37 = torch.ops.aten.mean.dim(pow_38, [2], True); pow_38 = None + add_74 = torch.ops.aten.add.Scalar(mean_37, 1e-05); mean_37 = None + rsqrt_37 = torch.ops.aten.rsqrt.default(add_74); add_74 = None + mul_148 = torch.ops.aten.mul.Tensor(convert_element_type_615, rsqrt_37); convert_element_type_615 = rsqrt_37 = None + mul_149 = torch.ops.aten.mul.Tensor(mul_148, wait_tensor_243); mul_148 = wait_tensor_243 = None + convert_element_type_616 = torch.ops.prims.convert_element_type.default(mul_149, torch.bfloat16); mul_149 = None + all_gather_into_tensor_206 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_616, 8, '1'); convert_element_type_616 = None + wait_tensor_244 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_206); all_gather_into_tensor_206 = None + split_83 = torch.ops.aten.split.Tensor(wait_tensor_244, 2); wait_tensor_244 = None + getitem_835 = split_83[0] + getitem_836 = split_83[1] + getitem_837 = split_83[2] + getitem_838 = split_83[3] + getitem_839 = split_83[4] + getitem_840 = split_83[5] + getitem_841 = split_83[6] + getitem_842 = split_83[7]; split_83 = None + cat_75 = torch.ops.aten.cat.default([getitem_835, getitem_836, getitem_837, getitem_838, getitem_839, getitem_840, getitem_841, getitem_842], 1); getitem_835 = getitem_836 = getitem_837 = getitem_838 = getitem_839 = getitem_840 = getitem_841 = getitem_842 = None + convert_element_type_617 = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16) + all_gather_into_tensor_207 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_617, 8, '0'); convert_element_type_617 = None + wait_tensor_245 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_207); all_gather_into_tensor_207 = None + permute_206 = torch.ops.aten.permute.default(wait_tensor_245, [1, 0]); wait_tensor_245 = None + view_1356 = torch.ops.aten.view.default(cat_75, [16384, 4096]); cat_75 = None + mm_130 = torch.ops.aten.mm.default(view_1356, permute_206); permute_206 = None + view_1357 = torch.ops.aten.view.default(mm_130, [2, 8192, 1792]) + convert_element_type_620 = torch.ops.prims.convert_element_type.default(view_1357, torch.float32); view_1357 = None + sigmoid_18 = torch.ops.aten.sigmoid.default(convert_element_type_620) + mul_150 = torch.ops.aten.mul.Tensor(convert_element_type_620, sigmoid_18); convert_element_type_620 = sigmoid_18 = None + convert_element_type_621 = torch.ops.prims.convert_element_type.default(mul_150, torch.bfloat16); mul_150 = None + convert_element_type_622 = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16) + all_gather_into_tensor_208 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_622, 8, '0'); convert_element_type_622 = None + wait_tensor_246 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_208); all_gather_into_tensor_208 = None + permute_207 = torch.ops.aten.permute.default(wait_tensor_246, [1, 0]); wait_tensor_246 = None + mm_131 = torch.ops.aten.mm.default(view_1356, permute_207); view_1356 = permute_207 = None + view_1364 = torch.ops.aten.view.default(mm_131, [2, 8192, 1792]); mm_131 = None + mul_151 = torch.ops.aten.mul.Tensor(convert_element_type_621, view_1364); convert_element_type_621 = view_1364 = None + convert_element_type_625 = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16) + all_gather_into_tensor_209 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_625, 8, '0'); convert_element_type_625 = None + wait_tensor_247 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_209); all_gather_into_tensor_209 = None + permute_208 = torch.ops.aten.permute.default(wait_tensor_247, [1, 0]); wait_tensor_247 = None + view_1371 = torch.ops.aten.view.default(mul_151, [16384, 1792]); mul_151 = None + mm_132 = torch.ops.aten.mm.default(view_1371, permute_208); view_1371 = permute_208 = None + view_1372 = torch.ops.aten.view.default(mm_132, [2, 8192, 4096]); mm_132 = None + split_84 = torch.ops.aten.split.Tensor(view_1372, 1024, 1); view_1372 = None + getitem_843 = split_84[0] + getitem_844 = split_84[1] + getitem_845 = split_84[2] + getitem_846 = split_84[3] + getitem_847 = split_84[4] + getitem_848 = split_84[5] + getitem_849 = split_84[6] + getitem_850 = split_84[7]; split_84 = None + cat_76 = torch.ops.aten.cat.default([getitem_843, getitem_844, getitem_845, getitem_846, getitem_847, getitem_848, getitem_849, getitem_850]); getitem_843 = getitem_844 = getitem_845 = getitem_846 = getitem_847 = getitem_848 = getitem_849 = getitem_850 = None + reduce_scatter_tensor_38 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_76, 'sum', 8, '1'); cat_76 = None + wait_tensor_248 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_38); reduce_scatter_tensor_38 = None + add_75 = torch.ops.aten.add.Tensor(add_73, wait_tensor_248); add_73 = wait_tensor_248 = None + convert_element_type_628 = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16) + all_gather_into_tensor_210 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_628, 8, '0'); convert_element_type_628 = None + wait_tensor_249 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_210); all_gather_into_tensor_210 = None + convert_element_type_629 = torch.ops.prims.convert_element_type.default(add_75, torch.float32) + pow_39 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_629, 2) + mean_38 = torch.ops.aten.mean.dim(pow_39, [2], True); pow_39 = None + add_76 = torch.ops.aten.add.Scalar(mean_38, 1e-05); mean_38 = None + rsqrt_38 = torch.ops.aten.rsqrt.default(add_76); add_76 = None + mul_152 = torch.ops.aten.mul.Tensor(convert_element_type_629, rsqrt_38); convert_element_type_629 = rsqrt_38 = None + mul_153 = torch.ops.aten.mul.Tensor(mul_152, wait_tensor_249); mul_152 = wait_tensor_249 = None + convert_element_type_630 = torch.ops.prims.convert_element_type.default(mul_153, torch.bfloat16); mul_153 = None + all_gather_into_tensor_211 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_630, 8, '1'); convert_element_type_630 = None + wait_tensor_250 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_211); all_gather_into_tensor_211 = None + split_85 = torch.ops.aten.split.Tensor(wait_tensor_250, 2); wait_tensor_250 = None + getitem_851 = split_85[0] + getitem_852 = split_85[1] + getitem_853 = split_85[2] + getitem_854 = split_85[3] + getitem_855 = split_85[4] + getitem_856 = split_85[5] + getitem_857 = split_85[6] + getitem_858 = split_85[7]; split_85 = None + cat_77 = torch.ops.aten.cat.default([getitem_851, getitem_852, getitem_853, getitem_854, getitem_855, getitem_856, getitem_857, getitem_858], 1); getitem_851 = getitem_852 = getitem_853 = getitem_854 = getitem_855 = getitem_856 = getitem_857 = getitem_858 = None + convert_element_type_631 = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16) + all_gather_into_tensor_212 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_631, 8, '0'); convert_element_type_631 = None + wait_tensor_251 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_212); all_gather_into_tensor_212 = None + permute_209 = torch.ops.aten.permute.default(wait_tensor_251, [1, 0]); wait_tensor_251 = None + view_1383 = torch.ops.aten.view.default(cat_77, [16384, 4096]); cat_77 = None + mm_133 = torch.ops.aten.mm.default(view_1383, permute_209); permute_209 = None + view_1384 = torch.ops.aten.view.default(mm_133, [2, 8192, 512]) + convert_element_type_634 = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16) + all_gather_into_tensor_213 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_634, 8, '0'); convert_element_type_634 = None + wait_tensor_252 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_213); all_gather_into_tensor_213 = None + permute_210 = torch.ops.aten.permute.default(wait_tensor_252, [1, 0]); wait_tensor_252 = None + mm_134 = torch.ops.aten.mm.default(view_1383, permute_210); permute_210 = None + view_1391 = torch.ops.aten.view.default(mm_134, [2, 8192, 128]); mm_134 = None + convert_element_type_637 = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16) + all_gather_into_tensor_214 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_637, 8, '0'); convert_element_type_637 = None + wait_tensor_253 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_214); all_gather_into_tensor_214 = None + permute_211 = torch.ops.aten.permute.default(wait_tensor_253, [1, 0]); wait_tensor_253 = None + mm_135 = torch.ops.aten.mm.default(view_1383, permute_211); view_1383 = permute_211 = None + view_1398 = torch.ops.aten.view.default(mm_135, [2, 8192, 128]) + view_1400 = torch.ops.aten.view.default(view_1384, [2, 8192, -1, 128]); view_1384 = None + view_1401 = torch.ops.aten.view.default(view_1391, [2, 8192, -1, 128]); view_1391 = None + view_1402 = torch.ops.aten.view.default(view_1398, [2, 8192, -1, 128]); view_1398 = None + convert_element_type_640 = torch.ops.prims.convert_element_type.default(view_1400, torch.float32); view_1400 = None + view_1403 = torch.ops.aten.view.default(convert_element_type_640, [2, 8192, 4, -1, 2]); convert_element_type_640 = None + view_as_complex_38 = torch.ops.aten.view_as_complex.default(view_1403); view_1403 = None + convert_element_type_641 = torch.ops.prims.convert_element_type.default(view_1401, torch.float32); view_1401 = None + view_1404 = torch.ops.aten.view.default(convert_element_type_641, [2, 8192, 1, -1, 2]); convert_element_type_641 = None + view_as_complex_39 = torch.ops.aten.view_as_complex.default(view_1404); view_1404 = None + mul_154 = torch.ops.aten.mul.Tensor(view_as_complex_38, view_37); view_as_complex_38 = None + view_as_real_38 = torch.ops.aten.view_as_real.default(mul_154); mul_154 = None + view_1406 = torch.ops.aten.view.default(view_as_real_38, [2, 8192, 4, 128]); view_as_real_38 = None + mul_155 = torch.ops.aten.mul.Tensor(view_as_complex_39, view_37); view_as_complex_39 = None + view_as_real_39 = torch.ops.aten.view_as_real.default(mul_155); mul_155 = None + view_1407 = torch.ops.aten.view.default(view_as_real_39, [2, 8192, 1, 128]); view_as_real_39 = None + convert_element_type_642 = torch.ops.prims.convert_element_type.default(view_1406, torch.bfloat16); view_1406 = None + convert_element_type_643 = torch.ops.prims.convert_element_type.default(view_1407, torch.bfloat16); view_1407 = None + unsqueeze_38 = torch.ops.aten.unsqueeze.default(convert_element_type_643, 3); convert_element_type_643 = None + expand_38 = torch.ops.aten.expand.default(unsqueeze_38, [2, 8192, 1, 4, 128]); unsqueeze_38 = None + view_1408 = torch.ops.aten.view.default(expand_38, [2, 8192, 4, 128]); expand_38 = None + unsqueeze_39 = torch.ops.aten.unsqueeze.default(view_1402, 3); view_1402 = None + expand_39 = torch.ops.aten.expand.default(unsqueeze_39, [2, 8192, 1, 4, 128]); unsqueeze_39 = None + view_1409 = torch.ops.aten.view.default(expand_39, [2, 8192, 4, 128]); expand_39 = None + permute_212 = torch.ops.aten.permute.default(convert_element_type_642, [0, 2, 1, 3]); convert_element_type_642 = None + permute_213 = torch.ops.aten.permute.default(view_1408, [0, 2, 1, 3]); view_1408 = None + permute_214 = torch.ops.aten.permute.default(view_1409, [0, 2, 1, 3]); view_1409 = None + _scaled_dot_product_cudnn_attention_19 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_212, permute_213, permute_214, None, True, 0.0, True); permute_212 = permute_213 = permute_214 = None + getitem_859 = _scaled_dot_product_cudnn_attention_19[0] + getitem_860 = _scaled_dot_product_cudnn_attention_19[1] + getitem_865 = _scaled_dot_product_cudnn_attention_19[6] + getitem_866 = _scaled_dot_product_cudnn_attention_19[7]; _scaled_dot_product_cudnn_attention_19 = None + permute_215 = torch.ops.aten.permute.default(getitem_859, [0, 2, 1, 3]) + view_1410 = torch.ops.aten.view.default(permute_215, [2, 8192, -1]); permute_215 = None + convert_element_type_644 = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16) + all_gather_into_tensor_215 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_644, 8, '0'); convert_element_type_644 = None + wait_tensor_254 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_215); all_gather_into_tensor_215 = None + permute_216 = torch.ops.aten.permute.default(wait_tensor_254, [1, 0]); wait_tensor_254 = None + view_1416 = torch.ops.aten.view.default(view_1410, [16384, 512]); view_1410 = None + mm_136 = torch.ops.aten.mm.default(view_1416, permute_216); view_1416 = permute_216 = None + view_1417 = torch.ops.aten.view.default(mm_136, [2, 8192, 4096]); mm_136 = None + split_86 = torch.ops.aten.split.Tensor(view_1417, 1024, 1); view_1417 = None + getitem_868 = split_86[0] + getitem_869 = split_86[1] + getitem_870 = split_86[2] + getitem_871 = split_86[3] + getitem_872 = split_86[4] + getitem_873 = split_86[5] + getitem_874 = split_86[6] + getitem_875 = split_86[7]; split_86 = None + cat_78 = torch.ops.aten.cat.default([getitem_868, getitem_869, getitem_870, getitem_871, getitem_872, getitem_873, getitem_874, getitem_875]); getitem_868 = getitem_869 = getitem_870 = getitem_871 = getitem_872 = getitem_873 = getitem_874 = getitem_875 = None + reduce_scatter_tensor_39 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_78, 'sum', 8, '1'); cat_78 = None + wait_tensor_255 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_39) + add_77 = torch.ops.aten.add.Tensor(add_75, wait_tensor_255); wait_tensor_255 = None + convert_element_type_647 = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16) + all_gather_into_tensor_216 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_647, 8, '0'); convert_element_type_647 = None + wait_tensor_256 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_216); all_gather_into_tensor_216 = None + convert_element_type_648 = torch.ops.prims.convert_element_type.default(add_77, torch.float32) + pow_40 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_648, 2) + mean_39 = torch.ops.aten.mean.dim(pow_40, [2], True); pow_40 = None + add_78 = torch.ops.aten.add.Scalar(mean_39, 1e-05); mean_39 = None + rsqrt_39 = torch.ops.aten.rsqrt.default(add_78); add_78 = None + mul_156 = torch.ops.aten.mul.Tensor(convert_element_type_648, rsqrt_39); convert_element_type_648 = rsqrt_39 = None + mul_157 = torch.ops.aten.mul.Tensor(mul_156, wait_tensor_256); mul_156 = wait_tensor_256 = None + convert_element_type_649 = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16); mul_157 = None + all_gather_into_tensor_217 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_649, 8, '1'); convert_element_type_649 = None + wait_tensor_257 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_217); all_gather_into_tensor_217 = None + split_87 = torch.ops.aten.split.Tensor(wait_tensor_257, 2); wait_tensor_257 = None + getitem_876 = split_87[0] + getitem_877 = split_87[1] + getitem_878 = split_87[2] + getitem_879 = split_87[3] + getitem_880 = split_87[4] + getitem_881 = split_87[5] + getitem_882 = split_87[6] + getitem_883 = split_87[7]; split_87 = None + cat_79 = torch.ops.aten.cat.default([getitem_876, getitem_877, getitem_878, getitem_879, getitem_880, getitem_881, getitem_882, getitem_883], 1); getitem_876 = getitem_877 = getitem_878 = getitem_879 = getitem_880 = getitem_881 = getitem_882 = getitem_883 = None + convert_element_type_650 = torch.ops.prims.convert_element_type.default(primals_181, torch.bfloat16) + all_gather_into_tensor_218 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_650, 8, '0'); convert_element_type_650 = None + wait_tensor_258 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_218); all_gather_into_tensor_218 = None + permute_217 = torch.ops.aten.permute.default(wait_tensor_258, [1, 0]); wait_tensor_258 = None + view_1428 = torch.ops.aten.view.default(cat_79, [16384, 4096]); cat_79 = None + mm_137 = torch.ops.aten.mm.default(view_1428, permute_217); permute_217 = None + view_1429 = torch.ops.aten.view.default(mm_137, [2, 8192, 1792]) + convert_element_type_653 = torch.ops.prims.convert_element_type.default(view_1429, torch.float32); view_1429 = None + sigmoid_19 = torch.ops.aten.sigmoid.default(convert_element_type_653) + mul_158 = torch.ops.aten.mul.Tensor(convert_element_type_653, sigmoid_19); convert_element_type_653 = sigmoid_19 = None + convert_element_type_654 = torch.ops.prims.convert_element_type.default(mul_158, torch.bfloat16); mul_158 = None + convert_element_type_655 = torch.ops.prims.convert_element_type.default(primals_182, torch.bfloat16) + all_gather_into_tensor_219 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_655, 8, '0'); convert_element_type_655 = None + wait_tensor_259 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_219); all_gather_into_tensor_219 = None + permute_218 = torch.ops.aten.permute.default(wait_tensor_259, [1, 0]); wait_tensor_259 = None + mm_138 = torch.ops.aten.mm.default(view_1428, permute_218); view_1428 = permute_218 = None + view_1436 = torch.ops.aten.view.default(mm_138, [2, 8192, 1792]); mm_138 = None + mul_159 = torch.ops.aten.mul.Tensor(convert_element_type_654, view_1436); convert_element_type_654 = view_1436 = None + convert_element_type_658 = torch.ops.prims.convert_element_type.default(primals_183, torch.bfloat16) + all_gather_into_tensor_220 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_658, 8, '0'); convert_element_type_658 = None + wait_tensor_260 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_220); all_gather_into_tensor_220 = None + permute_219 = torch.ops.aten.permute.default(wait_tensor_260, [1, 0]); wait_tensor_260 = None + view_1443 = torch.ops.aten.view.default(mul_159, [16384, 1792]); mul_159 = None + mm_139 = torch.ops.aten.mm.default(view_1443, permute_219); view_1443 = permute_219 = None + view_1444 = torch.ops.aten.view.default(mm_139, [2, 8192, 4096]); mm_139 = None + split_88 = torch.ops.aten.split.Tensor(view_1444, 1024, 1); view_1444 = None + getitem_884 = split_88[0] + getitem_885 = split_88[1] + getitem_886 = split_88[2] + getitem_887 = split_88[3] + getitem_888 = split_88[4] + getitem_889 = split_88[5] + getitem_890 = split_88[6] + getitem_891 = split_88[7]; split_88 = None + cat_80 = torch.ops.aten.cat.default([getitem_884, getitem_885, getitem_886, getitem_887, getitem_888, getitem_889, getitem_890, getitem_891]); getitem_884 = getitem_885 = getitem_886 = getitem_887 = getitem_888 = getitem_889 = getitem_890 = getitem_891 = None + reduce_scatter_tensor_40 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_80, 'sum', 8, '1'); cat_80 = None + wait_tensor_261 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_40); reduce_scatter_tensor_40 = None + add_79 = torch.ops.aten.add.Tensor(add_77, wait_tensor_261); add_77 = wait_tensor_261 = None + convert_element_type_661 = torch.ops.prims.convert_element_type.default(primals_184, torch.bfloat16) + all_gather_into_tensor_221 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_661, 8, '0'); convert_element_type_661 = None + wait_tensor_262 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_221); all_gather_into_tensor_221 = None + convert_element_type_662 = torch.ops.prims.convert_element_type.default(add_79, torch.float32) + pow_41 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_662, 2) + mean_40 = torch.ops.aten.mean.dim(pow_41, [2], True); pow_41 = None + add_80 = torch.ops.aten.add.Scalar(mean_40, 1e-05); mean_40 = None + rsqrt_40 = torch.ops.aten.rsqrt.default(add_80); add_80 = None + mul_160 = torch.ops.aten.mul.Tensor(convert_element_type_662, rsqrt_40); convert_element_type_662 = rsqrt_40 = None + mul_161 = torch.ops.aten.mul.Tensor(mul_160, wait_tensor_262); mul_160 = wait_tensor_262 = None + convert_element_type_663 = torch.ops.prims.convert_element_type.default(mul_161, torch.bfloat16); mul_161 = None + all_gather_into_tensor_222 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_663, 8, '1'); convert_element_type_663 = None + wait_tensor_263 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_222); all_gather_into_tensor_222 = None + split_89 = torch.ops.aten.split.Tensor(wait_tensor_263, 2); wait_tensor_263 = None + getitem_892 = split_89[0] + getitem_893 = split_89[1] + getitem_894 = split_89[2] + getitem_895 = split_89[3] + getitem_896 = split_89[4] + getitem_897 = split_89[5] + getitem_898 = split_89[6] + getitem_899 = split_89[7]; split_89 = None + cat_81 = torch.ops.aten.cat.default([getitem_892, getitem_893, getitem_894, getitem_895, getitem_896, getitem_897, getitem_898, getitem_899], 1); getitem_892 = getitem_893 = getitem_894 = getitem_895 = getitem_896 = getitem_897 = getitem_898 = getitem_899 = None + convert_element_type_664 = torch.ops.prims.convert_element_type.default(primals_185, torch.bfloat16) + all_gather_into_tensor_223 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_664, 8, '0'); convert_element_type_664 = None + wait_tensor_264 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_223); all_gather_into_tensor_223 = None + permute_220 = torch.ops.aten.permute.default(wait_tensor_264, [1, 0]); wait_tensor_264 = None + view_1455 = torch.ops.aten.view.default(cat_81, [16384, 4096]); cat_81 = None + mm_140 = torch.ops.aten.mm.default(view_1455, permute_220); permute_220 = None + view_1456 = torch.ops.aten.view.default(mm_140, [2, 8192, 512]) + convert_element_type_667 = torch.ops.prims.convert_element_type.default(primals_186, torch.bfloat16) + all_gather_into_tensor_224 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_667, 8, '0'); convert_element_type_667 = None + wait_tensor_265 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_224); all_gather_into_tensor_224 = None + permute_221 = torch.ops.aten.permute.default(wait_tensor_265, [1, 0]); wait_tensor_265 = None + mm_141 = torch.ops.aten.mm.default(view_1455, permute_221); permute_221 = None + view_1463 = torch.ops.aten.view.default(mm_141, [2, 8192, 128]); mm_141 = None + convert_element_type_670 = torch.ops.prims.convert_element_type.default(primals_187, torch.bfloat16) + all_gather_into_tensor_225 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_670, 8, '0'); convert_element_type_670 = None + wait_tensor_266 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_225); all_gather_into_tensor_225 = None + permute_222 = torch.ops.aten.permute.default(wait_tensor_266, [1, 0]); wait_tensor_266 = None + mm_142 = torch.ops.aten.mm.default(view_1455, permute_222); view_1455 = permute_222 = None + view_1470 = torch.ops.aten.view.default(mm_142, [2, 8192, 128]) + view_1472 = torch.ops.aten.view.default(view_1456, [2, 8192, -1, 128]); view_1456 = None + view_1473 = torch.ops.aten.view.default(view_1463, [2, 8192, -1, 128]); view_1463 = None + view_1474 = torch.ops.aten.view.default(view_1470, [2, 8192, -1, 128]); view_1470 = None + convert_element_type_673 = torch.ops.prims.convert_element_type.default(view_1472, torch.float32); view_1472 = None + view_1475 = torch.ops.aten.view.default(convert_element_type_673, [2, 8192, 4, -1, 2]); convert_element_type_673 = None + view_as_complex_40 = torch.ops.aten.view_as_complex.default(view_1475); view_1475 = None + convert_element_type_674 = torch.ops.prims.convert_element_type.default(view_1473, torch.float32); view_1473 = None + view_1476 = torch.ops.aten.view.default(convert_element_type_674, [2, 8192, 1, -1, 2]); convert_element_type_674 = None + view_as_complex_41 = torch.ops.aten.view_as_complex.default(view_1476); view_1476 = None + mul_162 = torch.ops.aten.mul.Tensor(view_as_complex_40, view_37); view_as_complex_40 = None + view_as_real_40 = torch.ops.aten.view_as_real.default(mul_162); mul_162 = None + view_1478 = torch.ops.aten.view.default(view_as_real_40, [2, 8192, 4, 128]); view_as_real_40 = None + mul_163 = torch.ops.aten.mul.Tensor(view_as_complex_41, view_37); view_as_complex_41 = None + view_as_real_41 = torch.ops.aten.view_as_real.default(mul_163); mul_163 = None + view_1479 = torch.ops.aten.view.default(view_as_real_41, [2, 8192, 1, 128]); view_as_real_41 = None + convert_element_type_675 = torch.ops.prims.convert_element_type.default(view_1478, torch.bfloat16); view_1478 = None + convert_element_type_676 = torch.ops.prims.convert_element_type.default(view_1479, torch.bfloat16); view_1479 = None + unsqueeze_40 = torch.ops.aten.unsqueeze.default(convert_element_type_676, 3); convert_element_type_676 = None + expand_40 = torch.ops.aten.expand.default(unsqueeze_40, [2, 8192, 1, 4, 128]); unsqueeze_40 = None + view_1480 = torch.ops.aten.view.default(expand_40, [2, 8192, 4, 128]); expand_40 = None + unsqueeze_41 = torch.ops.aten.unsqueeze.default(view_1474, 3); view_1474 = None + expand_41 = torch.ops.aten.expand.default(unsqueeze_41, [2, 8192, 1, 4, 128]); unsqueeze_41 = None + view_1481 = torch.ops.aten.view.default(expand_41, [2, 8192, 4, 128]); expand_41 = None + permute_223 = torch.ops.aten.permute.default(convert_element_type_675, [0, 2, 1, 3]); convert_element_type_675 = None + permute_224 = torch.ops.aten.permute.default(view_1480, [0, 2, 1, 3]); view_1480 = None + permute_225 = torch.ops.aten.permute.default(view_1481, [0, 2, 1, 3]); view_1481 = None + _scaled_dot_product_cudnn_attention_20 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_223, permute_224, permute_225, None, True, 0.0, True); permute_223 = permute_224 = permute_225 = None + getitem_900 = _scaled_dot_product_cudnn_attention_20[0] + getitem_901 = _scaled_dot_product_cudnn_attention_20[1] + getitem_906 = _scaled_dot_product_cudnn_attention_20[6] + getitem_907 = _scaled_dot_product_cudnn_attention_20[7]; _scaled_dot_product_cudnn_attention_20 = None + permute_226 = torch.ops.aten.permute.default(getitem_900, [0, 2, 1, 3]) + view_1482 = torch.ops.aten.view.default(permute_226, [2, 8192, -1]); permute_226 = None + convert_element_type_677 = torch.ops.prims.convert_element_type.default(primals_188, torch.bfloat16) + all_gather_into_tensor_226 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_677, 8, '0'); convert_element_type_677 = None + wait_tensor_267 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_226); all_gather_into_tensor_226 = None + permute_227 = torch.ops.aten.permute.default(wait_tensor_267, [1, 0]); wait_tensor_267 = None + view_1488 = torch.ops.aten.view.default(view_1482, [16384, 512]); view_1482 = None + mm_143 = torch.ops.aten.mm.default(view_1488, permute_227); view_1488 = permute_227 = None + view_1489 = torch.ops.aten.view.default(mm_143, [2, 8192, 4096]); mm_143 = None + split_90 = torch.ops.aten.split.Tensor(view_1489, 1024, 1); view_1489 = None + getitem_909 = split_90[0] + getitem_910 = split_90[1] + getitem_911 = split_90[2] + getitem_912 = split_90[3] + getitem_913 = split_90[4] + getitem_914 = split_90[5] + getitem_915 = split_90[6] + getitem_916 = split_90[7]; split_90 = None + cat_82 = torch.ops.aten.cat.default([getitem_909, getitem_910, getitem_911, getitem_912, getitem_913, getitem_914, getitem_915, getitem_916]); getitem_909 = getitem_910 = getitem_911 = getitem_912 = getitem_913 = getitem_914 = getitem_915 = getitem_916 = None + reduce_scatter_tensor_41 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_82, 'sum', 8, '1'); cat_82 = None + wait_tensor_268 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_41) + add_81 = torch.ops.aten.add.Tensor(add_79, wait_tensor_268); wait_tensor_268 = None + convert_element_type_680 = torch.ops.prims.convert_element_type.default(primals_189, torch.bfloat16) + all_gather_into_tensor_227 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_680, 8, '0'); convert_element_type_680 = None + wait_tensor_269 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_227); all_gather_into_tensor_227 = None + convert_element_type_681 = torch.ops.prims.convert_element_type.default(add_81, torch.float32) + pow_42 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_681, 2) + mean_41 = torch.ops.aten.mean.dim(pow_42, [2], True); pow_42 = None + add_82 = torch.ops.aten.add.Scalar(mean_41, 1e-05); mean_41 = None + rsqrt_41 = torch.ops.aten.rsqrt.default(add_82); add_82 = None + mul_164 = torch.ops.aten.mul.Tensor(convert_element_type_681, rsqrt_41); convert_element_type_681 = rsqrt_41 = None + mul_165 = torch.ops.aten.mul.Tensor(mul_164, wait_tensor_269); mul_164 = wait_tensor_269 = None + convert_element_type_682 = torch.ops.prims.convert_element_type.default(mul_165, torch.bfloat16); mul_165 = None + all_gather_into_tensor_228 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_682, 8, '1'); convert_element_type_682 = None + wait_tensor_270 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_228); all_gather_into_tensor_228 = None + split_91 = torch.ops.aten.split.Tensor(wait_tensor_270, 2); wait_tensor_270 = None + getitem_917 = split_91[0] + getitem_918 = split_91[1] + getitem_919 = split_91[2] + getitem_920 = split_91[3] + getitem_921 = split_91[4] + getitem_922 = split_91[5] + getitem_923 = split_91[6] + getitem_924 = split_91[7]; split_91 = None + cat_83 = torch.ops.aten.cat.default([getitem_917, getitem_918, getitem_919, getitem_920, getitem_921, getitem_922, getitem_923, getitem_924], 1); getitem_917 = getitem_918 = getitem_919 = getitem_920 = getitem_921 = getitem_922 = getitem_923 = getitem_924 = None + convert_element_type_683 = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16) + all_gather_into_tensor_229 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_683, 8, '0'); convert_element_type_683 = None + wait_tensor_271 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_229); all_gather_into_tensor_229 = None + permute_228 = torch.ops.aten.permute.default(wait_tensor_271, [1, 0]); wait_tensor_271 = None + view_1500 = torch.ops.aten.view.default(cat_83, [16384, 4096]); cat_83 = None + mm_144 = torch.ops.aten.mm.default(view_1500, permute_228); permute_228 = None + view_1501 = torch.ops.aten.view.default(mm_144, [2, 8192, 1792]) + convert_element_type_686 = torch.ops.prims.convert_element_type.default(view_1501, torch.float32); view_1501 = None + sigmoid_20 = torch.ops.aten.sigmoid.default(convert_element_type_686) + mul_166 = torch.ops.aten.mul.Tensor(convert_element_type_686, sigmoid_20); convert_element_type_686 = sigmoid_20 = None + convert_element_type_687 = torch.ops.prims.convert_element_type.default(mul_166, torch.bfloat16); mul_166 = None + convert_element_type_688 = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16) + all_gather_into_tensor_230 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_688, 8, '0'); convert_element_type_688 = None + wait_tensor_272 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_230); all_gather_into_tensor_230 = None + permute_229 = torch.ops.aten.permute.default(wait_tensor_272, [1, 0]); wait_tensor_272 = None + mm_145 = torch.ops.aten.mm.default(view_1500, permute_229); view_1500 = permute_229 = None + view_1508 = torch.ops.aten.view.default(mm_145, [2, 8192, 1792]); mm_145 = None + mul_167 = torch.ops.aten.mul.Tensor(convert_element_type_687, view_1508); convert_element_type_687 = view_1508 = None + convert_element_type_691 = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16) + all_gather_into_tensor_231 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_691, 8, '0'); convert_element_type_691 = None + wait_tensor_273 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_231); all_gather_into_tensor_231 = None + permute_230 = torch.ops.aten.permute.default(wait_tensor_273, [1, 0]); wait_tensor_273 = None + view_1515 = torch.ops.aten.view.default(mul_167, [16384, 1792]); mul_167 = None + mm_146 = torch.ops.aten.mm.default(view_1515, permute_230); view_1515 = permute_230 = None + view_1516 = torch.ops.aten.view.default(mm_146, [2, 8192, 4096]); mm_146 = None + split_92 = torch.ops.aten.split.Tensor(view_1516, 1024, 1); view_1516 = None + getitem_925 = split_92[0] + getitem_926 = split_92[1] + getitem_927 = split_92[2] + getitem_928 = split_92[3] + getitem_929 = split_92[4] + getitem_930 = split_92[5] + getitem_931 = split_92[6] + getitem_932 = split_92[7]; split_92 = None + cat_84 = torch.ops.aten.cat.default([getitem_925, getitem_926, getitem_927, getitem_928, getitem_929, getitem_930, getitem_931, getitem_932]); getitem_925 = getitem_926 = getitem_927 = getitem_928 = getitem_929 = getitem_930 = getitem_931 = getitem_932 = None + reduce_scatter_tensor_42 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_84, 'sum', 8, '1'); cat_84 = None + wait_tensor_274 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_42); reduce_scatter_tensor_42 = None + add_83 = torch.ops.aten.add.Tensor(add_81, wait_tensor_274); add_81 = wait_tensor_274 = None + convert_element_type_694 = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16) + all_gather_into_tensor_232 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_694, 8, '0'); convert_element_type_694 = None + wait_tensor_275 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_232); all_gather_into_tensor_232 = None + convert_element_type_695 = torch.ops.prims.convert_element_type.default(add_83, torch.float32) + pow_43 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_695, 2) + mean_42 = torch.ops.aten.mean.dim(pow_43, [2], True); pow_43 = None + add_84 = torch.ops.aten.add.Scalar(mean_42, 1e-05); mean_42 = None + rsqrt_42 = torch.ops.aten.rsqrt.default(add_84); add_84 = None + mul_168 = torch.ops.aten.mul.Tensor(convert_element_type_695, rsqrt_42); convert_element_type_695 = rsqrt_42 = None + mul_169 = torch.ops.aten.mul.Tensor(mul_168, wait_tensor_275); mul_168 = wait_tensor_275 = None + convert_element_type_696 = torch.ops.prims.convert_element_type.default(mul_169, torch.bfloat16); mul_169 = None + all_gather_into_tensor_233 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_696, 8, '1'); convert_element_type_696 = None + wait_tensor_276 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_233); all_gather_into_tensor_233 = None + split_93 = torch.ops.aten.split.Tensor(wait_tensor_276, 2); wait_tensor_276 = None + getitem_933 = split_93[0] + getitem_934 = split_93[1] + getitem_935 = split_93[2] + getitem_936 = split_93[3] + getitem_937 = split_93[4] + getitem_938 = split_93[5] + getitem_939 = split_93[6] + getitem_940 = split_93[7]; split_93 = None + cat_85 = torch.ops.aten.cat.default([getitem_933, getitem_934, getitem_935, getitem_936, getitem_937, getitem_938, getitem_939, getitem_940], 1); getitem_933 = getitem_934 = getitem_935 = getitem_936 = getitem_937 = getitem_938 = getitem_939 = getitem_940 = None + convert_element_type_697 = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16) + all_gather_into_tensor_234 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_697, 8, '0'); convert_element_type_697 = None + wait_tensor_277 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_234); all_gather_into_tensor_234 = None + permute_231 = torch.ops.aten.permute.default(wait_tensor_277, [1, 0]); wait_tensor_277 = None + view_1527 = torch.ops.aten.view.default(cat_85, [16384, 4096]); cat_85 = None + mm_147 = torch.ops.aten.mm.default(view_1527, permute_231); permute_231 = None + view_1528 = torch.ops.aten.view.default(mm_147, [2, 8192, 512]) + convert_element_type_700 = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16) + all_gather_into_tensor_235 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_700, 8, '0'); convert_element_type_700 = None + wait_tensor_278 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_235); all_gather_into_tensor_235 = None + permute_232 = torch.ops.aten.permute.default(wait_tensor_278, [1, 0]); wait_tensor_278 = None + mm_148 = torch.ops.aten.mm.default(view_1527, permute_232); permute_232 = None + view_1535 = torch.ops.aten.view.default(mm_148, [2, 8192, 128]); mm_148 = None + convert_element_type_703 = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16) + all_gather_into_tensor_236 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_703, 8, '0'); convert_element_type_703 = None + wait_tensor_279 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_236); all_gather_into_tensor_236 = None + permute_233 = torch.ops.aten.permute.default(wait_tensor_279, [1, 0]); wait_tensor_279 = None + mm_149 = torch.ops.aten.mm.default(view_1527, permute_233); view_1527 = permute_233 = None + view_1542 = torch.ops.aten.view.default(mm_149, [2, 8192, 128]) + view_1544 = torch.ops.aten.view.default(view_1528, [2, 8192, -1, 128]); view_1528 = None + view_1545 = torch.ops.aten.view.default(view_1535, [2, 8192, -1, 128]); view_1535 = None + view_1546 = torch.ops.aten.view.default(view_1542, [2, 8192, -1, 128]); view_1542 = None + convert_element_type_706 = torch.ops.prims.convert_element_type.default(view_1544, torch.float32); view_1544 = None + view_1547 = torch.ops.aten.view.default(convert_element_type_706, [2, 8192, 4, -1, 2]); convert_element_type_706 = None + view_as_complex_42 = torch.ops.aten.view_as_complex.default(view_1547); view_1547 = None + convert_element_type_707 = torch.ops.prims.convert_element_type.default(view_1545, torch.float32); view_1545 = None + view_1548 = torch.ops.aten.view.default(convert_element_type_707, [2, 8192, 1, -1, 2]); convert_element_type_707 = None + view_as_complex_43 = torch.ops.aten.view_as_complex.default(view_1548); view_1548 = None + mul_170 = torch.ops.aten.mul.Tensor(view_as_complex_42, view_37); view_as_complex_42 = None + view_as_real_42 = torch.ops.aten.view_as_real.default(mul_170); mul_170 = None + view_1550 = torch.ops.aten.view.default(view_as_real_42, [2, 8192, 4, 128]); view_as_real_42 = None + mul_171 = torch.ops.aten.mul.Tensor(view_as_complex_43, view_37); view_as_complex_43 = None + view_as_real_43 = torch.ops.aten.view_as_real.default(mul_171); mul_171 = None + view_1551 = torch.ops.aten.view.default(view_as_real_43, [2, 8192, 1, 128]); view_as_real_43 = None + convert_element_type_708 = torch.ops.prims.convert_element_type.default(view_1550, torch.bfloat16); view_1550 = None + convert_element_type_709 = torch.ops.prims.convert_element_type.default(view_1551, torch.bfloat16); view_1551 = None + unsqueeze_42 = torch.ops.aten.unsqueeze.default(convert_element_type_709, 3); convert_element_type_709 = None + expand_42 = torch.ops.aten.expand.default(unsqueeze_42, [2, 8192, 1, 4, 128]); unsqueeze_42 = None + view_1552 = torch.ops.aten.view.default(expand_42, [2, 8192, 4, 128]); expand_42 = None + unsqueeze_43 = torch.ops.aten.unsqueeze.default(view_1546, 3); view_1546 = None + expand_43 = torch.ops.aten.expand.default(unsqueeze_43, [2, 8192, 1, 4, 128]); unsqueeze_43 = None + view_1553 = torch.ops.aten.view.default(expand_43, [2, 8192, 4, 128]); expand_43 = None + permute_234 = torch.ops.aten.permute.default(convert_element_type_708, [0, 2, 1, 3]); convert_element_type_708 = None + permute_235 = torch.ops.aten.permute.default(view_1552, [0, 2, 1, 3]); view_1552 = None + permute_236 = torch.ops.aten.permute.default(view_1553, [0, 2, 1, 3]); view_1553 = None + _scaled_dot_product_cudnn_attention_21 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_234, permute_235, permute_236, None, True, 0.0, True); permute_234 = permute_235 = permute_236 = None + getitem_941 = _scaled_dot_product_cudnn_attention_21[0] + getitem_942 = _scaled_dot_product_cudnn_attention_21[1] + getitem_947 = _scaled_dot_product_cudnn_attention_21[6] + getitem_948 = _scaled_dot_product_cudnn_attention_21[7]; _scaled_dot_product_cudnn_attention_21 = None + permute_237 = torch.ops.aten.permute.default(getitem_941, [0, 2, 1, 3]) + view_1554 = torch.ops.aten.view.default(permute_237, [2, 8192, -1]); permute_237 = None + convert_element_type_710 = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16) + all_gather_into_tensor_237 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_710, 8, '0'); convert_element_type_710 = None + wait_tensor_280 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_237); all_gather_into_tensor_237 = None + permute_238 = torch.ops.aten.permute.default(wait_tensor_280, [1, 0]); wait_tensor_280 = None + view_1560 = torch.ops.aten.view.default(view_1554, [16384, 512]); view_1554 = None + mm_150 = torch.ops.aten.mm.default(view_1560, permute_238); view_1560 = permute_238 = None + view_1561 = torch.ops.aten.view.default(mm_150, [2, 8192, 4096]); mm_150 = None + split_94 = torch.ops.aten.split.Tensor(view_1561, 1024, 1); view_1561 = None + getitem_950 = split_94[0] + getitem_951 = split_94[1] + getitem_952 = split_94[2] + getitem_953 = split_94[3] + getitem_954 = split_94[4] + getitem_955 = split_94[5] + getitem_956 = split_94[6] + getitem_957 = split_94[7]; split_94 = None + cat_86 = torch.ops.aten.cat.default([getitem_950, getitem_951, getitem_952, getitem_953, getitem_954, getitem_955, getitem_956, getitem_957]); getitem_950 = getitem_951 = getitem_952 = getitem_953 = getitem_954 = getitem_955 = getitem_956 = getitem_957 = None + reduce_scatter_tensor_43 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_86, 'sum', 8, '1'); cat_86 = None + wait_tensor_281 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_43) + add_85 = torch.ops.aten.add.Tensor(add_83, wait_tensor_281); wait_tensor_281 = None + convert_element_type_713 = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16) + all_gather_into_tensor_238 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_713, 8, '0'); convert_element_type_713 = None + wait_tensor_282 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_238); all_gather_into_tensor_238 = None + convert_element_type_714 = torch.ops.prims.convert_element_type.default(add_85, torch.float32) + pow_44 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_714, 2) + mean_43 = torch.ops.aten.mean.dim(pow_44, [2], True); pow_44 = None + add_86 = torch.ops.aten.add.Scalar(mean_43, 1e-05); mean_43 = None + rsqrt_43 = torch.ops.aten.rsqrt.default(add_86); add_86 = None + mul_172 = torch.ops.aten.mul.Tensor(convert_element_type_714, rsqrt_43); convert_element_type_714 = rsqrt_43 = None + mul_173 = torch.ops.aten.mul.Tensor(mul_172, wait_tensor_282); mul_172 = wait_tensor_282 = None + convert_element_type_715 = torch.ops.prims.convert_element_type.default(mul_173, torch.bfloat16); mul_173 = None + all_gather_into_tensor_239 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_715, 8, '1'); convert_element_type_715 = None + wait_tensor_283 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_239); all_gather_into_tensor_239 = None + split_95 = torch.ops.aten.split.Tensor(wait_tensor_283, 2); wait_tensor_283 = None + getitem_958 = split_95[0] + getitem_959 = split_95[1] + getitem_960 = split_95[2] + getitem_961 = split_95[3] + getitem_962 = split_95[4] + getitem_963 = split_95[5] + getitem_964 = split_95[6] + getitem_965 = split_95[7]; split_95 = None + cat_87 = torch.ops.aten.cat.default([getitem_958, getitem_959, getitem_960, getitem_961, getitem_962, getitem_963, getitem_964, getitem_965], 1); getitem_958 = getitem_959 = getitem_960 = getitem_961 = getitem_962 = getitem_963 = getitem_964 = getitem_965 = None + convert_element_type_716 = torch.ops.prims.convert_element_type.default(primals_199, torch.bfloat16) + all_gather_into_tensor_240 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_716, 8, '0'); convert_element_type_716 = None + wait_tensor_284 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_240); all_gather_into_tensor_240 = None + permute_239 = torch.ops.aten.permute.default(wait_tensor_284, [1, 0]); wait_tensor_284 = None + view_1572 = torch.ops.aten.view.default(cat_87, [16384, 4096]); cat_87 = None + mm_151 = torch.ops.aten.mm.default(view_1572, permute_239); permute_239 = None + view_1573 = torch.ops.aten.view.default(mm_151, [2, 8192, 1792]) + convert_element_type_719 = torch.ops.prims.convert_element_type.default(view_1573, torch.float32); view_1573 = None + sigmoid_21 = torch.ops.aten.sigmoid.default(convert_element_type_719) + mul_174 = torch.ops.aten.mul.Tensor(convert_element_type_719, sigmoid_21); convert_element_type_719 = sigmoid_21 = None + convert_element_type_720 = torch.ops.prims.convert_element_type.default(mul_174, torch.bfloat16); mul_174 = None + convert_element_type_721 = torch.ops.prims.convert_element_type.default(primals_200, torch.bfloat16) + all_gather_into_tensor_241 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_721, 8, '0'); convert_element_type_721 = None + wait_tensor_285 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_241); all_gather_into_tensor_241 = None + permute_240 = torch.ops.aten.permute.default(wait_tensor_285, [1, 0]); wait_tensor_285 = None + mm_152 = torch.ops.aten.mm.default(view_1572, permute_240); view_1572 = permute_240 = None + view_1580 = torch.ops.aten.view.default(mm_152, [2, 8192, 1792]); mm_152 = None + mul_175 = torch.ops.aten.mul.Tensor(convert_element_type_720, view_1580); convert_element_type_720 = view_1580 = None + convert_element_type_724 = torch.ops.prims.convert_element_type.default(primals_201, torch.bfloat16) + all_gather_into_tensor_242 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_724, 8, '0'); convert_element_type_724 = None + wait_tensor_286 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_242); all_gather_into_tensor_242 = None + permute_241 = torch.ops.aten.permute.default(wait_tensor_286, [1, 0]); wait_tensor_286 = None + view_1587 = torch.ops.aten.view.default(mul_175, [16384, 1792]); mul_175 = None + mm_153 = torch.ops.aten.mm.default(view_1587, permute_241); view_1587 = permute_241 = None + view_1588 = torch.ops.aten.view.default(mm_153, [2, 8192, 4096]); mm_153 = None + split_96 = torch.ops.aten.split.Tensor(view_1588, 1024, 1); view_1588 = None + getitem_966 = split_96[0] + getitem_967 = split_96[1] + getitem_968 = split_96[2] + getitem_969 = split_96[3] + getitem_970 = split_96[4] + getitem_971 = split_96[5] + getitem_972 = split_96[6] + getitem_973 = split_96[7]; split_96 = None + cat_88 = torch.ops.aten.cat.default([getitem_966, getitem_967, getitem_968, getitem_969, getitem_970, getitem_971, getitem_972, getitem_973]); getitem_966 = getitem_967 = getitem_968 = getitem_969 = getitem_970 = getitem_971 = getitem_972 = getitem_973 = None + reduce_scatter_tensor_44 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_88, 'sum', 8, '1'); cat_88 = None + wait_tensor_287 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_44); reduce_scatter_tensor_44 = None + add_87 = torch.ops.aten.add.Tensor(add_85, wait_tensor_287); add_85 = wait_tensor_287 = None + convert_element_type_727 = torch.ops.prims.convert_element_type.default(primals_202, torch.bfloat16) + all_gather_into_tensor_243 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_727, 8, '0'); convert_element_type_727 = None + wait_tensor_288 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_243); all_gather_into_tensor_243 = None + convert_element_type_728 = torch.ops.prims.convert_element_type.default(add_87, torch.float32) + pow_45 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_728, 2) + mean_44 = torch.ops.aten.mean.dim(pow_45, [2], True); pow_45 = None + add_88 = torch.ops.aten.add.Scalar(mean_44, 1e-05); mean_44 = None + rsqrt_44 = torch.ops.aten.rsqrt.default(add_88); add_88 = None + mul_176 = torch.ops.aten.mul.Tensor(convert_element_type_728, rsqrt_44); convert_element_type_728 = rsqrt_44 = None + mul_177 = torch.ops.aten.mul.Tensor(mul_176, wait_tensor_288); mul_176 = wait_tensor_288 = None + convert_element_type_729 = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None + all_gather_into_tensor_244 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_729, 8, '1'); convert_element_type_729 = None + wait_tensor_289 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_244); all_gather_into_tensor_244 = None + split_97 = torch.ops.aten.split.Tensor(wait_tensor_289, 2); wait_tensor_289 = None + getitem_974 = split_97[0] + getitem_975 = split_97[1] + getitem_976 = split_97[2] + getitem_977 = split_97[3] + getitem_978 = split_97[4] + getitem_979 = split_97[5] + getitem_980 = split_97[6] + getitem_981 = split_97[7]; split_97 = None + cat_89 = torch.ops.aten.cat.default([getitem_974, getitem_975, getitem_976, getitem_977, getitem_978, getitem_979, getitem_980, getitem_981], 1); getitem_974 = getitem_975 = getitem_976 = getitem_977 = getitem_978 = getitem_979 = getitem_980 = getitem_981 = None + convert_element_type_730 = torch.ops.prims.convert_element_type.default(primals_203, torch.bfloat16) + all_gather_into_tensor_245 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_730, 8, '0'); convert_element_type_730 = None + wait_tensor_290 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_245); all_gather_into_tensor_245 = None + permute_242 = torch.ops.aten.permute.default(wait_tensor_290, [1, 0]); wait_tensor_290 = None + view_1599 = torch.ops.aten.view.default(cat_89, [16384, 4096]); cat_89 = None + mm_154 = torch.ops.aten.mm.default(view_1599, permute_242); permute_242 = None + view_1600 = torch.ops.aten.view.default(mm_154, [2, 8192, 512]) + convert_element_type_733 = torch.ops.prims.convert_element_type.default(primals_204, torch.bfloat16) + all_gather_into_tensor_246 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_733, 8, '0'); convert_element_type_733 = None + wait_tensor_291 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_246); all_gather_into_tensor_246 = None + permute_243 = torch.ops.aten.permute.default(wait_tensor_291, [1, 0]); wait_tensor_291 = None + mm_155 = torch.ops.aten.mm.default(view_1599, permute_243); permute_243 = None + view_1607 = torch.ops.aten.view.default(mm_155, [2, 8192, 128]); mm_155 = None + convert_element_type_736 = torch.ops.prims.convert_element_type.default(primals_205, torch.bfloat16) + all_gather_into_tensor_247 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_736, 8, '0'); convert_element_type_736 = None + wait_tensor_292 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_247); all_gather_into_tensor_247 = None + permute_244 = torch.ops.aten.permute.default(wait_tensor_292, [1, 0]); wait_tensor_292 = None + mm_156 = torch.ops.aten.mm.default(view_1599, permute_244); view_1599 = permute_244 = None + view_1614 = torch.ops.aten.view.default(mm_156, [2, 8192, 128]) + view_1616 = torch.ops.aten.view.default(view_1600, [2, 8192, -1, 128]); view_1600 = None + view_1617 = torch.ops.aten.view.default(view_1607, [2, 8192, -1, 128]); view_1607 = None + view_1618 = torch.ops.aten.view.default(view_1614, [2, 8192, -1, 128]); view_1614 = None + convert_element_type_739 = torch.ops.prims.convert_element_type.default(view_1616, torch.float32); view_1616 = None + view_1619 = torch.ops.aten.view.default(convert_element_type_739, [2, 8192, 4, -1, 2]); convert_element_type_739 = None + view_as_complex_44 = torch.ops.aten.view_as_complex.default(view_1619); view_1619 = None + convert_element_type_740 = torch.ops.prims.convert_element_type.default(view_1617, torch.float32); view_1617 = None + view_1620 = torch.ops.aten.view.default(convert_element_type_740, [2, 8192, 1, -1, 2]); convert_element_type_740 = None + view_as_complex_45 = torch.ops.aten.view_as_complex.default(view_1620); view_1620 = None + mul_178 = torch.ops.aten.mul.Tensor(view_as_complex_44, view_37); view_as_complex_44 = None + view_as_real_44 = torch.ops.aten.view_as_real.default(mul_178); mul_178 = None + view_1622 = torch.ops.aten.view.default(view_as_real_44, [2, 8192, 4, 128]); view_as_real_44 = None + mul_179 = torch.ops.aten.mul.Tensor(view_as_complex_45, view_37); view_as_complex_45 = None + view_as_real_45 = torch.ops.aten.view_as_real.default(mul_179); mul_179 = None + view_1623 = torch.ops.aten.view.default(view_as_real_45, [2, 8192, 1, 128]); view_as_real_45 = None + convert_element_type_741 = torch.ops.prims.convert_element_type.default(view_1622, torch.bfloat16); view_1622 = None + convert_element_type_742 = torch.ops.prims.convert_element_type.default(view_1623, torch.bfloat16); view_1623 = None + unsqueeze_44 = torch.ops.aten.unsqueeze.default(convert_element_type_742, 3); convert_element_type_742 = None + expand_44 = torch.ops.aten.expand.default(unsqueeze_44, [2, 8192, 1, 4, 128]); unsqueeze_44 = None + view_1624 = torch.ops.aten.view.default(expand_44, [2, 8192, 4, 128]); expand_44 = None + unsqueeze_45 = torch.ops.aten.unsqueeze.default(view_1618, 3); view_1618 = None + expand_45 = torch.ops.aten.expand.default(unsqueeze_45, [2, 8192, 1, 4, 128]); unsqueeze_45 = None + view_1625 = torch.ops.aten.view.default(expand_45, [2, 8192, 4, 128]); expand_45 = None + permute_245 = torch.ops.aten.permute.default(convert_element_type_741, [0, 2, 1, 3]); convert_element_type_741 = None + permute_246 = torch.ops.aten.permute.default(view_1624, [0, 2, 1, 3]); view_1624 = None + permute_247 = torch.ops.aten.permute.default(view_1625, [0, 2, 1, 3]); view_1625 = None + _scaled_dot_product_cudnn_attention_22 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_245, permute_246, permute_247, None, True, 0.0, True); permute_245 = permute_246 = permute_247 = None + getitem_982 = _scaled_dot_product_cudnn_attention_22[0] + getitem_983 = _scaled_dot_product_cudnn_attention_22[1] + getitem_988 = _scaled_dot_product_cudnn_attention_22[6] + getitem_989 = _scaled_dot_product_cudnn_attention_22[7]; _scaled_dot_product_cudnn_attention_22 = None + permute_248 = torch.ops.aten.permute.default(getitem_982, [0, 2, 1, 3]) + view_1626 = torch.ops.aten.view.default(permute_248, [2, 8192, -1]); permute_248 = None + convert_element_type_743 = torch.ops.prims.convert_element_type.default(primals_206, torch.bfloat16) + all_gather_into_tensor_248 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_743, 8, '0'); convert_element_type_743 = None + wait_tensor_293 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_248); all_gather_into_tensor_248 = None + permute_249 = torch.ops.aten.permute.default(wait_tensor_293, [1, 0]); wait_tensor_293 = None + view_1632 = torch.ops.aten.view.default(view_1626, [16384, 512]); view_1626 = None + mm_157 = torch.ops.aten.mm.default(view_1632, permute_249); view_1632 = permute_249 = None + view_1633 = torch.ops.aten.view.default(mm_157, [2, 8192, 4096]); mm_157 = None + split_98 = torch.ops.aten.split.Tensor(view_1633, 1024, 1); view_1633 = None + getitem_991 = split_98[0] + getitem_992 = split_98[1] + getitem_993 = split_98[2] + getitem_994 = split_98[3] + getitem_995 = split_98[4] + getitem_996 = split_98[5] + getitem_997 = split_98[6] + getitem_998 = split_98[7]; split_98 = None + cat_90 = torch.ops.aten.cat.default([getitem_991, getitem_992, getitem_993, getitem_994, getitem_995, getitem_996, getitem_997, getitem_998]); getitem_991 = getitem_992 = getitem_993 = getitem_994 = getitem_995 = getitem_996 = getitem_997 = getitem_998 = None + reduce_scatter_tensor_45 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_90, 'sum', 8, '1'); cat_90 = None + wait_tensor_294 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_45) + add_89 = torch.ops.aten.add.Tensor(add_87, wait_tensor_294); wait_tensor_294 = None + convert_element_type_746 = torch.ops.prims.convert_element_type.default(primals_207, torch.bfloat16) + all_gather_into_tensor_249 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_746, 8, '0'); convert_element_type_746 = None + wait_tensor_295 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_249); all_gather_into_tensor_249 = None + convert_element_type_747 = torch.ops.prims.convert_element_type.default(add_89, torch.float32) + pow_46 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_747, 2) + mean_45 = torch.ops.aten.mean.dim(pow_46, [2], True); pow_46 = None + add_90 = torch.ops.aten.add.Scalar(mean_45, 1e-05); mean_45 = None + rsqrt_45 = torch.ops.aten.rsqrt.default(add_90); add_90 = None + mul_180 = torch.ops.aten.mul.Tensor(convert_element_type_747, rsqrt_45); convert_element_type_747 = rsqrt_45 = None + mul_181 = torch.ops.aten.mul.Tensor(mul_180, wait_tensor_295); mul_180 = wait_tensor_295 = None + convert_element_type_748 = torch.ops.prims.convert_element_type.default(mul_181, torch.bfloat16); mul_181 = None + all_gather_into_tensor_250 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_748, 8, '1'); convert_element_type_748 = None + wait_tensor_296 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_250); all_gather_into_tensor_250 = None + split_99 = torch.ops.aten.split.Tensor(wait_tensor_296, 2); wait_tensor_296 = None + getitem_999 = split_99[0] + getitem_1000 = split_99[1] + getitem_1001 = split_99[2] + getitem_1002 = split_99[3] + getitem_1003 = split_99[4] + getitem_1004 = split_99[5] + getitem_1005 = split_99[6] + getitem_1006 = split_99[7]; split_99 = None + cat_91 = torch.ops.aten.cat.default([getitem_999, getitem_1000, getitem_1001, getitem_1002, getitem_1003, getitem_1004, getitem_1005, getitem_1006], 1); getitem_999 = getitem_1000 = getitem_1001 = getitem_1002 = getitem_1003 = getitem_1004 = getitem_1005 = getitem_1006 = None + convert_element_type_749 = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16) + all_gather_into_tensor_251 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_749, 8, '0'); convert_element_type_749 = None + wait_tensor_297 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_251); all_gather_into_tensor_251 = None + permute_250 = torch.ops.aten.permute.default(wait_tensor_297, [1, 0]); wait_tensor_297 = None + view_1644 = torch.ops.aten.view.default(cat_91, [16384, 4096]); cat_91 = None + mm_158 = torch.ops.aten.mm.default(view_1644, permute_250); permute_250 = None + view_1645 = torch.ops.aten.view.default(mm_158, [2, 8192, 1792]) + convert_element_type_752 = torch.ops.prims.convert_element_type.default(view_1645, torch.float32); view_1645 = None + sigmoid_22 = torch.ops.aten.sigmoid.default(convert_element_type_752) + mul_182 = torch.ops.aten.mul.Tensor(convert_element_type_752, sigmoid_22); convert_element_type_752 = sigmoid_22 = None + convert_element_type_753 = torch.ops.prims.convert_element_type.default(mul_182, torch.bfloat16); mul_182 = None + convert_element_type_754 = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16) + all_gather_into_tensor_252 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_754, 8, '0'); convert_element_type_754 = None + wait_tensor_298 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_252); all_gather_into_tensor_252 = None + permute_251 = torch.ops.aten.permute.default(wait_tensor_298, [1, 0]); wait_tensor_298 = None + mm_159 = torch.ops.aten.mm.default(view_1644, permute_251); view_1644 = permute_251 = None + view_1652 = torch.ops.aten.view.default(mm_159, [2, 8192, 1792]); mm_159 = None + mul_183 = torch.ops.aten.mul.Tensor(convert_element_type_753, view_1652); convert_element_type_753 = view_1652 = None + convert_element_type_757 = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16) + all_gather_into_tensor_253 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_757, 8, '0'); convert_element_type_757 = None + wait_tensor_299 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_253); all_gather_into_tensor_253 = None + permute_252 = torch.ops.aten.permute.default(wait_tensor_299, [1, 0]); wait_tensor_299 = None + view_1659 = torch.ops.aten.view.default(mul_183, [16384, 1792]); mul_183 = None + mm_160 = torch.ops.aten.mm.default(view_1659, permute_252); view_1659 = permute_252 = None + view_1660 = torch.ops.aten.view.default(mm_160, [2, 8192, 4096]); mm_160 = None + split_100 = torch.ops.aten.split.Tensor(view_1660, 1024, 1); view_1660 = None + getitem_1007 = split_100[0] + getitem_1008 = split_100[1] + getitem_1009 = split_100[2] + getitem_1010 = split_100[3] + getitem_1011 = split_100[4] + getitem_1012 = split_100[5] + getitem_1013 = split_100[6] + getitem_1014 = split_100[7]; split_100 = None + cat_92 = torch.ops.aten.cat.default([getitem_1007, getitem_1008, getitem_1009, getitem_1010, getitem_1011, getitem_1012, getitem_1013, getitem_1014]); getitem_1007 = getitem_1008 = getitem_1009 = getitem_1010 = getitem_1011 = getitem_1012 = getitem_1013 = getitem_1014 = None + reduce_scatter_tensor_46 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_92, 'sum', 8, '1'); cat_92 = None + wait_tensor_300 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_46); reduce_scatter_tensor_46 = None + add_91 = torch.ops.aten.add.Tensor(add_89, wait_tensor_300); add_89 = wait_tensor_300 = None + convert_element_type_760 = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16) + all_gather_into_tensor_254 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_760, 8, '0'); convert_element_type_760 = None + wait_tensor_301 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_254); all_gather_into_tensor_254 = None + convert_element_type_761 = torch.ops.prims.convert_element_type.default(add_91, torch.float32) + pow_47 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_761, 2) + mean_46 = torch.ops.aten.mean.dim(pow_47, [2], True); pow_47 = None + add_92 = torch.ops.aten.add.Scalar(mean_46, 1e-05); mean_46 = None + rsqrt_46 = torch.ops.aten.rsqrt.default(add_92); add_92 = None + mul_184 = torch.ops.aten.mul.Tensor(convert_element_type_761, rsqrt_46); convert_element_type_761 = rsqrt_46 = None + mul_185 = torch.ops.aten.mul.Tensor(mul_184, wait_tensor_301); mul_184 = wait_tensor_301 = None + convert_element_type_762 = torch.ops.prims.convert_element_type.default(mul_185, torch.bfloat16); mul_185 = None + all_gather_into_tensor_255 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_762, 8, '1'); convert_element_type_762 = None + wait_tensor_302 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_255); all_gather_into_tensor_255 = None + split_101 = torch.ops.aten.split.Tensor(wait_tensor_302, 2); wait_tensor_302 = None + getitem_1015 = split_101[0] + getitem_1016 = split_101[1] + getitem_1017 = split_101[2] + getitem_1018 = split_101[3] + getitem_1019 = split_101[4] + getitem_1020 = split_101[5] + getitem_1021 = split_101[6] + getitem_1022 = split_101[7]; split_101 = None + cat_93 = torch.ops.aten.cat.default([getitem_1015, getitem_1016, getitem_1017, getitem_1018, getitem_1019, getitem_1020, getitem_1021, getitem_1022], 1); getitem_1015 = getitem_1016 = getitem_1017 = getitem_1018 = getitem_1019 = getitem_1020 = getitem_1021 = getitem_1022 = None + convert_element_type_763 = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16) + all_gather_into_tensor_256 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_763, 8, '0'); convert_element_type_763 = None + wait_tensor_303 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_256); all_gather_into_tensor_256 = None + permute_253 = torch.ops.aten.permute.default(wait_tensor_303, [1, 0]); wait_tensor_303 = None + view_1671 = torch.ops.aten.view.default(cat_93, [16384, 4096]); cat_93 = None + mm_161 = torch.ops.aten.mm.default(view_1671, permute_253); permute_253 = None + view_1672 = torch.ops.aten.view.default(mm_161, [2, 8192, 512]) + convert_element_type_766 = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16) + all_gather_into_tensor_257 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_766, 8, '0'); convert_element_type_766 = None + wait_tensor_304 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_257); all_gather_into_tensor_257 = None + permute_254 = torch.ops.aten.permute.default(wait_tensor_304, [1, 0]); wait_tensor_304 = None + mm_162 = torch.ops.aten.mm.default(view_1671, permute_254); permute_254 = None + view_1679 = torch.ops.aten.view.default(mm_162, [2, 8192, 128]); mm_162 = None + convert_element_type_769 = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16) + all_gather_into_tensor_258 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_769, 8, '0'); convert_element_type_769 = None + wait_tensor_305 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_258); all_gather_into_tensor_258 = None + permute_255 = torch.ops.aten.permute.default(wait_tensor_305, [1, 0]); wait_tensor_305 = None + mm_163 = torch.ops.aten.mm.default(view_1671, permute_255); view_1671 = permute_255 = None + view_1686 = torch.ops.aten.view.default(mm_163, [2, 8192, 128]) + view_1688 = torch.ops.aten.view.default(view_1672, [2, 8192, -1, 128]); view_1672 = None + view_1689 = torch.ops.aten.view.default(view_1679, [2, 8192, -1, 128]); view_1679 = None + view_1690 = torch.ops.aten.view.default(view_1686, [2, 8192, -1, 128]); view_1686 = None + convert_element_type_772 = torch.ops.prims.convert_element_type.default(view_1688, torch.float32); view_1688 = None + view_1691 = torch.ops.aten.view.default(convert_element_type_772, [2, 8192, 4, -1, 2]); convert_element_type_772 = None + view_as_complex_46 = torch.ops.aten.view_as_complex.default(view_1691); view_1691 = None + convert_element_type_773 = torch.ops.prims.convert_element_type.default(view_1689, torch.float32); view_1689 = None + view_1692 = torch.ops.aten.view.default(convert_element_type_773, [2, 8192, 1, -1, 2]); convert_element_type_773 = None + view_as_complex_47 = torch.ops.aten.view_as_complex.default(view_1692); view_1692 = None + mul_186 = torch.ops.aten.mul.Tensor(view_as_complex_46, view_37); view_as_complex_46 = None + view_as_real_46 = torch.ops.aten.view_as_real.default(mul_186); mul_186 = None + view_1694 = torch.ops.aten.view.default(view_as_real_46, [2, 8192, 4, 128]); view_as_real_46 = None + mul_187 = torch.ops.aten.mul.Tensor(view_as_complex_47, view_37); view_as_complex_47 = None + view_as_real_47 = torch.ops.aten.view_as_real.default(mul_187); mul_187 = None + view_1695 = torch.ops.aten.view.default(view_as_real_47, [2, 8192, 1, 128]); view_as_real_47 = None + convert_element_type_774 = torch.ops.prims.convert_element_type.default(view_1694, torch.bfloat16); view_1694 = None + convert_element_type_775 = torch.ops.prims.convert_element_type.default(view_1695, torch.bfloat16); view_1695 = None + unsqueeze_46 = torch.ops.aten.unsqueeze.default(convert_element_type_775, 3); convert_element_type_775 = None + expand_46 = torch.ops.aten.expand.default(unsqueeze_46, [2, 8192, 1, 4, 128]); unsqueeze_46 = None + view_1696 = torch.ops.aten.view.default(expand_46, [2, 8192, 4, 128]); expand_46 = None + unsqueeze_47 = torch.ops.aten.unsqueeze.default(view_1690, 3); view_1690 = None + expand_47 = torch.ops.aten.expand.default(unsqueeze_47, [2, 8192, 1, 4, 128]); unsqueeze_47 = None + view_1697 = torch.ops.aten.view.default(expand_47, [2, 8192, 4, 128]); expand_47 = None + permute_256 = torch.ops.aten.permute.default(convert_element_type_774, [0, 2, 1, 3]); convert_element_type_774 = None + permute_257 = torch.ops.aten.permute.default(view_1696, [0, 2, 1, 3]); view_1696 = None + permute_258 = torch.ops.aten.permute.default(view_1697, [0, 2, 1, 3]); view_1697 = None + _scaled_dot_product_cudnn_attention_23 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_256, permute_257, permute_258, None, True, 0.0, True); permute_256 = permute_257 = permute_258 = None + getitem_1023 = _scaled_dot_product_cudnn_attention_23[0] + getitem_1024 = _scaled_dot_product_cudnn_attention_23[1] + getitem_1029 = _scaled_dot_product_cudnn_attention_23[6] + getitem_1030 = _scaled_dot_product_cudnn_attention_23[7]; _scaled_dot_product_cudnn_attention_23 = None + permute_259 = torch.ops.aten.permute.default(getitem_1023, [0, 2, 1, 3]) + view_1698 = torch.ops.aten.view.default(permute_259, [2, 8192, -1]); permute_259 = None + convert_element_type_776 = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16) + all_gather_into_tensor_259 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_776, 8, '0'); convert_element_type_776 = None + wait_tensor_306 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_259); all_gather_into_tensor_259 = None + permute_260 = torch.ops.aten.permute.default(wait_tensor_306, [1, 0]); wait_tensor_306 = None + view_1704 = torch.ops.aten.view.default(view_1698, [16384, 512]); view_1698 = None + mm_164 = torch.ops.aten.mm.default(view_1704, permute_260); view_1704 = permute_260 = None + view_1705 = torch.ops.aten.view.default(mm_164, [2, 8192, 4096]); mm_164 = None + split_102 = torch.ops.aten.split.Tensor(view_1705, 1024, 1); view_1705 = None + getitem_1032 = split_102[0] + getitem_1033 = split_102[1] + getitem_1034 = split_102[2] + getitem_1035 = split_102[3] + getitem_1036 = split_102[4] + getitem_1037 = split_102[5] + getitem_1038 = split_102[6] + getitem_1039 = split_102[7]; split_102 = None + cat_94 = torch.ops.aten.cat.default([getitem_1032, getitem_1033, getitem_1034, getitem_1035, getitem_1036, getitem_1037, getitem_1038, getitem_1039]); getitem_1032 = getitem_1033 = getitem_1034 = getitem_1035 = getitem_1036 = getitem_1037 = getitem_1038 = getitem_1039 = None + reduce_scatter_tensor_47 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_94, 'sum', 8, '1'); cat_94 = None + wait_tensor_307 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_47) + add_93 = torch.ops.aten.add.Tensor(add_91, wait_tensor_307); wait_tensor_307 = None + convert_element_type_779 = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16) + all_gather_into_tensor_260 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_779, 8, '0'); convert_element_type_779 = None + wait_tensor_308 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_260); all_gather_into_tensor_260 = None + convert_element_type_780 = torch.ops.prims.convert_element_type.default(add_93, torch.float32) + pow_48 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_780, 2) + mean_47 = torch.ops.aten.mean.dim(pow_48, [2], True); pow_48 = None + add_94 = torch.ops.aten.add.Scalar(mean_47, 1e-05); mean_47 = None + rsqrt_47 = torch.ops.aten.rsqrt.default(add_94); add_94 = None + mul_188 = torch.ops.aten.mul.Tensor(convert_element_type_780, rsqrt_47); convert_element_type_780 = rsqrt_47 = None + mul_189 = torch.ops.aten.mul.Tensor(mul_188, wait_tensor_308); mul_188 = wait_tensor_308 = None + convert_element_type_781 = torch.ops.prims.convert_element_type.default(mul_189, torch.bfloat16); mul_189 = None + all_gather_into_tensor_261 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_781, 8, '1'); convert_element_type_781 = None + wait_tensor_309 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_261); all_gather_into_tensor_261 = None + split_103 = torch.ops.aten.split.Tensor(wait_tensor_309, 2); wait_tensor_309 = None + getitem_1040 = split_103[0] + getitem_1041 = split_103[1] + getitem_1042 = split_103[2] + getitem_1043 = split_103[3] + getitem_1044 = split_103[4] + getitem_1045 = split_103[5] + getitem_1046 = split_103[6] + getitem_1047 = split_103[7]; split_103 = None + cat_95 = torch.ops.aten.cat.default([getitem_1040, getitem_1041, getitem_1042, getitem_1043, getitem_1044, getitem_1045, getitem_1046, getitem_1047], 1); getitem_1040 = getitem_1041 = getitem_1042 = getitem_1043 = getitem_1044 = getitem_1045 = getitem_1046 = getitem_1047 = None + convert_element_type_782 = torch.ops.prims.convert_element_type.default(primals_217, torch.bfloat16) + all_gather_into_tensor_262 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_782, 8, '0'); convert_element_type_782 = None + wait_tensor_310 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_262); all_gather_into_tensor_262 = None + permute_261 = torch.ops.aten.permute.default(wait_tensor_310, [1, 0]); wait_tensor_310 = None + view_1716 = torch.ops.aten.view.default(cat_95, [16384, 4096]); cat_95 = None + mm_165 = torch.ops.aten.mm.default(view_1716, permute_261); permute_261 = None + view_1717 = torch.ops.aten.view.default(mm_165, [2, 8192, 1792]) + convert_element_type_785 = torch.ops.prims.convert_element_type.default(view_1717, torch.float32); view_1717 = None + sigmoid_23 = torch.ops.aten.sigmoid.default(convert_element_type_785) + mul_190 = torch.ops.aten.mul.Tensor(convert_element_type_785, sigmoid_23); convert_element_type_785 = sigmoid_23 = None + convert_element_type_786 = torch.ops.prims.convert_element_type.default(mul_190, torch.bfloat16); mul_190 = None + convert_element_type_787 = torch.ops.prims.convert_element_type.default(primals_218, torch.bfloat16) + all_gather_into_tensor_263 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_787, 8, '0'); convert_element_type_787 = None + wait_tensor_311 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_263); all_gather_into_tensor_263 = None + permute_262 = torch.ops.aten.permute.default(wait_tensor_311, [1, 0]); wait_tensor_311 = None + mm_166 = torch.ops.aten.mm.default(view_1716, permute_262); view_1716 = permute_262 = None + view_1724 = torch.ops.aten.view.default(mm_166, [2, 8192, 1792]); mm_166 = None + mul_191 = torch.ops.aten.mul.Tensor(convert_element_type_786, view_1724); convert_element_type_786 = view_1724 = None + convert_element_type_790 = torch.ops.prims.convert_element_type.default(primals_219, torch.bfloat16) + all_gather_into_tensor_264 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_790, 8, '0'); convert_element_type_790 = None + wait_tensor_312 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_264); all_gather_into_tensor_264 = None + permute_263 = torch.ops.aten.permute.default(wait_tensor_312, [1, 0]); wait_tensor_312 = None + view_1731 = torch.ops.aten.view.default(mul_191, [16384, 1792]); mul_191 = None + mm_167 = torch.ops.aten.mm.default(view_1731, permute_263); view_1731 = permute_263 = None + view_1732 = torch.ops.aten.view.default(mm_167, [2, 8192, 4096]); mm_167 = None + split_104 = torch.ops.aten.split.Tensor(view_1732, 1024, 1); view_1732 = None + getitem_1048 = split_104[0] + getitem_1049 = split_104[1] + getitem_1050 = split_104[2] + getitem_1051 = split_104[3] + getitem_1052 = split_104[4] + getitem_1053 = split_104[5] + getitem_1054 = split_104[6] + getitem_1055 = split_104[7]; split_104 = None + cat_96 = torch.ops.aten.cat.default([getitem_1048, getitem_1049, getitem_1050, getitem_1051, getitem_1052, getitem_1053, getitem_1054, getitem_1055]); getitem_1048 = getitem_1049 = getitem_1050 = getitem_1051 = getitem_1052 = getitem_1053 = getitem_1054 = getitem_1055 = None + reduce_scatter_tensor_48 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_96, 'sum', 8, '1'); cat_96 = None + wait_tensor_313 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_48); reduce_scatter_tensor_48 = None + add_95 = torch.ops.aten.add.Tensor(add_93, wait_tensor_313); add_93 = wait_tensor_313 = None + convert_element_type_793 = torch.ops.prims.convert_element_type.default(primals_220, torch.bfloat16) + all_gather_into_tensor_265 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_793, 8, '0'); convert_element_type_793 = None + wait_tensor_314 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_265); all_gather_into_tensor_265 = None + convert_element_type_794 = torch.ops.prims.convert_element_type.default(add_95, torch.float32) + pow_49 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_794, 2) + mean_48 = torch.ops.aten.mean.dim(pow_49, [2], True); pow_49 = None + add_96 = torch.ops.aten.add.Scalar(mean_48, 1e-05); mean_48 = None + rsqrt_48 = torch.ops.aten.rsqrt.default(add_96); add_96 = None + mul_192 = torch.ops.aten.mul.Tensor(convert_element_type_794, rsqrt_48); convert_element_type_794 = rsqrt_48 = None + mul_193 = torch.ops.aten.mul.Tensor(mul_192, wait_tensor_314); mul_192 = wait_tensor_314 = None + convert_element_type_795 = torch.ops.prims.convert_element_type.default(mul_193, torch.bfloat16); mul_193 = None + all_gather_into_tensor_266 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_795, 8, '1'); convert_element_type_795 = None + wait_tensor_315 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_266); all_gather_into_tensor_266 = None + split_105 = torch.ops.aten.split.Tensor(wait_tensor_315, 2); wait_tensor_315 = None + getitem_1056 = split_105[0] + getitem_1057 = split_105[1] + getitem_1058 = split_105[2] + getitem_1059 = split_105[3] + getitem_1060 = split_105[4] + getitem_1061 = split_105[5] + getitem_1062 = split_105[6] + getitem_1063 = split_105[7]; split_105 = None + cat_97 = torch.ops.aten.cat.default([getitem_1056, getitem_1057, getitem_1058, getitem_1059, getitem_1060, getitem_1061, getitem_1062, getitem_1063], 1); getitem_1056 = getitem_1057 = getitem_1058 = getitem_1059 = getitem_1060 = getitem_1061 = getitem_1062 = getitem_1063 = None + convert_element_type_796 = torch.ops.prims.convert_element_type.default(primals_221, torch.bfloat16) + all_gather_into_tensor_267 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_796, 8, '0'); convert_element_type_796 = None + wait_tensor_316 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_267); all_gather_into_tensor_267 = None + permute_264 = torch.ops.aten.permute.default(wait_tensor_316, [1, 0]); wait_tensor_316 = None + view_1743 = torch.ops.aten.view.default(cat_97, [16384, 4096]); cat_97 = None + mm_168 = torch.ops.aten.mm.default(view_1743, permute_264); permute_264 = None + view_1744 = torch.ops.aten.view.default(mm_168, [2, 8192, 512]) + convert_element_type_799 = torch.ops.prims.convert_element_type.default(primals_222, torch.bfloat16) + all_gather_into_tensor_268 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_799, 8, '0'); convert_element_type_799 = None + wait_tensor_317 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_268); all_gather_into_tensor_268 = None + permute_265 = torch.ops.aten.permute.default(wait_tensor_317, [1, 0]); wait_tensor_317 = None + mm_169 = torch.ops.aten.mm.default(view_1743, permute_265); permute_265 = None + view_1751 = torch.ops.aten.view.default(mm_169, [2, 8192, 128]); mm_169 = None + convert_element_type_802 = torch.ops.prims.convert_element_type.default(primals_223, torch.bfloat16) + all_gather_into_tensor_269 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_802, 8, '0'); convert_element_type_802 = None + wait_tensor_318 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_269); all_gather_into_tensor_269 = None + permute_266 = torch.ops.aten.permute.default(wait_tensor_318, [1, 0]); wait_tensor_318 = None + mm_170 = torch.ops.aten.mm.default(view_1743, permute_266); view_1743 = permute_266 = None + view_1758 = torch.ops.aten.view.default(mm_170, [2, 8192, 128]) + view_1760 = torch.ops.aten.view.default(view_1744, [2, 8192, -1, 128]); view_1744 = None + view_1761 = torch.ops.aten.view.default(view_1751, [2, 8192, -1, 128]); view_1751 = None + view_1762 = torch.ops.aten.view.default(view_1758, [2, 8192, -1, 128]); view_1758 = None + convert_element_type_805 = torch.ops.prims.convert_element_type.default(view_1760, torch.float32); view_1760 = None + view_1763 = torch.ops.aten.view.default(convert_element_type_805, [2, 8192, 4, -1, 2]); convert_element_type_805 = None + view_as_complex_48 = torch.ops.aten.view_as_complex.default(view_1763); view_1763 = None + convert_element_type_806 = torch.ops.prims.convert_element_type.default(view_1761, torch.float32); view_1761 = None + view_1764 = torch.ops.aten.view.default(convert_element_type_806, [2, 8192, 1, -1, 2]); convert_element_type_806 = None + view_as_complex_49 = torch.ops.aten.view_as_complex.default(view_1764); view_1764 = None + mul_194 = torch.ops.aten.mul.Tensor(view_as_complex_48, view_37); view_as_complex_48 = None + view_as_real_48 = torch.ops.aten.view_as_real.default(mul_194); mul_194 = None + view_1766 = torch.ops.aten.view.default(view_as_real_48, [2, 8192, 4, 128]); view_as_real_48 = None + mul_195 = torch.ops.aten.mul.Tensor(view_as_complex_49, view_37); view_as_complex_49 = None + view_as_real_49 = torch.ops.aten.view_as_real.default(mul_195); mul_195 = None + view_1767 = torch.ops.aten.view.default(view_as_real_49, [2, 8192, 1, 128]); view_as_real_49 = None + convert_element_type_807 = torch.ops.prims.convert_element_type.default(view_1766, torch.bfloat16); view_1766 = None + convert_element_type_808 = torch.ops.prims.convert_element_type.default(view_1767, torch.bfloat16); view_1767 = None + unsqueeze_48 = torch.ops.aten.unsqueeze.default(convert_element_type_808, 3); convert_element_type_808 = None + expand_48 = torch.ops.aten.expand.default(unsqueeze_48, [2, 8192, 1, 4, 128]); unsqueeze_48 = None + view_1768 = torch.ops.aten.view.default(expand_48, [2, 8192, 4, 128]); expand_48 = None + unsqueeze_49 = torch.ops.aten.unsqueeze.default(view_1762, 3); view_1762 = None + expand_49 = torch.ops.aten.expand.default(unsqueeze_49, [2, 8192, 1, 4, 128]); unsqueeze_49 = None + view_1769 = torch.ops.aten.view.default(expand_49, [2, 8192, 4, 128]); expand_49 = None + permute_267 = torch.ops.aten.permute.default(convert_element_type_807, [0, 2, 1, 3]); convert_element_type_807 = None + permute_268 = torch.ops.aten.permute.default(view_1768, [0, 2, 1, 3]); view_1768 = None + permute_269 = torch.ops.aten.permute.default(view_1769, [0, 2, 1, 3]); view_1769 = None + _scaled_dot_product_cudnn_attention_24 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_267, permute_268, permute_269, None, True, 0.0, True); permute_267 = permute_268 = permute_269 = None + getitem_1064 = _scaled_dot_product_cudnn_attention_24[0] + getitem_1065 = _scaled_dot_product_cudnn_attention_24[1] + getitem_1070 = _scaled_dot_product_cudnn_attention_24[6] + getitem_1071 = _scaled_dot_product_cudnn_attention_24[7]; _scaled_dot_product_cudnn_attention_24 = None + permute_270 = torch.ops.aten.permute.default(getitem_1064, [0, 2, 1, 3]) + view_1770 = torch.ops.aten.view.default(permute_270, [2, 8192, -1]); permute_270 = None + convert_element_type_809 = torch.ops.prims.convert_element_type.default(primals_224, torch.bfloat16) + all_gather_into_tensor_270 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_809, 8, '0'); convert_element_type_809 = None + wait_tensor_319 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_270); all_gather_into_tensor_270 = None + permute_271 = torch.ops.aten.permute.default(wait_tensor_319, [1, 0]); wait_tensor_319 = None + view_1776 = torch.ops.aten.view.default(view_1770, [16384, 512]); view_1770 = None + mm_171 = torch.ops.aten.mm.default(view_1776, permute_271); view_1776 = permute_271 = None + view_1777 = torch.ops.aten.view.default(mm_171, [2, 8192, 4096]); mm_171 = None + split_106 = torch.ops.aten.split.Tensor(view_1777, 1024, 1); view_1777 = None + getitem_1073 = split_106[0] + getitem_1074 = split_106[1] + getitem_1075 = split_106[2] + getitem_1076 = split_106[3] + getitem_1077 = split_106[4] + getitem_1078 = split_106[5] + getitem_1079 = split_106[6] + getitem_1080 = split_106[7]; split_106 = None + cat_98 = torch.ops.aten.cat.default([getitem_1073, getitem_1074, getitem_1075, getitem_1076, getitem_1077, getitem_1078, getitem_1079, getitem_1080]); getitem_1073 = getitem_1074 = getitem_1075 = getitem_1076 = getitem_1077 = getitem_1078 = getitem_1079 = getitem_1080 = None + reduce_scatter_tensor_49 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_98, 'sum', 8, '1'); cat_98 = None + wait_tensor_320 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_49) + add_97 = torch.ops.aten.add.Tensor(add_95, wait_tensor_320); wait_tensor_320 = None + convert_element_type_812 = torch.ops.prims.convert_element_type.default(primals_225, torch.bfloat16) + all_gather_into_tensor_271 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_812, 8, '0'); convert_element_type_812 = None + wait_tensor_321 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_271); all_gather_into_tensor_271 = None + convert_element_type_813 = torch.ops.prims.convert_element_type.default(add_97, torch.float32) + pow_50 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_813, 2) + mean_49 = torch.ops.aten.mean.dim(pow_50, [2], True); pow_50 = None + add_98 = torch.ops.aten.add.Scalar(mean_49, 1e-05); mean_49 = None + rsqrt_49 = torch.ops.aten.rsqrt.default(add_98); add_98 = None + mul_196 = torch.ops.aten.mul.Tensor(convert_element_type_813, rsqrt_49); convert_element_type_813 = rsqrt_49 = None + mul_197 = torch.ops.aten.mul.Tensor(mul_196, wait_tensor_321); mul_196 = wait_tensor_321 = None + convert_element_type_814 = torch.ops.prims.convert_element_type.default(mul_197, torch.bfloat16); mul_197 = None + all_gather_into_tensor_272 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_814, 8, '1'); convert_element_type_814 = None + wait_tensor_322 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_272); all_gather_into_tensor_272 = None + split_107 = torch.ops.aten.split.Tensor(wait_tensor_322, 2); wait_tensor_322 = None + getitem_1081 = split_107[0] + getitem_1082 = split_107[1] + getitem_1083 = split_107[2] + getitem_1084 = split_107[3] + getitem_1085 = split_107[4] + getitem_1086 = split_107[5] + getitem_1087 = split_107[6] + getitem_1088 = split_107[7]; split_107 = None + cat_99 = torch.ops.aten.cat.default([getitem_1081, getitem_1082, getitem_1083, getitem_1084, getitem_1085, getitem_1086, getitem_1087, getitem_1088], 1); getitem_1081 = getitem_1082 = getitem_1083 = getitem_1084 = getitem_1085 = getitem_1086 = getitem_1087 = getitem_1088 = None + convert_element_type_815 = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16) + all_gather_into_tensor_273 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_815, 8, '0'); convert_element_type_815 = None + wait_tensor_323 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_273); all_gather_into_tensor_273 = None + permute_272 = torch.ops.aten.permute.default(wait_tensor_323, [1, 0]); wait_tensor_323 = None + view_1788 = torch.ops.aten.view.default(cat_99, [16384, 4096]); cat_99 = None + mm_172 = torch.ops.aten.mm.default(view_1788, permute_272); permute_272 = None + view_1789 = torch.ops.aten.view.default(mm_172, [2, 8192, 1792]) + convert_element_type_818 = torch.ops.prims.convert_element_type.default(view_1789, torch.float32); view_1789 = None + sigmoid_24 = torch.ops.aten.sigmoid.default(convert_element_type_818) + mul_198 = torch.ops.aten.mul.Tensor(convert_element_type_818, sigmoid_24); convert_element_type_818 = sigmoid_24 = None + convert_element_type_819 = torch.ops.prims.convert_element_type.default(mul_198, torch.bfloat16); mul_198 = None + convert_element_type_820 = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16) + all_gather_into_tensor_274 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_820, 8, '0'); convert_element_type_820 = None + wait_tensor_324 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_274); all_gather_into_tensor_274 = None + permute_273 = torch.ops.aten.permute.default(wait_tensor_324, [1, 0]); wait_tensor_324 = None + mm_173 = torch.ops.aten.mm.default(view_1788, permute_273); view_1788 = permute_273 = None + view_1796 = torch.ops.aten.view.default(mm_173, [2, 8192, 1792]); mm_173 = None + mul_199 = torch.ops.aten.mul.Tensor(convert_element_type_819, view_1796); convert_element_type_819 = view_1796 = None + convert_element_type_823 = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16) + all_gather_into_tensor_275 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_823, 8, '0'); convert_element_type_823 = None + wait_tensor_325 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_275); all_gather_into_tensor_275 = None + permute_274 = torch.ops.aten.permute.default(wait_tensor_325, [1, 0]); wait_tensor_325 = None + view_1803 = torch.ops.aten.view.default(mul_199, [16384, 1792]); mul_199 = None + mm_174 = torch.ops.aten.mm.default(view_1803, permute_274); view_1803 = permute_274 = None + view_1804 = torch.ops.aten.view.default(mm_174, [2, 8192, 4096]); mm_174 = None + split_108 = torch.ops.aten.split.Tensor(view_1804, 1024, 1); view_1804 = None + getitem_1089 = split_108[0] + getitem_1090 = split_108[1] + getitem_1091 = split_108[2] + getitem_1092 = split_108[3] + getitem_1093 = split_108[4] + getitem_1094 = split_108[5] + getitem_1095 = split_108[6] + getitem_1096 = split_108[7]; split_108 = None + cat_100 = torch.ops.aten.cat.default([getitem_1089, getitem_1090, getitem_1091, getitem_1092, getitem_1093, getitem_1094, getitem_1095, getitem_1096]); getitem_1089 = getitem_1090 = getitem_1091 = getitem_1092 = getitem_1093 = getitem_1094 = getitem_1095 = getitem_1096 = None + reduce_scatter_tensor_50 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_100, 'sum', 8, '1'); cat_100 = None + wait_tensor_326 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_50); reduce_scatter_tensor_50 = None + add_99 = torch.ops.aten.add.Tensor(add_97, wait_tensor_326); add_97 = wait_tensor_326 = None + convert_element_type_826 = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16) + all_gather_into_tensor_276 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_826, 8, '0'); convert_element_type_826 = None + wait_tensor_327 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_276); all_gather_into_tensor_276 = None + convert_element_type_827 = torch.ops.prims.convert_element_type.default(add_99, torch.float32) + pow_51 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_827, 2) + mean_50 = torch.ops.aten.mean.dim(pow_51, [2], True); pow_51 = None + add_100 = torch.ops.aten.add.Scalar(mean_50, 1e-05); mean_50 = None + rsqrt_50 = torch.ops.aten.rsqrt.default(add_100); add_100 = None + mul_200 = torch.ops.aten.mul.Tensor(convert_element_type_827, rsqrt_50); convert_element_type_827 = rsqrt_50 = None + mul_201 = torch.ops.aten.mul.Tensor(mul_200, wait_tensor_327); mul_200 = wait_tensor_327 = None + convert_element_type_828 = torch.ops.prims.convert_element_type.default(mul_201, torch.bfloat16); mul_201 = None + all_gather_into_tensor_277 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_828, 8, '1'); convert_element_type_828 = None + wait_tensor_328 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_277); all_gather_into_tensor_277 = None + split_109 = torch.ops.aten.split.Tensor(wait_tensor_328, 2); wait_tensor_328 = None + getitem_1097 = split_109[0] + getitem_1098 = split_109[1] + getitem_1099 = split_109[2] + getitem_1100 = split_109[3] + getitem_1101 = split_109[4] + getitem_1102 = split_109[5] + getitem_1103 = split_109[6] + getitem_1104 = split_109[7]; split_109 = None + cat_101 = torch.ops.aten.cat.default([getitem_1097, getitem_1098, getitem_1099, getitem_1100, getitem_1101, getitem_1102, getitem_1103, getitem_1104], 1); getitem_1097 = getitem_1098 = getitem_1099 = getitem_1100 = getitem_1101 = getitem_1102 = getitem_1103 = getitem_1104 = None + convert_element_type_829 = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16) + all_gather_into_tensor_278 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_829, 8, '0'); convert_element_type_829 = None + wait_tensor_329 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_278); all_gather_into_tensor_278 = None + permute_275 = torch.ops.aten.permute.default(wait_tensor_329, [1, 0]); wait_tensor_329 = None + view_1815 = torch.ops.aten.view.default(cat_101, [16384, 4096]); cat_101 = None + mm_175 = torch.ops.aten.mm.default(view_1815, permute_275); permute_275 = None + view_1816 = torch.ops.aten.view.default(mm_175, [2, 8192, 512]) + convert_element_type_832 = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16) + all_gather_into_tensor_279 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_832, 8, '0'); convert_element_type_832 = None + wait_tensor_330 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_279); all_gather_into_tensor_279 = None + permute_276 = torch.ops.aten.permute.default(wait_tensor_330, [1, 0]); wait_tensor_330 = None + mm_176 = torch.ops.aten.mm.default(view_1815, permute_276); permute_276 = None + view_1823 = torch.ops.aten.view.default(mm_176, [2, 8192, 128]); mm_176 = None + convert_element_type_835 = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16) + all_gather_into_tensor_280 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_835, 8, '0'); convert_element_type_835 = None + wait_tensor_331 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_280); all_gather_into_tensor_280 = None + permute_277 = torch.ops.aten.permute.default(wait_tensor_331, [1, 0]); wait_tensor_331 = None + mm_177 = torch.ops.aten.mm.default(view_1815, permute_277); view_1815 = permute_277 = None + view_1830 = torch.ops.aten.view.default(mm_177, [2, 8192, 128]) + view_1832 = torch.ops.aten.view.default(view_1816, [2, 8192, -1, 128]); view_1816 = None + view_1833 = torch.ops.aten.view.default(view_1823, [2, 8192, -1, 128]); view_1823 = None + view_1834 = torch.ops.aten.view.default(view_1830, [2, 8192, -1, 128]); view_1830 = None + convert_element_type_838 = torch.ops.prims.convert_element_type.default(view_1832, torch.float32); view_1832 = None + view_1835 = torch.ops.aten.view.default(convert_element_type_838, [2, 8192, 4, -1, 2]); convert_element_type_838 = None + view_as_complex_50 = torch.ops.aten.view_as_complex.default(view_1835); view_1835 = None + convert_element_type_839 = torch.ops.prims.convert_element_type.default(view_1833, torch.float32); view_1833 = None + view_1836 = torch.ops.aten.view.default(convert_element_type_839, [2, 8192, 1, -1, 2]); convert_element_type_839 = None + view_as_complex_51 = torch.ops.aten.view_as_complex.default(view_1836); view_1836 = None + mul_202 = torch.ops.aten.mul.Tensor(view_as_complex_50, view_37); view_as_complex_50 = None + view_as_real_50 = torch.ops.aten.view_as_real.default(mul_202); mul_202 = None + view_1838 = torch.ops.aten.view.default(view_as_real_50, [2, 8192, 4, 128]); view_as_real_50 = None + mul_203 = torch.ops.aten.mul.Tensor(view_as_complex_51, view_37); view_as_complex_51 = None + view_as_real_51 = torch.ops.aten.view_as_real.default(mul_203); mul_203 = None + view_1839 = torch.ops.aten.view.default(view_as_real_51, [2, 8192, 1, 128]); view_as_real_51 = None + convert_element_type_840 = torch.ops.prims.convert_element_type.default(view_1838, torch.bfloat16); view_1838 = None + convert_element_type_841 = torch.ops.prims.convert_element_type.default(view_1839, torch.bfloat16); view_1839 = None + unsqueeze_50 = torch.ops.aten.unsqueeze.default(convert_element_type_841, 3); convert_element_type_841 = None + expand_50 = torch.ops.aten.expand.default(unsqueeze_50, [2, 8192, 1, 4, 128]); unsqueeze_50 = None + view_1840 = torch.ops.aten.view.default(expand_50, [2, 8192, 4, 128]); expand_50 = None + unsqueeze_51 = torch.ops.aten.unsqueeze.default(view_1834, 3); view_1834 = None + expand_51 = torch.ops.aten.expand.default(unsqueeze_51, [2, 8192, 1, 4, 128]); unsqueeze_51 = None + view_1841 = torch.ops.aten.view.default(expand_51, [2, 8192, 4, 128]); expand_51 = None + permute_278 = torch.ops.aten.permute.default(convert_element_type_840, [0, 2, 1, 3]); convert_element_type_840 = None + permute_279 = torch.ops.aten.permute.default(view_1840, [0, 2, 1, 3]); view_1840 = None + permute_280 = torch.ops.aten.permute.default(view_1841, [0, 2, 1, 3]); view_1841 = None + _scaled_dot_product_cudnn_attention_25 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_278, permute_279, permute_280, None, True, 0.0, True); permute_278 = permute_279 = permute_280 = None + getitem_1105 = _scaled_dot_product_cudnn_attention_25[0] + getitem_1106 = _scaled_dot_product_cudnn_attention_25[1] + getitem_1111 = _scaled_dot_product_cudnn_attention_25[6] + getitem_1112 = _scaled_dot_product_cudnn_attention_25[7]; _scaled_dot_product_cudnn_attention_25 = None + permute_281 = torch.ops.aten.permute.default(getitem_1105, [0, 2, 1, 3]) + view_1842 = torch.ops.aten.view.default(permute_281, [2, 8192, -1]); permute_281 = None + convert_element_type_842 = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16) + all_gather_into_tensor_281 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_842, 8, '0'); convert_element_type_842 = None + wait_tensor_332 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_281); all_gather_into_tensor_281 = None + permute_282 = torch.ops.aten.permute.default(wait_tensor_332, [1, 0]); wait_tensor_332 = None + view_1848 = torch.ops.aten.view.default(view_1842, [16384, 512]); view_1842 = None + mm_178 = torch.ops.aten.mm.default(view_1848, permute_282); view_1848 = permute_282 = None + view_1849 = torch.ops.aten.view.default(mm_178, [2, 8192, 4096]); mm_178 = None + split_110 = torch.ops.aten.split.Tensor(view_1849, 1024, 1); view_1849 = None + getitem_1114 = split_110[0] + getitem_1115 = split_110[1] + getitem_1116 = split_110[2] + getitem_1117 = split_110[3] + getitem_1118 = split_110[4] + getitem_1119 = split_110[5] + getitem_1120 = split_110[6] + getitem_1121 = split_110[7]; split_110 = None + cat_102 = torch.ops.aten.cat.default([getitem_1114, getitem_1115, getitem_1116, getitem_1117, getitem_1118, getitem_1119, getitem_1120, getitem_1121]); getitem_1114 = getitem_1115 = getitem_1116 = getitem_1117 = getitem_1118 = getitem_1119 = getitem_1120 = getitem_1121 = None + reduce_scatter_tensor_51 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_102, 'sum', 8, '1'); cat_102 = None + wait_tensor_333 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_51) + add_101 = torch.ops.aten.add.Tensor(add_99, wait_tensor_333); wait_tensor_333 = None + convert_element_type_845 = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16) + all_gather_into_tensor_282 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_845, 8, '0'); convert_element_type_845 = None + wait_tensor_334 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_282); all_gather_into_tensor_282 = None + convert_element_type_846 = torch.ops.prims.convert_element_type.default(add_101, torch.float32) + pow_52 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_846, 2) + mean_51 = torch.ops.aten.mean.dim(pow_52, [2], True); pow_52 = None + add_102 = torch.ops.aten.add.Scalar(mean_51, 1e-05); mean_51 = None + rsqrt_51 = torch.ops.aten.rsqrt.default(add_102); add_102 = None + mul_204 = torch.ops.aten.mul.Tensor(convert_element_type_846, rsqrt_51); convert_element_type_846 = rsqrt_51 = None + mul_205 = torch.ops.aten.mul.Tensor(mul_204, wait_tensor_334); mul_204 = wait_tensor_334 = None + convert_element_type_847 = torch.ops.prims.convert_element_type.default(mul_205, torch.bfloat16); mul_205 = None + all_gather_into_tensor_283 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_847, 8, '1'); convert_element_type_847 = None + wait_tensor_335 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_283); all_gather_into_tensor_283 = None + split_111 = torch.ops.aten.split.Tensor(wait_tensor_335, 2); wait_tensor_335 = None + getitem_1122 = split_111[0] + getitem_1123 = split_111[1] + getitem_1124 = split_111[2] + getitem_1125 = split_111[3] + getitem_1126 = split_111[4] + getitem_1127 = split_111[5] + getitem_1128 = split_111[6] + getitem_1129 = split_111[7]; split_111 = None + cat_103 = torch.ops.aten.cat.default([getitem_1122, getitem_1123, getitem_1124, getitem_1125, getitem_1126, getitem_1127, getitem_1128, getitem_1129], 1); getitem_1122 = getitem_1123 = getitem_1124 = getitem_1125 = getitem_1126 = getitem_1127 = getitem_1128 = getitem_1129 = None + convert_element_type_848 = torch.ops.prims.convert_element_type.default(primals_235, torch.bfloat16) + all_gather_into_tensor_284 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_848, 8, '0'); convert_element_type_848 = None + wait_tensor_336 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_284); all_gather_into_tensor_284 = None + permute_283 = torch.ops.aten.permute.default(wait_tensor_336, [1, 0]); wait_tensor_336 = None + view_1860 = torch.ops.aten.view.default(cat_103, [16384, 4096]); cat_103 = None + mm_179 = torch.ops.aten.mm.default(view_1860, permute_283); permute_283 = None + view_1861 = torch.ops.aten.view.default(mm_179, [2, 8192, 1792]) + convert_element_type_851 = torch.ops.prims.convert_element_type.default(view_1861, torch.float32); view_1861 = None + sigmoid_25 = torch.ops.aten.sigmoid.default(convert_element_type_851) + mul_206 = torch.ops.aten.mul.Tensor(convert_element_type_851, sigmoid_25); convert_element_type_851 = sigmoid_25 = None + convert_element_type_852 = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None + convert_element_type_853 = torch.ops.prims.convert_element_type.default(primals_236, torch.bfloat16) + all_gather_into_tensor_285 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_853, 8, '0'); convert_element_type_853 = None + wait_tensor_337 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_285); all_gather_into_tensor_285 = None + permute_284 = torch.ops.aten.permute.default(wait_tensor_337, [1, 0]); wait_tensor_337 = None + mm_180 = torch.ops.aten.mm.default(view_1860, permute_284); view_1860 = permute_284 = None + view_1868 = torch.ops.aten.view.default(mm_180, [2, 8192, 1792]); mm_180 = None + mul_207 = torch.ops.aten.mul.Tensor(convert_element_type_852, view_1868); convert_element_type_852 = view_1868 = None + convert_element_type_856 = torch.ops.prims.convert_element_type.default(primals_237, torch.bfloat16) + all_gather_into_tensor_286 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_856, 8, '0'); convert_element_type_856 = None + wait_tensor_338 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_286); all_gather_into_tensor_286 = None + permute_285 = torch.ops.aten.permute.default(wait_tensor_338, [1, 0]); wait_tensor_338 = None + view_1875 = torch.ops.aten.view.default(mul_207, [16384, 1792]); mul_207 = None + mm_181 = torch.ops.aten.mm.default(view_1875, permute_285); view_1875 = permute_285 = None + view_1876 = torch.ops.aten.view.default(mm_181, [2, 8192, 4096]); mm_181 = None + split_112 = torch.ops.aten.split.Tensor(view_1876, 1024, 1); view_1876 = None + getitem_1130 = split_112[0] + getitem_1131 = split_112[1] + getitem_1132 = split_112[2] + getitem_1133 = split_112[3] + getitem_1134 = split_112[4] + getitem_1135 = split_112[5] + getitem_1136 = split_112[6] + getitem_1137 = split_112[7]; split_112 = None + cat_104 = torch.ops.aten.cat.default([getitem_1130, getitem_1131, getitem_1132, getitem_1133, getitem_1134, getitem_1135, getitem_1136, getitem_1137]); getitem_1130 = getitem_1131 = getitem_1132 = getitem_1133 = getitem_1134 = getitem_1135 = getitem_1136 = getitem_1137 = None + reduce_scatter_tensor_52 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_104, 'sum', 8, '1'); cat_104 = None + wait_tensor_339 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_52); reduce_scatter_tensor_52 = None + add_103 = torch.ops.aten.add.Tensor(add_101, wait_tensor_339); add_101 = wait_tensor_339 = None + convert_element_type_859 = torch.ops.prims.convert_element_type.default(primals_238, torch.bfloat16) + all_gather_into_tensor_287 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_859, 8, '0'); convert_element_type_859 = None + wait_tensor_340 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_287); all_gather_into_tensor_287 = None + convert_element_type_860 = torch.ops.prims.convert_element_type.default(add_103, torch.float32) + pow_53 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_860, 2) + mean_52 = torch.ops.aten.mean.dim(pow_53, [2], True); pow_53 = None + add_104 = torch.ops.aten.add.Scalar(mean_52, 1e-05); mean_52 = None + rsqrt_52 = torch.ops.aten.rsqrt.default(add_104); add_104 = None + mul_208 = torch.ops.aten.mul.Tensor(convert_element_type_860, rsqrt_52); convert_element_type_860 = rsqrt_52 = None + mul_209 = torch.ops.aten.mul.Tensor(mul_208, wait_tensor_340); mul_208 = wait_tensor_340 = None + convert_element_type_861 = torch.ops.prims.convert_element_type.default(mul_209, torch.bfloat16); mul_209 = None + all_gather_into_tensor_288 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_861, 8, '1'); convert_element_type_861 = None + wait_tensor_341 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_288); all_gather_into_tensor_288 = None + split_113 = torch.ops.aten.split.Tensor(wait_tensor_341, 2); wait_tensor_341 = None + getitem_1138 = split_113[0] + getitem_1139 = split_113[1] + getitem_1140 = split_113[2] + getitem_1141 = split_113[3] + getitem_1142 = split_113[4] + getitem_1143 = split_113[5] + getitem_1144 = split_113[6] + getitem_1145 = split_113[7]; split_113 = None + cat_105 = torch.ops.aten.cat.default([getitem_1138, getitem_1139, getitem_1140, getitem_1141, getitem_1142, getitem_1143, getitem_1144, getitem_1145], 1); getitem_1138 = getitem_1139 = getitem_1140 = getitem_1141 = getitem_1142 = getitem_1143 = getitem_1144 = getitem_1145 = None + convert_element_type_862 = torch.ops.prims.convert_element_type.default(primals_239, torch.bfloat16) + all_gather_into_tensor_289 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_862, 8, '0'); convert_element_type_862 = None + wait_tensor_342 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_289); all_gather_into_tensor_289 = None + permute_286 = torch.ops.aten.permute.default(wait_tensor_342, [1, 0]); wait_tensor_342 = None + view_1887 = torch.ops.aten.view.default(cat_105, [16384, 4096]); cat_105 = None + mm_182 = torch.ops.aten.mm.default(view_1887, permute_286); permute_286 = None + view_1888 = torch.ops.aten.view.default(mm_182, [2, 8192, 512]) + convert_element_type_865 = torch.ops.prims.convert_element_type.default(primals_240, torch.bfloat16) + all_gather_into_tensor_290 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_865, 8, '0'); convert_element_type_865 = None + wait_tensor_343 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_290); all_gather_into_tensor_290 = None + permute_287 = torch.ops.aten.permute.default(wait_tensor_343, [1, 0]); wait_tensor_343 = None + mm_183 = torch.ops.aten.mm.default(view_1887, permute_287); permute_287 = None + view_1895 = torch.ops.aten.view.default(mm_183, [2, 8192, 128]); mm_183 = None + convert_element_type_868 = torch.ops.prims.convert_element_type.default(primals_241, torch.bfloat16) + all_gather_into_tensor_291 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_868, 8, '0'); convert_element_type_868 = None + wait_tensor_344 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_291); all_gather_into_tensor_291 = None + permute_288 = torch.ops.aten.permute.default(wait_tensor_344, [1, 0]); wait_tensor_344 = None + mm_184 = torch.ops.aten.mm.default(view_1887, permute_288); view_1887 = permute_288 = None + view_1902 = torch.ops.aten.view.default(mm_184, [2, 8192, 128]) + view_1904 = torch.ops.aten.view.default(view_1888, [2, 8192, -1, 128]); view_1888 = None + view_1905 = torch.ops.aten.view.default(view_1895, [2, 8192, -1, 128]); view_1895 = None + view_1906 = torch.ops.aten.view.default(view_1902, [2, 8192, -1, 128]); view_1902 = None + convert_element_type_871 = torch.ops.prims.convert_element_type.default(view_1904, torch.float32); view_1904 = None + view_1907 = torch.ops.aten.view.default(convert_element_type_871, [2, 8192, 4, -1, 2]); convert_element_type_871 = None + view_as_complex_52 = torch.ops.aten.view_as_complex.default(view_1907); view_1907 = None + convert_element_type_872 = torch.ops.prims.convert_element_type.default(view_1905, torch.float32); view_1905 = None + view_1908 = torch.ops.aten.view.default(convert_element_type_872, [2, 8192, 1, -1, 2]); convert_element_type_872 = None + view_as_complex_53 = torch.ops.aten.view_as_complex.default(view_1908); view_1908 = None + mul_210 = torch.ops.aten.mul.Tensor(view_as_complex_52, view_37); view_as_complex_52 = None + view_as_real_52 = torch.ops.aten.view_as_real.default(mul_210); mul_210 = None + view_1910 = torch.ops.aten.view.default(view_as_real_52, [2, 8192, 4, 128]); view_as_real_52 = None + mul_211 = torch.ops.aten.mul.Tensor(view_as_complex_53, view_37); view_as_complex_53 = None + view_as_real_53 = torch.ops.aten.view_as_real.default(mul_211); mul_211 = None + view_1911 = torch.ops.aten.view.default(view_as_real_53, [2, 8192, 1, 128]); view_as_real_53 = None + convert_element_type_873 = torch.ops.prims.convert_element_type.default(view_1910, torch.bfloat16); view_1910 = None + convert_element_type_874 = torch.ops.prims.convert_element_type.default(view_1911, torch.bfloat16); view_1911 = None + unsqueeze_52 = torch.ops.aten.unsqueeze.default(convert_element_type_874, 3); convert_element_type_874 = None + expand_52 = torch.ops.aten.expand.default(unsqueeze_52, [2, 8192, 1, 4, 128]); unsqueeze_52 = None + view_1912 = torch.ops.aten.view.default(expand_52, [2, 8192, 4, 128]); expand_52 = None + unsqueeze_53 = torch.ops.aten.unsqueeze.default(view_1906, 3); view_1906 = None + expand_53 = torch.ops.aten.expand.default(unsqueeze_53, [2, 8192, 1, 4, 128]); unsqueeze_53 = None + view_1913 = torch.ops.aten.view.default(expand_53, [2, 8192, 4, 128]); expand_53 = None + permute_289 = torch.ops.aten.permute.default(convert_element_type_873, [0, 2, 1, 3]); convert_element_type_873 = None + permute_290 = torch.ops.aten.permute.default(view_1912, [0, 2, 1, 3]); view_1912 = None + permute_291 = torch.ops.aten.permute.default(view_1913, [0, 2, 1, 3]); view_1913 = None + _scaled_dot_product_cudnn_attention_26 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_289, permute_290, permute_291, None, True, 0.0, True); permute_289 = permute_290 = permute_291 = None + getitem_1146 = _scaled_dot_product_cudnn_attention_26[0] + getitem_1147 = _scaled_dot_product_cudnn_attention_26[1] + getitem_1152 = _scaled_dot_product_cudnn_attention_26[6] + getitem_1153 = _scaled_dot_product_cudnn_attention_26[7]; _scaled_dot_product_cudnn_attention_26 = None + permute_292 = torch.ops.aten.permute.default(getitem_1146, [0, 2, 1, 3]) + view_1914 = torch.ops.aten.view.default(permute_292, [2, 8192, -1]); permute_292 = None + convert_element_type_875 = torch.ops.prims.convert_element_type.default(primals_242, torch.bfloat16) + all_gather_into_tensor_292 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_875, 8, '0'); convert_element_type_875 = None + wait_tensor_345 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_292); all_gather_into_tensor_292 = None + permute_293 = torch.ops.aten.permute.default(wait_tensor_345, [1, 0]); wait_tensor_345 = None + view_1920 = torch.ops.aten.view.default(view_1914, [16384, 512]); view_1914 = None + mm_185 = torch.ops.aten.mm.default(view_1920, permute_293); view_1920 = permute_293 = None + view_1921 = torch.ops.aten.view.default(mm_185, [2, 8192, 4096]); mm_185 = None + split_114 = torch.ops.aten.split.Tensor(view_1921, 1024, 1); view_1921 = None + getitem_1155 = split_114[0] + getitem_1156 = split_114[1] + getitem_1157 = split_114[2] + getitem_1158 = split_114[3] + getitem_1159 = split_114[4] + getitem_1160 = split_114[5] + getitem_1161 = split_114[6] + getitem_1162 = split_114[7]; split_114 = None + cat_106 = torch.ops.aten.cat.default([getitem_1155, getitem_1156, getitem_1157, getitem_1158, getitem_1159, getitem_1160, getitem_1161, getitem_1162]); getitem_1155 = getitem_1156 = getitem_1157 = getitem_1158 = getitem_1159 = getitem_1160 = getitem_1161 = getitem_1162 = None + reduce_scatter_tensor_53 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_106, 'sum', 8, '1'); cat_106 = None + wait_tensor_346 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_53) + add_105 = torch.ops.aten.add.Tensor(add_103, wait_tensor_346); wait_tensor_346 = None + convert_element_type_878 = torch.ops.prims.convert_element_type.default(primals_243, torch.bfloat16) + all_gather_into_tensor_293 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_878, 8, '0'); convert_element_type_878 = None + wait_tensor_347 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_293); all_gather_into_tensor_293 = None + convert_element_type_879 = torch.ops.prims.convert_element_type.default(add_105, torch.float32) + pow_54 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_879, 2) + mean_53 = torch.ops.aten.mean.dim(pow_54, [2], True); pow_54 = None + add_106 = torch.ops.aten.add.Scalar(mean_53, 1e-05); mean_53 = None + rsqrt_53 = torch.ops.aten.rsqrt.default(add_106); add_106 = None + mul_212 = torch.ops.aten.mul.Tensor(convert_element_type_879, rsqrt_53); convert_element_type_879 = rsqrt_53 = None + mul_213 = torch.ops.aten.mul.Tensor(mul_212, wait_tensor_347); mul_212 = wait_tensor_347 = None + convert_element_type_880 = torch.ops.prims.convert_element_type.default(mul_213, torch.bfloat16); mul_213 = None + all_gather_into_tensor_294 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_880, 8, '1'); convert_element_type_880 = None + wait_tensor_348 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_294); all_gather_into_tensor_294 = None + split_115 = torch.ops.aten.split.Tensor(wait_tensor_348, 2); wait_tensor_348 = None + getitem_1163 = split_115[0] + getitem_1164 = split_115[1] + getitem_1165 = split_115[2] + getitem_1166 = split_115[3] + getitem_1167 = split_115[4] + getitem_1168 = split_115[5] + getitem_1169 = split_115[6] + getitem_1170 = split_115[7]; split_115 = None + cat_107 = torch.ops.aten.cat.default([getitem_1163, getitem_1164, getitem_1165, getitem_1166, getitem_1167, getitem_1168, getitem_1169, getitem_1170], 1); getitem_1163 = getitem_1164 = getitem_1165 = getitem_1166 = getitem_1167 = getitem_1168 = getitem_1169 = getitem_1170 = None + convert_element_type_881 = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16) + all_gather_into_tensor_295 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_881, 8, '0'); convert_element_type_881 = None + wait_tensor_349 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_295); all_gather_into_tensor_295 = None + permute_294 = torch.ops.aten.permute.default(wait_tensor_349, [1, 0]); wait_tensor_349 = None + view_1932 = torch.ops.aten.view.default(cat_107, [16384, 4096]); cat_107 = None + mm_186 = torch.ops.aten.mm.default(view_1932, permute_294); permute_294 = None + view_1933 = torch.ops.aten.view.default(mm_186, [2, 8192, 1792]) + convert_element_type_884 = torch.ops.prims.convert_element_type.default(view_1933, torch.float32); view_1933 = None + sigmoid_26 = torch.ops.aten.sigmoid.default(convert_element_type_884) + mul_214 = torch.ops.aten.mul.Tensor(convert_element_type_884, sigmoid_26); convert_element_type_884 = sigmoid_26 = None + convert_element_type_885 = torch.ops.prims.convert_element_type.default(mul_214, torch.bfloat16); mul_214 = None + convert_element_type_886 = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16) + all_gather_into_tensor_296 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_886, 8, '0'); convert_element_type_886 = None + wait_tensor_350 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_296); all_gather_into_tensor_296 = None + permute_295 = torch.ops.aten.permute.default(wait_tensor_350, [1, 0]); wait_tensor_350 = None + mm_187 = torch.ops.aten.mm.default(view_1932, permute_295); view_1932 = permute_295 = None + view_1940 = torch.ops.aten.view.default(mm_187, [2, 8192, 1792]); mm_187 = None + mul_215 = torch.ops.aten.mul.Tensor(convert_element_type_885, view_1940); convert_element_type_885 = view_1940 = None + convert_element_type_889 = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16) + all_gather_into_tensor_297 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_889, 8, '0'); convert_element_type_889 = None + wait_tensor_351 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_297); all_gather_into_tensor_297 = None + permute_296 = torch.ops.aten.permute.default(wait_tensor_351, [1, 0]); wait_tensor_351 = None + view_1947 = torch.ops.aten.view.default(mul_215, [16384, 1792]); mul_215 = None + mm_188 = torch.ops.aten.mm.default(view_1947, permute_296); view_1947 = permute_296 = None + view_1948 = torch.ops.aten.view.default(mm_188, [2, 8192, 4096]); mm_188 = None + split_116 = torch.ops.aten.split.Tensor(view_1948, 1024, 1); view_1948 = None + getitem_1171 = split_116[0] + getitem_1172 = split_116[1] + getitem_1173 = split_116[2] + getitem_1174 = split_116[3] + getitem_1175 = split_116[4] + getitem_1176 = split_116[5] + getitem_1177 = split_116[6] + getitem_1178 = split_116[7]; split_116 = None + cat_108 = torch.ops.aten.cat.default([getitem_1171, getitem_1172, getitem_1173, getitem_1174, getitem_1175, getitem_1176, getitem_1177, getitem_1178]); getitem_1171 = getitem_1172 = getitem_1173 = getitem_1174 = getitem_1175 = getitem_1176 = getitem_1177 = getitem_1178 = None + reduce_scatter_tensor_54 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_108, 'sum', 8, '1'); cat_108 = None + wait_tensor_352 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_54); reduce_scatter_tensor_54 = None + add_107 = torch.ops.aten.add.Tensor(add_105, wait_tensor_352); add_105 = wait_tensor_352 = None + convert_element_type_892 = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16) + all_gather_into_tensor_298 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_892, 8, '0'); convert_element_type_892 = None + wait_tensor_353 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_298); all_gather_into_tensor_298 = None + convert_element_type_893 = torch.ops.prims.convert_element_type.default(add_107, torch.float32) + pow_55 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_893, 2) + mean_54 = torch.ops.aten.mean.dim(pow_55, [2], True); pow_55 = None + add_108 = torch.ops.aten.add.Scalar(mean_54, 1e-05); mean_54 = None + rsqrt_54 = torch.ops.aten.rsqrt.default(add_108); add_108 = None + mul_216 = torch.ops.aten.mul.Tensor(convert_element_type_893, rsqrt_54); convert_element_type_893 = rsqrt_54 = None + mul_217 = torch.ops.aten.mul.Tensor(mul_216, wait_tensor_353); mul_216 = wait_tensor_353 = None + convert_element_type_894 = torch.ops.prims.convert_element_type.default(mul_217, torch.bfloat16); mul_217 = None + all_gather_into_tensor_299 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_894, 8, '1'); convert_element_type_894 = None + wait_tensor_354 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_299); all_gather_into_tensor_299 = None + split_117 = torch.ops.aten.split.Tensor(wait_tensor_354, 2); wait_tensor_354 = None + getitem_1179 = split_117[0] + getitem_1180 = split_117[1] + getitem_1181 = split_117[2] + getitem_1182 = split_117[3] + getitem_1183 = split_117[4] + getitem_1184 = split_117[5] + getitem_1185 = split_117[6] + getitem_1186 = split_117[7]; split_117 = None + cat_109 = torch.ops.aten.cat.default([getitem_1179, getitem_1180, getitem_1181, getitem_1182, getitem_1183, getitem_1184, getitem_1185, getitem_1186], 1); getitem_1179 = getitem_1180 = getitem_1181 = getitem_1182 = getitem_1183 = getitem_1184 = getitem_1185 = getitem_1186 = None + convert_element_type_895 = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16) + all_gather_into_tensor_300 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_895, 8, '0'); convert_element_type_895 = None + wait_tensor_355 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_300); all_gather_into_tensor_300 = None + permute_297 = torch.ops.aten.permute.default(wait_tensor_355, [1, 0]); wait_tensor_355 = None + view_1959 = torch.ops.aten.view.default(cat_109, [16384, 4096]); cat_109 = None + mm_189 = torch.ops.aten.mm.default(view_1959, permute_297); permute_297 = None + view_1960 = torch.ops.aten.view.default(mm_189, [2, 8192, 512]) + convert_element_type_898 = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16) + all_gather_into_tensor_301 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_898, 8, '0'); convert_element_type_898 = None + wait_tensor_356 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_301); all_gather_into_tensor_301 = None + permute_298 = torch.ops.aten.permute.default(wait_tensor_356, [1, 0]); wait_tensor_356 = None + mm_190 = torch.ops.aten.mm.default(view_1959, permute_298); permute_298 = None + view_1967 = torch.ops.aten.view.default(mm_190, [2, 8192, 128]); mm_190 = None + convert_element_type_901 = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16) + all_gather_into_tensor_302 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_901, 8, '0'); convert_element_type_901 = None + wait_tensor_357 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_302); all_gather_into_tensor_302 = None + permute_299 = torch.ops.aten.permute.default(wait_tensor_357, [1, 0]); wait_tensor_357 = None + mm_191 = torch.ops.aten.mm.default(view_1959, permute_299); view_1959 = permute_299 = None + view_1974 = torch.ops.aten.view.default(mm_191, [2, 8192, 128]) + view_1976 = torch.ops.aten.view.default(view_1960, [2, 8192, -1, 128]); view_1960 = None + view_1977 = torch.ops.aten.view.default(view_1967, [2, 8192, -1, 128]); view_1967 = None + view_1978 = torch.ops.aten.view.default(view_1974, [2, 8192, -1, 128]); view_1974 = None + convert_element_type_904 = torch.ops.prims.convert_element_type.default(view_1976, torch.float32); view_1976 = None + view_1979 = torch.ops.aten.view.default(convert_element_type_904, [2, 8192, 4, -1, 2]); convert_element_type_904 = None + view_as_complex_54 = torch.ops.aten.view_as_complex.default(view_1979); view_1979 = None + convert_element_type_905 = torch.ops.prims.convert_element_type.default(view_1977, torch.float32); view_1977 = None + view_1980 = torch.ops.aten.view.default(convert_element_type_905, [2, 8192, 1, -1, 2]); convert_element_type_905 = None + view_as_complex_55 = torch.ops.aten.view_as_complex.default(view_1980); view_1980 = None + mul_218 = torch.ops.aten.mul.Tensor(view_as_complex_54, view_37); view_as_complex_54 = None + view_as_real_54 = torch.ops.aten.view_as_real.default(mul_218); mul_218 = None + view_1982 = torch.ops.aten.view.default(view_as_real_54, [2, 8192, 4, 128]); view_as_real_54 = None + mul_219 = torch.ops.aten.mul.Tensor(view_as_complex_55, view_37); view_as_complex_55 = None + view_as_real_55 = torch.ops.aten.view_as_real.default(mul_219); mul_219 = None + view_1983 = torch.ops.aten.view.default(view_as_real_55, [2, 8192, 1, 128]); view_as_real_55 = None + convert_element_type_906 = torch.ops.prims.convert_element_type.default(view_1982, torch.bfloat16); view_1982 = None + convert_element_type_907 = torch.ops.prims.convert_element_type.default(view_1983, torch.bfloat16); view_1983 = None + unsqueeze_54 = torch.ops.aten.unsqueeze.default(convert_element_type_907, 3); convert_element_type_907 = None + expand_54 = torch.ops.aten.expand.default(unsqueeze_54, [2, 8192, 1, 4, 128]); unsqueeze_54 = None + view_1984 = torch.ops.aten.view.default(expand_54, [2, 8192, 4, 128]); expand_54 = None + unsqueeze_55 = torch.ops.aten.unsqueeze.default(view_1978, 3); view_1978 = None + expand_55 = torch.ops.aten.expand.default(unsqueeze_55, [2, 8192, 1, 4, 128]); unsqueeze_55 = None + view_1985 = torch.ops.aten.view.default(expand_55, [2, 8192, 4, 128]); expand_55 = None + permute_300 = torch.ops.aten.permute.default(convert_element_type_906, [0, 2, 1, 3]); convert_element_type_906 = None + permute_301 = torch.ops.aten.permute.default(view_1984, [0, 2, 1, 3]); view_1984 = None + permute_302 = torch.ops.aten.permute.default(view_1985, [0, 2, 1, 3]); view_1985 = None + _scaled_dot_product_cudnn_attention_27 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_300, permute_301, permute_302, None, True, 0.0, True); permute_300 = permute_301 = permute_302 = None + getitem_1187 = _scaled_dot_product_cudnn_attention_27[0] + getitem_1188 = _scaled_dot_product_cudnn_attention_27[1] + getitem_1193 = _scaled_dot_product_cudnn_attention_27[6] + getitem_1194 = _scaled_dot_product_cudnn_attention_27[7]; _scaled_dot_product_cudnn_attention_27 = None + permute_303 = torch.ops.aten.permute.default(getitem_1187, [0, 2, 1, 3]) + view_1986 = torch.ops.aten.view.default(permute_303, [2, 8192, -1]); permute_303 = None + convert_element_type_908 = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16) + all_gather_into_tensor_303 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_908, 8, '0'); convert_element_type_908 = None + wait_tensor_358 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_303); all_gather_into_tensor_303 = None + permute_304 = torch.ops.aten.permute.default(wait_tensor_358, [1, 0]); wait_tensor_358 = None + view_1992 = torch.ops.aten.view.default(view_1986, [16384, 512]); view_1986 = None + mm_192 = torch.ops.aten.mm.default(view_1992, permute_304); view_1992 = permute_304 = None + view_1993 = torch.ops.aten.view.default(mm_192, [2, 8192, 4096]); mm_192 = None + split_118 = torch.ops.aten.split.Tensor(view_1993, 1024, 1); view_1993 = None + getitem_1196 = split_118[0] + getitem_1197 = split_118[1] + getitem_1198 = split_118[2] + getitem_1199 = split_118[3] + getitem_1200 = split_118[4] + getitem_1201 = split_118[5] + getitem_1202 = split_118[6] + getitem_1203 = split_118[7]; split_118 = None + cat_110 = torch.ops.aten.cat.default([getitem_1196, getitem_1197, getitem_1198, getitem_1199, getitem_1200, getitem_1201, getitem_1202, getitem_1203]); getitem_1196 = getitem_1197 = getitem_1198 = getitem_1199 = getitem_1200 = getitem_1201 = getitem_1202 = getitem_1203 = None + reduce_scatter_tensor_55 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_110, 'sum', 8, '1'); cat_110 = None + wait_tensor_359 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_55) + add_109 = torch.ops.aten.add.Tensor(add_107, wait_tensor_359); wait_tensor_359 = None + convert_element_type_911 = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16) + all_gather_into_tensor_304 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_911, 8, '0'); convert_element_type_911 = None + wait_tensor_360 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_304); all_gather_into_tensor_304 = None + convert_element_type_912 = torch.ops.prims.convert_element_type.default(add_109, torch.float32) + pow_56 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_912, 2) + mean_55 = torch.ops.aten.mean.dim(pow_56, [2], True); pow_56 = None + add_110 = torch.ops.aten.add.Scalar(mean_55, 1e-05); mean_55 = None + rsqrt_55 = torch.ops.aten.rsqrt.default(add_110); add_110 = None + mul_220 = torch.ops.aten.mul.Tensor(convert_element_type_912, rsqrt_55); convert_element_type_912 = rsqrt_55 = None + mul_221 = torch.ops.aten.mul.Tensor(mul_220, wait_tensor_360); mul_220 = wait_tensor_360 = None + convert_element_type_913 = torch.ops.prims.convert_element_type.default(mul_221, torch.bfloat16); mul_221 = None + all_gather_into_tensor_305 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_913, 8, '1'); convert_element_type_913 = None + wait_tensor_361 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_305); all_gather_into_tensor_305 = None + split_119 = torch.ops.aten.split.Tensor(wait_tensor_361, 2); wait_tensor_361 = None + getitem_1204 = split_119[0] + getitem_1205 = split_119[1] + getitem_1206 = split_119[2] + getitem_1207 = split_119[3] + getitem_1208 = split_119[4] + getitem_1209 = split_119[5] + getitem_1210 = split_119[6] + getitem_1211 = split_119[7]; split_119 = None + cat_111 = torch.ops.aten.cat.default([getitem_1204, getitem_1205, getitem_1206, getitem_1207, getitem_1208, getitem_1209, getitem_1210, getitem_1211], 1); getitem_1204 = getitem_1205 = getitem_1206 = getitem_1207 = getitem_1208 = getitem_1209 = getitem_1210 = getitem_1211 = None + convert_element_type_914 = torch.ops.prims.convert_element_type.default(primals_253, torch.bfloat16) + all_gather_into_tensor_306 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_914, 8, '0'); convert_element_type_914 = None + wait_tensor_362 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_306); all_gather_into_tensor_306 = None + permute_305 = torch.ops.aten.permute.default(wait_tensor_362, [1, 0]); wait_tensor_362 = None + view_2004 = torch.ops.aten.view.default(cat_111, [16384, 4096]); cat_111 = None + mm_193 = torch.ops.aten.mm.default(view_2004, permute_305); permute_305 = None + view_2005 = torch.ops.aten.view.default(mm_193, [2, 8192, 1792]) + convert_element_type_917 = torch.ops.prims.convert_element_type.default(view_2005, torch.float32); view_2005 = None + sigmoid_27 = torch.ops.aten.sigmoid.default(convert_element_type_917) + mul_222 = torch.ops.aten.mul.Tensor(convert_element_type_917, sigmoid_27); convert_element_type_917 = sigmoid_27 = None + convert_element_type_918 = torch.ops.prims.convert_element_type.default(mul_222, torch.bfloat16); mul_222 = None + convert_element_type_919 = torch.ops.prims.convert_element_type.default(primals_254, torch.bfloat16) + all_gather_into_tensor_307 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_919, 8, '0'); convert_element_type_919 = None + wait_tensor_363 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_307); all_gather_into_tensor_307 = None + permute_306 = torch.ops.aten.permute.default(wait_tensor_363, [1, 0]); wait_tensor_363 = None + mm_194 = torch.ops.aten.mm.default(view_2004, permute_306); view_2004 = permute_306 = None + view_2012 = torch.ops.aten.view.default(mm_194, [2, 8192, 1792]); mm_194 = None + mul_223 = torch.ops.aten.mul.Tensor(convert_element_type_918, view_2012); convert_element_type_918 = view_2012 = None + convert_element_type_922 = torch.ops.prims.convert_element_type.default(primals_255, torch.bfloat16) + all_gather_into_tensor_308 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_922, 8, '0'); convert_element_type_922 = None + wait_tensor_364 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_308); all_gather_into_tensor_308 = None + permute_307 = torch.ops.aten.permute.default(wait_tensor_364, [1, 0]); wait_tensor_364 = None + view_2019 = torch.ops.aten.view.default(mul_223, [16384, 1792]); mul_223 = None + mm_195 = torch.ops.aten.mm.default(view_2019, permute_307); view_2019 = permute_307 = None + view_2020 = torch.ops.aten.view.default(mm_195, [2, 8192, 4096]); mm_195 = None + split_120 = torch.ops.aten.split.Tensor(view_2020, 1024, 1); view_2020 = None + getitem_1212 = split_120[0] + getitem_1213 = split_120[1] + getitem_1214 = split_120[2] + getitem_1215 = split_120[3] + getitem_1216 = split_120[4] + getitem_1217 = split_120[5] + getitem_1218 = split_120[6] + getitem_1219 = split_120[7]; split_120 = None + cat_112 = torch.ops.aten.cat.default([getitem_1212, getitem_1213, getitem_1214, getitem_1215, getitem_1216, getitem_1217, getitem_1218, getitem_1219]); getitem_1212 = getitem_1213 = getitem_1214 = getitem_1215 = getitem_1216 = getitem_1217 = getitem_1218 = getitem_1219 = None + reduce_scatter_tensor_56 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_112, 'sum', 8, '1'); cat_112 = None + wait_tensor_365 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_56); reduce_scatter_tensor_56 = None + add_111 = torch.ops.aten.add.Tensor(add_109, wait_tensor_365); add_109 = wait_tensor_365 = None + convert_element_type_925 = torch.ops.prims.convert_element_type.default(primals_256, torch.bfloat16) + all_gather_into_tensor_309 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_925, 8, '0'); convert_element_type_925 = None + wait_tensor_366 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_309); all_gather_into_tensor_309 = None + convert_element_type_926 = torch.ops.prims.convert_element_type.default(add_111, torch.float32) + pow_57 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_926, 2) + mean_56 = torch.ops.aten.mean.dim(pow_57, [2], True); pow_57 = None + add_112 = torch.ops.aten.add.Scalar(mean_56, 1e-05); mean_56 = None + rsqrt_56 = torch.ops.aten.rsqrt.default(add_112); add_112 = None + mul_224 = torch.ops.aten.mul.Tensor(convert_element_type_926, rsqrt_56); convert_element_type_926 = rsqrt_56 = None + mul_225 = torch.ops.aten.mul.Tensor(mul_224, wait_tensor_366); mul_224 = wait_tensor_366 = None + convert_element_type_927 = torch.ops.prims.convert_element_type.default(mul_225, torch.bfloat16); mul_225 = None + all_gather_into_tensor_310 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_927, 8, '1'); convert_element_type_927 = None + wait_tensor_367 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_310); all_gather_into_tensor_310 = None + split_121 = torch.ops.aten.split.Tensor(wait_tensor_367, 2); wait_tensor_367 = None + getitem_1220 = split_121[0] + getitem_1221 = split_121[1] + getitem_1222 = split_121[2] + getitem_1223 = split_121[3] + getitem_1224 = split_121[4] + getitem_1225 = split_121[5] + getitem_1226 = split_121[6] + getitem_1227 = split_121[7]; split_121 = None + cat_113 = torch.ops.aten.cat.default([getitem_1220, getitem_1221, getitem_1222, getitem_1223, getitem_1224, getitem_1225, getitem_1226, getitem_1227], 1); getitem_1220 = getitem_1221 = getitem_1222 = getitem_1223 = getitem_1224 = getitem_1225 = getitem_1226 = getitem_1227 = None + convert_element_type_928 = torch.ops.prims.convert_element_type.default(primals_257, torch.bfloat16) + all_gather_into_tensor_311 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_928, 8, '0'); convert_element_type_928 = None + wait_tensor_368 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_311); all_gather_into_tensor_311 = None + permute_308 = torch.ops.aten.permute.default(wait_tensor_368, [1, 0]); wait_tensor_368 = None + view_2031 = torch.ops.aten.view.default(cat_113, [16384, 4096]); cat_113 = None + mm_196 = torch.ops.aten.mm.default(view_2031, permute_308); permute_308 = None + view_2032 = torch.ops.aten.view.default(mm_196, [2, 8192, 512]) + convert_element_type_931 = torch.ops.prims.convert_element_type.default(primals_258, torch.bfloat16) + all_gather_into_tensor_312 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_931, 8, '0'); convert_element_type_931 = None + wait_tensor_369 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_312); all_gather_into_tensor_312 = None + permute_309 = torch.ops.aten.permute.default(wait_tensor_369, [1, 0]); wait_tensor_369 = None + mm_197 = torch.ops.aten.mm.default(view_2031, permute_309); permute_309 = None + view_2039 = torch.ops.aten.view.default(mm_197, [2, 8192, 128]); mm_197 = None + convert_element_type_934 = torch.ops.prims.convert_element_type.default(primals_259, torch.bfloat16) + all_gather_into_tensor_313 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_934, 8, '0'); convert_element_type_934 = None + wait_tensor_370 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_313); all_gather_into_tensor_313 = None + permute_310 = torch.ops.aten.permute.default(wait_tensor_370, [1, 0]); wait_tensor_370 = None + mm_198 = torch.ops.aten.mm.default(view_2031, permute_310); view_2031 = permute_310 = None + view_2046 = torch.ops.aten.view.default(mm_198, [2, 8192, 128]) + view_2048 = torch.ops.aten.view.default(view_2032, [2, 8192, -1, 128]); view_2032 = None + view_2049 = torch.ops.aten.view.default(view_2039, [2, 8192, -1, 128]); view_2039 = None + view_2050 = torch.ops.aten.view.default(view_2046, [2, 8192, -1, 128]); view_2046 = None + convert_element_type_937 = torch.ops.prims.convert_element_type.default(view_2048, torch.float32); view_2048 = None + view_2051 = torch.ops.aten.view.default(convert_element_type_937, [2, 8192, 4, -1, 2]); convert_element_type_937 = None + view_as_complex_56 = torch.ops.aten.view_as_complex.default(view_2051); view_2051 = None + convert_element_type_938 = torch.ops.prims.convert_element_type.default(view_2049, torch.float32); view_2049 = None + view_2052 = torch.ops.aten.view.default(convert_element_type_938, [2, 8192, 1, -1, 2]); convert_element_type_938 = None + view_as_complex_57 = torch.ops.aten.view_as_complex.default(view_2052); view_2052 = None + mul_226 = torch.ops.aten.mul.Tensor(view_as_complex_56, view_37); view_as_complex_56 = None + view_as_real_56 = torch.ops.aten.view_as_real.default(mul_226); mul_226 = None + view_2054 = torch.ops.aten.view.default(view_as_real_56, [2, 8192, 4, 128]); view_as_real_56 = None + mul_227 = torch.ops.aten.mul.Tensor(view_as_complex_57, view_37); view_as_complex_57 = None + view_as_real_57 = torch.ops.aten.view_as_real.default(mul_227); mul_227 = None + view_2055 = torch.ops.aten.view.default(view_as_real_57, [2, 8192, 1, 128]); view_as_real_57 = None + convert_element_type_939 = torch.ops.prims.convert_element_type.default(view_2054, torch.bfloat16); view_2054 = None + convert_element_type_940 = torch.ops.prims.convert_element_type.default(view_2055, torch.bfloat16); view_2055 = None + unsqueeze_56 = torch.ops.aten.unsqueeze.default(convert_element_type_940, 3); convert_element_type_940 = None + expand_56 = torch.ops.aten.expand.default(unsqueeze_56, [2, 8192, 1, 4, 128]); unsqueeze_56 = None + view_2056 = torch.ops.aten.view.default(expand_56, [2, 8192, 4, 128]); expand_56 = None + unsqueeze_57 = torch.ops.aten.unsqueeze.default(view_2050, 3); view_2050 = None + expand_57 = torch.ops.aten.expand.default(unsqueeze_57, [2, 8192, 1, 4, 128]); unsqueeze_57 = None + view_2057 = torch.ops.aten.view.default(expand_57, [2, 8192, 4, 128]); expand_57 = None + permute_311 = torch.ops.aten.permute.default(convert_element_type_939, [0, 2, 1, 3]); convert_element_type_939 = None + permute_312 = torch.ops.aten.permute.default(view_2056, [0, 2, 1, 3]); view_2056 = None + permute_313 = torch.ops.aten.permute.default(view_2057, [0, 2, 1, 3]); view_2057 = None + _scaled_dot_product_cudnn_attention_28 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_311, permute_312, permute_313, None, True, 0.0, True); permute_311 = permute_312 = permute_313 = None + getitem_1228 = _scaled_dot_product_cudnn_attention_28[0] + getitem_1229 = _scaled_dot_product_cudnn_attention_28[1] + getitem_1234 = _scaled_dot_product_cudnn_attention_28[6] + getitem_1235 = _scaled_dot_product_cudnn_attention_28[7]; _scaled_dot_product_cudnn_attention_28 = None + permute_314 = torch.ops.aten.permute.default(getitem_1228, [0, 2, 1, 3]) + view_2058 = torch.ops.aten.view.default(permute_314, [2, 8192, -1]); permute_314 = None + convert_element_type_941 = torch.ops.prims.convert_element_type.default(primals_260, torch.bfloat16) + all_gather_into_tensor_314 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_941, 8, '0'); convert_element_type_941 = None + wait_tensor_371 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_314); all_gather_into_tensor_314 = None + permute_315 = torch.ops.aten.permute.default(wait_tensor_371, [1, 0]); wait_tensor_371 = None + view_2064 = torch.ops.aten.view.default(view_2058, [16384, 512]); view_2058 = None + mm_199 = torch.ops.aten.mm.default(view_2064, permute_315); view_2064 = permute_315 = None + view_2065 = torch.ops.aten.view.default(mm_199, [2, 8192, 4096]); mm_199 = None + split_122 = torch.ops.aten.split.Tensor(view_2065, 1024, 1); view_2065 = None + getitem_1237 = split_122[0] + getitem_1238 = split_122[1] + getitem_1239 = split_122[2] + getitem_1240 = split_122[3] + getitem_1241 = split_122[4] + getitem_1242 = split_122[5] + getitem_1243 = split_122[6] + getitem_1244 = split_122[7]; split_122 = None + cat_114 = torch.ops.aten.cat.default([getitem_1237, getitem_1238, getitem_1239, getitem_1240, getitem_1241, getitem_1242, getitem_1243, getitem_1244]); getitem_1237 = getitem_1238 = getitem_1239 = getitem_1240 = getitem_1241 = getitem_1242 = getitem_1243 = getitem_1244 = None + reduce_scatter_tensor_57 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_114, 'sum', 8, '1'); cat_114 = None + wait_tensor_372 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_57) + add_113 = torch.ops.aten.add.Tensor(add_111, wait_tensor_372); wait_tensor_372 = None + convert_element_type_944 = torch.ops.prims.convert_element_type.default(primals_261, torch.bfloat16) + all_gather_into_tensor_315 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_944, 8, '0'); convert_element_type_944 = None + wait_tensor_373 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_315); all_gather_into_tensor_315 = None + convert_element_type_945 = torch.ops.prims.convert_element_type.default(add_113, torch.float32) + pow_58 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_945, 2) + mean_57 = torch.ops.aten.mean.dim(pow_58, [2], True); pow_58 = None + add_114 = torch.ops.aten.add.Scalar(mean_57, 1e-05); mean_57 = None + rsqrt_57 = torch.ops.aten.rsqrt.default(add_114); add_114 = None + mul_228 = torch.ops.aten.mul.Tensor(convert_element_type_945, rsqrt_57); convert_element_type_945 = rsqrt_57 = None + mul_229 = torch.ops.aten.mul.Tensor(mul_228, wait_tensor_373); mul_228 = wait_tensor_373 = None + convert_element_type_946 = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None + all_gather_into_tensor_316 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_946, 8, '1'); convert_element_type_946 = None + wait_tensor_374 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_316); all_gather_into_tensor_316 = None + split_123 = torch.ops.aten.split.Tensor(wait_tensor_374, 2); wait_tensor_374 = None + getitem_1245 = split_123[0] + getitem_1246 = split_123[1] + getitem_1247 = split_123[2] + getitem_1248 = split_123[3] + getitem_1249 = split_123[4] + getitem_1250 = split_123[5] + getitem_1251 = split_123[6] + getitem_1252 = split_123[7]; split_123 = None + cat_115 = torch.ops.aten.cat.default([getitem_1245, getitem_1246, getitem_1247, getitem_1248, getitem_1249, getitem_1250, getitem_1251, getitem_1252], 1); getitem_1245 = getitem_1246 = getitem_1247 = getitem_1248 = getitem_1249 = getitem_1250 = getitem_1251 = getitem_1252 = None + convert_element_type_947 = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16) + all_gather_into_tensor_317 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_947, 8, '0'); convert_element_type_947 = None + wait_tensor_375 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_317); all_gather_into_tensor_317 = None + permute_316 = torch.ops.aten.permute.default(wait_tensor_375, [1, 0]); wait_tensor_375 = None + view_2076 = torch.ops.aten.view.default(cat_115, [16384, 4096]); cat_115 = None + mm_200 = torch.ops.aten.mm.default(view_2076, permute_316); permute_316 = None + view_2077 = torch.ops.aten.view.default(mm_200, [2, 8192, 1792]) + convert_element_type_950 = torch.ops.prims.convert_element_type.default(view_2077, torch.float32); view_2077 = None + sigmoid_28 = torch.ops.aten.sigmoid.default(convert_element_type_950) + mul_230 = torch.ops.aten.mul.Tensor(convert_element_type_950, sigmoid_28); convert_element_type_950 = sigmoid_28 = None + convert_element_type_951 = torch.ops.prims.convert_element_type.default(mul_230, torch.bfloat16); mul_230 = None + convert_element_type_952 = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16) + all_gather_into_tensor_318 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_952, 8, '0'); convert_element_type_952 = None + wait_tensor_376 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_318); all_gather_into_tensor_318 = None + permute_317 = torch.ops.aten.permute.default(wait_tensor_376, [1, 0]); wait_tensor_376 = None + mm_201 = torch.ops.aten.mm.default(view_2076, permute_317); view_2076 = permute_317 = None + view_2084 = torch.ops.aten.view.default(mm_201, [2, 8192, 1792]); mm_201 = None + mul_231 = torch.ops.aten.mul.Tensor(convert_element_type_951, view_2084); convert_element_type_951 = view_2084 = None + convert_element_type_955 = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16) + all_gather_into_tensor_319 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_955, 8, '0'); convert_element_type_955 = None + wait_tensor_377 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_319); all_gather_into_tensor_319 = None + permute_318 = torch.ops.aten.permute.default(wait_tensor_377, [1, 0]); wait_tensor_377 = None + view_2091 = torch.ops.aten.view.default(mul_231, [16384, 1792]); mul_231 = None + mm_202 = torch.ops.aten.mm.default(view_2091, permute_318); view_2091 = permute_318 = None + view_2092 = torch.ops.aten.view.default(mm_202, [2, 8192, 4096]); mm_202 = None + split_124 = torch.ops.aten.split.Tensor(view_2092, 1024, 1); view_2092 = None + getitem_1253 = split_124[0] + getitem_1254 = split_124[1] + getitem_1255 = split_124[2] + getitem_1256 = split_124[3] + getitem_1257 = split_124[4] + getitem_1258 = split_124[5] + getitem_1259 = split_124[6] + getitem_1260 = split_124[7]; split_124 = None + cat_116 = torch.ops.aten.cat.default([getitem_1253, getitem_1254, getitem_1255, getitem_1256, getitem_1257, getitem_1258, getitem_1259, getitem_1260]); getitem_1253 = getitem_1254 = getitem_1255 = getitem_1256 = getitem_1257 = getitem_1258 = getitem_1259 = getitem_1260 = None + reduce_scatter_tensor_58 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_116, 'sum', 8, '1'); cat_116 = None + wait_tensor_378 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_58); reduce_scatter_tensor_58 = None + add_115 = torch.ops.aten.add.Tensor(add_113, wait_tensor_378); add_113 = wait_tensor_378 = None + convert_element_type_958 = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16) + all_gather_into_tensor_320 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_958, 8, '0'); convert_element_type_958 = None + wait_tensor_379 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_320); all_gather_into_tensor_320 = None + convert_element_type_959 = torch.ops.prims.convert_element_type.default(add_115, torch.float32) + pow_59 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_959, 2) + mean_58 = torch.ops.aten.mean.dim(pow_59, [2], True); pow_59 = None + add_116 = torch.ops.aten.add.Scalar(mean_58, 1e-05); mean_58 = None + rsqrt_58 = torch.ops.aten.rsqrt.default(add_116); add_116 = None + mul_232 = torch.ops.aten.mul.Tensor(convert_element_type_959, rsqrt_58); convert_element_type_959 = rsqrt_58 = None + mul_233 = torch.ops.aten.mul.Tensor(mul_232, wait_tensor_379); mul_232 = wait_tensor_379 = None + convert_element_type_960 = torch.ops.prims.convert_element_type.default(mul_233, torch.bfloat16); mul_233 = None + all_gather_into_tensor_321 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_960, 8, '1'); convert_element_type_960 = None + wait_tensor_380 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_321); all_gather_into_tensor_321 = None + split_125 = torch.ops.aten.split.Tensor(wait_tensor_380, 2); wait_tensor_380 = None + getitem_1261 = split_125[0] + getitem_1262 = split_125[1] + getitem_1263 = split_125[2] + getitem_1264 = split_125[3] + getitem_1265 = split_125[4] + getitem_1266 = split_125[5] + getitem_1267 = split_125[6] + getitem_1268 = split_125[7]; split_125 = None + cat_117 = torch.ops.aten.cat.default([getitem_1261, getitem_1262, getitem_1263, getitem_1264, getitem_1265, getitem_1266, getitem_1267, getitem_1268], 1); getitem_1261 = getitem_1262 = getitem_1263 = getitem_1264 = getitem_1265 = getitem_1266 = getitem_1267 = getitem_1268 = None + convert_element_type_961 = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16) + all_gather_into_tensor_322 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_961, 8, '0'); convert_element_type_961 = None + wait_tensor_381 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_322); all_gather_into_tensor_322 = None + permute_319 = torch.ops.aten.permute.default(wait_tensor_381, [1, 0]); wait_tensor_381 = None + view_2103 = torch.ops.aten.view.default(cat_117, [16384, 4096]); cat_117 = None + mm_203 = torch.ops.aten.mm.default(view_2103, permute_319); permute_319 = None + view_2104 = torch.ops.aten.view.default(mm_203, [2, 8192, 512]) + convert_element_type_964 = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16) + all_gather_into_tensor_323 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_964, 8, '0'); convert_element_type_964 = None + wait_tensor_382 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_323); all_gather_into_tensor_323 = None + permute_320 = torch.ops.aten.permute.default(wait_tensor_382, [1, 0]); wait_tensor_382 = None + mm_204 = torch.ops.aten.mm.default(view_2103, permute_320); permute_320 = None + view_2111 = torch.ops.aten.view.default(mm_204, [2, 8192, 128]); mm_204 = None + convert_element_type_967 = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16) + all_gather_into_tensor_324 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_967, 8, '0'); convert_element_type_967 = None + wait_tensor_383 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_324); all_gather_into_tensor_324 = None + permute_321 = torch.ops.aten.permute.default(wait_tensor_383, [1, 0]); wait_tensor_383 = None + mm_205 = torch.ops.aten.mm.default(view_2103, permute_321); view_2103 = permute_321 = None + view_2118 = torch.ops.aten.view.default(mm_205, [2, 8192, 128]) + view_2120 = torch.ops.aten.view.default(view_2104, [2, 8192, -1, 128]); view_2104 = None + view_2121 = torch.ops.aten.view.default(view_2111, [2, 8192, -1, 128]); view_2111 = None + view_2122 = torch.ops.aten.view.default(view_2118, [2, 8192, -1, 128]); view_2118 = None + convert_element_type_970 = torch.ops.prims.convert_element_type.default(view_2120, torch.float32); view_2120 = None + view_2123 = torch.ops.aten.view.default(convert_element_type_970, [2, 8192, 4, -1, 2]); convert_element_type_970 = None + view_as_complex_58 = torch.ops.aten.view_as_complex.default(view_2123); view_2123 = None + convert_element_type_971 = torch.ops.prims.convert_element_type.default(view_2121, torch.float32); view_2121 = None + view_2124 = torch.ops.aten.view.default(convert_element_type_971, [2, 8192, 1, -1, 2]); convert_element_type_971 = None + view_as_complex_59 = torch.ops.aten.view_as_complex.default(view_2124); view_2124 = None + mul_234 = torch.ops.aten.mul.Tensor(view_as_complex_58, view_37); view_as_complex_58 = None + view_as_real_58 = torch.ops.aten.view_as_real.default(mul_234); mul_234 = None + view_2126 = torch.ops.aten.view.default(view_as_real_58, [2, 8192, 4, 128]); view_as_real_58 = None + mul_235 = torch.ops.aten.mul.Tensor(view_as_complex_59, view_37); view_as_complex_59 = None + view_as_real_59 = torch.ops.aten.view_as_real.default(mul_235); mul_235 = None + view_2127 = torch.ops.aten.view.default(view_as_real_59, [2, 8192, 1, 128]); view_as_real_59 = None + convert_element_type_972 = torch.ops.prims.convert_element_type.default(view_2126, torch.bfloat16); view_2126 = None + convert_element_type_973 = torch.ops.prims.convert_element_type.default(view_2127, torch.bfloat16); view_2127 = None + unsqueeze_58 = torch.ops.aten.unsqueeze.default(convert_element_type_973, 3); convert_element_type_973 = None + expand_58 = torch.ops.aten.expand.default(unsqueeze_58, [2, 8192, 1, 4, 128]); unsqueeze_58 = None + view_2128 = torch.ops.aten.view.default(expand_58, [2, 8192, 4, 128]); expand_58 = None + unsqueeze_59 = torch.ops.aten.unsqueeze.default(view_2122, 3); view_2122 = None + expand_59 = torch.ops.aten.expand.default(unsqueeze_59, [2, 8192, 1, 4, 128]); unsqueeze_59 = None + view_2129 = torch.ops.aten.view.default(expand_59, [2, 8192, 4, 128]); expand_59 = None + permute_322 = torch.ops.aten.permute.default(convert_element_type_972, [0, 2, 1, 3]); convert_element_type_972 = None + permute_323 = torch.ops.aten.permute.default(view_2128, [0, 2, 1, 3]); view_2128 = None + permute_324 = torch.ops.aten.permute.default(view_2129, [0, 2, 1, 3]); view_2129 = None + _scaled_dot_product_cudnn_attention_29 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_322, permute_323, permute_324, None, True, 0.0, True); permute_322 = permute_323 = permute_324 = None + getitem_1269 = _scaled_dot_product_cudnn_attention_29[0] + getitem_1270 = _scaled_dot_product_cudnn_attention_29[1] + getitem_1275 = _scaled_dot_product_cudnn_attention_29[6] + getitem_1276 = _scaled_dot_product_cudnn_attention_29[7]; _scaled_dot_product_cudnn_attention_29 = None + permute_325 = torch.ops.aten.permute.default(getitem_1269, [0, 2, 1, 3]) + view_2130 = torch.ops.aten.view.default(permute_325, [2, 8192, -1]); permute_325 = None + convert_element_type_974 = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16) + all_gather_into_tensor_325 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_974, 8, '0'); convert_element_type_974 = None + wait_tensor_384 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_325); all_gather_into_tensor_325 = None + permute_326 = torch.ops.aten.permute.default(wait_tensor_384, [1, 0]); wait_tensor_384 = None + view_2136 = torch.ops.aten.view.default(view_2130, [16384, 512]); view_2130 = None + mm_206 = torch.ops.aten.mm.default(view_2136, permute_326); view_2136 = permute_326 = None + view_2137 = torch.ops.aten.view.default(mm_206, [2, 8192, 4096]); mm_206 = None + split_126 = torch.ops.aten.split.Tensor(view_2137, 1024, 1); view_2137 = None + getitem_1278 = split_126[0] + getitem_1279 = split_126[1] + getitem_1280 = split_126[2] + getitem_1281 = split_126[3] + getitem_1282 = split_126[4] + getitem_1283 = split_126[5] + getitem_1284 = split_126[6] + getitem_1285 = split_126[7]; split_126 = None + cat_118 = torch.ops.aten.cat.default([getitem_1278, getitem_1279, getitem_1280, getitem_1281, getitem_1282, getitem_1283, getitem_1284, getitem_1285]); getitem_1278 = getitem_1279 = getitem_1280 = getitem_1281 = getitem_1282 = getitem_1283 = getitem_1284 = getitem_1285 = None + reduce_scatter_tensor_59 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_118, 'sum', 8, '1'); cat_118 = None + wait_tensor_385 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_59) + add_117 = torch.ops.aten.add.Tensor(add_115, wait_tensor_385); wait_tensor_385 = None + convert_element_type_977 = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16) + all_gather_into_tensor_326 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_977, 8, '0'); convert_element_type_977 = None + wait_tensor_386 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_326); all_gather_into_tensor_326 = None + convert_element_type_978 = torch.ops.prims.convert_element_type.default(add_117, torch.float32) + pow_60 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_978, 2) + mean_59 = torch.ops.aten.mean.dim(pow_60, [2], True); pow_60 = None + add_118 = torch.ops.aten.add.Scalar(mean_59, 1e-05); mean_59 = None + rsqrt_59 = torch.ops.aten.rsqrt.default(add_118); add_118 = None + mul_236 = torch.ops.aten.mul.Tensor(convert_element_type_978, rsqrt_59); convert_element_type_978 = rsqrt_59 = None + mul_237 = torch.ops.aten.mul.Tensor(mul_236, wait_tensor_386); mul_236 = wait_tensor_386 = None + convert_element_type_979 = torch.ops.prims.convert_element_type.default(mul_237, torch.bfloat16); mul_237 = None + all_gather_into_tensor_327 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_979, 8, '1'); convert_element_type_979 = None + wait_tensor_387 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_327); all_gather_into_tensor_327 = None + split_127 = torch.ops.aten.split.Tensor(wait_tensor_387, 2); wait_tensor_387 = None + getitem_1286 = split_127[0] + getitem_1287 = split_127[1] + getitem_1288 = split_127[2] + getitem_1289 = split_127[3] + getitem_1290 = split_127[4] + getitem_1291 = split_127[5] + getitem_1292 = split_127[6] + getitem_1293 = split_127[7]; split_127 = None + cat_119 = torch.ops.aten.cat.default([getitem_1286, getitem_1287, getitem_1288, getitem_1289, getitem_1290, getitem_1291, getitem_1292, getitem_1293], 1); getitem_1286 = getitem_1287 = getitem_1288 = getitem_1289 = getitem_1290 = getitem_1291 = getitem_1292 = getitem_1293 = None + convert_element_type_980 = torch.ops.prims.convert_element_type.default(primals_271, torch.bfloat16) + all_gather_into_tensor_328 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_980, 8, '0'); convert_element_type_980 = None + wait_tensor_388 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_328); all_gather_into_tensor_328 = None + permute_327 = torch.ops.aten.permute.default(wait_tensor_388, [1, 0]); wait_tensor_388 = None + view_2148 = torch.ops.aten.view.default(cat_119, [16384, 4096]); cat_119 = None + mm_207 = torch.ops.aten.mm.default(view_2148, permute_327); permute_327 = None + view_2149 = torch.ops.aten.view.default(mm_207, [2, 8192, 1792]) + convert_element_type_983 = torch.ops.prims.convert_element_type.default(view_2149, torch.float32); view_2149 = None + sigmoid_29 = torch.ops.aten.sigmoid.default(convert_element_type_983) + mul_238 = torch.ops.aten.mul.Tensor(convert_element_type_983, sigmoid_29); convert_element_type_983 = sigmoid_29 = None + convert_element_type_984 = torch.ops.prims.convert_element_type.default(mul_238, torch.bfloat16); mul_238 = None + convert_element_type_985 = torch.ops.prims.convert_element_type.default(primals_272, torch.bfloat16) + all_gather_into_tensor_329 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_985, 8, '0'); convert_element_type_985 = None + wait_tensor_389 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_329); all_gather_into_tensor_329 = None + permute_328 = torch.ops.aten.permute.default(wait_tensor_389, [1, 0]); wait_tensor_389 = None + mm_208 = torch.ops.aten.mm.default(view_2148, permute_328); view_2148 = permute_328 = None + view_2156 = torch.ops.aten.view.default(mm_208, [2, 8192, 1792]); mm_208 = None + mul_239 = torch.ops.aten.mul.Tensor(convert_element_type_984, view_2156); convert_element_type_984 = view_2156 = None + convert_element_type_988 = torch.ops.prims.convert_element_type.default(primals_273, torch.bfloat16) + all_gather_into_tensor_330 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_988, 8, '0'); convert_element_type_988 = None + wait_tensor_390 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_330); all_gather_into_tensor_330 = None + permute_329 = torch.ops.aten.permute.default(wait_tensor_390, [1, 0]); wait_tensor_390 = None + view_2163 = torch.ops.aten.view.default(mul_239, [16384, 1792]); mul_239 = None + mm_209 = torch.ops.aten.mm.default(view_2163, permute_329); view_2163 = permute_329 = None + view_2164 = torch.ops.aten.view.default(mm_209, [2, 8192, 4096]); mm_209 = None + split_128 = torch.ops.aten.split.Tensor(view_2164, 1024, 1); view_2164 = None + getitem_1294 = split_128[0] + getitem_1295 = split_128[1] + getitem_1296 = split_128[2] + getitem_1297 = split_128[3] + getitem_1298 = split_128[4] + getitem_1299 = split_128[5] + getitem_1300 = split_128[6] + getitem_1301 = split_128[7]; split_128 = None + cat_120 = torch.ops.aten.cat.default([getitem_1294, getitem_1295, getitem_1296, getitem_1297, getitem_1298, getitem_1299, getitem_1300, getitem_1301]); getitem_1294 = getitem_1295 = getitem_1296 = getitem_1297 = getitem_1298 = getitem_1299 = getitem_1300 = getitem_1301 = None + reduce_scatter_tensor_60 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_120, 'sum', 8, '1'); cat_120 = None + wait_tensor_391 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_60); reduce_scatter_tensor_60 = None + add_119 = torch.ops.aten.add.Tensor(add_117, wait_tensor_391); add_117 = wait_tensor_391 = None + convert_element_type_991 = torch.ops.prims.convert_element_type.default(primals_274, torch.bfloat16) + all_gather_into_tensor_331 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_991, 8, '0'); convert_element_type_991 = None + wait_tensor_392 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_331); all_gather_into_tensor_331 = None + convert_element_type_992 = torch.ops.prims.convert_element_type.default(add_119, torch.float32) + pow_61 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_992, 2) + mean_60 = torch.ops.aten.mean.dim(pow_61, [2], True); pow_61 = None + add_120 = torch.ops.aten.add.Scalar(mean_60, 1e-05); mean_60 = None + rsqrt_60 = torch.ops.aten.rsqrt.default(add_120); add_120 = None + mul_240 = torch.ops.aten.mul.Tensor(convert_element_type_992, rsqrt_60); convert_element_type_992 = rsqrt_60 = None + mul_241 = torch.ops.aten.mul.Tensor(mul_240, wait_tensor_392); mul_240 = wait_tensor_392 = None + convert_element_type_993 = torch.ops.prims.convert_element_type.default(mul_241, torch.bfloat16); mul_241 = None + all_gather_into_tensor_332 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_993, 8, '1'); convert_element_type_993 = None + wait_tensor_393 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_332); all_gather_into_tensor_332 = None + split_129 = torch.ops.aten.split.Tensor(wait_tensor_393, 2); wait_tensor_393 = None + getitem_1302 = split_129[0] + getitem_1303 = split_129[1] + getitem_1304 = split_129[2] + getitem_1305 = split_129[3] + getitem_1306 = split_129[4] + getitem_1307 = split_129[5] + getitem_1308 = split_129[6] + getitem_1309 = split_129[7]; split_129 = None + cat_121 = torch.ops.aten.cat.default([getitem_1302, getitem_1303, getitem_1304, getitem_1305, getitem_1306, getitem_1307, getitem_1308, getitem_1309], 1); getitem_1302 = getitem_1303 = getitem_1304 = getitem_1305 = getitem_1306 = getitem_1307 = getitem_1308 = getitem_1309 = None + convert_element_type_994 = torch.ops.prims.convert_element_type.default(primals_275, torch.bfloat16) + all_gather_into_tensor_333 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_994, 8, '0'); convert_element_type_994 = None + wait_tensor_394 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_333); all_gather_into_tensor_333 = None + permute_330 = torch.ops.aten.permute.default(wait_tensor_394, [1, 0]); wait_tensor_394 = None + view_2175 = torch.ops.aten.view.default(cat_121, [16384, 4096]); cat_121 = None + mm_210 = torch.ops.aten.mm.default(view_2175, permute_330); permute_330 = None + view_2176 = torch.ops.aten.view.default(mm_210, [2, 8192, 512]) + convert_element_type_997 = torch.ops.prims.convert_element_type.default(primals_276, torch.bfloat16) + all_gather_into_tensor_334 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_997, 8, '0'); convert_element_type_997 = None + wait_tensor_395 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_334); all_gather_into_tensor_334 = None + permute_331 = torch.ops.aten.permute.default(wait_tensor_395, [1, 0]); wait_tensor_395 = None + mm_211 = torch.ops.aten.mm.default(view_2175, permute_331); permute_331 = None + view_2183 = torch.ops.aten.view.default(mm_211, [2, 8192, 128]); mm_211 = None + convert_element_type_1000 = torch.ops.prims.convert_element_type.default(primals_277, torch.bfloat16) + all_gather_into_tensor_335 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1000, 8, '0'); convert_element_type_1000 = None + wait_tensor_396 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_335); all_gather_into_tensor_335 = None + permute_332 = torch.ops.aten.permute.default(wait_tensor_396, [1, 0]); wait_tensor_396 = None + mm_212 = torch.ops.aten.mm.default(view_2175, permute_332); view_2175 = permute_332 = None + view_2190 = torch.ops.aten.view.default(mm_212, [2, 8192, 128]) + view_2192 = torch.ops.aten.view.default(view_2176, [2, 8192, -1, 128]); view_2176 = None + view_2193 = torch.ops.aten.view.default(view_2183, [2, 8192, -1, 128]); view_2183 = None + view_2194 = torch.ops.aten.view.default(view_2190, [2, 8192, -1, 128]); view_2190 = None + convert_element_type_1003 = torch.ops.prims.convert_element_type.default(view_2192, torch.float32); view_2192 = None + view_2195 = torch.ops.aten.view.default(convert_element_type_1003, [2, 8192, 4, -1, 2]); convert_element_type_1003 = None + view_as_complex_60 = torch.ops.aten.view_as_complex.default(view_2195); view_2195 = None + convert_element_type_1004 = torch.ops.prims.convert_element_type.default(view_2193, torch.float32); view_2193 = None + view_2196 = torch.ops.aten.view.default(convert_element_type_1004, [2, 8192, 1, -1, 2]); convert_element_type_1004 = None + view_as_complex_61 = torch.ops.aten.view_as_complex.default(view_2196); view_2196 = None + mul_242 = torch.ops.aten.mul.Tensor(view_as_complex_60, view_37); view_as_complex_60 = None + view_as_real_60 = torch.ops.aten.view_as_real.default(mul_242); mul_242 = None + view_2198 = torch.ops.aten.view.default(view_as_real_60, [2, 8192, 4, 128]); view_as_real_60 = None + mul_243 = torch.ops.aten.mul.Tensor(view_as_complex_61, view_37); view_as_complex_61 = None + view_as_real_61 = torch.ops.aten.view_as_real.default(mul_243); mul_243 = None + view_2199 = torch.ops.aten.view.default(view_as_real_61, [2, 8192, 1, 128]); view_as_real_61 = None + convert_element_type_1005 = torch.ops.prims.convert_element_type.default(view_2198, torch.bfloat16); view_2198 = None + convert_element_type_1006 = torch.ops.prims.convert_element_type.default(view_2199, torch.bfloat16); view_2199 = None + unsqueeze_60 = torch.ops.aten.unsqueeze.default(convert_element_type_1006, 3); convert_element_type_1006 = None + expand_60 = torch.ops.aten.expand.default(unsqueeze_60, [2, 8192, 1, 4, 128]); unsqueeze_60 = None + view_2200 = torch.ops.aten.view.default(expand_60, [2, 8192, 4, 128]); expand_60 = None + unsqueeze_61 = torch.ops.aten.unsqueeze.default(view_2194, 3); view_2194 = None + expand_61 = torch.ops.aten.expand.default(unsqueeze_61, [2, 8192, 1, 4, 128]); unsqueeze_61 = None + view_2201 = torch.ops.aten.view.default(expand_61, [2, 8192, 4, 128]); expand_61 = None + permute_333 = torch.ops.aten.permute.default(convert_element_type_1005, [0, 2, 1, 3]); convert_element_type_1005 = None + permute_334 = torch.ops.aten.permute.default(view_2200, [0, 2, 1, 3]); view_2200 = None + permute_335 = torch.ops.aten.permute.default(view_2201, [0, 2, 1, 3]); view_2201 = None + _scaled_dot_product_cudnn_attention_30 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_333, permute_334, permute_335, None, True, 0.0, True); permute_333 = permute_334 = permute_335 = None + getitem_1310 = _scaled_dot_product_cudnn_attention_30[0] + getitem_1311 = _scaled_dot_product_cudnn_attention_30[1] + getitem_1316 = _scaled_dot_product_cudnn_attention_30[6] + getitem_1317 = _scaled_dot_product_cudnn_attention_30[7]; _scaled_dot_product_cudnn_attention_30 = None + permute_336 = torch.ops.aten.permute.default(getitem_1310, [0, 2, 1, 3]) + view_2202 = torch.ops.aten.view.default(permute_336, [2, 8192, -1]); permute_336 = None + convert_element_type_1007 = torch.ops.prims.convert_element_type.default(primals_278, torch.bfloat16) + all_gather_into_tensor_336 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1007, 8, '0'); convert_element_type_1007 = None + wait_tensor_397 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_336); all_gather_into_tensor_336 = None + permute_337 = torch.ops.aten.permute.default(wait_tensor_397, [1, 0]); wait_tensor_397 = None + view_2208 = torch.ops.aten.view.default(view_2202, [16384, 512]); view_2202 = None + mm_213 = torch.ops.aten.mm.default(view_2208, permute_337); view_2208 = permute_337 = None + view_2209 = torch.ops.aten.view.default(mm_213, [2, 8192, 4096]); mm_213 = None + split_130 = torch.ops.aten.split.Tensor(view_2209, 1024, 1); view_2209 = None + getitem_1319 = split_130[0] + getitem_1320 = split_130[1] + getitem_1321 = split_130[2] + getitem_1322 = split_130[3] + getitem_1323 = split_130[4] + getitem_1324 = split_130[5] + getitem_1325 = split_130[6] + getitem_1326 = split_130[7]; split_130 = None + cat_122 = torch.ops.aten.cat.default([getitem_1319, getitem_1320, getitem_1321, getitem_1322, getitem_1323, getitem_1324, getitem_1325, getitem_1326]); getitem_1319 = getitem_1320 = getitem_1321 = getitem_1322 = getitem_1323 = getitem_1324 = getitem_1325 = getitem_1326 = None + reduce_scatter_tensor_61 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_122, 'sum', 8, '1'); cat_122 = None + wait_tensor_398 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_61) + add_121 = torch.ops.aten.add.Tensor(add_119, wait_tensor_398); wait_tensor_398 = None + convert_element_type_1010 = torch.ops.prims.convert_element_type.default(primals_279, torch.bfloat16) + all_gather_into_tensor_337 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1010, 8, '0'); convert_element_type_1010 = None + wait_tensor_399 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_337); all_gather_into_tensor_337 = None + convert_element_type_1011 = torch.ops.prims.convert_element_type.default(add_121, torch.float32) + pow_62 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1011, 2) + mean_61 = torch.ops.aten.mean.dim(pow_62, [2], True); pow_62 = None + add_122 = torch.ops.aten.add.Scalar(mean_61, 1e-05); mean_61 = None + rsqrt_61 = torch.ops.aten.rsqrt.default(add_122); add_122 = None + mul_244 = torch.ops.aten.mul.Tensor(convert_element_type_1011, rsqrt_61); convert_element_type_1011 = rsqrt_61 = None + mul_245 = torch.ops.aten.mul.Tensor(mul_244, wait_tensor_399); mul_244 = wait_tensor_399 = None + convert_element_type_1012 = torch.ops.prims.convert_element_type.default(mul_245, torch.bfloat16); mul_245 = None + all_gather_into_tensor_338 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1012, 8, '1'); convert_element_type_1012 = None + wait_tensor_400 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_338); all_gather_into_tensor_338 = None + split_131 = torch.ops.aten.split.Tensor(wait_tensor_400, 2); wait_tensor_400 = None + getitem_1327 = split_131[0] + getitem_1328 = split_131[1] + getitem_1329 = split_131[2] + getitem_1330 = split_131[3] + getitem_1331 = split_131[4] + getitem_1332 = split_131[5] + getitem_1333 = split_131[6] + getitem_1334 = split_131[7]; split_131 = None + cat_123 = torch.ops.aten.cat.default([getitem_1327, getitem_1328, getitem_1329, getitem_1330, getitem_1331, getitem_1332, getitem_1333, getitem_1334], 1); getitem_1327 = getitem_1328 = getitem_1329 = getitem_1330 = getitem_1331 = getitem_1332 = getitem_1333 = getitem_1334 = None + convert_element_type_1013 = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16) + all_gather_into_tensor_339 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1013, 8, '0'); convert_element_type_1013 = None + wait_tensor_401 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_339); all_gather_into_tensor_339 = None + permute_338 = torch.ops.aten.permute.default(wait_tensor_401, [1, 0]); wait_tensor_401 = None + view_2220 = torch.ops.aten.view.default(cat_123, [16384, 4096]); cat_123 = None + mm_214 = torch.ops.aten.mm.default(view_2220, permute_338); permute_338 = None + view_2221 = torch.ops.aten.view.default(mm_214, [2, 8192, 1792]) + convert_element_type_1016 = torch.ops.prims.convert_element_type.default(view_2221, torch.float32); view_2221 = None + sigmoid_30 = torch.ops.aten.sigmoid.default(convert_element_type_1016) + mul_246 = torch.ops.aten.mul.Tensor(convert_element_type_1016, sigmoid_30); convert_element_type_1016 = sigmoid_30 = None + convert_element_type_1017 = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None + convert_element_type_1018 = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16) + all_gather_into_tensor_340 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1018, 8, '0'); convert_element_type_1018 = None + wait_tensor_402 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_340); all_gather_into_tensor_340 = None + permute_339 = torch.ops.aten.permute.default(wait_tensor_402, [1, 0]); wait_tensor_402 = None + mm_215 = torch.ops.aten.mm.default(view_2220, permute_339); view_2220 = permute_339 = None + view_2228 = torch.ops.aten.view.default(mm_215, [2, 8192, 1792]); mm_215 = None + mul_247 = torch.ops.aten.mul.Tensor(convert_element_type_1017, view_2228); convert_element_type_1017 = view_2228 = None + convert_element_type_1021 = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16) + all_gather_into_tensor_341 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1021, 8, '0'); convert_element_type_1021 = None + wait_tensor_403 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_341); all_gather_into_tensor_341 = None + permute_340 = torch.ops.aten.permute.default(wait_tensor_403, [1, 0]); wait_tensor_403 = None + view_2235 = torch.ops.aten.view.default(mul_247, [16384, 1792]); mul_247 = None + mm_216 = torch.ops.aten.mm.default(view_2235, permute_340); view_2235 = permute_340 = None + view_2236 = torch.ops.aten.view.default(mm_216, [2, 8192, 4096]); mm_216 = None + split_132 = torch.ops.aten.split.Tensor(view_2236, 1024, 1); view_2236 = None + getitem_1335 = split_132[0] + getitem_1336 = split_132[1] + getitem_1337 = split_132[2] + getitem_1338 = split_132[3] + getitem_1339 = split_132[4] + getitem_1340 = split_132[5] + getitem_1341 = split_132[6] + getitem_1342 = split_132[7]; split_132 = None + cat_124 = torch.ops.aten.cat.default([getitem_1335, getitem_1336, getitem_1337, getitem_1338, getitem_1339, getitem_1340, getitem_1341, getitem_1342]); getitem_1335 = getitem_1336 = getitem_1337 = getitem_1338 = getitem_1339 = getitem_1340 = getitem_1341 = getitem_1342 = None + reduce_scatter_tensor_62 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_124, 'sum', 8, '1'); cat_124 = None + wait_tensor_404 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_62); reduce_scatter_tensor_62 = None + add_123 = torch.ops.aten.add.Tensor(add_121, wait_tensor_404); add_121 = wait_tensor_404 = None + convert_element_type_1024 = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16) + all_gather_into_tensor_342 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1024, 8, '0'); convert_element_type_1024 = None + wait_tensor_405 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_342); all_gather_into_tensor_342 = None + convert_element_type_1025 = torch.ops.prims.convert_element_type.default(add_123, torch.float32) + pow_63 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1025, 2) + mean_62 = torch.ops.aten.mean.dim(pow_63, [2], True); pow_63 = None + add_124 = torch.ops.aten.add.Scalar(mean_62, 1e-05); mean_62 = None + rsqrt_62 = torch.ops.aten.rsqrt.default(add_124); add_124 = None + mul_248 = torch.ops.aten.mul.Tensor(convert_element_type_1025, rsqrt_62); convert_element_type_1025 = rsqrt_62 = None + mul_249 = torch.ops.aten.mul.Tensor(mul_248, wait_tensor_405); mul_248 = wait_tensor_405 = None + convert_element_type_1026 = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16); mul_249 = None + all_gather_into_tensor_343 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1026, 8, '1'); convert_element_type_1026 = None + wait_tensor_406 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_343); all_gather_into_tensor_343 = None + split_133 = torch.ops.aten.split.Tensor(wait_tensor_406, 2); wait_tensor_406 = None + getitem_1343 = split_133[0] + getitem_1344 = split_133[1] + getitem_1345 = split_133[2] + getitem_1346 = split_133[3] + getitem_1347 = split_133[4] + getitem_1348 = split_133[5] + getitem_1349 = split_133[6] + getitem_1350 = split_133[7]; split_133 = None + cat_125 = torch.ops.aten.cat.default([getitem_1343, getitem_1344, getitem_1345, getitem_1346, getitem_1347, getitem_1348, getitem_1349, getitem_1350], 1); getitem_1343 = getitem_1344 = getitem_1345 = getitem_1346 = getitem_1347 = getitem_1348 = getitem_1349 = getitem_1350 = None + convert_element_type_1027 = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16) + all_gather_into_tensor_344 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1027, 8, '0'); convert_element_type_1027 = None + wait_tensor_407 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_344); all_gather_into_tensor_344 = None + permute_341 = torch.ops.aten.permute.default(wait_tensor_407, [1, 0]); wait_tensor_407 = None + view_2247 = torch.ops.aten.view.default(cat_125, [16384, 4096]); cat_125 = None + mm_217 = torch.ops.aten.mm.default(view_2247, permute_341); permute_341 = None + view_2248 = torch.ops.aten.view.default(mm_217, [2, 8192, 512]) + convert_element_type_1030 = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16) + all_gather_into_tensor_345 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1030, 8, '0'); convert_element_type_1030 = None + wait_tensor_408 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_345); all_gather_into_tensor_345 = None + permute_342 = torch.ops.aten.permute.default(wait_tensor_408, [1, 0]); wait_tensor_408 = None + mm_218 = torch.ops.aten.mm.default(view_2247, permute_342); permute_342 = None + view_2255 = torch.ops.aten.view.default(mm_218, [2, 8192, 128]); mm_218 = None + convert_element_type_1033 = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16) + all_gather_into_tensor_346 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1033, 8, '0'); convert_element_type_1033 = None + wait_tensor_409 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_346); all_gather_into_tensor_346 = None + permute_343 = torch.ops.aten.permute.default(wait_tensor_409, [1, 0]); wait_tensor_409 = None + mm_219 = torch.ops.aten.mm.default(view_2247, permute_343); view_2247 = permute_343 = None + view_2262 = torch.ops.aten.view.default(mm_219, [2, 8192, 128]) + view_2264 = torch.ops.aten.view.default(view_2248, [2, 8192, -1, 128]); view_2248 = None + view_2265 = torch.ops.aten.view.default(view_2255, [2, 8192, -1, 128]); view_2255 = None + view_2266 = torch.ops.aten.view.default(view_2262, [2, 8192, -1, 128]); view_2262 = None + convert_element_type_1036 = torch.ops.prims.convert_element_type.default(view_2264, torch.float32); view_2264 = None + view_2267 = torch.ops.aten.view.default(convert_element_type_1036, [2, 8192, 4, -1, 2]); convert_element_type_1036 = None + view_as_complex_62 = torch.ops.aten.view_as_complex.default(view_2267); view_2267 = None + convert_element_type_1037 = torch.ops.prims.convert_element_type.default(view_2265, torch.float32); view_2265 = None + view_2268 = torch.ops.aten.view.default(convert_element_type_1037, [2, 8192, 1, -1, 2]); convert_element_type_1037 = None + view_as_complex_63 = torch.ops.aten.view_as_complex.default(view_2268); view_2268 = None + mul_250 = torch.ops.aten.mul.Tensor(view_as_complex_62, view_37); view_as_complex_62 = None + view_as_real_62 = torch.ops.aten.view_as_real.default(mul_250); mul_250 = None + view_2270 = torch.ops.aten.view.default(view_as_real_62, [2, 8192, 4, 128]); view_as_real_62 = None + mul_251 = torch.ops.aten.mul.Tensor(view_as_complex_63, view_37); view_as_complex_63 = view_37 = None + view_as_real_63 = torch.ops.aten.view_as_real.default(mul_251); mul_251 = None + view_2271 = torch.ops.aten.view.default(view_as_real_63, [2, 8192, 1, 128]); view_as_real_63 = None + convert_element_type_1038 = torch.ops.prims.convert_element_type.default(view_2270, torch.bfloat16); view_2270 = None + convert_element_type_1039 = torch.ops.prims.convert_element_type.default(view_2271, torch.bfloat16); view_2271 = None + unsqueeze_62 = torch.ops.aten.unsqueeze.default(convert_element_type_1039, 3); convert_element_type_1039 = None + expand_62 = torch.ops.aten.expand.default(unsqueeze_62, [2, 8192, 1, 4, 128]); unsqueeze_62 = None + view_2272 = torch.ops.aten.view.default(expand_62, [2, 8192, 4, 128]); expand_62 = None + unsqueeze_63 = torch.ops.aten.unsqueeze.default(view_2266, 3); view_2266 = None + expand_63 = torch.ops.aten.expand.default(unsqueeze_63, [2, 8192, 1, 4, 128]); unsqueeze_63 = None + view_2273 = torch.ops.aten.view.default(expand_63, [2, 8192, 4, 128]); expand_63 = None + permute_344 = torch.ops.aten.permute.default(convert_element_type_1038, [0, 2, 1, 3]); convert_element_type_1038 = None + permute_345 = torch.ops.aten.permute.default(view_2272, [0, 2, 1, 3]); view_2272 = None + permute_346 = torch.ops.aten.permute.default(view_2273, [0, 2, 1, 3]); view_2273 = None + _scaled_dot_product_cudnn_attention_31 = torch.ops.aten._scaled_dot_product_cudnn_attention.default(permute_344, permute_345, permute_346, None, True, 0.0, True); permute_344 = permute_345 = permute_346 = None + getitem_1351 = _scaled_dot_product_cudnn_attention_31[0] + getitem_1352 = _scaled_dot_product_cudnn_attention_31[1] + getitem_1357 = _scaled_dot_product_cudnn_attention_31[6] + getitem_1358 = _scaled_dot_product_cudnn_attention_31[7]; _scaled_dot_product_cudnn_attention_31 = None + permute_347 = torch.ops.aten.permute.default(getitem_1351, [0, 2, 1, 3]) + view_2274 = torch.ops.aten.view.default(permute_347, [2, 8192, -1]); permute_347 = None + convert_element_type_1040 = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16) + all_gather_into_tensor_347 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1040, 8, '0'); convert_element_type_1040 = None + wait_tensor_410 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_347); all_gather_into_tensor_347 = None + permute_348 = torch.ops.aten.permute.default(wait_tensor_410, [1, 0]); wait_tensor_410 = None + view_2280 = torch.ops.aten.view.default(view_2274, [16384, 512]); view_2274 = None + mm_220 = torch.ops.aten.mm.default(view_2280, permute_348); view_2280 = permute_348 = None + view_2281 = torch.ops.aten.view.default(mm_220, [2, 8192, 4096]); mm_220 = None + split_134 = torch.ops.aten.split.Tensor(view_2281, 1024, 1); view_2281 = None + getitem_1360 = split_134[0] + getitem_1361 = split_134[1] + getitem_1362 = split_134[2] + getitem_1363 = split_134[3] + getitem_1364 = split_134[4] + getitem_1365 = split_134[5] + getitem_1366 = split_134[6] + getitem_1367 = split_134[7]; split_134 = None + cat_126 = torch.ops.aten.cat.default([getitem_1360, getitem_1361, getitem_1362, getitem_1363, getitem_1364, getitem_1365, getitem_1366, getitem_1367]); getitem_1360 = getitem_1361 = getitem_1362 = getitem_1363 = getitem_1364 = getitem_1365 = getitem_1366 = getitem_1367 = None + reduce_scatter_tensor_63 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_126, 'sum', 8, '1'); cat_126 = None + wait_tensor_411 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_63) + add_125 = torch.ops.aten.add.Tensor(add_123, wait_tensor_411); wait_tensor_411 = None + convert_element_type_1043 = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16) + all_gather_into_tensor_348 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1043, 8, '0'); convert_element_type_1043 = None + wait_tensor_412 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_348); all_gather_into_tensor_348 = None + convert_element_type_1044 = torch.ops.prims.convert_element_type.default(add_125, torch.float32) + pow_64 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1044, 2) + mean_63 = torch.ops.aten.mean.dim(pow_64, [2], True); pow_64 = None + add_126 = torch.ops.aten.add.Scalar(mean_63, 1e-05); mean_63 = None + rsqrt_63 = torch.ops.aten.rsqrt.default(add_126); add_126 = None + mul_252 = torch.ops.aten.mul.Tensor(convert_element_type_1044, rsqrt_63); convert_element_type_1044 = rsqrt_63 = None + mul_253 = torch.ops.aten.mul.Tensor(mul_252, wait_tensor_412); mul_252 = wait_tensor_412 = None + convert_element_type_1045 = torch.ops.prims.convert_element_type.default(mul_253, torch.bfloat16); mul_253 = None + all_gather_into_tensor_349 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1045, 8, '1'); convert_element_type_1045 = None + wait_tensor_413 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_349); all_gather_into_tensor_349 = None + split_135 = torch.ops.aten.split.Tensor(wait_tensor_413, 2); wait_tensor_413 = None + getitem_1368 = split_135[0] + getitem_1369 = split_135[1] + getitem_1370 = split_135[2] + getitem_1371 = split_135[3] + getitem_1372 = split_135[4] + getitem_1373 = split_135[5] + getitem_1374 = split_135[6] + getitem_1375 = split_135[7]; split_135 = None + cat_127 = torch.ops.aten.cat.default([getitem_1368, getitem_1369, getitem_1370, getitem_1371, getitem_1372, getitem_1373, getitem_1374, getitem_1375], 1); getitem_1368 = getitem_1369 = getitem_1370 = getitem_1371 = getitem_1372 = getitem_1373 = getitem_1374 = getitem_1375 = None + convert_element_type_1046 = torch.ops.prims.convert_element_type.default(primals_289, torch.bfloat16) + all_gather_into_tensor_350 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1046, 8, '0'); convert_element_type_1046 = None + wait_tensor_414 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_350); all_gather_into_tensor_350 = None + permute_349 = torch.ops.aten.permute.default(wait_tensor_414, [1, 0]); wait_tensor_414 = None + view_2292 = torch.ops.aten.view.default(cat_127, [16384, 4096]); cat_127 = None + mm_221 = torch.ops.aten.mm.default(view_2292, permute_349); permute_349 = None + view_2293 = torch.ops.aten.view.default(mm_221, [2, 8192, 1792]) + convert_element_type_1049 = torch.ops.prims.convert_element_type.default(view_2293, torch.float32); view_2293 = None + sigmoid_31 = torch.ops.aten.sigmoid.default(convert_element_type_1049) + mul_254 = torch.ops.aten.mul.Tensor(convert_element_type_1049, sigmoid_31); convert_element_type_1049 = sigmoid_31 = None + convert_element_type_1050 = torch.ops.prims.convert_element_type.default(mul_254, torch.bfloat16); mul_254 = None + convert_element_type_1051 = torch.ops.prims.convert_element_type.default(primals_290, torch.bfloat16) + all_gather_into_tensor_351 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1051, 8, '0'); convert_element_type_1051 = None + wait_tensor_415 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_351); all_gather_into_tensor_351 = None + permute_350 = torch.ops.aten.permute.default(wait_tensor_415, [1, 0]); wait_tensor_415 = None + mm_222 = torch.ops.aten.mm.default(view_2292, permute_350); view_2292 = permute_350 = None + view_2300 = torch.ops.aten.view.default(mm_222, [2, 8192, 1792]); mm_222 = None + mul_255 = torch.ops.aten.mul.Tensor(convert_element_type_1050, view_2300); convert_element_type_1050 = view_2300 = None + convert_element_type_1054 = torch.ops.prims.convert_element_type.default(primals_291, torch.bfloat16) + all_gather_into_tensor_352 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1054, 8, '0'); convert_element_type_1054 = None + wait_tensor_416 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_352); all_gather_into_tensor_352 = None + permute_351 = torch.ops.aten.permute.default(wait_tensor_416, [1, 0]); wait_tensor_416 = None + view_2307 = torch.ops.aten.view.default(mul_255, [16384, 1792]); mul_255 = None + mm_223 = torch.ops.aten.mm.default(view_2307, permute_351); view_2307 = permute_351 = None + view_2308 = torch.ops.aten.view.default(mm_223, [2, 8192, 4096]); mm_223 = None + split_136 = torch.ops.aten.split.Tensor(view_2308, 1024, 1); view_2308 = None + getitem_1376 = split_136[0] + getitem_1377 = split_136[1] + getitem_1378 = split_136[2] + getitem_1379 = split_136[3] + getitem_1380 = split_136[4] + getitem_1381 = split_136[5] + getitem_1382 = split_136[6] + getitem_1383 = split_136[7]; split_136 = None + cat_128 = torch.ops.aten.cat.default([getitem_1376, getitem_1377, getitem_1378, getitem_1379, getitem_1380, getitem_1381, getitem_1382, getitem_1383]); getitem_1376 = getitem_1377 = getitem_1378 = getitem_1379 = getitem_1380 = getitem_1381 = getitem_1382 = getitem_1383 = None + reduce_scatter_tensor_64 = torch.ops._c10d_functional.reduce_scatter_tensor.default(cat_128, 'sum', 8, '1'); cat_128 = None + wait_tensor_417 = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_64) + add_127 = torch.ops.aten.add.Tensor(add_125, wait_tensor_417); add_125 = wait_tensor_417 = None + convert_element_type_1057 = torch.ops.prims.convert_element_type.default(primals_292, torch.bfloat16) + all_gather_into_tensor_353 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1057, 8, '0'); convert_element_type_1057 = None + wait_tensor_418 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_353); all_gather_into_tensor_353 = None + convert_element_type_1058 = torch.ops.prims.convert_element_type.default(add_127, torch.float32); add_127 = None + pow_65 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_1058, 2) + mean_64 = torch.ops.aten.mean.dim(pow_65, [2], True); pow_65 = None + add_128 = torch.ops.aten.add.Scalar(mean_64, 1e-05); mean_64 = None + rsqrt_64 = torch.ops.aten.rsqrt.default(add_128); add_128 = None + mul_256 = torch.ops.aten.mul.Tensor(convert_element_type_1058, rsqrt_64); convert_element_type_1058 = None + mul_257 = torch.ops.aten.mul.Tensor(mul_256, wait_tensor_418); mul_256 = wait_tensor_418 = None + convert_element_type_1059 = torch.ops.prims.convert_element_type.default(mul_257, torch.bfloat16); mul_257 = None + all_gather_into_tensor_354 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1059, 8, '1'); convert_element_type_1059 = None + wait_tensor_419 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_354); all_gather_into_tensor_354 = None + split_137 = torch.ops.aten.split.Tensor(wait_tensor_419, 2); wait_tensor_419 = None + getitem_1384 = split_137[0] + getitem_1385 = split_137[1] + getitem_1386 = split_137[2] + getitem_1387 = split_137[3] + getitem_1388 = split_137[4] + getitem_1389 = split_137[5] + getitem_1390 = split_137[6] + getitem_1391 = split_137[7]; split_137 = None + cat_129 = torch.ops.aten.cat.default([getitem_1384, getitem_1385, getitem_1386, getitem_1387, getitem_1388, getitem_1389, getitem_1390, getitem_1391], 1); getitem_1384 = getitem_1385 = getitem_1386 = getitem_1387 = getitem_1388 = getitem_1389 = getitem_1390 = getitem_1391 = None + convert_element_type_1060 = torch.ops.prims.convert_element_type.default(primals_293, torch.bfloat16) + all_gather_into_tensor_355 = torch.ops._c10d_functional.all_gather_into_tensor.default(convert_element_type_1060, 8, '0'); convert_element_type_1060 = None + wait_tensor_420 = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_355); all_gather_into_tensor_355 = None + permute_352 = torch.ops.aten.permute.default(wait_tensor_420, [1, 0]); wait_tensor_420 = None + view_2319 = torch.ops.aten.view.default(cat_129, [16384, 4096]); cat_129 = None + mm_224 = torch.ops.aten.mm.default(view_2319, permute_352); permute_352 = None + view_2320 = torch.ops.aten.view.default(mm_224, [2, 8192, 16032]); mm_224 = None + return (view_2320, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30, primals_31, primals_32, primals_33, primals_34, primals_35, primals_36, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_46, primals_47, primals_48, primals_49, primals_50, primals_51, primals_52, primals_53, primals_54, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_64, primals_65, primals_66, primals_67, primals_68, primals_69, primals_70, primals_71, primals_72, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_82, primals_83, primals_84, primals_85, primals_86, primals_87, primals_88, primals_89, primals_90, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_100, primals_101, primals_102, primals_103, primals_104, primals_105, primals_106, primals_107, primals_108, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_118, primals_119, primals_120, primals_121, primals_122, primals_123, primals_124, primals_125, primals_126, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_136, primals_137, primals_138, primals_139, primals_140, primals_141, primals_142, primals_143, primals_144, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_154, primals_155, primals_156, primals_157, primals_158, primals_159, primals_160, primals_161, primals_162, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_172, primals_173, primals_174, primals_175, primals_176, primals_177, primals_178, primals_179, primals_180, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_190, primals_191, primals_192, primals_193, primals_194, primals_195, primals_196, primals_197, primals_198, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_208, primals_209, primals_210, primals_211, primals_212, primals_213, primals_214, primals_215, primals_216, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_226, primals_227, primals_228, primals_229, primals_230, primals_231, primals_232, primals_233, primals_234, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_244, primals_245, primals_246, primals_247, primals_248, primals_249, primals_250, primals_251, primals_252, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_262, primals_263, primals_264, primals_265, primals_266, primals_267, primals_268, primals_269, primals_270, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_280, primals_281, primals_282, primals_283, primals_284, primals_285, primals_286, primals_287, primals_288, primals_289, primals_290, primals_291, primals_292, primals_293, wait_tensor_1, mm, mm_2, getitem_80, getitem_81, getitem_86, getitem_87, reduce_scatter_tensor_1, mm_4, add_3, mm_7, mm_9, getitem_121, getitem_122, getitem_127, getitem_128, reduce_scatter_tensor_3, mm_11, add_7, mm_14, mm_16, getitem_162, getitem_163, getitem_168, getitem_169, reduce_scatter_tensor_5, mm_18, add_11, mm_21, mm_23, getitem_203, getitem_204, getitem_209, getitem_210, reduce_scatter_tensor_7, mm_25, add_15, mm_28, mm_30, getitem_244, getitem_245, getitem_250, getitem_251, reduce_scatter_tensor_9, mm_32, add_19, mm_35, mm_37, getitem_285, getitem_286, getitem_291, getitem_292, reduce_scatter_tensor_11, mm_39, add_23, mm_42, mm_44, getitem_326, getitem_327, getitem_332, getitem_333, reduce_scatter_tensor_13, mm_46, add_27, mm_49, mm_51, getitem_367, getitem_368, getitem_373, getitem_374, reduce_scatter_tensor_15, mm_53, add_31, mm_56, mm_58, getitem_408, getitem_409, getitem_414, getitem_415, reduce_scatter_tensor_17, mm_60, add_35, mm_63, mm_65, getitem_449, getitem_450, getitem_455, getitem_456, reduce_scatter_tensor_19, mm_67, add_39, mm_70, mm_72, getitem_490, getitem_491, getitem_496, getitem_497, reduce_scatter_tensor_21, mm_74, add_43, mm_77, mm_79, getitem_531, getitem_532, getitem_537, getitem_538, reduce_scatter_tensor_23, mm_81, add_47, mm_84, mm_86, getitem_572, getitem_573, getitem_578, getitem_579, reduce_scatter_tensor_25, mm_88, add_51, mm_91, mm_93, getitem_613, getitem_614, getitem_619, getitem_620, reduce_scatter_tensor_27, mm_95, add_55, mm_98, mm_100, getitem_654, getitem_655, getitem_660, getitem_661, reduce_scatter_tensor_29, mm_102, add_59, mm_105, mm_107, getitem_695, getitem_696, getitem_701, getitem_702, reduce_scatter_tensor_31, mm_109, add_63, mm_112, mm_114, getitem_736, getitem_737, getitem_742, getitem_743, reduce_scatter_tensor_33, mm_116, add_67, mm_119, mm_121, getitem_777, getitem_778, getitem_783, getitem_784, reduce_scatter_tensor_35, mm_123, add_71, mm_126, mm_128, getitem_818, getitem_819, getitem_824, getitem_825, reduce_scatter_tensor_37, mm_130, add_75, mm_133, mm_135, getitem_859, getitem_860, getitem_865, getitem_866, reduce_scatter_tensor_39, mm_137, add_79, mm_140, mm_142, getitem_900, getitem_901, getitem_906, getitem_907, reduce_scatter_tensor_41, mm_144, add_83, mm_147, mm_149, getitem_941, getitem_942, getitem_947, getitem_948, reduce_scatter_tensor_43, mm_151, add_87, mm_154, mm_156, getitem_982, getitem_983, getitem_988, getitem_989, reduce_scatter_tensor_45, mm_158, add_91, mm_161, mm_163, getitem_1023, getitem_1024, getitem_1029, getitem_1030, reduce_scatter_tensor_47, mm_165, add_95, mm_168, mm_170, getitem_1064, getitem_1065, getitem_1070, getitem_1071, reduce_scatter_tensor_49, mm_172, add_99, mm_175, mm_177, getitem_1105, getitem_1106, getitem_1111, getitem_1112, reduce_scatter_tensor_51, mm_179, add_103, mm_182, mm_184, getitem_1146, getitem_1147, getitem_1152, getitem_1153, reduce_scatter_tensor_53, mm_186, add_107, mm_189, mm_191, getitem_1187, getitem_1188, getitem_1193, getitem_1194, reduce_scatter_tensor_55, mm_193, add_111, mm_196, mm_198, getitem_1228, getitem_1229, getitem_1234, getitem_1235, reduce_scatter_tensor_57, mm_200, add_115, mm_203, mm_205, getitem_1269, getitem_1270, getitem_1275, getitem_1276, reduce_scatter_tensor_59, mm_207, add_119, mm_210, mm_212, getitem_1310, getitem_1311, getitem_1316, getitem_1317, reduce_scatter_tensor_61, mm_214, add_123, mm_217, mm_219, getitem_1351, getitem_1352, getitem_1357, getitem_1358, reduce_scatter_tensor_63, mm_221, reduce_scatter_tensor_64, rsqrt_64, view_2319) + +def load_args(reader): + buf0 = reader.storage(None, 131072, device=device(type='cuda', index=0), dtype_hint=torch.int64) + reader.tensor(buf0, (2, 8192), dtype=torch.int64, is_leaf=True) # primals_1 + buf1 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf1, (2004, 4096), is_leaf=True) # primals_2 + buf2 = reader.storage(None, 4194304, device=device(type='cuda', index=0), dtype_hint=torch.complex64) + reader.tensor(buf2, (8192, 64), dtype=torch.complex64, is_leaf=True) # primals_3 + buf3 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf3, (512,), is_leaf=True) # primals_4 + buf4 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf4, (64, 4096), is_leaf=True) # primals_5 + buf5 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf5, (16, 4096), is_leaf=True) # primals_6 + buf6 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf6, (16, 4096), is_leaf=True) # primals_7 + buf7 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf7, (512, 512), is_leaf=True) # primals_8 + buf8 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf8, (512,), is_leaf=True) # primals_9 + buf9 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf9, (224, 4096), is_leaf=True) # primals_10 + buf10 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf10, (224, 4096), is_leaf=True) # primals_11 + buf11 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf11, (512, 1792), is_leaf=True) # primals_12 + buf12 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf12, (512,), is_leaf=True) # primals_13 + buf13 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf13, (64, 4096), is_leaf=True) # primals_14 + buf14 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf14, (16, 4096), is_leaf=True) # primals_15 + buf15 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf15, (16, 4096), is_leaf=True) # primals_16 + buf16 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf16, (512, 512), is_leaf=True) # primals_17 + buf17 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf17, (512,), is_leaf=True) # primals_18 + buf18 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf18, (224, 4096), is_leaf=True) # primals_19 + buf19 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf19, (224, 4096), is_leaf=True) # primals_20 + buf20 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf20, (512, 1792), is_leaf=True) # primals_21 + buf21 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf21, (512,), is_leaf=True) # primals_22 + buf22 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf22, (64, 4096), is_leaf=True) # primals_23 + buf23 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf23, (16, 4096), is_leaf=True) # primals_24 + buf24 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf24, (16, 4096), is_leaf=True) # primals_25 + buf25 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf25, (512, 512), is_leaf=True) # primals_26 + buf26 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf26, (512,), is_leaf=True) # primals_27 + buf27 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf27, (224, 4096), is_leaf=True) # primals_28 + buf28 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf28, (224, 4096), is_leaf=True) # primals_29 + buf29 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf29, (512, 1792), is_leaf=True) # primals_30 + buf30 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf30, (512,), is_leaf=True) # primals_31 + buf31 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf31, (64, 4096), is_leaf=True) # primals_32 + buf32 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf32, (16, 4096), is_leaf=True) # primals_33 + buf33 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf33, (16, 4096), is_leaf=True) # primals_34 + buf34 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf34, (512, 512), is_leaf=True) # primals_35 + buf35 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf35, (512,), is_leaf=True) # primals_36 + buf36 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf36, (224, 4096), is_leaf=True) # primals_37 + buf37 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf37, (224, 4096), is_leaf=True) # primals_38 + buf38 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf38, (512, 1792), is_leaf=True) # primals_39 + buf39 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf39, (512,), is_leaf=True) # primals_40 + buf40 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf40, (64, 4096), is_leaf=True) # primals_41 + buf41 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf41, (16, 4096), is_leaf=True) # primals_42 + buf42 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf42, (16, 4096), is_leaf=True) # primals_43 + buf43 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf43, (512, 512), is_leaf=True) # primals_44 + buf44 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf44, (512,), is_leaf=True) # primals_45 + buf45 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf45, (224, 4096), is_leaf=True) # primals_46 + buf46 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf46, (224, 4096), is_leaf=True) # primals_47 + buf47 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf47, (512, 1792), is_leaf=True) # primals_48 + buf48 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf48, (512,), is_leaf=True) # primals_49 + buf49 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf49, (64, 4096), is_leaf=True) # primals_50 + buf50 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf50, (16, 4096), is_leaf=True) # primals_51 + buf51 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf51, (16, 4096), is_leaf=True) # primals_52 + buf52 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf52, (512, 512), is_leaf=True) # primals_53 + buf53 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf53, (512,), is_leaf=True) # primals_54 + buf54 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf54, (224, 4096), is_leaf=True) # primals_55 + buf55 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf55, (224, 4096), is_leaf=True) # primals_56 + buf56 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf56, (512, 1792), is_leaf=True) # primals_57 + buf57 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf57, (512,), is_leaf=True) # primals_58 + buf58 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf58, (64, 4096), is_leaf=True) # primals_59 + buf59 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf59, (16, 4096), is_leaf=True) # primals_60 + buf60 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf60, (16, 4096), is_leaf=True) # primals_61 + buf61 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf61, (512, 512), is_leaf=True) # primals_62 + buf62 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf62, (512,), is_leaf=True) # primals_63 + buf63 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf63, (224, 4096), is_leaf=True) # primals_64 + buf64 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf64, (224, 4096), is_leaf=True) # primals_65 + buf65 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf65, (512, 1792), is_leaf=True) # primals_66 + buf66 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf66, (512,), is_leaf=True) # primals_67 + buf67 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf67, (64, 4096), is_leaf=True) # primals_68 + buf68 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf68, (16, 4096), is_leaf=True) # primals_69 + buf69 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf69, (16, 4096), is_leaf=True) # primals_70 + buf70 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf70, (512, 512), is_leaf=True) # primals_71 + buf71 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf71, (512,), is_leaf=True) # primals_72 + buf72 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf72, (224, 4096), is_leaf=True) # primals_73 + buf73 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf73, (224, 4096), is_leaf=True) # primals_74 + buf74 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf74, (512, 1792), is_leaf=True) # primals_75 + buf75 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf75, (512,), is_leaf=True) # primals_76 + buf76 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf76, (64, 4096), is_leaf=True) # primals_77 + buf77 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf77, (16, 4096), is_leaf=True) # primals_78 + buf78 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf78, (16, 4096), is_leaf=True) # primals_79 + buf79 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf79, (512, 512), is_leaf=True) # primals_80 + buf80 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf80, (512,), is_leaf=True) # primals_81 + buf81 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf81, (224, 4096), is_leaf=True) # primals_82 + buf82 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf82, (224, 4096), is_leaf=True) # primals_83 + buf83 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf83, (512, 1792), is_leaf=True) # primals_84 + buf84 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf84, (512,), is_leaf=True) # primals_85 + buf85 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf85, (64, 4096), is_leaf=True) # primals_86 + buf86 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf86, (16, 4096), is_leaf=True) # primals_87 + buf87 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf87, (16, 4096), is_leaf=True) # primals_88 + buf88 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf88, (512, 512), is_leaf=True) # primals_89 + buf89 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf89, (512,), is_leaf=True) # primals_90 + buf90 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf90, (224, 4096), is_leaf=True) # primals_91 + buf91 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf91, (224, 4096), is_leaf=True) # primals_92 + buf92 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf92, (512, 1792), is_leaf=True) # primals_93 + buf93 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf93, (512,), is_leaf=True) # primals_94 + buf94 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf94, (64, 4096), is_leaf=True) # primals_95 + buf95 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf95, (16, 4096), is_leaf=True) # primals_96 + buf96 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf96, (16, 4096), is_leaf=True) # primals_97 + buf97 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf97, (512, 512), is_leaf=True) # primals_98 + buf98 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf98, (512,), is_leaf=True) # primals_99 + buf99 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf99, (224, 4096), is_leaf=True) # primals_100 + buf100 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf100, (224, 4096), is_leaf=True) # primals_101 + buf101 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf101, (512, 1792), is_leaf=True) # primals_102 + buf102 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf102, (512,), is_leaf=True) # primals_103 + buf103 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf103, (64, 4096), is_leaf=True) # primals_104 + buf104 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf104, (16, 4096), is_leaf=True) # primals_105 + buf105 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf105, (16, 4096), is_leaf=True) # primals_106 + buf106 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf106, (512, 512), is_leaf=True) # primals_107 + buf107 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf107, (512,), is_leaf=True) # primals_108 + buf108 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf108, (224, 4096), is_leaf=True) # primals_109 + buf109 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf109, (224, 4096), is_leaf=True) # primals_110 + buf110 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf110, (512, 1792), is_leaf=True) # primals_111 + buf111 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf111, (512,), is_leaf=True) # primals_112 + buf112 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf112, (64, 4096), is_leaf=True) # primals_113 + buf113 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf113, (16, 4096), is_leaf=True) # primals_114 + buf114 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf114, (16, 4096), is_leaf=True) # primals_115 + buf115 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf115, (512, 512), is_leaf=True) # primals_116 + buf116 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf116, (512,), is_leaf=True) # primals_117 + buf117 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf117, (224, 4096), is_leaf=True) # primals_118 + buf118 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf118, (224, 4096), is_leaf=True) # primals_119 + buf119 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf119, (512, 1792), is_leaf=True) # primals_120 + buf120 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf120, (512,), is_leaf=True) # primals_121 + buf121 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf121, (64, 4096), is_leaf=True) # primals_122 + buf122 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf122, (16, 4096), is_leaf=True) # primals_123 + buf123 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf123, (16, 4096), is_leaf=True) # primals_124 + buf124 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf124, (512, 512), is_leaf=True) # primals_125 + buf125 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf125, (512,), is_leaf=True) # primals_126 + buf126 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf126, (224, 4096), is_leaf=True) # primals_127 + buf127 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf127, (224, 4096), is_leaf=True) # primals_128 + buf128 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf128, (512, 1792), is_leaf=True) # primals_129 + buf129 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf129, (512,), is_leaf=True) # primals_130 + buf130 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf130, (64, 4096), is_leaf=True) # primals_131 + buf131 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf131, (16, 4096), is_leaf=True) # primals_132 + buf132 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf132, (16, 4096), is_leaf=True) # primals_133 + buf133 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf133, (512, 512), is_leaf=True) # primals_134 + buf134 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf134, (512,), is_leaf=True) # primals_135 + buf135 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf135, (224, 4096), is_leaf=True) # primals_136 + buf136 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf136, (224, 4096), is_leaf=True) # primals_137 + buf137 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf137, (512, 1792), is_leaf=True) # primals_138 + buf138 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf138, (512,), is_leaf=True) # primals_139 + buf139 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf139, (64, 4096), is_leaf=True) # primals_140 + buf140 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf140, (16, 4096), is_leaf=True) # primals_141 + buf141 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf141, (16, 4096), is_leaf=True) # primals_142 + buf142 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf142, (512, 512), is_leaf=True) # primals_143 + buf143 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf143, (512,), is_leaf=True) # primals_144 + buf144 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf144, (224, 4096), is_leaf=True) # primals_145 + buf145 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf145, (224, 4096), is_leaf=True) # primals_146 + buf146 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf146, (512, 1792), is_leaf=True) # primals_147 + buf147 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf147, (512,), is_leaf=True) # primals_148 + buf148 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf148, (64, 4096), is_leaf=True) # primals_149 + buf149 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf149, (16, 4096), is_leaf=True) # primals_150 + buf150 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf150, (16, 4096), is_leaf=True) # primals_151 + buf151 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf151, (512, 512), is_leaf=True) # primals_152 + buf152 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf152, (512,), is_leaf=True) # primals_153 + buf153 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf153, (224, 4096), is_leaf=True) # primals_154 + buf154 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf154, (224, 4096), is_leaf=True) # primals_155 + buf155 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf155, (512, 1792), is_leaf=True) # primals_156 + buf156 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf156, (512,), is_leaf=True) # primals_157 + buf157 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf157, (64, 4096), is_leaf=True) # primals_158 + buf158 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf158, (16, 4096), is_leaf=True) # primals_159 + buf159 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf159, (16, 4096), is_leaf=True) # primals_160 + buf160 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf160, (512, 512), is_leaf=True) # primals_161 + buf161 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf161, (512,), is_leaf=True) # primals_162 + buf162 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf162, (224, 4096), is_leaf=True) # primals_163 + buf163 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf163, (224, 4096), is_leaf=True) # primals_164 + buf164 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf164, (512, 1792), is_leaf=True) # primals_165 + buf165 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf165, (512,), is_leaf=True) # primals_166 + buf166 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf166, (64, 4096), is_leaf=True) # primals_167 + buf167 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf167, (16, 4096), is_leaf=True) # primals_168 + buf168 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf168, (16, 4096), is_leaf=True) # primals_169 + buf169 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf169, (512, 512), is_leaf=True) # primals_170 + buf170 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf170, (512,), is_leaf=True) # primals_171 + buf171 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf171, (224, 4096), is_leaf=True) # primals_172 + buf172 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf172, (224, 4096), is_leaf=True) # primals_173 + buf173 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf173, (512, 1792), is_leaf=True) # primals_174 + buf174 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf174, (512,), is_leaf=True) # primals_175 + buf175 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf175, (64, 4096), is_leaf=True) # primals_176 + buf176 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf176, (16, 4096), is_leaf=True) # primals_177 + buf177 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf177, (16, 4096), is_leaf=True) # primals_178 + buf178 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf178, (512, 512), is_leaf=True) # primals_179 + buf179 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf179, (512,), is_leaf=True) # primals_180 + buf180 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf180, (224, 4096), is_leaf=True) # primals_181 + buf181 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf181, (224, 4096), is_leaf=True) # primals_182 + buf182 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf182, (512, 1792), is_leaf=True) # primals_183 + buf183 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf183, (512,), is_leaf=True) # primals_184 + buf184 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf184, (64, 4096), is_leaf=True) # primals_185 + buf185 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf185, (16, 4096), is_leaf=True) # primals_186 + buf186 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf186, (16, 4096), is_leaf=True) # primals_187 + buf187 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf187, (512, 512), is_leaf=True) # primals_188 + buf188 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf188, (512,), is_leaf=True) # primals_189 + buf189 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf189, (224, 4096), is_leaf=True) # primals_190 + buf190 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf190, (224, 4096), is_leaf=True) # primals_191 + buf191 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf191, (512, 1792), is_leaf=True) # primals_192 + buf192 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf192, (512,), is_leaf=True) # primals_193 + buf193 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf193, (64, 4096), is_leaf=True) # primals_194 + buf194 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf194, (16, 4096), is_leaf=True) # primals_195 + buf195 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf195, (16, 4096), is_leaf=True) # primals_196 + buf196 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf196, (512, 512), is_leaf=True) # primals_197 + buf197 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf197, (512,), is_leaf=True) # primals_198 + buf198 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf198, (224, 4096), is_leaf=True) # primals_199 + buf199 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf199, (224, 4096), is_leaf=True) # primals_200 + buf200 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf200, (512, 1792), is_leaf=True) # primals_201 + buf201 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf201, (512,), is_leaf=True) # primals_202 + buf202 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf202, (64, 4096), is_leaf=True) # primals_203 + buf203 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf203, (16, 4096), is_leaf=True) # primals_204 + buf204 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf204, (16, 4096), is_leaf=True) # primals_205 + buf205 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf205, (512, 512), is_leaf=True) # primals_206 + buf206 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf206, (512,), is_leaf=True) # primals_207 + buf207 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf207, (224, 4096), is_leaf=True) # primals_208 + buf208 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf208, (224, 4096), is_leaf=True) # primals_209 + buf209 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf209, (512, 1792), is_leaf=True) # primals_210 + buf210 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf210, (512,), is_leaf=True) # primals_211 + buf211 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf211, (64, 4096), is_leaf=True) # primals_212 + buf212 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf212, (16, 4096), is_leaf=True) # primals_213 + buf213 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf213, (16, 4096), is_leaf=True) # primals_214 + buf214 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf214, (512, 512), is_leaf=True) # primals_215 + buf215 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf215, (512,), is_leaf=True) # primals_216 + buf216 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf216, (224, 4096), is_leaf=True) # primals_217 + buf217 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf217, (224, 4096), is_leaf=True) # primals_218 + buf218 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf218, (512, 1792), is_leaf=True) # primals_219 + buf219 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf219, (512,), is_leaf=True) # primals_220 + buf220 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf220, (64, 4096), is_leaf=True) # primals_221 + buf221 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf221, (16, 4096), is_leaf=True) # primals_222 + buf222 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf222, (16, 4096), is_leaf=True) # primals_223 + buf223 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf223, (512, 512), is_leaf=True) # primals_224 + buf224 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf224, (512,), is_leaf=True) # primals_225 + buf225 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf225, (224, 4096), is_leaf=True) # primals_226 + buf226 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf226, (224, 4096), is_leaf=True) # primals_227 + buf227 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf227, (512, 1792), is_leaf=True) # primals_228 + buf228 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf228, (512,), is_leaf=True) # primals_229 + buf229 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf229, (64, 4096), is_leaf=True) # primals_230 + buf230 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf230, (16, 4096), is_leaf=True) # primals_231 + buf231 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf231, (16, 4096), is_leaf=True) # primals_232 + buf232 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf232, (512, 512), is_leaf=True) # primals_233 + buf233 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf233, (512,), is_leaf=True) # primals_234 + buf234 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf234, (224, 4096), is_leaf=True) # primals_235 + buf235 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf235, (224, 4096), is_leaf=True) # primals_236 + buf236 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf236, (512, 1792), is_leaf=True) # primals_237 + buf237 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf237, (512,), is_leaf=True) # primals_238 + buf238 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf238, (64, 4096), is_leaf=True) # primals_239 + buf239 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf239, (16, 4096), is_leaf=True) # primals_240 + buf240 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf240, (16, 4096), is_leaf=True) # primals_241 + buf241 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf241, (512, 512), is_leaf=True) # primals_242 + buf242 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf242, (512,), is_leaf=True) # primals_243 + buf243 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf243, (224, 4096), is_leaf=True) # primals_244 + buf244 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf244, (224, 4096), is_leaf=True) # primals_245 + buf245 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf245, (512, 1792), is_leaf=True) # primals_246 + buf246 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf246, (512,), is_leaf=True) # primals_247 + buf247 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf247, (64, 4096), is_leaf=True) # primals_248 + buf248 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf248, (16, 4096), is_leaf=True) # primals_249 + buf249 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf249, (16, 4096), is_leaf=True) # primals_250 + buf250 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf250, (512, 512), is_leaf=True) # primals_251 + buf251 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf251, (512,), is_leaf=True) # primals_252 + buf252 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf252, (224, 4096), is_leaf=True) # primals_253 + buf253 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf253, (224, 4096), is_leaf=True) # primals_254 + buf254 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf254, (512, 1792), is_leaf=True) # primals_255 + buf255 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf255, (512,), is_leaf=True) # primals_256 + buf256 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf256, (64, 4096), is_leaf=True) # primals_257 + buf257 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf257, (16, 4096), is_leaf=True) # primals_258 + buf258 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf258, (16, 4096), is_leaf=True) # primals_259 + buf259 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf259, (512, 512), is_leaf=True) # primals_260 + buf260 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf260, (512,), is_leaf=True) # primals_261 + buf261 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf261, (224, 4096), is_leaf=True) # primals_262 + buf262 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf262, (224, 4096), is_leaf=True) # primals_263 + buf263 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf263, (512, 1792), is_leaf=True) # primals_264 + buf264 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf264, (512,), is_leaf=True) # primals_265 + buf265 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf265, (64, 4096), is_leaf=True) # primals_266 + buf266 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf266, (16, 4096), is_leaf=True) # primals_267 + buf267 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf267, (16, 4096), is_leaf=True) # primals_268 + buf268 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf268, (512, 512), is_leaf=True) # primals_269 + buf269 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf269, (512,), is_leaf=True) # primals_270 + buf270 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf270, (224, 4096), is_leaf=True) # primals_271 + buf271 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf271, (224, 4096), is_leaf=True) # primals_272 + buf272 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf272, (512, 1792), is_leaf=True) # primals_273 + buf273 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf273, (512,), is_leaf=True) # primals_274 + buf274 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf274, (64, 4096), is_leaf=True) # primals_275 + buf275 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf275, (16, 4096), is_leaf=True) # primals_276 + buf276 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf276, (16, 4096), is_leaf=True) # primals_277 + buf277 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf277, (512, 512), is_leaf=True) # primals_278 + buf278 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf278, (512,), is_leaf=True) # primals_279 + buf279 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf279, (224, 4096), is_leaf=True) # primals_280 + buf280 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf280, (224, 4096), is_leaf=True) # primals_281 + buf281 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf281, (512, 1792), is_leaf=True) # primals_282 + buf282 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf282, (512,), is_leaf=True) # primals_283 + buf283 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf283, (64, 4096), is_leaf=True) # primals_284 + buf284 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf284, (16, 4096), is_leaf=True) # primals_285 + buf285 = reader.storage(None, 262144, device=device(type='cuda', index=0)) + reader.tensor(buf285, (16, 4096), is_leaf=True) # primals_286 + buf286 = reader.storage(None, 1048576, device=device(type='cuda', index=0)) + reader.tensor(buf286, (512, 512), is_leaf=True) # primals_287 + buf287 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf287, (512,), is_leaf=True) # primals_288 + buf288 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf288, (224, 4096), is_leaf=True) # primals_289 + buf289 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf289, (224, 4096), is_leaf=True) # primals_290 + buf290 = reader.storage(None, 3670016, device=device(type='cuda', index=0)) + reader.tensor(buf290, (512, 1792), is_leaf=True) # primals_291 + buf291 = reader.storage(None, 2048, device=device(type='cuda', index=0)) + reader.tensor(buf291, (512,), is_leaf=True) # primals_292 + buf292 = reader.storage(None, 32833536, device=device(type='cuda', index=0)) + reader.tensor(buf292, (2004, 4096), is_leaf=True) # primals_293 + +load_args._version = 0 + +def get_pg_config(): + return {'0': {'size': 8, 'rank': 0}, '1': {'size': 8, 'rank': 0}} + +def get_colls_estimations_file(): + return "colls8_8.table" diff --git a/autoparallel/tools/overlap_simulator/run.py b/autoparallel/tools/overlap_simulator/run.py new file mode 100644 index 00000000..0cb6ebee --- /dev/null +++ b/autoparallel/tools/overlap_simulator/run.py @@ -0,0 +1,849 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +""" +Overlap Scheduling Experiments Runner + +This script runs overlap scheduling experiments with various bucketing strategies +on different model variants and configurations. +""" + +# Standard library imports +import argparse +import copy +import dataclasses +import logging +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +# Third-party imports +import torch + +# Local imports +import torch._dynamo.config +import torch._functorch.config +import torch._inductor.config +import torch._inductor.inductor_prims +import torch.distributed as dist +import torch.fx as fx +import torch.fx.experimental._config +from torch import device, tensor # noqa: F401 (used by repro modules) +from torch._dynamo.testing import rand_strided # noqa: F401 (used by repro modules) +from torch.fx.operator_schemas import normalize_function + +from autoparallel.graph_passes.debug_helpers import create_execution_trace + +# Constants +DEFAULT_VARIANT = "llama3_8b_bw_256_2d" + +BYTES_PER_MB = 1024 * 1024 +MS_TO_US_MULTIPLIER = 1000 + +# Logging configuration +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def configure_torch() -> None: + """Configure torch settings for overlap scheduling experiments.""" + torch._inductor.config.allow_buffer_reuse = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.max_autotune = False + torch._inductor.config.coordinate_descent_tuning = False + torch._inductor.config.deterministic = False + torch._inductor.config.aten_distributed_optimizations.collective_bucketing = True + torch._inductor.config.triton.store_cubin = False + torch._inductor.config.test_configs.runtime_triton_dtype_assert = False + torch._functorch.config.functionalize_rng_ops = False + torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = True + torch._functorch.config.unlift_effect_tokens = False + torch._functorch.config.selective_decompose = False + + +@dataclasses.dataclass +class Stats: + """Statistics for graph collective operations.""" + + num_ag: int # Number of all-gather operations + num_rs: int # Number of reduce-scatter operations + num_ar: int # Number of all-reduce operations + runtime: float # Total runtime in milliseconds + peak_memory_gb: float # Peak memory in GB + + def __str__(self) -> str: + return ( + f"AG:{self.num_ag}, RS:{self.num_rs}, AR:{self.num_ar}, " + f"Runtime:{self.runtime:.2f}ms, PeakMem:{self.peak_memory_gb:.2f}GB" + ) + + +@dataclasses.dataclass +class VariantConfig: + """Configuration for a model variant.""" + + repro_class: type + load_args_func: Callable + get_pg_config_func: Callable + get_colls_file_func: Callable + # Maps repro group names to table group names (e.g. {'513': '1'} means + # EP group "513" uses intranode timings from table group "1"). + # When set, table entries are remapped to match the repro's group names/sizes. + get_colls_group_mapping_func: Optional[Callable] = None + + +class CollectiveEstimationParser: + """Parser for collective estimation table files.""" + + @staticmethod + def parse_table(file_path: str) -> Dict[Tuple[str, int, str], Dict[int, float]]: + """ + Parse the collectives estimations table file. + + Args: + file_path: Path to the table file + + Returns: + Dict mapping (group_name, group_size, collective_name) -> {size_mb: time_ms} + """ + result: Dict[Tuple[str, int, str], Dict[int, float]] = {} + + try: + with open(file_path, "r") as f: + lines = f.readlines() + except FileNotFoundError: + logger.error(f"Collective estimation file not found: {file_path}") + return result + except Exception as e: + logger.error(f"Error reading collective estimation file {file_path}: {e}") + return result + + if len(lines) < 2: + logger.warning( + f"Collective estimation file {file_path} has insufficient data" + ) + return result + + # Parse header to get size columns + header = lines[0] + size_columns: List[int] = [] + for part in header.split(): + if part.endswith("MB"): + try: + size_mb = int(part.replace("MB", "")) + size_columns.append(size_mb) + except ValueError: + continue + + # Process data lines (skip separator line) + for line_num, line in enumerate(lines[2:], start=3): + line = line.strip() + if not line: + continue + + parts = line.split() + if len(parts) < 3 + len(size_columns): + logger.warning(f"Insufficient data in line {line_num} of {file_path}") + continue + + try: + group_name = parts[0] + group_size = int(parts[1]) + collective = parts[2] + + size_to_time: Dict[int, float] = {} + for i, size_mb in enumerate(size_columns): + time_ms = float(parts[3 + i]) + size_to_time[size_mb] = time_ms + + result[(group_name, group_size, collective)] = size_to_time + + except (ValueError, IndexError) as e: + logger.warning(f"Error parsing line {line_num} in {file_path}: {e}") + continue + + logger.info(f"Parsed {len(result)} collective entries from {file_path}") + return result + + @staticmethod + def interpolate_time(size_to_time: Dict[int, float], size_mb: float) -> float: + """ + Interpolate or extrapolate time for a given size in MB. + + Args: + size_to_time: Mapping of size (MB) to time (ms) + size_mb: Target size in MB + + Returns: + Estimated time in milliseconds + """ + if not size_to_time: + return 0.0 + + sorted_sizes = sorted(size_to_time.keys()) + + # For sizes less than 1MB, use 1MB value + if size_mb < 1.0: + return size_to_time.get(1, size_to_time[sorted_sizes[0]]) + + # Exact match + size_int = int(size_mb) + if size_int in size_to_time: + return size_to_time[size_int] + + # Find surrounding points for interpolation + lower_size = None + upper_size = None + + for s in sorted_sizes: + if s <= size_mb: + lower_size = s + if s >= size_mb and upper_size is None: + upper_size = s + + # Extrapolation cases + if lower_size is None: + # Below minimum - use first two points + if len(sorted_sizes) >= 2: + s1, s2 = sorted_sizes[0], sorted_sizes[1] + t1, t2 = size_to_time[s1], size_to_time[s2] + slope = (t2 - t1) / (s2 - s1) + return max(0.0, t1 + slope * (size_mb - s1)) + return size_to_time[sorted_sizes[0]] + + if upper_size is None: + # Above maximum - use last two points + if len(sorted_sizes) >= 2: + s1, s2 = sorted_sizes[-2], sorted_sizes[-1] + t1, t2 = size_to_time[s1], size_to_time[s2] + slope = (t2 - t1) / (s2 - s1) + return max(0.0, t2 + slope * (size_mb - s2)) + return size_to_time[sorted_sizes[-1]] + + # Interpolation between two points + if lower_size == upper_size: + return size_to_time[lower_size] + + t1, t2 = size_to_time[lower_size], size_to_time[upper_size] + fraction = (size_mb - lower_size) / (upper_size - lower_size) + return t1 + fraction * (t2 - t1) + + +class NodeEstimator: + """Handles runtime estimation for nodes in the computation graph.""" + + def __init__( + self, + nodes_estimations_dict: Dict[fx.Node, float], + collective_table: Dict[Tuple[str, int, str], Dict[int, float]], + ): + self.node_names_ests = { + n.name: est for n, est in nodes_estimations_dict.items() + } + self.collective_table = collective_table + + @staticmethod + def get_hint(x: Union[int, torch.SymInt]) -> Optional[int]: + """Extract concrete int from SymInt if needed.""" + if isinstance(x, int): + return x + if hasattr(x, "node") and hasattr(x.node, "hint"): + return x.node.hint + return None + + @staticmethod + def get_tensor_bytes(node: fx.Node) -> Optional[int]: + """Get the size in bytes of the tensor produced by this node.""" + if "val" not in node.meta: + return None + + t = node.meta["val"] + if not isinstance(t, torch.Tensor): + return None + + shape = [NodeEstimator.get_hint(dim) for dim in t.shape] + if any(s is None for s in shape): + return None + + numel = 1 + for dim in shape: + numel *= dim # type: ignore[operator] + return numel * t.dtype.itemsize + + def get_collective_info(self, node: fx.Node) -> Optional[Tuple[str, int, int, str]]: + """ + Extract collective type, group_size, tensor bytes, and group_name. + + Returns: + (collective_name, group_size, tensor_bytes, group_name) or None + """ + if node.op != "call_function": + return None + + target_str = str(node.target) + collective_name = None + group_size = None + + # Determine collective type and extract group_size + if "all_gather_into_tensor" in target_str: + collective_name = "all_gather_into_tensor" + if len(node.args) >= 2: + group_size = ( + self.get_hint(node.args[1]) # type: ignore[arg-type] + if hasattr(node.args[1], "node") + else node.args[1] + ) + if isinstance(node.args[1], int): + group_size = node.args[1] + + elif "reduce_scatter_tensor" in target_str: + collective_name = "reduce_scatter_tensor" + if len(node.args) >= 3: + group_size = ( + self.get_hint(node.args[2]) # type: ignore[arg-type] + if hasattr(node.args[2], "node") + else node.args[2] + ) + if isinstance(node.args[2], int): + group_size = node.args[2] + + elif "all_reduce" in target_str: + collective_name = "all_reduce" + # No explicit group_size in args for all_reduce + + else: + return None + + # Get tensor bytes from input tensor + input_node = node.args[0] if node.args else None + if not isinstance(input_node, fx.Node): + return None + + tensor_bytes = self.get_tensor_bytes(input_node) + if tensor_bytes is None: + return None + + # Extract group_name + try: + group_name = get_group_name(node) + except Exception as e: + logger.warning(f"Failed to extract group name from node {node.name}: {e}") + group_name = "" + + return (collective_name, group_size, tensor_bytes, group_name) # type: ignore[return-value] + + def estimate(self, node: fx.Node) -> float: + """Estimate execution time for a node in milliseconds.""" + # Override collectives with measured table values + coll_info = self.get_collective_info(node) + if coll_info is not None: + collective_name, group_size, tensor_bytes, node_group_name = coll_info + size_mb = tensor_bytes / BYTES_PER_MB + + if group_size is not None: + for ( + table_group, + gs, + cn, + ), size_to_time in self.collective_table.items(): + if ( + gs == group_size + and cn == collective_name + and table_group in node_group_name + ): + return CollectiveEstimationParser.interpolate_time( + size_to_time, size_mb + ) + else: + for ( + table_group, + gs, + cn, + ), size_to_time in self.collective_table.items(): + if cn == collective_name and table_group in node_group_name: + return CollectiveEstimationParser.interpolate_time( + size_to_time, size_mb + ) + + # For all non-collective nodes, use pre-computed estimations from gather_node_runtime_estimations + return self.node_names_ests.get(node.name, 0.0) + + +class ExperimentRunner: + """Main experiment runner for overlap scheduling.""" + + def __init__(self, variant_config: VariantConfig): + self.variant_config = variant_config + self.setup_process_groups() + + def setup_process_groups(self) -> None: + """Set up fake process groups matching the repro's pg_config.""" + from torch.testing._internal.distributed.fake_pg import FakeStore + + pg_config = self.variant_config.get_pg_config_func() + world_size = max(info["size"] for info in pg_config.values()) + + store = FakeStore() + dist.init_process_group( + backend="fake", rank=0, world_size=world_size, store=store + ) + default_pg = dist.distributed_c10d._get_default_group() + + # Create all subgroups (new_group auto-registers under generated names) + pgs = {} + for name, info in pg_config.items(): + if info["size"] == world_size: + pgs[name] = default_pg + else: + pgs[name] = dist.new_group(list(range(info["size"]))) + + # Re-register under the desired names + torch._C._distributed_c10d._unregister_all_process_groups() + for name, pg in pgs.items(): + torch._C._distributed_c10d._register_process_group(name, pg) + + def run_experiment( + self, + variant_name: str, + save_traces: bool = False, + trace_output_dir: str = ".", + ) -> Tuple[Stats, Stats, Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: + """Run overlap scheduling experiment. + + Returns: + Tuple of (stats_before, stats_after, trace_before, trace_after) + """ + try: + # Setup model and graph + mod = self.variant_config.repro_class() + + with torch.no_grad(): + from torch._dynamo.debug_utils import InputReader + from torch.fx.experimental.proxy_tensor import make_fx + + reader = InputReader() + self.variant_config.load_args_func(reader) + args = reader.args + + gm = make_fx(mod, tracing_mode="fake")(*args) + + from torch._inductor.fx_passes.overlap_scheduling import ( # type: ignore[attr-defined] + gather_node_runtime_estimations, + schedule_overlap_bucketing, + ) + + # Gather estimations once — reused for both before/after stats and scheduling + node_estimations, fusion_region_of = gather_node_runtime_estimations(gm) + + colls_file_path = resolve_colls_file_path( + self.variant_config.get_colls_file_func() + ) + estimator = self._create_estimator(node_estimations, colls_file_path) + + # Stats before optimization + gm_before = copy.deepcopy(gm) + before_trace_path = ( + os.path.join(trace_output_dir, f"{variant_name}_before_trace.json") + if save_traces + else None + ) + stats_before, trace_before = self.calculate_stats( + gm_before, + estimator, + f"{variant_name}_before", + save_trace_path=before_trace_path, + ) + + # Run scheduling with pre-computed estimations and fusion regions + gm_after = schedule_overlap_bucketing( # type: ignore[call-arg] + gm, + collective_bucketing=True, + insert_overlap_deps=False, + node_runtime_estimations=node_estimations, + fusion_region_of=fusion_region_of, + ) + # gather runtime estimations as bucketing adds another collectives and ops + ( + node_estimations_after, + fusion_region_of, + ) = gather_node_runtime_estimations(gm_after) + estimator = self._create_estimator( + node_estimations_after, colls_file_path + ) + + # Stats after optimization (same estimations, reordered graph) + after_trace_path = ( + os.path.join(trace_output_dir, f"{variant_name}_after_trace.json") + if save_traces + else None + ) + stats_after, trace_after = self.calculate_stats( + gm_after, + estimator, + f"{variant_name}_after", + save_trace_path=after_trace_path, + ) + + return stats_before, stats_after, trace_before, trace_after + + except Exception as e: + logger.error(f"Error running experiment for {variant_name}: {e}") + raise + + def _create_estimator( + self, nodes_estimations_dict: Dict[fx.Node, float], colls_file_path: str + ) -> NodeEstimator: + """Create a node estimator with collective table.""" + raw_table = CollectiveEstimationParser.parse_table(colls_file_path) + + mapping_func = self.variant_config.get_colls_group_mapping_func + if mapping_func is not None: + # Remap table entries: replace table group names/sizes with repro's + pg_config = self.variant_config.get_pg_config_func() + mapping = mapping_func() + collective_table = {} + for repro_group, table_group in mapping.items(): + repro_size = pg_config[repro_group]["size"] + for (tg, gs, cn), timings in raw_table.items(): + if tg == table_group: + collective_table[(repro_group, repro_size, cn)] = timings + else: + collective_table = raw_table + + return NodeEstimator(nodes_estimations_dict, collective_table) + + def calculate_stats( + self, + gm: fx.GraphModule, + estimator: NodeEstimator, + name: str, + save_trace_path: Optional[str] = None, + ) -> Tuple[Stats, Dict[str, Any]]: + """Calculate statistics for a graph module. + + Returns: + Tuple of (Stats, trace_dict) + """ + from torch._inductor.fx_passes.memory_estimator import build_memory_profile + + num_ag = num_rs = num_ar = 0 + + for node in gm.graph.nodes: + if node.op == "call_function": + target_str = str(node.target) + if "all_gather_into_tensor" in target_str: + num_ag += 1 + elif "reduce_scatter_tensor" in target_str: + num_rs += 1 + elif "all_reduce" in target_str: + num_ar += 1 + + trace = create_execution_trace( + gm, estimator.estimate, name=name, file_path=save_trace_path + ) + + # Calculate total runtime + max_end_time = 0.0 + for event in trace.get("traceEvents", []): + ts = event.get("ts", 0) + dur = event.get("dur", 0) + end_time = ts + dur + max_end_time = max(max_end_time, end_time) + + runtime_ms = max_end_time / MS_TO_US_MULTIPLIER # Convert back to ms + + # Calculate peak memory + memory_profile = build_memory_profile( + gm.graph, is_releasable=lambda n: n.op != "placeholder" + ) + peak_memory_gb = max(memory_profile) / 2**30 if memory_profile else 0.0 + + return ( + Stats( + num_ag=num_ag, + num_rs=num_rs, + num_ar=num_ar, + runtime=runtime_ms, + peak_memory_gb=peak_memory_gb, + ), + trace, + ) + + def cleanup(self) -> None: + """Clean up process groups.""" + dist.destroy_process_group() + + +def get_group_name(n: fx.Node) -> str: + """Extract the group name from a collective operation node.""" + opt_args_kwargs = normalize_function( + n.target, # type: ignore[arg-type] + args=n.args, + kwargs=n.kwargs, + normalize_to_only_use_kwargs=True, + ) + assert opt_args_kwargs is not None + _, kwargs = opt_args_kwargs + return kwargs["group_name"] + + +def resolve_colls_file_path(filename: str) -> str: + """Resolve collective estimations filename to full path relative to run.py.""" + script_dir = os.path.dirname(os.path.abspath(__file__)) + return os.path.join(script_dir, filename) + + +def get_variant_configs() -> Dict[str, VariantConfig]: + """Get all available variant configurations.""" + # Import all repro modules + from repro_dsv3_bw_64 import Repro as Repro_dsv3_bw_64 + from repro_dsv3_bw_64 import get_colls_estimations_file as get_colls_file_dsv3_bw_64 + from repro_dsv3_bw_64 import get_colls_group_mapping as get_colls_mapping_dsv3_bw_64 + from repro_dsv3_bw_64 import get_pg_config as get_pg_config_dsv3_bw_64 + from repro_dsv3_bw_64 import load_args as load_args_dsv3_bw_64 + from repro_dsv3_bw_128 import Repro as Repro_dsv3_bw_128 + from repro_dsv3_bw_128 import ( + get_colls_estimations_file as get_colls_file_dsv3_bw_128, + ) + from repro_dsv3_bw_128 import ( + get_colls_group_mapping as get_colls_mapping_dsv3_bw_128, + ) + from repro_dsv3_bw_128 import get_pg_config as get_pg_config_dsv3_bw_128 + from repro_dsv3_bw_128 import load_args as load_args_dsv3_bw_128 + from repro_dsv3_fw_64 import Repro as Repro_dsv3_fw_64 + from repro_dsv3_fw_64 import get_colls_estimations_file as get_colls_file_dsv3_fw_64 + from repro_dsv3_fw_64 import get_colls_group_mapping as get_colls_mapping_dsv3_fw_64 + from repro_dsv3_fw_64 import get_pg_config as get_pg_config_dsv3_fw_64 + from repro_dsv3_fw_64 import load_args as load_args_dsv3_fw_64 + from repro_dsv3_fw_128 import Repro as Repro_dsv3_fw_128 + from repro_dsv3_fw_128 import ( + get_colls_estimations_file as get_colls_file_dsv3_fw_128, + ) + from repro_dsv3_fw_128 import ( + get_colls_group_mapping as get_colls_mapping_dsv3_fw_128, + ) + from repro_dsv3_fw_128 import get_pg_config as get_pg_config_dsv3_fw_128 + from repro_dsv3_fw_128 import load_args as load_args_dsv3_fw_128 + from repro_llama3_8b_bw_64_1d import Repro as Repro_bw_64_1d + from repro_llama3_8b_bw_64_1d import ( + get_colls_estimations_file as get_colls_file_bw_64_1d, + ) + from repro_llama3_8b_bw_64_1d import get_pg_config as get_pg_config_bw_64_1d + from repro_llama3_8b_bw_64_1d import load_args as load_args_bw_64_1d + from repro_llama3_8b_bw_64_2d import Repro as Repro_bw_64_2d + from repro_llama3_8b_bw_64_2d import ( + get_colls_estimations_file as get_colls_file_bw_64_2d, + ) + from repro_llama3_8b_bw_64_2d import get_pg_config as get_pg_config_bw_64_2d + from repro_llama3_8b_bw_64_2d import load_args as load_args_bw_64_2d + from repro_llama3_8b_bw_256_1d import Repro as Repro_bw_256_1d + from repro_llama3_8b_bw_256_1d import ( + get_colls_estimations_file as get_colls_file_bw_256_1d, + ) + from repro_llama3_8b_bw_256_1d import get_pg_config as get_pg_config_bw_256_1d + from repro_llama3_8b_bw_256_1d import load_args as load_args_bw_256_1d + from repro_llama3_8b_bw_256_2d import Repro as Repro_bw_256_2d + from repro_llama3_8b_bw_256_2d import ( + get_colls_estimations_file as get_colls_file_bw_256_2d, + ) + from repro_llama3_8b_bw_256_2d import get_pg_config as get_pg_config_bw_256_2d + from repro_llama3_8b_bw_256_2d import load_args as load_args_bw_256_2d + from repro_llama3_8b_fw_64_1d import Repro as Repro_fw_64_1d + from repro_llama3_8b_fw_64_1d import ( + get_colls_estimations_file as get_colls_file_fw_64_1d, + ) + from repro_llama3_8b_fw_64_1d import get_pg_config as get_pg_config_fw_64_1d + from repro_llama3_8b_fw_64_1d import load_args as load_args_fw_64_1d + from repro_llama3_8b_fw_64_2d import Repro as Repro_fw_64_2d + from repro_llama3_8b_fw_64_2d import ( + get_colls_estimations_file as get_colls_file_fw_64_2d, + ) + from repro_llama3_8b_fw_64_2d import get_pg_config as get_pg_config_fw_64_2d + from repro_llama3_8b_fw_64_2d import load_args as load_args_fw_64_2d + from repro_llama3_8b_fw_256_1d import Repro as Repro_fw_256_1d + from repro_llama3_8b_fw_256_1d import ( + get_colls_estimations_file as get_colls_file_fw_256_1d, + ) + from repro_llama3_8b_fw_256_1d import get_pg_config as get_pg_config_fw_256_1d + from repro_llama3_8b_fw_256_1d import load_args as load_args_fw_256_1d + from repro_llama3_8b_fw_256_2d import Repro as Repro_fw_256_2d + from repro_llama3_8b_fw_256_2d import ( + get_colls_estimations_file as get_colls_file_fw_256_2d, + ) + from repro_llama3_8b_fw_256_2d import get_pg_config as get_pg_config_fw_256_2d + from repro_llama3_8b_fw_256_2d import load_args as load_args_fw_256_2d + + return { + "llama3_8b_bw_256_2d": VariantConfig( + Repro_bw_256_2d, + load_args_bw_256_2d, + get_pg_config_bw_256_2d, + get_colls_file_bw_256_2d, + ), + "llama3_8b_bw_256_1d": VariantConfig( + Repro_bw_256_1d, + load_args_bw_256_1d, + get_pg_config_bw_256_1d, + get_colls_file_bw_256_1d, + ), + "llama3_8b_bw_64_2d": VariantConfig( + Repro_bw_64_2d, + load_args_bw_64_2d, + get_pg_config_bw_64_2d, + get_colls_file_bw_64_2d, + ), + "llama3_8b_bw_64_1d": VariantConfig( + Repro_bw_64_1d, + load_args_bw_64_1d, + get_pg_config_bw_64_1d, + get_colls_file_bw_64_1d, + ), + "llama3_8b_fw_256_2d": VariantConfig( + Repro_fw_256_2d, + load_args_fw_256_2d, + get_pg_config_fw_256_2d, + get_colls_file_fw_256_2d, + ), + "llama3_8b_fw_256_1d": VariantConfig( + Repro_fw_256_1d, + load_args_fw_256_1d, + get_pg_config_fw_256_1d, + get_colls_file_fw_256_1d, + ), + "llama3_8b_fw_64_2d": VariantConfig( + Repro_fw_64_2d, + load_args_fw_64_2d, + get_pg_config_fw_64_2d, + get_colls_file_fw_64_2d, + ), + "llama3_8b_fw_64_1d": VariantConfig( + Repro_fw_64_1d, + load_args_fw_64_1d, + get_pg_config_fw_64_1d, + get_colls_file_fw_64_1d, + ), + "dsv3_fw_64": VariantConfig( + Repro_dsv3_fw_64, + load_args_dsv3_fw_64, + get_pg_config_dsv3_fw_64, + get_colls_file_dsv3_fw_64, + get_colls_mapping_dsv3_fw_64, + ), + "dsv3_bw_64": VariantConfig( + Repro_dsv3_bw_64, + load_args_dsv3_bw_64, + get_pg_config_dsv3_bw_64, + get_colls_file_dsv3_bw_64, + get_colls_mapping_dsv3_bw_64, + ), + "dsv3_fw_128": VariantConfig( + Repro_dsv3_fw_128, + load_args_dsv3_fw_128, + get_pg_config_dsv3_fw_128, + get_colls_file_dsv3_fw_128, + get_colls_mapping_dsv3_fw_128, + ), + "dsv3_bw_128": VariantConfig( + Repro_dsv3_bw_128, + load_args_dsv3_bw_128, + get_pg_config_dsv3_bw_128, + get_colls_file_dsv3_bw_128, + get_colls_mapping_dsv3_bw_128, + ), + } + + +def create_argument_parser() -> argparse.ArgumentParser: + """Create and configure the argument parser.""" + parser = argparse.ArgumentParser( + description="Run overlap scheduling experiments", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python run.py --variant dsv3_bw_64 --save-traces + python run.py --variant dsv3_bw_128 + python run.py --variant llama3_8b_bw_256_2d +""", + ) + + variant_choices = list(get_variant_configs().keys()) + parser.add_argument( + "--variant", + type=str, + default=DEFAULT_VARIANT, + choices=variant_choices, + help=f"Model variant (default: {DEFAULT_VARIANT})", + ) + + parser.add_argument( + "--save-traces", + action="store_true", + help="Save simulated traces to JSON files for visualization", + ) + + parser.add_argument( + "--trace-output-dir", + type=str, + default=".", + help="Directory to save trace files (default: current directory)", + ) + + return parser + + +def main() -> None: + """Main entry point.""" + # Configure torch before any experiments + configure_torch() + + # Parse arguments + parser = create_argument_parser() + args = parser.parse_args() + + # Get variant configuration + variant_configs = get_variant_configs() + if args.variant not in variant_configs: + logger.error(f"Unknown variant: {args.variant}") + return + + variant_config = variant_configs[args.variant] + + # Run experiment + try: + logger.info( + f"Running overlap scheduling experiment for variant: {args.variant}" + ) + + runner = ExperimentRunner(variant_config) + stats_before, stats_after, trace_before, trace_after = runner.run_experiment( + args.variant, + save_traces=args.save_traces, + trace_output_dir=args.trace_output_dir, + ) + + # Print results + logger.info("Experiment completed successfully") + print(f"\nResults for {args.variant}:") + print(f"BEFORE: {stats_before}") + print(f"AFTER: {stats_after}") + + # Calculate improvement + if stats_before.runtime > 0: + improvement = ( + (stats_before.runtime - stats_after.runtime) / stats_before.runtime + ) * 100 + print(f"Runtime improvement: {improvement:.2f}%") + + runner.cleanup() + + except Exception as e: + logger.error(f"Experiment failed: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index a78962c7..70a1d2bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,3 +43,13 @@ exclude = [ [tool.hatch.metadata] allow-direct-references = true + +[tool.black] +exclude = "autoparallel/tools/overlap_simulator/repro_.*\\.py" + +[tool.isort] +profile = "black" +skip_glob = ["autoparallel/tools/overlap_simulator/repro_*.py"] + +[tool.mypy] +exclude = ["autoparallel/tools/overlap_simulator/repro_.*\\.py"]